Skip to main content

objectiveai_sdk/functions/expression/
starlark.rs

1//! Starlark expression evaluation engine.
2//! Provides a sandboxed Starlark runtime for evaluating expressions.
3//! Variables `input`, `output`, and `map` are injected into the global scope.
4
5use starlark::environment::{Globals, GlobalsBuilder, Module};
6use starlark::eval::Evaluator;
7use starlark::starlark_module;
8use starlark::syntax::{AstModule, Dialect};
9use starlark::values::dict::DictRef;
10use starlark::values::float::UnpackFloat;
11use starlark::values::list::ListRef;
12use starlark::values::{Heap, UnpackValue, Value as StarlarkValue};
13use std::sync::LazyLock;
14
15use super::{ExpressionError, OneOrMany};
16
17/// Global Starlark globals with custom functions.
18pub static STARLARK_GLOBALS: LazyLock<Globals> = LazyLock::new(|| {
19    let mut builder = GlobalsBuilder::standard();
20    register_custom_functions(&mut builder);
21    builder.build()
22});
23
24/// Register custom functions that extend Starlark's standard library.
25#[starlark_module]
26fn register_custom_functions(builder: &mut GlobalsBuilder) {
27    /// Sum of a list of numbers. Returns 0 for empty list.
28    fn sum<'v>(
29        #[starlark(require = pos)] xs: &ListRef<'v>,
30    ) -> starlark::Result<f64> {
31        let mut total = 0.0;
32        for x in xs.iter() {
33            let n = UnpackFloat::unpack_value(x)
34                .map_err(|e| {
35                    starlark::Error::new_other(anyhow::anyhow!("{}", e))
36                })?
37                .ok_or_else(|| {
38                    starlark::Error::new_other(anyhow::anyhow!(
39                        "sum: expected number, got {}",
40                        x.get_type()
41                    ))
42                })?;
43            total += n.0;
44        }
45        Ok(total)
46    }
47
48    /// Absolute value of a number.
49    fn abs(#[starlark(require = pos)] x: UnpackFloat) -> starlark::Result<f64> {
50        Ok(x.0.abs())
51    }
52
53    /// Convert to float.
54    fn float(
55        #[starlark(require = pos)] x: UnpackFloat,
56    ) -> starlark::Result<f64> {
57        Ok(x.0)
58    }
59
60    /// Round a number to the nearest integer.
61    fn round(
62        #[starlark(require = pos)] x: UnpackFloat,
63    ) -> starlark::Result<i64> {
64        Ok(x.0.round() as i64)
65    }
66}
67
68/// Trait for direct conversion to Starlark values (bypassing serde_json).
69pub trait ToStarlarkValue {
70    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v>;
71}
72
73impl ToStarlarkValue for str {
74    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
75        heap.alloc_str(self).to_value()
76    }
77}
78
79impl ToStarlarkValue for String {
80    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
81        heap.alloc_str(self).to_value()
82    }
83}
84
85impl ToStarlarkValue for i32 {
86    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
87        heap.alloc(*self as i64)
88    }
89}
90
91impl ToStarlarkValue for i64 {
92    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
93        heap.alloc(*self)
94    }
95}
96
97impl ToStarlarkValue for u32 {
98    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
99        heap.alloc(*self as i64)
100    }
101}
102
103impl ToStarlarkValue for u64 {
104    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
105        heap.alloc(*self as i64)
106    }
107}
108
109impl ToStarlarkValue for f64 {
110    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
111        heap.alloc(*self)
112    }
113}
114
115impl ToStarlarkValue for bool {
116    fn to_starlark_value<'v>(&self, _heap: &'v Heap) -> StarlarkValue<'v> {
117        StarlarkValue::new_bool(*self)
118    }
119}
120
121impl ToStarlarkValue for rust_decimal::Decimal {
122    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
123        use rust_decimal::prelude::ToPrimitive;
124        heap.alloc(self.to_f64().unwrap_or(0.0))
125    }
126}
127
128impl<T: ToStarlarkValue> ToStarlarkValue for Vec<T> {
129    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
130        let items: Vec<StarlarkValue> =
131            self.iter().map(|v| v.to_starlark_value(heap)).collect();
132        heap.alloc(items)
133    }
134}
135
136impl<T: ToStarlarkValue> ToStarlarkValue for [T] {
137    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
138        let items: Vec<StarlarkValue> =
139            self.iter().map(|v| v.to_starlark_value(heap)).collect();
140        heap.alloc(items)
141    }
142}
143
144impl<T: ToStarlarkValue> ToStarlarkValue for Option<T> {
145    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
146        match self {
147            Some(v) => v.to_starlark_value(heap),
148            None => StarlarkValue::new_none(),
149        }
150    }
151}
152
153impl<T: ToStarlarkValue> ToStarlarkValue for indexmap::IndexMap<String, T> {
154    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
155        let pairs: Vec<(&str, StarlarkValue)> = self
156            .iter()
157            .map(|(k, v)| (k.as_str(), v.to_starlark_value(heap)))
158            .collect();
159        heap.alloc(starlark::values::dict::AllocDict(pairs))
160    }
161}
162
163impl ToStarlarkValue for serde_json::Value {
164    fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
165        match self {
166            serde_json::Value::Null => StarlarkValue::new_none(),
167            serde_json::Value::Bool(b) => b.to_starlark_value(heap),
168            serde_json::Value::Number(n) => {
169                if let Some(i) = n.as_i64() {
170                    i.to_starlark_value(heap)
171                } else {
172                    n.as_f64().unwrap_or(0.0).to_starlark_value(heap)
173                }
174            }
175            serde_json::Value::String(s) => s.to_starlark_value(heap),
176            serde_json::Value::Array(arr) => {
177                let items: Vec<StarlarkValue> =
178                    arr.iter().map(|v| v.to_starlark_value(heap)).collect();
179                heap.alloc(items)
180            }
181            serde_json::Value::Object(obj) => {
182                let pairs: Vec<(&str, StarlarkValue)> = obj
183                    .iter()
184                    .map(|(k, v)| (k.as_str(), v.to_starlark_value(heap)))
185                    .collect();
186                heap.alloc(starlark::values::dict::AllocDict(pairs))
187            }
188        }
189    }
190}
191/// Trait for converting a Starlark runtime value into a Rust type.
192///
193/// Used by [`Expression`](super::Expression) to compile Starlark expressions
194/// directly from `starlark::values::Value` to the target type.
195pub trait FromStarlarkValue: Sized {
196    fn from_starlark_value(
197        value: &StarlarkValue,
198    ) -> Result<Self, ExpressionError>;
199}
200
201// Primitives and common types
202impl FromStarlarkValue for rust_decimal::Decimal {
203    fn from_starlark_value(
204        value: &StarlarkValue,
205    ) -> Result<Self, ExpressionError> {
206        if let Ok(Some(i)) = i64::unpack_value(*value) {
207            return Ok(rust_decimal::Decimal::from(i));
208        }
209        if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
210            return rust_decimal::Decimal::try_from(f).map_err(|e| {
211                ExpressionError::StarlarkConversionError(format!(
212                    "Decimal: {}",
213                    e
214                ))
215            });
216        }
217        Err(ExpressionError::StarlarkConversionError(
218            "Decimal: expected number".into(),
219        ))
220    }
221}
222
223impl FromStarlarkValue for bool {
224    fn from_starlark_value(
225        value: &StarlarkValue,
226    ) -> Result<Self, ExpressionError> {
227        bool::unpack_value(*value)
228            .map_err(|e| {
229                ExpressionError::StarlarkConversionError(e.to_string())
230            })
231            .and_then(|o| {
232                o.ok_or_else(|| {
233                    ExpressionError::StarlarkConversionError(
234                        "expected bool".to_string(),
235                    )
236                })
237            })
238    }
239}
240
241impl FromStarlarkValue for i64 {
242    fn from_starlark_value(
243        value: &StarlarkValue,
244    ) -> Result<Self, ExpressionError> {
245        i64::unpack_value(*value)
246            .map_err(|e| {
247                ExpressionError::StarlarkConversionError(e.to_string())
248            })
249            .and_then(|o| {
250                o.ok_or_else(|| {
251                    ExpressionError::StarlarkConversionError(
252                        "expected int".to_string(),
253                    )
254                })
255            })
256    }
257}
258
259impl FromStarlarkValue for u64 {
260    fn from_starlark_value(
261        value: &StarlarkValue,
262    ) -> Result<Self, ExpressionError> {
263        let i = i64::unpack_value(*value)
264            .map_err(|e| {
265                ExpressionError::StarlarkConversionError(e.to_string())
266            })?
267            .ok_or_else(|| {
268                ExpressionError::StarlarkConversionError(
269                    "expected int".to_string(),
270                )
271            })?;
272        if i < 0 {
273            return Err(ExpressionError::StarlarkConversionError(
274                "expected non-negative int".to_string(),
275            ));
276        }
277        Ok(i as u64)
278    }
279}
280
281impl FromStarlarkValue for f64 {
282    fn from_starlark_value(
283        value: &StarlarkValue,
284    ) -> Result<Self, ExpressionError> {
285        if let Ok(Some(i)) = i64::unpack_value(*value) {
286            return Ok(i as f64);
287        }
288        UnpackFloat::unpack_value(*value)
289            .map_err(|e| {
290                ExpressionError::StarlarkConversionError(e.to_string())
291            })
292            .and_then(|o| {
293                o.ok_or_else(|| {
294                    ExpressionError::StarlarkConversionError(
295                        "expected number".to_string(),
296                    )
297                })
298            })
299            .map(|u| u.0)
300    }
301}
302
303impl FromStarlarkValue for String {
304    fn from_starlark_value(
305        value: &StarlarkValue,
306    ) -> Result<Self, ExpressionError> {
307        <&str as UnpackValue>::unpack_value(*value)
308            .map_err(|e| {
309                ExpressionError::StarlarkConversionError(e.to_string())
310            })?
311            .map(|s| s.to_owned())
312            .ok_or_else(|| {
313                ExpressionError::StarlarkConversionError(
314                    "expected string".to_string(),
315                )
316            })
317    }
318}
319
320impl<T: FromStarlarkValue> FromStarlarkValue for Option<T> {
321    fn from_starlark_value(
322        value: &StarlarkValue,
323    ) -> Result<Self, ExpressionError> {
324        if value.is_none() {
325            return Ok(None);
326        }
327        T::from_starlark_value(value).map(Some)
328    }
329}
330
331impl<T: FromStarlarkValue> FromStarlarkValue for Vec<T> {
332    fn from_starlark_value(
333        value: &StarlarkValue,
334    ) -> Result<Self, ExpressionError> {
335        let list = ListRef::from_value(*value).ok_or_else(|| {
336            ExpressionError::StarlarkConversionError(
337                "expected list".to_string(),
338            )
339        })?;
340        let mut out = Vec::with_capacity(list.len());
341        for v in list.iter() {
342            out.push(T::from_starlark_value(&v)?);
343        }
344        Ok(out)
345    }
346}
347
348impl<V: FromStarlarkValue> FromStarlarkValue for indexmap::IndexMap<String, V> {
349    fn from_starlark_value(
350        value: &StarlarkValue,
351    ) -> Result<Self, ExpressionError> {
352        let dict = DictRef::from_value(*value).ok_or_else(|| {
353            ExpressionError::StarlarkConversionError(
354                "expected dict".to_string(),
355            )
356        })?;
357        let mut map = indexmap::IndexMap::with_capacity(dict.len());
358        for (k, v) in dict.iter() {
359            let key = <&str as UnpackValue>::unpack_value(k)
360                .map_err(|e| {
361                    ExpressionError::StarlarkConversionError(e.to_string())
362                })?
363                .ok_or_else(|| {
364                    ExpressionError::StarlarkConversionError(
365                        "expected string key".to_string(),
366                    )
367                })?
368                .to_owned();
369            map.insert(key, V::from_starlark_value(&v)?);
370        }
371        Ok(map)
372    }
373}
374
375impl FromStarlarkValue for serde_json::Value {
376    fn from_starlark_value(
377        value: &StarlarkValue,
378    ) -> Result<Self, ExpressionError> {
379        if value.is_none() {
380            return Ok(serde_json::Value::Null);
381        }
382        if let Ok(Some(b)) = bool::unpack_value(*value) {
383            return Ok(serde_json::Value::Bool(b));
384        }
385        if let Ok(Some(i)) = i64::unpack_value(*value) {
386            return Ok(serde_json::Value::Number(serde_json::Number::from(i)));
387        }
388        if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
389            if let Some(n) = serde_json::Number::from_f64(f) {
390                return Ok(serde_json::Value::Number(n));
391            }
392        }
393        if let Ok(Some(s)) = <&str as UnpackValue>::unpack_value(*value) {
394            return Ok(serde_json::Value::String(s.to_owned()));
395        }
396        if let Some(list) = ListRef::from_value(*value) {
397            let mut items = Vec::with_capacity(list.len());
398            for v in list.iter() {
399                items.push(serde_json::Value::from_starlark_value(&v)?);
400            }
401            return Ok(serde_json::Value::Array(items));
402        }
403        if let Some(dict) = DictRef::from_value(*value) {
404            let mut obj = serde_json::Map::with_capacity(dict.len());
405            for (k, v) in dict.iter() {
406                let key = <&str as UnpackValue>::unpack_value(k)
407                    .map_err(|e| {
408                        ExpressionError::StarlarkConversionError(e.to_string())
409                    })?
410                    .ok_or_else(|| {
411                        ExpressionError::StarlarkConversionError(
412                            "expected string key".to_string(),
413                        )
414                    })?;
415                obj.insert(
416                    key.to_owned(),
417                    serde_json::Value::from_starlark_value(&v)?,
418                );
419            }
420            return Ok(serde_json::Value::Object(obj));
421        }
422        Err(ExpressionError::StarlarkConversionError(format!(
423            "unsupported type: {}",
424            value.get_type()
425        )))
426    }
427}
428
429/// Run a Starlark expression and pass the result (while still valid) to `f`.
430pub(crate) fn with_eval_result<F, R>(
431    code: &str,
432    params: &super::Params,
433    f: F,
434) -> Result<R, ExpressionError>
435where
436    F: FnOnce(&StarlarkValue) -> Result<R, ExpressionError>,
437{
438    let module = Module::new();
439    {
440        let heap = module.heap();
441        match params {
442            super::Params::Owned(owned) => {
443                module.set("input", owned.input.to_starlark_value(heap));
444                module.set(
445                    "output",
446                    owned
447                        .output
448                        .as_ref()
449                        .map_or(StarlarkValue::new_none(), |o| {
450                            o.to_starlark_value(heap)
451                        }),
452                );
453                module.set(
454                    "map",
455                    owned.map.map_or(StarlarkValue::new_none(), |m| {
456                        heap.alloc(m as i64)
457                    }),
458                );
459                module.set(
460                    "tasks_min",
461                    owned.tasks_min.map_or(StarlarkValue::new_none(), |v| {
462                        heap.alloc(v as i64)
463                    }),
464                );
465                module.set(
466                    "tasks_max",
467                    owned.tasks_max.map_or(StarlarkValue::new_none(), |v| {
468                        heap.alloc(v as i64)
469                    }),
470                );
471                module.set(
472                    "depth",
473                    owned.depth.map_or(StarlarkValue::new_none(), |v| {
474                        heap.alloc(v as i64)
475                    }),
476                );
477                module.set(
478                    "name",
479                    owned.name.as_ref().map_or(StarlarkValue::new_none(), |v| {
480                        heap.alloc(v.as_str())
481                    }),
482                );
483                module.set(
484                    "spec",
485                    owned.spec.as_ref().map_or(StarlarkValue::new_none(), |v| {
486                        heap.alloc(v.as_str())
487                    }),
488                );
489            }
490            super::Params::Ref(r) => {
491                module.set("input", r.input.to_starlark_value(heap));
492                module.set(
493                    "output",
494                    r.output.as_ref().map_or(StarlarkValue::new_none(), |o| {
495                        o.to_starlark_value(heap)
496                    }),
497                );
498                module.set(
499                    "map",
500                    r.map.map_or(StarlarkValue::new_none(), |m| {
501                        heap.alloc(m as i64)
502                    }),
503                );
504                module.set(
505                    "tasks_min",
506                    r.tasks_min.map_or(StarlarkValue::new_none(), |v| {
507                        heap.alloc(v as i64)
508                    }),
509                );
510                module.set(
511                    "tasks_max",
512                    r.tasks_max.map_or(StarlarkValue::new_none(), |v| {
513                        heap.alloc(v as i64)
514                    }),
515                );
516                module.set(
517                    "depth",
518                    r.depth.map_or(StarlarkValue::new_none(), |v| {
519                        heap.alloc(v as i64)
520                    }),
521                );
522                module.set(
523                    "name",
524                    r.name.map_or(StarlarkValue::new_none(), |v| {
525                        heap.alloc(v)
526                    }),
527                );
528                module.set(
529                    "spec",
530                    r.spec.map_or(StarlarkValue::new_none(), |v| {
531                        heap.alloc(v)
532                    }),
533                );
534            }
535        }
536    }
537    let ast =
538        AstModule::parse("expression", code.to_string(), &Dialect::Extended)
539            .map_err(|e| ExpressionError::StarlarkParseError(e.to_string()))?;
540    let mut eval = Evaluator::new(&module);
541    let result = eval
542        .eval_module(ast, &STARLARK_GLOBALS)
543        .map_err(|e| ExpressionError::StarlarkEvalError(e.to_string()))?;
544    f(&result)
545}
546
547fn svalue_to_one_or_many<T: FromStarlarkValue>(
548    value: &StarlarkValue,
549) -> Result<OneOrMany<T>, ExpressionError> {
550    if value.is_none() {
551        return Ok(OneOrMany::Many(Vec::new()));
552    }
553    if let Ok(v) = T::from_starlark_value(value) {
554        return Ok(OneOrMany::One(v));
555    }
556    if let Some(list) = ListRef::from_value(*value) {
557        let mut vs: Vec<T> = Vec::with_capacity(list.len());
558        for v in list.iter() {
559            if let Some(opt) = Option::<T>::from_starlark_value(&v)? {
560                vs.push(opt);
561            }
562        }
563        return Ok(if vs.is_empty() {
564            OneOrMany::Many(Vec::new())
565        } else if vs.len() == 1 {
566            OneOrMany::One(vs.into_iter().next().unwrap())
567        } else {
568            OneOrMany::Many(vs)
569        });
570    }
571    match Option::<T>::from_starlark_value(value)? {
572        Some(v) => Ok(OneOrMany::One(v)),
573        None => Ok(OneOrMany::Many(Vec::new())),
574    }
575}
576
577impl<T: FromStarlarkValue> FromStarlarkValue for OneOrMany<T> {
578    fn from_starlark_value(
579        value: &StarlarkValue,
580    ) -> Result<Self, ExpressionError> {
581        svalue_to_one_or_many(value)
582    }
583}
584
585impl<T: FromStarlarkValue> OneOrMany<T> {
586    pub fn from_starlark(
587        code: &str,
588        params: &super::Params,
589    ) -> Result<Self, ExpressionError> {
590        with_eval_result(code, params, svalue_to_one_or_many)
591    }
592}
593