1use crate::hnsw::{DistanceMetric, SearchResult};
11use ipfrs_core::{Cid, Error, Result};
12use ipfrs_network::libp2p::PeerId;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SemanticDHTConfig {
21 pub embedding_dim: usize,
23 pub replication_factor: usize,
25 pub routing_table_size: usize,
27 pub distance_metric: DistanceMetric,
29 pub max_hops: usize,
31 pub query_timeout_ms: u64,
33}
34
35impl Default for SemanticDHTConfig {
36 fn default() -> Self {
37 Self {
38 embedding_dim: 768,
39 replication_factor: 3,
40 routing_table_size: 20,
41 distance_metric: DistanceMetric::Cosine,
42 max_hops: 5,
43 query_timeout_ms: 5000,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct SemanticPeer {
51 pub peer_id: PeerId,
53 pub embedding: Vec<f32>,
55 pub cluster_id: Option<usize>,
57 pub last_seen: u64,
59 pub load: f32,
61}
62
63impl SemanticPeer {
64 pub fn new(peer_id: PeerId, embedding: Vec<f32>) -> Self {
66 Self {
67 peer_id,
68 embedding,
69 cluster_id: None,
70 last_seen: current_timestamp(),
71 load: 0.0,
72 }
73 }
74
75 pub fn update_last_seen(&mut self) {
77 self.last_seen = current_timestamp();
78 }
79
80 pub fn update_load(&mut self, load: f32) {
82 self.load = load.clamp(0.0, 1.0);
83 }
84}
85
86#[derive(Debug)]
88pub struct SemanticRoutingTable {
89 config: SemanticDHTConfig,
91 peers: Arc<RwLock<HashMap<PeerId, SemanticPeer>>>,
93 clusters: Arc<RwLock<HashMap<usize, Vec<PeerId>>>>,
95 local_embedding: Arc<RwLock<Vec<f32>>>,
97 route_cache: Arc<RwLock<lru::LruCache<u64, Vec<PeerId>>>>,
99}
100
101impl SemanticRoutingTable {
102 pub fn new(config: SemanticDHTConfig) -> Self {
104 let local_embedding = vec![0.0; config.embedding_dim];
105 Self {
106 config,
107 peers: Arc::new(RwLock::new(HashMap::new())),
108 clusters: Arc::new(RwLock::new(HashMap::new())),
109 local_embedding: Arc::new(RwLock::new(local_embedding)),
110 route_cache: Arc::new(RwLock::new(lru::LruCache::new(
111 std::num::NonZeroUsize::new(1000).unwrap(),
112 ))),
113 }
114 }
115
116 pub fn update_local_embedding(&self, embedding: Vec<f32>) -> Result<()> {
118 if embedding.len() != self.config.embedding_dim {
119 return Err(Error::InvalidInput(format!(
120 "Expected embedding dimension {}, got {}",
121 self.config.embedding_dim,
122 embedding.len()
123 )));
124 }
125 *self.local_embedding.write() = embedding;
126 Ok(())
127 }
128
129 pub fn add_peer(&self, peer: SemanticPeer) -> Result<()> {
131 if peer.embedding.len() != self.config.embedding_dim {
132 return Err(Error::InvalidInput(format!(
133 "Expected embedding dimension {}, got {}",
134 self.config.embedding_dim,
135 peer.embedding.len()
136 )));
137 }
138 self.peers.write().insert(peer.peer_id, peer);
139 Ok(())
140 }
141
142 pub fn remove_peer(&self, peer_id: &PeerId) {
144 self.peers.write().remove(peer_id);
145 }
146
147 pub fn find_nearest_peers(&self, embedding: &[f32], k: usize) -> Vec<(PeerId, f32)> {
149 if let Some(cached_peers) = self.get_cached_route(embedding) {
151 let peers = self.peers.read();
153 let result: Vec<(PeerId, f32)> = cached_peers
154 .iter()
155 .filter_map(|peer_id| {
156 peers.get(peer_id).map(|peer| {
157 let distance = self.compute_distance(embedding, &peer.embedding);
158 (*peer_id, distance)
159 })
160 })
161 .take(k)
162 .collect();
163
164 if result.len() == k {
165 return result;
166 }
167 }
169
170 let peers = self.peers.read();
171 let mut distances: Vec<(PeerId, f32)> = peers
172 .values()
173 .map(|peer| {
174 let distance = self.compute_distance(embedding, &peer.embedding);
175 (peer.peer_id, distance)
176 })
177 .collect();
178
179 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
181
182 let result: Vec<(PeerId, f32)> = distances.into_iter().take(k).collect();
183
184 let peer_ids: Vec<PeerId> = result.iter().map(|(id, _)| *id).collect();
186 drop(peers);
187 self.cache_route(embedding, peer_ids);
188
189 result
190 }
191
192 pub fn find_nearest_peers_balanced(&self, embedding: &[f32], k: usize) -> Vec<(PeerId, f32)> {
194 let peers = self.peers.read();
195 let mut scored_peers: Vec<(PeerId, f32)> = peers
196 .values()
197 .map(|peer| {
198 let distance = self.compute_distance(embedding, &peer.embedding);
199 let score = distance * (1.0 + peer.load);
201 (peer.peer_id, score)
202 })
203 .collect();
204
205 scored_peers.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
206 scored_peers.into_iter().take(k).collect()
207 }
208
209 pub fn get_cluster_peers(&self, cluster_id: usize) -> Vec<PeerId> {
211 self.clusters
212 .read()
213 .get(&cluster_id)
214 .cloned()
215 .unwrap_or_default()
216 }
217
218 pub fn num_peers(&self) -> usize {
220 self.peers.read().len()
221 }
222
223 pub fn num_clusters(&self) -> usize {
225 self.clusters.read().len()
226 }
227
228 fn hash_embedding(embedding: &[f32]) -> u64 {
230 use std::collections::hash_map::DefaultHasher;
231 use std::hash::{Hash, Hasher};
232
233 let mut hasher = DefaultHasher::new();
234 for &val in embedding.iter().take(8) {
236 val.to_bits().hash(&mut hasher);
238 }
239 hasher.finish()
240 }
241
242 pub fn get_cached_route(&self, embedding: &[f32]) -> Option<Vec<PeerId>> {
244 let hash = Self::hash_embedding(embedding);
245 self.route_cache.write().get(&hash).cloned()
246 }
247
248 pub fn cache_route(&self, embedding: &[f32], peers: Vec<PeerId>) {
250 let hash = Self::hash_embedding(embedding);
251 self.route_cache.write().put(hash, peers);
252 }
253
254 pub fn clear_route_cache(&self) {
256 self.route_cache.write().clear();
257 }
258
259 pub fn route_cache_stats(&self) -> (usize, usize) {
261 let cache = self.route_cache.read();
262 (cache.len(), cache.cap().get())
263 }
264
265 pub fn update_clusters(&self, num_clusters: usize) -> Result<()> {
267 let peers = self.peers.read();
268 if peers.is_empty() {
269 return Ok(());
270 }
271
272 let embeddings: Vec<Vec<f32>> = peers.values().map(|p| p.embedding.clone()).collect();
273 let peer_ids: Vec<PeerId> = peers.keys().cloned().collect();
274 drop(peers);
275
276 let assignments = self.kmeans_clustering(&embeddings, num_clusters);
278
279 let mut peers_write = self.peers.write();
281 let mut clusters_write = self.clusters.write();
282 clusters_write.clear();
283
284 for (peer_id, cluster_id) in peer_ids.iter().zip(assignments.iter()) {
285 if let Some(peer) = peers_write.get_mut(peer_id) {
286 peer.cluster_id = Some(*cluster_id);
287 }
288 clusters_write
289 .entry(*cluster_id)
290 .or_default()
291 .push(*peer_id);
292 }
293
294 Ok(())
295 }
296
297 fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
299 match self.config.distance_metric {
300 DistanceMetric::L2 => crate::simd::l2_distance(a, b),
301 DistanceMetric::Cosine => crate::simd::cosine_distance(a, b),
302 DistanceMetric::DotProduct => -crate::simd::dot_product(a, b), }
304 }
305
306 fn kmeans_clustering(&self, embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
308 if embeddings.is_empty() || k == 0 {
309 return Vec::new();
310 }
311
312 let k = k.min(embeddings.len());
313 let dim = embeddings[0].len();
314
315 let mut centroids: Vec<Vec<f32>> = (0..k)
317 .map(|i| embeddings[i % embeddings.len()].clone())
318 .collect();
319
320 let mut assignments = vec![0; embeddings.len()];
321 let max_iterations = 10;
322
323 for _ in 0..max_iterations {
324 for (i, embedding) in embeddings.iter().enumerate() {
326 let mut min_dist = f32::MAX;
327 let mut best_cluster = 0;
328
329 for (cluster_id, centroid) in centroids.iter().enumerate() {
330 let dist = self.compute_distance(embedding, centroid);
331 if dist < min_dist {
332 min_dist = dist;
333 best_cluster = cluster_id;
334 }
335 }
336 assignments[i] = best_cluster;
337 }
338
339 let mut new_centroids = vec![vec![0.0; dim]; k];
341 let mut counts = vec![0; k];
342
343 for (embedding, &cluster_id) in embeddings.iter().zip(assignments.iter()) {
344 for (j, &val) in embedding.iter().enumerate() {
345 new_centroids[cluster_id][j] += val;
346 }
347 counts[cluster_id] += 1;
348 }
349
350 for (cluster_id, count) in counts.iter().enumerate() {
351 if *count > 0 {
352 for j in 0..dim {
353 new_centroids[cluster_id][j] /= *count as f32;
354 }
355 }
356 }
357
358 centroids = new_centroids;
359 }
360
361 assignments
362 }
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct DHTQuery {
368 pub embedding: Vec<f32>,
370 pub k: usize,
372 pub query_id: String,
374 pub ttl: usize,
376 #[serde(skip)]
378 pub visited: HashSet<PeerId>,
379}
380
381#[derive(Debug, Clone)]
383pub struct DHTQueryResponse {
384 pub query_id: String,
386 pub results: Vec<SearchResult>,
388 pub peer_id: PeerId,
390}
391
392#[derive(Debug, Clone, Serialize, Deserialize)]
394pub enum ReplicationStrategy {
395 NearestPeers(usize),
397 SameCluster,
399 CrossCluster(usize),
401}
402
403#[derive(Debug, Clone)]
405pub struct DHTEntry {
406 pub cid: Cid,
408 pub embedding: Vec<f32>,
410 pub primary_peer: PeerId,
412 pub replicas: Vec<PeerId>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct SemanticDHTStats {
419 pub num_peers: usize,
421 pub num_clusters: usize,
423 pub num_local_entries: usize,
425 pub queries_processed: u64,
427 pub avg_query_latency_ms: f64,
429 pub multi_hop_queries: u64,
431}
432
433fn current_timestamp() -> u64 {
435 std::time::SystemTime::now()
436 .duration_since(std::time::UNIX_EPOCH)
437 .unwrap()
438 .as_secs()
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_routing_table_creation() {
447 let config = SemanticDHTConfig::default();
448 let table = SemanticRoutingTable::new(config);
449
450 let local_emb = vec![0.5; 768];
451 assert!(table.update_local_embedding(local_emb).is_ok());
452 }
453
454 #[test]
455 fn test_add_peer() {
456 let config = SemanticDHTConfig::default();
457 let table = SemanticRoutingTable::new(config);
458
459 let peer_id = PeerId::random();
460 let embedding = vec![0.5; 768];
461 let peer = SemanticPeer::new(peer_id, embedding);
462
463 assert!(table.add_peer(peer).is_ok());
464 }
465
466 #[test]
467 fn test_find_nearest_peers() {
468 let config = SemanticDHTConfig::default();
469 let table = SemanticRoutingTable::new(config);
470
471 for i in 0..10 {
473 let peer_id = PeerId::random();
474 let embedding = vec![i as f32 * 0.1; 768];
475 let peer = SemanticPeer::new(peer_id, embedding);
476 table.add_peer(peer).unwrap();
477 }
478
479 let query_embedding = vec![0.5; 768];
480 let nearest = table.find_nearest_peers(&query_embedding, 3);
481
482 assert_eq!(nearest.len(), 3);
483 }
484
485 #[test]
486 fn test_clustering() {
487 let config = SemanticDHTConfig::default();
488 let table = SemanticRoutingTable::new(config);
489
490 for i in 0..20 {
492 let peer_id = PeerId::random();
493 let mut embedding = vec![0.0; 768];
494 if i < 10 {
496 embedding[0] = 1.0;
497 } else {
498 embedding[0] = -1.0;
499 }
500 let peer = SemanticPeer::new(peer_id, embedding);
501 table.add_peer(peer).unwrap();
502 }
503
504 assert!(table.update_clusters(2).is_ok());
505
506 let cluster0 = table.get_cluster_peers(0);
508 let cluster1 = table.get_cluster_peers(1);
509
510 assert!(!cluster0.is_empty() || !cluster1.is_empty());
511 }
512
513 #[test]
514 fn test_load_balancing() {
515 let config = SemanticDHTConfig::default();
516 let table = SemanticRoutingTable::new(config);
517
518 for i in 0..5 {
520 let peer_id = PeerId::random();
521 let embedding = vec![0.5; 768];
522 let mut peer = SemanticPeer::new(peer_id, embedding);
523 peer.update_load(i as f32 * 0.2); table.add_peer(peer).unwrap();
525 }
526
527 let query_embedding = vec![0.5; 768];
528 let balanced = table.find_nearest_peers_balanced(&query_embedding, 3);
529
530 assert_eq!(balanced.len(), 3);
531 }
533
534 #[test]
535 fn test_route_caching() {
536 let config = SemanticDHTConfig::default();
537 let table = SemanticRoutingTable::new(config);
538
539 for i in 0..10 {
541 let peer_id = PeerId::random();
542 let embedding = vec![i as f32 * 0.1; 768];
543 let peer = SemanticPeer::new(peer_id, embedding);
544 table.add_peer(peer).unwrap();
545 }
546
547 let query_embedding = vec![0.5; 768];
548
549 let (cache_size_before, _) = table.route_cache_stats();
551 assert_eq!(cache_size_before, 0);
552
553 let result1 = table.find_nearest_peers(&query_embedding, 3);
554 assert_eq!(result1.len(), 3);
555
556 let (cache_size_after, _) = table.route_cache_stats();
558 assert_eq!(cache_size_after, 1);
559
560 let result2 = table.find_nearest_peers(&query_embedding, 3);
562 assert_eq!(result2.len(), 3);
563
564 let ids1: Vec<_> = result1.iter().map(|(id, _)| id).collect();
566 let ids2: Vec<_> = result2.iter().map(|(id, _)| id).collect();
567 assert_eq!(ids1, ids2);
568
569 table.clear_route_cache();
571 let (cache_size_cleared, _) = table.route_cache_stats();
572 assert_eq!(cache_size_cleared, 0);
573 }
574}