network_protocol/service/
secure.rs

1use crate::core::packet::Packet;
2use crate::error::{ProtocolError, Result};
3use crate::utils::crypto::Crypto;
4use crate::utils::timeout::{with_timeout_error, DEFAULT_TIMEOUT};
5
6use futures::{SinkExt, StreamExt};
7use std::time::Duration;
8use tokio::net::TcpStream;
9use tokio_util::codec::Framed;
10use tracing::{debug, instrument};
11use zeroize::Zeroize;
12
13pub struct SecureConnection {
14    framed: Framed<TcpStream, crate::core::codec::PacketCodec>,
15    crypto: Crypto,
16    send_timeout: Duration,
17    recv_timeout: Duration,
18    last_activity: std::time::Instant,
19}
20
21impl SecureConnection {
22    pub fn new(
23        framed: Framed<TcpStream, crate::core::codec::PacketCodec>,
24        mut key: [u8; 32],
25    ) -> Self {
26        let conn = Self {
27            framed,
28            crypto: Crypto::new(&key),
29            send_timeout: DEFAULT_TIMEOUT,
30            recv_timeout: DEFAULT_TIMEOUT,
31            last_activity: std::time::Instant::now(),
32        };
33
34        // Zeroize the key after it's been used to initialize the crypto object
35        key.zeroize();
36
37        conn
38    }
39
40    /// Set custom timeout durations
41    pub fn with_timeouts(mut self, send_timeout: Duration, recv_timeout: Duration) -> Self {
42        self.send_timeout = send_timeout;
43        self.recv_timeout = recv_timeout;
44        self
45    }
46
47    /// Get the time since the last activity (send or receive)
48    pub fn time_since_last_activity(&self) -> Duration {
49        self.last_activity.elapsed()
50    }
51
52    /// Update the last activity timestamp
53    fn update_activity(&mut self) {
54        self.last_activity = std::time::Instant::now();
55    }
56
57    #[instrument(skip(self, msg), level = "debug")]
58    pub async fn secure_send(&mut self, msg: impl serde::Serialize) -> Result<()> {
59        let data = bincode::serialize(&msg)?;
60        let mut nonce = Crypto::generate_nonce();
61        let ciphertext = self.crypto.encrypt(&data, &nonce)?;
62
63        // Zeroize the plaintext data after encryption to prevent lingering in memory
64        let data_to_zero = data;
65        drop(data_to_zero); // Drop will be handled, but we explicitly mark intent
66
67        let mut final_payload = nonce.to_vec();
68        final_payload.extend(ciphertext);
69
70        // Zeroize nonce after use
71        nonce.zeroize();
72
73        let packet = Packet {
74            version: 1,
75            payload: final_payload,
76        };
77
78        debug!(timeout_ms = ?self.send_timeout.as_millis(), "Sending packet with timeout");
79
80        with_timeout_error(
81            async {
82                self.framed.send(packet).await?;
83                Ok(())
84            },
85            self.send_timeout,
86        )
87        .await?;
88
89        self.update_activity();
90        Ok(())
91    }
92
93    #[instrument(skip(self), level = "debug")]
94    pub async fn secure_recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<T> {
95        debug!(timeout_ms = ?self.recv_timeout.as_millis(), "Receiving packet with timeout");
96
97        let pkt = with_timeout_error(
98            async {
99                let pkt = self
100                    .framed
101                    .next()
102                    .await
103                    .ok_or(ProtocolError::ConnectionClosed)??;
104                Ok(pkt)
105            },
106            self.recv_timeout,
107        )
108        .await?;
109
110        if pkt.payload.len() < 24 {
111            return Err(ProtocolError::DecryptionFailure);
112        }
113
114        let (nonce_bytes, ciphertext) = pkt.payload.split_at(24);
115        let mut nonce = [0u8; 24];
116        nonce.copy_from_slice(nonce_bytes);
117
118        let plaintext = self.crypto.decrypt(ciphertext, &nonce)?;
119
120        // Zeroize nonce after decryption
121        nonce.zeroize();
122
123        let msg = bincode::deserialize(&plaintext)?;
124
125        // Zeroize plaintext after deserialization
126        let plaintext_to_zero = plaintext;
127        drop(plaintext_to_zero); // Explicitly mark for zeroization
128
129        self.update_activity();
130        Ok(msg)
131    }
132}