1use cid::Cid;
27use dashmap::DashMap;
28use libp2p::PeerId;
29use multihash_codetable::{Code, MultihashDigest};
30use parking_lot::RwLock;
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use thiserror::Error;
36
37#[derive(Error, Debug)]
39pub enum SemanticDhtError {
40 #[error("Invalid embedding dimension: expected {expected}, got {actual}")]
41 InvalidDimension { expected: usize, actual: usize },
42
43 #[error("Unknown namespace: {0}")]
44 UnknownNamespace(String),
45
46 #[error("No peers found for embedding region")]
47 NoPeersFound,
48
49 #[error("Query timeout after {0:?}")]
50 QueryTimeout(Duration),
51
52 #[error("Embedding encoding error: {0}")]
53 EncodingError(String),
54}
55
56#[derive(Debug, Clone)]
58pub struct SemanticDhtConfig {
59 pub lsh_hash_functions: usize,
61
62 pub lsh_hash_tables: usize,
64
65 pub lsh_bucket_width: f32,
67
68 pub max_query_peers: usize,
70
71 pub query_timeout: Duration,
73
74 pub enable_caching: bool,
76
77 pub cache_ttl: Duration,
79
80 pub max_cache_size: usize,
82
83 pub top_k: usize,
85}
86
87impl Default for SemanticDhtConfig {
88 fn default() -> Self {
89 Self {
90 lsh_hash_functions: 8,
91 lsh_hash_tables: 4,
92 lsh_bucket_width: 4.0,
93 max_query_peers: 20,
94 query_timeout: Duration::from_secs(10),
95 enable_caching: true,
96 cache_ttl: Duration::from_secs(300),
97 max_cache_size: 1000,
98 top_k: 10,
99 }
100 }
101}
102
103#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
105pub struct NamespaceId(pub String);
106
107impl NamespaceId {
108 pub fn new(name: impl Into<String>) -> Self {
109 Self(name.into())
110 }
111
112 pub fn text() -> Self {
114 Self("text".to_string())
115 }
116
117 pub fn image() -> Self {
119 Self("image".to_string())
120 }
121
122 pub fn audio() -> Self {
124 Self("audio".to_string())
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct SemanticNamespace {
131 pub id: NamespaceId,
133
134 pub dimension: usize,
136
137 pub distance_metric: DistanceMetric,
139
140 pub lsh_config: LshConfig,
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146pub enum DistanceMetric {
147 Euclidean,
149
150 Cosine,
152
153 Manhattan,
155
156 DotProduct,
158}
159
160#[derive(Debug, Clone)]
162pub struct LshConfig {
163 pub hash_functions: usize,
165
166 pub num_tables: usize,
168
169 pub bucket_width: f32,
171}
172
173impl Default for LshConfig {
174 fn default() -> Self {
175 Self {
176 hash_functions: 8,
177 num_tables: 4,
178 bucket_width: 4.0,
179 }
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
185pub struct LshHash {
186 pub table: usize,
188
189 pub bucket: Vec<i32>,
191}
192
193impl LshHash {
194 pub fn to_cid(&self) -> Cid {
196 let mut data = Vec::new();
198 data.push(self.table as u8);
199 for &val in &self.bucket {
200 data.extend_from_slice(&val.to_le_bytes());
201 }
202
203 let hash = Code::Sha2_256.digest(&data);
205
206 Cid::new_v1(0x55, hash) }
209}
210
211#[derive(Debug, Clone)]
213pub struct SemanticQuery {
214 pub embedding: Vec<f32>,
216
217 pub namespace: NamespaceId,
219
220 pub top_k: usize,
222
223 pub metadata_filter: Option<HashMap<String, String>>,
225
226 pub timeout: Duration,
228}
229
230#[derive(Debug, Clone)]
232pub struct SemanticResult {
233 pub cid: Cid,
235
236 pub score: f32,
238
239 pub peer: PeerId,
241
242 pub metadata: HashMap<String, String>,
244}
245
246#[derive(Debug, Clone)]
248struct CacheEntry {
249 results: Vec<SemanticResult>,
250 timestamp: Instant,
251}
252
253pub struct SemanticDht {
255 config: SemanticDhtConfig,
257
258 namespaces: Arc<DashMap<NamespaceId, SemanticNamespace>>,
260
261 lsh_projections: Arc<DashMap<NamespaceId, Vec<Vec<f32>>>>,
263
264 hash_to_peers: Arc<DashMap<LshHash, Vec<PeerId>>>,
266
267 local_index: Arc<DashMap<Cid, (Vec<f32>, NamespaceId)>>,
269
270 query_cache: Arc<DashMap<Vec<u8>, CacheEntry>>,
272
273 stats: Arc<RwLock<SemanticDhtStats>>,
275}
276
277#[derive(Debug, Clone, Default, Serialize, Deserialize)]
279pub struct SemanticDhtStats {
280 pub total_queries: u64,
282
283 pub successful_queries: u64,
285
286 pub failed_queries: u64,
288
289 pub cache_hits: u64,
291
292 pub cache_misses: u64,
294
295 pub avg_query_latency_ms: f64,
297
298 pub indexed_content: u64,
300
301 pub queries_per_namespace: HashMap<String, u64>,
303}
304
305impl SemanticDht {
306 pub fn new(config: SemanticDhtConfig) -> Self {
308 Self {
309 config,
310 namespaces: Arc::new(DashMap::new()),
311 lsh_projections: Arc::new(DashMap::new()),
312 hash_to_peers: Arc::new(DashMap::new()),
313 local_index: Arc::new(DashMap::new()),
314 query_cache: Arc::new(DashMap::new()),
315 stats: Arc::new(RwLock::new(SemanticDhtStats::default())),
316 }
317 }
318
319 pub fn register_namespace(&self, namespace: SemanticNamespace) -> Result<(), SemanticDhtError> {
321 let namespace_id = namespace.id.clone();
322
323 let projections = self.generate_lsh_projections(
325 namespace.dimension,
326 namespace.lsh_config.hash_functions,
327 namespace.lsh_config.num_tables,
328 );
329
330 self.lsh_projections
331 .insert(namespace_id.clone(), projections);
332 self.namespaces.insert(namespace_id, namespace);
333
334 Ok(())
335 }
336
337 fn generate_lsh_projections(
339 &self,
340 dimension: usize,
341 hash_functions: usize,
342 num_tables: usize,
343 ) -> Vec<Vec<f32>> {
344 use std::f32::consts::PI;
345
346 let mut projections = Vec::new();
347 let total_projections = hash_functions * num_tables;
348
349 for i in 0..total_projections {
351 let mut projection = Vec::with_capacity(dimension);
352
353 for j in 0..dimension {
354 let seed = (i * dimension + j) as f32;
356 let angle = seed * 2.0 * PI / 1000.0;
357 let value = angle.sin();
358 projection.push(value);
359 }
360
361 let norm: f32 = projection.iter().map(|x| x * x).sum::<f32>().sqrt();
363 if norm > 0.0 {
364 for val in &mut projection {
365 *val /= norm;
366 }
367 }
368
369 projections.push(projection);
370 }
371
372 projections
373 }
374
375 pub fn compute_lsh_hashes(
377 &self,
378 embedding: &[f32],
379 namespace: &NamespaceId,
380 ) -> Result<Vec<LshHash>, SemanticDhtError> {
381 let ns = self
382 .namespaces
383 .get(namespace)
384 .ok_or_else(|| SemanticDhtError::UnknownNamespace(namespace.0.clone()))?;
385
386 if embedding.len() != ns.dimension {
387 return Err(SemanticDhtError::InvalidDimension {
388 expected: ns.dimension,
389 actual: embedding.len(),
390 });
391 }
392
393 let projections = self
394 .lsh_projections
395 .get(namespace)
396 .ok_or_else(|| SemanticDhtError::UnknownNamespace(namespace.0.clone()))?;
397
398 let mut hashes = Vec::new();
399 let hash_functions = ns.lsh_config.hash_functions;
400
401 for table in 0..ns.lsh_config.num_tables {
402 let mut bucket = Vec::with_capacity(hash_functions);
403
404 for func in 0..hash_functions {
405 let proj_idx = table * hash_functions + func;
406 let projection = &projections[proj_idx];
407
408 let dot_product: f32 = embedding
410 .iter()
411 .zip(projection.iter())
412 .map(|(a, b)| a * b)
413 .sum();
414
415 let quantized = (dot_product / ns.lsh_config.bucket_width).floor() as i32;
417 bucket.push(quantized);
418 }
419
420 hashes.push(LshHash { table, bucket });
421 }
422
423 Ok(hashes)
424 }
425
426 pub fn index_content(
428 &self,
429 cid: Cid,
430 embedding: Vec<f32>,
431 namespace: NamespaceId,
432 ) -> Result<(), SemanticDhtError> {
433 let ns = self
435 .namespaces
436 .get(&namespace)
437 .ok_or_else(|| SemanticDhtError::UnknownNamespace(namespace.0.clone()))?;
438
439 if embedding.len() != ns.dimension {
440 return Err(SemanticDhtError::InvalidDimension {
441 expected: ns.dimension,
442 actual: embedding.len(),
443 });
444 }
445
446 self.local_index
448 .insert(cid, (embedding.clone(), namespace.clone()));
449
450 let hashes = self.compute_lsh_hashes(&embedding, &namespace)?;
452
453 for hash in hashes {
455 let _ = hash.to_cid();
457 }
458
459 let mut stats = self.stats.write();
461 stats.indexed_content += 1;
462
463 Ok(())
464 }
465
466 pub fn query(&self, query: SemanticQuery) -> Result<Vec<SemanticResult>, SemanticDhtError> {
468 let start = Instant::now();
469
470 if self.config.enable_caching {
472 let cache_key = self.compute_cache_key(&query);
473 if let Some(entry) = self.query_cache.get(&cache_key) {
474 if start.duration_since(entry.timestamp) < self.config.cache_ttl {
475 let mut stats = self.stats.write();
476 stats.cache_hits += 1;
477 return Ok(entry.results.clone());
478 }
479 }
480 }
481
482 let _ns = self
484 .namespaces
485 .get(&query.namespace)
486 .ok_or_else(|| SemanticDhtError::UnknownNamespace(query.namespace.0.clone()))?;
487
488 let hashes = self.compute_lsh_hashes(&query.embedding, &query.namespace)?;
490
491 let mut candidate_peers = Vec::new();
493 for hash in &hashes {
494 if let Some(peers) = self.hash_to_peers.get(hash) {
495 candidate_peers.extend(peers.iter().cloned());
496 }
497 }
498
499 let mut results = Vec::new();
501 for entry in self.local_index.iter() {
502 let (cid, (embedding, ns)) = entry.pair();
503
504 if ns != &query.namespace {
505 continue;
506 }
507
508 let distance = self.compute_distance(&query.embedding, embedding, &query.namespace)?;
509 let score = 1.0 / (1.0 + distance); results.push(SemanticResult {
512 cid: *cid,
513 score,
514 peer: PeerId::random(), metadata: HashMap::new(),
516 });
517 }
518
519 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
521 results.truncate(query.top_k);
522
523 if self.config.enable_caching {
525 let cache_key = self.compute_cache_key(&query);
526 let entry = CacheEntry {
527 results: results.clone(),
528 timestamp: start,
529 };
530 self.query_cache.insert(cache_key, entry);
531
532 self.cleanup_cache();
534 }
535
536 let latency = start.elapsed().as_millis() as f64;
538 let mut stats = self.stats.write();
539 stats.total_queries += 1;
540 stats.successful_queries += 1;
541 stats.cache_misses += 1;
542
543 let alpha = 0.1;
545 stats.avg_query_latency_ms = alpha * latency + (1.0 - alpha) * stats.avg_query_latency_ms;
546
547 *stats
548 .queries_per_namespace
549 .entry(query.namespace.0.clone())
550 .or_insert(0) += 1;
551
552 Ok(results)
553 }
554
555 fn compute_distance(
557 &self,
558 a: &[f32],
559 b: &[f32],
560 namespace: &NamespaceId,
561 ) -> Result<f32, SemanticDhtError> {
562 let ns = self
563 .namespaces
564 .get(namespace)
565 .ok_or_else(|| SemanticDhtError::UnknownNamespace(namespace.0.clone()))?;
566
567 let distance = match ns.distance_metric {
568 DistanceMetric::Euclidean => a
569 .iter()
570 .zip(b.iter())
571 .map(|(x, y)| (x - y).powi(2))
572 .sum::<f32>()
573 .sqrt(),
574 DistanceMetric::Cosine => {
575 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
576 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
577 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
578 1.0 - (dot / (norm_a * norm_b))
579 }
580 DistanceMetric::Manhattan => a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum(),
581 DistanceMetric::DotProduct => {
582 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>() }
584 };
585
586 Ok(distance)
587 }
588
589 fn compute_cache_key(&self, query: &SemanticQuery) -> Vec<u8> {
591 let mut data = Vec::new();
593 data.extend_from_slice(query.namespace.0.as_bytes());
594 for &val in &query.embedding {
595 data.extend_from_slice(&val.to_le_bytes());
596 }
597 data
598 }
599
600 fn cleanup_cache(&self) {
602 if self.query_cache.len() <= self.config.max_cache_size {
603 return;
604 }
605
606 let now = Instant::now();
607 let ttl = self.config.cache_ttl;
608
609 self.query_cache
610 .retain(|_, entry| now.duration_since(entry.timestamp) < ttl);
611 }
612
613 pub fn stats(&self) -> SemanticDhtStats {
615 self.stats.read().clone()
616 }
617
618 pub fn get_namespace(&self, id: &NamespaceId) -> Option<SemanticNamespace> {
620 self.namespaces.get(id).map(|ns| ns.clone())
621 }
622
623 pub fn list_namespaces(&self) -> Vec<NamespaceId> {
625 self.namespaces
626 .iter()
627 .map(|entry| entry.key().clone())
628 .collect()
629 }
630}
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 fn create_test_embedding(dim: usize, seed: f32) -> Vec<f32> {
637 (0..dim).map(|i| ((i as f32 + seed) * 0.1).sin()).collect()
638 }
639
640 #[test]
641 fn test_semantic_dht_creation() {
642 let config = SemanticDhtConfig::default();
643 let dht = SemanticDht::new(config);
644 assert_eq!(dht.list_namespaces().len(), 0);
645 }
646
647 #[test]
648 fn test_namespace_registration() {
649 let dht = SemanticDht::new(SemanticDhtConfig::default());
650
651 let namespace = SemanticNamespace {
652 id: NamespaceId::text(),
653 dimension: 128,
654 distance_metric: DistanceMetric::Cosine,
655 lsh_config: LshConfig::default(),
656 };
657
658 dht.register_namespace(namespace.clone()).unwrap();
659
660 assert_eq!(dht.list_namespaces().len(), 1);
661 assert_eq!(
662 dht.get_namespace(&NamespaceId::text()).unwrap().dimension,
663 128
664 );
665 }
666
667 #[test]
668 fn test_lsh_hash_computation() {
669 let dht = SemanticDht::new(SemanticDhtConfig::default());
670
671 let namespace = SemanticNamespace {
672 id: NamespaceId::text(),
673 dimension: 64,
674 distance_metric: DistanceMetric::Euclidean,
675 lsh_config: LshConfig::default(),
676 };
677
678 dht.register_namespace(namespace).unwrap();
679
680 let embedding = create_test_embedding(64, 1.0);
681 let hashes = dht
682 .compute_lsh_hashes(&embedding, &NamespaceId::text())
683 .unwrap();
684
685 assert_eq!(hashes.len(), 4); assert_eq!(hashes[0].bucket.len(), 8); }
688
689 #[test]
690 fn test_content_indexing() {
691 let dht = SemanticDht::new(SemanticDhtConfig::default());
692
693 let namespace = SemanticNamespace {
694 id: NamespaceId::text(),
695 dimension: 64,
696 distance_metric: DistanceMetric::Cosine,
697 lsh_config: LshConfig::default(),
698 };
699
700 dht.register_namespace(namespace).unwrap();
701
702 let cid = Cid::default();
703 let embedding = create_test_embedding(64, 1.0);
704
705 dht.index_content(cid, embedding, NamespaceId::text())
706 .unwrap();
707
708 let stats = dht.stats();
709 assert_eq!(stats.indexed_content, 1);
710 }
711
712 #[test]
713 fn test_semantic_query() {
714 let dht = SemanticDht::new(SemanticDhtConfig::default());
715
716 let namespace = SemanticNamespace {
717 id: NamespaceId::text(),
718 dimension: 64,
719 distance_metric: DistanceMetric::Cosine,
720 lsh_config: LshConfig::default(),
721 };
722
723 dht.register_namespace(namespace).unwrap();
724
725 for i in 0..5 {
727 let cid = Cid::default();
728 let embedding = create_test_embedding(64, i as f32);
729 dht.index_content(cid, embedding, NamespaceId::text())
730 .unwrap();
731 }
732
733 let query = SemanticQuery {
735 embedding: create_test_embedding(64, 2.5),
736 namespace: NamespaceId::text(),
737 top_k: 3,
738 metadata_filter: None,
739 timeout: Duration::from_secs(5),
740 };
741
742 let results = dht.query(query).unwrap();
743 assert!(results.len() <= 3);
744
745 for i in 1..results.len() {
747 assert!(results[i - 1].score >= results[i].score);
748 }
749 }
750
751 #[test]
752 fn test_distance_metrics() {
753 let dht = SemanticDht::new(SemanticDhtConfig::default());
754
755 let ns_euclidean = SemanticNamespace {
757 id: NamespaceId::new("euclidean"),
758 dimension: 3,
759 distance_metric: DistanceMetric::Euclidean,
760 lsh_config: LshConfig::default(),
761 };
762 dht.register_namespace(ns_euclidean).unwrap();
763
764 let a = vec![1.0, 0.0, 0.0];
765 let b = vec![0.0, 1.0, 0.0];
766 let dist = dht
767 .compute_distance(&a, &b, &NamespaceId::new("euclidean"))
768 .unwrap();
769 assert!((dist - 1.414).abs() < 0.01); let ns_cosine = SemanticNamespace {
773 id: NamespaceId::new("cosine"),
774 dimension: 2,
775 distance_metric: DistanceMetric::Cosine,
776 lsh_config: LshConfig::default(),
777 };
778 dht.register_namespace(ns_cosine).unwrap();
779
780 let a = vec![1.0, 0.0];
781 let b = vec![1.0, 0.0];
782 let dist = dht
783 .compute_distance(&a, &b, &NamespaceId::new("cosine"))
784 .unwrap();
785 assert!(dist.abs() < 0.01); }
787
788 #[test]
789 fn test_query_caching() {
790 let config = SemanticDhtConfig {
791 enable_caching: true,
792 cache_ttl: Duration::from_secs(60),
793 ..Default::default()
794 };
795
796 let dht = SemanticDht::new(config);
797
798 let namespace = SemanticNamespace {
799 id: NamespaceId::text(),
800 dimension: 64,
801 distance_metric: DistanceMetric::Cosine,
802 lsh_config: LshConfig::default(),
803 };
804
805 dht.register_namespace(namespace).unwrap();
806
807 let cid = Cid::default();
809 let embedding = create_test_embedding(64, 1.0);
810 dht.index_content(cid, embedding.clone(), NamespaceId::text())
811 .unwrap();
812
813 let query = SemanticQuery {
815 embedding: embedding.clone(),
816 namespace: NamespaceId::text(),
817 top_k: 3,
818 metadata_filter: None,
819 timeout: Duration::from_secs(5),
820 };
821
822 let _ = dht.query(query.clone()).unwrap();
823 let stats1 = dht.stats();
824 assert_eq!(stats1.cache_misses, 1);
825
826 let _ = dht.query(query).unwrap();
828 let stats2 = dht.stats();
829 assert_eq!(stats2.cache_hits, 1);
830 }
831
832 #[test]
833 fn test_invalid_dimension() {
834 let dht = SemanticDht::new(SemanticDhtConfig::default());
835
836 let namespace = SemanticNamespace {
837 id: NamespaceId::text(),
838 dimension: 64,
839 distance_metric: DistanceMetric::Cosine,
840 lsh_config: LshConfig::default(),
841 };
842
843 dht.register_namespace(namespace).unwrap();
844
845 let cid = Cid::default();
846 let wrong_embedding = create_test_embedding(32, 1.0); let result = dht.index_content(cid, wrong_embedding, NamespaceId::text());
849 assert!(matches!(
850 result,
851 Err(SemanticDhtError::InvalidDimension { .. })
852 ));
853 }
854
855 #[test]
856 fn test_unknown_namespace() {
857 let dht = SemanticDht::new(SemanticDhtConfig::default());
858
859 let embedding = create_test_embedding(64, 1.0);
860 let result = dht.compute_lsh_hashes(&embedding, &NamespaceId::text());
861
862 assert!(matches!(result, Err(SemanticDhtError::UnknownNamespace(_))));
863 }
864
865 #[test]
866 fn test_lsh_hash_to_cid() {
867 let hash = LshHash {
868 table: 0,
869 bucket: vec![1, 2, 3, 4],
870 };
871
872 let cid = hash.to_cid();
873 assert_eq!(cid.version(), cid::Version::V1);
874 }
875
876 #[test]
877 fn test_namespace_ids() {
878 assert_eq!(NamespaceId::text().0, "text");
879 assert_eq!(NamespaceId::image().0, "image");
880 assert_eq!(NamespaceId::audio().0, "audio");
881 }
882}