1use crate::onion::{CircuitManager, DirectoryClient, MLKEMOnionRouter};
2use crate::types::{NetworkError, NetworkMessage, PeerId, RoutingStrategy};
3use rand::seq::{IteratorRandom, SliceRandom};
4use rand::thread_rng;
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7use tokio::sync::{Mutex, RwLock};
8
9#[derive(Debug, Clone)]
11pub struct HopInfo {
12 #[allow(dead_code)]
13 peer_id: PeerId,
14 known_peers: HashSet<PeerId>,
15 layer_keys: HashMap<usize, Vec<u8>>,
16}
17
18impl HopInfo {
19 pub fn can_decrypt_layer(&self, layer: usize) -> bool {
21 self.layer_keys.contains_key(&layer)
22 }
23
24 pub fn knows_peer(&self, peer: &PeerId) -> bool {
26 self.known_peers.contains(peer)
27 }
28}
29
30#[derive(Clone)]
32pub struct Router {
33 peers: Arc<RwLock<HashSet<PeerId>>>,
35 hop_info: Arc<RwLock<HashMap<PeerId, HopInfo>>>,
37 onion_router: Arc<Mutex<MLKEMOnionRouter>>,
39 circuit_manager: Arc<Mutex<CircuitManager>>,
41 directory_client: Arc<DirectoryClient>,
43}
44
45impl Router {
46 pub async fn new() -> Result<Self, NetworkError> {
48 let onion_router = MLKEMOnionRouter::new().await.map_err(|e| {
49 NetworkError::RoutingError(format!("Failed to create onion router: {:?}", e))
50 })?;
51
52 let circuit_manager = Arc::new(Mutex::new(CircuitManager::new()));
53 let directory_client = Arc::new(DirectoryClient::new());
54
55 Ok(Self {
56 peers: Arc::new(RwLock::new(HashSet::new())),
57 hop_info: Arc::new(RwLock::new(HashMap::new())),
58 onion_router: Arc::new(Mutex::new(onion_router)),
59 circuit_manager,
60 directory_client,
61 })
62 }
63
64 pub async fn add_peer(&self, peer_id: PeerId) {
66 let mut peers = self.peers.write().await;
67 peers.insert(peer_id);
68
69 let mut hop_info = self.hop_info.write().await;
71 let mut known_peers = HashSet::new();
72
73 let all_peers: Vec<_> = peers.iter().filter(|&&p| p != peer_id).cloned().collect();
75 let mut rng = thread_rng();
76 let subset_size = (all_peers.len() / 2).clamp(1, 3); let known_subset: Vec<_> = all_peers
78 .choose_multiple(&mut rng, subset_size)
79 .cloned()
80 .collect();
81
82 for peer in known_subset {
83 known_peers.insert(peer);
84 }
85
86 let mut layer_keys = HashMap::new();
88 for i in 0..5 {
89 layer_keys.insert(i, vec![i as u8; 32]); }
92
93 hop_info.insert(
94 peer_id,
95 HopInfo {
96 peer_id,
97 known_peers,
98 layer_keys,
99 },
100 );
101 }
102
103 pub async fn route(
105 &self,
106 message: &NetworkMessage,
107 strategy: RoutingStrategy,
108 ) -> Result<Vec<PeerId>, NetworkError> {
109 match strategy {
110 RoutingStrategy::Anonymous { hops } => self.route_anonymous(message, hops).await,
111 RoutingStrategy::Direct(peer_bytes) => {
112 if peer_bytes.len() == 32 {
114 let mut peer_id_bytes = [0u8; 32];
115 peer_id_bytes.copy_from_slice(&peer_bytes);
116 Ok(vec![PeerId::from_bytes(peer_id_bytes)])
117 } else {
118 Err(NetworkError::RoutingError("Invalid peer ID format".into()))
119 }
120 }
121 RoutingStrategy::Flood => {
122 let peers = self.peers.read().await;
123 Ok(peers.iter().cloned().collect())
124 }
125 RoutingStrategy::RandomSubset(count) => {
126 let peers = self.peers.read().await;
127 let mut rng = thread_rng();
128 let selected: Vec<_> = peers
129 .iter()
130 .choose_multiple(&mut rng, count)
131 .into_iter()
132 .cloned()
133 .collect();
134 Ok(selected)
135 }
136 }
137 }
138
139 async fn route_anonymous(
141 &self,
142 message: &NetworkMessage,
143 hops: usize,
144 ) -> Result<Vec<PeerId>, NetworkError> {
145 let actual_hops = hops.max(3);
147 let peers = self.peers.read().await;
148
149 let source_peer = if message.source.len() == 32 {
151 let mut bytes = [0u8; 32];
152 bytes.copy_from_slice(&message.source);
153 Some(PeerId::from_bytes(bytes))
154 } else {
155 None
156 };
157
158 let dest_peer = if message.destination.len() == 32 {
159 let mut bytes = [0u8; 32];
160 bytes.copy_from_slice(&message.destination);
161 Some(PeerId::from_bytes(bytes))
162 } else {
163 None
164 };
165
166 let available_peers: Vec<_> = peers
167 .iter()
168 .filter(|&&p| Some(p) != source_peer && Some(p) != dest_peer)
169 .cloned()
170 .collect();
171
172 if available_peers.len() < hops {
173 return Err(NetworkError::RoutingError(
174 "Not enough peers for anonymous routing".into(),
175 ));
176 }
177
178 let mut circuit_mgr = self.circuit_manager.lock().await;
180 let circuit_id = circuit_mgr
181 .build_circuit(actual_hops, &self.directory_client)
182 .await
183 .map_err(|e| NetworkError::RoutingError(format!("Circuit build failed: {:?}", e)))?;
184
185 circuit_mgr.activate_circuit(circuit_id).map_err(|e| {
187 NetworkError::RoutingError(format!("Circuit activation failed: {:?}", e))
188 })?;
189
190 let circuit = circuit_mgr
192 .get_active_circuit()
193 .ok_or_else(|| NetworkError::RoutingError("No active circuit available".into()))?;
194
195 let onion_router = self.onion_router.lock().await;
197 let _layers = onion_router
198 .encrypt_layers(message.payload.clone(), circuit.hops.clone())
199 .await
200 .map_err(|e| NetworkError::RoutingError(format!("Onion encryption failed: {:?}", e)))?;
201
202 let route: Vec<PeerId> = circuit
204 .hops
205 .iter()
206 .filter_map(|node_id| {
207 if node_id.len() == 32 {
208 let mut peer_id_bytes = [0u8; 32];
209 peer_id_bytes.copy_from_slice(&node_id[..32]);
210 Some(PeerId::from_bytes(peer_id_bytes))
211 } else {
212 None
213 }
214 })
215 .collect();
216
217 circuit_mgr.update_circuit_metrics(circuit_id, message.payload.len() as u64, true);
219
220 Ok(route)
221 }
222
223 #[allow(dead_code)]
225 async fn update_hop_knowledge(&self, route: &[PeerId]) {
226 let mut hop_info = self.hop_info.write().await;
227
228 for (i, &peer_id) in route.iter().enumerate() {
229 if let Some(info) = hop_info.get_mut(&peer_id) {
230 info.known_peers.clear();
232
233 if i > 0 {
235 info.known_peers.insert(route[i - 1]);
236 }
237 if i < route.len() - 1 {
238 info.known_peers.insert(route[i + 1]);
239 }
240
241 info.layer_keys.clear();
243 info.layer_keys.insert(i, vec![i as u8; 32]);
244 }
245 }
246 }
247
248 pub async fn get_hop_info(&self, peer_id: &PeerId) -> Result<HopInfo, NetworkError> {
250 let hop_info = self.hop_info.read().await;
251 hop_info
252 .get(peer_id)
253 .cloned()
254 .ok_or_else(|| NetworkError::RoutingError("Hop information not found".into()))
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::types::MessagePriority;
262 use std::time::Duration;
263
264 #[tokio::test]
265 async fn test_router_creation() {
266 let router = Router::new().await.unwrap();
267 let peers = router.peers.read().await;
268 assert!(peers.is_empty());
269 }
270
271 #[tokio::test]
272 async fn test_add_peer() {
273 let router = Router::new().await.unwrap();
274 let peer_id = PeerId::random();
275
276 router.add_peer(peer_id).await;
277
278 let peers = router.peers.read().await;
279 assert!(peers.contains(&peer_id));
280 }
281
282 #[tokio::test]
283 async fn test_anonymous_routing() {
284 let router = Router::new().await.unwrap();
285
286 let peers: Vec<_> = (0..5).map(|_| PeerId::random()).collect();
288 for peer in &peers {
289 router.add_peer(*peer).await;
290 }
291
292 let msg = NetworkMessage {
294 id: "test".into(),
295 source: peers[0].to_bytes().to_vec(),
296 destination: peers[4].to_bytes().to_vec(),
297 payload: vec![1, 2, 3],
298 priority: MessagePriority::High,
299 ttl: Duration::from_secs(60),
300 };
301
302 let route = router
304 .route(&msg, RoutingStrategy::Anonymous { hops: 3 })
305 .await
306 .unwrap();
307
308 assert_eq!(route.len(), 3);
309 assert!(!route.contains(&peers[0])); assert!(!route.contains(&peers[4])); }
312}