ipfrs_semantic/
dht.rs

1//! Distributed Semantic DHT
2//!
3//! This module provides a distributed hash table optimized for semantic search:
4//! - Embedding-based routing to nearest peers in vector space
5//! - Clustering of similar nodes for locality optimization
6//! - Distributed k-NN search across multiple peers
7//! - Replication for fault tolerance
8//! - Load balancing and query routing optimization
9
10use crate::hnsw::{DistanceMetric, SearchResult};
11use ipfrs_core::{Cid, Error, Result};
12use ipfrs_network::libp2p::PeerId;
13use parking_lot::RwLock;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17
18/// Configuration for the semantic DHT
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct SemanticDHTConfig {
21    /// Embedding dimension for routing
22    pub embedding_dim: usize,
23    /// Number of replicas for each entry
24    pub replication_factor: usize,
25    /// Number of closest peers to consider for routing
26    pub routing_table_size: usize,
27    /// Distance metric for peer similarity
28    pub distance_metric: DistanceMetric,
29    /// Number of hops for multi-hop search
30    pub max_hops: usize,
31    /// Timeout for peer queries in milliseconds
32    pub query_timeout_ms: u64,
33}
34
35impl Default for SemanticDHTConfig {
36    fn default() -> Self {
37        Self {
38            embedding_dim: 768,
39            replication_factor: 3,
40            routing_table_size: 20,
41            distance_metric: DistanceMetric::Cosine,
42            max_hops: 5,
43            query_timeout_ms: 5000,
44        }
45    }
46}
47
48/// Represents a peer in the semantic DHT with its embedding
49#[derive(Debug, Clone)]
50pub struct SemanticPeer {
51    /// Peer identifier
52    pub peer_id: PeerId,
53    /// Embedding representing this peer's data distribution
54    pub embedding: Vec<f32>,
55    /// Cluster ID this peer belongs to
56    pub cluster_id: Option<usize>,
57    /// Last seen timestamp
58    pub last_seen: u64,
59    /// Load metric (0.0 = idle, 1.0 = overloaded)
60    pub load: f32,
61}
62
63impl SemanticPeer {
64    /// Create a new semantic peer
65    pub fn new(peer_id: PeerId, embedding: Vec<f32>) -> Self {
66        Self {
67            peer_id,
68            embedding,
69            cluster_id: None,
70            last_seen: current_timestamp(),
71            load: 0.0,
72        }
73    }
74
75    /// Update the last seen timestamp
76    pub fn update_last_seen(&mut self) {
77        self.last_seen = current_timestamp();
78    }
79
80    /// Update the load metric
81    pub fn update_load(&mut self, load: f32) {
82        self.load = load.clamp(0.0, 1.0);
83    }
84}
85
86/// Routing table for semantic DHT
87#[derive(Debug)]
88pub struct SemanticRoutingTable {
89    /// Configuration
90    config: SemanticDHTConfig,
91    /// Known peers with their embeddings
92    peers: Arc<RwLock<HashMap<PeerId, SemanticPeer>>>,
93    /// Cluster assignments
94    clusters: Arc<RwLock<HashMap<usize, Vec<PeerId>>>>,
95    /// Local peer's embedding
96    local_embedding: Arc<RwLock<Vec<f32>>>,
97    /// Route cache: maps embedding hash to best peers (for query routing optimization)
98    route_cache: Arc<RwLock<lru::LruCache<u64, Vec<PeerId>>>>,
99}
100
101impl SemanticRoutingTable {
102    /// Create a new semantic routing table
103    pub fn new(config: SemanticDHTConfig) -> Self {
104        let local_embedding = vec![0.0; config.embedding_dim];
105        Self {
106            config,
107            peers: Arc::new(RwLock::new(HashMap::new())),
108            clusters: Arc::new(RwLock::new(HashMap::new())),
109            local_embedding: Arc::new(RwLock::new(local_embedding)),
110            route_cache: Arc::new(RwLock::new(lru::LruCache::new(
111                std::num::NonZeroUsize::new(1000).unwrap(),
112            ))),
113        }
114    }
115
116    /// Update local peer's embedding based on stored data
117    pub fn update_local_embedding(&self, embedding: Vec<f32>) -> Result<()> {
118        if embedding.len() != self.config.embedding_dim {
119            return Err(Error::InvalidInput(format!(
120                "Expected embedding dimension {}, got {}",
121                self.config.embedding_dim,
122                embedding.len()
123            )));
124        }
125        *self.local_embedding.write() = embedding;
126        Ok(())
127    }
128
129    /// Add or update a peer in the routing table
130    pub fn add_peer(&self, peer: SemanticPeer) -> Result<()> {
131        if peer.embedding.len() != self.config.embedding_dim {
132            return Err(Error::InvalidInput(format!(
133                "Expected embedding dimension {}, got {}",
134                self.config.embedding_dim,
135                peer.embedding.len()
136            )));
137        }
138        self.peers.write().insert(peer.peer_id, peer);
139        Ok(())
140    }
141
142    /// Remove a peer from the routing table
143    pub fn remove_peer(&self, peer_id: &PeerId) {
144        self.peers.write().remove(peer_id);
145    }
146
147    /// Find k nearest peers to a given embedding (greedy routing)
148    pub fn find_nearest_peers(&self, embedding: &[f32], k: usize) -> Vec<(PeerId, f32)> {
149        // Check route cache first
150        if let Some(cached_peers) = self.get_cached_route(embedding) {
151            // Return cached peers with recomputed distances for accuracy
152            let peers = self.peers.read();
153            let result: Vec<(PeerId, f32)> = cached_peers
154                .iter()
155                .filter_map(|peer_id| {
156                    peers.get(peer_id).map(|peer| {
157                        let distance = self.compute_distance(embedding, &peer.embedding);
158                        (*peer_id, distance)
159                    })
160                })
161                .take(k)
162                .collect();
163
164            if result.len() == k {
165                return result;
166            }
167            // Cache was stale, fall through to recompute
168        }
169
170        let peers = self.peers.read();
171        let mut distances: Vec<(PeerId, f32)> = peers
172            .values()
173            .map(|peer| {
174                let distance = self.compute_distance(embedding, &peer.embedding);
175                (peer.peer_id, distance)
176            })
177            .collect();
178
179        // Sort by distance (ascending for L2, descending for cosine similarity)
180        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
181
182        let result: Vec<(PeerId, f32)> = distances.into_iter().take(k).collect();
183
184        // Cache the routing decision
185        let peer_ids: Vec<PeerId> = result.iter().map(|(id, _)| *id).collect();
186        drop(peers);
187        self.cache_route(embedding, peer_ids);
188
189        result
190    }
191
192    /// Find k nearest peers with load balancing consideration
193    pub fn find_nearest_peers_balanced(&self, embedding: &[f32], k: usize) -> Vec<(PeerId, f32)> {
194        let peers = self.peers.read();
195        let mut scored_peers: Vec<(PeerId, f32)> = peers
196            .values()
197            .map(|peer| {
198                let distance = self.compute_distance(embedding, &peer.embedding);
199                // Penalize overloaded peers: score = distance * (1 + load)
200                let score = distance * (1.0 + peer.load);
201                (peer.peer_id, score)
202            })
203            .collect();
204
205        scored_peers.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
206        scored_peers.into_iter().take(k).collect()
207    }
208
209    /// Get peers in a specific cluster
210    pub fn get_cluster_peers(&self, cluster_id: usize) -> Vec<PeerId> {
211        self.clusters
212            .read()
213            .get(&cluster_id)
214            .cloned()
215            .unwrap_or_default()
216    }
217
218    /// Get number of peers
219    pub fn num_peers(&self) -> usize {
220        self.peers.read().len()
221    }
222
223    /// Get number of clusters
224    pub fn num_clusters(&self) -> usize {
225        self.clusters.read().len()
226    }
227
228    /// Hash an embedding for route caching
229    fn hash_embedding(embedding: &[f32]) -> u64 {
230        use std::collections::hash_map::DefaultHasher;
231        use std::hash::{Hash, Hasher};
232
233        let mut hasher = DefaultHasher::new();
234        // Hash first 8 dimensions for efficiency (representative sample)
235        for &val in embedding.iter().take(8) {
236            // Convert to bits for consistent hashing
237            val.to_bits().hash(&mut hasher);
238        }
239        hasher.finish()
240    }
241
242    /// Check route cache for cached routing decision
243    pub fn get_cached_route(&self, embedding: &[f32]) -> Option<Vec<PeerId>> {
244        let hash = Self::hash_embedding(embedding);
245        self.route_cache.write().get(&hash).cloned()
246    }
247
248    /// Cache a routing decision for future queries
249    pub fn cache_route(&self, embedding: &[f32], peers: Vec<PeerId>) {
250        let hash = Self::hash_embedding(embedding);
251        self.route_cache.write().put(hash, peers);
252    }
253
254    /// Clear the route cache (useful when network topology changes significantly)
255    pub fn clear_route_cache(&self) {
256        self.route_cache.write().clear();
257    }
258
259    /// Get route cache statistics
260    pub fn route_cache_stats(&self) -> (usize, usize) {
261        let cache = self.route_cache.read();
262        (cache.len(), cache.cap().get())
263    }
264
265    /// Update peer clusters using k-means clustering
266    pub fn update_clusters(&self, num_clusters: usize) -> Result<()> {
267        let peers = self.peers.read();
268        if peers.is_empty() {
269            return Ok(());
270        }
271
272        let embeddings: Vec<Vec<f32>> = peers.values().map(|p| p.embedding.clone()).collect();
273        let peer_ids: Vec<PeerId> = peers.keys().cloned().collect();
274        drop(peers);
275
276        // Simple k-means clustering
277        let assignments = self.kmeans_clustering(&embeddings, num_clusters);
278
279        // Update peer cluster assignments
280        let mut peers_write = self.peers.write();
281        let mut clusters_write = self.clusters.write();
282        clusters_write.clear();
283
284        for (peer_id, cluster_id) in peer_ids.iter().zip(assignments.iter()) {
285            if let Some(peer) = peers_write.get_mut(peer_id) {
286                peer.cluster_id = Some(*cluster_id);
287            }
288            clusters_write
289                .entry(*cluster_id)
290                .or_default()
291                .push(*peer_id);
292        }
293
294        Ok(())
295    }
296
297    /// Compute distance between two embeddings
298    fn compute_distance(&self, a: &[f32], b: &[f32]) -> f32 {
299        match self.config.distance_metric {
300            DistanceMetric::L2 => crate::simd::l2_distance(a, b),
301            DistanceMetric::Cosine => crate::simd::cosine_distance(a, b),
302            DistanceMetric::DotProduct => -crate::simd::dot_product(a, b), // Negative for similarity
303        }
304    }
305
306    /// Simple k-means clustering implementation
307    fn kmeans_clustering(&self, embeddings: &[Vec<f32>], k: usize) -> Vec<usize> {
308        if embeddings.is_empty() || k == 0 {
309            return Vec::new();
310        }
311
312        let k = k.min(embeddings.len());
313        let dim = embeddings[0].len();
314
315        // Initialize centroids randomly
316        let mut centroids: Vec<Vec<f32>> = (0..k)
317            .map(|i| embeddings[i % embeddings.len()].clone())
318            .collect();
319
320        let mut assignments = vec![0; embeddings.len()];
321        let max_iterations = 10;
322
323        for _ in 0..max_iterations {
324            // Assignment step
325            for (i, embedding) in embeddings.iter().enumerate() {
326                let mut min_dist = f32::MAX;
327                let mut best_cluster = 0;
328
329                for (cluster_id, centroid) in centroids.iter().enumerate() {
330                    let dist = self.compute_distance(embedding, centroid);
331                    if dist < min_dist {
332                        min_dist = dist;
333                        best_cluster = cluster_id;
334                    }
335                }
336                assignments[i] = best_cluster;
337            }
338
339            // Update step
340            let mut new_centroids = vec![vec![0.0; dim]; k];
341            let mut counts = vec![0; k];
342
343            for (embedding, &cluster_id) in embeddings.iter().zip(assignments.iter()) {
344                for (j, &val) in embedding.iter().enumerate() {
345                    new_centroids[cluster_id][j] += val;
346                }
347                counts[cluster_id] += 1;
348            }
349
350            for (cluster_id, count) in counts.iter().enumerate() {
351                if *count > 0 {
352                    for j in 0..dim {
353                        new_centroids[cluster_id][j] /= *count as f32;
354                    }
355                }
356            }
357
358            centroids = new_centroids;
359        }
360
361        assignments
362    }
363}
364
365/// DHT query for distributed search
366#[derive(Debug, Clone, Serialize, Deserialize)]
367pub struct DHTQuery {
368    /// Query embedding
369    pub embedding: Vec<f32>,
370    /// Number of results requested
371    pub k: usize,
372    /// Query ID for tracking
373    pub query_id: String,
374    /// TTL (time to live) for query propagation
375    pub ttl: usize,
376    /// Peers already visited (to prevent loops) - serialized as strings
377    #[serde(skip)]
378    pub visited: HashSet<PeerId>,
379}
380
381/// DHT query response
382#[derive(Debug, Clone)]
383pub struct DHTQueryResponse {
384    /// Query ID
385    pub query_id: String,
386    /// Results from this peer
387    pub results: Vec<SearchResult>,
388    /// Responding peer ID
389    pub peer_id: PeerId,
390}
391
392/// Replication strategy for fault tolerance
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub enum ReplicationStrategy {
395    /// Replicate to k nearest peers
396    NearestPeers(usize),
397    /// Replicate to peers in same cluster
398    SameCluster,
399    /// Replicate to peers across different clusters
400    CrossCluster(usize),
401}
402
403/// Entry in the distributed index
404#[derive(Debug, Clone)]
405pub struct DHTEntry {
406    /// Content ID
407    pub cid: Cid,
408    /// Embedding
409    pub embedding: Vec<f32>,
410    /// Primary peer responsible for this entry
411    pub primary_peer: PeerId,
412    /// Replica peers
413    pub replicas: Vec<PeerId>,
414}
415
416/// Statistics for the semantic DHT
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct SemanticDHTStats {
419    /// Number of known peers
420    pub num_peers: usize,
421    /// Number of clusters
422    pub num_clusters: usize,
423    /// Number of entries in local index
424    pub num_local_entries: usize,
425    /// Number of queries processed
426    pub queries_processed: u64,
427    /// Average query latency in milliseconds
428    pub avg_query_latency_ms: f64,
429    /// Number of multi-hop queries
430    pub multi_hop_queries: u64,
431}
432
433/// Get current timestamp in seconds
434fn current_timestamp() -> u64 {
435    std::time::SystemTime::now()
436        .duration_since(std::time::UNIX_EPOCH)
437        .unwrap()
438        .as_secs()
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_routing_table_creation() {
447        let config = SemanticDHTConfig::default();
448        let table = SemanticRoutingTable::new(config);
449
450        let local_emb = vec![0.5; 768];
451        assert!(table.update_local_embedding(local_emb).is_ok());
452    }
453
454    #[test]
455    fn test_add_peer() {
456        let config = SemanticDHTConfig::default();
457        let table = SemanticRoutingTable::new(config);
458
459        let peer_id = PeerId::random();
460        let embedding = vec![0.5; 768];
461        let peer = SemanticPeer::new(peer_id, embedding);
462
463        assert!(table.add_peer(peer).is_ok());
464    }
465
466    #[test]
467    fn test_find_nearest_peers() {
468        let config = SemanticDHTConfig::default();
469        let table = SemanticRoutingTable::new(config);
470
471        // Add some peers
472        for i in 0..10 {
473            let peer_id = PeerId::random();
474            let embedding = vec![i as f32 * 0.1; 768];
475            let peer = SemanticPeer::new(peer_id, embedding);
476            table.add_peer(peer).unwrap();
477        }
478
479        let query_embedding = vec![0.5; 768];
480        let nearest = table.find_nearest_peers(&query_embedding, 3);
481
482        assert_eq!(nearest.len(), 3);
483    }
484
485    #[test]
486    fn test_clustering() {
487        let config = SemanticDHTConfig::default();
488        let table = SemanticRoutingTable::new(config);
489
490        // Add peers with distinct embeddings
491        for i in 0..20 {
492            let peer_id = PeerId::random();
493            let mut embedding = vec![0.0; 768];
494            // Create two clusters
495            if i < 10 {
496                embedding[0] = 1.0;
497            } else {
498                embedding[0] = -1.0;
499            }
500            let peer = SemanticPeer::new(peer_id, embedding);
501            table.add_peer(peer).unwrap();
502        }
503
504        assert!(table.update_clusters(2).is_ok());
505
506        // Check that clusters were assigned
507        let cluster0 = table.get_cluster_peers(0);
508        let cluster1 = table.get_cluster_peers(1);
509
510        assert!(!cluster0.is_empty() || !cluster1.is_empty());
511    }
512
513    #[test]
514    fn test_load_balancing() {
515        let config = SemanticDHTConfig::default();
516        let table = SemanticRoutingTable::new(config);
517
518        // Add peers with different loads
519        for i in 0..5 {
520            let peer_id = PeerId::random();
521            let embedding = vec![0.5; 768];
522            let mut peer = SemanticPeer::new(peer_id, embedding);
523            peer.update_load(i as f32 * 0.2); // Load: 0.0, 0.2, 0.4, 0.6, 0.8
524            table.add_peer(peer).unwrap();
525        }
526
527        let query_embedding = vec![0.5; 768];
528        let balanced = table.find_nearest_peers_balanced(&query_embedding, 3);
529
530        assert_eq!(balanced.len(), 3);
531        // Lower load peers should be preferred
532    }
533
534    #[test]
535    fn test_route_caching() {
536        let config = SemanticDHTConfig::default();
537        let table = SemanticRoutingTable::new(config);
538
539        // Add some peers
540        for i in 0..10 {
541            let peer_id = PeerId::random();
542            let embedding = vec![i as f32 * 0.1; 768];
543            let peer = SemanticPeer::new(peer_id, embedding);
544            table.add_peer(peer).unwrap();
545        }
546
547        let query_embedding = vec![0.5; 768];
548
549        // First query should not be cached
550        let (cache_size_before, _) = table.route_cache_stats();
551        assert_eq!(cache_size_before, 0);
552
553        let result1 = table.find_nearest_peers(&query_embedding, 3);
554        assert_eq!(result1.len(), 3);
555
556        // After first query, should be cached
557        let (cache_size_after, _) = table.route_cache_stats();
558        assert_eq!(cache_size_after, 1);
559
560        // Second query with same embedding should use cache
561        let result2 = table.find_nearest_peers(&query_embedding, 3);
562        assert_eq!(result2.len(), 3);
563
564        // Results should be the same peer IDs
565        let ids1: Vec<_> = result1.iter().map(|(id, _)| id).collect();
566        let ids2: Vec<_> = result2.iter().map(|(id, _)| id).collect();
567        assert_eq!(ids1, ids2);
568
569        // Test cache clearing
570        table.clear_route_cache();
571        let (cache_size_cleared, _) = table.route_cache_stats();
572        assert_eq!(cache_size_cleared, 0);
573    }
574}