Skip to main content

objectiveai_sdk/functions/expression/
params.rs

1//! Parameters and context for expression evaluation.
2//!
3//! Provides the context available to expressions (JMESPath or Starlark) during
4//! compilation, including the function input, task outputs, and current map element.
5
6use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
7use objectiveai_sdk_macros::schema_override;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use starlark::values::{
11    Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
12};
13
14/// Context for evaluating expressions (JMESPath or Starlark).
15///
16/// Contains all data accessible within expressions: `input`, `output`, and `map`.
17#[schema_override(RefOwnedEnum)]
18#[derive(Debug, Clone, PartialEq, Serialize)]
19#[serde(untagged)]
20pub enum Params<'i, 'to> {
21    /// Owned version (for deserialization).
22    Owned(ParamsOwned),
23    /// Borrowed version (for efficient evaluation).
24    Ref(ParamsRef<'i, 'to>),
25}
26
27impl JsonSchema for Params<'static, 'static> {
28    fn schema_name() -> std::borrow::Cow<'static, str> {
29        ParamsOwned::schema_name()
30    }
31    fn json_schema(
32        generator: &mut schemars::SchemaGenerator,
33    ) -> schemars::Schema {
34        ParamsOwned::json_schema(generator)
35    }
36}
37
38impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
39    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40    where
41        D: serde::Deserializer<'de>,
42    {
43        let owned = ParamsOwned::deserialize(deserializer)?;
44        Ok(Params::Owned(owned))
45    }
46}
47
48/// Owned version of expression parameters.
49#[schema_override(Owned)]
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
51#[schemars(rename = "functions.expression.Params")]
52pub struct ParamsOwned {
53    /// The function's input data.
54    pub input: super::InputValue,
55    /// Results from executed tasks. Only populated for task output expressions.
56    pub output: Option<TaskOutputOwned>,
57    /// Current map index. Only populated for mapped task expressions.
58    pub map: Option<u64>,
59}
60
61/// Borrowed version of expression parameters.
62#[schema_override(Ref)]
63#[derive(Debug, Clone, PartialEq, Serialize)]
64pub struct ParamsRef<'i, 'to> {
65    /// The function's input data.
66    pub input: &'i super::InputValue,
67    /// Results from executed tasks. Only populated for task output expressions.
68    pub output: Option<TaskOutput<'to>>,
69    /// Current map index. Only populated for mapped task expressions.
70    pub map: Option<u64>,
71}
72
73/// Output from an executed task.
74#[schema_override(RefOwnedEnum)]
75#[derive(Debug, Clone, PartialEq, Serialize)]
76#[serde(untagged)]
77pub enum TaskOutput<'a> {
78    /// Owned version.
79    Owned(TaskOutputOwned),
80    /// Borrowed version.
81    Ref(TaskOutputRef<'a>),
82}
83
84impl JsonSchema for TaskOutput<'static> {
85    fn schema_name() -> std::borrow::Cow<'static, str> {
86        TaskOutputOwned::schema_name()
87    }
88    fn json_schema(
89        generator: &mut schemars::SchemaGenerator,
90    ) -> schemars::Schema {
91        TaskOutputOwned::json_schema(generator)
92    }
93}
94
95impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
96    fn to_starlark_value<'v>(
97        &self,
98        heap: &'v StarlarkHeap,
99    ) -> StarlarkValue<'v> {
100        match self {
101            TaskOutput::Owned(o) => o.to_starlark_value(heap),
102            TaskOutput::Ref(r) => r.to_starlark_value(heap),
103        }
104    }
105}
106
107impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
108    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109    where
110        D: serde::Deserializer<'de>,
111    {
112        let owned = TaskOutputOwned::deserialize(deserializer)?;
113        Ok(TaskOutput::Owned(owned))
114    }
115}
116
117/// Owned task output variants.
118#[schema_override(Owned)]
119#[derive(
120    Debug,
121    Clone,
122    PartialEq,
123    Serialize,
124    Deserialize,
125    JsonSchema,
126    arbitrary::Arbitrary,
127)]
128#[serde(untagged)]
129#[schemars(rename = "functions.expression.TaskOutput")]
130pub enum TaskOutputOwned {
131    /// A single scalar score.
132    #[schemars(title = "Scalar")]
133    Scalar(
134        #[serde(deserialize_with = "crate::serde_util::decimal")]
135        #[schemars(with = "f64")]
136        #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)]
137        rust_decimal::Decimal,
138    ),
139    /// A vector of scores.
140    #[schemars(title = "Vector")]
141    Vector(
142        #[serde(deserialize_with = "crate::serde_util::vec_decimal")]
143        #[schemars(with = "Vec<f64>")]
144        #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)]
145        Vec<rust_decimal::Decimal>,
146    ),
147    /// Multiple vectors of scores (from mapped tasks).
148    #[schemars(title = "Vectors")]
149    Vectors(
150        #[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")]
151        #[schemars(with = "Vec<Vec<f64>>")]
152        #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)]
153        Vec<Vec<rust_decimal::Decimal>>,
154    ),
155    /// An error occurred during execution.
156    #[schemars(title = "Err")]
157    Err {
158        #[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
159        error: serde_json::Value,
160    },
161}
162
163impl ToStarlarkValue for TaskOutputOwned {
164    fn to_starlark_value<'v>(
165        &self,
166        heap: &'v StarlarkHeap,
167    ) -> StarlarkValue<'v> {
168        match self {
169            TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
170            TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
171            TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
172            TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
173        }
174    }
175}
176
177impl FromStarlarkValue for TaskOutputOwned {
178    fn from_starlark_value(
179        value: &StarlarkValue,
180    ) -> Result<Self, ExpressionError> {
181        use starlark::values::float::UnpackFloat;
182        if value.is_none() {
183            return Ok(TaskOutputOwned::Err {
184                error: serde_json::Value::Null,
185            });
186        }
187        if let Some(list) = starlark::values::list::ListRef::from_value(*value)
188        {
189            // Check if it's a list of lists (Vectors) or list of numbers (Vector)
190            let mut all_numeric = true;
191            let mut all_lists = true;
192            let mut decimals = Vec::with_capacity(list.len());
193            let mut vecs = Vec::with_capacity(list.len());
194
195            for v in list.iter() {
196                if let Some(inner_list) =
197                    starlark::values::list::ListRef::from_value(v)
198                {
199                    // Try to parse inner list as numbers
200                    let mut inner_decimals =
201                        Vec::with_capacity(inner_list.len());
202                    let mut inner_all_numeric = true;
203                    for iv in inner_list.iter() {
204                        if let Ok(Some(i)) = i64::unpack_value(iv) {
205                            inner_decimals.push(rust_decimal::Decimal::from(i));
206                        } else if let Ok(Some(UnpackFloat(f))) =
207                            UnpackFloat::unpack_value(iv)
208                        {
209                            match rust_decimal::Decimal::try_from(f) {
210                                Ok(d) => inner_decimals.push(d),
211                                Err(_) => {
212                                    inner_all_numeric = false;
213                                    break;
214                                }
215                            }
216                        } else {
217                            inner_all_numeric = false;
218                            break;
219                        }
220                    }
221                    if inner_all_numeric {
222                        vecs.push(inner_decimals);
223                    } else {
224                        all_lists = false;
225                    }
226                    all_numeric = false;
227                } else if let Ok(Some(i)) = i64::unpack_value(v) {
228                    decimals.push(rust_decimal::Decimal::from(i));
229                    all_lists = false;
230                } else if let Ok(Some(UnpackFloat(f))) =
231                    UnpackFloat::unpack_value(v)
232                {
233                    match rust_decimal::Decimal::try_from(f) {
234                        Ok(d) => {
235                            decimals.push(d);
236                            all_lists = false;
237                        }
238                        Err(_) => {
239                            all_numeric = false;
240                            all_lists = false;
241                            break;
242                        }
243                    }
244                } else {
245                    all_numeric = false;
246                    all_lists = false;
247                    break;
248                }
249            }
250            if all_numeric && !decimals.is_empty() {
251                return Ok(TaskOutputOwned::Vector(decimals));
252            }
253            if all_numeric && decimals.is_empty() && list.len() == 0 {
254                return Ok(TaskOutputOwned::Vector(Vec::new()));
255            }
256            if all_lists && !vecs.is_empty() {
257                return Ok(TaskOutputOwned::Vectors(vecs));
258            }
259            if all_lists && vecs.is_empty() && list.len() == 0 {
260                return Ok(TaskOutputOwned::Vectors(Vec::new()));
261            }
262        }
263        if let Ok(Some(i)) = i64::unpack_value(*value) {
264            return Ok(TaskOutputOwned::Scalar(rust_decimal::Decimal::from(i)));
265        }
266        if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
267            if let Ok(d) = rust_decimal::Decimal::try_from(f) {
268                return Ok(TaskOutputOwned::Scalar(d));
269            }
270        }
271        let v = serde_json::Value::from_starlark_value(value)?;
272        Ok(TaskOutputOwned::Err { error: v })
273    }
274}
275
276impl super::FromSpecial for TaskOutputOwned {
277    fn from_special(
278        special: &super::Special,
279        params: &super::Params,
280    ) -> Result<Self, super::ExpressionError> {
281        match special {
282            super::Special::Output => {
283                let output = params_output(params)?;
284                Ok(output.clone())
285            }
286            super::Special::TaskOutputL1Normalized => {
287                let output = params_output(params)?;
288                match output {
289                    TaskOutputOwned::Scalar(_) => Ok(output.clone()),
290                    TaskOutputOwned::Vector(v) => {
291                        Ok(TaskOutputOwned::Vector(l1_normalize(v)))
292                    }
293                    TaskOutputOwned::Vectors(vecs) => {
294                        Ok(TaskOutputOwned::Vectors(
295                            vecs.iter().map(|v| l1_normalize(v)).collect(),
296                        ))
297                    }
298                    TaskOutputOwned::Err { .. } => Ok(output.clone()),
299                }
300            }
301            super::Special::TaskOutputWeightedSum => {
302                let output = params_output(params)?;
303                match output {
304                    TaskOutputOwned::Vector(scores) => {
305                        Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
306                    }
307                    TaskOutputOwned::Vectors(vecs) => {
308                        Ok(TaskOutputOwned::Vector(
309                            vecs.iter()
310                                .map(|scores| weighted_sum(scores))
311                                .collect(),
312                        ))
313                    }
314                    _ => Err(super::ExpressionError::UnsupportedSpecial),
315                }
316            }
317            _ => Err(super::ExpressionError::UnsupportedSpecial),
318        }
319    }
320}
321
322impl TaskOutputOwned {
323    /// Converts the output into an error variant (wrapping the value as JSON).
324    pub fn into_err(self) -> Self {
325        match self {
326            Self::Scalar(scalar) => Self::Err {
327                error: serde_json::to_value(scalar).unwrap(),
328            },
329            Self::Vector(vector) => Self::Err {
330                error: serde_json::to_value(vector).unwrap(),
331            },
332            Self::Vectors(vectors) => Self::Err {
333                error: serde_json::to_value(vectors).unwrap(),
334            },
335            Self::Err { error } => Self::Err { error },
336        }
337    }
338}
339
340/// Borrowed task output variants.
341#[schema_override(Ref)]
342#[derive(Debug, Clone, PartialEq, Serialize)]
343#[serde(untagged)]
344pub enum TaskOutputRef<'a> {
345    /// A single scalar score.
346    Scalar(&'a rust_decimal::Decimal),
347    /// A vector of scores.
348    Vector(&'a [rust_decimal::Decimal]),
349    /// Multiple vectors of scores (from mapped tasks).
350    Vectors(&'a [Vec<rust_decimal::Decimal>]),
351    /// An error occurred during execution.
352    Err { error: &'a serde_json::Value },
353}
354
355impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
356    fn to_starlark_value<'v>(
357        &self,
358        heap: &'v StarlarkHeap,
359    ) -> StarlarkValue<'v> {
360        match self {
361            TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
362            TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
363            TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
364            TaskOutputRef::Err { error } => error.to_starlark_value(heap),
365        }
366    }
367}
368
369fn params_output<'a>(
370    params: &'a super::Params,
371) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
372    match params {
373        super::Params::Owned(o) => o
374            .output
375            .as_ref()
376            .ok_or(super::ExpressionError::UnsupportedSpecial),
377        super::Params::Ref(r) => match &r.output {
378            Some(TaskOutput::Owned(o)) => Ok(o),
379            Some(TaskOutput::Ref(_)) => {
380                // We can't return a reference to TaskOutputRef as TaskOutputOwned,
381                // but in practice this path uses Owned. If we hit Ref, it's unsupported.
382                Err(super::ExpressionError::UnsupportedSpecial)
383            }
384            None => Err(super::ExpressionError::UnsupportedSpecial),
385        },
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_task_output_deserialize_strict_err_wire_format() {
395        // JSON number → Scalar
396        let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
397        assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
398
399        // JSON array of numbers → Vector
400        let parsed: TaskOutputOwned =
401            serde_json::from_str("[1, 2, 3]").unwrap();
402        assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
403
404        // JSON array of arrays → Vectors
405        let parsed: TaskOutputOwned =
406            serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
407        assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
408
409        // Bare values that previously fell through to Err must now FAIL,
410        // since Err is wire-formatted as `{"error": ...}`.
411        assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
412        assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
413        assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
414
415        // `{"error": ...}` is now the canonical Err wire form. The inner value
416        // unwraps by exactly one level.
417        let parsed: TaskOutputOwned =
418            serde_json::from_str(r#"{"error": "something"}"#).unwrap();
419        assert!(matches!(
420            parsed,
421            TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
422        ));
423
424        let parsed: TaskOutputOwned =
425            serde_json::from_str(r#"{"error": null}"#).unwrap();
426        assert!(matches!(
427            parsed,
428            TaskOutputOwned::Err {
429                error: serde_json::Value::Null
430            }
431        ));
432
433        // Round-trip: Err { error: String("94") } ↔ {"error":"94"}.
434        let original = TaskOutputOwned::Err {
435            error: serde_json::Value::String("94".to_string()),
436        };
437        let json = serde_json::to_string(&original).unwrap();
438        assert_eq!(json, r#"{"error":"94"}"#);
439        let roundtripped: TaskOutputOwned =
440            serde_json::from_str(&json).unwrap();
441        assert!(matches!(
442            roundtripped,
443            TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
444        ));
445
446        // Empty array → Vector (not Vectors, since no inner arrays)
447        let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
448        assert!(
449            matches!(parsed, TaskOutputOwned::Vector(_))
450                || matches!(parsed, TaskOutputOwned::Vectors(_))
451        );
452    }
453}
454
455fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
456    if v.is_empty() {
457        return Vec::new();
458    }
459    let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
460    if sum.is_zero() {
461        let uniform =
462            rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
463        vec![uniform; v.len()]
464    } else {
465        v.iter().map(|d| d / sum).collect()
466    }
467}
468
469/// Computes a weighted sum of scores where the first element has weight 0,
470/// the last element has weight 1, and intermediate elements are evenly spaced.
471fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
472    let len = scores.len();
473    if len <= 1 {
474        return scores.iter().sum();
475    }
476    let mut ws = rust_decimal::Decimal::ZERO;
477    let last = len - 1;
478    for (i, score) in scores.iter().enumerate() {
479        let weight =
480            rust_decimal::Decimal::from(i) / rust_decimal::Decimal::from(last);
481        ws += score * weight;
482    }
483    ws
484}