1use contextdb_core::{Error, Result, Value};
2use std::collections::BTreeSet;
3
4#[derive(Debug, Clone)]
5pub struct RankFormula {
6 root: FormulaNode,
7 refs: Vec<String>,
8}
9
10impl RankFormula {
11 pub fn compile(formula: &str) -> Result<Self> {
12 Self::compile_for_index("", formula)
13 }
14
15 pub fn compile_for_index(index: &str, formula: &str) -> Result<Self> {
16 let mut parser = FormulaParser::new(index, formula);
17 let root = parser.parse_expr()?;
18 parser.skip_ws();
19 if !parser.is_eof() {
20 return Err(parser.unexpected_at_current());
21 }
22 let mut refs = BTreeSet::new();
23 collect_refs(&root, &mut refs);
24 Ok(Self {
25 root,
26 refs: refs.into_iter().collect(),
27 })
28 }
29
30 pub fn const_one() -> Self {
31 Self {
32 root: FormulaNode::Literal(1.0),
33 refs: Vec::new(),
34 }
35 }
36
37 pub fn column_refs(&self) -> &[String] {
38 &self.refs
39 }
40
41 pub fn eval_with_resolver(
42 &self,
43 vector_score: f32,
44 mut resolver: impl FnMut(&str) -> std::result::Result<Option<f32>, FormulaEvalError>,
45 ) -> std::result::Result<Option<f32>, FormulaEvalError> {
46 eval_node(&self.root, vector_score, &mut resolver)
47 }
48
49 pub fn eval(
50 &self,
51 anchor: &std::collections::HashMap<String, Value>,
52 joined: Option<&std::collections::HashMap<String, Value>>,
53 vector_score: f32,
54 ) -> std::result::Result<Option<f32>, FormulaEvalError> {
55 self.eval_with_resolver(vector_score, |column| {
56 let value = anchor
57 .get(column)
58 .or_else(|| joined.and_then(|row| row.get(column)))
59 .unwrap_or(&Value::Null);
60 value_to_rank_number(value, column)
61 })
62 }
63}
64
65#[derive(Debug, Clone)]
66enum FormulaNode {
67 Literal(f32),
68 ColRef(String),
69 Coalesce(Box<FormulaNode>, f32),
70 Mul(Box<FormulaNode>, Box<FormulaNode>),
71 Add(Box<FormulaNode>, Box<FormulaNode>),
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum FormulaEvalError {
76 UnsupportedType {
77 column: String,
78 actual: &'static str,
79 },
80 CorruptJoinedColumn {
81 column: String,
82 },
83}
84
85impl FormulaEvalError {
86 pub fn reason(&self) -> String {
87 match self {
88 FormulaEvalError::UnsupportedType { column, actual } => {
89 format!("rank formula column `{column}` has unsupported runtime type {actual}")
90 }
91 FormulaEvalError::CorruptJoinedColumn { column } => {
92 format!("failed to decode joined column `{column}`")
93 }
94 }
95 }
96}
97
98fn eval_node(
99 node: &FormulaNode,
100 vector_score: f32,
101 resolver: &mut impl FnMut(&str) -> std::result::Result<Option<f32>, FormulaEvalError>,
102) -> std::result::Result<Option<f32>, FormulaEvalError> {
103 match node {
104 FormulaNode::Literal(value) => Ok(Some(*value)),
105 FormulaNode::ColRef(column) if column == "vector_score" => Ok(Some(vector_score)),
106 FormulaNode::ColRef(column) => resolver(column),
107 FormulaNode::Coalesce(expr, fallback) => {
108 Ok(eval_node(expr, vector_score, resolver)?.or(Some(*fallback)))
109 }
110 FormulaNode::Mul(left, right) => {
111 match (
112 eval_node(left, vector_score, resolver)?,
113 eval_node(right, vector_score, resolver)?,
114 ) {
115 (Some(left), Some(right)) => Ok(Some(left * right)),
116 _ => Ok(None),
117 }
118 }
119 FormulaNode::Add(left, right) => {
120 match (
121 eval_node(left, vector_score, resolver)?,
122 eval_node(right, vector_score, resolver)?,
123 ) {
124 (Some(left), Some(right)) => Ok(Some(left + right)),
125 _ => Ok(None),
126 }
127 }
128 }
129}
130
131fn value_to_rank_number(
132 value: &Value,
133 column: &str,
134) -> std::result::Result<Option<f32>, FormulaEvalError> {
135 match value {
136 Value::Null => Ok(None),
137 Value::Float64(value) => Ok(Some(*value as f32)),
138 Value::Int64(value) => Ok(Some(*value as f32)),
139 Value::Bool(value) => Ok(Some(if *value { 1.0 } else { 0.0 })),
140 Value::Text(_) => Err(FormulaEvalError::UnsupportedType {
141 column: column.to_string(),
142 actual: "TEXT",
143 }),
144 Value::Json(_) => Err(FormulaEvalError::UnsupportedType {
145 column: column.to_string(),
146 actual: "JSON",
147 }),
148 Value::Uuid(_) => Err(FormulaEvalError::UnsupportedType {
149 column: column.to_string(),
150 actual: "UUID",
151 }),
152 Value::Vector(_) => Err(FormulaEvalError::UnsupportedType {
153 column: column.to_string(),
154 actual: "VECTOR",
155 }),
156 Value::Timestamp(_) => Err(FormulaEvalError::UnsupportedType {
157 column: column.to_string(),
158 actual: "TIMESTAMP",
159 }),
160 Value::TxId(_) => Err(FormulaEvalError::UnsupportedType {
161 column: column.to_string(),
162 actual: "TXID",
163 }),
164 }
165}
166
167fn collect_refs(node: &FormulaNode, refs: &mut BTreeSet<String>) {
168 match node {
169 FormulaNode::Literal(_) => {}
170 FormulaNode::ColRef(column) => {
171 refs.insert(column.clone());
172 }
173 FormulaNode::Coalesce(expr, _) => collect_refs(expr, refs),
174 FormulaNode::Mul(left, right) | FormulaNode::Add(left, right) => {
175 collect_refs(left, refs);
176 collect_refs(right, refs);
177 }
178 }
179}
180
181struct FormulaParser<'a> {
182 index: &'a str,
183 input: &'a str,
184 pos: usize,
185}
186
187impl<'a> FormulaParser<'a> {
188 fn new(index: &'a str, input: &'a str) -> Self {
189 Self {
190 index,
191 input,
192 pos: 0,
193 }
194 }
195
196 fn parse_expr(&mut self) -> Result<FormulaNode> {
197 self.parse_add()
198 }
199
200 fn parse_add(&mut self) -> Result<FormulaNode> {
201 let mut node = self.parse_mul()?;
202 loop {
203 self.skip_ws();
204 if !self.consume('+') {
205 break;
206 }
207 let right = self.parse_mul()?;
208 node = FormulaNode::Add(Box::new(node), Box::new(right));
209 }
210 Ok(node)
211 }
212
213 fn parse_mul(&mut self) -> Result<FormulaNode> {
214 let mut node = self.parse_primary()?;
215 loop {
216 self.skip_ws();
217 if !self.consume('*') {
218 break;
219 }
220 let right = self.parse_primary()?;
221 node = FormulaNode::Mul(Box::new(node), Box::new(right));
222 }
223 Ok(node)
224 }
225
226 fn parse_primary(&mut self) -> Result<FormulaNode> {
227 self.skip_ws();
228 if self.is_eof() {
229 return Err(self.error(self.position(), "expected expression"));
230 }
231 if self.consume('(') {
232 let expr = self.parse_expr()?;
233 self.skip_ws();
234 if !self.consume(')') {
235 return Err(self.error(self.position(), "expected ')'"));
236 }
237 return Ok(expr);
238 }
239 if self.peek() == Some('{') {
240 return self.parse_col_ref();
241 }
242 if self.starts_ident("coalesce") {
243 return self.parse_coalesce();
244 }
245 if self
246 .peek()
247 .is_some_and(|ch| ch.is_ascii_digit() || ch == '.')
248 {
249 return self.parse_number().map(FormulaNode::Literal);
250 }
251 if self.starts_ident("CASE") {
252 return Err(self.error(self.position(), "CASE expressions are not supported"));
253 }
254 if self.starts_ident("SELECT") {
255 return Err(self.error(self.position(), "SELECT subqueries are not supported"));
256 }
257 if self
258 .peek()
259 .is_some_and(|ch| ch.is_ascii_alphabetic() || ch == '_')
260 {
261 let start = self.pos;
262 let ident = self.read_identifier();
263 self.skip_ws();
264 if self.peek() == Some('(') {
265 return Err(self.error_at(start, "function calls are not supported"));
266 }
267 return Err(self.error_at(start, &format!("unsupported token `{ident}`")));
268 }
269 Err(self.unexpected_at_current())
270 }
271
272 fn parse_coalesce(&mut self) -> Result<FormulaNode> {
273 let start = self.pos;
274 self.pos += "coalesce".len();
275 self.skip_ws();
276 if !self.consume('(') {
277 return Err(self.error_at(start, "coalesce requires '('"));
278 }
279 self.skip_ws();
280 if self.is_eof() {
281 return Err(self.error(self.input.len() + 2, "coalesce requires expression"));
282 }
283 let expr = self.parse_expr()?;
284 self.skip_ws();
285 if !self.consume(',') {
286 return Err(self.error(self.input.len() + 2, "coalesce requires fallback literal"));
287 }
288 let fallback = self.parse_number()?;
289 self.skip_ws();
290 if !self.consume(')') {
291 return Err(self.error(self.input.len() + 2, "coalesce requires closing ')'"));
292 }
293 Ok(FormulaNode::Coalesce(Box::new(expr), fallback))
294 }
295
296 fn parse_col_ref(&mut self) -> Result<FormulaNode> {
297 let start = self.pos;
298 self.pos += 1;
299 let body_start = self.pos;
300 while let Some(ch) = self.peek() {
301 if ch == '}' {
302 let name = &self.input[body_start..self.pos];
303 self.pos += 1;
304 if let Some(offset) = name.find('.') {
305 return Err(self.error_at(
306 body_start + offset,
307 "table-qualified column references are not supported",
308 ));
309 }
310 if name.is_empty()
311 || !name
312 .chars()
313 .all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
314 {
315 return Err(self.error_at(start, "invalid column reference"));
316 }
317 return Ok(FormulaNode::ColRef(name.to_string()));
318 }
319 self.pos += ch.len_utf8();
320 }
321 Err(self.error_at(start, "unterminated column reference"))
322 }
323
324 fn parse_number(&mut self) -> Result<f32> {
325 self.skip_ws();
326 let start = self.pos;
327 let mut seen_digit = false;
328 while let Some(ch) = self.peek() {
329 if ch.is_ascii_digit() {
330 seen_digit = true;
331 self.pos += 1;
332 } else if ch == '.' {
333 self.pos += 1;
334 } else {
335 break;
336 }
337 }
338 if !seen_digit {
339 return Err(self.error_at(start, "expected number literal"));
340 }
341 self.input[start..self.pos]
342 .parse::<f32>()
343 .map_err(|err| self.error_at(start, &format!("invalid number literal: {err}")))
344 }
345
346 fn unexpected_at_current(&self) -> Error {
347 match self.peek() {
348 Some('/') => self.error(self.position(), "unsupported operator `/`"),
349 Some('-') => self.error(self.position(), "unsupported operator `-`"),
350 Some(ch) => self.error(self.position(), &format!("unexpected token `{ch}`")),
351 None => self.error(self.position(), "unexpected end of formula"),
352 }
353 }
354
355 fn skip_ws(&mut self) {
356 while let Some(ch) = self.peek() {
357 if !ch.is_whitespace() {
358 break;
359 }
360 self.pos += ch.len_utf8();
361 }
362 }
363
364 fn consume(&mut self, expected: char) -> bool {
365 if self.peek() == Some(expected) {
366 self.pos += expected.len_utf8();
367 true
368 } else {
369 false
370 }
371 }
372
373 fn starts_ident(&self, ident: &str) -> bool {
374 self.input[self.pos..]
375 .get(..ident.len())
376 .is_some_and(|s| s.eq_ignore_ascii_case(ident))
377 && self.input[self.pos + ident.len()..]
378 .chars()
379 .next()
380 .is_none_or(|ch| !ch.is_ascii_alphanumeric() && ch != '_')
381 }
382
383 fn read_identifier(&mut self) -> &'a str {
384 let start = self.pos;
385 while let Some(ch) = self.peek() {
386 if ch.is_ascii_alphanumeric() || ch == '_' {
387 self.pos += ch.len_utf8();
388 } else {
389 break;
390 }
391 }
392 &self.input[start..self.pos]
393 }
394
395 fn peek(&self) -> Option<char> {
396 self.input[self.pos..].chars().next()
397 }
398
399 fn is_eof(&self) -> bool {
400 self.pos >= self.input.len()
401 }
402
403 fn position(&self) -> usize {
404 self.pos + 1
405 }
406
407 fn error_at(&self, zero_based: usize, reason: &str) -> Error {
408 self.error(zero_based + 1, reason)
409 }
410
411 fn error(&self, position: usize, reason: &str) -> Error {
412 Error::RankPolicyFormulaParse {
413 index: self.index.to_string(),
414 position,
415 reason: reason.to_string(),
416 }
417 }
418}