1use std::collections::HashMap;
22use std::io::{Read, Write};
23use std::path::Path;
24
25use serde::{Deserialize, Serialize};
26
27use crate::error::{Error, Result};
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct TermWandInfo {
32 pub df: u32,
34 pub total_tf: u64,
36 pub max_tf: u32,
38 pub idf: f32,
40 pub upper_bound: f32,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct WandData {
51 pub total_docs: u64,
53 pub total_tokens: u64,
55 pub avg_doc_len: f32,
57 pub bm25_k1: f32,
59 pub bm25_b: f32,
61 #[serde(skip)]
63 term_map: HashMap<String, TermWandInfo>,
64 terms: Vec<TermEntry>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69struct TermEntry {
70 term: String,
71 df: u32,
72 total_tf: u64,
73 max_tf: u32,
74 idf: f32,
75 upper_bound: f32,
76}
77
78impl WandData {
79 pub fn new(total_docs: u64, avg_doc_len: f32) -> Self {
81 Self {
82 total_docs,
83 total_tokens: (total_docs as f32 * avg_doc_len) as u64,
84 avg_doc_len,
85 bm25_k1: 1.2,
86 bm25_b: 0.75,
87 term_map: HashMap::new(),
88 terms: Vec::new(),
89 }
90 }
91
92 pub fn from_json_file<P: AsRef<Path>>(path: P) -> Result<Self> {
94 let file = std::fs::File::open(path).map_err(Error::Io)?;
95 let reader = std::io::BufReader::new(file);
96 Self::from_json_reader(reader)
97 }
98
99 pub fn from_json_reader<R: Read>(reader: R) -> Result<Self> {
101 let mut data: WandData =
102 serde_json::from_reader(reader).map_err(|e| Error::Serialization(e.to_string()))?;
103 data.build_term_map();
104 Ok(data)
105 }
106
107 pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
109 let mut data: WandData =
110 serde_json::from_slice(bytes).map_err(|e| Error::Serialization(e.to_string()))?;
111 data.build_term_map();
112 Ok(data)
113 }
114
115 pub fn to_json_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
117 let file = std::fs::File::create(path).map_err(Error::Io)?;
118 let writer = std::io::BufWriter::new(file);
119 self.to_json_writer(writer)
120 }
121
122 pub fn to_json_writer<W: Write>(&self, writer: W) -> Result<()> {
124 serde_json::to_writer_pretty(writer, self)
125 .map_err(|e| Error::Serialization(e.to_string()))?;
126 Ok(())
127 }
128
129 fn build_term_map(&mut self) {
131 self.term_map.clear();
132 for entry in &self.terms {
133 self.term_map.insert(
134 entry.term.clone(),
135 TermWandInfo {
136 df: entry.df,
137 total_tf: entry.total_tf,
138 max_tf: entry.max_tf,
139 idf: entry.idf,
140 upper_bound: entry.upper_bound,
141 },
142 );
143 }
144 }
145
146 pub fn get_idf(&self, field: &str, term: &str) -> Option<f32> {
152 let key = format!("{}:{}", field, term);
153 self.term_map.get(&key).map(|info| info.idf)
154 }
155
156 pub fn get_term_info(&self, field: &str, term: &str) -> Option<&TermWandInfo> {
158 let key = format!("{}:{}", field, term);
159 self.term_map.get(&key)
160 }
161
162 pub fn get_upper_bound(&self, field: &str, term: &str) -> Option<f32> {
164 let key = format!("{}:{}", field, term);
165 self.term_map.get(&key).map(|info| info.upper_bound)
166 }
167
168 pub fn compute_idf(&self, df: u32) -> f32 {
172 let n = self.total_docs as f32;
173 let df = df as f32;
174 ((n - df + 0.5) / (df + 0.5)).ln()
175 }
176
177 pub fn compute_upper_bound(&self, max_tf: u32, idf: f32) -> f32 {
181 let tf = max_tf as f32;
182 let min_length_norm = 1.0 - self.bm25_b;
183 let tf_norm = (tf * (self.bm25_k1 + 1.0)) / (tf + self.bm25_k1 * min_length_norm);
184 idf * tf_norm
185 }
186
187 pub fn add_term(&mut self, field: &str, term: &str, df: u32, total_tf: u64, max_tf: u32) {
189 let idf = self.compute_idf(df);
190 let upper_bound = self.compute_upper_bound(max_tf, idf);
191 let key = format!("{}:{}", field, term);
192
193 let info = TermWandInfo {
194 df,
195 total_tf,
196 max_tf,
197 idf,
198 upper_bound,
199 };
200
201 self.term_map.insert(key.clone(), info.clone());
202 self.terms.push(TermEntry {
203 term: key,
204 df,
205 total_tf,
206 max_tf,
207 idf,
208 upper_bound,
209 });
210 }
211
212 pub fn num_terms(&self) -> usize {
214 self.term_map.len()
215 }
216
217 pub fn has_term(&self, field: &str, term: &str) -> bool {
219 let key = format!("{}:{}", field, term);
220 self.term_map.contains_key(&key)
221 }
222}
223
224impl Default for WandData {
225 fn default() -> Self {
226 Self::new(0, 0.0)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_wand_data_basic() {
236 let mut wand = WandData::new(1000, 100.0);
237 wand.add_term("content", "hello", 100, 500, 10);
238 wand.add_term("content", "world", 50, 200, 5);
239
240 assert!(wand.has_term("content", "hello"));
241 assert!(wand.has_term("content", "world"));
242 assert!(!wand.has_term("content", "missing"));
243
244 let hello_idf = wand.get_idf("content", "hello").unwrap();
245 let world_idf = wand.get_idf("content", "world").unwrap();
246
247 assert!(world_idf > hello_idf);
249 }
250
251 #[test]
252 fn test_wand_data_serialization() {
253 let mut wand = WandData::new(1000, 100.0);
254 wand.add_term("title", "test", 50, 100, 3);
255
256 let json = serde_json::to_string(&wand).unwrap();
257 let restored = WandData::from_json_bytes(json.as_bytes()).unwrap();
258
259 assert_eq!(restored.total_docs, wand.total_docs);
260 assert_eq!(restored.avg_doc_len, wand.avg_doc_len);
261 assert!(restored.has_term("title", "test"));
262 }
263
264 #[test]
265 fn test_compute_idf() {
266 let wand = WandData::new(1000, 100.0);
267
268 let rare_idf = wand.compute_idf(10);
270 let common_idf = wand.compute_idf(500);
272
273 assert!(rare_idf > common_idf);
274 assert!(rare_idf > 0.0);
275 }
276}