network_protocol/service/
client.rs1use futures::{SinkExt, StreamExt};
2use tokio::time;
3use tracing::{debug, info, warn, instrument};
4
5use crate::config::ClientConfig;
6
7use crate::core::packet::Packet;
8use crate::protocol::message::Message;
9use crate::protocol::handshake::{client_secure_handshake_init, client_secure_handshake_verify, client_derive_session_key};
11use crate::protocol::heartbeat::{build_ping, is_pong};
12use crate::protocol::keepalive::KeepAliveManager;
13use crate::service::secure::SecureConnection;
14use crate::transport::remote;
15use crate::error::{Result, ProtocolError};
16use crate::utils::timeout::with_timeout_error;
17
18pub struct Client {
20 conn: SecureConnection,
21 keep_alive: KeepAliveManager,
22 config: ClientConfig,
23}
24
25impl Client {
26 #[instrument(skip(addr), fields(address = %addr))]
28 pub async fn connect(addr: &str) -> Result<Self> {
29 let config = ClientConfig {
30 address: addr.to_string(),
31 ..Default::default()
32 };
33 Self::connect_with_config(config).await
34 }
35
36 #[instrument(skip(config), fields(address = %config.address))]
38 pub async fn connect_with_config(config: ClientConfig) -> Result<Self> {
39 let mut framed = with_timeout_error(
41 async {
42 remote::connect(&config.address).await
43 },
44 config.connection_timeout
45 ).await?;
46
47 let init_msg = client_secure_handshake_init()?;
57 let init_bytes = bincode::serialize(&init_msg)?;
58 framed.send(Packet {
59 version: 1,
60 payload: init_bytes,
61 }).await?;
62
63 let response = with_timeout_error(
65 async {
66 let packet = framed.next().await
67 .ok_or(ProtocolError::ConnectionClosed)?
68 .map_err(|e| ProtocolError::TransportError(e.to_string()))?;
69 bincode::deserialize::<Message>(&packet.payload)
70 .map_err(|e| ProtocolError::DeserializeError(e.to_string()))
71 },
72 config.connection_timeout
73 ).await?;
74
75 let (server_pub_key, server_nonce, nonce_verification) = match response {
78 Message::SecureHandshakeResponse { pub_key, nonce, nonce_verification } => {
79 (pub_key, nonce, nonce_verification)
80 },
81 _ => return Err(ProtocolError::HandshakeError("Invalid server response message type".into())),
82 };
83
84 let verify_msg = client_secure_handshake_verify(server_pub_key, server_nonce, nonce_verification)?;
86
87 let verify_bytes = bincode::serialize(&verify_msg)?;
88 framed.send(Packet {
89 version: 1,
90 payload: verify_bytes,
91 }).await?;
92
93 let key = client_derive_session_key()?;
95 let conn = SecureConnection::new(framed, key);
96
97 let dead_timeout = config.heartbeat_interval.mul_f32(4.0); let keep_alive = KeepAliveManager::with_settings(config.heartbeat_interval, dead_timeout);
100
101 info!("Connection established successfully");
102 Ok(Self { conn, keep_alive, config })
103 }
104
105 #[instrument(skip(self, msg))]
107 pub async fn send(&mut self, msg: Message) -> Result<()> {
108 let result = self.conn.secure_send(msg).await;
109 if result.is_ok() {
110 self.keep_alive.update_send();
111 }
112 result
113 }
114
115 #[instrument(skip(self))]
117 pub async fn recv(&mut self) -> Result<Message> {
118 let result = self.conn.secure_recv().await;
119 if result.is_ok() {
120 self.keep_alive.update_recv();
121 }
122 result
123 }
124
125 #[instrument(skip(self))]
127 pub async fn send_keepalive(&mut self) -> Result<()> {
128 debug!("Sending keep-alive ping");
129 let ping = build_ping();
130 self.send(ping).await
131 }
132
133 #[instrument(skip(self))]
135 pub async fn recv_with_keepalive(&mut self, timeout_duration: std::time::Duration) -> Result<Message> {
136 let mut ping_interval = time::interval(self.keep_alive.ping_interval());
137
138 let timeout = time::sleep(timeout_duration);
139 tokio::pin!(timeout);
140
141 loop {
142 tokio::select! {
143 _ = ping_interval.tick() => {
145 if self.keep_alive.should_ping() {
146 self.send_keepalive().await?;
147 }
148
149 if self.keep_alive.is_connection_dead() {
151 warn!(dead_seconds = ?self.keep_alive.time_since_last_recv().as_secs(),
152 "Connection appears dead");
153 return Err(ProtocolError::ConnectionTimeout);
154 }
155 }
156
157 recv_result = self.conn.secure_recv::<Message>() => {
159 match recv_result {
160 Ok(msg) => {
161 self.keep_alive.update_recv();
162
163 if !is_pong(&msg) {
165 return Ok(msg);
166 } else {
167 debug!("Received pong response");
168 }
170 }
171 Err(ProtocolError::Timeout) => {
172 continue;
174 }
175 Err(e) => return Err(e),
176 }
177 }
178
179 _ = &mut timeout => {
181 return Err(ProtocolError::Timeout);
182 }
183 }
184 }
185 }
186
187 #[instrument(skip(self, msg))]
189 pub async fn send_and_wait(&mut self, msg: Message) -> Result<Message> {
190 self.send(msg).await?;
191 self.recv_with_keepalive(self.config.response_timeout).await
193 }
194}