use std::collections::HashMap;
use serde_json::Value as JsonValue;
use super::{
helpers::variant_name,
traits::{BamlDecode, BamlEncode},
};
use crate::{
__internal::cffi_field_type_literal,
error::BamlError,
proto::baml_cffi_v1::{
cffi_value_holder, host_map_entry, host_value, CffiValueHolder, HostListValue,
HostMapEntry, HostMapValue, HostValue,
},
};
impl BamlDecode for String {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::StringValue(s)) => Ok(s.clone()),
Some(cffi_value_holder::Value::LiteralValue(l)) => match &l.literal {
Some(cffi_field_type_literal::Literal::StringLiteral(s)) => Ok(s.value.clone()),
_ => Err(BamlError::internal(format!(
"expected string, got {:?}",
holder.value.as_ref().map(variant_name)
))),
},
other => Err(BamlError::internal(format!(
"expected string, got {:?}",
other.as_ref().map(variant_name)
))),
}
}
}
impl BamlDecode for i64 {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::IntValue(i)) => Ok(*i),
Some(cffi_value_holder::Value::LiteralValue(l)) => match &l.literal {
Some(cffi_field_type_literal::Literal::IntLiteral(i)) => Ok(i.value),
_ => Err(BamlError::internal(format!(
"expected int, got {:?}",
holder.value.as_ref().map(variant_name)
))),
},
other => Err(BamlError::internal(format!(
"expected int, got {:?}",
other.as_ref().map(variant_name)
))),
}
}
}
impl BamlDecode for f64 {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::FloatValue(f)) => Ok(*f),
other => Err(BamlError::internal(format!(
"expected float, got {:?}",
other.as_ref().map(variant_name)
))),
}
}
}
impl BamlDecode for bool {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::BoolValue(b)) => Ok(*b),
Some(cffi_value_holder::Value::LiteralValue(l)) => match &l.literal {
Some(cffi_field_type_literal::Literal::BoolLiteral(b)) => Ok(b.value),
_ => Err(BamlError::internal(format!(
"expected bool, got {:?}",
holder.value.as_ref().map(variant_name)
))),
},
other => Err(BamlError::internal(format!(
"expected bool, got {:?}",
other.as_ref().map(variant_name)
))),
}
}
}
impl BamlDecode for () {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::NullValue(_)) | None => Ok(()),
other => Err(BamlError::internal(format!(
"expected null/void, got {:?} - {:?}",
other.as_ref().map(variant_name),
holder,
))),
}
}
}
impl BamlEncode for String {
fn baml_encode(&self) -> HostValue {
HostValue {
value: Some(host_value::Value::StringValue(self.clone())),
}
}
}
impl BamlEncode for &str {
fn baml_encode(&self) -> HostValue {
HostValue {
value: Some(host_value::Value::StringValue((*self).to_string())),
}
}
}
impl BamlEncode for i64 {
fn baml_encode(&self) -> HostValue {
HostValue {
value: Some(host_value::Value::IntValue(*self)),
}
}
}
impl BamlEncode for i32 {
fn baml_encode(&self) -> HostValue {
i64::from(*self).baml_encode()
}
}
impl BamlEncode for f64 {
fn baml_encode(&self) -> HostValue {
HostValue {
value: Some(host_value::Value::FloatValue(*self)),
}
}
}
impl BamlEncode for bool {
fn baml_encode(&self) -> HostValue {
HostValue {
value: Some(host_value::Value::BoolValue(*self)),
}
}
}
impl BamlEncode for () {
fn baml_encode(&self) -> HostValue {
HostValue { value: None }
}
}
impl<T: BamlEncode> BamlEncode for &T {
fn baml_encode(&self) -> HostValue {
(*self).baml_encode()
}
}
impl BamlEncode for HostValue {
fn baml_encode(&self) -> HostValue {
self.clone()
}
}
impl BamlEncode for JsonValue {
fn baml_encode(&self) -> HostValue {
let inner = match self {
JsonValue::Null => None,
JsonValue::Bool(b) => Some(host_value::Value::BoolValue(*b)),
JsonValue::Number(n) => {
if let Some(i) = n.as_i64() {
Some(host_value::Value::IntValue(i))
} else if let Some(f) = n.as_f64() {
Some(host_value::Value::FloatValue(f))
} else {
Some(host_value::Value::StringValue(n.to_string()))
}
}
JsonValue::String(s) => Some(host_value::Value::StringValue(s.clone())),
JsonValue::Array(arr) => {
let values = arr.iter().map(BamlEncode::baml_encode).collect();
Some(host_value::Value::ListValue(HostListValue { values }))
}
JsonValue::Object(obj) => {
let entries = obj
.iter()
.map(|(k, v)| HostMapEntry {
key: Some(host_map_entry::Key::StringKey(k.clone())),
value: Some(v.baml_encode()),
})
.collect();
Some(host_value::Value::MapValue(HostMapValue { entries }))
}
};
HostValue { value: inner }
}
}
impl BamlDecode for JsonValue {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::NullValue(_)) | None => Ok(JsonValue::Null),
Some(cffi_value_holder::Value::StringValue(s)) => Ok(JsonValue::String(s.clone())),
Some(cffi_value_holder::Value::IntValue(i)) => Ok(JsonValue::Number((*i).into())),
Some(cffi_value_holder::Value::FloatValue(f)) => {
if let Some(number) = serde_json::Number::from_f64(*f) {
Ok(JsonValue::Number(number))
} else {
Err(BamlError::internal(format!(
"failed to convert float to json number: {}",
*f
)))
}
}
Some(cffi_value_holder::Value::BoolValue(b)) => Ok(JsonValue::Bool(*b)),
Some(cffi_value_holder::Value::ListValue(list)) => {
let items = list
.items
.iter()
.map(Self::baml_decode)
.collect::<Result<Vec<_>, _>>()?;
Ok(JsonValue::Array(items))
}
Some(cffi_value_holder::Value::MapValue(map)) => {
let mut result = serde_json::Map::new();
for entry in &map.entries {
let value = entry
.value
.as_ref()
.ok_or_else(|| BamlError::internal("map entry missing value"))?;
result.insert(entry.key.clone(), Self::baml_decode(value)?);
}
Ok(JsonValue::Object(result))
}
Some(cffi_value_holder::Value::ClassValue(class)) => {
let mut fields = serde_json::Map::new();
for entry in &class.fields {
if let Some(value) = &entry.value {
fields.insert(entry.key.clone(), Self::baml_decode(value)?);
}
}
Ok(JsonValue::Object(fields))
}
Some(cffi_value_holder::Value::EnumValue(e)) => {
Ok(JsonValue::String(e.value.clone()))
}
Some(cffi_value_holder::Value::UnionVariantValue(union)) => {
let inner = union
.value
.as_ref()
.ok_or_else(|| BamlError::internal("union variant missing value"))?;
let decoded_value = Self::baml_decode(inner)?;
Ok(decoded_value)
}
Some(cffi_value_holder::Value::CheckedValue(checked)) => {
let inner = checked
.value
.as_ref()
.ok_or_else(|| BamlError::internal("checked value missing inner"))?;
let value = Box::new(Self::baml_decode(inner)?);
let checks = checked
.checks
.iter()
.map(|c| {
(
c.name.clone(),
crate::Check {
name: c.name.clone(),
expression: c.expression.clone(),
status: match c.status.as_str() {
"passed" | "PASSED" => crate::CheckStatus::Succeeded,
_ => crate::CheckStatus::Failed,
},
},
)
})
.collect::<HashMap<String, crate::Check>>();
Ok(serde_json::json!({ "value": value, "checks": checks }))
}
Some(cffi_value_holder::Value::StreamingStateValue(ss)) => {
let inner = ss
.value
.as_ref()
.ok_or_else(|| BamlError::internal("stream state missing value"))?;
let value = Box::new(Self::baml_decode(inner)?);
let state = match ss.state() {
crate::__internal::CffiStreamState::Pending => crate::StreamingState::Pending,
crate::__internal::CffiStreamState::Started => crate::StreamingState::Started,
crate::__internal::CffiStreamState::Done => crate::StreamingState::Done,
};
Ok(serde_json::json!({ "value": value, "state": state }))
}
Some(cffi_value_holder::Value::LiteralValue(lit)) => {
match &lit.literal {
Some(cffi_field_type_literal::Literal::StringLiteral(s)) => {
Ok(JsonValue::String(s.value.clone()))
}
Some(cffi_field_type_literal::Literal::IntLiteral(i)) => {
Ok(JsonValue::Number(i.value.into()))
}
Some(cffi_field_type_literal::Literal::BoolLiteral(b)) => {
Ok(JsonValue::Bool(b.value))
}
None => Ok(JsonValue::Null),
}
}
Some(cffi_value_holder::Value::ObjectValue(_)) => {
Err(BamlError::internal(
"ObjectValue cannot be decoded to BamlValue",
))
}
}
}
}