use crate::ir::{BamlValue, Class, FieldType, IR};
use serde::{Serialize, Serializer};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct StreamingBamlValue {
pub value: BamlValue,
pub completion_state: CompletionState,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum CompletionState {
Complete,
Partial,
Pending,
}
impl Serialize for StreamingBamlValue {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("StreamingValue", 2)?;
state.serialize_field("value", &self.value)?;
state.serialize_field("state", &self.completion_state)?;
state.end()
}
}
impl StreamingBamlValue {
pub fn new(value: BamlValue, state: CompletionState) -> Self {
Self {
value,
completion_state: state,
}
}
pub fn from_ir_skeleton(ir: &IR, field_type: &FieldType) -> Self {
let value = create_skeleton_value(ir, field_type);
Self::new(value, CompletionState::Pending)
}
pub fn update_from_partial(&mut self, ir: &IR, partial_value: BamlValue, field_type: &FieldType) {
merge_values(&mut self.value, partial_value, ir, field_type);
self.completion_state = CompletionState::Partial;
}
pub fn mark_complete(&mut self) {
self.completion_state = CompletionState::Complete;
}
}
fn create_skeleton_value(ir: &IR, field_type: &FieldType) -> BamlValue {
match field_type {
FieldType::String => BamlValue::Null,
FieldType::Int => BamlValue::Null,
FieldType::Float => BamlValue::Null,
FieldType::Bool => BamlValue::Null,
FieldType::Enum(_) => BamlValue::Null,
FieldType::TaggedEnum(_) => BamlValue::Null,
FieldType::Class(class_name) => {
if let Some(class) = ir.find_class(class_name) {
create_skeleton_class(ir, class)
} else {
BamlValue::Null
}
}
FieldType::List(_) => BamlValue::List(vec![]),
FieldType::Map(_, _) => BamlValue::Map(HashMap::new()),
FieldType::Union(types) => {
if let Some(first_type) = types.first() {
create_skeleton_value(ir, first_type)
} else {
BamlValue::Null
}
}
}
}
fn create_skeleton_class(ir: &IR, class: &Class) -> BamlValue {
let mut fields = HashMap::new();
for field in &class.fields {
let field_value = if field.optional {
BamlValue::Null
} else {
create_skeleton_value(ir, &field.field_type)
};
fields.insert(field.name.clone(), field_value);
}
BamlValue::Map(fields)
}
fn merge_values(
target: &mut BamlValue,
source: BamlValue,
ir: &IR,
field_type: &FieldType,
) {
match (target, source) {
(target @ BamlValue::Null, source @ BamlValue::String(_)) => *target = source,
(target @ BamlValue::Null, source @ BamlValue::Int(_)) => *target = source,
(target @ BamlValue::Null, source @ BamlValue::Float(_)) => *target = source,
(target @ BamlValue::Null, source @ BamlValue::Bool(_)) => *target = source,
(target @ BamlValue::String(_), source @ BamlValue::String(_)) => *target = source,
(target @ BamlValue::Int(_), source @ BamlValue::Int(_)) => *target = source,
(target @ BamlValue::Float(_), source @ BamlValue::Float(_)) => *target = source,
(target @ BamlValue::Bool(_), source @ BamlValue::Bool(_)) => *target = source,
(BamlValue::Map(target_map), BamlValue::Map(source_map)) => {
if let FieldType::Class(class_name) = field_type {
if let Some(class) = ir.find_class(class_name) {
for (key, source_value) in source_map {
if let Some(field) = class.fields.iter().find(|f| f.name == key) {
if let Some(target_value) = target_map.get_mut(&key) {
merge_values(target_value, source_value, ir, &field.field_type);
} else {
target_map.insert(key, source_value);
}
}
}
return;
}
}
for (key, value) in source_map {
target_map.insert(key, value);
}
}
(target @ BamlValue::List(_), source @ BamlValue::List(_)) => *target = source,
(_, BamlValue::Null) => {}
(target, source) => *target = source,
}
}
pub trait StreamingCapable {
fn create_skeleton(ir: &IR, field_type: &FieldType) -> StreamingBamlValue {
StreamingBamlValue::from_ir_skeleton(ir, field_type)
}
}
impl StreamingCapable for StreamingBamlValue {}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{Field, FieldType};
#[test]
fn test_create_skeleton_primitives() {
let ir = IR::new();
let skeleton = create_skeleton_value(&ir, &FieldType::String);
assert!(matches!(skeleton, BamlValue::Null));
let skeleton = create_skeleton_value(&ir, &FieldType::Int);
assert!(matches!(skeleton, BamlValue::Null));
}
#[test]
fn test_create_skeleton_class() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
let skeleton = StreamingBamlValue::from_ir_skeleton(&ir, &FieldType::Class("Person".to_string()));
if let BamlValue::Map(map) = &skeleton.value {
assert!(map.contains_key("name"));
assert!(map.contains_key("age"));
assert!(matches!(map.get("name"), Some(BamlValue::Null)));
assert!(matches!(map.get("age"), Some(BamlValue::Null)));
} else {
panic!("Expected Map");
}
}
#[test]
fn test_merge_values() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
let mut skeleton = StreamingBamlValue::from_ir_skeleton(&ir, &FieldType::Class("Person".to_string()));
let mut partial1 = HashMap::new();
partial1.insert("name".to_string(), BamlValue::String("John".to_string()));
skeleton.update_from_partial(&ir, BamlValue::Map(partial1), &FieldType::Class("Person".to_string()));
if let BamlValue::Map(map) = &skeleton.value {
assert_eq!(map.get("name").and_then(|v| v.as_string()), Some("John"));
assert!(matches!(map.get("age"), Some(BamlValue::Null))); } else {
panic!("Expected Map");
}
assert_eq!(skeleton.completion_state, CompletionState::Partial);
let mut partial2 = HashMap::new();
partial2.insert("name".to_string(), BamlValue::String("John".to_string()));
partial2.insert("age".to_string(), BamlValue::Int(30));
skeleton.update_from_partial(&ir, BamlValue::Map(partial2), &FieldType::Class("Person".to_string()));
if let BamlValue::Map(map) = &skeleton.value {
assert_eq!(map.get("name").and_then(|v| v.as_string()), Some("John"));
assert_eq!(map.get("age").and_then(|v| v.as_int()), Some(30));
} else {
panic!("Expected Map");
}
skeleton.mark_complete();
assert_eq!(skeleton.completion_state, CompletionState::Complete);
}
#[test]
fn test_serialization() {
let value = StreamingBamlValue::new(
BamlValue::String("test".to_string()),
CompletionState::Partial,
);
let json = serde_json::to_string(&value).unwrap();
assert!(json.contains("\"state\":\"partial\""));
assert!(json.contains("\"value\""));
}
}