network_protocol/service/
client.rs

1use futures::{SinkExt, StreamExt};
2use tokio::time;
3use tracing::{debug, info, instrument, warn};
4
5use crate::config::ClientConfig;
6
7use crate::core::packet::Packet;
8use crate::protocol::message::Message;
9// Import secure handshake functions
10use crate::error::{ProtocolError, Result};
11use crate::protocol::handshake::{
12    client_derive_session_key, client_secure_handshake_init, client_secure_handshake_verify,
13};
14use crate::protocol::heartbeat::{build_ping, is_pong};
15use crate::protocol::keepalive::KeepAliveManager;
16use crate::service::secure::SecureConnection;
17use crate::transport::remote;
18use crate::utils::timeout::with_timeout_error;
19
20/// High-level protocol client with post-handshake encryption
21pub struct Client {
22    conn: SecureConnection,
23    keep_alive: KeepAliveManager,
24    config: ClientConfig,
25}
26
27impl Client {
28    /// Connect and perform secure handshake with timeout using default configuration
29    #[instrument(skip(addr), fields(address = %addr))]
30    pub async fn connect(addr: &str) -> Result<Self> {
31        let config = ClientConfig {
32            address: addr.to_string(),
33            ..Default::default()
34        };
35        Self::connect_with_config(config).await
36    }
37
38    /// Connect and perform secure handshake with custom configuration
39    #[instrument(skip(config), fields(address = %config.address))]
40    pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
41        // Connect with timeout
42        let mut framed = with_timeout_error(
43            async { remote::connect(&config.address).await },
44            config.connection_timeout,
45        )
46        .await?;
47
48        // --- Legacy Handshake Support ---
49        // Commented out legacy code for reference
50        // #[allow(deprecated)]
51        // async fn legacy_handshake(framed: &mut remote::RemoteFramed) -> Result<[u8; 32]> {
52        //     // Legacy handshake process
53        //     let (client_nonce, handshake) = client_handshake_init();
54        //
55        // --- Secure Handshake Process ---
56        // Step 1: Send client init with public key, nonce, and timestamp
57        let (client_state, init_msg) = client_secure_handshake_init()?;
58        let init_bytes = bincode::serialize(&init_msg)?;
59        framed
60            .send(Packet {
61                version: 1,
62                payload: init_bytes,
63            })
64            .await?;
65
66        // Step 2: Receive server response with timeout
67        let response = with_timeout_error(
68            async {
69                let packet = framed
70                    .next()
71                    .await
72                    .ok_or(ProtocolError::ConnectionClosed)?
73                    .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
74                bincode::deserialize::<Message>(&packet.payload)
75                    .map_err(|e| ProtocolError::DeserializeError(e.to_string()))
76            },
77            config.connection_timeout,
78        )
79        .await?;
80
81        // Step 3: Verify server response and send confirmation
82        // Extract data from server response
83        let (server_pub_key, server_nonce, nonce_verification) = match response {
84            Message::SecureHandshakeResponse {
85                pub_key,
86                nonce,
87                nonce_verification,
88            } => (pub_key, nonce, nonce_verification),
89            _ => {
90                return Err(ProtocolError::HandshakeError(
91                    "Invalid server response message type".into(),
92                ))
93            }
94        };
95
96        // Verify the server's response and prepare confirmation message
97        let (client_state_verified, verify_msg) = client_secure_handshake_verify(
98            client_state,
99            server_pub_key,
100            server_nonce,
101            nonce_verification,
102        )?;
103
104        let verify_bytes = bincode::serialize(&verify_msg)?;
105        framed
106            .send(Packet {
107                version: 1,
108                payload: verify_bytes,
109            })
110            .await?;
111
112        // Step 4: Derive shared session key
113        let key = client_derive_session_key(client_state_verified)?;
114        let conn = SecureConnection::new(framed, key);
115
116        // Create keep-alive manager with configured interval
117        let dead_timeout = config.heartbeat_interval.mul_f32(4.0); // 4x the heartbeat interval
118        let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
119
120        info!("Connection established successfully");
121        Ok(Self {
122            conn,
123            keep_alive,
124            config,
125        })
126    }
127
128    /// Securely send a message
129    #[instrument(skip(self, msg))]
130    pub async fn send(&mut self, msg: Message) -> Result<()> {
131        let result = self.conn.secure_send(msg).await;
132        if result.is_ok() {
133            self.keep_alive.update_send();
134        }
135        result
136    }
137
138    /// Securely receive a message
139    #[instrument(skip(self))]
140    pub async fn recv(&mut self) -> Result<Message> {
141        let result = self.conn.secure_recv().await;
142        if result.is_ok() {
143            self.keep_alive.update_recv();
144        }
145        result
146    }
147
148    /// Send a keep-alive ping to the server
149    #[instrument(skip(self))]
150    pub async fn send_keepalive(&mut self) -> Result<()> {
151        debug!("Sending keep-alive ping");
152        let ping = build_ping();
153        self.send(ping).await
154    }
155
156    /// Wait for messages with keep-alive handling using custom timeout
157    #[instrument(skip(self))]
158    pub async fn recv_with_keepalive(
159        &mut self,
160        timeout_duration: std::time::Duration,
161    ) -> Result<Message> {
162        let mut ping_interval = time::interval(self.keep_alive.ping_interval());
163
164        let timeout = time::sleep(timeout_duration);
165        tokio::pin!(timeout);
166
167        loop {
168            tokio::select! {
169                // Check if we need to send a ping
170                _ = ping_interval.tick() => {
171                    if self.keep_alive.should_ping() {
172                        self.send_keepalive().await?;
173                    }
174
175                    // Check if connection is dead
176                    if self.keep_alive.is_connection_dead() {
177                        warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(),
178                              "Connection appears dead");
179                        return Err(ProtocolError::ConnectionTimeout);
180                    }
181                }
182
183                // Try to receive a message
184                recv_result = self.conn.secure_recv::<Message>() => {
185                    match recv_result {
186                        Ok(msg) => {
187                            self.keep_alive.update_recv();
188
189                            // Filter out pong messages, return everything else
190                            if !is_pong(&msg) {
191                                return Ok(msg);
192                            } else {
193                                debug!("Received pong response");
194                                // Continue waiting for non-pong messages
195                            }
196                        }
197                        Err(ProtocolError::Timeout) => {
198                            // Timeout is expected, just continue the loop
199                            continue;
200                        }
201                        Err(e) => return Err(e),
202                    }
203                }
204
205                // User-provided timeout
206                _ = &mut timeout => {
207                    return Err(ProtocolError::Timeout);
208                }
209            }
210        }
211    }
212
213    /// Send a message and wait for a response with keep-alive handling
214    #[instrument(skip(self, msg))]
215    pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
216        self.send(msg).await?;
217        // Use configured response timeout
218        self.recv_with_keepalive(self.config.response_timeout).await
219    }
220}