atm0s_small_p2p/
secure.rs

1use crate::PeerId;
2use serde::{Deserialize, Serialize};
3
4pub trait HandshakeProtocol: Send + Sync + 'static {
5    fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8>;
6    fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>;
7    fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8>;
8    fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>;
9}
10
11const HASH_SEED: &str = "atm0s-small-p2p";
12const HANDSHAKE_TIMEOUT: u64 = 30_000;
13
14#[derive(Debug, Serialize, Deserialize)]
15struct HandshakeMessage {
16    payload: Vec<u8>,
17    signature: Vec<u8>,
18}
19
20#[derive(Debug, Serialize, Deserialize)]
21struct HandshakeData {
22    from: PeerId,
23    to: PeerId,
24    timestamp: u64,
25    is_initiator: bool,
26}
27
28/// Simple secure_key protect with hash
29/// Idea is we serialize HandshakeData to bytes with bincode then concat it with secure_key and a seed
30/// Then compare received hash for ensuring two nodes have same secure_key
31/// at_ts timestamp is used for avoiding relay attach, if it older than HANDSHAKE_TIMEOUT then we reject
32pub struct SharedKeyHandshake {
33    secure_key: String,
34}
35
36impl From<&str> for SharedKeyHandshake {
37    fn from(value: &str) -> Self {
38        Self { secure_key: value.to_owned() }
39    }
40}
41
42impl SharedKeyHandshake {
43    fn generate_handshake(&self, from: PeerId, to: PeerId, is_client: bool, now: u64) -> Vec<u8> {
44        let handshake_data = HandshakeData {
45            from,
46            to,
47            timestamp: now,
48            is_initiator: is_client,
49        };
50
51        let data = bincode::serialize(&handshake_data).unwrap();
52        let mut hash_input = data.clone();
53        hash_input.extend_from_slice(self.secure_key.as_bytes());
54        hash_input.extend_from_slice(HASH_SEED.as_bytes());
55
56        let hash = blake3::hash(&hash_input).as_bytes().to_vec();
57
58        let handshake = HandshakeMessage { payload: data, signature: hash };
59        bincode::serialize(&handshake).unwrap()
60    }
61
62    fn validate_handshake(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, expected_is_client: bool, current_ts: u64) -> Result<(), String> {
63        let handshake: HandshakeMessage = bincode::deserialize(&data).map_err(|_| "Invalid handshake format".to_string())?;
64
65        let handshake_data: HandshakeData = bincode::deserialize(&handshake.payload).map_err(|_| "Invalid handshake data format".to_string())?;
66
67        // Verify timestamp
68        if current_ts > handshake_data.timestamp + HANDSHAKE_TIMEOUT {
69            return Err(format!("Handshake timeout {} vs {}", current_ts, handshake_data.timestamp));
70        }
71
72        // Verify peer IDs
73        if handshake_data.from != expected_from || handshake_data.to != expected_to {
74            return Err("Invalid peer IDs".to_string());
75        }
76
77        // Verify client/server role
78        if handshake_data.is_initiator != expected_is_client {
79            return Err("Invalid client/server role".to_string());
80        }
81
82        // Verify hash
83        let mut hash_input = handshake.payload;
84        hash_input.extend_from_slice(self.secure_key.as_bytes());
85        hash_input.extend_from_slice(HASH_SEED.as_bytes());
86        let expected_hash = blake3::hash(&hash_input).as_bytes().to_vec();
87
88        if handshake.signature != expected_hash {
89            return Err("Invalid handshake hash".to_string());
90        }
91
92        Ok(())
93    }
94}
95
96impl HandshakeProtocol for SharedKeyHandshake {
97    fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8> {
98        self.generate_handshake(from, to, true, now)
99    }
100
101    fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> {
102        self.validate_handshake(data, expected_from, expected_to, true, now)
103    }
104
105    fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8> {
106        self.generate_handshake(from, to, false, now)
107    }
108
109    fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> {
110        self.validate_handshake(data, expected_from, expected_to, false, now)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use crate::now_ms;
117
118    use super::*;
119
120    #[test]
121    fn test_handshake_flow() {
122        let secure = SharedKeyHandshake::from("test_key");
123        let peer1 = PeerId::from(1);
124        let peer2 = PeerId::from(2);
125
126        // Test request handshake
127        let request = secure.create_request(peer1, peer2, now_ms());
128        assert!(secure.verify_request(request, peer1, peer2, now_ms()).is_ok());
129
130        // Test response handshake
131        let response = secure.create_response(peer2, peer1, now_ms());
132        assert!(secure.verify_response(response, peer2, peer1, now_ms()).is_ok());
133    }
134
135    #[test]
136    fn test_invalid_handshake() {
137        let secure1 = SharedKeyHandshake::from("key1");
138        let secure2 = SharedKeyHandshake::from("key2");
139        let peer1 = PeerId::from(1);
140        let peer2 = PeerId::from(2);
141
142        let request = secure1.create_request(peer1, peer2, now_ms());
143        assert!(secure2.verify_request(request, peer1, peer2, now_ms()).is_err());
144    }
145
146    #[test]
147    fn test_handshake_timeout() {
148        let secure = SharedKeyHandshake::from("test_key");
149        let peer1 = PeerId::from(1);
150        let peer2 = PeerId::from(2);
151
152        // when date of peer2 is faster than peer1
153        let request = secure.create_request(peer2, peer1, 1000);
154        assert!(secure.verify_request(request, peer2, peer1, 980).is_ok());
155
156        // when peer2 is too slow
157        let request = secure.create_request(peer2, peer1, 1000);
158        assert!(secure.verify_request(request, peer2, peer1, 1000 + HANDSHAKE_TIMEOUT + 1).is_err());
159    }
160}