Skip to main content

content_index/
query.rs

1use crate::{IndexError, IndexRecord, UfpIndex};
2use hashbrown::HashSet;
3use std::cmp::Ordering;
4
5/// Result entry for a similarity query.
6#[derive(Debug, Clone)]
7pub struct QueryResult {
8    /// Canonical hash of the matched document.
9    pub canonical_hash: String,
10    /// Similarity score (0.0 to 1.0, higher is more similar).
11    pub score: f32,
12    /// Metadata associated with the matched document.
13    pub metadata: serde_json::Value,
14}
15
16/// Defines the search mode
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum QueryMode {
19    /// Compare quantized embeddings with cosine similarity
20    Semantic,
21    /// Compare perceptual MinHash signatures with Jaccard similarity
22    Perceptual,
23}
24
25/// Provides semantic & perceptual retrieval methods
26impl UfpIndex {
27    /// Compute cosine similarity between two quantized vectors.
28    /// Simple implementation that the compiler can auto-vectorize.
29    #[inline]
30    fn cosine_similarity(a: &[i8], b: &[i8]) -> f32 {
31        if a.len() != b.len() || a.is_empty() {
32            return 0.0;
33        }
34
35        let mut dot: i32 = 0;
36        let mut norm_a: i32 = 0;
37        let mut norm_b: i32 = 0;
38
39        for (&x, &y) in a.iter().zip(b.iter()) {
40            let xi = x as i32;
41            let yi = y as i32;
42            dot += xi * yi;
43            norm_a += xi * xi;
44            norm_b += yi * yi;
45        }
46
47        let norm_a_f = (norm_a as f32).sqrt();
48        let norm_b_f = (norm_b as f32).sqrt();
49
50        if norm_a_f == 0.0 || norm_b_f == 0.0 {
51            return 0.0;
52        }
53
54        dot as f32 / (norm_a_f * norm_b_f)
55    }
56
57    /// Compute Jaccard similarity for perceptual fingerprints (MinHash).
58    /// This is the size of the intersection divided by the size of the union.
59    #[inline]
60    fn jaccard_similarity(
61        query: &HashSet<u64>,
62        candidate: &[u64],
63        scratch: &mut HashSet<u64>,
64    ) -> f32 {
65        if query.is_empty() || candidate.is_empty() {
66            return 0.0;
67        }
68        // The scratch space is used to avoid re-allocating a HashSet for each candidate.
69        scratch.clear();
70
71        let mut intersection = 0usize;
72        let mut union = query.len();
73
74        for &value in candidate {
75            // If the value is already in the scratch set, it's a duplicate in the candidate.
76            if !scratch.insert(value) {
77                continue;
78            }
79
80            if query.contains(&value) {
81                intersection += 1;
82            } else {
83                union += 1;
84            }
85        }
86
87        if union == 0 {
88            0.0
89        } else {
90            intersection as f32 / union as f32
91        }
92    }
93
94    /// Search for top-k most similar entries.
95    pub fn search(
96        &self,
97        query: &IndexRecord,
98        mode: QueryMode,
99        top_k: usize,
100    ) -> Result<Vec<QueryResult>, IndexError> {
101        if top_k == 0 {
102            return Ok(Vec::new());
103        }
104
105        // Extract the query vectors, returning early if they are empty for the selected mode.
106        let query_embedding = query.embedding.as_ref().filter(|emb| !emb.is_empty());
107        let query_perceptual = query.perceptual.as_ref().filter(|mh| !mh.is_empty());
108
109        if matches!(mode, QueryMode::Semantic) && query_embedding.is_none() {
110            return Ok(Vec::new());
111        }
112        if matches!(mode, QueryMode::Perceptual) && query_perceptual.is_none() {
113            return Ok(Vec::new());
114        }
115
116        // For perceptual search, convert the query MinHash vector to a HashSet for efficient lookups.
117        let perceptual_set = query_perceptual.map(|mh| {
118            let mut set = HashSet::with_capacity(mh.len());
119            set.extend(mh.iter().copied());
120            set
121        });
122
123        let mut results = Vec::new();
124        let mut scratch = HashSet::new();
125        let mut processed_hashes = std::collections::HashSet::new();
126
127        match mode {
128            QueryMode::Perceptual => {
129                if let (Some(query_set), Some(_)) = (perceptual_set.as_ref(), query_perceptual) {
130                    // Count candidate frequencies using the lock-free inverted index
131                    let mut candidate_counts = std::collections::HashMap::new();
132                    for &hash_val in query_set {
133                        if let Some(candidates) = self.perceptual_index.get(&hash_val) {
134                            for candidate_hash in candidates.value() {
135                                *candidate_counts.entry(candidate_hash.clone()).or_insert(0) += 1;
136                            }
137                        }
138                    }
139
140                    // Calculate Jaccard similarity for candidates
141                    for (candidate_hash, intersection_size) in candidate_counts {
142                        if intersection_size > 0 {
143                            if let Some(rec_data) = self.backend.get(&candidate_hash)? {
144                                let rec = self.decode_record(&rec_data);
145                                if let Ok(record) = rec {
146                                    if let Some(rp) = &record.perceptual {
147                                        let score =
148                                            Self::jaccard_similarity(query_set, rp, &mut scratch);
149                                        if score > 0.0 {
150                                            results.push(QueryResult {
151                                                canonical_hash: candidate_hash.clone(),
152                                                score,
153                                                metadata: record.metadata.clone(),
154                                            });
155                                            processed_hashes.insert(candidate_hash);
156                                        }
157                                    }
158                                }
159                            }
160                        }
161                    }
162                }
163            }
164            QueryMode::Semantic => {
165                // Try to use ANN if available and dataset is large enough
166                self.rebuild_ann_if_needed();
167
168                if let Some(query_embedding) = query_embedding {
169                    if self.should_use_ann() {
170                        // Use ANN for approximate search
171                        if let Ok(ann_lock) = self.ann_index.try_lock() {
172                            if let Some(ref ann) = *ann_lock {
173                                // Convert query from i8 to f32
174                                let query_f32: Vec<f32> =
175                                    query_embedding.iter().map(|&v| v as f32 / 100.0).collect();
176
177                                if let Ok(ann_results) = ann.search(&query_f32, top_k * 2) {
178                                    for ann_result in ann_results {
179                                        if let Some(candidate_hash) = ann.get_id(ann_result.index) {
180                                            if let Some(rec_data) =
181                                                self.backend.get(candidate_hash)?
182                                            {
183                                                if let Ok(record) = self.decode_record(&rec_data) {
184                                                    // Convert distance back to similarity
185                                                    let score =
186                                                        1.0 - ann_result.distance.clamp(0.0, 1.0);
187                                                    if score > 0.0 {
188                                                        results.push(QueryResult {
189                                                            canonical_hash: candidate_hash.clone(),
190                                                            score,
191                                                            metadata: record.metadata.clone(),
192                                                        });
193                                                        processed_hashes
194                                                            .insert(candidate_hash.clone());
195                                                    }
196                                                }
197                                            }
198                                        }
199                                    }
200                                }
201                            }
202                        }
203                    }
204
205                    // Fall back to linear scan if ANN not available or didn't return enough results
206                    if results.is_empty() {
207                        // Simple vector search using lock-free DashMap
208                        for entry in self.semantic_index.iter() {
209                            let candidate_hash = entry.key();
210                            let candidate_embedding = entry.value();
211                            let score =
212                                Self::cosine_similarity(query_embedding, candidate_embedding);
213                            if score > 0.0 && !processed_hashes.contains(candidate_hash) {
214                                if let Some(rec_data) = self.backend.get(candidate_hash)? {
215                                    let rec = self.decode_record(&rec_data);
216                                    if let Ok(record) = rec {
217                                        results.push(QueryResult {
218                                            canonical_hash: candidate_hash.clone(),
219                                            score,
220                                            metadata: record.metadata.clone(),
221                                        });
222                                        processed_hashes.insert(candidate_hash.clone());
223                                    }
224                                }
225                            }
226                        }
227                    }
228                }
229            }
230        }
231
232        // Sort results by score in descending order.
233        // Ties are broken by the canonical hash to ensure deterministic ordering.
234        results.sort_unstable_by(|a, b| {
235            b.score
236                .partial_cmp(&a.score)
237                .unwrap_or(Ordering::Equal)
238                .then_with(|| a.canonical_hash.cmp(&b.canonical_hash))
239        });
240        // Return only the top-k results.
241        results.truncate(top_k);
242        Ok(results)
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use crate::{BackendConfig, IndexConfig, INDEX_SCHEMA_VERSION};
250    use serde_json::json;
251
252    #[test]
253    fn jaccard_similarity_counts_value_matches() {
254        let mut query_set = HashSet::new();
255        query_set.extend([1_u64, 2, 3, 4]);
256
257        let mut scratch = HashSet::new();
258        let candidate = vec![4_u64, 2, 8, 9];
259        let score = UfpIndex::jaccard_similarity(&query_set, &candidate, &mut scratch);
260
261        assert!((score - (2.0 / 6.0)).abs() < f32::EPSILON);
262    }
263
264    #[test]
265    fn semantic_search_orders_by_score_and_tie_breaks_hashes() {
266        let index = seed_index(vec![
267            semantic_record("doc-b", &[5, 0, 0, 0]),
268            semantic_record("doc-a", &[5, 0, 0, 0]),
269            semantic_record("doc-c", &[1, 1, 1, 1]),
270        ]);
271
272        let query = IndexRecord {
273            schema_version: INDEX_SCHEMA_VERSION,
274            canonical_hash: "query".into(),
275            perceptual: None,
276            embedding: Some(vec![5, 0, 0, 0]),
277            metadata: json!({}),
278        };
279
280        let hits = index
281            .search(&query, QueryMode::Semantic, 3)
282            .expect("semantic search");
283        assert_eq!(hits.len(), 3);
284        assert_eq!(hits[0].canonical_hash, "doc-a");
285        assert_eq!(hits[1].canonical_hash, "doc-b");
286        assert_eq!(hits[2].canonical_hash, "doc-c");
287        assert!((hits[0].score - hits[1].score).abs() < f32::EPSILON);
288    }
289
290    #[test]
291    fn perceptual_search_respects_top_k_and_filters_zero_scores() {
292        let index = seed_index(vec![
293            perceptual_record("doc-a", &[1, 2, 9, 10]),
294            perceptual_record("doc-b", &[3, 4, 7, 8]),
295            perceptual_record("doc-c", &[10, 11, 12, 13]),
296        ]);
297
298        let query = IndexRecord {
299            schema_version: INDEX_SCHEMA_VERSION,
300            canonical_hash: "query".into(),
301            perceptual: Some(vec![3, 4, 7, 8]),
302            embedding: None,
303            metadata: json!({}),
304        };
305
306        let hits = index
307            .search(&query, QueryMode::Perceptual, 1)
308            .expect("perceptual search");
309
310        assert_eq!(hits.len(), 1);
311        assert_eq!(hits[0].canonical_hash, "doc-b");
312        assert!(hits[0].score > 0.0);
313    }
314
315    #[test]
316    fn zero_top_k_short_circuits() {
317        let index = seed_index(vec![semantic_record("doc-a", &[1, 0, 0, 0])]);
318        let query = IndexRecord {
319            schema_version: INDEX_SCHEMA_VERSION,
320            canonical_hash: "query".into(),
321            embedding: Some(vec![1, 0, 0, 0]),
322            perceptual: None,
323            metadata: json!({}),
324        };
325
326        let hits = index
327            .search(&query, QueryMode::Semantic, 0)
328            .expect("semantic search");
329        assert!(hits.is_empty());
330    }
331
332    #[test]
333    fn cosine_similarity_basic() {
334        let a = vec![10_i8, 20, 30, 40, 50];
335        let b = vec![5_i8, 10, 15, 20, 25];
336
337        let result = UfpIndex::cosine_similarity(&a, &b);
338
339        // Compute reference
340        let dot: i32 = a
341            .iter()
342            .zip(&b)
343            .map(|(&x, &y)| (x as i32) * (y as i32))
344            .sum();
345        let norm_a = (a.iter().map(|&x| (x as i32) * (x as i32)).sum::<i32>() as f32).sqrt();
346        let norm_b = (b.iter().map(|&x| (x as i32) * (x as i32)).sum::<i32>() as f32).sqrt();
347        let expected = dot as f32 / (norm_a * norm_b);
348
349        assert!((result - expected).abs() < 0.0001);
350    }
351
352    #[test]
353    fn cosine_similarity_large_vector() {
354        // Test with vector larger than chunk size
355        let a: Vec<i8> = (0..100).map(|i| (i % 127) as i8).collect();
356        let b: Vec<i8> = (0..100).map(|i| ((i + 10) % 127) as i8).collect();
357
358        let result = UfpIndex::cosine_similarity(&a, &b);
359        assert!((0.0..=1.0).contains(&result));
360    }
361
362    fn seed_index(records: Vec<IndexRecord>) -> UfpIndex {
363        let cfg = IndexConfig::new().with_backend(BackendConfig::in_memory());
364        let index = UfpIndex::new(cfg).expect("index init");
365        for record in records {
366            index.upsert(&record).expect("seed record");
367        }
368        index
369    }
370
371    fn semantic_record(hash: &str, embedding: &[i8]) -> IndexRecord {
372        IndexRecord {
373            schema_version: INDEX_SCHEMA_VERSION,
374            canonical_hash: hash.into(),
375            perceptual: None,
376            embedding: Some(embedding.to_vec()),
377            metadata: json!({ "hash": hash }),
378        }
379    }
380
381    fn perceptual_record(hash: &str, fingerprint: &[u64]) -> IndexRecord {
382        IndexRecord {
383            schema_version: INDEX_SCHEMA_VERSION,
384            canonical_hash: hash.into(),
385            perceptual: Some(fingerprint.to_vec()),
386            embedding: None,
387            metadata: json!({ "hash": hash }),
388        }
389    }
390}