Skip to main content

assay/
search.rs

1use std::collections::HashMap;
2
3/// Result of a search query, containing the document ID and relevance score.
4#[derive(Debug, Clone)]
5pub struct SearchResult {
6    pub id: String,
7    pub score: f64,
8}
9
10/// Trait for search backends. Implemented by `BM25Index` (in-memory fallback)
11/// and can be implemented by FTS5-backed stores when the `db` feature is enabled.
12pub trait SearchEngine {
13    /// Add a document with weighted fields.
14    /// Each field is `(field_name, field_value, field_weight)`.
15    fn add_document(&mut self, id: &str, fields: &[(&str, &str, f64)]);
16
17    /// Search for documents matching the query, returning at most `limit` results
18    /// sorted by descending relevance score.
19    fn search(&self, query: &str, limit: usize) -> Vec<SearchResult>;
20}
21
22/// Per-field data stored for a single document.
23#[derive(Debug)]
24struct FieldData {
25    tokens: Vec<String>,
26    weight: f64,
27}
28
29/// Per-document data.
30#[derive(Debug)]
31struct Document {
32    fields: HashMap<String, FieldData>,
33}
34
35/// Zero-dependency BM25 search index.
36///
37/// Uses the Okapi BM25 ranking function with configurable field weights.
38/// This is the fallback backend when no database/FTS5 is available.
39#[derive(Debug)]
40pub struct BM25Index {
41    documents: HashMap<String, Document>,
42    /// For each field name, tracks the total token count across all documents
43    /// (used to compute average field length).
44    field_total_tokens: HashMap<String, usize>,
45    /// For each field name, tracks how many documents have that field.
46    field_doc_count: HashMap<String, usize>,
47    /// Document frequency: for each term, the set of document IDs containing it.
48    doc_freq: HashMap<String, Vec<String>>,
49}
50
51impl Default for BM25Index {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl BM25Index {
58    /// BM25 term frequency saturation parameter.
59    const K1: f64 = 1.2;
60    /// BM25 document length normalization parameter.
61    const B: f64 = 0.75;
62
63    pub fn new() -> Self {
64        Self {
65            documents: HashMap::new(),
66            field_total_tokens: HashMap::new(),
67            field_doc_count: HashMap::new(),
68            doc_freq: HashMap::new(),
69        }
70    }
71
72    /// Tokenize text: split on non-alphanumeric/underscore, lowercase, filter len <= 1.
73    fn tokenize(text: &str) -> Vec<String> {
74        text.split(|c: char| !c.is_alphanumeric() && c != '_')
75            .map(|s| s.to_lowercase())
76            .filter(|s| s.len() > 1)
77            .collect()
78    }
79
80    /// Compute IDF for a term.
81    /// IDF(t) = ln((N - df + 0.5) / (df + 0.5) + 1.0)
82    fn idf(&self, term: &str) -> f64 {
83        let n = self.documents.len() as f64;
84        let df = self.doc_freq.get(term).map_or(0, |docs| docs.len()) as f64;
85        f64::ln((n - df + 0.5) / (df + 0.5) + 1.0)
86    }
87
88    /// Average field length for a given field name.
89    fn avg_field_len(&self, field_name: &str) -> f64 {
90        let total = *self.field_total_tokens.get(field_name).unwrap_or(&0) as f64;
91        let count = *self.field_doc_count.get(field_name).unwrap_or(&0) as f64;
92        if count == 0.0 {
93            return 0.0;
94        }
95        total / count
96    }
97}
98
99impl SearchEngine for BM25Index {
100    fn add_document(&mut self, id: &str, fields: &[(&str, &str, f64)]) {
101        let mut doc_fields = HashMap::new();
102        let mut seen_terms: HashMap<String, bool> = HashMap::new();
103
104        for &(field_name, field_value, weight) in fields {
105            let tokens = Self::tokenize(field_value);
106
107            // Track field statistics
108            *self
109                .field_total_tokens
110                .entry(field_name.to_string())
111                .or_insert(0) += tokens.len();
112            *self
113                .field_doc_count
114                .entry(field_name.to_string())
115                .or_insert(0) += 1;
116
117            // Track unique terms in this document for document frequency
118            for token in &tokens {
119                seen_terms.entry(token.clone()).or_insert(true);
120            }
121
122            doc_fields.insert(field_name.to_string(), FieldData { tokens, weight });
123        }
124
125        // Update document frequency: each term gets this doc ID once
126        for term in seen_terms.keys() {
127            self.doc_freq
128                .entry(term.clone())
129                .or_default()
130                .push(id.to_string());
131        }
132
133        self.documents
134            .insert(id.to_string(), Document { fields: doc_fields });
135    }
136
137    fn search(&self, query: &str, limit: usize) -> Vec<SearchResult> {
138        let query_tokens = Self::tokenize(query);
139        if query_tokens.is_empty() || self.documents.is_empty() {
140            return Vec::new();
141        }
142
143        let mut scores: HashMap<&str, f64> = HashMap::new();
144
145        for term in &query_tokens {
146            let idf = self.idf(term);
147
148            for (doc_id, doc) in &self.documents {
149                let mut doc_term_score = 0.0;
150
151                for (field_name, field_data) in &doc.fields {
152                    let tf = field_data.tokens.iter().filter(|t| *t == term).count() as f64;
153
154                    if tf == 0.0 {
155                        continue;
156                    }
157
158                    let field_len = field_data.tokens.len() as f64;
159                    let avg_fl = self.avg_field_len(field_name);
160
161                    let tf_norm = if avg_fl == 0.0 {
162                        0.0
163                    } else {
164                        (tf * (Self::K1 + 1.0))
165                            / (tf + Self::K1 * (1.0 - Self::B + Self::B * field_len / avg_fl))
166                    };
167
168                    doc_term_score += idf * tf_norm * field_data.weight;
169                }
170
171                if doc_term_score > 0.0 {
172                    *scores.entry(doc_id.as_str()).or_insert(0.0) += doc_term_score;
173                }
174            }
175        }
176
177        let mut results: Vec<SearchResult> = scores
178            .into_iter()
179            .map(|(id, score)| SearchResult {
180                id: id.to_string(),
181                score,
182            })
183            .collect();
184
185        results.sort_by(|a, b| {
186            b.score
187                .partial_cmp(&a.score)
188                .unwrap_or(std::cmp::Ordering::Equal)
189        });
190        results.truncate(limit);
191        results
192    }
193}