Skip to main content

network_protocol/service/
client.rs

1use futures::{SinkExt, StreamExt};
2use tokio::time;
3use tracing::{debug, info, instrument, warn};
4
5use crate::config::ClientConfig;
6
7use crate::core::packet::Packet;
8use crate::protocol::message::Message;
9// Import secure handshake functions
10use crate::error::{ProtocolError, Result};
11use crate::protocol::handshake::{
12    client_derive_session_key, client_secure_handshake_init, client_secure_handshake_verify,
13};
14use crate::protocol::heartbeat::{build_ping, is_pong};
15use crate::protocol::keepalive::KeepAliveManager;
16use crate::service::secure::SecureConnection;
17use crate::transport::remote;
18use crate::utils::replay_cache::ReplayCache;
19use crate::utils::timeout::with_timeout_error;
20
21/// High-level protocol client with post-handshake encryption
22pub struct Client {
23    conn: SecureConnection,
24    keep_alive: KeepAliveManager,
25    config: ClientConfig,
26    #[allow(dead_code)]
27    replay_cache: ReplayCache,
28}
29
30impl Client {
31    /// Connect and perform secure handshake with timeout using default configuration
32    #[instrument(skip(addr), fields(address = %addr))]
33    pub async fn connect(addr: &str) -> Result<Self> {
34        let config = ClientConfig {
35            address: addr.to_string(),
36            ..Default::default()
37        };
38        Self::connect_with_config(config).await
39    }
40
41    /// Connect and perform secure handshake with custom configuration
42    #[instrument(skip(config), fields(address = %config.address))]
43    pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
44        // Connect with timeout
45        let mut framed = with_timeout_error(
46            async { remote::connect(&config.address).await },
47            config.connection_timeout,
48        )
49        .await?;
50
51        // --- Legacy Handshake Support ---
52        // Commented out legacy code for reference
53        // #[allow(deprecated)]
54        // async fn legacy_handshake(framed: &mut remote::RemoteFramed) -> Result<[u8; 32]> {
55        //     // Legacy handshake process
56        //     let (client_nonce, handshake) = client_handshake_init();
57        //
58        // --- Secure Handshake Process ---
59        // Step 1: Send client init with public key, nonce, and timestamp
60        let (client_state, init_msg) = client_secure_handshake_init()?;
61        let init_bytes = bincode::serialize(&init_msg)?;
62        framed
63            .send(Packet {
64                version: 1,
65                payload: init_bytes,
66            })
67            .await?;
68
69        // Step 2: Receive server response with timeout
70        let response = with_timeout_error(
71            async {
72                let packet = framed
73                    .next()
74                    .await
75                    .ok_or(ProtocolError::ConnectionClosed)?
76                    .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
77                bincode::deserialize::<Message>(&packet.payload)
78                    .map_err(|e| ProtocolError::DeserializeError(e.to_string()))
79            },
80            config.connection_timeout,
81        )
82        .await?;
83
84        // Step 3: Verify server response and send confirmation
85        // Extract data from server response
86        let (server_pub_key, server_nonce, nonce_verification) = match response {
87            Message::SecureHandshakeResponse {
88                pub_key,
89                nonce,
90                nonce_verification,
91            } => (pub_key, nonce, nonce_verification),
92            _ => {
93                return Err(ProtocolError::HandshakeError(
94                    "Invalid server response message type".into(),
95                ))
96            }
97        };
98
99        // Verify the server's response and prepare confirmation message
100        let (client_state_verified, verify_msg) = client_secure_handshake_verify(
101            client_state,
102            server_pub_key,
103            server_nonce,
104            nonce_verification,
105            &config.address,
106            &mut ReplayCache::new(),
107        )?;
108
109        let verify_bytes = bincode::serialize(&verify_msg)?;
110        framed
111            .send(Packet {
112                version: 1,
113                payload: verify_bytes,
114            })
115            .await?;
116
117        // Step 4: Derive shared session key
118        let key = client_derive_session_key(client_state_verified)?;
119        let conn = SecureConnection::new(framed, key);
120
121        // Create keep-alive manager with configured interval
122        let dead_timeout = config.heartbeat_interval.mul_f32(4.0); // 4x the heartbeat interval
123        let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
124
125        info!("Connection established successfully");
126        Ok(Self {
127            conn,
128            keep_alive,
129            config,
130            replay_cache: ReplayCache::new(),
131        })
132    }
133
134    /// Securely send a message
135    #[instrument(skip(self, msg))]
136    pub async fn send(&mut self, msg: Message) -> Result<()> {
137        let result = self.conn.secure_send(msg).await;
138        if result.is_ok() {
139            self.keep_alive.update_send();
140        }
141        result
142    }
143
144    /// Securely receive a message
145    #[instrument(skip(self))]
146    pub async fn recv(&mut self) -> Result<Message> {
147        let result = self.conn.secure_recv().await;
148        if result.is_ok() {
149            self.keep_alive.update_recv();
150        }
151        result
152    }
153
154    /// Send a keep-alive ping to the server
155    #[instrument(skip(self))]
156    pub async fn send_keepalive(&mut self) -> Result<()> {
157        debug!("Sending keep-alive ping");
158        let ping = build_ping();
159        self.send(ping).await
160    }
161
162    /// Wait for messages with keep-alive handling using custom timeout
163    #[instrument(skip(self))]
164    pub async fn recv_with_keepalive(
165        &mut self,
166        timeout_duration: std::time::Duration,
167    ) -> Result<Message> {
168        let mut ping_interval = time::interval(self.keep_alive.ping_interval());
169
170        let timeout = time::sleep(timeout_duration);
171        tokio::pin!(timeout);
172
173        loop {
174            tokio::select! {
175                // Check if we need to send a ping
176                _ = ping_interval.tick() => {
177                    if self.keep_alive.should_ping() {
178                        self.send_keepalive().await?;
179                    }
180
181                    // Check if connection is dead
182                    if self.keep_alive.is_connection_dead() {
183                        warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(),
184                              "Connection appears dead");
185                        return Err(ProtocolError::ConnectionTimeout);
186                    }
187                }
188
189                // Try to receive a message
190                recv_result = self.conn.secure_recv::<Message>() => {
191                    match recv_result {
192                        Ok(msg) => {
193                            self.keep_alive.update_recv();
194
195                            // Filter out pong messages, return everything else
196                            if !is_pong(&msg) {
197                                return Ok(msg);
198                            } else {
199                                debug!("Received pong response");
200                                // Continue waiting for non-pong messages
201                            }
202                        }
203                        Err(ProtocolError::Timeout) => {
204                            // Timeout is expected, just continue the loop
205                            continue;
206                        }
207                        Err(e) => return Err(e),
208                    }
209                }
210
211                // User-provided timeout
212                _ = &mut timeout => {
213                    return Err(ProtocolError::Timeout);
214                }
215            }
216        }
217    }
218
219    /// Send a message and wait for a response with keep-alive handling
220    #[instrument(skip(self, msg))]
221    pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
222        self.send(msg).await?;
223        // Use configured response timeout
224        self.recv_with_keepalive(self.config.response_timeout).await
225    }
226}