use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::{
codec::{BamlDecode, BamlEncode},
error::BamlError,
proto::baml_cffi_v1::{cffi_value_holder, CffiStreamState, CffiValueHolder, HostValue},
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checked<T> {
pub value: T,
pub checks: HashMap<String, Check>,
}
impl<T: Default> Default for Checked<T> {
fn default() -> Self {
Self {
value: T::default(),
checks: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct Check {
pub name: String,
pub expression: String,
pub status: CheckStatus,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CheckStatus {
Succeeded,
Failed,
}
impl<T: BamlDecode> BamlDecode for Checked<T> {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::CheckedValue(checked)) => {
let inner = checked
.value
.as_ref()
.ok_or_else(|| BamlError::internal("missing checked value"))?;
let value = T::baml_decode(inner)?;
let checks = checked
.checks
.iter()
.map(|c| {
Ok((
c.name.clone(),
Check {
name: c.name.clone(),
expression: c.expression.clone(),
status: match c.status.as_str() {
"succeeded" => CheckStatus::Succeeded,
"failed" => CheckStatus::Failed,
_ => {
return Err(BamlError::internal(format!(
"invalid check status: {}",
c.status
)));
}
},
},
))
})
.collect::<Result<HashMap<String, Check>, BamlError>>()?;
Ok(Checked { value, checks })
}
other => Err(BamlError::internal(format!(
"expected checked value, got {:?}",
other.is_some()
))),
}
}
}
impl<T: BamlEncode> BamlEncode for Checked<T> {
fn baml_encode(&self) -> HostValue {
self.value.baml_encode()
}
}
impl<T> Checked<T> {
pub fn all_passed(&self) -> bool {
self.checks
.values()
.all(|c| c.status == CheckStatus::Succeeded)
}
pub fn any_failed(&self) -> bool {
self.checks
.values()
.any(|c| c.status == CheckStatus::Failed)
}
pub fn get_check(&self, name: &str) -> Option<&Check> {
self.checks.get(name)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamState<T> {
pub value: T,
pub state: StreamingState,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum StreamingState {
Pending,
Started,
Done,
}
impl<T> StreamState<T> {
pub fn new(value: T) -> Self {
Self {
value,
state: StreamingState::Pending,
}
}
}
impl<T: Default> Default for StreamState<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T: BamlDecode> BamlDecode for StreamState<T> {
fn baml_decode(holder: &CffiValueHolder) -> Result<Self, BamlError> {
match &holder.value {
Some(cffi_value_holder::Value::StreamingStateValue(ss)) => {
let inner = ss
.value
.as_ref()
.ok_or_else(|| BamlError::internal("missing stream state value"))?;
let value = T::baml_decode(inner)?;
let state = match ss.state() {
CffiStreamState::Pending => StreamingState::Pending,
CffiStreamState::Started => StreamingState::Started,
CffiStreamState::Done => StreamingState::Done,
};
Ok(StreamState { value, state })
}
other => Err(BamlError::internal(format!(
"expected stream state value, got {:?}",
other.is_some()
))),
}
}
}