1use std::collections::HashMap;
13
14use crate::core::{DocId, Result, ScoreMode, Scorer, TwoPhaseIterator};
15
16use crate::query::{BoundQuery, Query, ScorerSupplier};
17use crate::search::searcher::Searcher;
18use crate::segment::reader::SegmentReader;
19
20#[derive(Clone, Debug)]
25enum Expr {
26 Score, Literal(f64), Param(String), BinOp(Box<Expr>, BinOp, Box<Expr>),
30 UnaryMinus(Box<Expr>),
31 Fn1(MathFn1, Box<Expr>), Fn2(MathFn2, Box<Expr>, Box<Expr>), }
34
35#[derive(Clone, Debug)]
36enum BinOp {
37 Add,
38 Sub,
39 Mul,
40 Div,
41 Mod,
42}
43
44#[derive(Clone, Debug)]
45enum MathFn1 {
46 Log,
47 Log10,
48 Sqrt,
49 Abs,
50 Ln,
51}
52
53#[derive(Clone, Debug)]
54enum MathFn2 {
55 Pow,
56 Max,
57 Min,
58}
59
60impl Expr {
61 fn eval(&self, score: f64, params: &HashMap<String, f64>) -> f64 {
63 match self {
64 Expr::Score => score,
65 Expr::Literal(v) => *v,
66 Expr::Param(name) => params.get(name).copied().unwrap_or(0.0),
67 Expr::BinOp(l, op, r) => {
68 let lv = l.eval(score, params);
69 let rv = r.eval(score, params);
70 match op {
71 BinOp::Add => lv + rv,
72 BinOp::Sub => lv - rv,
73 BinOp::Mul => lv * rv,
74 BinOp::Div => {
75 if rv != 0.0 {
76 lv / rv
77 } else {
78 0.0
79 }
80 }
81 BinOp::Mod => {
82 if rv != 0.0 {
83 lv % rv
84 } else {
85 0.0
86 }
87 }
88 }
89 }
90 Expr::UnaryMinus(e) => -e.eval(score, params),
91 Expr::Fn1(f, arg) => {
92 let v = arg.eval(score, params);
93 match f {
94 MathFn1::Log => (1.0 + v).ln(),
95 MathFn1::Log10 => v.log10(),
96 MathFn1::Sqrt => v.sqrt(),
97 MathFn1::Abs => v.abs(),
98 MathFn1::Ln => v.ln(),
99 }
100 }
101 Expr::Fn2(f, a, b) => {
102 let av = a.eval(score, params);
103 let bv = b.eval(score, params);
104 match f {
105 MathFn2::Pow => av.powf(bv),
106 MathFn2::Max => av.max(bv),
107 MathFn2::Min => av.min(bv),
108 }
109 }
110 }
111 }
112}
113
114struct Parser<'a> {
119 tokens: Vec<Token>,
120 pos: usize,
121 params: &'a HashMap<String, f64>,
122}
123
124#[derive(Clone, Debug)]
125enum Token {
126 Num(f64),
127 Ident(String),
128 Plus,
129 Minus,
130 Star,
131 Slash,
132 Percent,
133 LParen,
134 RParen,
135 Comma,
136 Dot,
137}
138
139fn tokenize(s: &str) -> Vec<Token> {
140 let mut tokens = Vec::new();
141 let bytes = s.as_bytes();
142 let mut i = 0;
143 while i < bytes.len() {
144 match bytes[i] {
145 b' ' | b'\t' | b'\n' => i += 1,
146 b'+' => {
147 tokens.push(Token::Plus);
148 i += 1;
149 }
150 b'-' => {
151 tokens.push(Token::Minus);
152 i += 1;
153 }
154 b'*' => {
155 tokens.push(Token::Star);
156 i += 1;
157 }
158 b'/' => {
159 tokens.push(Token::Slash);
160 i += 1;
161 }
162 b'%' => {
163 tokens.push(Token::Percent);
164 i += 1;
165 }
166 b'(' => {
167 tokens.push(Token::LParen);
168 i += 1;
169 }
170 b')' => {
171 tokens.push(Token::RParen);
172 i += 1;
173 }
174 b',' => {
175 tokens.push(Token::Comma);
176 i += 1;
177 }
178 b'.' if i + 1 < bytes.len() && bytes[i + 1].is_ascii_alphabetic() => {
179 tokens.push(Token::Dot);
180 i += 1;
181 }
182 b'0'..=b'9' | b'.' => {
183 let start = i;
184 while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') {
185 i += 1;
186 }
187 let num: f64 = std::str::from_utf8(&bytes[start..i])
188 .unwrap()
189 .parse()
190 .unwrap_or(0.0);
191 tokens.push(Token::Num(num));
192 }
193 b'a'..=b'z' | b'A'..=b'Z' | b'_' => {
194 let start = i;
195 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
196 i += 1;
197 }
198 let ident = std::str::from_utf8(&bytes[start..i]).unwrap().to_string();
199 tokens.push(Token::Ident(ident));
200 }
201 _ => i += 1, }
203 }
204 tokens
205}
206
207impl<'a> Parser<'a> {
208 fn new(source: &str, params: &'a HashMap<String, f64>) -> Self {
209 Self {
210 tokens: tokenize(source),
211 pos: 0,
212 params,
213 }
214 }
215
216 fn peek(&self) -> Option<&Token> {
217 self.tokens.get(self.pos)
218 }
219 fn advance(&mut self) -> Option<Token> {
220 let t = self.tokens.get(self.pos).cloned();
221 self.pos += 1;
222 t
223 }
224
225 fn parse_expr(&mut self) -> Expr {
226 self.parse_additive()
227 }
228
229 fn parse_additive(&mut self) -> Expr {
230 let mut left = self.parse_multiplicative();
231 loop {
232 match self.peek() {
233 Some(Token::Plus) => {
234 self.advance();
235 left = Expr::BinOp(
236 Box::new(left),
237 BinOp::Add,
238 Box::new(self.parse_multiplicative()),
239 );
240 }
241 Some(Token::Minus) => {
242 self.advance();
243 left = Expr::BinOp(
244 Box::new(left),
245 BinOp::Sub,
246 Box::new(self.parse_multiplicative()),
247 );
248 }
249 _ => break,
250 }
251 }
252 left
253 }
254
255 fn parse_multiplicative(&mut self) -> Expr {
256 let mut left = self.parse_unary();
257 loop {
258 match self.peek() {
259 Some(Token::Star) => {
260 self.advance();
261 left = Expr::BinOp(Box::new(left), BinOp::Mul, Box::new(self.parse_unary()));
262 }
263 Some(Token::Slash) => {
264 self.advance();
265 left = Expr::BinOp(Box::new(left), BinOp::Div, Box::new(self.parse_unary()));
266 }
267 Some(Token::Percent) => {
268 self.advance();
269 left = Expr::BinOp(Box::new(left), BinOp::Mod, Box::new(self.parse_unary()));
270 }
271 _ => break,
272 }
273 }
274 left
275 }
276
277 fn parse_unary(&mut self) -> Expr {
278 if matches!(self.peek(), Some(Token::Minus)) {
279 self.advance();
280 Expr::UnaryMinus(Box::new(self.parse_primary()))
281 } else {
282 self.parse_primary()
283 }
284 }
285
286 fn parse_primary(&mut self) -> Expr {
287 match self.advance() {
288 Some(Token::Num(n)) => Expr::Literal(n),
289 Some(Token::LParen) => {
290 let e = self.parse_expr();
291 self.advance(); e
293 }
294 Some(Token::Ident(name)) => {
295 if name == "_score" {
296 Expr::Score
297 } else if name == "Math" {
298 self.advance(); if let Some(Token::Ident(func)) = self.advance() {
301 self.advance(); let arg1 = self.parse_expr();
303 match func.as_str() {
304 "sqrt" | "abs" | "log" | "log10" | "ln" => {
305 self.advance(); let f = match func.as_str() {
307 "sqrt" => MathFn1::Sqrt,
308 "abs" => MathFn1::Abs,
309 "log" => MathFn1::Log,
310 "log10" => MathFn1::Log10,
311 "ln" => MathFn1::Ln,
312 _ => unreachable!(),
313 };
314 Expr::Fn1(f, Box::new(arg1))
315 }
316 "pow" | "max" | "min" => {
317 self.advance(); let arg2 = self.parse_expr();
319 self.advance(); let f = match func.as_str() {
321 "pow" => MathFn2::Pow,
322 "max" => MathFn2::Max,
323 "min" => MathFn2::Min,
324 _ => unreachable!(),
325 };
326 Expr::Fn2(f, Box::new(arg1), Box::new(arg2))
327 }
328 _ => Expr::Literal(0.0),
329 }
330 } else {
331 Expr::Literal(0.0)
332 }
333 } else if self.params.contains_key(&name) {
334 Expr::Param(name)
335 } else {
336 Expr::Literal(0.0) }
338 }
339 _ => Expr::Literal(0.0),
340 }
341 }
342}
343
344fn compile_script(source: &str, params: &HashMap<String, f64>) -> Expr {
345 let mut parser = Parser::new(source, params);
346 parser.parse_expr()
347}
348
349pub struct ScriptScoreQuery {
354 pub(crate) query: Box<dyn Query>,
355 pub script: String,
356 pub params: HashMap<String, f64>,
357}
358
359impl Query for ScriptScoreQuery {
360 fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
361 let inner = self.query.bind(searcher, score_mode)?;
362 let expr = compile_script(&self.script, &self.params);
363 Ok(Box::new(BoundScriptScoreQuery {
364 inner,
365 expr,
366 params: self.params.clone(),
367 }))
368 }
369}
370
371struct BoundScriptScoreQuery {
372 inner: Box<dyn BoundQuery>,
373 expr: Expr,
374 params: HashMap<String, f64>,
375}
376
377impl BoundQuery for BoundScriptScoreQuery {
378 fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
379 let inner = match self.inner.scorer_supplier(reader)? {
380 Some(s) => s,
381 None => return Ok(None),
382 };
383 Ok(Some(Box::new(ScriptScoreScorerSupplier {
384 inner,
385 expr: self.expr.clone(),
386 params: self.params.clone(),
387 })))
388 }
389}
390
391struct ScriptScoreScorerSupplier {
392 inner: Box<dyn ScorerSupplier>,
393 expr: Expr,
394 params: HashMap<String, f64>,
395}
396
397impl ScorerSupplier for ScriptScoreScorerSupplier {
398 fn cost(&self) -> u64 {
399 self.inner.cost()
400 }
401 fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
402 let inner = self.inner.scorer()?;
403 Ok(Box::new(ScriptScoreScorer {
404 inner,
405 expr: self.expr,
406 params: self.params,
407 }))
408 }
409}
410
411struct ScriptScoreScorer {
412 inner: Box<dyn Scorer>,
413 expr: Expr,
414 params: HashMap<String, f64>,
415}
416
417impl Scorer for ScriptScoreScorer {
418 fn doc_id(&self) -> DocId {
419 self.inner.doc_id()
420 }
421 fn next(&mut self) -> DocId {
422 self.inner.next()
423 }
424 fn advance(&mut self, target: DocId) -> DocId {
425 self.inner.advance(target)
426 }
427
428 fn score(&mut self) -> f32 {
429 let base = self.inner.score() as f64;
430 self.expr.eval(base, &self.params) as f32
431 }
432
433 fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
434 None
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::analysis::Token;
442 use crate::core::{FieldId, SegmentId};
443 use crate::mapping::{FieldType, Mapping};
444 use crate::query::match_query::MatchQuery;
445 use crate::segment::builder::SegmentBuilder;
446 use crate::segment::reader::SegmentReader;
447
448 fn make_tokens(terms: &[&str]) -> Vec<Token> {
449 terms
450 .iter()
451 .enumerate()
452 .map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
453 .collect()
454 }
455
456 #[test]
457 fn expr_eval_basic() {
458 let params = HashMap::from([("factor".to_string(), 3.0)]);
459 let expr = compile_script("_score * factor", ¶ms);
460 assert_eq!(expr.eval(2.0, ¶ms), 6.0);
461 }
462
463 #[test]
464 fn expr_eval_math_functions() {
465 let params = HashMap::new();
466 let expr = compile_script("Math.sqrt(_score)", ¶ms);
467 assert!((expr.eval(4.0, ¶ms) - 2.0).abs() < 0.001);
468
469 let expr2 = compile_script("Math.log(_score)", ¶ms);
470 assert!((expr2.eval(1.0, ¶ms) - (2.0f64).ln()).abs() < 0.001); let expr3 = compile_script("Math.max(_score, 10.0)", ¶ms);
473 assert_eq!(expr3.eval(5.0, ¶ms), 10.0);
474 }
475
476 #[test]
477 fn expr_eval_complex() {
478 let params = HashMap::from([("boost".to_string(), 1.5)]);
479 let expr = compile_script("(_score + 1.0) * boost", ¶ms);
480 assert_eq!(expr.eval(2.0, ¶ms), 4.5);
481 }
482
483 #[test]
484 fn expr_eval_constant() {
485 let params = HashMap::new();
486 let expr = compile_script("42.0", ¶ms);
487 assert_eq!(expr.eval(999.0, ¶ms), 42.0);
488 }
489
490 #[test]
491 fn script_score_query() {
492 let schema = Mapping::builder().field("text", FieldType::Text).build();
493 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
494 builder.add_document(
495 &[(FieldId::new(0), make_tokens(&["hello", "world"]))],
496 b"{}",
497 );
498 let reader = SegmentReader::open(builder.build()).unwrap();
499 let store = crate::search::segment_store::SegmentStore::new(
500 vec![reader],
501 crate::analysis::AnalyzerRegistry::new(),
502 None,
503 None,
504 );
505 let searcher = Searcher::new(&store);
506
507 let query = ScriptScoreQuery {
508 query: Box::new(MatchQuery {
509 field: "text".into(),
510 query_text: "hello".into(),
511 analyzer: None,
512 }),
513 script: "_score * factor".to_string(),
514 params: HashMap::from([("factor".to_string(), 3.0)]),
515 };
516
517 let results = searcher.search_query(&query, 10, 0).unwrap();
518 assert_eq!(results.total_hits.value, 1);
519 assert!(results.hits[0].score > 0.0);
520 }
521
522 #[test]
523 fn script_score_constant_42() {
524 let schema = Mapping::builder().field("text", FieldType::Text).build();
525 let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
526 builder.add_document(&[(FieldId::new(0), make_tokens(&["hello"]))], b"{}");
527 let reader = SegmentReader::open(builder.build()).unwrap();
528 let store = crate::search::segment_store::SegmentStore::new(
529 vec![reader],
530 crate::analysis::AnalyzerRegistry::new(),
531 None,
532 None,
533 );
534 let searcher = Searcher::new(&store);
535
536 let query = ScriptScoreQuery {
537 query: Box::new(MatchQuery {
538 field: "text".into(),
539 query_text: "hello".into(),
540 analyzer: None,
541 }),
542 script: "42.0".to_string(),
543 params: HashMap::new(),
544 };
545
546 let results = searcher.search_query(&query, 10, 0).unwrap();
547 assert_eq!(results.total_hits.value, 1);
548 assert!((results.hits[0].score - 42.0).abs() < 0.01);
549 }
550}