ipfrs_semantic/
dht_node.rs

1//! Distributed Semantic DHT Node
2//!
3//! This module implements the main DHT node that coordinates:
4//! - Local vector index management
5//! - Distributed k-NN search across peers
6//! - Replication and fault tolerance
7//! - Query routing and result aggregation
8
9use crate::dht::{
10    ReplicationStrategy, SemanticDHTConfig, SemanticDHTStats, SemanticPeer, SemanticRoutingTable,
11};
12use crate::hnsw::{SearchResult, VectorIndex};
13use ipfrs_core::{Cid, Result};
14use ipfrs_network::libp2p::PeerId;
15use parking_lot::RwLock;
16use std::collections::HashMap;
17use std::sync::Arc;
18use std::time::Instant;
19
20/// Main semantic DHT node
21pub struct SemanticDHTNode {
22    /// Configuration
23    config: SemanticDHTConfig,
24    /// Local peer ID
25    local_peer_id: PeerId,
26    /// Local vector index
27    local_index: Arc<RwLock<VectorIndex>>,
28    /// Routing table
29    routing_table: Arc<SemanticRoutingTable>,
30    /// Replication strategy
31    replication_strategy: ReplicationStrategy,
32    /// Query statistics
33    stats: Arc<RwLock<SemanticDHTStats>>,
34    /// Pending queries
35    pending_queries: Arc<RwLock<HashMap<String, Instant>>>,
36    /// Last successful synchronization timestamp (unix timestamp in seconds)
37    last_sync_timestamp: Arc<RwLock<u64>>,
38    /// Number of pending synchronization operations
39    pending_syncs: Arc<RwLock<usize>>,
40}
41
42impl SemanticDHTNode {
43    /// Create a new semantic DHT node
44    pub fn new(config: SemanticDHTConfig, local_peer_id: PeerId, local_index: VectorIndex) -> Self {
45        let routing_table = Arc::new(SemanticRoutingTable::new(config.clone()));
46
47        let stats = SemanticDHTStats {
48            num_peers: 0,
49            num_clusters: 0,
50            num_local_entries: 0,
51            queries_processed: 0,
52            avg_query_latency_ms: 0.0,
53            multi_hop_queries: 0,
54        };
55
56        Self {
57            config,
58            local_peer_id,
59            local_index: Arc::new(RwLock::new(local_index)),
60            routing_table,
61            replication_strategy: ReplicationStrategy::NearestPeers(3),
62            stats: Arc::new(RwLock::new(stats)),
63            pending_queries: Arc::new(RwLock::new(HashMap::new())),
64            last_sync_timestamp: Arc::new(RwLock::new(0)),
65            pending_syncs: Arc::new(RwLock::new(0)),
66        }
67    }
68
69    /// Insert a vector into the local index and replicate to peers
70    pub async fn insert(&self, cid: &Cid, embedding: &[f32]) -> Result<()> {
71        // Insert into local index
72        self.local_index.write().insert(cid, embedding)?;
73
74        // Update local embedding (aggregate of stored vectors)
75        self.update_local_embedding().await?;
76
77        // Determine replica peers based on strategy
78        let replica_peers = self.select_replica_peers(embedding).await?;
79
80        // TODO: Send replication requests to peers
81        // For now, just log the intended replicas
82        tracing::debug!("Would replicate {:?} to {} peers", cid, replica_peers.len());
83
84        Ok(())
85    }
86
87    /// Search for nearest neighbors locally
88    pub fn search_local(&self, embedding: &[f32], k: usize) -> Result<Vec<SearchResult>> {
89        let index = self.local_index.read();
90        let ef_search = self.config.max_hops * 10; // Heuristic
91        index.search(embedding, k, ef_search)
92    }
93
94    /// Distributed k-NN search across multiple peers
95    pub async fn search_distributed(
96        &self,
97        embedding: &[f32],
98        k: usize,
99    ) -> Result<Vec<SearchResult>> {
100        let query_id = format!("{:?}-{}", self.local_peer_id, uuid::Uuid::new_v4());
101        let start_time = Instant::now();
102
103        // Record pending query
104        self.pending_queries
105            .write()
106            .insert(query_id.clone(), start_time);
107
108        // Local search
109        let mut all_results = self.search_local(embedding, k)?;
110
111        // Find nearest peers to forward query to
112        let nearest_peers = self
113            .routing_table
114            .find_nearest_peers_balanced(embedding, self.config.routing_table_size);
115
116        // Multi-hop search
117        if !nearest_peers.is_empty() && self.config.max_hops > 0 {
118            let remote_results = self
119                .multi_hop_search(embedding, k, query_id.clone(), 0)
120                .await?;
121            all_results.extend(remote_results);
122        }
123
124        // Aggregate and rank results
125        let final_results = self.aggregate_results(all_results, k);
126
127        // Update statistics
128        let latency = start_time.elapsed().as_millis() as f64;
129        self.update_query_stats(latency, !nearest_peers.is_empty());
130
131        // Clean up pending query
132        self.pending_queries.write().remove(&query_id);
133
134        Ok(final_results)
135    }
136
137    /// Multi-hop search with TTL
138    async fn multi_hop_search(
139        &self,
140        embedding: &[f32],
141        _k: usize,
142        _query_id: String,
143        hop: usize,
144    ) -> Result<Vec<SearchResult>> {
145        if hop >= self.config.max_hops {
146            return Ok(Vec::new());
147        }
148
149        let nearest_peers = self.routing_table.find_nearest_peers_balanced(embedding, 3); // Top 3 peers
150
151        let all_results = Vec::new();
152
153        for (peer_id, _distance) in nearest_peers {
154            // TODO: Send query to remote peer
155            // For now, simulate with local search
156            if peer_id != self.local_peer_id {
157                tracing::debug!("Would query peer {:?} at hop {}", peer_id, hop);
158
159                // In real implementation:
160                // let response = self.send_query_to_peer(peer_id, query).await?;
161                // all_results.extend(response.results);
162            }
163        }
164
165        Ok(all_results)
166    }
167
168    /// Aggregate and deduplicate results from multiple sources
169    fn aggregate_results(&self, results: Vec<SearchResult>, k: usize) -> Vec<SearchResult> {
170        // Deduplicate by CID
171        let mut seen = HashMap::new();
172        let mut deduplicated = Vec::new();
173
174        for result in results {
175            if let Some(&existing_score) = seen.get(&result.cid) {
176                // Keep better score
177                if result.score < existing_score {
178                    // Find and update
179                    if let Some(pos) = deduplicated
180                        .iter()
181                        .position(|r: &SearchResult| r.cid == result.cid)
182                    {
183                        deduplicated[pos] = result.clone();
184                        seen.insert(result.cid, result.score);
185                    }
186                }
187            } else {
188                seen.insert(result.cid, result.score);
189                deduplicated.push(result);
190            }
191        }
192
193        // Sort by score and take top k
194        deduplicated.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap());
195        deduplicated.into_iter().take(k).collect()
196    }
197
198    /// Select replica peers based on replication strategy
199    async fn select_replica_peers(&self, embedding: &[f32]) -> Result<Vec<PeerId>> {
200        match &self.replication_strategy {
201            ReplicationStrategy::NearestPeers(n) => {
202                let peers = self.routing_table.find_nearest_peers(embedding, *n);
203                Ok(peers.into_iter().map(|(peer_id, _)| peer_id).collect())
204            }
205            ReplicationStrategy::SameCluster => {
206                // Find local peer's cluster
207                // For now, return empty
208                Ok(Vec::new())
209            }
210            ReplicationStrategy::CrossCluster(_n) => {
211                // Select n peers from different clusters
212                // For now, return empty
213                Ok(Vec::new())
214            }
215        }
216    }
217
218    /// Update local peer's embedding based on stored vectors
219    async fn update_local_embedding(&self) -> Result<()> {
220        let index = self.local_index.read();
221        let dim = self.config.embedding_dim;
222
223        // Compute centroid of all local vectors
224        let mut centroid = vec![0.0; dim];
225        let _count = 0;
226
227        // This is a simplified version - in practice, we'd iterate over actual vectors
228        // For now, just use a placeholder
229        drop(index);
230
231        // Normalize centroid
232        let norm: f32 = centroid.iter().map(|x| x * x).sum::<f32>().sqrt();
233        if norm > 1e-6 {
234            for x in &mut centroid {
235                *x /= norm;
236            }
237        }
238
239        self.routing_table.update_local_embedding(centroid)?;
240
241        Ok(())
242    }
243
244    /// Update query statistics
245    fn update_query_stats(&self, latency_ms: f64, is_multi_hop: bool) {
246        let mut stats = self.stats.write();
247        stats.queries_processed += 1;
248
249        // Update running average
250        let alpha = 0.1; // Exponential moving average factor
251        stats.avg_query_latency_ms =
252            alpha * latency_ms + (1.0 - alpha) * stats.avg_query_latency_ms;
253
254        if is_multi_hop {
255            stats.multi_hop_queries += 1;
256        }
257    }
258
259    /// Add a peer to the routing table
260    pub fn add_peer(&self, peer: SemanticPeer) -> Result<()> {
261        self.routing_table.add_peer(peer)?;
262
263        // Update stats
264        let mut stats = self.stats.write();
265        stats.num_peers = self.routing_table.num_peers();
266
267        Ok(())
268    }
269
270    /// Remove a peer from the routing table
271    pub fn remove_peer(&self, peer_id: &PeerId) {
272        self.routing_table.remove_peer(peer_id);
273
274        // Update stats
275        let mut stats = self.stats.write();
276        stats.num_peers = self.routing_table.num_peers();
277    }
278
279    /// Update peer clustering
280    pub fn update_clusters(&self, num_clusters: usize) -> Result<()> {
281        self.routing_table.update_clusters(num_clusters)?;
282
283        // Update stats
284        let mut stats = self.stats.write();
285        stats.num_clusters = self.routing_table.num_clusters();
286
287        Ok(())
288    }
289
290    /// Get DHT statistics
291    pub fn stats(&self) -> SemanticDHTStats {
292        let mut stats = self.stats.read().clone();
293        stats.num_local_entries = self.local_index.read().len();
294        stats
295    }
296
297    /// Get DHT statistics (alias for stats)
298    pub fn get_stats(&self) -> SemanticDHTStats {
299        self.stats()
300    }
301
302    /// Get reference to the routing table
303    pub fn routing_table(&self) -> &SemanticRoutingTable {
304        &self.routing_table
305    }
306
307    /// Set replication strategy
308    pub fn set_replication_strategy(&mut self, strategy: ReplicationStrategy) {
309        self.replication_strategy = strategy;
310    }
311
312    /// Get a snapshot of local index entries for synchronization
313    /// Returns CIDs that can be used for delta synchronization
314    pub fn get_index_snapshot(&self) -> Vec<Cid> {
315        let index = self.local_index.read();
316        // Get all CIDs from the index
317        index.get_all_cids()
318    }
319
320    /// Check if local index has a specific CID
321    pub fn has_entry(&self, cid: &Cid) -> bool {
322        let index = self.local_index.read();
323        index.contains(cid)
324    }
325
326    /// Prepare synchronization delta: entries that peer needs
327    /// Returns CIDs that are in our index but not in the peer's snapshot
328    pub fn prepare_sync_delta(&self, peer_snapshot: &[Cid]) -> Vec<Cid> {
329        let local_snapshot = self.get_index_snapshot();
330        let peer_set: std::collections::HashSet<_> = peer_snapshot.iter().collect();
331
332        local_snapshot
333            .into_iter()
334            .filter(|cid| !peer_set.contains(cid))
335            .collect()
336    }
337
338    /// Apply synchronization delta: add entries from peer
339    /// This is a foundation - actual implementation would fetch embeddings from peer
340    pub async fn apply_sync_delta(&self, delta_cids: Vec<Cid>) -> Result<usize> {
341        // NOTE: In full implementation with network protocol, this would:
342        // 1. Request embeddings for delta_cids from peer
343        // 2. Call apply_sync_delta_with_embeddings with the fetched data
344        // For now, just return count of CIDs that would be synced
345        Ok(delta_cids.len())
346    }
347
348    /// Apply synchronization delta with embeddings: add entries from peer
349    /// This method actually inserts the embeddings into the local index
350    pub async fn apply_sync_delta_with_embeddings(
351        &self,
352        delta_entries: Vec<(Cid, Vec<f32>)>,
353    ) -> Result<usize> {
354        // Increment pending syncs counter
355        *self.pending_syncs.write() += 1;
356
357        let mut synced_count = 0;
358
359        // Insert each entry into local index
360        for (cid, embedding) in delta_entries {
361            match self.local_index.write().insert(&cid, &embedding) {
362                Ok(_) => {
363                    synced_count += 1;
364                }
365                Err(e) => {
366                    tracing::warn!("Failed to insert CID {:?} during sync: {}", cid, e);
367                }
368            }
369        }
370
371        // Update last sync timestamp
372        let now = std::time::SystemTime::now()
373            .duration_since(std::time::UNIX_EPOCH)
374            .unwrap_or_default()
375            .as_secs();
376        *self.last_sync_timestamp.write() = now;
377
378        // Decrement pending syncs counter
379        *self.pending_syncs.write() -= 1;
380
381        // Update local embedding after sync
382        self.update_local_embedding().await?;
383
384        tracing::debug!("Synced {} entries from peer", synced_count);
385
386        Ok(synced_count)
387    }
388
389    /// Get synchronization statistics
390    pub fn sync_stats(&self) -> SyncStats {
391        SyncStats {
392            local_entries: self.local_index.read().len(),
393            last_sync_timestamp: *self.last_sync_timestamp.read(),
394            pending_syncs: *self.pending_syncs.read(),
395        }
396    }
397}
398
399/// Statistics for index synchronization
400#[derive(Debug, Clone)]
401pub struct SyncStats {
402    /// Number of entries in local index
403    pub local_entries: usize,
404    /// Timestamp of last successful sync
405    pub last_sync_timestamp: u64,
406    /// Number of pending sync operations
407    pub pending_syncs: usize,
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::hnsw::DistanceMetric;
414
415    #[tokio::test]
416    async fn test_dht_node_creation() {
417        let config = SemanticDHTConfig::default();
418        let peer_id = PeerId::random();
419        let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
420
421        let node = SemanticDHTNode::new(config, peer_id, index);
422        let stats = node.stats();
423
424        assert_eq!(stats.num_peers, 0);
425        assert_eq!(stats.queries_processed, 0);
426    }
427
428    #[tokio::test]
429    async fn test_local_insert_and_search() {
430        let config = SemanticDHTConfig::default();
431        let peer_id = PeerId::random();
432        let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
433
434        let node = SemanticDHTNode::new(config, peer_id, index);
435
436        // Insert some vectors
437        for i in 0..10 {
438            use multihash_codetable::{Code, MultihashDigest};
439            let data = format!("test_vector_{}", i);
440            let hash = Code::Sha2_256.digest(data.as_bytes());
441            let cid = Cid::new_v1(0x55, hash);
442            let embedding = vec![i as f32 * 0.1; 768];
443            node.insert(&cid, &embedding).await.unwrap();
444        }
445
446        // Search
447        let query = vec![0.5; 768];
448        let results = node.search_local(&query, 5).unwrap();
449
450        assert!(!results.is_empty());
451        assert!(results.len() <= 5);
452    }
453
454    #[tokio::test]
455    async fn test_add_peers() {
456        let config = SemanticDHTConfig::default();
457        let peer_id = PeerId::random();
458        let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
459
460        let node = SemanticDHTNode::new(config, peer_id, index);
461
462        // Add some peers
463        for i in 0..5 {
464            let peer_id = PeerId::random();
465            let embedding = vec![i as f32 * 0.2; 768];
466            let peer = SemanticPeer::new(peer_id, embedding);
467            node.add_peer(peer).unwrap();
468        }
469
470        let stats = node.stats();
471        assert_eq!(stats.num_peers, 5);
472    }
473
474    #[tokio::test]
475    async fn test_clustering() {
476        let config = SemanticDHTConfig::default();
477        let peer_id = PeerId::random();
478        let index = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
479
480        let node = SemanticDHTNode::new(config, peer_id, index);
481
482        // Add peers
483        for i in 0..20 {
484            let peer_id = PeerId::random();
485            let mut embedding = vec![0.0; 768];
486            embedding[0] = if i < 10 { 1.0 } else { -1.0 };
487            let peer = SemanticPeer::new(peer_id, embedding);
488            node.add_peer(peer).unwrap();
489        }
490
491        // Update clusters
492        node.update_clusters(2).unwrap();
493
494        let stats = node.stats();
495        assert!(stats.num_clusters > 0);
496    }
497
498    #[tokio::test]
499    async fn test_index_synchronization() {
500        use multihash_codetable::{Code, MultihashDigest};
501
502        let config = SemanticDHTConfig::default();
503        let peer_id1 = PeerId::random();
504        let peer_id2 = PeerId::random();
505
506        let index1 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
507        let index2 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
508
509        let node1 = SemanticDHTNode::new(config.clone(), peer_id1, index1);
510        let node2 = SemanticDHTNode::new(config, peer_id2, index2);
511
512        // Insert data into node1
513        let mut cids1 = Vec::new();
514        for i in 0..5 {
515            let data = format!("node1_vector_{}", i);
516            let hash = Code::Sha2_256.digest(data.as_bytes());
517            let cid = Cid::new_v1(0x55, hash);
518            let embedding = vec![i as f32 * 0.1; 768];
519            node1.insert(&cid, &embedding).await.unwrap();
520            cids1.push(cid);
521        }
522
523        // Insert different data into node2
524        let mut cids2 = Vec::new();
525        for i in 5..10 {
526            let data = format!("node2_vector_{}", i);
527            let hash = Code::Sha2_256.digest(data.as_bytes());
528            let cid = Cid::new_v1(0x55, hash);
529            let embedding = vec![i as f32 * 0.1; 768];
530            node2.insert(&cid, &embedding).await.unwrap();
531            cids2.push(cid);
532        }
533
534        // Get snapshots
535        let snapshot1 = node1.get_index_snapshot();
536        let snapshot2 = node2.get_index_snapshot();
537
538        assert_eq!(snapshot1.len(), 5);
539        assert_eq!(snapshot2.len(), 5);
540
541        // Check that node1 has its entries
542        for cid in &cids1 {
543            assert!(node1.has_entry(cid));
544        }
545
546        // Prepare delta: what node2 needs from node1
547        let delta = node1.prepare_sync_delta(&snapshot2);
548        assert_eq!(delta.len(), 5); // All of node1's entries are missing from node2
549
550        // Apply delta (in real implementation, this would fetch and insert)
551        let synced_count = node2.apply_sync_delta(delta).await.unwrap();
552        assert_eq!(synced_count, 5);
553
554        // Check sync stats
555        let sync_stats = node1.sync_stats();
556        assert_eq!(sync_stats.local_entries, 5);
557    }
558
559    #[tokio::test]
560    async fn test_sync_with_embeddings() {
561        use multihash_codetable::{Code, MultihashDigest};
562
563        let config = SemanticDHTConfig::default();
564        let peer_id1 = PeerId::random();
565        let peer_id2 = PeerId::random();
566
567        let index1 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
568        let index2 = VectorIndex::new(768, DistanceMetric::Cosine, 16, 200).unwrap();
569
570        let node1 = SemanticDHTNode::new(config.clone(), peer_id1, index1);
571        let node2 = SemanticDHTNode::new(config, peer_id2, index2);
572
573        // Insert data into node1
574        let mut entries_to_sync = Vec::new();
575        for i in 0..5 {
576            let data = format!("sync_test_vector_{}", i);
577            let hash = Code::Sha2_256.digest(data.as_bytes());
578            let cid = Cid::new_v1(0x55, hash);
579            let embedding = vec![i as f32 * 0.1; 768];
580            node1.insert(&cid, &embedding).await.unwrap();
581            entries_to_sync.push((cid, embedding));
582        }
583
584        // Check initial state
585        let sync_stats_before = node2.sync_stats();
586        assert_eq!(sync_stats_before.local_entries, 0);
587        assert_eq!(sync_stats_before.last_sync_timestamp, 0);
588        assert_eq!(sync_stats_before.pending_syncs, 0);
589
590        // Apply sync with embeddings to node2
591        let synced_count = node2
592            .apply_sync_delta_with_embeddings(entries_to_sync.clone())
593            .await
594            .unwrap();
595        assert_eq!(synced_count, 5);
596
597        // Check that node2 now has the entries
598        let sync_stats_after = node2.sync_stats();
599        assert_eq!(sync_stats_after.local_entries, 5);
600        assert!(sync_stats_after.last_sync_timestamp > 0); // Should be updated
601        assert_eq!(sync_stats_after.pending_syncs, 0); // Should be back to 0
602
603        // Verify all CIDs are present in node2
604        for (cid, _) in &entries_to_sync {
605            assert!(node2.has_entry(cid));
606        }
607
608        // Search should work on node2 now
609        let query = vec![0.15; 768];
610        let results = node2.search_local(&query, 3).unwrap();
611        assert!(!results.is_empty());
612    }
613}