use super::{ExpressionError, FromStarlarkValue, ToStarlarkValue};
use objectiveai_sdk_macros::schema_override;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use starlark::values::{
Heap as StarlarkHeap, UnpackValue, Value as StarlarkValue,
};
#[schema_override(RefOwnedEnum)]
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(untagged)]
pub enum Params<'i, 'to> {
Owned(ParamsOwned),
Ref(ParamsRef<'i, 'to>),
}
impl JsonSchema for Params<'static, 'static> {
fn schema_name() -> std::borrow::Cow<'static, str> {
ParamsOwned::schema_name()
}
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
ParamsOwned::json_schema(generator)
}
}
impl<'de> serde::Deserialize<'de> for Params<'static, 'static> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let owned = ParamsOwned::deserialize(deserializer)?;
Ok(Params::Owned(owned))
}
}
#[schema_override(Owned)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema)]
#[schemars(rename = "functions.expression.Params")]
pub struct ParamsOwned {
pub input: super::InputValue,
pub output: Option<TaskOutputOwned>,
pub map: Option<u64>,
pub tasks_min: Option<u64>,
pub tasks_max: Option<u64>,
pub depth: Option<u64>,
pub name: Option<String>,
pub spec: Option<String>,
}
#[schema_override(Ref)]
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct ParamsRef<'i, 'to> {
pub input: &'i super::InputValue,
pub output: Option<TaskOutput<'to>>,
pub map: Option<u64>,
pub tasks_min: Option<u64>,
pub tasks_max: Option<u64>,
pub depth: Option<u64>,
pub name: Option<&'i str>,
pub spec: Option<&'i str>,
}
#[schema_override(RefOwnedEnum)]
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(untagged)]
pub enum TaskOutput<'a> {
Owned(TaskOutputOwned),
Ref(TaskOutputRef<'a>),
}
impl JsonSchema for TaskOutput<'static> {
fn schema_name() -> std::borrow::Cow<'static, str> {
TaskOutputOwned::schema_name()
}
fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
TaskOutputOwned::json_schema(generator)
}
}
impl<'a> super::ToStarlarkValue for TaskOutput<'a> {
fn to_starlark_value<'v>(
&self,
heap: &'v StarlarkHeap,
) -> StarlarkValue<'v> {
match self {
TaskOutput::Owned(o) => o.to_starlark_value(heap),
TaskOutput::Ref(r) => r.to_starlark_value(heap),
}
}
}
impl<'de> serde::Deserialize<'de> for TaskOutput<'static> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let owned = TaskOutputOwned::deserialize(deserializer)?;
Ok(TaskOutput::Owned(owned))
}
}
#[schema_override(Owned)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, JsonSchema, arbitrary::Arbitrary)]
#[serde(untagged)]
#[schemars(rename = "functions.expression.TaskOutput")]
pub enum TaskOutputOwned {
#[schemars(title = "Scalar")]
Scalar(#[serde(deserialize_with = "crate::serde_util::decimal")] #[schemars(with = "f64")] #[arbitrary(with = crate::arbitrary_util::arbitrary_rust_decimal)] rust_decimal::Decimal),
#[schemars(title = "Vector")]
Vector(#[serde(deserialize_with = "crate::serde_util::vec_decimal")] #[schemars(with = "Vec<f64>")] #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_rust_decimal)] Vec<rust_decimal::Decimal>),
#[schemars(title = "Vectors")]
Vectors(#[serde(deserialize_with = "crate::serde_util::vec_vec_decimal")] #[schemars(with = "Vec<Vec<f64>>")] #[arbitrary(with = crate::arbitrary_util::arbitrary_vec_vec_rust_decimal)] Vec<Vec<rust_decimal::Decimal>>),
#[schemars(title = "Err")]
Err {
#[arbitrary(with = crate::arbitrary_util::arbitrary_json_value)]
error: serde_json::Value,
},
}
impl ToStarlarkValue for TaskOutputOwned {
fn to_starlark_value<'v>(
&self,
heap: &'v StarlarkHeap,
) -> StarlarkValue<'v> {
match self {
TaskOutputOwned::Scalar(d) => d.to_starlark_value(heap),
TaskOutputOwned::Vector(ds) => ds.to_starlark_value(heap),
TaskOutputOwned::Vectors(vecs) => vecs.to_starlark_value(heap),
TaskOutputOwned::Err { error } => error.to_starlark_value(heap),
}
}
}
impl FromStarlarkValue for TaskOutputOwned {
fn from_starlark_value(
value: &StarlarkValue,
) -> Result<Self, ExpressionError> {
use starlark::values::float::UnpackFloat;
if value.is_none() {
return Ok(TaskOutputOwned::Err {
error: serde_json::Value::Null,
});
}
if let Some(list) = starlark::values::list::ListRef::from_value(*value)
{
let mut all_numeric = true;
let mut all_lists = true;
let mut decimals = Vec::with_capacity(list.len());
let mut vecs = Vec::with_capacity(list.len());
for v in list.iter() {
if let Some(inner_list) =
starlark::values::list::ListRef::from_value(v)
{
let mut inner_decimals =
Vec::with_capacity(inner_list.len());
let mut inner_all_numeric = true;
for iv in inner_list.iter() {
if let Ok(Some(i)) = i64::unpack_value(iv) {
inner_decimals
.push(rust_decimal::Decimal::from(i));
} else if let Ok(Some(UnpackFloat(f))) =
UnpackFloat::unpack_value(iv)
{
match rust_decimal::Decimal::try_from(f) {
Ok(d) => inner_decimals.push(d),
Err(_) => {
inner_all_numeric = false;
break;
}
}
} else {
inner_all_numeric = false;
break;
}
}
if inner_all_numeric {
vecs.push(inner_decimals);
} else {
all_lists = false;
}
all_numeric = false;
} else if let Ok(Some(i)) = i64::unpack_value(v) {
decimals.push(rust_decimal::Decimal::from(i));
all_lists = false;
} else if let Ok(Some(UnpackFloat(f))) =
UnpackFloat::unpack_value(v)
{
match rust_decimal::Decimal::try_from(f) {
Ok(d) => {
decimals.push(d);
all_lists = false;
}
Err(_) => {
all_numeric = false;
all_lists = false;
break;
}
}
} else {
all_numeric = false;
all_lists = false;
break;
}
}
if all_numeric && !decimals.is_empty() {
return Ok(TaskOutputOwned::Vector(decimals));
}
if all_numeric && decimals.is_empty() && list.len() == 0 {
return Ok(TaskOutputOwned::Vector(Vec::new()));
}
if all_lists && !vecs.is_empty() {
return Ok(TaskOutputOwned::Vectors(vecs));
}
if all_lists && vecs.is_empty() && list.len() == 0 {
return Ok(TaskOutputOwned::Vectors(Vec::new()));
}
}
if let Ok(Some(i)) = i64::unpack_value(*value) {
return Ok(TaskOutputOwned::Scalar(
rust_decimal::Decimal::from(i),
));
}
if let Ok(Some(UnpackFloat(f))) = UnpackFloat::unpack_value(*value) {
if let Ok(d) = rust_decimal::Decimal::try_from(f) {
return Ok(TaskOutputOwned::Scalar(d));
}
}
let v = serde_json::Value::from_starlark_value(value)?;
Ok(TaskOutputOwned::Err { error: v })
}
}
impl super::FromSpecial for TaskOutputOwned {
fn from_special(
special: &super::Special,
params: &super::Params,
) -> Result<Self, super::ExpressionError> {
match special {
super::Special::Output => {
let output = params_output(params)?;
Ok(output.clone())
}
super::Special::TaskOutputL1Normalized => {
let output = params_output(params)?;
match output {
TaskOutputOwned::Scalar(_) => Ok(output.clone()),
TaskOutputOwned::Vector(v) => {
Ok(TaskOutputOwned::Vector(l1_normalize(v)))
}
TaskOutputOwned::Vectors(vecs) => {
Ok(TaskOutputOwned::Vectors(
vecs.iter().map(|v| l1_normalize(v)).collect(),
))
}
TaskOutputOwned::Err { .. } => Ok(output.clone()),
}
}
super::Special::TaskOutputWeightedSum => {
let output = params_output(params)?;
match output {
TaskOutputOwned::Vector(scores) => {
Ok(TaskOutputOwned::Scalar(weighted_sum(scores)))
}
TaskOutputOwned::Vectors(vecs) => {
Ok(TaskOutputOwned::Vector(
vecs.iter()
.map(|scores| weighted_sum(scores))
.collect(),
))
}
_ => Err(super::ExpressionError::UnsupportedSpecial),
}
}
_ => Err(super::ExpressionError::UnsupportedSpecial),
}
}
}
impl TaskOutputOwned {
pub fn into_err(self) -> Self {
match self {
Self::Scalar(scalar) => Self::Err {
error: serde_json::to_value(scalar).unwrap(),
},
Self::Vector(vector) => Self::Err {
error: serde_json::to_value(vector).unwrap(),
},
Self::Vectors(vectors) => Self::Err {
error: serde_json::to_value(vectors).unwrap(),
},
Self::Err { error } => Self::Err { error },
}
}
}
#[schema_override(Ref)]
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(untagged)]
pub enum TaskOutputRef<'a> {
Scalar(&'a rust_decimal::Decimal),
Vector(&'a [rust_decimal::Decimal]),
Vectors(&'a [Vec<rust_decimal::Decimal>]),
Err { error: &'a serde_json::Value },
}
impl<'a> ToStarlarkValue for TaskOutputRef<'a> {
fn to_starlark_value<'v>(
&self,
heap: &'v StarlarkHeap,
) -> StarlarkValue<'v> {
match self {
TaskOutputRef::Scalar(d) => d.to_starlark_value(heap),
TaskOutputRef::Vector(ds) => ds.to_starlark_value(heap),
TaskOutputRef::Vectors(vecs) => vecs.to_starlark_value(heap),
TaskOutputRef::Err { error } => error.to_starlark_value(heap),
}
}
}
fn params_output<'a>(
params: &'a super::Params,
) -> Result<&'a TaskOutputOwned, super::ExpressionError> {
match params {
super::Params::Owned(o) => o
.output
.as_ref()
.ok_or(super::ExpressionError::UnsupportedSpecial),
super::Params::Ref(r) => match &r.output {
Some(TaskOutput::Owned(o)) => Ok(o),
Some(TaskOutput::Ref(_)) => {
Err(super::ExpressionError::UnsupportedSpecial)
}
None => Err(super::ExpressionError::UnsupportedSpecial),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_output_deserialize_strict_err_wire_format() {
let parsed: TaskOutputOwned = serde_json::from_str("94").unwrap();
assert!(matches!(parsed, TaskOutputOwned::Scalar(_)));
let parsed: TaskOutputOwned = serde_json::from_str("[1, 2, 3]").unwrap();
assert!(matches!(parsed, TaskOutputOwned::Vector(_)));
let parsed: TaskOutputOwned = serde_json::from_str("[[1, 2], [3, 4]]").unwrap();
assert!(matches!(parsed, TaskOutputOwned::Vectors(_)));
assert!(serde_json::from_str::<TaskOutputOwned>("null").is_err());
assert!(serde_json::from_str::<TaskOutputOwned>("true").is_err());
assert!(serde_json::from_str::<TaskOutputOwned>(r#""94""#).is_err());
let parsed: TaskOutputOwned =
serde_json::from_str(r#"{"error": "something"}"#).unwrap();
assert!(matches!(
parsed,
TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "something"
));
let parsed: TaskOutputOwned =
serde_json::from_str(r#"{"error": null}"#).unwrap();
assert!(matches!(
parsed,
TaskOutputOwned::Err { error: serde_json::Value::Null }
));
let original = TaskOutputOwned::Err {
error: serde_json::Value::String("94".to_string()),
};
let json = serde_json::to_string(&original).unwrap();
assert_eq!(json, r#"{"error":"94"}"#);
let roundtripped: TaskOutputOwned = serde_json::from_str(&json).unwrap();
assert!(matches!(
roundtripped,
TaskOutputOwned::Err { error: serde_json::Value::String(ref s) } if s == "94"
));
let parsed: TaskOutputOwned = serde_json::from_str("[]").unwrap();
assert!(
matches!(parsed, TaskOutputOwned::Vector(_))
|| matches!(parsed, TaskOutputOwned::Vectors(_))
);
}
}
fn l1_normalize(v: &[rust_decimal::Decimal]) -> Vec<rust_decimal::Decimal> {
if v.is_empty() {
return Vec::new();
}
let sum: rust_decimal::Decimal = v.iter().map(|d| d.abs()).sum();
if sum.is_zero() {
let uniform =
rust_decimal::Decimal::ONE / rust_decimal::Decimal::from(v.len());
vec![uniform; v.len()]
} else {
v.iter().map(|d| d / sum).collect()
}
}
fn weighted_sum(scores: &[rust_decimal::Decimal]) -> rust_decimal::Decimal {
let len = scores.len();
if len <= 1 {
return scores.iter().sum();
}
let mut ws = rust_decimal::Decimal::ZERO;
let last = len - 1;
for (i, score) in scores.iter().enumerate() {
let weight = rust_decimal::Decimal::from(i)
/ rust_decimal::Decimal::from(last);
ws += score * weight;
}
ws
}