use std::collections::HashMap;
use super::{
dynamic_types::{DynamicClass, DynamicEnum, DynamicUnion},
from_baml_value::FromBamlValue,
from_baml_value_ref::FromBamlValueRef,
known_types::KnownTypes,
};
use crate::{
error::{BamlError, BamlTypeName, FullTypeName},
types::{Checked, StreamState},
};
#[derive(Debug, Clone)]
pub enum BamlValue<T: KnownTypes, S: KnownTypes> {
String(String),
Int(i64),
Float(f64),
Bool(bool),
Null,
List(Vec<BamlValue<T, S>>),
Map(HashMap<String, BamlValue<T, S>>),
Known(T), StreamKnown(S),
Checked(Checked<Box<BamlValue<T, S>>>),
StreamState(StreamState<Box<BamlValue<T, S>>>),
DynamicClass(DynamicClass<T, S>),
DynamicEnum(DynamicEnum),
DynamicUnion(DynamicUnion<T, S>),
}
impl<T: KnownTypes, S: KnownTypes> FullTypeName for BamlValue<T, S> {
fn full_type_name(&self) -> String {
match self {
BamlValue::String(_) => String::baml_type_name(),
BamlValue::Int(_) => i64::baml_type_name(),
BamlValue::Float(_) => f64::baml_type_name(),
BamlValue::Bool(_) => bool::baml_type_name(),
BamlValue::Null => <()>::baml_type_name(),
BamlValue::List(_) => "List<?>".to_string(), BamlValue::Map(_) => "Map<String, ?>".to_string(),
BamlValue::Known(t) => t.type_name().to_string(),
BamlValue::StreamKnown(s) => s.type_name().to_string(),
BamlValue::Checked(c) => format!("Checked<{}>", c.value.full_type_name()),
BamlValue::StreamState(ss) => format!("StreamState<{}>", ss.value.full_type_name()),
BamlValue::DynamicClass(dc) => dc.full_type_name(),
BamlValue::DynamicEnum(de) => de.full_type_name(),
BamlValue::DynamicUnion(du) => du.full_type_name(),
}
}
}
impl<T: KnownTypes, S: KnownTypes> BamlValue<T, S> {
pub fn get<V: FromBamlValue<T, S>>(self) -> Result<V, BamlError> {
V::from_baml_value(self)
}
pub fn get_ref<'a, V: FromBamlValueRef<'a, T, S>>(&'a self) -> Result<V, BamlError> {
V::from_baml_value_ref(self)
}
}
use super::traits::{BamlDecode, BamlEncode};
use crate::proto::baml_cffi_v1::{
host_map_entry, host_value, HostClassValue, HostEnumValue, HostListValue, HostMapEntry,
HostMapValue, HostValue,
};
impl<T: KnownTypes, S: KnownTypes> BamlEncode for BamlValue<T, S> {
fn baml_encode(&self) -> HostValue {
match self {
BamlValue::Null => HostValue { value: None },
BamlValue::String(s) => HostValue {
value: Some(host_value::Value::StringValue(s.clone())),
},
BamlValue::Int(i) => HostValue {
value: Some(host_value::Value::IntValue(*i)),
},
BamlValue::Float(f) => HostValue {
value: Some(host_value::Value::FloatValue(*f)),
},
BamlValue::Bool(b) => HostValue {
value: Some(host_value::Value::BoolValue(*b)),
},
BamlValue::List(items) => HostValue {
value: Some(host_value::Value::ListValue(HostListValue {
values: items
.iter()
.map(super::traits::BamlEncode::baml_encode)
.collect(),
})),
},
BamlValue::Map(map) => HostValue {
value: Some(host_value::Value::MapValue(HostMapValue {
entries: map
.iter()
.map(|(k, v)| HostMapEntry {
key: Some(host_map_entry::Key::StringKey(k.clone())),
value: Some(v.baml_encode()),
})
.collect(),
})),
},
BamlValue::DynamicClass(dc) => {
let entries = dc
.fields()
.map(|(k, v)| HostMapEntry {
key: Some(host_map_entry::Key::StringKey(k.to_string())),
value: Some(v.baml_encode()),
})
.collect();
HostValue {
value: Some(host_value::Value::ClassValue(HostClassValue {
name: dc.name().to_string(),
fields: entries,
})),
}
}
BamlValue::DynamicEnum(de) => HostValue {
value: Some(host_value::Value::EnumValue(HostEnumValue {
name: de.name().to_string(),
value: de.value.clone(),
})),
},
BamlValue::DynamicUnion(du) => {
du.value.baml_encode()
}
BamlValue::Known(_) => {
HostValue { value: None }
}
BamlValue::StreamKnown(_) => {
HostValue { value: None }
}
BamlValue::Checked(c) => {
c.value.baml_encode()
}
BamlValue::StreamState(ss) => {
ss.value.baml_encode()
}
}
}
}
use crate::{
proto::baml_cffi_v1::{
cffi_field_type_literal, cffi_value_holder, CffiStreamState, CffiValueHolder,
},
types::{Check, CheckStatus, StreamingState},
};
impl<T: KnownTypes, S: KnownTypes> BamlDecode for BamlValue<T, S> {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::NullValue(_)) | None => Ok(BamlValue::Null),
Some(cffi_value_holder::Value::StringValue(s)) => Ok(BamlValue::String(s.clone())),
Some(cffi_value_holder::Value::IntValue(i)) => Ok(BamlValue::Int(*i)),
Some(cffi_value_holder::Value::FloatValue(f)) => Ok(BamlValue::Float(*f)),
Some(cffi_value_holder::Value::BoolValue(b)) => Ok(BamlValue::Bool(*b)),
Some(cffi_value_holder::Value::ListValue(list)) => {
let items = list
.items
.iter()
.map(Self::baml_decode)
.collect::<Result<Vec<_>, _>>()?;
Ok(BamlValue::List(items))
}
Some(cffi_value_holder::Value::MapValue(map)) => {
let mut result = HashMap::new();
for entry in &map.entries {
let value = entry
.value
.as_ref()
.ok_or_else(|| BamlError::internal("map entry missing value"))?;
result.insert(entry.key.clone(), Self::baml_decode(value)?);
}
Ok(BamlValue::Map(result))
}
Some(cffi_value_holder::Value::ClassValue(class)) => {
let name = class
.name
.as_ref()
.map(|n| n.name.clone())
.unwrap_or_default();
let mut fields = HashMap::new();
for entry in &class.fields {
if let Some(value) = &entry.value {
fields.insert(entry.key.clone(), Self::baml_decode(value)?);
}
}
Ok(BamlValue::DynamicClass(DynamicClass::with_fields(
name, fields,
)))
}
Some(cffi_value_holder::Value::EnumValue(e)) => {
let name = e.name.as_ref().map(|n| n.name.clone()).unwrap_or_default();
Ok(BamlValue::DynamicEnum(DynamicEnum {
name,
value: e.value.clone(),
}))
}
Some(cffi_value_holder::Value::UnionVariantValue(union)) => {
let inner = union
.value
.as_ref()
.ok_or_else(|| BamlError::internal("union variant missing value"))?;
let decoded_value = Self::baml_decode(inner)?;
if union.is_single_pattern {
Ok(decoded_value)
} else {
let name = union
.name
.as_ref()
.map(|n| n.name.clone())
.unwrap_or_default();
let variant_name = union.value_option_name.clone();
Ok(BamlValue::DynamicUnion(DynamicUnion {
name,
variant_name,
value: Box::new(decoded_value),
}))
}
}
Some(cffi_value_holder::Value::CheckedValue(checked)) => {
let inner = checked
.value
.as_ref()
.ok_or_else(|| BamlError::internal("checked value missing inner"))?;
let value = Box::new(Self::baml_decode(inner)?);
let checks = checked
.checks
.iter()
.map(|c| {
(
c.name.clone(),
Check {
name: c.name.clone(),
expression: c.expression.clone(),
status: match c.status.as_str() {
"passed" | "PASSED" => CheckStatus::Succeeded,
_ => CheckStatus::Failed,
},
},
)
})
.collect();
Ok(BamlValue::Checked(Checked { value, checks }))
}
Some(cffi_value_holder::Value::StreamingStateValue(ss)) => {
let inner = ss
.value
.as_ref()
.ok_or_else(|| BamlError::internal("stream state missing value"))?;
let value = Box::new(Self::baml_decode(inner)?);
let state = match ss.state() {
CffiStreamState::Pending => StreamingState::Pending,
CffiStreamState::Started => StreamingState::Started,
CffiStreamState::Done => StreamingState::Done,
};
Ok(BamlValue::StreamState(StreamState { value, state }))
}
Some(cffi_value_holder::Value::LiteralValue(lit)) => {
match &lit.literal {
Some(cffi_field_type_literal::Literal::StringLiteral(s)) => {
Ok(BamlValue::String(s.value.clone()))
}
Some(cffi_field_type_literal::Literal::IntLiteral(i)) => {
Ok(BamlValue::Int(i.value))
}
Some(cffi_field_type_literal::Literal::BoolLiteral(b)) => {
Ok(BamlValue::Bool(b.value))
}
None => Ok(BamlValue::Null),
}
}
Some(cffi_value_holder::Value::ObjectValue(_)) => {
Err(BamlError::internal(
"ObjectValue cannot be decoded to BamlValue",
))
}
}
}
}
impl<T: KnownTypes + serde::Serialize, S: KnownTypes + serde::Serialize> serde::Serialize
for BamlValue<T, S>
{
fn serialize<Ser: serde::Serializer>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error> {
match self {
BamlValue::Null => serializer.serialize_none(),
BamlValue::String(s) => serializer.serialize_str(s),
BamlValue::Int(i) => serializer.serialize_i64(*i),
BamlValue::Float(f) => serializer.serialize_f64(*f),
BamlValue::Bool(b) => serializer.serialize_bool(*b),
BamlValue::List(items) => items.serialize(serializer),
BamlValue::Map(map) => map.serialize(serializer),
BamlValue::DynamicClass(dc) => dc.fields.serialize(serializer),
BamlValue::DynamicEnum(de) => serializer.serialize_str(&de.value),
BamlValue::DynamicUnion(du) => du.serialize(serializer),
BamlValue::Known(known) => known.serialize(serializer),
BamlValue::StreamKnown(known) => known.serialize(serializer),
BamlValue::Checked(c) => c.serialize(serializer),
BamlValue::StreamState(ss) => ss.serialize(serializer),
}
}
}
impl<'de, T: KnownTypes + serde::Deserialize<'de>, S: KnownTypes + serde::Deserialize<'de>>
serde::Deserialize<'de> for BamlValue<T, S>
{
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct Visitor<'de, T, S>(std::marker::PhantomData<(&'de (), T, S)>)
where
T: KnownTypes + serde::Deserialize<'de>,
S: KnownTypes + serde::Deserialize<'de>;
impl<'de, T, S> serde::de::Visitor<'de> for Visitor<'de, T, S>
where
T: KnownTypes + serde::Deserialize<'de>,
S: KnownTypes + serde::Deserialize<'de>,
{
type Value = BamlValue<T, S>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("BAML value")
}
fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(BamlValue::Null)
}
fn visit_some<D: serde::Deserializer<'de>>(
self,
deserializer: D,
) -> Result<Self::Value, D::Error> {
deserializer.deserialize_any(Visitor(std::marker::PhantomData))
}
fn visit_unit<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(BamlValue::Null)
}
fn visit_seq<A: serde::de::SeqAccess<'de>>(
self,
mut seq: A,
) -> Result<Self::Value, A::Error> {
let mut list = Vec::new();
while let Some(value) = seq.next_element()? {
list.push(value);
}
Ok(BamlValue::List(list))
}
fn visit_map<A: serde::de::MapAccess<'de>>(
self,
mut map: A,
) -> Result<Self::Value, A::Error> {
let mut hashmap: HashMap<String, BamlValue<T, S>> = HashMap::new();
while let Some((key, value)) = map.next_entry()? {
hashmap.insert(key, value);
}
Ok(BamlValue::Map(hashmap))
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(BamlValue::String(v.to_string()))
}
fn visit_i64<E: serde::de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(Self::Value::Int(v))
}
fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
let v = i64::try_from(v).map_err(|_| {
E::invalid_value(
serde::de::Unexpected::Unsigned(v),
&"a value fitting in an i64",
)
})?;
Ok(Self::Value::Int(v))
}
fn visit_bool<E: serde::de::Error>(self, v: bool) -> Result<Self::Value, E> {
Ok(Self::Value::Bool(v))
}
fn visit_f64<E: serde::de::Error>(self, v: f64) -> Result<Self::Value, E> {
Ok(Self::Value::Float(v))
}
}
deserializer.deserialize_any(Visitor(std::marker::PhantomData))
}
}