network_protocol/service/
secure.rs1use crate::core::packet::Packet;
2use crate::error::{ProtocolError, Result};
3use crate::utils::crypto::Crypto;
4use crate::utils::timeout::{with_timeout_error, DEFAULT_TIMEOUT};
5
6use futures::{SinkExt, StreamExt};
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio_util::codec::Framed;
10use tracing::{debug, instrument};
11use zeroize::Zeroize;
12
13pub struct SecureConnection {
14 framed: Framed<TcpStream, crate::core::codec::PacketCodec>,
15 crypto: Crypto,
16 send_timeout: Duration,
17 recv_timeout: Duration,
18 last_activity: std::time::Instant,
19}
20
21impl SecureConnection {
22 pub fn new(
23 framed: Framed<TcpStream, crate::core::codec::PacketCodec>,
24 mut key: [u8; 32],
25 ) -> Self {
26 let conn = Self {
27 framed,
28 crypto: Crypto::new(&key),
29 send_timeout: DEFAULT_TIMEOUT,
30 recv_timeout: DEFAULT_TIMEOUT,
31 last_activity: std::time::Instant::now(),
32 };
33
34 key.zeroize();
36
37 conn
38 }
39
40 pub fn with_timeouts(mut self, send_timeout: Duration, recv_timeout: Duration) -> Self {
42 self.send_timeout = send_timeout;
43 self.recv_timeout = recv_timeout;
44 self
45 }
46
47 pub fn time_since_last_activity(&self) -> Duration {
49 self.last_activity.elapsed()
50 }
51
52 fn update_activity(&mut self) {
54 self.last_activity = std::time::Instant::now();
55 }
56
57 #[instrument(skip(self, msg), level = "debug")]
58 pub async fn secure_send(&mut self, msg: impl serde::Serialize) -> Result<()> {
59 let data = bincode::serialize(&msg)?;
60 let mut nonce = Crypto::generate_nonce();
61 let ciphertext = self.crypto.encrypt(&data, &nonce)?;
62
63 let data_to_zero = data;
65 drop(data_to_zero); let mut final_payload = nonce.to_vec();
68 final_payload.extend(ciphertext);
69
70 nonce.zeroize();
72
73 let packet = Packet {
74 version: 1,
75 payload: final_payload,
76 };
77
78 debug!(timeout_ms = ?self.send_timeout.as_millis(), "Sending packet with timeout");
79
80 with_timeout_error(
81 async {
82 self.framed.send(packet).await?;
83 Ok(())
84 },
85 self.send_timeout,
86 )
87 .await?;
88
89 self.update_activity();
90 Ok(())
91 }
92
93 #[instrument(skip(self), level = "debug")]
94 pub async fn secure_recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<T> {
95 debug!(timeout_ms = ?self.recv_timeout.as_millis(), "Receiving packet with timeout");
96
97 let pkt = with_timeout_error(
98 async {
99 let pkt = self
100 .framed
101 .next()
102 .await
103 .ok_or(ProtocolError::ConnectionClosed)??;
104 Ok(pkt)
105 },
106 self.recv_timeout,
107 )
108 .await?;
109
110 if pkt.payload.len() < 24 {
111 return Err(ProtocolError::DecryptionFailure);
112 }
113
114 let (nonce_bytes, ciphertext) = pkt.payload.split_at(24);
115 let mut nonce = [0u8; 24];
116 nonce.copy_from_slice(nonce_bytes);
117
118 let plaintext = self.crypto.decrypt(ciphertext, &nonce)?;
119
120 nonce.zeroize();
122
123 let msg = bincode::deserialize(&plaintext)?;
124
125 let plaintext_to_zero = plaintext;
127 drop(plaintext_to_zero); self.update_activity();
130 Ok(msg)
131 }
132}