network_protocol/service/
client.rs1use futures::{SinkExt, StreamExt};
2use tokio::time;
3use tracing::{debug, info, instrument, warn};
4
5use crate::config::ClientConfig;
6
7use crate::core::packet::Packet;
8use crate::protocol::message::Message;
9use crate::error::{ProtocolError, Result};
11use crate::protocol::handshake::{
12 client_derive_session_key, client_secure_handshake_init, client_secure_handshake_verify,
13};
14use crate::protocol::heartbeat::{build_ping, is_pong};
15use crate::protocol::keepalive::KeepAliveManager;
16use crate::service::secure::SecureConnection;
17use crate::transport::remote;
18use crate::utils::replay_cache::ReplayCache;
19use crate::utils::timeout::with_timeout_error;
20
21pub struct Client {
23 conn: SecureConnection,
24 keep_alive: KeepAliveManager,
25 config: ClientConfig,
26 #[allow(dead_code)]
27 replay_cache: ReplayCache,
28}
29
30impl Client {
31 #[instrument(skip(addr), fields(address = %addr))]
33 pub async fn connect(addr: &str) -> Result<Self> {
34 let config = ClientConfig {
35 address: addr.to_string(),
36 ..Default::default()
37 };
38 Self::connect_with_config(config).await
39 }
40
41 #[instrument(skip(config), fields(address = %config.address))]
43 pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
44 let mut framed = with_timeout_error(
46 async { remote::connect(&config.address).await },
47 config.connection_timeout,
48 )
49 .await?;
50
51 let (client_state, init_msg) = client_secure_handshake_init()?;
61 let init_bytes = bincode::serialize(&init_msg)?;
62 framed
63 .send(Packet {
64 version: 1,
65 payload: init_bytes,
66 })
67 .await?;
68
69 let response = with_timeout_error(
71 async {
72 let packet = framed
73 .next()
74 .await
75 .ok_or(ProtocolError::ConnectionClosed)?
76 .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
77 bincode::deserialize::<Message>(&packet.payload)
78 .map_err(|e| ProtocolError::DeserializeError(e.to_string()))
79 },
80 config.connection_timeout,
81 )
82 .await?;
83
84 let (server_pub_key, server_nonce, nonce_verification) = match response {
87 Message::SecureHandshakeResponse {
88 pub_key,
89 nonce,
90 nonce_verification,
91 } => (pub_key, nonce, nonce_verification),
92 _ => {
93 return Err(ProtocolError::HandshakeError(
94 "Invalid server response message type".into(),
95 ))
96 }
97 };
98
99 let (client_state_verified, verify_msg) = client_secure_handshake_verify(
101 client_state,
102 server_pub_key,
103 server_nonce,
104 nonce_verification,
105 &config.address,
106 &mut ReplayCache::new(),
107 )?;
108
109 let verify_bytes = bincode::serialize(&verify_msg)?;
110 framed
111 .send(Packet {
112 version: 1,
113 payload: verify_bytes,
114 })
115 .await?;
116
117 let key = client_derive_session_key(client_state_verified)?;
119 let conn = SecureConnection::new(framed, key);
120
121 let dead_timeout = config.heartbeat_interval.mul_f32(4.0); let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
124
125 info!("Connection established successfully");
126 Ok(Self {
127 conn,
128 keep_alive,
129 config,
130 replay_cache: ReplayCache::new(),
131 })
132 }
133
134 #[instrument(skip(self, msg))]
136 pub async fn send(&mut self, msg: Message) -> Result<()> {
137 let result = self.conn.secure_send(msg).await;
138 if result.is_ok() {
139 self.keep_alive.update_send();
140 }
141 result
142 }
143
144 #[instrument(skip(self))]
146 pub async fn recv(&mut self) -> Result<Message> {
147 let result = self.conn.secure_recv().await;
148 if result.is_ok() {
149 self.keep_alive.update_recv();
150 }
151 result
152 }
153
154 #[instrument(skip(self))]
156 pub async fn send_keepalive(&mut self) -> Result<()> {
157 debug!("Sending keep-alive ping");
158 let ping = build_ping();
159 self.send(ping).await
160 }
161
162 #[instrument(skip(self))]
164 pub async fn recv_with_keepalive(
165 &mut self,
166 timeout_duration: std::time::Duration,
167 ) -> Result<Message> {
168 let mut ping_interval = time::interval(self.keep_alive.ping_interval());
169
170 let timeout = time::sleep(timeout_duration);
171 tokio::pin!(timeout);
172
173 loop {
174 tokio::select! {
175 _ = ping_interval.tick() => {
177 if self.keep_alive.should_ping() {
178 self.send_keepalive().await?;
179 }
180
181 if self.keep_alive.is_connection_dead() {
183 warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(),
184 "Connection appears dead");
185 return Err(ProtocolError::ConnectionTimeout);
186 }
187 }
188
189 recv_result = self.conn.secure_recv::<Message>() => {
191 match recv_result {
192 Ok(msg) => {
193 self.keep_alive.update_recv();
194
195 if !is_pong(&msg) {
197 return Ok(msg);
198 } else {
199 debug!("Received pong response");
200 }
202 }
203 Err(ProtocolError::Timeout) => {
204 continue;
206 }
207 Err(e) => return Err(e),
208 }
209 }
210
211 _ = &mut timeout => {
213 return Err(ProtocolError::Timeout);
214 }
215 }
216 }
217 }
218
219 #[instrument(skip(self, msg))]
221 pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
222 self.send(msg).await?;
223 self.recv_with_keepalive(self.config.response_timeout).await
225 }
226}