1use crate::PeerId;
2use serde::{Deserialize, Serialize};
3
4pub trait HandshakeProtocol: Send + Sync + 'static {
5 fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8>;
6 fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>;
7 fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8>;
8 fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String>;
9}
10
11const HASH_SEED: &str = "atm0s-small-p2p";
12const HANDSHAKE_TIMEOUT: u64 = 30_000;
13
14#[derive(Debug, Serialize, Deserialize)]
15struct HandshakeMessage {
16 payload: Vec<u8>,
17 signature: Vec<u8>,
18}
19
20#[derive(Debug, Serialize, Deserialize)]
21struct HandshakeData {
22 from: PeerId,
23 to: PeerId,
24 timestamp: u64,
25 is_initiator: bool,
26}
27
28pub struct SharedKeyHandshake {
33 secure_key: String,
34}
35
36impl From<&str> for SharedKeyHandshake {
37 fn from(value: &str) -> Self {
38 Self { secure_key: value.to_owned() }
39 }
40}
41
42impl SharedKeyHandshake {
43 fn generate_handshake(&self, from: PeerId, to: PeerId, is_client: bool, now: u64) -> Vec<u8> {
44 let handshake_data = HandshakeData {
45 from,
46 to,
47 timestamp: now,
48 is_initiator: is_client,
49 };
50
51 let data = bincode::serialize(&handshake_data).unwrap();
52 let mut hash_input = data.clone();
53 hash_input.extend_from_slice(self.secure_key.as_bytes());
54 hash_input.extend_from_slice(HASH_SEED.as_bytes());
55
56 let hash = blake3::hash(&hash_input).as_bytes().to_vec();
57
58 let handshake = HandshakeMessage { payload: data, signature: hash };
59 bincode::serialize(&handshake).unwrap()
60 }
61
62 fn validate_handshake(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, expected_is_client: bool, current_ts: u64) -> Result<(), String> {
63 let handshake: HandshakeMessage = bincode::deserialize(&data).map_err(|_| "Invalid handshake format".to_string())?;
64
65 let handshake_data: HandshakeData = bincode::deserialize(&handshake.payload).map_err(|_| "Invalid handshake data format".to_string())?;
66
67 if current_ts > handshake_data.timestamp + HANDSHAKE_TIMEOUT {
69 return Err(format!("Handshake timeout {} vs {}", current_ts, handshake_data.timestamp));
70 }
71
72 if handshake_data.from != expected_from || handshake_data.to != expected_to {
74 return Err("Invalid peer IDs".to_string());
75 }
76
77 if handshake_data.is_initiator != expected_is_client {
79 return Err("Invalid client/server role".to_string());
80 }
81
82 let mut hash_input = handshake.payload;
84 hash_input.extend_from_slice(self.secure_key.as_bytes());
85 hash_input.extend_from_slice(HASH_SEED.as_bytes());
86 let expected_hash = blake3::hash(&hash_input).as_bytes().to_vec();
87
88 if handshake.signature != expected_hash {
89 return Err("Invalid handshake hash".to_string());
90 }
91
92 Ok(())
93 }
94}
95
96impl HandshakeProtocol for SharedKeyHandshake {
97 fn create_request(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8> {
98 self.generate_handshake(from, to, true, now)
99 }
100
101 fn verify_request(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> {
102 self.validate_handshake(data, expected_from, expected_to, true, now)
103 }
104
105 fn create_response(&self, from: PeerId, to: PeerId, now: u64) -> Vec<u8> {
106 self.generate_handshake(from, to, false, now)
107 }
108
109 fn verify_response(&self, data: Vec<u8>, expected_from: PeerId, expected_to: PeerId, now: u64) -> Result<(), String> {
110 self.validate_handshake(data, expected_from, expected_to, false, now)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use crate::now_ms;
117
118 use super::*;
119
120 #[test]
121 fn test_handshake_flow() {
122 let secure = SharedKeyHandshake::from("test_key");
123 let peer1 = PeerId::from(1);
124 let peer2 = PeerId::from(2);
125
126 let request = secure.create_request(peer1, peer2, now_ms());
128 assert!(secure.verify_request(request, peer1, peer2, now_ms()).is_ok());
129
130 let response = secure.create_response(peer2, peer1, now_ms());
132 assert!(secure.verify_response(response, peer2, peer1, now_ms()).is_ok());
133 }
134
135 #[test]
136 fn test_invalid_handshake() {
137 let secure1 = SharedKeyHandshake::from("key1");
138 let secure2 = SharedKeyHandshake::from("key2");
139 let peer1 = PeerId::from(1);
140 let peer2 = PeerId::from(2);
141
142 let request = secure1.create_request(peer1, peer2, now_ms());
143 assert!(secure2.verify_request(request, peer1, peer2, now_ms()).is_err());
144 }
145
146 #[test]
147 fn test_handshake_timeout() {
148 let secure = SharedKeyHandshake::from("test_key");
149 let peer1 = PeerId::from(1);
150 let peer2 = PeerId::from(2);
151
152 let request = secure.create_request(peer2, peer1, 1000);
154 assert!(secure.verify_request(request, peer2, peer1, 980).is_ok());
155
156 let request = secure.create_request(peer2, peer1, 1000);
158 assert!(secure.verify_request(request, peer2, peer1, 1000 + HANDSHAKE_TIMEOUT + 1).is_err());
159 }
160}