Skip to main content

objectiveai_sdk/functions/expression/
params.rs

1//! Parameters and context for expression evaluation.
2//!
3//! Provides the context available to expressions (JMESPath or Starlark) during
4//! compilation, including the function input, task outputs, and current map element.
5
6use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
7use objectiveai_sdk_macros::schema_override;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use starlark::values::{
11    Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
12};
13
14/// Context for evaluating expressions (JMESPath or Starlark).
15///
16/// Contains all data accessible within expressions: `input`, `output`, and `map`.
17#[schema_override(RefOwnedEnum)]
18#[derive(Debug, Clone, PartialEq, Serialize)]
19#[serde(untagged)]
20pub enum Params<'i, 'to> {
21    /// Owned version (for deserialization).
22    Owned(ParamsOwned),
23    /// Borrowed version (for efficient evaluation).
24    Ref(ParamsRef<'i, 'to>),
25}
26
27impl JsonSchema for Params<'static, 'static> {
28    fn schema_name() -> std::borrow::Cow<'static, str> {
29        ParamsOwned::schema_name()
30    }
31    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
32        ParamsOwned::json_schema(generator)
33    }
34}
35
36impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
37    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38    where
39        D: serde::Deserializer<'de>,
40    {
41        let owned = ParamsOwned::deserialize(deserializer)?;
42        Ok(Params::Owned(owned))
43    }
44}
45
46/// Owned version of expression parameters.
47#[schema_override(Owned)]
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
49#[schemars(rename = "functions.expression.Params")]
50pub struct ParamsOwned {
51    /// The function's input data.
52    pub input: super::InputValue,
53    /// Results from executed tasks. Only populated for task output expressions.
54    pub output: Option<TaskOutputOwned>,
55    /// Current map index. Only populated for mapped task expressions.
56    pub map: Option<u64>,
57    /// Resolved minimum task count for this node type.
58    /// Only provided for invention prompt expressions.
59    pub tasks_min: Option<u64>,
60    /// Resolved maximum task count for this node type.
61    /// Only provided for invention prompt expressions.
62    pub tasks_max: Option<u64>,
63    /// Current recursion depth.
64    /// Only provided for invention prompt expressions.
65    pub depth: Option<u64>,
66    /// The function's name.
67    /// Only provided for invention prompt expressions.
68    pub name: Option<String>,
69    /// The specification text.
70    /// Only provided for invention prompt expressions.
71    pub spec: Option<String>,
72}
73
74/// Borrowed version of expression parameters.
75#[schema_override(Ref)]
76#[derive(Debug, Clone, PartialEq, Serialize)]
77pub struct ParamsRef<'i, 'to> {
78    /// The function's input data.
79    pub input: &'i super::InputValue,
80    /// Results from executed tasks. Only populated for task output expressions.
81    pub output: Option<TaskOutput<'to>>,
82    /// Current map index. Only populated for mapped task expressions.
83    pub map: Option<u64>,
84    /// Resolved minimum task count for this node type.
85    /// Only provided for invention prompt expressions.
86    pub tasks_min: Option<u64>,
87    /// Resolved maximum task count for this node type.
88    /// Only provided for invention prompt expressions.
89    pub tasks_max: Option<u64>,
90    /// Current recursion depth.
91    /// Only provided for invention prompt expressions.
92    pub depth: Option<u64>,
93    /// The function's name.
94    /// Only provided for invention prompt expressions.
95    pub name: Option<&'i str>,
96    /// The specification text.
97    /// Only provided for invention prompt expressions.
98    pub spec: Option<&'i str>,
99}
100
101/// Output from an executed task.
102#[schema_override(RefOwnedEnum)]
103#[derive(Debug, Clone, PartialEq, Serialize)]
104#[serde(untagged)]
105pub enum TaskOutput<'a> {
106    /// Owned version.
107    Owned(TaskOutputOwned),
108    /// Borrowed version.
109    Ref(TaskOutputRef<'a>),
110}
111
112impl JsonSchema for TaskOutput<'static> {
113    fn schema_name() -> std::borrow::Cow<'static, str> {
114        TaskOutputOwned::schema_name()
115    }
116    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
117        TaskOutputOwned::json_schema(generator)
118    }
119}
120
121impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
122    fn to_starlark_value<'v>(
123        &self,
124        heap: &'v StarlarkHeap,
125    ) -> StarlarkValue<'v> {
126        match self {
127            TaskOutput::Owned(o) => o.to_starlark_value(heap),
128            TaskOutput::Ref(r) => r.to_starlark_value(heap),
129        }
130    }
131}
132
133impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
134    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135    where
136        D: serde::Deserializer<'de>,
137    {
138        let owned = TaskOutputOwned::deserialize(deserializer)?;
139        Ok(TaskOutput::Owned(owned))
140    }
141}
142
143/// Owned task output variants.
144#[schema_override(Owned)]
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
146#[serde(untagged)]
147#[schemars(rename = "functions.expression.TaskOutput")]
148pub enum TaskOutputOwned {
149    /// A single scalar score.
150    #[schemars(title = "Scalar")]
151    Scalar(#[serde(deserialize_with = "crate::serde_util::decimal")] #[schemars(with = "f64")] #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)] rust_decimal::Decimal),
152    /// A vector of scores.
153    #[schemars(title = "Vector")]
154    Vector(#[serde(deserialize_with = "crate::serde_util::vec_decimal")] #[schemars(with = "Vec<f64>")] #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)] Vec<rust_decimal::Decimal>),
155    /// Multiple vectors of scores (from mapped tasks).
156    #[schemars(title = "Vectors")]
157    Vectors(#[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")] #[schemars(with = "Vec<Vec<f64>>")] #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)] Vec<Vec<rust_decimal::Decimal>>),
158    /// An error occurred during execution.
159    #[schemars(title = "Err")]
160    Err {
161        #[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
162        error: serde_json::Value,
163    },
164}
165
166impl ToStarlarkValue for TaskOutputOwned {
167    fn to_starlark_value<'v>(
168        &self,
169        heap: &'v StarlarkHeap,
170    ) -> StarlarkValue<'v> {
171        match self {
172            TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
173            TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
174            TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
175            TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
176        }
177    }
178}
179
180impl FromStarlarkValue for TaskOutputOwned {
181    fn from_starlark_value(
182        value: &StarlarkValue,
183    ) -> Result<Self, ExpressionError> {
184        use starlark::values::float::UnpackFloat;
185        if value.is_none() {
186            return Ok(TaskOutputOwned::Err {
187                error: serde_json::Value::Null,
188            });
189        }
190        if let Some(list) = starlark::values::list::ListRef::from_value(*value)
191        {
192            // Check if it's a list of lists (Vectors) or list of numbers (Vector)
193            let mut all_numeric = true;
194            let mut all_lists = true;
195            let mut decimals = Vec::with_capacity(list.len());
196            let mut vecs = Vec::with_capacity(list.len());
197
198            for v in list.iter() {
199                if let Some(inner_list) =
200                    starlark::values::list::ListRef::from_value(v)
201                {
202                    // Try to parse inner list as numbers
203                    let mut inner_decimals =
204                        Vec::with_capacity(inner_list.len());
205                    let mut inner_all_numeric = true;
206                    for iv in inner_list.iter() {
207                        if let Ok(Some(i)) = i64::unpack_value(iv) {
208                            inner_decimals
209                                .push(rust_decimal::Decimal::from(i));
210                        } else if let Ok(Some(UnpackFloat(f))) =
211                            UnpackFloat::unpack_value(iv)
212                        {
213                            match rust_decimal::Decimal::try_from(f) {
214                                Ok(d) => inner_decimals.push(d),
215                                Err(_) => {
216                                    inner_all_numeric = false;
217                                    break;
218                                }
219                            }
220                        } else {
221                            inner_all_numeric = false;
222                            break;
223                        }
224                    }
225                    if inner_all_numeric {
226                        vecs.push(inner_decimals);
227                    } else {
228                        all_lists = false;
229                    }
230                    all_numeric = false;
231                } else if let Ok(Some(i)) = i64::unpack_value(v) {
232                    decimals.push(rust_decimal::Decimal::from(i));
233                    all_lists = false;
234                } else if let Ok(Some(UnpackFloat(f))) =
235                    UnpackFloat::unpack_value(v)
236                {
237                    match rust_decimal::Decimal::try_from(f) {
238                        Ok(d) => {
239                            decimals.push(d);
240                            all_lists = false;
241                        }
242                        Err(_) => {
243                            all_numeric = false;
244                            all_lists = false;
245                            break;
246                        }
247                    }
248                } else {
249                    all_numeric = false;
250                    all_lists = false;
251                    break;
252                }
253            }
254            if all_numeric && !decimals.is_empty() {
255                return Ok(TaskOutputOwned::Vector(decimals));
256            }
257            if all_numeric && decimals.is_empty() && list.len() == 0 {
258                return Ok(TaskOutputOwned::Vector(Vec::new()));
259            }
260            if all_lists && !vecs.is_empty() {
261                return Ok(TaskOutputOwned::Vectors(vecs));
262            }
263            if all_lists && vecs.is_empty() && list.len() == 0 {
264                return Ok(TaskOutputOwned::Vectors(Vec::new()));
265            }
266        }
267        if let Ok(Some(i)) = i64::unpack_value(*value) {
268            return Ok(TaskOutputOwned::Scalar(
269                rust_decimal::Decimal::from(i),
270            ));
271        }
272        if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
273            if let Ok(d) = rust_decimal::Decimal::try_from(f) {
274                return Ok(TaskOutputOwned::Scalar(d));
275            }
276        }
277        let v = serde_json::Value::from_starlark_value(value)?;
278        Ok(TaskOutputOwned::Err { error: v })
279    }
280}
281
282impl super::FromSpecial for TaskOutputOwned {
283    fn from_special(
284        special: &super::Special,
285        params: &super::Params,
286    ) -> Result<Self, super::ExpressionError> {
287        match special {
288            super::Special::Output => {
289                let output = params_output(params)?;
290                Ok(output.clone())
291            }
292            super::Special::TaskOutputL1Normalized => {
293                let output = params_output(params)?;
294                match output {
295                    TaskOutputOwned::Scalar(_) => Ok(output.clone()),
296                    TaskOutputOwned::Vector(v) => {
297                        Ok(TaskOutputOwned::Vector(l1_normalize(v)))
298                    }
299                    TaskOutputOwned::Vectors(vecs) => {
300                        Ok(TaskOutputOwned::Vectors(
301                            vecs.iter().map(|v| l1_normalize(v)).collect(),
302                        ))
303                    }
304                    TaskOutputOwned::Err { .. } => Ok(output.clone()),
305                }
306            }
307            super::Special::TaskOutputWeightedSum => {
308                let output = params_output(params)?;
309                match output {
310                    TaskOutputOwned::Vector(scores) => {
311                        Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
312                    }
313                    TaskOutputOwned::Vectors(vecs) => {
314                        Ok(TaskOutputOwned::Vector(
315                            vecs.iter()
316                                .map(|scores| weighted_sum(scores))
317                                .collect(),
318                        ))
319                    }
320                    _ => Err(super::ExpressionError::UnsupportedSpecial),
321                }
322            }
323            _ => Err(super::ExpressionError::UnsupportedSpecial),
324        }
325    }
326}
327
328impl TaskOutputOwned {
329    /// Converts the output into an error variant (wrapping the value as JSON).
330    pub fn into_err(self) -> Self {
331        match self {
332            Self::Scalar(scalar) => Self::Err {
333                error: serde_json::to_value(scalar).unwrap(),
334            },
335            Self::Vector(vector) => Self::Err {
336                error: serde_json::to_value(vector).unwrap(),
337            },
338            Self::Vectors(vectors) => Self::Err {
339                error: serde_json::to_value(vectors).unwrap(),
340            },
341            Self::Err { error } => Self::Err { error },
342        }
343    }
344}
345
346/// Borrowed task output variants.
347#[schema_override(Ref)]
348#[derive(Debug, Clone, PartialEq, Serialize)]
349#[serde(untagged)]
350pub enum TaskOutputRef<'a> {
351    /// A single scalar score.
352    Scalar(&'a rust_decimal::Decimal),
353    /// A vector of scores.
354    Vector(&'a [rust_decimal::Decimal]),
355    /// Multiple vectors of scores (from mapped tasks).
356    Vectors(&'a [Vec<rust_decimal::Decimal>]),
357    /// An error occurred during execution.
358    Err { error: &'a serde_json::Value },
359}
360
361impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
362    fn to_starlark_value<'v>(
363        &self,
364        heap: &'v StarlarkHeap,
365    ) -> StarlarkValue<'v> {
366        match self {
367            TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
368            TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
369            TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
370            TaskOutputRef::Err { error } => error.to_starlark_value(heap),
371        }
372    }
373}
374
375fn params_output<'a>(
376    params: &'a super::Params,
377) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
378    match params {
379        super::Params::Owned(o) => o
380            .output
381            .as_ref()
382            .ok_or(super::ExpressionError::UnsupportedSpecial),
383        super::Params::Ref(r) => match &r.output {
384            Some(TaskOutput::Owned(o)) => Ok(o),
385            Some(TaskOutput::Ref(_)) => {
386                // We can't return a reference to TaskOutputRef as TaskOutputOwned,
387                // but in practice this path uses Owned. If we hit Ref, it's unsupported.
388                Err(super::ExpressionError::UnsupportedSpecial)
389            }
390            None => Err(super::ExpressionError::UnsupportedSpecial),
391        },
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_task_output_deserialize_strict_err_wire_format() {
401        // JSON number → Scalar
402        let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
403        assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
404
405        // JSON array of numbers → Vector
406        let parsed: TaskOutputOwned = serde_json::from_str("[1, 2, 3]").unwrap();
407        assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
408
409        // JSON array of arrays → Vectors
410        let parsed: TaskOutputOwned = serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
411        assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
412
413        // Bare values that previously fell through to Err must now FAIL,
414        // since Err is wire-formatted as `{"error": ...}`.
415        assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
416        assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
417        assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
418
419        // `{"error": ...}` is now the canonical Err wire form. The inner value
420        // unwraps by exactly one level.
421        let parsed: TaskOutputOwned =
422            serde_json::from_str(r#"{"error": "something"}"#).unwrap();
423        assert!(matches!(
424            parsed,
425            TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
426        ));
427
428        let parsed: TaskOutputOwned =
429            serde_json::from_str(r#"{"error": null}"#).unwrap();
430        assert!(matches!(
431            parsed,
432            TaskOutputOwned::Err { error: serde_json::Value::Null }
433        ));
434
435        // Round-trip: Err { error: String("94") } ↔ {"error":"94"}.
436        let original = TaskOutputOwned::Err {
437            error: serde_json::Value::String("94".to_string()),
438        };
439        let json = serde_json::to_string(&original).unwrap();
440        assert_eq!(json, r#"{"error":"94"}"#);
441        let roundtripped: TaskOutputOwned = serde_json::from_str(&json).unwrap();
442        assert!(matches!(
443            roundtripped,
444            TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
445        ));
446
447        // Empty array → Vector (not Vectors, since no inner arrays)
448        let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
449        assert!(
450            matches!(parsed, TaskOutputOwned::Vector(_))
451                || matches!(parsed, TaskOutputOwned::Vectors(_))
452        );
453    }
454}
455
456fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
457    if v.is_empty() {
458        return Vec::new();
459    }
460    let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
461    if sum.is_zero() {
462        let uniform =
463            rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
464        vec![uniform; v.len()]
465    } else {
466        v.iter().map(|d| d / sum).collect()
467    }
468}
469
470/// Computes a weighted sum of scores where the first element has weight 0,
471/// the last element has weight 1, and intermediate elements are evenly spaced.
472fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
473    let len = scores.len();
474    if len <= 1 {
475        return scores.iter().sum();
476    }
477    let mut ws = rust_decimal::Decimal::ZERO;
478    let last = len - 1;
479    for (i, score) in scores.iter().enumerate() {
480        let weight = rust_decimal::Decimal::from(i)
481            / rust_decimal::Decimal::from(last);
482        ws += score * weight;
483    }
484    ws
485}