use std::collections::HashMap;
use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
#[derive(Clone, Debug)]
enum Expr {
Score, Literal(f64), Param(String), BinOp(Box<Expr>, BinOp, Box<Expr>),
UnaryMinus(Box<Expr>),
Fn1(MathFn1, Box<Expr>), Fn2(MathFn2, Box<Expr>, Box<Expr>), }
#[derive(Clone, Debug)]
enum BinOp {
Add,
Sub,
Mul,
Div,
Mod,
}
#[derive(Clone, Debug)]
enum MathFn1 {
Log,
Log10,
Sqrt,
Abs,
Ln,
}
#[derive(Clone, Debug)]
enum MathFn2 {
Pow,
Max,
Min,
}
impl Expr {
fn eval(&self, score: f64, params: &HashMap<String, f64>) -> f64 {
match self {
Expr::Score => score,
Expr::Literal(v) => *v,
Expr::Param(name) => params.get(name).copied().unwrap_or(0.0),
Expr::BinOp(l, op, r) => {
let lv = l.eval(score, params);
let rv = r.eval(score, params);
match op {
BinOp::Add => lv + rv,
BinOp::Sub => lv - rv,
BinOp::Mul => lv * rv,
BinOp::Div => {
if rv != 0.0 {
lv / rv
} else {
0.0
}
}
BinOp::Mod => {
if rv != 0.0 {
lv % rv
} else {
0.0
}
}
}
}
Expr::UnaryMinus(e) => -e.eval(score, params),
Expr::Fn1(f, arg) => {
let v = arg.eval(score, params);
match f {
MathFn1::Log => (1.0 + v).ln(),
MathFn1::Log10 => v.log10(),
MathFn1::Sqrt => v.sqrt(),
MathFn1::Abs => v.abs(),
MathFn1::Ln => v.ln(),
}
}
Expr::Fn2(f, a, b) => {
let av = a.eval(score, params);
let bv = b.eval(score, params);
match f {
MathFn2::Pow => av.powf(bv),
MathFn2::Max => av.max(bv),
MathFn2::Min => av.min(bv),
}
}
}
}
}
struct Parser<'a> {
tokens: Vec<Token>,
pos: usize,
params: &'a HashMap<String, f64>,
}
#[derive(Clone, Debug)]
enum Token {
Num(f64),
Ident(String),
Plus,
Minus,
Star,
Slash,
Percent,
LParen,
RParen,
Comma,
Dot,
}
fn tokenize(s: &str) -> Vec<Token> {
let mut tokens = Vec::new();
let bytes = s.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b' ' | b'\t' | b'\n' => i += 1,
b'+' => {
tokens.push(Token::Plus);
i += 1;
}
b'-' => {
tokens.push(Token::Minus);
i += 1;
}
b'*' => {
tokens.push(Token::Star);
i += 1;
}
b'/' => {
tokens.push(Token::Slash);
i += 1;
}
b'%' => {
tokens.push(Token::Percent);
i += 1;
}
b'(' => {
tokens.push(Token::LParen);
i += 1;
}
b')' => {
tokens.push(Token::RParen);
i += 1;
}
b',' => {
tokens.push(Token::Comma);
i += 1;
}
b'.' if i + 1 < bytes.len() && bytes[i + 1].is_ascii_alphabetic() => {
tokens.push(Token::Dot);
i += 1;
}
b'0'..=b'9' | b'.' => {
let start = i;
while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') {
i += 1;
}
let num: f64 = std::str::from_utf8(&bytes[start..i])
.unwrap()
.parse()
.unwrap_or(0.0);
tokens.push(Token::Num(num));
}
b'a'..=b'z' | b'A'..=b'Z' | b'_' => {
let start = i;
while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
i += 1;
}
let ident = std::str::from_utf8(&bytes[start..i]).unwrap().to_string();
tokens.push(Token::Ident(ident));
}
_ => i += 1, }
}
tokens
}
impl<'a> Parser<'a> {
fn new(source: &str, params: &'a HashMap<String, f64>) -> Self {
Self {
tokens: tokenize(source),
pos: 0,
params,
}
}
fn peek(&self) -> Option<&Token> {
self.tokens.get(self.pos)
}
fn advance(&mut self) -> Option<Token> {
let t = self.tokens.get(self.pos).cloned();
self.pos += 1;
t
}
fn parse_expr(&mut self) -> Expr {
self.parse_additive()
}
fn parse_additive(&mut self) -> Expr {
let mut left = self.parse_multiplicative();
loop {
match self.peek() {
Some(Token::Plus) => {
self.advance();
left = Expr::BinOp(
Box::new(left),
BinOp::Add,
Box::new(self.parse_multiplicative()),
);
}
Some(Token::Minus) => {
self.advance();
left = Expr::BinOp(
Box::new(left),
BinOp::Sub,
Box::new(self.parse_multiplicative()),
);
}
_ => break,
}
}
left
}
fn parse_multiplicative(&mut self) -> Expr {
let mut left = self.parse_unary();
loop {
match self.peek() {
Some(Token::Star) => {
self.advance();
left = Expr::BinOp(Box::new(left), BinOp::Mul, Box::new(self.parse_unary()));
}
Some(Token::Slash) => {
self.advance();
left = Expr::BinOp(Box::new(left), BinOp::Div, Box::new(self.parse_unary()));
}
Some(Token::Percent) => {
self.advance();
left = Expr::BinOp(Box::new(left), BinOp::Mod, Box::new(self.parse_unary()));
}
_ => break,
}
}
left
}
fn parse_unary(&mut self) -> Expr {
if matches!(self.peek(), Some(Token::Minus)) {
self.advance();
Expr::UnaryMinus(Box::new(self.parse_primary()))
} else {
self.parse_primary()
}
}
fn parse_primary(&mut self) -> Expr {
match self.advance() {
Some(Token::Num(n)) => Expr::Literal(n),
Some(Token::LParen) => {
let e = self.parse_expr();
self.advance(); e
}
Some(Token::Ident(name)) => {
if name == "_score" {
Expr::Score
} else if name == "Math" {
self.advance(); if let Some(Token::Ident(func)) = self.advance() {
self.advance(); let arg1 = self.parse_expr();
match func.as_str() {
"sqrt" | "abs" | "log" | "log10" | "ln" => {
self.advance(); let f = match func.as_str() {
"sqrt" => MathFn1::Sqrt,
"abs" => MathFn1::Abs,
"log" => MathFn1::Log,
"log10" => MathFn1::Log10,
"ln" => MathFn1::Ln,
_ => unreachable!(),
};
Expr::Fn1(f, Box::new(arg1))
}
"pow" | "max" | "min" => {
self.advance(); let arg2 = self.parse_expr();
self.advance(); let f = match func.as_str() {
"pow" => MathFn2::Pow,
"max" => MathFn2::Max,
"min" => MathFn2::Min,
_ => unreachable!(),
};
Expr::Fn2(f, Box::new(arg1), Box::new(arg2))
}
_ => Expr::Literal(0.0),
}
} else {
Expr::Literal(0.0)
}
} else if self.params.contains_key(&name) {
Expr::Param(name)
} else {
Expr::Literal(0.0) }
}
_ => Expr::Literal(0.0),
}
}
}
fn compile_script(source: &str, params: &HashMap<String, f64>) -> Expr {
let mut parser = Parser::new(source, params);
parser.parse_expr()
}
pub struct ScriptScoreQuery {
pub(crate) query: Box<dyn Query>,
pub script: String,
pub params: HashMap<String, f64>,
}
impl Query for ScriptScoreQuery {
fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
let inner = self.query.bind(searcher, score_mode)?;
let expr = compile_script(&self.script, &self.params);
Ok(Box::new(BoundScriptScoreQuery {
inner,
expr,
params: self.params.clone(),
}))
}
}
struct BoundScriptScoreQuery {
inner: Box<dyn BoundQuery>,
expr: Expr,
params: HashMap<String, f64>,
}
impl BoundQuery for BoundScriptScoreQuery {
fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
let inner = match self.inner.scorer_supplier(reader)? {
Some(s) => s,
None => return Ok(None),
};
Ok(Some(Box::new(ScriptScoreScorerSupplier {
inner,
expr: self.expr.clone(),
params: self.params.clone(),
})))
}
}
struct ScriptScoreScorerSupplier {
inner: Box<dyn ScorerSupplier>,
expr: Expr,
params: HashMap<String, f64>,
}
impl ScorerSupplier for ScriptScoreScorerSupplier {
fn cost(&self) -> u64 {
self.inner.cost()
}
fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
let inner = self.inner.scorer()?;
Ok(Box::new(ScriptScoreScorer {
inner,
expr: self.expr,
params: self.params,
}))
}
}
struct ScriptScoreScorer {
inner: Box<dyn Scorer>,
expr: Expr,
params: HashMap<String, f64>,
}
impl Scorer for ScriptScoreScorer {
fn doc_id(&self) -> DocId {
self.inner.doc_id()
}
fn next(&mut self) -> DocId {
self.inner.next()
}
fn advance(&mut self, target: DocId) -> DocId {
self.inner.advance(target)
}
fn score(&mut self) -> f32 {
let base = self.inner.score() as f64;
self.expr.eval(base, &self.params) as f32
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::Token;
use crate::core::{FieldId, SegmentId};
use crate::mapping::{FieldType, Mapping};
use crate::query::match_query::MatchQuery;
use crate::segment::builder::SegmentBuilder;
use crate::segment::reader::SegmentReader;
fn make_tokens(terms: &[&str]) -> Vec<Token> {
terms
.iter()
.enumerate()
.map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
.collect()
}
#[test]
fn expr_eval_basic() {
let params = HashMap::from([("factor".to_string(), 3.0)]);
let expr = compile_script("_score * factor", ¶ms);
assert_eq!(expr.eval(2.0, ¶ms), 6.0);
}
#[test]
fn expr_eval_math_functions() {
let params = HashMap::new();
let expr = compile_script("Math.sqrt(_score)", ¶ms);
assert!((expr.eval(4.0, ¶ms) - 2.0).abs() < 0.001);
let expr2 = compile_script("Math.log(_score)", ¶ms);
assert!((expr2.eval(1.0, ¶ms) - (2.0f64).ln()).abs() < 0.001);
let expr3 = compile_script("Math.max(_score, 10.0)", ¶ms);
assert_eq!(expr3.eval(5.0, ¶ms), 10.0);
}
#[test]
fn expr_eval_complex() {
let params = HashMap::from([("boost".to_string(), 1.5)]);
let expr = compile_script("(_score + 1.0) * boost", ¶ms);
assert_eq!(expr.eval(2.0, ¶ms), 4.5);
}
#[test]
fn expr_eval_constant() {
let params = HashMap::new();
let expr = compile_script("42.0", ¶ms);
assert_eq!(expr.eval(999.0, ¶ms), 42.0);
}
#[test]
fn script_score_query() {
let schema = Mapping::builder().field("text", FieldType::Text).build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(
&[(FieldId::new(0), make_tokens(&["hello", "world"]))],
b"{}",
);
let reader = SegmentReader::open(builder.build()).unwrap();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = ScriptScoreQuery {
query: Box::new(MatchQuery {
field: "text".into(),
query_text: "hello".into(),
analyzer: None,
}),
script: "_score * factor".to_string(),
params: HashMap::from([("factor".to_string(), 3.0)]),
};
let results = searcher.search_query(&query, 10, 0).unwrap();
assert_eq!(results.total_hits.value, 1);
assert!(results.hits[0].score > 0.0);
}
#[test]
fn script_score_constant_42() {
let schema = Mapping::builder().field("text", FieldType::Text).build();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(&[(FieldId::new(0), make_tokens(&["hello"]))], b"{}");
let reader = SegmentReader::open(builder.build()).unwrap();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = ScriptScoreQuery {
query: Box::new(MatchQuery {
field: "text".into(),
query_text: "hello".into(),
analyzer: None,
}),
script: "42.0".to_string(),
params: HashMap::new(),
};
let results = searcher.search_query(&query, 10, 0).unwrap();
assert_eq!(results.total_hits.value, 1);
assert!((results.hits[0].score - 42.0).abs() < 0.01);
}
}