network_protocol/service/
client.rs

1use futures::{SinkExt, StreamExt};
2use tokio::time;
3use tracing::{debug, info, warn, instrument};
4
5use crate::config::ClientConfig;
6
7use crate::core::packet::Packet;
8use crate::protocol::message::Message;
9// Import secure handshake functions
10use crate::protocol::handshake::{client_secure_handshake_init, client_secure_handshake_verify, client_derive_session_key};
11use crate::protocol::heartbeat::{build_ping, is_pong};
12use crate::protocol::keepalive::KeepAliveManager;
13use crate::service::secure::SecureConnection;
14use crate::transport::remote;
15use crate::error::{Result, ProtocolError};
16use crate::utils::timeout::with_timeout_error;
17
18/// High-level protocol client with post-handshake encryption
19pub struct Client {
20    conn: SecureConnection,
21    keep_alive: KeepAliveManager,
22    config: ClientConfig,
23}
24
25impl Client {
26    /// Connect and perform secure handshake with timeout using default configuration
27    #[instrument(skip(addr), fields(address = %addr))]
28    pub async fn connect(addr: &str) -> Result<Self> {
29        let config = ClientConfig {
30            address: addr.to_string(),
31            ..Default::default()
32        };
33        Self::connect_with_config(config).await
34    }
35    
36    /// Connect and perform secure handshake with custom configuration
37    #[instrument(skip(config), fields(address = %config.address))]
38    pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
39        // Connect with timeout
40        let mut framed = with_timeout_error(
41            async {
42                remote::connect(&config.address).await
43            },
44            config.connection_timeout
45        ).await?;
46        
47        // --- Legacy Handshake Support ---
48        // Commented out legacy code for reference
49        // #[allow(deprecated)]
50        // async fn legacy_handshake(framed: &mut remote::RemoteFramed) -> Result<[u8; 32]> {
51        //     // Legacy handshake process
52        //     let (client_nonce, handshake) = client_handshake_init();
53        //     
54        // --- Secure Handshake Process ---
55        // Step 1: Send client init with public key, nonce, and timestamp
56        let init_msg = client_secure_handshake_init()?;
57        let init_bytes = bincode::serialize(&init_msg)?;
58        framed.send(Packet {
59            version: 1,
60            payload: init_bytes,
61        }).await?;
62        
63        // Step 2: Receive server response with timeout
64        let response = with_timeout_error(
65            async {
66                let packet = framed.next().await
67                    .ok_or(ProtocolError::ConnectionClosed)?
68                    .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
69                bincode::deserialize::<Message>(&packet.payload)
70                    .map_err(|e| ProtocolError::DeserializeError(e.to_string()))
71            },
72            config.connection_timeout
73        ).await?;
74        
75        // Step 3: Verify server response and send confirmation
76        // Extract data from server response
77        let (server_pub_key, server_nonce, nonce_verification) = match response {
78            Message::SecureHandshakeResponse { pub_key, nonce, nonce_verification } => {
79                (pub_key, nonce, nonce_verification)
80            },
81            _ => return Err(ProtocolError::HandshakeError("Invalid server response message type".into())),
82        };
83        
84        // Verify the server's response and prepare confirmation message
85        let verify_msg = client_secure_handshake_verify(server_pub_key, server_nonce, nonce_verification)?;
86        
87        let verify_bytes = bincode::serialize(&verify_msg)?;
88        framed.send(Packet {
89            version: 1,
90            payload: verify_bytes,
91        }).await?;
92        
93        // Step 4: Derive shared session key
94        let key = client_derive_session_key()?;
95        let conn = SecureConnection::new(framed, key);
96        
97        // Create keep-alive manager with configured interval
98        let dead_timeout = config.heartbeat_interval.mul_f32(4.0); // 4x the heartbeat interval
99        let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
100        
101        info!("Connection established successfully");
102        Ok(Self { conn, keep_alive, config })
103    }
104
105    /// Securely send a message
106    #[instrument(skip(self, msg))]
107    pub async fn send(&mut self, msg: Message) -> Result<()> {
108        let result = self.conn.secure_send(msg).await;
109        if result.is_ok() {
110            self.keep_alive.update_send();
111        }
112        result
113    }
114
115    /// Securely receive a message
116    #[instrument(skip(self))]
117    pub async fn recv(&mut self) -> Result<Message> {
118        let result = self.conn.secure_recv().await;
119        if result.is_ok() {
120            self.keep_alive.update_recv();
121        }
122        result
123    }
124    
125    /// Send a keep-alive ping to the server
126    #[instrument(skip(self))]
127    pub async fn send_keepalive(&mut self) -> Result<()> {
128        debug!("Sending keep-alive ping");
129        let ping = build_ping();
130        self.send(ping).await
131    }
132    
133    /// Wait for messages with keep-alive handling using custom timeout
134    #[instrument(skip(self))]
135    pub async fn recv_with_keepalive(&mut self, timeout_duration: std::time::Duration) -> Result<Message> {
136        let mut ping_interval = time::interval(self.keep_alive.ping_interval());
137        
138        let timeout = time::sleep(timeout_duration);
139        tokio::pin!(timeout);
140        
141        loop {
142            tokio::select! {
143                // Check if we need to send a ping
144                _ = ping_interval.tick() => {
145                    if self.keep_alive.should_ping() {
146                        self.send_keepalive().await?;
147                    }
148                    
149                    // Check if connection is dead
150                    if self.keep_alive.is_connection_dead() {
151                        warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(), 
152                              "Connection appears dead");
153                        return Err(ProtocolError::ConnectionTimeout);
154                    }
155                }
156                
157                // Try to receive a message
158                recv_result = self.conn.secure_recv::<Message>() => {
159                    match recv_result {
160                        Ok(msg) => {
161                            self.keep_alive.update_recv();
162                            
163                            // Filter out pong messages, return everything else
164                            if !is_pong(&msg) {
165                                return Ok(msg);
166                            } else {
167                                debug!("Received pong response");
168                                // Continue waiting for non-pong messages
169                            }
170                        }
171                        Err(ProtocolError::Timeout) => {
172                            // Timeout is expected, just continue the loop
173                            continue;
174                        }
175                        Err(e) => return Err(e),
176                    }
177                }
178                
179                // User-provided timeout
180                _ = &mut timeout => {
181                    return Err(ProtocolError::Timeout);
182                }
183            }
184        }
185    }
186
187    /// Send a message and wait for a response with keep-alive handling
188    #[instrument(skip(self, msg))]
189    pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
190        self.send(msg).await?;
191        // Use configured response timeout
192        self.recv_with_keepalive(self.config.response_timeout).await
193    }
194}