Skip to main content

luci/query/
script_score.rs

1//! Script score query: custom scoring via compiled expressions.
2//!
3//! Parses a simple expression language at weight creation time and compiles
4//! it to a Rust closure for zero-overhead per-doc evaluation.
5//!
6//! Supports: `_score`, numeric literals, `+`, `-`, `*`, `/`, parentheses,
7//! `Math.log()`, `Math.sqrt()`, `Math.pow()`, `Math.max()`, `Math.min()`,
8//! `Math.abs()`, and user-defined params.
9//!
10//! See [[elasticsearch-parity]] and [[investigation-20260320-01-monty-script-score-overhead]].
11
12use std::collections::HashMap;
13
14use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
15
16use crate::query::{BoundQuery, Query, ScorerSupplier};
17use crate::search::searcher::Searcher;
18use crate::segment::reader::SegmentReader;
19
20// ---------------------------------------------------------------------------
21// Expression AST and compiler
22// ---------------------------------------------------------------------------
23
24#[derive(Clone, Debug)]
25enum Expr {
26    Score,         // _score
27    Literal(f64),  // 42.0
28    Param(String), // user param
29    BinOp(Box<Expr>, BinOp, Box<Expr>),
30    UnaryMinus(Box<Expr>),
31    Fn1(MathFn1, Box<Expr>),            // Math.sqrt(x)
32    Fn2(MathFn2, Box<Expr>, Box<Expr>), // Math.pow(x, y)
33}
34
35#[derive(Clone, Debug)]
36enum BinOp {
37    Add,
38    Sub,
39    Mul,
40    Div,
41    Mod,
42}
43
44#[derive(Clone, Debug)]
45enum MathFn1 {
46    Log,
47    Log10,
48    Sqrt,
49    Abs,
50    Ln,
51}
52
53#[derive(Clone, Debug)]
54enum MathFn2 {
55    Pow,
56    Max,
57    Min,
58}
59
60impl Expr {
61    /// Evaluate the expression with the given score and params.
62    fn eval(&self, score: f64, params: &HashMap<String, f64>) -> f64 {
63        match self {
64            Expr::Score => score,
65            Expr::Literal(v) => *v,
66            Expr::Param(name) => params.get(name).copied().unwrap_or(0.0),
67            Expr::BinOp(l, op, r) => {
68                let lv = l.eval(score, params);
69                let rv = r.eval(score, params);
70                match op {
71                    BinOp::Add => lv + rv,
72                    BinOp::Sub => lv - rv,
73                    BinOp::Mul => lv * rv,
74                    BinOp::Div => {
75                        if rv != 0.0 {
76                            lv / rv
77                        } else {
78                            0.0
79                        }
80                    }
81                    BinOp::Mod => {
82                        if rv != 0.0 {
83                            lv % rv
84                        } else {
85                            0.0
86                        }
87                    }
88                }
89            }
90            Expr::UnaryMinus(e) => -e.eval(score, params),
91            Expr::Fn1(f, arg) => {
92                let v = arg.eval(score, params);
93                match f {
94                    MathFn1::Log => (1.0 + v).ln(),
95                    MathFn1::Log10 => v.log10(),
96                    MathFn1::Sqrt => v.sqrt(),
97                    MathFn1::Abs => v.abs(),
98                    MathFn1::Ln => v.ln(),
99                }
100            }
101            Expr::Fn2(f, a, b) => {
102                let av = a.eval(score, params);
103                let bv = b.eval(score, params);
104                match f {
105                    MathFn2::Pow => av.powf(bv),
106                    MathFn2::Max => av.max(bv),
107                    MathFn2::Min => av.min(bv),
108                }
109            }
110        }
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Simple recursive descent parser
116// ---------------------------------------------------------------------------
117
118struct Parser<'a> {
119    tokens: Vec<Token>,
120    pos: usize,
121    params: &'a HashMap<String, f64>,
122}
123
124#[derive(Clone, Debug)]
125enum Token {
126    Num(f64),
127    Ident(String),
128    Plus,
129    Minus,
130    Star,
131    Slash,
132    Percent,
133    LParen,
134    RParen,
135    Comma,
136    Dot,
137}
138
139fn tokenize(s: &str) -> Vec<Token> {
140    let mut tokens = Vec::new();
141    let bytes = s.as_bytes();
142    let mut i = 0;
143    while i < bytes.len() {
144        match bytes[i] {
145            b' ' | b'\t' | b'\n' => i += 1,
146            b'+' => {
147                tokens.push(Token::Plus);
148                i += 1;
149            }
150            b'-' => {
151                tokens.push(Token::Minus);
152                i += 1;
153            }
154            b'*' => {
155                tokens.push(Token::Star);
156                i += 1;
157            }
158            b'/' => {
159                tokens.push(Token::Slash);
160                i += 1;
161            }
162            b'%' => {
163                tokens.push(Token::Percent);
164                i += 1;
165            }
166            b'(' => {
167                tokens.push(Token::LParen);
168                i += 1;
169            }
170            b')' => {
171                tokens.push(Token::RParen);
172                i += 1;
173            }
174            b',' => {
175                tokens.push(Token::Comma);
176                i += 1;
177            }
178            b'.' if i + 1 < bytes.len() && bytes[i + 1].is_ascii_alphabetic() => {
179                tokens.push(Token::Dot);
180                i += 1;
181            }
182            b'0'..=b'9' | b'.' => {
183                let start = i;
184                while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') {
185                    i += 1;
186                }
187                let num: f64 = std::str::from_utf8(&bytes[start..i])
188                    .unwrap()
189                    .parse()
190                    .unwrap_or(0.0);
191                tokens.push(Token::Num(num));
192            }
193            b'a'..=b'z' | b'A'..=b'Z' | b'_' => {
194                let start = i;
195                while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
196                    i += 1;
197                }
198                let ident = std::str::from_utf8(&bytes[start..i]).unwrap().to_string();
199                tokens.push(Token::Ident(ident));
200            }
201            _ => i += 1, // skip unknown
202        }
203    }
204    tokens
205}
206
207impl<'a> Parser<'a> {
208    fn new(source: &str, params: &'a HashMap<String, f64>) -> Self {
209        Self {
210            tokens: tokenize(source),
211            pos: 0,
212            params,
213        }
214    }
215
216    fn peek(&self) -> Option<&Token> {
217        self.tokens.get(self.pos)
218    }
219    fn advance(&mut self) -> Option<Token> {
220        let t = self.tokens.get(self.pos).cloned();
221        self.pos += 1;
222        t
223    }
224
225    fn parse_expr(&mut self) -> Expr {
226        self.parse_additive()
227    }
228
229    fn parse_additive(&mut self) -> Expr {
230        let mut left = self.parse_multiplicative();
231        loop {
232            match self.peek() {
233                Some(Token::Plus) => {
234                    self.advance();
235                    left = Expr::BinOp(
236                        Box::new(left),
237                        BinOp::Add,
238                        Box::new(self.parse_multiplicative()),
239                    );
240                }
241                Some(Token::Minus) => {
242                    self.advance();
243                    left = Expr::BinOp(
244                        Box::new(left),
245                        BinOp::Sub,
246                        Box::new(self.parse_multiplicative()),
247                    );
248                }
249                _ => break,
250            }
251        }
252        left
253    }
254
255    fn parse_multiplicative(&mut self) -> Expr {
256        let mut left = self.parse_unary();
257        loop {
258            match self.peek() {
259                Some(Token::Star) => {
260                    self.advance();
261                    left = Expr::BinOp(Box::new(left), BinOp::Mul, Box::new(self.parse_unary()));
262                }
263                Some(Token::Slash) => {
264                    self.advance();
265                    left = Expr::BinOp(Box::new(left), BinOp::Div, Box::new(self.parse_unary()));
266                }
267                Some(Token::Percent) => {
268                    self.advance();
269                    left = Expr::BinOp(Box::new(left), BinOp::Mod, Box::new(self.parse_unary()));
270                }
271                _ => break,
272            }
273        }
274        left
275    }
276
277    fn parse_unary(&mut self) -> Expr {
278        if matches!(self.peek(), Some(Token::Minus)) {
279            self.advance();
280            Expr::UnaryMinus(Box::new(self.parse_primary()))
281        } else {
282            self.parse_primary()
283        }
284    }
285
286    fn parse_primary(&mut self) -> Expr {
287        match self.advance() {
288            Some(Token::Num(n)) => Expr::Literal(n),
289            Some(Token::LParen) => {
290                let e = self.parse_expr();
291                self.advance(); // consume RParen
292                e
293            }
294            Some(Token::Ident(name)) => {
295                if name == "_score" {
296                    Expr::Score
297                } else if name == "Math" {
298                    // Expect .func(args)
299                    self.advance(); // consume Dot
300                    if let Some(Token::Ident(func)) = self.advance() {
301                        self.advance(); // consume LParen
302                        let arg1 = self.parse_expr();
303                        match func.as_str() {
304                            "sqrt" | "abs" | "log" | "log10" | "ln" => {
305                                self.advance(); // consume RParen
306                                let f = match func.as_str() {
307                                    "sqrt" => MathFn1::Sqrt,
308                                    "abs" => MathFn1::Abs,
309                                    "log" => MathFn1::Log,
310                                    "log10" => MathFn1::Log10,
311                                    "ln" => MathFn1::Ln,
312                                    _ => unreachable!(),
313                                };
314                                Expr::Fn1(f, Box::new(arg1))
315                            }
316                            "pow" | "max" | "min" => {
317                                self.advance(); // consume Comma
318                                let arg2 = self.parse_expr();
319                                self.advance(); // consume RParen
320                                let f = match func.as_str() {
321                                    "pow" => MathFn2::Pow,
322                                    "max" => MathFn2::Max,
323                                    "min" => MathFn2::Min,
324                                    _ => unreachable!(),
325                                };
326                                Expr::Fn2(f, Box::new(arg1), Box::new(arg2))
327                            }
328                            _ => Expr::Literal(0.0),
329                        }
330                    } else {
331                        Expr::Literal(0.0)
332                    }
333                } else if self.params.contains_key(&name) {
334                    Expr::Param(name)
335                } else {
336                    Expr::Literal(0.0) // unknown identifier
337                }
338            }
339            _ => Expr::Literal(0.0),
340        }
341    }
342}
343
344fn compile_script(source: &str, params: &HashMap<String, f64>) -> Expr {
345    let mut parser = Parser::new(source, params);
346    parser.parse_expr()
347}
348
349// ---------------------------------------------------------------------------
350// Query implementation
351// ---------------------------------------------------------------------------
352
353pub struct ScriptScoreQuery {
354    pub(crate) query: Box<dyn Query>,
355    pub script: String,
356    pub params: HashMap<String, f64>,
357}
358
359impl Query for ScriptScoreQuery {
360    fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
361        let inner = self.query.bind(searcher, score_mode)?;
362        let expr = compile_script(&self.script, &self.params);
363        Ok(Box::new(BoundScriptScoreQuery {
364            inner,
365            expr,
366            params: self.params.clone(),
367        }))
368    }
369}
370
371struct BoundScriptScoreQuery {
372    inner: Box<dyn BoundQuery>,
373    expr: Expr,
374    params: HashMap<String, f64>,
375}
376
377impl BoundQuery for BoundScriptScoreQuery {
378    fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
379        let inner = match self.inner.scorer_supplier(reader)? {
380            Some(s) => s,
381            None => return Ok(None),
382        };
383        Ok(Some(Box::new(ScriptScoreScorerSupplier {
384            inner,
385            expr: self.expr.clone(),
386            params: self.params.clone(),
387        })))
388    }
389}
390
391struct ScriptScoreScorerSupplier {
392    inner: Box<dyn ScorerSupplier>,
393    expr: Expr,
394    params: HashMap<String, f64>,
395}
396
397impl ScorerSupplier for ScriptScoreScorerSupplier {
398    fn cost(&self) -> u64 {
399        self.inner.cost()
400    }
401    fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
402        let inner = self.inner.scorer()?;
403        Ok(Box::new(ScriptScoreScorer {
404            inner,
405            expr: self.expr,
406            params: self.params,
407        }))
408    }
409}
410
411struct ScriptScoreScorer {
412    inner: Box<dyn Scorer>,
413    expr: Expr,
414    params: HashMap<String, f64>,
415}
416
417impl Scorer for ScriptScoreScorer {
418    fn doc_id(&self) -> DocId {
419        self.inner.doc_id()
420    }
421    fn next(&mut self) -> DocId {
422        self.inner.next()
423    }
424    fn advance(&mut self, target: DocId) -> DocId {
425        self.inner.advance(target)
426    }
427
428    fn score(&mut self) -> f32 {
429        let base = self.inner.score() as f64;
430        self.expr.eval(base, &self.params) as f32
431    }
432
433    fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
434        None
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::analysis::Token;
442    use crate::core::{FieldId, SegmentId};
443    use crate::mapping::{FieldType, Mapping};
444    use crate::query::match_query::MatchQuery;
445    use crate::segment::builder::SegmentBuilder;
446    use crate::segment::reader::SegmentReader;
447
448    fn make_tokens(terms: &[&str]) -> Vec<Token> {
449        terms
450            .iter()
451            .enumerate()
452            .map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
453            .collect()
454    }
455
456    #[test]
457    fn expr_eval_basic() {
458        let params = HashMap::from([("factor".to_string(), 3.0)]);
459        let expr = compile_script("_score * factor", &params);
460        assert_eq!(expr.eval(2.0, &params), 6.0);
461    }
462
463    #[test]
464    fn expr_eval_math_functions() {
465        let params = HashMap::new();
466        let expr = compile_script("Math.sqrt(_score)", &params);
467        assert!((expr.eval(4.0, &params) - 2.0).abs() < 0.001);
468
469        let expr2 = compile_script("Math.log(_score)", &params);
470        assert!((expr2.eval(1.0, &params) - (2.0f64).ln()).abs() < 0.001); // log(1+1)
471
472        let expr3 = compile_script("Math.max(_score, 10.0)", &params);
473        assert_eq!(expr3.eval(5.0, &params), 10.0);
474    }
475
476    #[test]
477    fn expr_eval_complex() {
478        let params = HashMap::from([("boost".to_string(), 1.5)]);
479        let expr = compile_script("(_score + 1.0) * boost", &params);
480        assert_eq!(expr.eval(2.0, &params), 4.5);
481    }
482
483    #[test]
484    fn expr_eval_constant() {
485        let params = HashMap::new();
486        let expr = compile_script("42.0", &params);
487        assert_eq!(expr.eval(999.0, &params), 42.0);
488    }
489
490    #[test]
491    fn script_score_query() {
492        let schema = Mapping::builder().field("text", FieldType::Text).build();
493        let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
494        builder.add_document(
495            &[(FieldId::new(0), make_tokens(&["hello", "world"]))],
496            b"{}",
497        );
498        let reader = SegmentReader::open(builder.build()).unwrap();
499        let store = crate::search::segment_store::SegmentStore::new(
500            vec![reader],
501            crate::analysis::AnalyzerRegistry::new(),
502            None,
503            None,
504        );
505        let searcher = Searcher::new(&store);
506
507        let query = ScriptScoreQuery {
508            query: Box::new(MatchQuery {
509                field: "text".into(),
510                query_text: "hello".into(),
511                analyzer: None,
512            }),
513            script: "_score * factor".to_string(),
514            params: HashMap::from([("factor".to_string(), 3.0)]),
515        };
516
517        let results = searcher.search_query(&query, 10, 0).unwrap();
518        assert_eq!(results.total_hits.value, 1);
519        assert!(results.hits[0].score > 0.0);
520    }
521
522    #[test]
523    fn script_score_constant_42() {
524        let schema = Mapping::builder().field("text", FieldType::Text).build();
525        let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
526        builder.add_document(&[(FieldId::new(0), make_tokens(&["hello"]))], b"{}");
527        let reader = SegmentReader::open(builder.build()).unwrap();
528        let store = crate::search::segment_store::SegmentStore::new(
529            vec![reader],
530            crate::analysis::AnalyzerRegistry::new(),
531            None,
532            None,
533        );
534        let searcher = Searcher::new(&store);
535
536        let query = ScriptScoreQuery {
537            query: Box::new(MatchQuery {
538                field: "text".into(),
539                query_text: "hello".into(),
540                analyzer: None,
541            }),
542            script: "42.0".to_string(),
543            params: HashMap::new(),
544        };
545
546        let results = searcher.search_query(&query, 10, 0).unwrap();
547        assert_eq!(results.total_hits.value, 1);
548        assert!((results.hits[0].score - 42.0).abs() < 0.01);
549    }
550}