objectiveai_sdk/functions/expression/
starlark.rs1use starlark::environment::{Globals, GlobalsBuilder, Module};
6use starlark::eval::Evaluator;
7use starlark::starlark_module;
8use starlark::syntax::{AstModule, Dialect};
9use starlark::values::dict::DictRef;
10use starlark::values::float::UnpackFloat;
11use starlark::values::list::ListRef;
12use starlark::values::{Heap, UnpackValue, Value as StarlarkValue};
13use std::sync::LazyLock;
14
15use super::{ExpressionError, OneOrMany};
16
17pub static STARLARK_GLOBALS: LazyLock<Globals> = LazyLock::new(|| {
19 let mut builder = GlobalsBuilder::standard();
20 register_custom_functions(&mut builder);
21 builder.build()
22});
23
24#[starlark_module]
26fn register_custom_functions(builder: &mut GlobalsBuilder) {
27 fn sum<'v>(
29 #[starlark(require = pos)] xs: &ListRef<'v>,
30 ) -> starlark::Result<f64> {
31 let mut total = 0.0;
32 for x in xs.iter() {
33 let n = UnpackFloat::unpack_value(x)
34 .map_err(|e| {
35 starlark::Error::new_other(anyhow::anyhow!("{}", e))
36 })?
37 .ok_or_else(|| {
38 starlark::Error::new_other(anyhow::anyhow!(
39 "sum: expected number, got {}",
40 x.get_type()
41 ))
42 })?;
43 total += n.0;
44 }
45 Ok(total)
46 }
47
48 fn abs(#[starlark(require = pos)] x: UnpackFloat) -> starlark::Result<f64> {
50 Ok(x.0.abs())
51 }
52
53 fn float(
55 #[starlark(require = pos)] x: UnpackFloat,
56 ) -> starlark::Result<f64> {
57 Ok(x.0)
58 }
59
60 fn round(
62 #[starlark(require = pos)] x: UnpackFloat,
63 ) -> starlark::Result<i64> {
64 Ok(x.0.round() as i64)
65 }
66}
67
68pub trait ToStarlarkValue {
70 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v>;
71}
72
73impl ToStarlarkValue for str {
74 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
75 heap.alloc_str(self).to_value()
76 }
77}
78
79impl ToStarlarkValue for String {
80 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
81 heap.alloc_str(self).to_value()
82 }
83}
84
85impl ToStarlarkValue for i32 {
86 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
87 heap.alloc(*self as i64)
88 }
89}
90
91impl ToStarlarkValue for i64 {
92 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
93 heap.alloc(*self)
94 }
95}
96
97impl ToStarlarkValue for u32 {
98 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
99 heap.alloc(*self as i64)
100 }
101}
102
103impl ToStarlarkValue for u64 {
104 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
105 heap.alloc(*self as i64)
106 }
107}
108
109impl ToStarlarkValue for f64 {
110 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
111 heap.alloc(*self)
112 }
113}
114
115impl ToStarlarkValue for bool {
116 fn to_starlark_value<'v>(&self, _heap: &'v Heap) -> StarlarkValue<'v> {
117 StarlarkValue::new_bool(*self)
118 }
119}
120
121impl ToStarlarkValue for rust_decimal::Decimal {
122 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
123 use rust_decimal::prelude::ToPrimitive;
124 heap.alloc(self.to_f64().unwrap_or(0.0))
125 }
126}
127
128impl<T: ToStarlarkValue> ToStarlarkValue for Vec<T> {
129 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
130 let items: Vec<StarlarkValue> =
131 self.iter().map(|v| v.to_starlark_value(heap)).collect();
132 heap.alloc(items)
133 }
134}
135
136impl<T: ToStarlarkValue> ToStarlarkValue for [T] {
137 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
138 let items: Vec<StarlarkValue> =
139 self.iter().map(|v| v.to_starlark_value(heap)).collect();
140 heap.alloc(items)
141 }
142}
143
144impl<T: ToStarlarkValue> ToStarlarkValue for Option<T> {
145 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
146 match self {
147 Some(v) => v.to_starlark_value(heap),
148 None => StarlarkValue::new_none(),
149 }
150 }
151}
152
153impl<T: ToStarlarkValue> ToStarlarkValue for indexmap::IndexMap<String, T> {
154 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
155 let pairs: Vec<(&str, StarlarkValue)> = self
156 .iter()
157 .map(|(k, v)| (k.as_str(), v.to_starlark_value(heap)))
158 .collect();
159 heap.alloc(starlark::values::dict::AllocDict(pairs))
160 }
161}
162
163impl ToStarlarkValue for serde_json::Value {
164 fn to_starlark_value<'v>(&self, heap: &'v Heap) -> StarlarkValue<'v> {
165 match self {
166 serde_json::Value::Null => StarlarkValue::new_none(),
167 serde_json::Value::Bool(b) => b.to_starlark_value(heap),
168 serde_json::Value::Number(n) => {
169 if let Some(i) = n.as_i64() {
170 i.to_starlark_value(heap)
171 } else {
172 n.as_f64().unwrap_or(0.0).to_starlark_value(heap)
173 }
174 }
175 serde_json::Value::String(s) => s.to_starlark_value(heap),
176 serde_json::Value::Array(arr) => {
177 let items: Vec<StarlarkValue> =
178 arr.iter().map(|v| v.to_starlark_value(heap)).collect();
179 heap.alloc(items)
180 }
181 serde_json::Value::Object(obj) => {
182 let pairs: Vec<(&str, StarlarkValue)> = obj
183 .iter()
184 .map(|(k, v)| (k.as_str(), v.to_starlark_value(heap)))
185 .collect();
186 heap.alloc(starlark::values::dict::AllocDict(pairs))
187 }
188 }
189 }
190}
191pub trait FromStarlarkValue: Sized {
196 fn from_starlark_value(
197 value: &StarlarkValue,
198 ) -> Result<Self, ExpressionError>;
199}
200
201impl FromStarlarkValue for rust_decimal::Decimal {
203 fn from_starlark_value(
204 value: &StarlarkValue,
205 ) -> Result<Self, ExpressionError> {
206 if let Ok(Some(i)) = i64::unpack_value(*value) {
207 return Ok(rust_decimal::Decimal::from(i));
208 }
209 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
210 return rust_decimal::Decimal::try_from(f).map_err(|e| {
211 ExpressionError::StarlarkConversionError(format!(
212 "Decimal: {}",
213 e
214 ))
215 });
216 }
217 Err(ExpressionError::StarlarkConversionError(
218 "Decimal: expected number".into(),
219 ))
220 }
221}
222
223impl FromStarlarkValue for bool {
224 fn from_starlark_value(
225 value: &StarlarkValue,
226 ) -> Result<Self, ExpressionError> {
227 bool::unpack_value(*value)
228 .map_err(|e| {
229 ExpressionError::StarlarkConversionError(e.to_string())
230 })
231 .and_then(|o| {
232 o.ok_or_else(|| {
233 ExpressionError::StarlarkConversionError(
234 "expected bool".to_string(),
235 )
236 })
237 })
238 }
239}
240
241impl FromStarlarkValue for i64 {
242 fn from_starlark_value(
243 value: &StarlarkValue,
244 ) -> Result<Self, ExpressionError> {
245 i64::unpack_value(*value)
246 .map_err(|e| {
247 ExpressionError::StarlarkConversionError(e.to_string())
248 })
249 .and_then(|o| {
250 o.ok_or_else(|| {
251 ExpressionError::StarlarkConversionError(
252 "expected int".to_string(),
253 )
254 })
255 })
256 }
257}
258
259impl FromStarlarkValue for u64 {
260 fn from_starlark_value(
261 value: &StarlarkValue,
262 ) -> Result<Self, ExpressionError> {
263 let i = i64::unpack_value(*value)
264 .map_err(|e| {
265 ExpressionError::StarlarkConversionError(e.to_string())
266 })?
267 .ok_or_else(|| {
268 ExpressionError::StarlarkConversionError(
269 "expected int".to_string(),
270 )
271 })?;
272 if i < 0 {
273 return Err(ExpressionError::StarlarkConversionError(
274 "expected non-negative int".to_string(),
275 ));
276 }
277 Ok(i as u64)
278 }
279}
280
281impl FromStarlarkValue for f64 {
282 fn from_starlark_value(
283 value: &StarlarkValue,
284 ) -> Result<Self, ExpressionError> {
285 if let Ok(Some(i)) = i64::unpack_value(*value) {
286 return Ok(i as f64);
287 }
288 UnpackFloat::unpack_value(*value)
289 .map_err(|e| {
290 ExpressionError::StarlarkConversionError(e.to_string())
291 })
292 .and_then(|o| {
293 o.ok_or_else(|| {
294 ExpressionError::StarlarkConversionError(
295 "expected number".to_string(),
296 )
297 })
298 })
299 .map(|u| u.0)
300 }
301}
302
303impl FromStarlarkValue for String {
304 fn from_starlark_value(
305 value: &StarlarkValue,
306 ) -> Result<Self, ExpressionError> {
307 <&str as UnpackValue>::unpack_value(*value)
308 .map_err(|e| {
309 ExpressionError::StarlarkConversionError(e.to_string())
310 })?
311 .map(|s| s.to_owned())
312 .ok_or_else(|| {
313 ExpressionError::StarlarkConversionError(
314 "expected string".to_string(),
315 )
316 })
317 }
318}
319
320impl<T: FromStarlarkValue> FromStarlarkValue for Option<T> {
321 fn from_starlark_value(
322 value: &StarlarkValue,
323 ) -> Result<Self, ExpressionError> {
324 if value.is_none() {
325 return Ok(None);
326 }
327 T::from_starlark_value(value).map(Some)
328 }
329}
330
331impl<T: FromStarlarkValue> FromStarlarkValue for Vec<T> {
332 fn from_starlark_value(
333 value: &StarlarkValue,
334 ) -> Result<Self, ExpressionError> {
335 let list = ListRef::from_value(*value).ok_or_else(|| {
336 ExpressionError::StarlarkConversionError(
337 "expected list".to_string(),
338 )
339 })?;
340 let mut out = Vec::with_capacity(list.len());
341 for v in list.iter() {
342 out.push(T::from_starlark_value(&v)?);
343 }
344 Ok(out)
345 }
346}
347
348impl<V: FromStarlarkValue> FromStarlarkValue for indexmap::IndexMap<String, V> {
349 fn from_starlark_value(
350 value: &StarlarkValue,
351 ) -> Result<Self, ExpressionError> {
352 let dict = DictRef::from_value(*value).ok_or_else(|| {
353 ExpressionError::StarlarkConversionError(
354 "expected dict".to_string(),
355 )
356 })?;
357 let mut map = indexmap::IndexMap::with_capacity(dict.len());
358 for (k, v) in dict.iter() {
359 let key = <&str as UnpackValue>::unpack_value(k)
360 .map_err(|e| {
361 ExpressionError::StarlarkConversionError(e.to_string())
362 })?
363 .ok_or_else(|| {
364 ExpressionError::StarlarkConversionError(
365 "expected string key".to_string(),
366 )
367 })?
368 .to_owned();
369 map.insert(key, V::from_starlark_value(&v)?);
370 }
371 Ok(map)
372 }
373}
374
375impl FromStarlarkValue for serde_json::Value {
376 fn from_starlark_value(
377 value: &StarlarkValue,
378 ) -> Result<Self, ExpressionError> {
379 if value.is_none() {
380 return Ok(serde_json::Value::Null);
381 }
382 if let Ok(Some(b)) = bool::unpack_value(*value) {
383 return Ok(serde_json::Value::Bool(b));
384 }
385 if let Ok(Some(i)) = i64::unpack_value(*value) {
386 return Ok(serde_json::Value::Number(serde_json::Number::from(i)));
387 }
388 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
389 if let Some(n) = serde_json::Number::from_f64(f) {
390 return Ok(serde_json::Value::Number(n));
391 }
392 }
393 if let Ok(Some(s)) = <&str as UnpackValue>::unpack_value(*value) {
394 return Ok(serde_json::Value::String(s.to_owned()));
395 }
396 if let Some(list) = ListRef::from_value(*value) {
397 let mut items = Vec::with_capacity(list.len());
398 for v in list.iter() {
399 items.push(serde_json::Value::from_starlark_value(&v)?);
400 }
401 return Ok(serde_json::Value::Array(items));
402 }
403 if let Some(dict) = DictRef::from_value(*value) {
404 let mut obj = serde_json::Map::with_capacity(dict.len());
405 for (k, v) in dict.iter() {
406 let key = <&str as UnpackValue>::unpack_value(k)
407 .map_err(|e| {
408 ExpressionError::StarlarkConversionError(e.to_string())
409 })?
410 .ok_or_else(|| {
411 ExpressionError::StarlarkConversionError(
412 "expected string key".to_string(),
413 )
414 })?;
415 obj.insert(
416 key.to_owned(),
417 serde_json::Value::from_starlark_value(&v)?,
418 );
419 }
420 return Ok(serde_json::Value::Object(obj));
421 }
422 Err(ExpressionError::StarlarkConversionError(format!(
423 "unsupported type: {}",
424 value.get_type()
425 )))
426 }
427}
428
429pub(crate) fn with_eval_result<F, R>(
431 code: &str,
432 params: &super::Params,
433 f: F,
434) -> Result<R, ExpressionError>
435where
436 F: FnOnce(&StarlarkValue) -> Result<R, ExpressionError>,
437{
438 let module = Module::new();
439 {
440 let heap = module.heap();
441 match params {
442 super::Params::Owned(owned) => {
443 module.set("input", owned.input.to_starlark_value(heap));
444 module.set(
445 "output",
446 owned
447 .output
448 .as_ref()
449 .map_or(StarlarkValue::new_none(), |o| {
450 o.to_starlark_value(heap)
451 }),
452 );
453 module.set(
454 "map",
455 owned.map.map_or(StarlarkValue::new_none(), |m| {
456 heap.alloc(m as i64)
457 }),
458 );
459 module.set(
460 "tasks_min",
461 owned.tasks_min.map_or(StarlarkValue::new_none(), |v| {
462 heap.alloc(v as i64)
463 }),
464 );
465 module.set(
466 "tasks_max",
467 owned.tasks_max.map_or(StarlarkValue::new_none(), |v| {
468 heap.alloc(v as i64)
469 }),
470 );
471 module.set(
472 "depth",
473 owned.depth.map_or(StarlarkValue::new_none(), |v| {
474 heap.alloc(v as i64)
475 }),
476 );
477 module.set(
478 "name",
479 owned.name.as_ref().map_or(StarlarkValue::new_none(), |v| {
480 heap.alloc(v.as_str())
481 }),
482 );
483 module.set(
484 "spec",
485 owned.spec.as_ref().map_or(StarlarkValue::new_none(), |v| {
486 heap.alloc(v.as_str())
487 }),
488 );
489 }
490 super::Params::Ref(r) => {
491 module.set("input", r.input.to_starlark_value(heap));
492 module.set(
493 "output",
494 r.output.as_ref().map_or(StarlarkValue::new_none(), |o| {
495 o.to_starlark_value(heap)
496 }),
497 );
498 module.set(
499 "map",
500 r.map.map_or(StarlarkValue::new_none(), |m| {
501 heap.alloc(m as i64)
502 }),
503 );
504 module.set(
505 "tasks_min",
506 r.tasks_min.map_or(StarlarkValue::new_none(), |v| {
507 heap.alloc(v as i64)
508 }),
509 );
510 module.set(
511 "tasks_max",
512 r.tasks_max.map_or(StarlarkValue::new_none(), |v| {
513 heap.alloc(v as i64)
514 }),
515 );
516 module.set(
517 "depth",
518 r.depth.map_or(StarlarkValue::new_none(), |v| {
519 heap.alloc(v as i64)
520 }),
521 );
522 module.set(
523 "name",
524 r.name.map_or(StarlarkValue::new_none(), |v| {
525 heap.alloc(v)
526 }),
527 );
528 module.set(
529 "spec",
530 r.spec.map_or(StarlarkValue::new_none(), |v| {
531 heap.alloc(v)
532 }),
533 );
534 }
535 }
536 }
537 let ast =
538 AstModule::parse("expression", code.to_string(), &Dialect::Extended)
539 .map_err(|e| ExpressionError::StarlarkParseError(e.to_string()))?;
540 let mut eval = Evaluator::new(&module);
541 let result = eval
542 .eval_module(ast, &STARLARK_GLOBALS)
543 .map_err(|e| ExpressionError::StarlarkEvalError(e.to_string()))?;
544 f(&result)
545}
546
547fn svalue_to_one_or_many<T: FromStarlarkValue>(
548 value: &StarlarkValue,
549) -> Result<OneOrMany<T>, ExpressionError> {
550 if value.is_none() {
551 return Ok(OneOrMany::Many(Vec::new()));
552 }
553 if let Ok(v) = T::from_starlark_value(value) {
554 return Ok(OneOrMany::One(v));
555 }
556 if let Some(list) = ListRef::from_value(*value) {
557 let mut vs: Vec<T> = Vec::with_capacity(list.len());
558 for v in list.iter() {
559 if let Some(opt) = Option::<T>::from_starlark_value(&v)? {
560 vs.push(opt);
561 }
562 }
563 return Ok(if vs.is_empty() {
564 OneOrMany::Many(Vec::new())
565 } else if vs.len() == 1 {
566 OneOrMany::One(vs.into_iter().next().unwrap())
567 } else {
568 OneOrMany::Many(vs)
569 });
570 }
571 match Option::<T>::from_starlark_value(value)? {
572 Some(v) => Ok(OneOrMany::One(v)),
573 None => Ok(OneOrMany::Many(Vec::new())),
574 }
575}
576
577impl<T: FromStarlarkValue> FromStarlarkValue for OneOrMany<T> {
578 fn from_starlark_value(
579 value: &StarlarkValue,
580 ) -> Result<Self, ExpressionError> {
581 svalue_to_one_or_many(value)
582 }
583}
584
585impl<T: FromStarlarkValue> OneOrMany<T> {
586 pub fn from_starlark(
587 code: &str,
588 params: &super::Params,
589 ) -> Result<Self, ExpressionError> {
590 with_eval_result(code, params, svalue_to_one_or_many)
591 }
592}
593