network_protocol/service/
secure.rs1use crate::core::packet::Packet;
2use crate::utils::crypto::Crypto;
3use crate::error::{Result, ProtocolError};
4use crate::utils::timeout::{with_timeout_error, DEFAULT_TIMEOUT};
5
6use tokio_util::codec::Framed;
7use tokio::net::TcpStream;
8use futures::{SinkExt, StreamExt};
9use std::time::Duration;
10use tracing::{debug, instrument};
11
12pub struct SecureConnection {
13 framed: Framed<TcpStream, crate::core::codec::PacketCodec>,
14 crypto: Crypto,
15 send_timeout: Duration,
16 recv_timeout: Duration,
17 last_activity: std::time::Instant,
18}
19
20impl SecureConnection {
21 pub fn new(framed: Framed<TcpStream, crate::core::codec::PacketCodec>, key: [u8; 32]) -> Self {
22 Self {
23 framed,
24 crypto: Crypto::new(&key),
25 send_timeout: DEFAULT_TIMEOUT,
26 recv_timeout: DEFAULT_TIMEOUT,
27 last_activity: std::time::Instant::now(),
28 }
29 }
30
31 pub fn with_timeouts(mut self, send_timeout: Duration, recv_timeout: Duration) -> Self {
33 self.send_timeout = send_timeout;
34 self.recv_timeout = recv_timeout;
35 self
36 }
37
38 pub fn time_since_last_activity(&self) -> Duration {
40 self.last_activity.elapsed()
41 }
42
43 fn update_activity(&mut self) {
45 self.last_activity = std::time::Instant::now();
46 }
47
48 #[instrument(skip(self, msg), level = "debug")]
49 pub async fn secure_send(&mut self, msg: impl serde::Serialize) -> Result<()> {
50 let data = bincode::serialize(&msg)?;
51 let nonce = Crypto::generate_nonce();
52 let ciphertext = self.crypto.encrypt(&data, &nonce)?;
53
54 let mut final_payload = nonce.to_vec();
55 final_payload.extend(ciphertext);
56
57 let packet = Packet {
58 version: 1,
59 payload: final_payload,
60 };
61
62 debug!(timeout_ms = ?self.send_timeout.as_millis(), "Sending packet with timeout");
63
64 with_timeout_error(
65 async {
66 self.framed.send(packet).await?;
67 Ok(())
68 },
69 self.send_timeout
70 ).await?;
71
72 self.update_activity();
73 Ok(())
74 }
75
76 #[instrument(skip(self), level = "debug")]
77 pub async fn secure_recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<T> {
78 debug!(timeout_ms = ?self.recv_timeout.as_millis(), "Receiving packet with timeout");
79
80 let pkt = with_timeout_error(
81 async {
82 let pkt = self.framed.next().await
83 .ok_or(ProtocolError::ConnectionClosed)??;
84 Ok(pkt)
85 },
86 self.recv_timeout
87 ).await?;
88
89 if pkt.payload.len() < 24 {
90 return Err(ProtocolError::DecryptionFailure);
91 }
92
93 let (nonce_bytes, ciphertext) = pkt.payload.split_at(24);
94 let mut nonce = [0u8; 24];
95 nonce.copy_from_slice(nonce_bytes);
96
97 let plaintext = self.crypto.decrypt(ciphertext, &nonce)?;
98 let msg = bincode::deserialize(&plaintext)?;
99
100 self.update_activity();
101 Ok(msg)
102 }
103}