Skip to main content

contextdb_engine/
rank_formula.rs

1use contextdb_core::{Error, Result, Value};
2use std::collections::BTreeSet;
3
4#[derive(Debug, Clone)]
5pub struct RankFormula {
6    root: FormulaNode,
7    refs: Vec<String>,
8}
9
10impl RankFormula {
11    pub fn compile(formula: &str) -> Result<Self> {
12        Self::compile_for_index("", formula)
13    }
14
15    pub fn compile_for_index(index: &str, formula: &str) -> Result<Self> {
16        let mut parser = FormulaParser::new(index, formula);
17        let root = parser.parse_expr()?;
18        parser.skip_ws();
19        if !parser.is_eof() {
20            return Err(parser.unexpected_at_current());
21        }
22        let mut refs = BTreeSet::new();
23        collect_refs(&root, &mut refs);
24        Ok(Self {
25            root,
26            refs: refs.into_iter().collect(),
27        })
28    }
29
30    pub fn const_one() -> Self {
31        Self {
32            root: FormulaNode::Literal(1.0),
33            refs: Vec::new(),
34        }
35    }
36
37    pub fn column_refs(&self) -> &[String] {
38        &self.refs
39    }
40
41    pub fn eval_with_resolver(
42        &self,
43        vector_score: f32,
44        mut resolver: impl FnMut(&str) -> std::result::Result<Option<f32>, FormulaEvalError>,
45    ) -> std::result::Result<Option<f32>, FormulaEvalError> {
46        eval_node(&self.root, vector_score, &mut resolver)
47    }
48
49    pub fn eval(
50        &self,
51        anchor: &std::collections::HashMap<String, Value>,
52        joined: Option<&std::collections::HashMap<String, Value>>,
53        vector_score: f32,
54    ) -> std::result::Result<Option<f32>, FormulaEvalError> {
55        self.eval_with_resolver(vector_score, |column| {
56            let value = anchor
57                .get(column)
58                .or_else(|| joined.and_then(|row| row.get(column)))
59                .unwrap_or(&Value::Null);
60            value_to_rank_number(value, column)
61        })
62    }
63}
64
65#[derive(Debug, Clone)]
66enum FormulaNode {
67    Literal(f32),
68    ColRef(String),
69    Coalesce(Box<FormulaNode>, f32),
70    Mul(Box<FormulaNode>, Box<FormulaNode>),
71    Add(Box<FormulaNode>, Box<FormulaNode>),
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum FormulaEvalError {
76    UnsupportedType {
77        column: String,
78        actual: &'static str,
79    },
80    CorruptJoinedColumn {
81        column: String,
82    },
83}
84
85impl FormulaEvalError {
86    pub fn reason(&self) -> String {
87        match self {
88            FormulaEvalError::UnsupportedType { column, actual } => {
89                format!("rank formula column `{column}` has unsupported runtime type {actual}")
90            }
91            FormulaEvalError::CorruptJoinedColumn { column } => {
92                format!("failed to decode joined column `{column}`")
93            }
94        }
95    }
96}
97
98fn eval_node(
99    node: &FormulaNode,
100    vector_score: f32,
101    resolver: &mut impl FnMut(&str) -> std::result::Result<Option<f32>, FormulaEvalError>,
102) -> std::result::Result<Option<f32>, FormulaEvalError> {
103    match node {
104        FormulaNode::Literal(value) => Ok(Some(*value)),
105        FormulaNode::ColRef(column) if column == "vector_score" => Ok(Some(vector_score)),
106        FormulaNode::ColRef(column) => resolver(column),
107        FormulaNode::Coalesce(expr, fallback) => {
108            Ok(eval_node(expr, vector_score, resolver)?.or(Some(*fallback)))
109        }
110        FormulaNode::Mul(left, right) => {
111            match (
112                eval_node(left, vector_score, resolver)?,
113                eval_node(right, vector_score, resolver)?,
114            ) {
115                (Some(left), Some(right)) => Ok(Some(left * right)),
116                _ => Ok(None),
117            }
118        }
119        FormulaNode::Add(left, right) => {
120            match (
121                eval_node(left, vector_score, resolver)?,
122                eval_node(right, vector_score, resolver)?,
123            ) {
124                (Some(left), Some(right)) => Ok(Some(left + right)),
125                _ => Ok(None),
126            }
127        }
128    }
129}
130
131fn value_to_rank_number(
132    value: &Value,
133    column: &str,
134) -> std::result::Result<Option<f32>, FormulaEvalError> {
135    match value {
136        Value::Null => Ok(None),
137        Value::Float64(value) => Ok(Some(*value as f32)),
138        Value::Int64(value) => Ok(Some(*value as f32)),
139        Value::Bool(value) => Ok(Some(if *value { 1.0 } else { 0.0 })),
140        Value::Text(_) => Err(FormulaEvalError::UnsupportedType {
141            column: column.to_string(),
142            actual: "TEXT",
143        }),
144        Value::Json(_) => Err(FormulaEvalError::UnsupportedType {
145            column: column.to_string(),
146            actual: "JSON",
147        }),
148        Value::Uuid(_) => Err(FormulaEvalError::UnsupportedType {
149            column: column.to_string(),
150            actual: "UUID",
151        }),
152        Value::Vector(_) => Err(FormulaEvalError::UnsupportedType {
153            column: column.to_string(),
154            actual: "VECTOR",
155        }),
156        Value::Timestamp(_) => Err(FormulaEvalError::UnsupportedType {
157            column: column.to_string(),
158            actual: "TIMESTAMP",
159        }),
160        Value::TxId(_) => Err(FormulaEvalError::UnsupportedType {
161            column: column.to_string(),
162            actual: "TXID",
163        }),
164    }
165}
166
167fn collect_refs(node: &FormulaNode, refs: &mut BTreeSet<String>) {
168    match node {
169        FormulaNode::Literal(_) => {}
170        FormulaNode::ColRef(column) => {
171            refs.insert(column.clone());
172        }
173        FormulaNode::Coalesce(expr, _) => collect_refs(expr, refs),
174        FormulaNode::Mul(left, right) | FormulaNode::Add(left, right) => {
175            collect_refs(left, refs);
176            collect_refs(right, refs);
177        }
178    }
179}
180
181struct FormulaParser<'a> {
182    index: &'a str,
183    input: &'a str,
184    pos: usize,
185}
186
187impl<'a> FormulaParser<'a> {
188    fn new(index: &'a str, input: &'a str) -> Self {
189        Self {
190            index,
191            input,
192            pos: 0,
193        }
194    }
195
196    fn parse_expr(&mut self) -> Result<FormulaNode> {
197        self.parse_add()
198    }
199
200    fn parse_add(&mut self) -> Result<FormulaNode> {
201        let mut node = self.parse_mul()?;
202        loop {
203            self.skip_ws();
204            if !self.consume('+') {
205                break;
206            }
207            let right = self.parse_mul()?;
208            node = FormulaNode::Add(Box::new(node), Box::new(right));
209        }
210        Ok(node)
211    }
212
213    fn parse_mul(&mut self) -> Result<FormulaNode> {
214        let mut node = self.parse_primary()?;
215        loop {
216            self.skip_ws();
217            if !self.consume('*') {
218                break;
219            }
220            let right = self.parse_primary()?;
221            node = FormulaNode::Mul(Box::new(node), Box::new(right));
222        }
223        Ok(node)
224    }
225
226    fn parse_primary(&mut self) -> Result<FormulaNode> {
227        self.skip_ws();
228        if self.is_eof() {
229            return Err(self.error(self.position(), "expected expression"));
230        }
231        if self.consume('(') {
232            let expr = self.parse_expr()?;
233            self.skip_ws();
234            if !self.consume(')') {
235                return Err(self.error(self.position(), "expected ')'"));
236            }
237            return Ok(expr);
238        }
239        if self.peek() == Some('{') {
240            return self.parse_col_ref();
241        }
242        if self.starts_ident("coalesce") {
243            return self.parse_coalesce();
244        }
245        if self
246            .peek()
247            .is_some_and(|ch| ch.is_ascii_digit() || ch == '.')
248        {
249            return self.parse_number().map(FormulaNode::Literal);
250        }
251        if self.starts_ident("CASE") {
252            return Err(self.error(self.position(), "CASE expressions are not supported"));
253        }
254        if self.starts_ident("SELECT") {
255            return Err(self.error(self.position(), "SELECT subqueries are not supported"));
256        }
257        if self
258            .peek()
259            .is_some_and(|ch| ch.is_ascii_alphabetic() || ch == '_')
260        {
261            let start = self.pos;
262            let ident = self.read_identifier();
263            self.skip_ws();
264            if self.peek() == Some('(') {
265                return Err(self.error_at(start, "function calls are not supported"));
266            }
267            return Err(self.error_at(start, &format!("unsupported token `{ident}`")));
268        }
269        Err(self.unexpected_at_current())
270    }
271
272    fn parse_coalesce(&mut self) -> Result<FormulaNode> {
273        let start = self.pos;
274        self.pos += "coalesce".len();
275        self.skip_ws();
276        if !self.consume('(') {
277            return Err(self.error_at(start, "coalesce requires '('"));
278        }
279        self.skip_ws();
280        if self.is_eof() {
281            return Err(self.error(self.input.len() + 2, "coalesce requires expression"));
282        }
283        let expr = self.parse_expr()?;
284        self.skip_ws();
285        if !self.consume(',') {
286            return Err(self.error(self.input.len() + 2, "coalesce requires fallback literal"));
287        }
288        let fallback = self.parse_number()?;
289        self.skip_ws();
290        if !self.consume(')') {
291            return Err(self.error(self.input.len() + 2, "coalesce requires closing ')'"));
292        }
293        Ok(FormulaNode::Coalesce(Box::new(expr), fallback))
294    }
295
296    fn parse_col_ref(&mut self) -> Result<FormulaNode> {
297        let start = self.pos;
298        self.pos += 1;
299        let body_start = self.pos;
300        while let Some(ch) = self.peek() {
301            if ch == '}' {
302                let name = &self.input[body_start..self.pos];
303                self.pos += 1;
304                if let Some(offset) = name.find('.') {
305                    return Err(self.error_at(
306                        body_start + offset,
307                        "table-qualified column references are not supported",
308                    ));
309                }
310                if name.is_empty()
311                    || !name
312                        .chars()
313                        .all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
314                {
315                    return Err(self.error_at(start, "invalid column reference"));
316                }
317                return Ok(FormulaNode::ColRef(name.to_string()));
318            }
319            self.pos += ch.len_utf8();
320        }
321        Err(self.error_at(start, "unterminated column reference"))
322    }
323
324    fn parse_number(&mut self) -> Result<f32> {
325        self.skip_ws();
326        let start = self.pos;
327        let mut seen_digit = false;
328        while let Some(ch) = self.peek() {
329            if ch.is_ascii_digit() {
330                seen_digit = true;
331                self.pos += 1;
332            } else if ch == '.' {
333                self.pos += 1;
334            } else {
335                break;
336            }
337        }
338        if !seen_digit {
339            return Err(self.error_at(start, "expected number literal"));
340        }
341        self.input[start..self.pos]
342            .parse::<f32>()
343            .map_err(|err| self.error_at(start, &format!("invalid number literal: {err}")))
344    }
345
346    fn unexpected_at_current(&self) -> Error {
347        match self.peek() {
348            Some('/') => self.error(self.position(), "unsupported operator `/`"),
349            Some('-') => self.error(self.position(), "unsupported operator `-`"),
350            Some(ch) => self.error(self.position(), &format!("unexpected token `{ch}`")),
351            None => self.error(self.position(), "unexpected end of formula"),
352        }
353    }
354
355    fn skip_ws(&mut self) {
356        while let Some(ch) = self.peek() {
357            if !ch.is_whitespace() {
358                break;
359            }
360            self.pos += ch.len_utf8();
361        }
362    }
363
364    fn consume(&mut self, expected: char) -> bool {
365        if self.peek() == Some(expected) {
366            self.pos += expected.len_utf8();
367            true
368        } else {
369            false
370        }
371    }
372
373    fn starts_ident(&self, ident: &str) -> bool {
374        self.input[self.pos..]
375            .get(..ident.len())
376            .is_some_and(|s| s.eq_ignore_ascii_case(ident))
377            && self.input[self.pos + ident.len()..]
378                .chars()
379                .next()
380                .is_none_or(|ch| !ch.is_ascii_alphanumeric() && ch != '_')
381    }
382
383    fn read_identifier(&mut self) -> &'a str {
384        let start = self.pos;
385        while let Some(ch) = self.peek() {
386            if ch.is_ascii_alphanumeric() || ch == '_' {
387                self.pos += ch.len_utf8();
388            } else {
389                break;
390            }
391        }
392        &self.input[start..self.pos]
393    }
394
395    fn peek(&self) -> Option<char> {
396        self.input[self.pos..].chars().next()
397    }
398
399    fn is_eof(&self) -> bool {
400        self.pos >= self.input.len()
401    }
402
403    fn position(&self) -> usize {
404        self.pos + 1
405    }
406
407    fn error_at(&self, zero_based: usize, reason: &str) -> Error {
408        self.error(zero_based + 1, reason)
409    }
410
411    fn error(&self, position: usize, reason: &str) -> Error {
412        Error::RankPolicyFormulaParse {
413            index: self.index.to_string(),
414            position,
415            reason: reason.to_string(),
416        }
417    }
418}