network_protocol/service/
secure.rs

1use crate::core::packet::Packet;
2use crate::utils::crypto::Crypto;
3use crate::error::{Result, ProtocolError};
4use crate::utils::timeout::{with_timeout_error, DEFAULT_TIMEOUT};
5
6use tokio_util::codec::Framed;
7use tokio::net::TcpStream;
8use futures::{SinkExt, StreamExt};
9use std::time::Duration;
10use tracing::{debug, instrument};
11
12pub struct SecureConnection {
13    framed: Framed<TcpStream, crate::core::codec::PacketCodec>,
14    crypto: Crypto,
15    send_timeout: Duration,
16    recv_timeout: Duration,
17    last_activity: std::time::Instant,
18}
19
20impl SecureConnection {
21    pub fn new(framed: Framed<TcpStream, crate::core::codec::PacketCodec>, key: [u8; 32]) -> Self {
22        Self {
23            framed,
24            crypto: Crypto::new(&key),
25            send_timeout: DEFAULT_TIMEOUT,
26            recv_timeout: DEFAULT_TIMEOUT,
27            last_activity: std::time::Instant::now(),
28        }
29    }
30    
31    /// Set custom timeout durations
32    pub fn with_timeouts(mut self, send_timeout: Duration, recv_timeout: Duration) -> Self {
33        self.send_timeout = send_timeout;
34        self.recv_timeout = recv_timeout;
35        self
36    }
37    
38    /// Get the time since the last activity (send or receive)
39    pub fn time_since_last_activity(&self) -> Duration {
40        self.last_activity.elapsed()
41    }
42    
43    /// Update the last activity timestamp
44    fn update_activity(&mut self) {
45        self.last_activity = std::time::Instant::now();
46    }
47
48    #[instrument(skip(self, msg), level = "debug")]
49    pub async fn secure_send(&mut self, msg: impl serde::Serialize) -> Result<()> {
50        let data = bincode::serialize(&msg)?;
51        let nonce = Crypto::generate_nonce();
52        let ciphertext = self.crypto.encrypt(&data, &nonce)?;
53
54        let mut final_payload = nonce.to_vec();
55        final_payload.extend(ciphertext);
56
57        let packet = Packet {
58            version: 1,
59            payload: final_payload,
60        };
61        
62        debug!(timeout_ms = ?self.send_timeout.as_millis(), "Sending packet with timeout");
63        
64        with_timeout_error(
65            async {
66                self.framed.send(packet).await?;
67                Ok(())
68            },
69            self.send_timeout
70        ).await?;
71        
72        self.update_activity();
73        Ok(())
74    }
75
76    #[instrument(skip(self), level = "debug")]
77    pub async fn secure_recv<T: serde::de::DeserializeOwned>(&mut self) -> Result<T> {
78        debug!(timeout_ms = ?self.recv_timeout.as_millis(), "Receiving packet with timeout");
79        
80        let pkt = with_timeout_error(
81            async {
82                let pkt = self.framed.next().await
83                    .ok_or(ProtocolError::ConnectionClosed)??;
84                Ok(pkt)
85            },
86            self.recv_timeout
87        ).await?;
88
89        if pkt.payload.len() < 24 {
90            return Err(ProtocolError::DecryptionFailure);
91        }
92
93        let (nonce_bytes, ciphertext) = pkt.payload.split_at(24);
94        let mut nonce = [0u8; 24];
95        nonce.copy_from_slice(nonce_bytes);
96
97        let plaintext = self.crypto.decrypt(ciphertext, &nonce)?;
98        let msg = bincode::deserialize(&plaintext)?;
99        
100        self.update_activity();
101        Ok(msg)
102    }
103}