1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use crate::common::types::ScoreType;
5use itertools::Itertools;
6use crate::segment::common::operation_error::{OperationError, OperationResult};
7use crate::segment::index::query_optimization::rescore_formula::parsed_formula::*;
8use crate::segment::json_path::JsonPath;
9use crate::segment::types::{Condition, GeoPoint};
10use serde::Serialize;
11use serde_json::Value;
12
13#[derive(Clone, Debug, PartialEq, Serialize)]
14pub struct FormulaInternal {
15 pub formula: ExpressionInternal,
16 pub defaults: HashMap<String, Value>,
17}
18
19impl TryFrom<FormulaInternal> for ParsedFormula {
20 type Error = OperationError;
21
22 fn try_from(value: FormulaInternal) -> Result<Self, Self::Error> {
23 let FormulaInternal { formula, defaults } = value;
24
25 let mut payload_vars = HashSet::new();
26 let mut conditions = Vec::new();
27
28 let parsed_expression = formula.parse_and_convert(&mut payload_vars, &mut conditions)?;
29
30 let defaults = defaults
31 .into_iter()
32 .map(|(key, value)| {
33 let key = key
34 .as_str()
35 .parse()
36 .map_err(|msg| failed_to_parse("variable ID", &key, &msg))?;
37 OperationResult::Ok((key, value))
38 })
39 .try_collect()?;
40
41 Ok(ParsedFormula {
42 formula: parsed_expression,
43 payload_vars,
44 conditions,
45 defaults,
46 })
47 }
48}
49
50#[derive(Clone, Debug, PartialEq, Serialize)]
51pub enum ExpressionInternal {
52 Constant(f32),
53 Variable(String),
54 Condition(Box<Condition>),
55 GeoDistance {
56 origin: GeoPoint,
57 to: JsonPath,
58 },
59 Datetime(String),
60 DatetimeKey(JsonPath),
61 Mult(Vec<ExpressionInternal>),
62 Sum(Vec<ExpressionInternal>),
63 Neg(Box<ExpressionInternal>),
64 Div {
65 left: Box<ExpressionInternal>,
66 right: Box<ExpressionInternal>,
67 by_zero_default: Option<ScoreType>,
68 },
69 Sqrt(Box<ExpressionInternal>),
70 Pow {
71 base: Box<ExpressionInternal>,
72 exponent: Box<ExpressionInternal>,
73 },
74 Exp(Box<ExpressionInternal>),
75 Log10(Box<ExpressionInternal>),
76 Ln(Box<ExpressionInternal>),
77 Abs(Box<ExpressionInternal>),
78 Decay {
79 kind: DecayKind,
80 x: Box<ExpressionInternal>,
81 target: Option<Box<ExpressionInternal>>,
82 midpoint: Option<f32>,
83 scale: Option<f32>,
84 },
85}
86
87impl ExpressionInternal {
88 fn parse_and_convert(
89 self,
90 payload_vars: &mut HashSet<JsonPath>,
91 conditions: &mut Vec<Condition>,
92 ) -> OperationResult<ParsedExpression> {
93 let expr = match self {
94 ExpressionInternal::Constant(c) => {
95 ParsedExpression::Constant(PreciseScoreOrdered::from(PreciseScore::from(c)))
96 }
97 ExpressionInternal::Variable(var) => {
98 let var: VariableId = var
99 .parse()
100 .map_err(|msg| failed_to_parse("variable ID", &var, &msg))?;
101 if let VariableId::Payload(payload_var) = var.clone() {
102 payload_vars.insert(payload_var);
103 }
104 ParsedExpression::Variable(var)
105 }
106 ExpressionInternal::Condition(condition) => {
107 let condition_id = conditions.len();
108 conditions.push(*condition);
109 ParsedExpression::new_condition_id(condition_id)
110 }
111 ExpressionInternal::GeoDistance { origin, to } => {
112 payload_vars.insert(to.clone());
113 ParsedExpression::new_geo_distance(origin, to)
114 }
115 ExpressionInternal::Datetime(dt_str) => {
116 ParsedExpression::Datetime(DatetimeExpression::Constant(
117 dt_str
118 .parse()
119 .map_err(|err| failed_to_parse("date-time", &dt_str, err))?,
120 ))
121 }
122 ExpressionInternal::DatetimeKey(json_path) => {
123 payload_vars.insert(json_path.clone());
124 ParsedExpression::Datetime(DatetimeExpression::PayloadVariable(json_path))
125 }
126 ExpressionInternal::Mult(internal_expressions) => ParsedExpression::Mult(
127 internal_expressions
128 .into_iter()
129 .map(|expr| expr.parse_and_convert(payload_vars, conditions))
130 .try_collect()?,
131 ),
132 ExpressionInternal::Sum(expression_internals) => ParsedExpression::Sum(
133 expression_internals
134 .into_iter()
135 .map(|expr| expr.parse_and_convert(payload_vars, conditions))
136 .try_collect()?,
137 ),
138 ExpressionInternal::Neg(expression_internal) => ParsedExpression::new_neg(
139 expression_internal.parse_and_convert(payload_vars, conditions)?,
140 ),
141 ExpressionInternal::Div {
142 left,
143 right,
144 by_zero_default,
145 } => ParsedExpression::new_div(
146 left.parse_and_convert(payload_vars, conditions)?,
147 right.parse_and_convert(payload_vars, conditions)?,
148 by_zero_default.map(PreciseScore::from),
149 ),
150 ExpressionInternal::Sqrt(expression_internal) => ParsedExpression::Sqrt(Box::new(
151 expression_internal.parse_and_convert(payload_vars, conditions)?,
152 )),
153 ExpressionInternal::Pow { base, exponent } => ParsedExpression::Pow {
154 base: Box::new(base.parse_and_convert(payload_vars, conditions)?),
155 exponent: Box::new(exponent.parse_and_convert(payload_vars, conditions)?),
156 },
157 ExpressionInternal::Exp(expression_internal) => ParsedExpression::Exp(Box::new(
158 expression_internal.parse_and_convert(payload_vars, conditions)?,
159 )),
160 ExpressionInternal::Log10(expression_internal) => ParsedExpression::Log10(Box::new(
161 expression_internal.parse_and_convert(payload_vars, conditions)?,
162 )),
163 ExpressionInternal::Ln(expression_internal) => ParsedExpression::Ln(Box::new(
164 expression_internal.parse_and_convert(payload_vars, conditions)?,
165 )),
166 ExpressionInternal::Abs(expression_internal) => ParsedExpression::Abs(Box::new(
167 expression_internal.parse_and_convert(payload_vars, conditions)?,
168 )),
169 ExpressionInternal::Decay {
170 kind,
171 x,
172 target,
173 midpoint,
174 scale,
175 } => {
176 let lambda = ParsedExpression::decay_params_to_lambda(midpoint, scale, kind)?;
177
178 let x = x.parse_and_convert(payload_vars, conditions)?;
179
180 let target = target
181 .map(|t| t.parse_and_convert(payload_vars, conditions))
182 .transpose()?
183 .map(Box::new);
184
185 ParsedExpression::Decay {
186 kind,
187 x: Box::new(x),
188 target,
189 lambda: PreciseScoreOrdered::from(lambda),
190 }
191 }
192 };
193
194 Ok(expr)
195 }
196}
197
198fn failed_to_parse(what: &str, value: &str, message: impl fmt::Display) -> OperationError {
199 OperationError::validation_error(format!("failed to parse {what} {value}: {message}"))
200}