objectiveai_sdk/functions/expression/
params.rs1use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
7use objectiveai_sdk_macros::schema_override;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use starlark::values::{
11 Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
12};
13
14#[schema_override(RefOwnedEnum)]
18#[derive(Debug, Clone, PartialEq, Serialize)]
19#[serde(untagged)]
20pub enum Params<'i, 'to> {
21 Owned(ParamsOwned),
23 Ref(ParamsRef<'i, 'to>),
25}
26
27impl JsonSchema for Params<'static, 'static> {
28 fn schema_name() -> std::borrow::Cow<'static, str> {
29 ParamsOwned::schema_name()
30 }
31 fn json_schema(
32 generator: &mut schemars::SchemaGenerator,
33 ) -> schemars::Schema {
34 ParamsOwned::json_schema(generator)
35 }
36}
37
38impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
39 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
40 where
41 D: serde::Deserializer<'de>,
42 {
43 let owned = ParamsOwned::deserialize(deserializer)?;
44 Ok(Params::Owned(owned))
45 }
46}
47
48#[schema_override(Owned)]
50#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
51#[schemars(rename = "functions.expression.Params")]
52pub struct ParamsOwned {
53 pub input: super::InputValue,
55 pub output: Option<TaskOutputOwned>,
57 pub map: Option<u64>,
59}
60
61#[schema_override(Ref)]
63#[derive(Debug, Clone, PartialEq, Serialize)]
64pub struct ParamsRef<'i, 'to> {
65 pub input: &'i super::InputValue,
67 pub output: Option<TaskOutput<'to>>,
69 pub map: Option<u64>,
71}
72
73#[schema_override(RefOwnedEnum)]
75#[derive(Debug, Clone, PartialEq, Serialize)]
76#[serde(untagged)]
77pub enum TaskOutput<'a> {
78 Owned(TaskOutputOwned),
80 Ref(TaskOutputRef<'a>),
82}
83
84impl JsonSchema for TaskOutput<'static> {
85 fn schema_name() -> std::borrow::Cow<'static, str> {
86 TaskOutputOwned::schema_name()
87 }
88 fn json_schema(
89 generator: &mut schemars::SchemaGenerator,
90 ) -> schemars::Schema {
91 TaskOutputOwned::json_schema(generator)
92 }
93}
94
95impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
96 fn to_starlark_value<'v>(
97 &self,
98 heap: &'v StarlarkHeap,
99 ) -> StarlarkValue<'v> {
100 match self {
101 TaskOutput::Owned(o) => o.to_starlark_value(heap),
102 TaskOutput::Ref(r) => r.to_starlark_value(heap),
103 }
104 }
105}
106
107impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
108 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
109 where
110 D: serde::Deserializer<'de>,
111 {
112 let owned = TaskOutputOwned::deserialize(deserializer)?;
113 Ok(TaskOutput::Owned(owned))
114 }
115}
116
117#[schema_override(Owned)]
119#[derive(
120 Debug,
121 Clone,
122 PartialEq,
123 Serialize,
124 Deserialize,
125 JsonSchema,
126 arbitrary::Arbitrary,
127)]
128#[serde(untagged)]
129#[schemars(rename = "functions.expression.TaskOutput")]
130pub enum TaskOutputOwned {
131 #[schemars(title = "Scalar")]
133 Scalar(
134 #[serde(deserialize_with = "crate::serde_util::decimal")]
135 #[schemars(with = "f64")]
136 #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)]
137 rust_decimal::Decimal,
138 ),
139 #[schemars(title = "Vector")]
141 Vector(
142 #[serde(deserialize_with = "crate::serde_util::vec_decimal")]
143 #[schemars(with = "Vec<f64>")]
144 #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)]
145 Vec<rust_decimal::Decimal>,
146 ),
147 #[schemars(title = "Vectors")]
149 Vectors(
150 #[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")]
151 #[schemars(with = "Vec<Vec<f64>>")]
152 #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)]
153 Vec<Vec<rust_decimal::Decimal>>,
154 ),
155 #[schemars(title = "Err")]
157 Err {
158 #[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
159 error: serde_json::Value,
160 },
161}
162
163impl ToStarlarkValue for TaskOutputOwned {
164 fn to_starlark_value<'v>(
165 &self,
166 heap: &'v StarlarkHeap,
167 ) -> StarlarkValue<'v> {
168 match self {
169 TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
170 TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
171 TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
172 TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
173 }
174 }
175}
176
177impl FromStarlarkValue for TaskOutputOwned {
178 fn from_starlark_value(
179 value: &StarlarkValue,
180 ) -> Result<Self, ExpressionError> {
181 use starlark::values::float::UnpackFloat;
182 if value.is_none() {
183 return Ok(TaskOutputOwned::Err {
184 error: serde_json::Value::Null,
185 });
186 }
187 if let Some(list) = starlark::values::list::ListRef::from_value(*value)
188 {
189 let mut all_numeric = true;
191 let mut all_lists = true;
192 let mut decimals = Vec::with_capacity(list.len());
193 let mut vecs = Vec::with_capacity(list.len());
194
195 for v in list.iter() {
196 if let Some(inner_list) =
197 starlark::values::list::ListRef::from_value(v)
198 {
199 let mut inner_decimals =
201 Vec::with_capacity(inner_list.len());
202 let mut inner_all_numeric = true;
203 for iv in inner_list.iter() {
204 if let Ok(Some(i)) = i64::unpack_value(iv) {
205 inner_decimals.push(rust_decimal::Decimal::from(i));
206 } else if let Ok(Some(UnpackFloat(f))) =
207 UnpackFloat::unpack_value(iv)
208 {
209 match rust_decimal::Decimal::try_from(f) {
210 Ok(d) => inner_decimals.push(d),
211 Err(_) => {
212 inner_all_numeric = false;
213 break;
214 }
215 }
216 } else {
217 inner_all_numeric = false;
218 break;
219 }
220 }
221 if inner_all_numeric {
222 vecs.push(inner_decimals);
223 } else {
224 all_lists = false;
225 }
226 all_numeric = false;
227 } else if let Ok(Some(i)) = i64::unpack_value(v) {
228 decimals.push(rust_decimal::Decimal::from(i));
229 all_lists = false;
230 } else if let Ok(Some(UnpackFloat(f))) =
231 UnpackFloat::unpack_value(v)
232 {
233 match rust_decimal::Decimal::try_from(f) {
234 Ok(d) => {
235 decimals.push(d);
236 all_lists = false;
237 }
238 Err(_) => {
239 all_numeric = false;
240 all_lists = false;
241 break;
242 }
243 }
244 } else {
245 all_numeric = false;
246 all_lists = false;
247 break;
248 }
249 }
250 if all_numeric && !decimals.is_empty() {
251 return Ok(TaskOutputOwned::Vector(decimals));
252 }
253 if all_numeric && decimals.is_empty() && list.len() == 0 {
254 return Ok(TaskOutputOwned::Vector(Vec::new()));
255 }
256 if all_lists && !vecs.is_empty() {
257 return Ok(TaskOutputOwned::Vectors(vecs));
258 }
259 if all_lists && vecs.is_empty() && list.len() == 0 {
260 return Ok(TaskOutputOwned::Vectors(Vec::new()));
261 }
262 }
263 if let Ok(Some(i)) = i64::unpack_value(*value) {
264 return Ok(TaskOutputOwned::Scalar(rust_decimal::Decimal::from(i)));
265 }
266 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
267 if let Ok(d) = rust_decimal::Decimal::try_from(f) {
268 return Ok(TaskOutputOwned::Scalar(d));
269 }
270 }
271 let v = serde_json::Value::from_starlark_value(value)?;
272 Ok(TaskOutputOwned::Err { error: v })
273 }
274}
275
276impl super::FromSpecial for TaskOutputOwned {
277 fn from_special(
278 special: &super::Special,
279 params: &super::Params,
280 ) -> Result<Self, super::ExpressionError> {
281 match special {
282 super::Special::Output => {
283 let output = params_output(params)?;
284 Ok(output.clone())
285 }
286 super::Special::TaskOutputL1Normalized => {
287 let output = params_output(params)?;
288 match output {
289 TaskOutputOwned::Scalar(_) => Ok(output.clone()),
290 TaskOutputOwned::Vector(v) => {
291 Ok(TaskOutputOwned::Vector(l1_normalize(v)))
292 }
293 TaskOutputOwned::Vectors(vecs) => {
294 Ok(TaskOutputOwned::Vectors(
295 vecs.iter().map(|v| l1_normalize(v)).collect(),
296 ))
297 }
298 TaskOutputOwned::Err { .. } => Ok(output.clone()),
299 }
300 }
301 super::Special::TaskOutputWeightedSum => {
302 let output = params_output(params)?;
303 match output {
304 TaskOutputOwned::Vector(scores) => {
305 Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
306 }
307 TaskOutputOwned::Vectors(vecs) => {
308 Ok(TaskOutputOwned::Vector(
309 vecs.iter()
310 .map(|scores| weighted_sum(scores))
311 .collect(),
312 ))
313 }
314 _ => Err(super::ExpressionError::UnsupportedSpecial),
315 }
316 }
317 _ => Err(super::ExpressionError::UnsupportedSpecial),
318 }
319 }
320}
321
322impl TaskOutputOwned {
323 pub fn into_err(self) -> Self {
325 match self {
326 Self::Scalar(scalar) => Self::Err {
327 error: serde_json::to_value(scalar).unwrap(),
328 },
329 Self::Vector(vector) => Self::Err {
330 error: serde_json::to_value(vector).unwrap(),
331 },
332 Self::Vectors(vectors) => Self::Err {
333 error: serde_json::to_value(vectors).unwrap(),
334 },
335 Self::Err { error } => Self::Err { error },
336 }
337 }
338}
339
340#[schema_override(Ref)]
342#[derive(Debug, Clone, PartialEq, Serialize)]
343#[serde(untagged)]
344pub enum TaskOutputRef<'a> {
345 Scalar(&'a rust_decimal::Decimal),
347 Vector(&'a [rust_decimal::Decimal]),
349 Vectors(&'a [Vec<rust_decimal::Decimal>]),
351 Err { error: &'a serde_json::Value },
353}
354
355impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
356 fn to_starlark_value<'v>(
357 &self,
358 heap: &'v StarlarkHeap,
359 ) -> StarlarkValue<'v> {
360 match self {
361 TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
362 TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
363 TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
364 TaskOutputRef::Err { error } => error.to_starlark_value(heap),
365 }
366 }
367}
368
369fn params_output<'a>(
370 params: &'a super::Params,
371) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
372 match params {
373 super::Params::Owned(o) => o
374 .output
375 .as_ref()
376 .ok_or(super::ExpressionError::UnsupportedSpecial),
377 super::Params::Ref(r) => match &r.output {
378 Some(TaskOutput::Owned(o)) => Ok(o),
379 Some(TaskOutput::Ref(_)) => {
380 Err(super::ExpressionError::UnsupportedSpecial)
383 }
384 None => Err(super::ExpressionError::UnsupportedSpecial),
385 },
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_task_output_deserialize_strict_err_wire_format() {
395 let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
397 assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
398
399 let parsed: TaskOutputOwned =
401 serde_json::from_str("[1, 2, 3]").unwrap();
402 assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
403
404 let parsed: TaskOutputOwned =
406 serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
407 assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
408
409 assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
412 assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
413 assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
414
415 let parsed: TaskOutputOwned =
418 serde_json::from_str(r#"{"error": "something"}"#).unwrap();
419 assert!(matches!(
420 parsed,
421 TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
422 ));
423
424 let parsed: TaskOutputOwned =
425 serde_json::from_str(r#"{"error": null}"#).unwrap();
426 assert!(matches!(
427 parsed,
428 TaskOutputOwned::Err {
429 error: serde_json::Value::Null
430 }
431 ));
432
433 let original = TaskOutputOwned::Err {
435 error: serde_json::Value::String("94".to_string()),
436 };
437 let json = serde_json::to_string(&original).unwrap();
438 assert_eq!(json, r#"{"error":"94"}"#);
439 let roundtripped: TaskOutputOwned =
440 serde_json::from_str(&json).unwrap();
441 assert!(matches!(
442 roundtripped,
443 TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
444 ));
445
446 let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
448 assert!(
449 matches!(parsed, TaskOutputOwned::Vector(_))
450 || matches!(parsed, TaskOutputOwned::Vectors(_))
451 );
452 }
453}
454
455fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
456 if v.is_empty() {
457 return Vec::new();
458 }
459 let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
460 if sum.is_zero() {
461 let uniform =
462 rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
463 vec![uniform; v.len()]
464 } else {
465 v.iter().map(|d| d / sum).collect()
466 }
467}
468
469fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
472 let len = scores.len();
473 if len <= 1 {
474 return scores.iter().sum();
475 }
476 let mut ws = rust_decimal::Decimal::ZERO;
477 let last = len - 1;
478 for (i, score) in scores.iter().enumerate() {
479 let weight =
480 rust_decimal::Decimal::from(i) / rust_decimal::Decimal::from(last);
481 ws += score * weight;
482 }
483 ws
484}