Skip to main content

qdrant_edge/shard/query/
formula.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use crate::common::types::ScoreType;
5use itertools::Itertools;
6use crate::segment::common::operation_error::{OperationError, OperationResult};
7use crate::segment::index::query_optimization::rescore_formula::parsed_formula::*;
8use crate::segment::json_path::JsonPath;
9use crate::segment::types::{Condition, GeoPoint};
10use serde::Serialize;
11use serde_json::Value;
12
13#[derive(Clone, Debug, PartialEq, Serialize)]
14pub struct FormulaInternal {
15    pub formula: ExpressionInternal,
16    pub defaults: HashMap<String, Value>,
17}
18
19impl TryFrom<FormulaInternal> for ParsedFormula {
20    type Error = OperationError;
21
22    fn try_from(value: FormulaInternal) -> Result<Self, Self::Error> {
23        let FormulaInternal { formula, defaults } = value;
24
25        let mut payload_vars = HashSet::new();
26        let mut conditions = Vec::new();
27
28        let parsed_expression = formula.parse_and_convert(&mut payload_vars, &mut conditions)?;
29
30        let defaults = defaults
31            .into_iter()
32            .map(|(key, value)| {
33                let key = key
34                    .as_str()
35                    .parse()
36                    .map_err(|msg| failed_to_parse("variable ID", &key, &msg))?;
37                OperationResult::Ok((key, value))
38            })
39            .try_collect()?;
40
41        Ok(ParsedFormula {
42            formula: parsed_expression,
43            payload_vars,
44            conditions,
45            defaults,
46        })
47    }
48}
49
50#[derive(Clone, Debug, PartialEq, Serialize)]
51pub enum ExpressionInternal {
52    Constant(f32),
53    Variable(String),
54    Condition(Box<Condition>),
55    GeoDistance {
56        origin: GeoPoint,
57        to: JsonPath,
58    },
59    Datetime(String),
60    DatetimeKey(JsonPath),
61    Mult(Vec<ExpressionInternal>),
62    Sum(Vec<ExpressionInternal>),
63    Neg(Box<ExpressionInternal>),
64    Div {
65        left: Box<ExpressionInternal>,
66        right: Box<ExpressionInternal>,
67        by_zero_default: Option<ScoreType>,
68    },
69    Sqrt(Box<ExpressionInternal>),
70    Pow {
71        base: Box<ExpressionInternal>,
72        exponent: Box<ExpressionInternal>,
73    },
74    Exp(Box<ExpressionInternal>),
75    Log10(Box<ExpressionInternal>),
76    Ln(Box<ExpressionInternal>),
77    Abs(Box<ExpressionInternal>),
78    Decay {
79        kind: DecayKind,
80        x: Box<ExpressionInternal>,
81        target: Option<Box<ExpressionInternal>>,
82        midpoint: Option<f32>,
83        scale: Option<f32>,
84    },
85}
86
87impl ExpressionInternal {
88    fn parse_and_convert(
89        self,
90        payload_vars: &mut HashSet<JsonPath>,
91        conditions: &mut Vec<Condition>,
92    ) -> OperationResult<ParsedExpression> {
93        let expr = match self {
94            ExpressionInternal::Constant(c) => {
95                ParsedExpression::Constant(PreciseScoreOrdered::from(PreciseScore::from(c)))
96            }
97            ExpressionInternal::Variable(var) => {
98                let var: VariableId = var
99                    .parse()
100                    .map_err(|msg| failed_to_parse("variable ID", &var, &msg))?;
101                if let VariableId::Payload(payload_var) = var.clone() {
102                    payload_vars.insert(payload_var);
103                }
104                ParsedExpression::Variable(var)
105            }
106            ExpressionInternal::Condition(condition) => {
107                let condition_id = conditions.len();
108                conditions.push(*condition);
109                ParsedExpression::new_condition_id(condition_id)
110            }
111            ExpressionInternal::GeoDistance { origin, to } => {
112                payload_vars.insert(to.clone());
113                ParsedExpression::new_geo_distance(origin, to)
114            }
115            ExpressionInternal::Datetime(dt_str) => {
116                ParsedExpression::Datetime(DatetimeExpression::Constant(
117                    dt_str
118                        .parse()
119                        .map_err(|err| failed_to_parse("date-time", &dt_str, err))?,
120                ))
121            }
122            ExpressionInternal::DatetimeKey(json_path) => {
123                payload_vars.insert(json_path.clone());
124                ParsedExpression::Datetime(DatetimeExpression::PayloadVariable(json_path))
125            }
126            ExpressionInternal::Mult(internal_expressions) => ParsedExpression::Mult(
127                internal_expressions
128                    .into_iter()
129                    .map(|expr| expr.parse_and_convert(payload_vars, conditions))
130                    .try_collect()?,
131            ),
132            ExpressionInternal::Sum(expression_internals) => ParsedExpression::Sum(
133                expression_internals
134                    .into_iter()
135                    .map(|expr| expr.parse_and_convert(payload_vars, conditions))
136                    .try_collect()?,
137            ),
138            ExpressionInternal::Neg(expression_internal) => ParsedExpression::new_neg(
139                expression_internal.parse_and_convert(payload_vars, conditions)?,
140            ),
141            ExpressionInternal::Div {
142                left,
143                right,
144                by_zero_default,
145            } => ParsedExpression::new_div(
146                left.parse_and_convert(payload_vars, conditions)?,
147                right.parse_and_convert(payload_vars, conditions)?,
148                by_zero_default.map(PreciseScore::from),
149            ),
150            ExpressionInternal::Sqrt(expression_internal) => ParsedExpression::Sqrt(Box::new(
151                expression_internal.parse_and_convert(payload_vars, conditions)?,
152            )),
153            ExpressionInternal::Pow { base, exponent } => ParsedExpression::Pow {
154                base: Box::new(base.parse_and_convert(payload_vars, conditions)?),
155                exponent: Box::new(exponent.parse_and_convert(payload_vars, conditions)?),
156            },
157            ExpressionInternal::Exp(expression_internal) => ParsedExpression::Exp(Box::new(
158                expression_internal.parse_and_convert(payload_vars, conditions)?,
159            )),
160            ExpressionInternal::Log10(expression_internal) => ParsedExpression::Log10(Box::new(
161                expression_internal.parse_and_convert(payload_vars, conditions)?,
162            )),
163            ExpressionInternal::Ln(expression_internal) => ParsedExpression::Ln(Box::new(
164                expression_internal.parse_and_convert(payload_vars, conditions)?,
165            )),
166            ExpressionInternal::Abs(expression_internal) => ParsedExpression::Abs(Box::new(
167                expression_internal.parse_and_convert(payload_vars, conditions)?,
168            )),
169            ExpressionInternal::Decay {
170                kind,
171                x,
172                target,
173                midpoint,
174                scale,
175            } => {
176                let lambda = ParsedExpression::decay_params_to_lambda(midpoint, scale, kind)?;
177
178                let x = x.parse_and_convert(payload_vars, conditions)?;
179
180                let target = target
181                    .map(|t| t.parse_and_convert(payload_vars, conditions))
182                    .transpose()?
183                    .map(Box::new);
184
185                ParsedExpression::Decay {
186                    kind,
187                    x: Box::new(x),
188                    target,
189                    lambda: PreciseScoreOrdered::from(lambda),
190                }
191            }
192        };
193
194        Ok(expr)
195    }
196}
197
198fn failed_to_parse(what: &str, value: &str, message: impl fmt::Display) -> OperationError {
199    OperationError::validation_error(format!("failed to parse {what} {value}: {message}"))
200}