hermes_core/
wand.rs

1//! WAND (Weak AND) data structures for efficient top-k retrieval
2//!
3//! This module provides pre-computed data structures that enable WAND and BlockMax WAND
4//! query optimization. The key insight is that we can pre-compute upper bound scores
5//! for each term, allowing us to skip documents that can't possibly make it into the
6//! top-k results.
7//!
8//! # Usage
9//!
10//! 1. Pre-compute term statistics using `hermes-tool term-stats`:
11//!    ```bash
12//!    cat docs.jsonl | hermes-tool term-stats --field content > wand_stats.json
13//!    ```
14//!
15//! 2. Load the statistics during indexing or query time:
16//!    ```rust,ignore
17//!    let wand_data = WandData::from_json_file("wand_stats.json")?;
18//!    let idf = wand_data.get_idf("content", "hello").unwrap_or(0.0);
19//!    ```
20
21use std::collections::HashMap;
22use std::io::{Read, Write};
23use std::path::Path;
24
25use serde::{Deserialize, Serialize};
26
27use crate::error::{Error, Result};
28
29/// Per-term statistics for WAND optimization
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TermWandInfo {
32    /// Document frequency (number of documents containing this term)
33    pub df: u32,
34    /// Total term frequency across all documents
35    pub total_tf: u64,
36    /// Maximum term frequency in any single document
37    pub max_tf: u32,
38    /// IDF value: log((N - df + 0.5) / (df + 0.5))
39    pub idf: f32,
40    /// Upper bound score for this term (BM25 with max_tf and conservative length norm)
41    pub upper_bound: f32,
42}
43
44/// Collection-level WAND data
45///
46/// Contains pre-computed statistics needed for efficient WAND query processing.
47/// This data is typically computed offline using `hermes-tool term-stats` and
48/// loaded at index open time.
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct WandData {
51    /// Total number of documents in the collection
52    pub total_docs: u64,
53    /// Total number of tokens across all documents
54    pub total_tokens: u64,
55    /// Average document length (tokens per document)
56    pub avg_doc_len: f32,
57    /// BM25 k1 parameter used for computing upper bounds
58    pub bm25_k1: f32,
59    /// BM25 b parameter used for computing upper bounds
60    pub bm25_b: f32,
61    /// Per-term statistics, keyed by "field:term"
62    #[serde(skip)]
63    term_map: HashMap<String, TermWandInfo>,
64    /// Raw term list (for serialization)
65    terms: Vec<TermEntry>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69struct TermEntry {
70    term: String,
71    df: u32,
72    total_tf: u64,
73    max_tf: u32,
74    idf: f32,
75    upper_bound: f32,
76}
77
78impl WandData {
79    /// Create empty WAND data
80    pub fn new(total_docs: u64, avg_doc_len: f32) -> Self {
81        Self {
82            total_docs,
83            total_tokens: (total_docs as f32 * avg_doc_len) as u64,
84            avg_doc_len,
85            bm25_k1: 1.2,
86            bm25_b: 0.75,
87            term_map: HashMap::new(),
88            terms: Vec::new(),
89        }
90    }
91
92    /// Load WAND data from a JSON file
93    pub fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
94        let file = std::fs::File::open(path).map_err(Error::Io)?;
95        let reader = std::io::BufReader::new(file);
96        Self::from_json_reader(reader)
97    }
98
99    /// Load WAND data from a JSON reader
100    pub fn from_json_reader<R: Read>(reader: R) -> Result<Self> {
101        let mut data: WandData =
102            serde_json::from_reader(reader).map_err(|e| Error::Serialization(e.to_string()))?;
103        data.build_term_map();
104        Ok(data)
105    }
106
107    /// Load WAND data from JSON bytes
108    pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
109        let mut data: WandData =
110            serde_json::from_slice(bytes).map_err(|e| Error::Serialization(e.to_string()))?;
111        data.build_term_map();
112        Ok(data)
113    }
114
115    /// Save WAND data to a JSON file
116    pub fn to_json_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
117        let file = std::fs::File::create(path).map_err(Error::Io)?;
118        let writer = std::io::BufWriter::new(file);
119        self.to_json_writer(writer)
120    }
121
122    /// Write WAND data to a JSON writer
123    pub fn to_json_writer<W: Write>(&self, writer: W) -> Result<()> {
124        serde_json::to_writer_pretty(writer, self)
125            .map_err(|e| Error::Serialization(e.to_string()))?;
126        Ok(())
127    }
128
129    /// Build the term map from the terms vector
130    fn build_term_map(&mut self) {
131        self.term_map.clear();
132        for entry in &self.terms {
133            self.term_map.insert(
134                entry.term.clone(),
135                TermWandInfo {
136                    df: entry.df,
137                    total_tf: entry.total_tf,
138                    max_tf: entry.max_tf,
139                    idf: entry.idf,
140                    upper_bound: entry.upper_bound,
141                },
142            );
143        }
144    }
145
146    /// Get IDF for a term in a field
147    ///
148    /// Returns None if the term is not found in the pre-computed data.
149    /// In that case, you should compute IDF on-the-fly using the segment's
150    /// document count and term document frequency.
151    pub fn get_idf(&self, field: &str, term: &str) -> Option<f32> {
152        let key = format!("{}:{}", field, term);
153        self.term_map.get(&key).map(|info| info.idf)
154    }
155
156    /// Get full term info for a term in a field
157    pub fn get_term_info(&self, field: &str, term: &str) -> Option<&TermWandInfo> {
158        let key = format!("{}:{}", field, term);
159        self.term_map.get(&key)
160    }
161
162    /// Get upper bound score for a term
163    pub fn get_upper_bound(&self, field: &str, term: &str) -> Option<f32> {
164        let key = format!("{}:{}", field, term);
165        self.term_map.get(&key).map(|info| info.upper_bound)
166    }
167
168    /// Compute IDF for a term given its document frequency
169    ///
170    /// Uses the BM25 IDF formula: log((N - df + 0.5) / (df + 0.5))
171    pub fn compute_idf(&self, df: u32) -> f32 {
172        let n = self.total_docs as f32;
173        let df = df as f32;
174        ((n - df + 0.5) / (df + 0.5)).ln()
175    }
176
177    /// Compute upper bound score for a term given max_tf and IDF
178    ///
179    /// Uses conservative length normalization (assumes shortest possible document)
180    pub fn compute_upper_bound(&self, max_tf: u32, idf: f32) -> f32 {
181        let tf = max_tf as f32;
182        let min_length_norm = 1.0 - self.bm25_b;
183        let tf_norm = (tf * (self.bm25_k1 + 1.0)) / (tf + self.bm25_k1 * min_length_norm);
184        idf * tf_norm
185    }
186
187    /// Add or update term statistics
188    pub fn add_term(&mut self, field: &str, term: &str, df: u32, total_tf: u64, max_tf: u32) {
189        let idf = self.compute_idf(df);
190        let upper_bound = self.compute_upper_bound(max_tf, idf);
191        let key = format!("{}:{}", field, term);
192
193        let info = TermWandInfo {
194            df,
195            total_tf,
196            max_tf,
197            idf,
198            upper_bound,
199        };
200
201        self.term_map.insert(key.clone(), info.clone());
202        self.terms.push(TermEntry {
203            term: key,
204            df,
205            total_tf,
206            max_tf,
207            idf,
208            upper_bound,
209        });
210    }
211
212    /// Get the number of terms in the WAND data
213    pub fn num_terms(&self) -> usize {
214        self.term_map.len()
215    }
216
217    /// Check if WAND data is available for a term
218    pub fn has_term(&self, field: &str, term: &str) -> bool {
219        let key = format!("{}:{}", field, term);
220        self.term_map.contains_key(&key)
221    }
222}
223
224impl Default for WandData {
225    fn default() -> Self {
226        Self::new(0, 0.0)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_wand_data_basic() {
236        let mut wand = WandData::new(1000, 100.0);
237        wand.add_term("content", "hello", 100, 500, 10);
238        wand.add_term("content", "world", 50, 200, 5);
239
240        assert!(wand.has_term("content", "hello"));
241        assert!(wand.has_term("content", "world"));
242        assert!(!wand.has_term("content", "missing"));
243
244        let hello_idf = wand.get_idf("content", "hello").unwrap();
245        let world_idf = wand.get_idf("content", "world").unwrap();
246
247        // "world" appears in fewer docs, so should have higher IDF
248        assert!(world_idf > hello_idf);
249    }
250
251    #[test]
252    fn test_wand_data_serialization() {
253        let mut wand = WandData::new(1000, 100.0);
254        wand.add_term("title", "test", 50, 100, 3);
255
256        let json = serde_json::to_string(&wand).unwrap();
257        let restored = WandData::from_json_bytes(json.as_bytes()).unwrap();
258
259        assert_eq!(restored.total_docs, wand.total_docs);
260        assert_eq!(restored.avg_doc_len, wand.avg_doc_len);
261        assert!(restored.has_term("title", "test"));
262    }
263
264    #[test]
265    fn test_compute_idf() {
266        let wand = WandData::new(1000, 100.0);
267
268        // Rare term (df=10) should have high IDF
269        let rare_idf = wand.compute_idf(10);
270        // Common term (df=500) should have low IDF
271        let common_idf = wand.compute_idf(500);
272
273        assert!(rare_idf > common_idf);
274        assert!(rare_idf > 0.0);
275    }
276}