1use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
9
10use crate::query::ast::{FieldValueModifier, FunctionBoostMode, FunctionScoreMode, ScoreFunction};
11use crate::query::{BoundQuery, Query, ScorerSupplier};
12use crate::search::searcher::Searcher;
13use crate::segment::reader::SegmentReader;
14
15pub struct FunctionScoreQuery {
16 pub(crate) query: Box<dyn Query>,
17 pub functions: Vec<ScoreFunction>,
18 pub score_mode: FunctionScoreMode,
19 pub boost_mode: FunctionBoostMode,
20}
21
22impl Query for FunctionScoreQuery {
23 fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
24 let inner = self.query.bind(searcher, score_mode)?;
25 Ok(Box::new(BoundFunctionScoreQuery {
26 inner,
27 functions: self.functions.clone(),
28 score_mode: self.score_mode.clone(),
29 boost_mode: self.boost_mode.clone(),
30 }))
31 }
32}
33
34struct BoundFunctionScoreQuery {
35 inner: Box<dyn BoundQuery>,
36 functions: Vec<ScoreFunction>,
37 score_mode: FunctionScoreMode,
38 boost_mode: FunctionBoostMode,
39}
40
41impl BoundQuery for BoundFunctionScoreQuery {
42 fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
43 let inner = match self.inner.scorer_supplier(reader)? {
44 Some(s) => s,
45 None => return Ok(None),
46 };
47
48 let mut field_values: Vec<Option<Vec<f64>>> = Vec::new();
50 for func in &self.functions {
51 match func {
52 ScoreFunction::FieldValueFactor { field, missing, .. } => {
53 let field_id = reader
54 .header()
55 .fields
56 .iter()
57 .find(|f| f.field_name == *field)
58 .map(|f| f.field_id);
59 if let Some(fid) = field_id {
60 if let Some(col) = reader.column(fid) {
61 let doc_count = col.doc_count();
62 let vals: Vec<f64> = (0..doc_count)
63 .map(|i| col.numeric_value(i).unwrap_or(*missing))
64 .collect();
65 field_values.push(Some(vals));
66 } else {
67 field_values.push(None);
68 }
69 } else {
70 field_values.push(None);
71 }
72 }
73 _ => field_values.push(None),
74 }
75 }
76
77 Ok(Some(Box::new(FunctionScoreScorerSupplier {
78 inner,
79 functions: self.functions.clone(),
80 score_mode: self.score_mode.clone(),
81 boost_mode: self.boost_mode.clone(),
82 field_values,
83 })))
84 }
85}
86
87struct FunctionScoreScorerSupplier {
88 inner: Box<dyn ScorerSupplier>,
89 functions: Vec<ScoreFunction>,
90 score_mode: FunctionScoreMode,
91 boost_mode: FunctionBoostMode,
92 field_values: Vec<Option<Vec<f64>>>,
93}
94
95impl ScorerSupplier for FunctionScoreScorerSupplier {
96 fn cost(&self) -> u64 {
97 self.inner.cost()
98 }
99 fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
100 let inner = self.inner.scorer()?;
101 Ok(Box::new(FunctionScoreScorer {
102 inner,
103 functions: self.functions,
104 score_mode: self.score_mode,
105 boost_mode: self.boost_mode,
106 field_values: self.field_values,
107 }))
108 }
109}
110
111struct FunctionScoreScorer {
112 inner: Box<dyn Scorer>,
113 functions: Vec<ScoreFunction>,
114 score_mode: FunctionScoreMode,
115 boost_mode: FunctionBoostMode,
116 field_values: Vec<Option<Vec<f64>>>,
117}
118
119impl FunctionScoreScorer {
120 fn compute_function_score(&self, doc_id: DocId) -> f32 {
121 let mut scores: Vec<f32> = Vec::new();
122
123 for (i, func) in self.functions.iter().enumerate() {
124 let s = match func {
125 ScoreFunction::Weight(w) => *w,
126 ScoreFunction::FieldValueFactor {
127 factor,
128 modifier,
129 missing,
130 ..
131 } => {
132 let val = self.field_values[i]
133 .as_ref()
134 .and_then(|vals| vals.get(doc_id.as_u32() as usize).copied())
135 .unwrap_or(*missing);
136 let modified = apply_modifier(val, modifier);
137 (modified * *factor as f64) as f32
138 }
139 ScoreFunction::RandomScore { seed } => random_score_hash(*seed, doc_id),
140 };
141 scores.push(s);
142 }
143
144 if scores.is_empty() {
145 return 1.0;
146 }
147
148 match self.score_mode {
149 FunctionScoreMode::Multiply => scores.iter().product(),
150 FunctionScoreMode::Sum => scores.iter().sum(),
151 FunctionScoreMode::Avg => scores.iter().sum::<f32>() / scores.len() as f32,
152 FunctionScoreMode::First => scores[0],
153 FunctionScoreMode::Max => scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
154 FunctionScoreMode::Min => scores.iter().cloned().fold(f32::INFINITY, f32::min),
155 }
156 }
157}
158
159fn apply_modifier(val: f64, modifier: &FieldValueModifier) -> f64 {
160 match modifier {
161 FieldValueModifier::None => val,
162 FieldValueModifier::Log1p => (1.0 + val).log10(),
163 FieldValueModifier::Log2p => (2.0 + val).log10(),
164 FieldValueModifier::Ln1p => (1.0 + val).ln(),
165 FieldValueModifier::Ln2p => (2.0 + val).ln(),
166 FieldValueModifier::Sqrt => val.sqrt(),
167 FieldValueModifier::Square => val * val,
168 FieldValueModifier::Reciprocal => 1.0 / val.max(f64::MIN_POSITIVE),
169 }
170}
171
172impl Scorer for FunctionScoreScorer {
173 fn doc_id(&self) -> DocId {
174 self.inner.doc_id()
175 }
176 fn next(&mut self) -> DocId {
177 self.inner.next()
178 }
179 fn advance(&mut self, target: DocId) -> DocId {
180 self.inner.advance(target)
181 }
182
183 fn score(&mut self) -> f32 {
184 let query_score = self.inner.score();
185 let func_score = self.compute_function_score(self.inner.doc_id());
186
187 match self.boost_mode {
188 FunctionBoostMode::Multiply => query_score * func_score,
189 FunctionBoostMode::Replace => func_score,
190 FunctionBoostMode::Sum => query_score + func_score,
191 FunctionBoostMode::Avg => (query_score + func_score) / 2.0,
192 FunctionBoostMode::Max => query_score.max(func_score),
193 FunctionBoostMode::Min => query_score.min(func_score),
194 }
195 }
196
197 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
198 None
199 }
200}
201
202fn random_score_hash(seed: u64, doc_id: DocId) -> f32 {
211 let seed32 = ((seed >> 32) as u32) ^ (seed as u32);
213 let mut h = doc_id.as_u32() ^ seed32;
214 h ^= h >> 16;
215 h = h.wrapping_mul(0x85ebca6b);
216 h ^= h >> 13;
217 h = h.wrapping_mul(0xc2b2ae35);
218 h ^= h >> 16;
219 (h & 0x00FFFFFF) as f32 / (1u32 << 24) as f32
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::analysis::Token;
226 use crate::columnar::writer::ColumnValue;
227 use crate::core::{FieldId, SegmentId};
228 use crate::mapping::{FieldType, Mapping};
229 use crate::query::match_query::MatchQuery;
230 use crate::segment::builder::SegmentBuilder;
231 use crate::segment::reader::SegmentReader;
232
233 fn make_tokens(terms: &[&str]) -> Vec<Token> {
234 terms
235 .iter()
236 .enumerate()
237 .map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
238 .collect()
239 }
240
241 #[test]
242 fn function_score_weight() {
243 let schema = Mapping::builder().field("text", FieldType::Text).build();
244 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
245 builder.add_document(
246 &[(FieldId::new(0), make_tokens(&["hello", "world"]))],
247 b"{}",
248 );
249 let reader = SegmentReader::open(builder.build()).unwrap();
250 let store = crate::search::segment_store::SegmentStore::new(
251 vec![reader],
252 crate::analysis::AnalyzerRegistry::new(),
253 None,
254 None,
255 );
256 let searcher = Searcher::new(&store);
257
258 let query = FunctionScoreQuery {
259 query: Box::new(MatchQuery {
260 field: "text".into(),
261 query_text: "hello".into(),
262 analyzer: None,
263 }),
264 functions: vec![ScoreFunction::Weight(2.0)],
265 score_mode: FunctionScoreMode::Multiply,
266 boost_mode: FunctionBoostMode::Multiply,
267 };
268
269 let results = searcher.search_query(&query, 10, 0).unwrap();
270 assert_eq!(results.total_hits.value, 1);
271 assert!(results.hits[0].score > 0.0);
273 }
274
275 #[test]
282 fn random_score_adjacent_docs_uncorrelated() {
283 let scores: Vec<f32> = (0..100u32)
287 .map(|i| random_score_hash(42, DocId::new(i)))
288 .collect();
289
290 let equal_pairs = scores.windows(2).filter(|w| w[0] == w[1]).count();
293 assert!(
294 equal_pairs < 5,
295 "{equal_pairs}/99 adjacent pairs have identical scores — \
296 hash function does not avalanche"
297 );
298
299 let ascending = scores.windows(2).filter(|w| w[1] > w[0]).count();
302 assert!(
303 ascending > 30 && ascending < 70,
304 "{ascending}/99 ascending pairs — expected ~50 for random distribution"
305 );
306 }
307
308 #[test]
309 fn random_score_deterministic() {
310 let s1 = random_score_hash(42, DocId::new(100));
312 let s2 = random_score_hash(42, DocId::new(100));
313 assert_eq!(s1, s2);
314 }
315
316 #[test]
317 fn random_score_different_seeds() {
318 let s1 = random_score_hash(1, DocId::new(100));
319 let s2 = random_score_hash(2, DocId::new(100));
320 assert_ne!(s1, s2);
321 }
322
323 #[test]
324 fn random_score_uniform_distribution() {
325 let mut buckets = [0u32; 10];
328 for i in 0..10_000u32 {
329 let score = random_score_hash(42, DocId::new(i));
330 let bucket = ((score * 10.0) as usize).min(9);
331 buckets[bucket] += 1;
332 }
333 let expected = 1000.0f64;
334 let chi_sq: f64 = buckets
335 .iter()
336 .map(|&b| {
337 let diff = b as f64 - expected;
338 diff * diff / expected
339 })
340 .sum();
341 assert!(
342 chi_sq < 21.67,
343 "distribution not uniform: chi_sq={chi_sq}, buckets={buckets:?}"
344 );
345 }
346
347 #[test]
348 fn random_score_in_range() {
349 for i in 0..10_000u32 {
350 let score = random_score_hash(42, DocId::new(i));
351 assert!(
352 (0.0..1.0).contains(&score),
353 "out of range: {score} for doc {i}"
354 );
355 }
356 }
357
358 #[test]
359 fn function_score_field_value_factor() {
360 let schema = Mapping::builder()
361 .field("text", FieldType::Text)
362 .field("popularity", FieldType::Integer)
363 .build();
364 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
365
366 builder.add_document(&[(FieldId::new(0), make_tokens(&["search"]))], b"{}");
368 builder.add_column_value(FieldId::new(1), ColumnValue::I64(10));
369
370 builder.add_document(&[(FieldId::new(0), make_tokens(&["search"]))], b"{}");
372 builder.add_column_value(FieldId::new(1), ColumnValue::I64(100));
373
374 let reader = SegmentReader::open(builder.build()).unwrap();
375 let store = crate::search::segment_store::SegmentStore::new(
376 vec![reader],
377 crate::analysis::AnalyzerRegistry::new(),
378 None,
379 None,
380 );
381 let searcher = Searcher::new(&store);
382
383 let query = FunctionScoreQuery {
384 query: Box::new(MatchQuery {
385 field: "text".into(),
386 query_text: "search".into(),
387 analyzer: None,
388 }),
389 functions: vec![ScoreFunction::FieldValueFactor {
390 field: "popularity".into(),
391 factor: 1.0,
392 modifier: FieldValueModifier::Log1p,
393 missing: 1.0,
394 }],
395 score_mode: FunctionScoreMode::Multiply,
396 boost_mode: FunctionBoostMode::Multiply,
397 };
398
399 let results = searcher.search_query(&query, 10, 0).unwrap();
400 assert_eq!(results.total_hits.value, 2);
401 assert!(results.hits[0].score > results.hits[1].score);
403 }
404}