objectiveai_sdk/functions/expression/
params.rs1use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
7use objectiveai_sdk_macros::schema_override;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use starlark::values::{
11 Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
12};
13
14#[schema_override(RefOwnedEnum)]
18#[derive(Debug, Clone, PartialEq, Serialize)]
19#[serde(untagged)]
20pub enum Params<'i, 'to> {
21 Owned(ParamsOwned),
23 Ref(ParamsRef<'i, 'to>),
25}
26
27impl JsonSchema for Params<'static, 'static> {
28 fn schema_name() -> std::borrow::Cow<'static, str> {
29 ParamsOwned::schema_name()
30 }
31 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
32 ParamsOwned::json_schema(generator)
33 }
34}
35
36impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
37 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
38 where
39 D: serde::Deserializer<'de>,
40 {
41 let owned = ParamsOwned::deserialize(deserializer)?;
42 Ok(Params::Owned(owned))
43 }
44}
45
46#[schema_override(Owned)]
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
49#[schemars(rename = "functions.expression.Params")]
50pub struct ParamsOwned {
51 pub input: super::InputValue,
53 pub output: Option<TaskOutputOwned>,
55 pub map: Option<u64>,
57 pub tasks_min: Option<u64>,
60 pub tasks_max: Option<u64>,
63 pub depth: Option<u64>,
66 pub name: Option<String>,
69 pub spec: Option<String>,
72}
73
74#[schema_override(Ref)]
76#[derive(Debug, Clone, PartialEq, Serialize)]
77pub struct ParamsRef<'i, 'to> {
78 pub input: &'i super::InputValue,
80 pub output: Option<TaskOutput<'to>>,
82 pub map: Option<u64>,
84 pub tasks_min: Option<u64>,
87 pub tasks_max: Option<u64>,
90 pub depth: Option<u64>,
93 pub name: Option<&'i str>,
96 pub spec: Option<&'i str>,
99}
100
101#[schema_override(RefOwnedEnum)]
103#[derive(Debug, Clone, PartialEq, Serialize)]
104#[serde(untagged)]
105pub enum TaskOutput<'a> {
106 Owned(TaskOutputOwned),
108 Ref(TaskOutputRef<'a>),
110}
111
112impl JsonSchema for TaskOutput<'static> {
113 fn schema_name() -> std::borrow::Cow<'static, str> {
114 TaskOutputOwned::schema_name()
115 }
116 fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
117 TaskOutputOwned::json_schema(generator)
118 }
119}
120
121impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
122 fn to_starlark_value<'v>(
123 &self,
124 heap: &'v StarlarkHeap,
125 ) -> StarlarkValue<'v> {
126 match self {
127 TaskOutput::Owned(o) => o.to_starlark_value(heap),
128 TaskOutput::Ref(r) => r.to_starlark_value(heap),
129 }
130 }
131}
132
133impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
134 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
135 where
136 D: serde::Deserializer<'de>,
137 {
138 let owned = TaskOutputOwned::deserialize(deserializer)?;
139 Ok(TaskOutput::Owned(owned))
140 }
141}
142
143#[schema_override(Owned)]
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
146#[serde(untagged)]
147#[schemars(rename = "functions.expression.TaskOutput")]
148pub enum TaskOutputOwned {
149 #[schemars(title = "Scalar")]
151 Scalar(#[serde(deserialize_with = "crate::serde_util::decimal")] #[schemars(with = "f64")] #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)] rust_decimal::Decimal),
152 #[schemars(title = "Vector")]
154 Vector(#[serde(deserialize_with = "crate::serde_util::vec_decimal")] #[schemars(with = "Vec<f64>")] #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)] Vec<rust_decimal::Decimal>),
155 #[schemars(title = "Vectors")]
157 Vectors(#[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")] #[schemars(with = "Vec<Vec<f64>>")] #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)] Vec<Vec<rust_decimal::Decimal>>),
158 #[schemars(title = "Err")]
160 Err {
161 #[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
162 error: serde_json::Value,
163 },
164}
165
166impl ToStarlarkValue for TaskOutputOwned {
167 fn to_starlark_value<'v>(
168 &self,
169 heap: &'v StarlarkHeap,
170 ) -> StarlarkValue<'v> {
171 match self {
172 TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
173 TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
174 TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
175 TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
176 }
177 }
178}
179
180impl FromStarlarkValue for TaskOutputOwned {
181 fn from_starlark_value(
182 value: &StarlarkValue,
183 ) -> Result<Self, ExpressionError> {
184 use starlark::values::float::UnpackFloat;
185 if value.is_none() {
186 return Ok(TaskOutputOwned::Err {
187 error: serde_json::Value::Null,
188 });
189 }
190 if let Some(list) = starlark::values::list::ListRef::from_value(*value)
191 {
192 let mut all_numeric = true;
194 let mut all_lists = true;
195 let mut decimals = Vec::with_capacity(list.len());
196 let mut vecs = Vec::with_capacity(list.len());
197
198 for v in list.iter() {
199 if let Some(inner_list) =
200 starlark::values::list::ListRef::from_value(v)
201 {
202 let mut inner_decimals =
204 Vec::with_capacity(inner_list.len());
205 let mut inner_all_numeric = true;
206 for iv in inner_list.iter() {
207 if let Ok(Some(i)) = i64::unpack_value(iv) {
208 inner_decimals
209 .push(rust_decimal::Decimal::from(i));
210 } else if let Ok(Some(UnpackFloat(f))) =
211 UnpackFloat::unpack_value(iv)
212 {
213 match rust_decimal::Decimal::try_from(f) {
214 Ok(d) => inner_decimals.push(d),
215 Err(_) => {
216 inner_all_numeric = false;
217 break;
218 }
219 }
220 } else {
221 inner_all_numeric = false;
222 break;
223 }
224 }
225 if inner_all_numeric {
226 vecs.push(inner_decimals);
227 } else {
228 all_lists = false;
229 }
230 all_numeric = false;
231 } else if let Ok(Some(i)) = i64::unpack_value(v) {
232 decimals.push(rust_decimal::Decimal::from(i));
233 all_lists = false;
234 } else if let Ok(Some(UnpackFloat(f))) =
235 UnpackFloat::unpack_value(v)
236 {
237 match rust_decimal::Decimal::try_from(f) {
238 Ok(d) => {
239 decimals.push(d);
240 all_lists = false;
241 }
242 Err(_) => {
243 all_numeric = false;
244 all_lists = false;
245 break;
246 }
247 }
248 } else {
249 all_numeric = false;
250 all_lists = false;
251 break;
252 }
253 }
254 if all_numeric && !decimals.is_empty() {
255 return Ok(TaskOutputOwned::Vector(decimals));
256 }
257 if all_numeric && decimals.is_empty() && list.len() == 0 {
258 return Ok(TaskOutputOwned::Vector(Vec::new()));
259 }
260 if all_lists && !vecs.is_empty() {
261 return Ok(TaskOutputOwned::Vectors(vecs));
262 }
263 if all_lists && vecs.is_empty() && list.len() == 0 {
264 return Ok(TaskOutputOwned::Vectors(Vec::new()));
265 }
266 }
267 if let Ok(Some(i)) = i64::unpack_value(*value) {
268 return Ok(TaskOutputOwned::Scalar(
269 rust_decimal::Decimal::from(i),
270 ));
271 }
272 if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
273 if let Ok(d) = rust_decimal::Decimal::try_from(f) {
274 return Ok(TaskOutputOwned::Scalar(d));
275 }
276 }
277 let v = serde_json::Value::from_starlark_value(value)?;
278 Ok(TaskOutputOwned::Err { error: v })
279 }
280}
281
282impl super::FromSpecial for TaskOutputOwned {
283 fn from_special(
284 special: &super::Special,
285 params: &super::Params,
286 ) -> Result<Self, super::ExpressionError> {
287 match special {
288 super::Special::Output => {
289 let output = params_output(params)?;
290 Ok(output.clone())
291 }
292 super::Special::TaskOutputL1Normalized => {
293 let output = params_output(params)?;
294 match output {
295 TaskOutputOwned::Scalar(_) => Ok(output.clone()),
296 TaskOutputOwned::Vector(v) => {
297 Ok(TaskOutputOwned::Vector(l1_normalize(v)))
298 }
299 TaskOutputOwned::Vectors(vecs) => {
300 Ok(TaskOutputOwned::Vectors(
301 vecs.iter().map(|v| l1_normalize(v)).collect(),
302 ))
303 }
304 TaskOutputOwned::Err { .. } => Ok(output.clone()),
305 }
306 }
307 super::Special::TaskOutputWeightedSum => {
308 let output = params_output(params)?;
309 match output {
310 TaskOutputOwned::Vector(scores) => {
311 Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
312 }
313 TaskOutputOwned::Vectors(vecs) => {
314 Ok(TaskOutputOwned::Vector(
315 vecs.iter()
316 .map(|scores| weighted_sum(scores))
317 .collect(),
318 ))
319 }
320 _ => Err(super::ExpressionError::UnsupportedSpecial),
321 }
322 }
323 _ => Err(super::ExpressionError::UnsupportedSpecial),
324 }
325 }
326}
327
328impl TaskOutputOwned {
329 pub fn into_err(self) -> Self {
331 match self {
332 Self::Scalar(scalar) => Self::Err {
333 error: serde_json::to_value(scalar).unwrap(),
334 },
335 Self::Vector(vector) => Self::Err {
336 error: serde_json::to_value(vector).unwrap(),
337 },
338 Self::Vectors(vectors) => Self::Err {
339 error: serde_json::to_value(vectors).unwrap(),
340 },
341 Self::Err { error } => Self::Err { error },
342 }
343 }
344}
345
346#[schema_override(Ref)]
348#[derive(Debug, Clone, PartialEq, Serialize)]
349#[serde(untagged)]
350pub enum TaskOutputRef<'a> {
351 Scalar(&'a rust_decimal::Decimal),
353 Vector(&'a [rust_decimal::Decimal]),
355 Vectors(&'a [Vec<rust_decimal::Decimal>]),
357 Err { error: &'a serde_json::Value },
359}
360
361impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
362 fn to_starlark_value<'v>(
363 &self,
364 heap: &'v StarlarkHeap,
365 ) -> StarlarkValue<'v> {
366 match self {
367 TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
368 TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
369 TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
370 TaskOutputRef::Err { error } => error.to_starlark_value(heap),
371 }
372 }
373}
374
375fn params_output<'a>(
376 params: &'a super::Params,
377) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
378 match params {
379 super::Params::Owned(o) => o
380 .output
381 .as_ref()
382 .ok_or(super::ExpressionError::UnsupportedSpecial),
383 super::Params::Ref(r) => match &r.output {
384 Some(TaskOutput::Owned(o)) => Ok(o),
385 Some(TaskOutput::Ref(_)) => {
386 Err(super::ExpressionError::UnsupportedSpecial)
389 }
390 None => Err(super::ExpressionError::UnsupportedSpecial),
391 },
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_task_output_deserialize_strict_err_wire_format() {
401 let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
403 assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
404
405 let parsed: TaskOutputOwned = serde_json::from_str("[1, 2, 3]").unwrap();
407 assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
408
409 let parsed: TaskOutputOwned = serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
411 assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
412
413 assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
416 assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
417 assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
418
419 let parsed: TaskOutputOwned =
422 serde_json::from_str(r#"{"error": "something"}"#).unwrap();
423 assert!(matches!(
424 parsed,
425 TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
426 ));
427
428 let parsed: TaskOutputOwned =
429 serde_json::from_str(r#"{"error": null}"#).unwrap();
430 assert!(matches!(
431 parsed,
432 TaskOutputOwned::Err { error: serde_json::Value::Null }
433 ));
434
435 let original = TaskOutputOwned::Err {
437 error: serde_json::Value::String("94".to_string()),
438 };
439 let json = serde_json::to_string(&original).unwrap();
440 assert_eq!(json, r#"{"error":"94"}"#);
441 let roundtripped: TaskOutputOwned = serde_json::from_str(&json).unwrap();
442 assert!(matches!(
443 roundtripped,
444 TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
445 ));
446
447 let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
449 assert!(
450 matches!(parsed, TaskOutputOwned::Vector(_))
451 || matches!(parsed, TaskOutputOwned::Vectors(_))
452 );
453 }
454}
455
456fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
457 if v.is_empty() {
458 return Vec::new();
459 }
460 let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
461 if sum.is_zero() {
462 let uniform =
463 rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
464 vec![uniform; v.len()]
465 } else {
466 v.iter().map(|d| d / sum).collect()
467 }
468}
469
470fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
473 let len = scores.len();
474 if len <= 1 {
475 return scores.iter().sum();
476 }
477 let mut ws = rust_decimal::Decimal::ZERO;
478 let last = len - 1;
479 for (i, score) in scores.iter().enumerate() {
480 let weight = rust_decimal::Decimal::from(i)
481 / rust_decimal::Decimal::from(last);
482 ws += score * weight;
483 }
484 ws
485}