objectiveai_sdk/functions/expression/
params.rs1use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
7use objectiveai_sdk_macros::schema_override;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use starlark::values::{
11 Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
12};
13
14#[schema_override(RefOwnedEnum)]
18#[derive(Debug, Clone, PartialEq, Serialize)]
19#[serde(untagged)]
20pub enum Params<'i, 'to> {
21 Owned(ParamsOwned),
23 Ref(ParamsRef<'i, 'to>),
25}
26
27impl JsonSchema for Params<'static, 'static> {
28 fn schema_name() -> std::borrow::Cow<'static, str> {
29 ParamsOwned::schema_name()
30 }
31 fn json_schema(
32 generator: &mut schemars::SchemaGenerator,
33 ) -> schemars::Schema {
34 ParamsOwned::json_schema(generator)
35 }
36}
37
38impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
39 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40 where
41 D: serde::Deserializer<'de>,
42 {
43 let owned = ParamsOwned::deserialize(deserializer)?;
44 Ok(Params::Owned(owned))
45 }
46}
47
48#[schema_override(Owned)]
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
51#[schemars(rename = "functions.expression.Params")]
52pub struct ParamsOwned {
53 pub input: super::InputValue,
55 pub output: Option<TaskOutputOwned>,
57 pub map: Option<u64>,
59 pub tasks_min: Option<u64>,
62 pub tasks_max: Option<u64>,
65 pub depth: Option<u64>,
68 pub name: Option<String>,
71 pub spec: Option<String>,
74}
75
76#[schema_override(Ref)]
78#[derive(Debug, Clone, PartialEq, Serialize)]
79pub struct ParamsRef<'i, 'to> {
80 pub input: &'i super::InputValue,
82 pub output: Option<TaskOutput<'to>>,
84 pub map: Option<u64>,
86 pub tasks_min: Option<u64>,
89 pub tasks_max: Option<u64>,
92 pub depth: Option<u64>,
95 pub name: Option<&'i str>,
98 pub spec: Option<&'i str>,
101}
102
103#[schema_override(RefOwnedEnum)]
105#[derive(Debug, Clone, PartialEq, Serialize)]
106#[serde(untagged)]
107pub enum TaskOutput<'a> {
108 Owned(TaskOutputOwned),
110 Ref(TaskOutputRef<'a>),
112}
113
114impl JsonSchema for TaskOutput<'static> {
115 fn schema_name() -> std::borrow::Cow<'static, str> {
116 TaskOutputOwned::schema_name()
117 }
118 fn json_schema(
119 generator: &mut schemars::SchemaGenerator,
120 ) -> schemars::Schema {
121 TaskOutputOwned::json_schema(generator)
122 }
123}
124
125impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
126 fn to_starlark_value<'v>(
127 &self,
128 heap: &'v StarlarkHeap,
129 ) -> StarlarkValue<'v> {
130 match self {
131 TaskOutput::Owned(o) => o.to_starlark_value(heap),
132 TaskOutput::Ref(r) => r.to_starlark_value(heap),
133 }
134 }
135}
136
137impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
138 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
139 where
140 D: serde::Deserializer<'de>,
141 {
142 let owned = TaskOutputOwned::deserialize(deserializer)?;
143 Ok(TaskOutput::Owned(owned))
144 }
145}
146
147#[schema_override(Owned)]
149#[derive(
150 Debug,
151 Clone,
152 PartialEq,
153 Serialize,
154 Deserialize,
155 JsonSchema,
156 arbitrary::Arbitrary,
157)]
158#[serde(untagged)]
159#[schemars(rename = "functions.expression.TaskOutput")]
160pub enum TaskOutputOwned {
161 #[schemars(title = "Scalar")]
163 Scalar(
164 #[serde(deserialize_with = "crate::serde_util::decimal")]
165 #[schemars(with = "f64")]
166 #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)]
167 rust_decimal::Decimal,
168 ),
169 #[schemars(title = "Vector")]
171 Vector(
172 #[serde(deserialize_with = "crate::serde_util::vec_decimal")]
173 #[schemars(with = "Vec<f64>")]
174 #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)]
175 Vec<rust_decimal::Decimal>,
176 ),
177 #[schemars(title = "Vectors")]
179 Vectors(
180 #[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")]
181 #[schemars(with = "Vec<Vec<f64>>")]
182 #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)]
183 Vec<Vec<rust_decimal::Decimal>>,
184 ),
185 #[schemars(title = "Err")]
187 Err {
188 #[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
189 error: serde_json::Value,
190 },
191}
192
193impl ToStarlarkValue for TaskOutputOwned {
194 fn to_starlark_value<'v>(
195 &self,
196 heap: &'v StarlarkHeap,
197 ) -> StarlarkValue<'v> {
198 match self {
199 TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
200 TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
201 TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
202 TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
203 }
204 }
205}
206
207impl FromStarlarkValue for TaskOutputOwned {
208 fn from_starlark_value(
209 value: &StarlarkValue,
210 ) -> Result<Self, ExpressionError> {
211 use starlark::values::float::UnpackFloat;
212 if value.is_none() {
213 return Ok(TaskOutputOwned::Err {
214 error: serde_json::Value::Null,
215 });
216 }
217 if let Some(list) = starlark::values::list::ListRef::from_value(*value)
218 {
219 let mut all_numeric = true;
221 let mut all_lists = true;
222 let mut decimals = Vec::with_capacity(list.len());
223 let mut vecs = Vec::with_capacity(list.len());
224
225 for v in list.iter() {
226 if let Some(inner_list) =
227 starlark::values::list::ListRef::from_value(v)
228 {
229 let mut inner_decimals =
231 Vec::with_capacity(inner_list.len());
232 let mut inner_all_numeric = true;
233 for iv in inner_list.iter() {
234 if let Ok(Some(i)) = i64::unpack_value(iv) {
235 inner_decimals.push(rust_decimal::Decimal::from(i));
236 } else if let Ok(Some(UnpackFloat(f))) =
237 UnpackFloat::unpack_value(iv)
238 {
239 match rust_decimal::Decimal::try_from(f) {
240 Ok(d) => inner_decimals.push(d),
241 Err(_) => {
242 inner_all_numeric = false;
243 break;
244 }
245 }
246 } else {
247 inner_all_numeric = false;
248 break;
249 }
250 }
251 if inner_all_numeric {
252 vecs.push(inner_decimals);
253 } else {
254 all_lists = false;
255 }
256 all_numeric = false;
257 } else if let Ok(Some(i)) = i64::unpack_value(v) {
258 decimals.push(rust_decimal::Decimal::from(i));
259 all_lists = false;
260 } else if let Ok(Some(UnpackFloat(f))) =
261 UnpackFloat::unpack_value(v)
262 {
263 match rust_decimal::Decimal::try_from(f) {
264 Ok(d) => {
265 decimals.push(d);
266 all_lists = false;
267 }
268 Err(_) => {
269 all_numeric = false;
270 all_lists = false;
271 break;
272 }
273 }
274 } else {
275 all_numeric = false;
276 all_lists = false;
277 break;
278 }
279 }
280 if all_numeric && !decimals.is_empty() {
281 return Ok(TaskOutputOwned::Vector(decimals));
282 }
283 if all_numeric && decimals.is_empty() && list.len() == 0 {
284 return Ok(TaskOutputOwned::Vector(Vec::new()));
285 }
286 if all_lists && !vecs.is_empty() {
287 return Ok(TaskOutputOwned::Vectors(vecs));
288 }
289 if all_lists && vecs.is_empty() && list.len() == 0 {
290 return Ok(TaskOutputOwned::Vectors(Vec::new()));
291 }
292 }
293 if let Ok(Some(i)) = i64::unpack_value(*value) {
294 return Ok(TaskOutputOwned::Scalar(rust_decimal::Decimal::from(i)));
295 }
296 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
297 if let Ok(d) = rust_decimal::Decimal::try_from(f) {
298 return Ok(TaskOutputOwned::Scalar(d));
299 }
300 }
301 let v = serde_json::Value::from_starlark_value(value)?;
302 Ok(TaskOutputOwned::Err { error: v })
303 }
304}
305
306impl super::FromSpecial for TaskOutputOwned {
307 fn from_special(
308 special: &super::Special,
309 params: &super::Params,
310 ) -> Result<Self, super::ExpressionError> {
311 match special {
312 super::Special::Output => {
313 let output = params_output(params)?;
314 Ok(output.clone())
315 }
316 super::Special::TaskOutputL1Normalized => {
317 let output = params_output(params)?;
318 match output {
319 TaskOutputOwned::Scalar(_) => Ok(output.clone()),
320 TaskOutputOwned::Vector(v) => {
321 Ok(TaskOutputOwned::Vector(l1_normalize(v)))
322 }
323 TaskOutputOwned::Vectors(vecs) => {
324 Ok(TaskOutputOwned::Vectors(
325 vecs.iter().map(|v| l1_normalize(v)).collect(),
326 ))
327 }
328 TaskOutputOwned::Err { .. } => Ok(output.clone()),
329 }
330 }
331 super::Special::TaskOutputWeightedSum => {
332 let output = params_output(params)?;
333 match output {
334 TaskOutputOwned::Vector(scores) => {
335 Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
336 }
337 TaskOutputOwned::Vectors(vecs) => {
338 Ok(TaskOutputOwned::Vector(
339 vecs.iter()
340 .map(|scores| weighted_sum(scores))
341 .collect(),
342 ))
343 }
344 _ => Err(super::ExpressionError::UnsupportedSpecial),
345 }
346 }
347 _ => Err(super::ExpressionError::UnsupportedSpecial),
348 }
349 }
350}
351
352impl TaskOutputOwned {
353 pub fn into_err(self) -> Self {
355 match self {
356 Self::Scalar(scalar) => Self::Err {
357 error: serde_json::to_value(scalar).unwrap(),
358 },
359 Self::Vector(vector) => Self::Err {
360 error: serde_json::to_value(vector).unwrap(),
361 },
362 Self::Vectors(vectors) => Self::Err {
363 error: serde_json::to_value(vectors).unwrap(),
364 },
365 Self::Err { error } => Self::Err { error },
366 }
367 }
368}
369
370#[schema_override(Ref)]
372#[derive(Debug, Clone, PartialEq, Serialize)]
373#[serde(untagged)]
374pub enum TaskOutputRef<'a> {
375 Scalar(&'a rust_decimal::Decimal),
377 Vector(&'a [rust_decimal::Decimal]),
379 Vectors(&'a [Vec<rust_decimal::Decimal>]),
381 Err { error: &'a serde_json::Value },
383}
384
385impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
386 fn to_starlark_value<'v>(
387 &self,
388 heap: &'v StarlarkHeap,
389 ) -> StarlarkValue<'v> {
390 match self {
391 TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
392 TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
393 TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
394 TaskOutputRef::Err { error } => error.to_starlark_value(heap),
395 }
396 }
397}
398
399fn params_output<'a>(
400 params: &'a super::Params,
401) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
402 match params {
403 super::Params::Owned(o) => o
404 .output
405 .as_ref()
406 .ok_or(super::ExpressionError::UnsupportedSpecial),
407 super::Params::Ref(r) => match &r.output {
408 Some(TaskOutput::Owned(o)) => Ok(o),
409 Some(TaskOutput::Ref(_)) => {
410 Err(super::ExpressionError::UnsupportedSpecial)
413 }
414 None => Err(super::ExpressionError::UnsupportedSpecial),
415 },
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_task_output_deserialize_strict_err_wire_format() {
425 let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
427 assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
428
429 let parsed: TaskOutputOwned =
431 serde_json::from_str("[1, 2, 3]").unwrap();
432 assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
433
434 let parsed: TaskOutputOwned =
436 serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
437 assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
438
439 assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
442 assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
443 assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
444
445 let parsed: TaskOutputOwned =
448 serde_json::from_str(r#"{"error": "something"}"#).unwrap();
449 assert!(matches!(
450 parsed,
451 TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
452 ));
453
454 let parsed: TaskOutputOwned =
455 serde_json::from_str(r#"{"error": null}"#).unwrap();
456 assert!(matches!(
457 parsed,
458 TaskOutputOwned::Err {
459 error: serde_json::Value::Null
460 }
461 ));
462
463 let original = TaskOutputOwned::Err {
465 error: serde_json::Value::String("94".to_string()),
466 };
467 let json = serde_json::to_string(&original).unwrap();
468 assert_eq!(json, r#"{"error":"94"}"#);
469 let roundtripped: TaskOutputOwned =
470 serde_json::from_str(&json).unwrap();
471 assert!(matches!(
472 roundtripped,
473 TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
474 ));
475
476 let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
478 assert!(
479 matches!(parsed, TaskOutputOwned::Vector(_))
480 || matches!(parsed, TaskOutputOwned::Vectors(_))
481 );
482 }
483}
484
485fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
486 if v.is_empty() {
487 return Vec::new();
488 }
489 let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
490 if sum.is_zero() {
491 let uniform =
492 rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
493 vec![uniform; v.len()]
494 } else {
495 v.iter().map(|d| d / sum).collect()
496 }
497}
498
499fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
502 let len = scores.len();
503 if len <= 1 {
504 return scores.iter().sum();
505 }
506 let mut ws = rust_decimal::Decimal::ZERO;
507 let last = len - 1;
508 for (i, score) in scores.iter().enumerate() {
509 let weight =
510 rust_decimal::Decimal::from(i) / rust_decimal::Decimal::from(last);
511 ws += score * weight;
512 }
513 ws
514}