Skip to main content

objectiveai_sdk/functions/expression/
params.rs

1//! Parameters and context for expression evaluation.
2//!
3//! Provides the context available to expressions (JMESPath or Starlark) during
4//! compilation, including the function input, task outputs, and current map element.
5
6use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
7use objectiveai_sdk_macros::schema_override;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use starlark::values::{
11    Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
12};
13
14/// Context for evaluating expressions (JMESPath or Starlark).
15///
16/// Contains all data accessible within expressions: `input`, `output`, and `map`.
17#[schema_override(RefOwnedEnum)]
18#[derive(Debug, Clone, PartialEq, Serialize)]
19#[serde(untagged)]
20pub enum Params<'i, 'to> {
21    /// Owned version (for deserialization).
22    Owned(ParamsOwned),
23    /// Borrowed version (for efficient evaluation).
24    Ref(ParamsRef<'i, 'to>),
25}
26
27impl JsonSchema for Params<'static, 'static> {
28    fn schema_name() -> std::borrow::Cow<'static, str> {
29        ParamsOwned::schema_name()
30    }
31    fn json_schema(
32        generator: &mut schemars::SchemaGenerator,
33    ) -> schemars::Schema {
34        ParamsOwned::json_schema(generator)
35    }
36}
37
38impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
39    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40    where
41        D: serde::Deserializer<'de>,
42    {
43        let owned = ParamsOwned::deserialize(deserializer)?;
44        Ok(Params::Owned(owned))
45    }
46}
47
48/// Owned version of expression parameters.
49#[schema_override(Owned)]
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
51#[schemars(rename = "functions.expression.Params")]
52pub struct ParamsOwned {
53    /// The function's input data.
54    pub input: super::InputValue,
55    /// Results from executed tasks. Only populated for task output expressions.
56    pub output: Option<TaskOutputOwned>,
57    /// Current map index. Only populated for mapped task expressions.
58    pub map: Option<u64>,
59    /// Resolved minimum task count for this node type.
60    /// Only provided for invention prompt expressions.
61    pub tasks_min: Option<u64>,
62    /// Resolved maximum task count for this node type.
63    /// Only provided for invention prompt expressions.
64    pub tasks_max: Option<u64>,
65    /// Current recursion depth.
66    /// Only provided for invention prompt expressions.
67    pub depth: Option<u64>,
68    /// The function's name.
69    /// Only provided for invention prompt expressions.
70    pub name: Option<String>,
71    /// The specification text.
72    /// Only provided for invention prompt expressions.
73    pub spec: Option<String>,
74}
75
76/// Borrowed version of expression parameters.
77#[schema_override(Ref)]
78#[derive(Debug, Clone, PartialEq, Serialize)]
79pub struct ParamsRef<'i, 'to> {
80    /// The function's input data.
81    pub input: &'i super::InputValue,
82    /// Results from executed tasks. Only populated for task output expressions.
83    pub output: Option<TaskOutput<'to>>,
84    /// Current map index. Only populated for mapped task expressions.
85    pub map: Option<u64>,
86    /// Resolved minimum task count for this node type.
87    /// Only provided for invention prompt expressions.
88    pub tasks_min: Option<u64>,
89    /// Resolved maximum task count for this node type.
90    /// Only provided for invention prompt expressions.
91    pub tasks_max: Option<u64>,
92    /// Current recursion depth.
93    /// Only provided for invention prompt expressions.
94    pub depth: Option<u64>,
95    /// The function's name.
96    /// Only provided for invention prompt expressions.
97    pub name: Option<&'i str>,
98    /// The specification text.
99    /// Only provided for invention prompt expressions.
100    pub spec: Option<&'i str>,
101}
102
103/// Output from an executed task.
104#[schema_override(RefOwnedEnum)]
105#[derive(Debug, Clone, PartialEq, Serialize)]
106#[serde(untagged)]
107pub enum TaskOutput<'a> {
108    /// Owned version.
109    Owned(TaskOutputOwned),
110    /// Borrowed version.
111    Ref(TaskOutputRef<'a>),
112}
113
114impl JsonSchema for TaskOutput<'static> {
115    fn schema_name() -> std::borrow::Cow<'static, str> {
116        TaskOutputOwned::schema_name()
117    }
118    fn json_schema(
119        generator: &mut schemars::SchemaGenerator,
120    ) -> schemars::Schema {
121        TaskOutputOwned::json_schema(generator)
122    }
123}
124
125impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
126    fn to_starlark_value<'v>(
127        &self,
128        heap: &'v StarlarkHeap,
129    ) -> StarlarkValue<'v> {
130        match self {
131            TaskOutput::Owned(o) => o.to_starlark_value(heap),
132            TaskOutput::Ref(r) => r.to_starlark_value(heap),
133        }
134    }
135}
136
137impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
138    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139    where
140        D: serde::Deserializer<'de>,
141    {
142        let owned = TaskOutputOwned::deserialize(deserializer)?;
143        Ok(TaskOutput::Owned(owned))
144    }
145}
146
147/// Owned task output variants.
148#[schema_override(Owned)]
149#[derive(
150    Debug,
151    Clone,
152    PartialEq,
153    Serialize,
154    Deserialize,
155    JsonSchema,
156    arbitrary::Arbitrary,
157)]
158#[serde(untagged)]
159#[schemars(rename = "functions.expression.TaskOutput")]
160pub enum TaskOutputOwned {
161    /// A single scalar score.
162    #[schemars(title = "Scalar")]
163    Scalar(
164        #[serde(deserialize_with = "crate::serde_util::decimal")]
165        #[schemars(with = "f64")]
166        #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)]
167        rust_decimal::Decimal,
168    ),
169    /// A vector of scores.
170    #[schemars(title = "Vector")]
171    Vector(
172        #[serde(deserialize_with = "crate::serde_util::vec_decimal")]
173        #[schemars(with = "Vec<f64>")]
174        #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)]
175        Vec<rust_decimal::Decimal>,
176    ),
177    /// Multiple vectors of scores (from mapped tasks).
178    #[schemars(title = "Vectors")]
179    Vectors(
180        #[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")]
181        #[schemars(with = "Vec<Vec<f64>>")]
182        #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)]
183        Vec<Vec<rust_decimal::Decimal>>,
184    ),
185    /// An error occurred during execution.
186    #[schemars(title = "Err")]
187    Err {
188        #[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
189        error: serde_json::Value,
190    },
191}
192
193impl ToStarlarkValue for TaskOutputOwned {
194    fn to_starlark_value<'v>(
195        &self,
196        heap: &'v StarlarkHeap,
197    ) -> StarlarkValue<'v> {
198        match self {
199            TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
200            TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
201            TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
202            TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
203        }
204    }
205}
206
207impl FromStarlarkValue for TaskOutputOwned {
208    fn from_starlark_value(
209        value: &StarlarkValue,
210    ) -> Result<Self, ExpressionError> {
211        use starlark::values::float::UnpackFloat;
212        if value.is_none() {
213            return Ok(TaskOutputOwned::Err {
214                error: serde_json::Value::Null,
215            });
216        }
217        if let Some(list) = starlark::values::list::ListRef::from_value(*value)
218        {
219            // Check if it's a list of lists (Vectors) or list of numbers (Vector)
220            let mut all_numeric = true;
221            let mut all_lists = true;
222            let mut decimals = Vec::with_capacity(list.len());
223            let mut vecs = Vec::with_capacity(list.len());
224
225            for v in list.iter() {
226                if let Some(inner_list) =
227                    starlark::values::list::ListRef::from_value(v)
228                {
229                    // Try to parse inner list as numbers
230                    let mut inner_decimals =
231                        Vec::with_capacity(inner_list.len());
232                    let mut inner_all_numeric = true;
233                    for iv in inner_list.iter() {
234                        if let Ok(Some(i)) = i64::unpack_value(iv) {
235                            inner_decimals.push(rust_decimal::Decimal::from(i));
236                        } else if let Ok(Some(UnpackFloat(f))) =
237                            UnpackFloat::unpack_value(iv)
238                        {
239                            match rust_decimal::Decimal::try_from(f) {
240                                Ok(d) => inner_decimals.push(d),
241                                Err(_) => {
242                                    inner_all_numeric = false;
243                                    break;
244                                }
245                            }
246                        } else {
247                            inner_all_numeric = false;
248                            break;
249                        }
250                    }
251                    if inner_all_numeric {
252                        vecs.push(inner_decimals);
253                    } else {
254                        all_lists = false;
255                    }
256                    all_numeric = false;
257                } else if let Ok(Some(i)) = i64::unpack_value(v) {
258                    decimals.push(rust_decimal::Decimal::from(i));
259                    all_lists = false;
260                } else if let Ok(Some(UnpackFloat(f))) =
261                    UnpackFloat::unpack_value(v)
262                {
263                    match rust_decimal::Decimal::try_from(f) {
264                        Ok(d) => {
265                            decimals.push(d);
266                            all_lists = false;
267                        }
268                        Err(_) => {
269                            all_numeric = false;
270                            all_lists = false;
271                            break;
272                        }
273                    }
274                } else {
275                    all_numeric = false;
276                    all_lists = false;
277                    break;
278                }
279            }
280            if all_numeric && !decimals.is_empty() {
281                return Ok(TaskOutputOwned::Vector(decimals));
282            }
283            if all_numeric && decimals.is_empty() && list.len() == 0 {
284                return Ok(TaskOutputOwned::Vector(Vec::new()));
285            }
286            if all_lists && !vecs.is_empty() {
287                return Ok(TaskOutputOwned::Vectors(vecs));
288            }
289            if all_lists && vecs.is_empty() && list.len() == 0 {
290                return Ok(TaskOutputOwned::Vectors(Vec::new()));
291            }
292        }
293        if let Ok(Some(i)) = i64::unpack_value(*value) {
294            return Ok(TaskOutputOwned::Scalar(rust_decimal::Decimal::from(i)));
295        }
296        if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
297            if let Ok(d) = rust_decimal::Decimal::try_from(f) {
298                return Ok(TaskOutputOwned::Scalar(d));
299            }
300        }
301        let v = serde_json::Value::from_starlark_value(value)?;
302        Ok(TaskOutputOwned::Err { error: v })
303    }
304}
305
306impl super::FromSpecial for TaskOutputOwned {
307    fn from_special(
308        special: &super::Special,
309        params: &super::Params,
310    ) -> Result<Self, super::ExpressionError> {
311        match special {
312            super::Special::Output => {
313                let output = params_output(params)?;
314                Ok(output.clone())
315            }
316            super::Special::TaskOutputL1Normalized => {
317                let output = params_output(params)?;
318                match output {
319                    TaskOutputOwned::Scalar(_) => Ok(output.clone()),
320                    TaskOutputOwned::Vector(v) => {
321                        Ok(TaskOutputOwned::Vector(l1_normalize(v)))
322                    }
323                    TaskOutputOwned::Vectors(vecs) => {
324                        Ok(TaskOutputOwned::Vectors(
325                            vecs.iter().map(|v| l1_normalize(v)).collect(),
326                        ))
327                    }
328                    TaskOutputOwned::Err { .. } => Ok(output.clone()),
329                }
330            }
331            super::Special::TaskOutputWeightedSum => {
332                let output = params_output(params)?;
333                match output {
334                    TaskOutputOwned::Vector(scores) => {
335                        Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
336                    }
337                    TaskOutputOwned::Vectors(vecs) => {
338                        Ok(TaskOutputOwned::Vector(
339                            vecs.iter()
340                                .map(|scores| weighted_sum(scores))
341                                .collect(),
342                        ))
343                    }
344                    _ => Err(super::ExpressionError::UnsupportedSpecial),
345                }
346            }
347            _ => Err(super::ExpressionError::UnsupportedSpecial),
348        }
349    }
350}
351
352impl TaskOutputOwned {
353    /// Converts the output into an error variant (wrapping the value as JSON).
354    pub fn into_err(self) -> Self {
355        match self {
356            Self::Scalar(scalar) => Self::Err {
357                error: serde_json::to_value(scalar).unwrap(),
358            },
359            Self::Vector(vector) => Self::Err {
360                error: serde_json::to_value(vector).unwrap(),
361            },
362            Self::Vectors(vectors) => Self::Err {
363                error: serde_json::to_value(vectors).unwrap(),
364            },
365            Self::Err { error } => Self::Err { error },
366        }
367    }
368}
369
370/// Borrowed task output variants.
371#[schema_override(Ref)]
372#[derive(Debug, Clone, PartialEq, Serialize)]
373#[serde(untagged)]
374pub enum TaskOutputRef<'a> {
375    /// A single scalar score.
376    Scalar(&'a rust_decimal::Decimal),
377    /// A vector of scores.
378    Vector(&'a [rust_decimal::Decimal]),
379    /// Multiple vectors of scores (from mapped tasks).
380    Vectors(&'a [Vec<rust_decimal::Decimal>]),
381    /// An error occurred during execution.
382    Err { error: &'a serde_json::Value },
383}
384
385impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
386    fn to_starlark_value<'v>(
387        &self,
388        heap: &'v StarlarkHeap,
389    ) -> StarlarkValue<'v> {
390        match self {
391            TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
392            TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
393            TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
394            TaskOutputRef::Err { error } => error.to_starlark_value(heap),
395        }
396    }
397}
398
399fn params_output<'a>(
400    params: &'a super::Params,
401) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
402    match params {
403        super::Params::Owned(o) => o
404            .output
405            .as_ref()
406            .ok_or(super::ExpressionError::UnsupportedSpecial),
407        super::Params::Ref(r) => match &r.output {
408            Some(TaskOutput::Owned(o)) => Ok(o),
409            Some(TaskOutput::Ref(_)) => {
410                // We can't return a reference to TaskOutputRef as TaskOutputOwned,
411                // but in practice this path uses Owned. If we hit Ref, it's unsupported.
412                Err(super::ExpressionError::UnsupportedSpecial)
413            }
414            None => Err(super::ExpressionError::UnsupportedSpecial),
415        },
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_task_output_deserialize_strict_err_wire_format() {
425        // JSON number → Scalar
426        let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
427        assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
428
429        // JSON array of numbers → Vector
430        let parsed: TaskOutputOwned =
431            serde_json::from_str("[1, 2, 3]").unwrap();
432        assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
433
434        // JSON array of arrays → Vectors
435        let parsed: TaskOutputOwned =
436            serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
437        assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
438
439        // Bare values that previously fell through to Err must now FAIL,
440        // since Err is wire-formatted as `{"error": ...}`.
441        assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
442        assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
443        assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
444
445        // `{"error": ...}` is now the canonical Err wire form. The inner value
446        // unwraps by exactly one level.
447        let parsed: TaskOutputOwned =
448            serde_json::from_str(r#"{"error": "something"}"#).unwrap();
449        assert!(matches!(
450            parsed,
451            TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
452        ));
453
454        let parsed: TaskOutputOwned =
455            serde_json::from_str(r#"{"error": null}"#).unwrap();
456        assert!(matches!(
457            parsed,
458            TaskOutputOwned::Err {
459                error: serde_json::Value::Null
460            }
461        ));
462
463        // Round-trip: Err { error: String("94") } ↔ {"error":"94"}.
464        let original = TaskOutputOwned::Err {
465            error: serde_json::Value::String("94".to_string()),
466        };
467        let json = serde_json::to_string(&original).unwrap();
468        assert_eq!(json, r#"{"error":"94"}"#);
469        let roundtripped: TaskOutputOwned =
470            serde_json::from_str(&json).unwrap();
471        assert!(matches!(
472            roundtripped,
473            TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
474        ));
475
476        // Empty array → Vector (not Vectors, since no inner arrays)
477        let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
478        assert!(
479            matches!(parsed, TaskOutputOwned::Vector(_))
480                || matches!(parsed, TaskOutputOwned::Vectors(_))
481        );
482    }
483}
484
485fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
486    if v.is_empty() {
487        return Vec::new();
488    }
489    let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
490    if sum.is_zero() {
491        let uniform =
492            rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
493        vec![uniform; v.len()]
494    } else {
495        v.iter().map(|d| d / sum).collect()
496    }
497}
498
499/// Computes a weighted sum of scores where the first element has weight 0,
500/// the last element has weight 1, and intermediate elements are evenly spaced.
501fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
502    let len = scores.len();
503    if len <= 1 {
504        return scores.iter().sum();
505    }
506    let mut ws = rust_decimal::Decimal::ZERO;
507    let last = len - 1;
508    for (i, score) in scores.iter().enumerate() {
509        let weight =
510            rust_decimal::Decimal::from(i) / rust_decimal::Decimal::from(last);
511        ws += score * weight;
512    }
513    ws
514}