1use crate::dht::{
10 ReplicationStrategy, SemanticDHTConfig, SemanticDHTStats, SemanticPeer, SemanticRoutingTable,
11};
12use crate::hnsw::{SearchResult, VectorIndex};
13use ipfrs_core::{Cid, Result};
14use ipfrs_network::libp2p::PeerId;
15use parking_lot::RwLock;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Instant;
19
20pub struct SemanticDHTNode {
22 config: SemanticDHTConfig,
24 local_peer_id: PeerId,
26 local_index: Arc<RwLock<VectorIndex>>,
28 routing_table: Arc<SemanticRoutingTable>,
30 replication_strategy: ReplicationStrategy,
32 stats: Arc<RwLock<SemanticDHTStats>>,
34 pending_queries: Arc<RwLock<HashMap<String, Instant>>>,
36 last_sync_timestamp: Arc<RwLock<u64>>,
38 pending_syncs: Arc<RwLock<usize>>,
40}
41
42impl SemanticDHTNode {
43 pub fn new(config: SemanticDHTConfig, local_peer_id: PeerId, local_index: VectorIndex) -> Self {
45 let routing_table = Arc::new(SemanticRoutingTable::new(config.clone()));
46
47 let stats = SemanticDHTStats {
48 num_peers: 0,
49 num_clusters: 0,
50 num_local_entries: 0,
51 queries_processed: 0,
52 avg_query_latency_ms: 0.0,
53 multi_hop_queries: 0,
54 };
55
56 Self {
57 config,
58 local_peer_id,
59 local_index: Arc::new(RwLock::new(local_index)),
60 routing_table,
61 replication_strategy: ReplicationStrategy::NearestPeers(3),
62 stats: Arc::new(RwLock::new(stats)),
63 pending_queries: Arc::new(RwLock::new(HashMap::new())),
64 last_sync_timestamp: Arc::new(RwLock::new(0)),
65 pending_syncs: Arc::new(RwLock::new(0)),
66 }
67 }
68
69 pub async fn insert(&self, cid: &Cid, embedding: &[f32]) -> Result<()> {
71 self.local_index.write().insert(cid, embedding)?;
73
74 self.update_local_embedding().await?;
76
77 let replica_peers = self.select_replica_peers(embedding).await?;
79
80 tracing::debug!("Would replicate {:?} to {} peers", cid, replica_peers.len());
83
84 Ok(())
85 }
86
87 pub fn search_local(&self, embedding: &[f32], k: usize) -> Result<Vec<SearchResult>> {
89 let index = self.local_index.read();
90 let ef_search = self.config.max_hops * 10; index.search(embedding, k, ef_search)
92 }
93
94 pub async fn search_distributed(
96 &self,
97 embedding: &[f32],
98 k: usize,
99 ) -> Result<Vec<SearchResult>> {
100 let query_id = format!("{:?}-{}", self.local_peer_id, uuid::Uuid::new_v4());
101 let start_time = Instant::now();
102
103 self.pending_queries
105 .write()
106 .insert(query_id.clone(), start_time);
107
108 let mut all_results = self.search_local(embedding, k)?;
110
111 let nearest_peers = self
113 .routing_table
114 .find_nearest_peers_balanced(embedding, self.config.routing_table_size);
115
116 if !nearest_peers.is_empty() && self.config.max_hops > 0 {
118 let remote_results = self
119 .multi_hop_search(embedding, k, query_id.clone(), 0)
120 .await?;
121 all_results.extend(remote_results);
122 }
123
124 let final_results = self.aggregate_results(all_results, k);
126
127 let latency = start_time.elapsed().as_millis() as f64;
129 self.update_query_stats(latency, !nearest_peers.is_empty());
130
131 self.pending_queries.write().remove(&query_id);
133
134 Ok(final_results)
135 }
136
137 async fn multi_hop_search(
139 &self,
140 embedding: &[f32],
141 _k: usize,
142 _query_id: String,
143 hop: usize,
144 ) -> Result<Vec<SearchResult>> {
145 if hop >= self.config.max_hops {
146 return Ok(Vec::new());
147 }
148
149 let nearest_peers = self.routing_table.find_nearest_peers_balanced(embedding, 3); let all_results = Vec::new();
152
153 for (peer_id, _distance) in nearest_peers {
154 if peer_id != self.local_peer_id {
157 tracing::debug!("Would query peer {:?} at hop {}", peer_id, hop);
158
159 }
163 }
164
165 Ok(all_results)
166 }
167
168 fn aggregate_results(&self, results: Vec<SearchResult>, k: usize) -> Vec<SearchResult> {
170 let mut seen = HashMap::new();
172 let mut deduplicated = Vec::new();
173
174 for result in results {
175 if let Some(&existing_score) = seen.get(&result.cid) {
176 if result.score < existing_score {
178 if let Some(pos) = deduplicated
180 .iter()
181 .position(|r: &SearchResult| r.cid == result.cid)
182 {
183 deduplicated[pos] = result.clone();
184 seen.insert(result.cid, result.score);
185 }
186 }
187 } else {
188 seen.insert(result.cid, result.score);
189 deduplicated.push(result);
190 }
191 }
192
193 deduplicated.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
195 deduplicated.into_iter().take(k).collect()
196 }
197
198 async fn select_replica_peers(&self, embedding: &[f32]) -> Result<Vec<PeerId>> {
200 match &self.replication_strategy {
201 ReplicationStrategy::NearestPeers(n) => {
202 let peers = self.routing_table.find_nearest_peers(embedding, *n);
203 Ok(peers.into_iter().map(|(peer_id, _)| peer_id).collect())
204 }
205 ReplicationStrategy::SameCluster => {
206 Ok(Vec::new())
209 }
210 ReplicationStrategy::CrossCluster(_n) => {
211 Ok(Vec::new())
214 }
215 }
216 }
217
218 async fn update_local_embedding(&self) -> Result<()> {
220 let index = self.local_index.read();
221 let dim = self.config.embedding_dim;
222
223 let mut centroid = vec![0.0; dim];
225 let _count = 0;
226
227 drop(index);
230
231 let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
233 if norm > 1e-6 {
234 for x in &mut centroid {
235 *x /= norm;
236 }
237 }
238
239 self.routing_table.update_local_embedding(centroid)?;
240
241 Ok(())
242 }
243
244 fn update_query_stats(&self, latency_ms: f64, is_multi_hop: bool) {
246 let mut stats = self.stats.write();
247 stats.queries_processed += 1;
248
249 let alpha = 0.1; stats.avg_query_latency_ms =
252 alpha * latency_ms + (1.0 - alpha) * stats.avg_query_latency_ms;
253
254 if is_multi_hop {
255 stats.multi_hop_queries += 1;
256 }
257 }
258
259 pub fn add_peer(&self, peer: SemanticPeer) -> Result<()> {
261 self.routing_table.add_peer(peer)?;
262
263 let mut stats = self.stats.write();
265 stats.num_peers = self.routing_table.num_peers();
266
267 Ok(())
268 }
269
270 pub fn remove_peer(&self, peer_id: &PeerId) {
272 self.routing_table.remove_peer(peer_id);
273
274 let mut stats = self.stats.write();
276 stats.num_peers = self.routing_table.num_peers();
277 }
278
279 pub fn update_clusters(&self, num_clusters: usize) -> Result<()> {
281 self.routing_table.update_clusters(num_clusters)?;
282
283 let mut stats = self.stats.write();
285 stats.num_clusters = self.routing_table.num_clusters();
286
287 Ok(())
288 }
289
290 pub fn stats(&self) -> SemanticDHTStats {
292 let mut stats = self.stats.read().clone();
293 stats.num_local_entries = self.local_index.read().len();
294 stats
295 }
296
297 pub fn get_stats(&self) -> SemanticDHTStats {
299 self.stats()
300 }
301
302 pub fn routing_table(&self) -> &SemanticRoutingTable {
304 &self.routing_table
305 }
306
307 pub fn set_replication_strategy(&mut self, strategy: ReplicationStrategy) {
309 self.replication_strategy = strategy;
310 }
311
312 pub fn get_index_snapshot(&self) -> Vec<Cid> {
315 let index = self.local_index.read();
316 index.get_all_cids()
318 }
319
320 pub fn has_entry(&self, cid: &Cid) -> bool {
322 let index = self.local_index.read();
323 index.contains(cid)
324 }
325
326 pub fn prepare_sync_delta(&self, peer_snapshot: &[Cid]) -> Vec<Cid> {
329 let local_snapshot = self.get_index_snapshot();
330 let peer_set: std::collections::HashSet<_> = peer_snapshot.iter().collect();
331
332 local_snapshot
333 .into_iter()
334 .filter(|cid| !peer_set.contains(cid))
335 .collect()
336 }
337
338 pub async fn apply_sync_delta(&self, delta_cids: Vec<Cid>) -> Result<usize> {
341 Ok(delta_cids.len())
346 }
347
348 pub async fn apply_sync_delta_with_embeddings(
351 &self,
352 delta_entries: Vec<(Cid, Vec<f32>)>,
353 ) -> Result<usize> {
354 *self.pending_syncs.write() += 1;
356
357 let mut synced_count = 0;
358
359 for (cid, embedding) in delta_entries {
361 match self.local_index.write().insert(&cid, &embedding) {
362 Ok(_) => {
363 synced_count += 1;
364 }
365 Err(e) => {
366 tracing::warn!("Failed to insert CID {:?} during sync: {}", cid, e);
367 }
368 }
369 }
370
371 let now = std::time::SystemTime::now()
373 .duration_since(std::time::UNIX_EPOCH)
374 .unwrap_or_default()
375 .as_secs();
376 *self.last_sync_timestamp.write() = now;
377
378 *self.pending_syncs.write() -= 1;
380
381 self.update_local_embedding().await?;
383
384 tracing::debug!("Synced {} entries from peer", synced_count);
385
386 Ok(synced_count)
387 }
388
389 pub fn sync_stats(&self) -> SyncStats {
391 SyncStats {
392 local_entries: self.local_index.read().len(),
393 last_sync_timestamp: *self.last_sync_timestamp.read(),
394 pending_syncs: *self.pending_syncs.read(),
395 }
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct SyncStats {
402 pub local_entries: usize,
404 pub last_sync_timestamp: u64,
406 pub pending_syncs: usize,
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::hnsw::DistanceMetric;
414
415 #[tokio::test]
416 async fn test_dht_node_creation() {
417 let config = SemanticDHTConfig::default();
418 let peer_id = PeerId::random();
419 let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
420
421 let node = SemanticDHTNode::new(config, peer_id, index);
422 let stats = node.stats();
423
424 assert_eq!(stats.num_peers, 0);
425 assert_eq!(stats.queries_processed, 0);
426 }
427
428 #[tokio::test]
429 async fn test_local_insert_and_search() {
430 let config = SemanticDHTConfig::default();
431 let peer_id = PeerId::random();
432 let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
433
434 let node = SemanticDHTNode::new(config, peer_id, index);
435
436 for i in 0..10 {
438 use multihash_codetable::{Code, MultihashDigest};
439 let data = format!("test_vector_{}", i);
440 let hash = Code::Sha2_256.digest(data.as_bytes());
441 let cid = Cid::new_v1(0x55, hash);
442 let embedding = vec![i as f32 * 0.1; 768];
443 node.insert(&cid, &embedding).await.unwrap();
444 }
445
446 let query = vec![0.5; 768];
448 let results = node.search_local(&query, 5).unwrap();
449
450 assert!(!results.is_empty());
451 assert!(results.len() <= 5);
452 }
453
454 #[tokio::test]
455 async fn test_add_peers() {
456 let config = SemanticDHTConfig::default();
457 let peer_id = PeerId::random();
458 let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
459
460 let node = SemanticDHTNode::new(config, peer_id, index);
461
462 for i in 0..5 {
464 let peer_id = PeerId::random();
465 let embedding = vec![i as f32 * 0.2; 768];
466 let peer = SemanticPeer::new(peer_id, embedding);
467 node.add_peer(peer).unwrap();
468 }
469
470 let stats = node.stats();
471 assert_eq!(stats.num_peers, 5);
472 }
473
474 #[tokio::test]
475 async fn test_clustering() {
476 let config = SemanticDHTConfig::default();
477 let peer_id = PeerId::random();
478 let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
479
480 let node = SemanticDHTNode::new(config, peer_id, index);
481
482 for i in 0..20 {
484 let peer_id = PeerId::random();
485 let mut embedding = vec![0.0; 768];
486 embedding[0] = if i < 10 { 1.0 } else { -1.0 };
487 let peer = SemanticPeer::new(peer_id, embedding);
488 node.add_peer(peer).unwrap();
489 }
490
491 node.update_clusters(2).unwrap();
493
494 let stats = node.stats();
495 assert!(stats.num_clusters > 0);
496 }
497
498 #[tokio::test]
499 async fn test_index_synchronization() {
500 use multihash_codetable::{Code, MultihashDigest};
501
502 let config = SemanticDHTConfig::default();
503 let peer_id1 = PeerId::random();
504 let peer_id2 = PeerId::random();
505
506 let index1 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
507 let index2 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
508
509 let node1 = SemanticDHTNode::new(config.clone(), peer_id1, index1);
510 let node2 = SemanticDHTNode::new(config, peer_id2, index2);
511
512 let mut cids1 = Vec::new();
514 for i in 0..5 {
515 let data = format!("node1_vector_{}", i);
516 let hash = Code::Sha2_256.digest(data.as_bytes());
517 let cid = Cid::new_v1(0x55, hash);
518 let embedding = vec![i as f32 * 0.1; 768];
519 node1.insert(&cid, &embedding).await.unwrap();
520 cids1.push(cid);
521 }
522
523 let mut cids2 = Vec::new();
525 for i in 5..10 {
526 let data = format!("node2_vector_{}", i);
527 let hash = Code::Sha2_256.digest(data.as_bytes());
528 let cid = Cid::new_v1(0x55, hash);
529 let embedding = vec![i as f32 * 0.1; 768];
530 node2.insert(&cid, &embedding).await.unwrap();
531 cids2.push(cid);
532 }
533
534 let snapshot1 = node1.get_index_snapshot();
536 let snapshot2 = node2.get_index_snapshot();
537
538 assert_eq!(snapshot1.len(), 5);
539 assert_eq!(snapshot2.len(), 5);
540
541 for cid in &cids1 {
543 assert!(node1.has_entry(cid));
544 }
545
546 let delta = node1.prepare_sync_delta(&snapshot2);
548 assert_eq!(delta.len(), 5); let synced_count = node2.apply_sync_delta(delta).await.unwrap();
552 assert_eq!(synced_count, 5);
553
554 let sync_stats = node1.sync_stats();
556 assert_eq!(sync_stats.local_entries, 5);
557 }
558
559 #[tokio::test]
560 async fn test_sync_with_embeddings() {
561 use multihash_codetable::{Code, MultihashDigest};
562
563 let config = SemanticDHTConfig::default();
564 let peer_id1 = PeerId::random();
565 let peer_id2 = PeerId::random();
566
567 let index1 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
568 let index2 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
569
570 let node1 = SemanticDHTNode::new(config.clone(), peer_id1, index1);
571 let node2 = SemanticDHTNode::new(config, peer_id2, index2);
572
573 let mut entries_to_sync = Vec::new();
575 for i in 0..5 {
576 let data = format!("sync_test_vector_{}", i);
577 let hash = Code::Sha2_256.digest(data.as_bytes());
578 let cid = Cid::new_v1(0x55, hash);
579 let embedding = vec![i as f32 * 0.1; 768];
580 node1.insert(&cid, &embedding).await.unwrap();
581 entries_to_sync.push((cid, embedding));
582 }
583
584 let sync_stats_before = node2.sync_stats();
586 assert_eq!(sync_stats_before.local_entries, 0);
587 assert_eq!(sync_stats_before.last_sync_timestamp, 0);
588 assert_eq!(sync_stats_before.pending_syncs, 0);
589
590 let synced_count = node2
592 .apply_sync_delta_with_embeddings(entries_to_sync.clone())
593 .await
594 .unwrap();
595 assert_eq!(synced_count, 5);
596
597 let sync_stats_after = node2.sync_stats();
599 assert_eq!(sync_stats_after.local_entries, 5);
600 assert!(sync_stats_after.last_sync_timestamp > 0); assert_eq!(sync_stats_after.pending_syncs, 0); for (cid, _) in &entries_to_sync {
605 assert!(node2.has_entry(cid));
606 }
607
608 let query = vec![0.15; 768];
610 let results = node2.search_local(&query, 3).unwrap();
611 assert!(!results.is_empty());
612 }
613}