network_protocol/service/
client.rs1use futures::{SinkExt, StreamExt};
2use tokio::time;
3use tracing::{debug, info, instrument, warn};
4
5use crate::config::ClientConfig;
6
7use crate::core::packet::Packet;
8use crate::protocol::message::Message;
9use crate::error::{ProtocolError, Result};
11use crate::protocol::handshake::{
12 client_derive_session_key, client_secure_handshake_init, client_secure_handshake_verify,
13};
14use crate::protocol::heartbeat::{build_ping, is_pong};
15use crate::protocol::keepalive::KeepAliveManager;
16use crate::service::secure::SecureConnection;
17use crate::transport::remote;
18use crate::utils::timeout::with_timeout_error;
19
20pub struct Client {
22 conn: SecureConnection,
23 keep_alive: KeepAliveManager,
24 config: ClientConfig,
25}
26
27impl Client {
28 #[instrument(skip(addr), fields(address = %addr))]
30 pub async fn connect(addr: &str) -> Result<Self> {
31 let config = ClientConfig {
32 address: addr.to_string(),
33 ..Default::default()
34 };
35 Self::connect_with_config(config).await
36 }
37
38 #[instrument(skip(config), fields(address = %config.address))]
40 pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
41 let mut framed = with_timeout_error(
43 async { remote::connect(&config.address).await },
44 config.connection_timeout,
45 )
46 .await?;
47
48 let (client_state, init_msg) = client_secure_handshake_init()?;
58 let init_bytes = bincode::serialize(&init_msg)?;
59 framed
60 .send(Packet {
61 version: 1,
62 payload: init_bytes,
63 })
64 .await?;
65
66 let response = with_timeout_error(
68 async {
69 let packet = framed
70 .next()
71 .await
72 .ok_or(ProtocolError::ConnectionClosed)?
73 .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
74 bincode::deserialize::<Message>(&packet.payload)
75 .map_err(|e| ProtocolError::DeserializeError(e.to_string()))
76 },
77 config.connection_timeout,
78 )
79 .await?;
80
81 let (server_pub_key, server_nonce, nonce_verification) = match response {
84 Message::SecureHandshakeResponse {
85 pub_key,
86 nonce,
87 nonce_verification,
88 } => (pub_key, nonce, nonce_verification),
89 _ => {
90 return Err(ProtocolError::HandshakeError(
91 "Invalid server response message type".into(),
92 ))
93 }
94 };
95
96 let (client_state_verified, verify_msg) = client_secure_handshake_verify(
98 client_state,
99 server_pub_key,
100 server_nonce,
101 nonce_verification,
102 )?;
103
104 let verify_bytes = bincode::serialize(&verify_msg)?;
105 framed
106 .send(Packet {
107 version: 1,
108 payload: verify_bytes,
109 })
110 .await?;
111
112 let key = client_derive_session_key(client_state_verified)?;
114 let conn = SecureConnection::new(framed, key);
115
116 let dead_timeout = config.heartbeat_interval.mul_f32(4.0); let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
119
120 info!("Connection established successfully");
121 Ok(Self {
122 conn,
123 keep_alive,
124 config,
125 })
126 }
127
128 #[instrument(skip(self, msg))]
130 pub async fn send(&mut self, msg: Message) -> Result<()> {
131 let result = self.conn.secure_send(msg).await;
132 if result.is_ok() {
133 self.keep_alive.update_send();
134 }
135 result
136 }
137
138 #[instrument(skip(self))]
140 pub async fn recv(&mut self) -> Result<Message> {
141 let result = self.conn.secure_recv().await;
142 if result.is_ok() {
143 self.keep_alive.update_recv();
144 }
145 result
146 }
147
148 #[instrument(skip(self))]
150 pub async fn send_keepalive(&mut self) -> Result<()> {
151 debug!("Sending keep-alive ping");
152 let ping = build_ping();
153 self.send(ping).await
154 }
155
156 #[instrument(skip(self))]
158 pub async fn recv_with_keepalive(
159 &mut self,
160 timeout_duration: std::time::Duration,
161 ) -> Result<Message> {
162 let mut ping_interval = time::interval(self.keep_alive.ping_interval());
163
164 let timeout = time::sleep(timeout_duration);
165 tokio::pin!(timeout);
166
167 loop {
168 tokio::select! {
169 _ = ping_interval.tick() => {
171 if self.keep_alive.should_ping() {
172 self.send_keepalive().await?;
173 }
174
175 if self.keep_alive.is_connection_dead() {
177 warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(),
178 "Connection appears dead");
179 return Err(ProtocolError::ConnectionTimeout);
180 }
181 }
182
183 recv_result = self.conn.secure_recv::<Message>() => {
185 match recv_result {
186 Ok(msg) => {
187 self.keep_alive.update_recv();
188
189 if !is_pong(&msg) {
191 return Ok(msg);
192 } else {
193 debug!("Received pong response");
194 }
196 }
197 Err(ProtocolError::Timeout) => {
198 continue;
200 }
201 Err(e) => return Err(e),
202 }
203 }
204
205 _ = &mut timeout => {
207 return Err(ProtocolError::Timeout);
208 }
209 }
210 }
211 }
212
213 #[instrument(skip(self, msg))]
215 pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
216 self.send(msg).await?;
217 self.recv_with_keepalive(self.config.response_timeout).await
219 }
220}