use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
use crate::query::ast::{FieldValueModifier, FunctionBoostMode, FunctionScoreMode, ScoreFunction};
use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
pub struct FunctionScoreQuery {
pub(crate) query: Box<dyn Query>,
pub functions: Vec<ScoreFunction>,
pub score_mode: FunctionScoreMode,
pub boost_mode: FunctionBoostMode,
}
impl Query for FunctionScoreQuery {
fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
let inner = self.query.bind(searcher, score_mode)?;
Ok(Box::new(BoundFunctionScoreQuery {
inner,
functions: self.functions.clone(),
score_mode: self.score_mode.clone(),
boost_mode: self.boost_mode.clone(),
}))
}
}
struct BoundFunctionScoreQuery {
inner: Box<dyn BoundQuery>,
functions: Vec<ScoreFunction>,
score_mode: FunctionScoreMode,
boost_mode: FunctionBoostMode,
}
impl BoundQuery for BoundFunctionScoreQuery {
fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
let inner = match self.inner.scorer_supplier(reader)? {
Some(s) => s,
None => return Ok(None),
};
let mut field_values: Vec<Option<Vec<f64>>> = Vec::new();
for func in &self.functions {
match func {
ScoreFunction::FieldValueFactor { field, missing, .. } => {
let field_id = reader
.header()
.fields
.iter()
.find(|f| f.field_name == *field)
.map(|f| f.field_id);
if let Some(fid) = field_id {
if let Some(col) = reader.column(fid) {
let doc_count = col.doc_count();
let vals: Vec<f64> = (0..doc_count)
.map(|i| col.numeric_value(i).unwrap_or(*missing))
.collect();
field_values.push(Some(vals));
} else {
field_values.push(None);
}
} else {
field_values.push(None);
}
}
_ => field_values.push(None),
}
}
Ok(Some(Box::new(FunctionScoreScorerSupplier {
inner,
functions: self.functions.clone(),
score_mode: self.score_mode.clone(),
boost_mode: self.boost_mode.clone(),
field_values,
})))
}
}
struct FunctionScoreScorerSupplier {
inner: Box<dyn ScorerSupplier>,
functions: Vec<ScoreFunction>,
score_mode: FunctionScoreMode,
boost_mode: FunctionBoostMode,
field_values: Vec<Option<Vec<f64>>>,
}
impl ScorerSupplier for FunctionScoreScorerSupplier {
fn cost(&self) -> u64 {
self.inner.cost()
}
fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
let inner = self.inner.scorer()?;
Ok(Box::new(FunctionScoreScorer {
inner,
functions: self.functions,
score_mode: self.score_mode,
boost_mode: self.boost_mode,
field_values: self.field_values,
}))
}
}
struct FunctionScoreScorer {
inner: Box<dyn Scorer>,
functions: Vec<ScoreFunction>,
score_mode: FunctionScoreMode,
boost_mode: FunctionBoostMode,
field_values: Vec<Option<Vec<f64>>>,
}
impl FunctionScoreScorer {
fn compute_function_score(&self, doc_id: DocId) -> f32 {
let mut scores: Vec<f32> = Vec::new();
for (i, func) in self.functions.iter().enumerate() {
let s = match func {
ScoreFunction::Weight(w) => *w,
ScoreFunction::FieldValueFactor {
factor,
modifier,
missing,
..
} => {
let val = self.field_values[i]
.as_ref()
.and_then(|vals| vals.get(doc_id.as_u32() as usize).copied())
.unwrap_or(*missing);
let modified = apply_modifier(val, modifier);
(modified * *factor as f64) as f32
}
ScoreFunction::RandomScore { seed } => random_score_hash(*seed, doc_id),
};
scores.push(s);
}
if scores.is_empty() {
return 1.0;
}
match self.score_mode {
FunctionScoreMode::Multiply => scores.iter().product(),
FunctionScoreMode::Sum => scores.iter().sum(),
FunctionScoreMode::Avg => scores.iter().sum::<f32>() / scores.len() as f32,
FunctionScoreMode::First => scores[0],
FunctionScoreMode::Max => scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
FunctionScoreMode::Min => scores.iter().cloned().fold(f32::INFINITY, f32::min),
}
}
}
fn apply_modifier(val: f64, modifier: &FieldValueModifier) -> f64 {
match modifier {
FieldValueModifier::None => val,
FieldValueModifier::Log1p => (1.0 + val).log10(),
FieldValueModifier::Log2p => (2.0 + val).log10(),
FieldValueModifier::Ln1p => (1.0 + val).ln(),
FieldValueModifier::Ln2p => (2.0 + val).ln(),
FieldValueModifier::Sqrt => val.sqrt(),
FieldValueModifier::Square => val * val,
FieldValueModifier::Reciprocal => 1.0 / val.max(f64::MIN_POSITIVE),
}
}
impl Scorer for FunctionScoreScorer {
fn doc_id(&self) -> DocId {
self.inner.doc_id()
}
fn next(&mut self) -> DocId {
self.inner.next()
}
fn advance(&mut self, target: DocId) -> DocId {
self.inner.advance(target)
}
fn score(&mut self) -> f32 {
let query_score = self.inner.score();
let func_score = self.compute_function_score(self.inner.doc_id());
match self.boost_mode {
FunctionBoostMode::Multiply => query_score * func_score,
FunctionBoostMode::Replace => func_score,
FunctionBoostMode::Sum => query_score + func_score,
FunctionBoostMode::Avg => (query_score + func_score) / 2.0,
FunctionBoostMode::Max => query_score.max(func_score),
FunctionBoostMode::Min => query_score.min(func_score),
}
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
fn random_score_hash(seed: u64, doc_id: DocId) -> f32 {
let seed32 = ((seed >> 32) as u32) ^ (seed as u32);
let mut h = doc_id.as_u32() ^ seed32;
h ^= h >> 16;
h = h.wrapping_mul(0x85ebca6b);
h ^= h >> 13;
h = h.wrapping_mul(0xc2b2ae35);
h ^= h >> 16;
(h & 0x00FFFFFF) as f32 / (1u32 << 24) as f32
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::Token;
use crate::columnar::writer::ColumnValue;
use crate::core::{FieldId, SegmentId};
use crate::mapping::{FieldType, Mapping};
use crate::query::match_query::MatchQuery;
use crate::segment::builder::SegmentBuilder;
use crate::segment::reader::SegmentReader;
fn make_tokens(terms: &[&str]) -> Vec<Token> {
terms
.iter()
.enumerate()
.map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
.collect()
}
#[test]
fn function_score_weight() {
let schema = Mapping::builder().field("text", FieldType::Text).build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(
&[(FieldId::new(0), make_tokens(&["hello", "world"]))],
b"{}",
);
let reader = SegmentReader::open(builder.build()).unwrap();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = FunctionScoreQuery {
query: Box::new(MatchQuery {
field: "text".into(),
query_text: "hello".into(),
analyzer: None,
}),
functions: vec![ScoreFunction::Weight(2.0)],
score_mode: FunctionScoreMode::Multiply,
boost_mode: FunctionBoostMode::Multiply,
};
let results = searcher.search_query(&query, 10, 0).unwrap();
assert_eq!(results.total_hits.value, 1);
assert!(results.hits[0].score > 0.0);
}
#[test]
fn random_score_adjacent_docs_uncorrelated() {
let scores: Vec<f32> = (0..100u32)
.map(|i| random_score_hash(42, DocId::new(i)))
.collect();
let equal_pairs = scores.windows(2).filter(|w| w[0] == w[1]).count();
assert!(
equal_pairs < 5,
"{equal_pairs}/99 adjacent pairs have identical scores — \
hash function does not avalanche"
);
let ascending = scores.windows(2).filter(|w| w[1] > w[0]).count();
assert!(
ascending > 30 && ascending < 70,
"{ascending}/99 ascending pairs — expected ~50 for random distribution"
);
}
#[test]
fn random_score_deterministic() {
let s1 = random_score_hash(42, DocId::new(100));
let s2 = random_score_hash(42, DocId::new(100));
assert_eq!(s1, s2);
}
#[test]
fn random_score_different_seeds() {
let s1 = random_score_hash(1, DocId::new(100));
let s2 = random_score_hash(2, DocId::new(100));
assert_ne!(s1, s2);
}
#[test]
fn random_score_uniform_distribution() {
let mut buckets = [0u32; 10];
for i in 0..10_000u32 {
let score = random_score_hash(42, DocId::new(i));
let bucket = ((score * 10.0) as usize).min(9);
buckets[bucket] += 1;
}
let expected = 1000.0f64;
let chi_sq: f64 = buckets
.iter()
.map(|&b| {
let diff = b as f64 - expected;
diff * diff / expected
})
.sum();
assert!(
chi_sq < 21.67,
"distribution not uniform: chi_sq={chi_sq}, buckets={buckets:?}"
);
}
#[test]
fn random_score_in_range() {
for i in 0..10_000u32 {
let score = random_score_hash(42, DocId::new(i));
assert!(
(0.0..1.0).contains(&score),
"out of range: {score} for doc {i}"
);
}
}
#[test]
fn function_score_field_value_factor() {
let schema = Mapping::builder()
.field("text", FieldType::Text)
.field("popularity", FieldType::Integer)
.build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(&[(FieldId::new(0), make_tokens(&["search"]))], b"{}");
builder.add_column_value(FieldId::new(1), ColumnValue::I64(10));
builder.add_document(&[(FieldId::new(0), make_tokens(&["search"]))], b"{}");
builder.add_column_value(FieldId::new(1), ColumnValue::I64(100));
let reader = SegmentReader::open(builder.build()).unwrap();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = FunctionScoreQuery {
query: Box::new(MatchQuery {
field: "text".into(),
query_text: "search".into(),
analyzer: None,
}),
functions: vec![ScoreFunction::FieldValueFactor {
field: "popularity".into(),
factor: 1.0,
modifier: FieldValueModifier::Log1p,
missing: 1.0,
}],
score_mode: FunctionScoreMode::Multiply,
boost_mode: FunctionBoostMode::Multiply,
};
let results = searcher.search_query(&query, 10, 0).unwrap();
assert_eq!(results.total_hits.value, 2);
assert!(results.hits[0].score > results.hits[1].score);
}
}