Skip to main content

agentic_memory/v3/indexes/
semantic.rs

1//! Semantic similarity index with text fallback search.
2
3use super::{Index, IndexResult};
4use crate::v3::block::{Block, BlockHash};
5use std::collections::HashMap;
6
7/// Semantic similarity index.
8/// Supports both embedding-based and text-based search.
9pub struct SemanticIndex {
10    /// Embeddings storage: sequence -> embedding vector
11    embeddings: HashMap<u64, Vec<f32>>,
12
13    /// Text content for fallback search
14    text_content: HashMap<u64, String>,
15
16    /// Block hashes
17    hashes: HashMap<u64, BlockHash>,
18
19    /// Embedding dimension
20    dimension: usize,
21}
22
23impl SemanticIndex {
24    pub fn new(dimension: usize) -> Self {
25        Self {
26            embeddings: HashMap::new(),
27            text_content: HashMap::new(),
28            hashes: HashMap::new(),
29            dimension,
30        }
31    }
32
33    /// Add embedding for a block
34    pub fn add_embedding(&mut self, sequence: u64, embedding: Vec<f32>) {
35        if embedding.len() == self.dimension {
36            self.embeddings.insert(sequence, embedding);
37        }
38    }
39
40    /// Search by embedding vector
41    pub fn search_by_embedding(&self, query: &[f32], limit: usize) -> Vec<IndexResult> {
42        if query.len() != self.dimension {
43            return vec![];
44        }
45
46        let mut scores: Vec<(u64, f32)> = self
47            .embeddings
48            .iter()
49            .map(|(seq, emb)| {
50                let score = cosine_similarity(query, emb);
51                (*seq, score)
52            })
53            .collect();
54
55        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
56
57        scores
58            .into_iter()
59            .take(limit)
60            .filter_map(|(seq, score)| {
61                self.hashes.get(&seq).map(|hash| IndexResult {
62                    block_sequence: seq,
63                    block_hash: *hash,
64                    score,
65                })
66            })
67            .collect()
68    }
69
70    /// Search by text (fallback when no embeddings available)
71    pub fn search_by_text(&self, query: &str, limit: usize) -> Vec<IndexResult> {
72        let query_lower = query.to_lowercase();
73        let query_words: Vec<&str> = query_lower.split_whitespace().collect();
74
75        if query_words.is_empty() {
76            return vec![];
77        }
78
79        let mut scores: Vec<(u64, f32)> = self
80            .text_content
81            .iter()
82            .map(|(seq, text)| {
83                let text_lower = text.to_lowercase();
84                let matches = query_words
85                    .iter()
86                    .filter(|w| text_lower.contains(*w))
87                    .count();
88                let score = matches as f32 / query_words.len() as f32;
89                (*seq, score)
90            })
91            .filter(|(_, score)| *score > 0.0)
92            .collect();
93
94        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
95
96        scores
97            .into_iter()
98            .take(limit)
99            .filter_map(|(seq, score)| {
100                self.hashes.get(&seq).map(|hash| IndexResult {
101                    block_sequence: seq,
102                    block_hash: *hash,
103                    score,
104                })
105            })
106            .collect()
107    }
108
109    /// Get indexed block count
110    pub fn len(&self) -> usize {
111        self.hashes.len()
112    }
113
114    /// Check if empty
115    pub fn is_empty(&self) -> bool {
116        self.hashes.is_empty()
117    }
118}
119
120impl Index for SemanticIndex {
121    fn index(&mut self, block: &Block) {
122        self.hashes.insert(block.sequence, block.hash);
123
124        if let Some(text) = block.extract_text() {
125            self.text_content.insert(block.sequence, text);
126        }
127    }
128
129    fn remove(&mut self, sequence: u64) {
130        self.embeddings.remove(&sequence);
131        self.text_content.remove(&sequence);
132        self.hashes.remove(&sequence);
133    }
134
135    fn rebuild(&mut self, blocks: impl Iterator<Item = Block>) {
136        self.embeddings.clear();
137        self.text_content.clear();
138        self.hashes.clear();
139        for block in blocks {
140            self.index(&block);
141        }
142    }
143}
144
145fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
146    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
147    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
148    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
149
150    if norm_a == 0.0 || norm_b == 0.0 {
151        0.0
152    } else {
153        dot / (norm_a * norm_b)
154    }
155}