use anyhow::{Context, Result};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use crate::ir::{BamlValue, FieldType, IR};
pub struct Parser<'a> {
ir: &'a IR,
}
impl<'a> Parser<'a> {
pub fn new(ir: &'a IR) -> Self {
Self { ir }
}
pub fn parse(&self, raw_response: &str, target_type: &FieldType) -> Result<BamlValue> {
let json_str = self.extract_json(raw_response)?;
let json_value: JsonValue = serde_json::from_str(&json_str)
.context("Failed to parse JSON from LLM response")?;
self.coerce(&json_value, target_type)
}
fn extract_json(&self, response: &str) -> Result<String> {
let response = response.trim();
if let Some(start) = response.find("```json") {
let json_start = start + 7; if let Some(end_offset) = response[json_start..].find("```") {
let json_end = json_start + end_offset;
return Ok(response[json_start..json_end].trim().to_string());
}
}
if let Some(start) = response.find("```") {
if let Some(end) = response[start + 3..].find("```") {
let json_start = start + 3;
let json_end = start + 3 + end;
let content = response[json_start..json_end].trim();
if content.starts_with('{') || content.starts_with('[') {
return Ok(content.to_string());
}
}
}
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
if end > start {
return Ok(response[start..=end].to_string());
}
}
}
if let Some(start) = response.find('[') {
if let Some(end) = response.rfind(']') {
if end > start {
return Ok(response[start..=end].to_string());
}
}
}
Ok(response.to_string())
}
fn coerce(&self, value: &JsonValue, target_type: &FieldType) -> Result<BamlValue> {
match target_type {
FieldType::String => self.coerce_string(value),
FieldType::Int => self.coerce_int(value),
FieldType::Float => self.coerce_float(value),
FieldType::Bool => self.coerce_bool(value),
FieldType::Enum(enum_name) => self.coerce_enum(value, enum_name),
FieldType::Class(class_name) => {
if self.ir.find_class(class_name).is_some() {
self.coerce_class(value, class_name)
} else if self.ir.find_enum(class_name).is_some() {
self.coerce_enum(value, class_name)
} else {
anyhow::bail!("Type '{}' not found (neither class nor enum)", class_name)
}
}
FieldType::List(inner) => self.coerce_list(value, inner),
FieldType::Map(k, v) => self.coerce_map(value, k, v),
FieldType::Union(types) => self.coerce_union(value, types),
}
}
fn coerce_string(&self, value: &JsonValue) -> Result<BamlValue> {
match value {
JsonValue::String(s) => Ok(BamlValue::String(s.clone())),
JsonValue::Number(n) => Ok(BamlValue::String(n.to_string())),
JsonValue::Bool(b) => Ok(BamlValue::String(b.to_string())),
JsonValue::Null => Ok(BamlValue::String("".to_string())),
_ => anyhow::bail!("Cannot coerce {:?} to string", value),
}
}
fn coerce_int(&self, value: &JsonValue) -> Result<BamlValue> {
match value {
JsonValue::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(BamlValue::Int(i))
} else if let Some(f) = n.as_f64() {
Ok(BamlValue::Int(f as i64))
} else {
anyhow::bail!("Cannot coerce number to int")
}
}
JsonValue::String(s) => {
let i = s.parse::<i64>()
.context("Cannot parse string as int")?;
Ok(BamlValue::Int(i))
}
_ => anyhow::bail!("Cannot coerce {:?} to int", value),
}
}
fn coerce_float(&self, value: &JsonValue) -> Result<BamlValue> {
match value {
JsonValue::Number(n) => {
if let Some(f) = n.as_f64() {
Ok(BamlValue::Float(f))
} else {
anyhow::bail!("Cannot coerce number to float")
}
}
JsonValue::String(s) => {
let f = s.parse::<f64>()
.context("Cannot parse string as float")?;
Ok(BamlValue::Float(f))
}
_ => anyhow::bail!("Cannot coerce {:?} to float", value),
}
}
fn coerce_bool(&self, value: &JsonValue) -> Result<BamlValue> {
match value {
JsonValue::Bool(b) => Ok(BamlValue::Bool(*b)),
JsonValue::String(s) => {
let s_lower = s.to_lowercase();
if s_lower == "true" || s_lower == "yes" || s_lower == "1" {
Ok(BamlValue::Bool(true))
} else if s_lower == "false" || s_lower == "no" || s_lower == "0" {
Ok(BamlValue::Bool(false))
} else {
anyhow::bail!("Cannot parse '{}' as bool", s)
}
}
JsonValue::Number(n) => {
if let Some(i) = n.as_i64() {
Ok(BamlValue::Bool(i != 0))
} else {
anyhow::bail!("Cannot coerce number to bool")
}
}
_ => anyhow::bail!("Cannot coerce {:?} to bool", value),
}
}
fn coerce_enum(&self, value: &JsonValue, enum_name: &str) -> Result<BamlValue> {
let e = self.ir.find_enum(enum_name)
.ok_or_else(|| anyhow::anyhow!("Enum '{}' not found", enum_name))?;
let str_value = match value {
JsonValue::String(s) => s.clone(),
_ => value.to_string(),
};
if e.values.contains(&str_value) {
Ok(BamlValue::String(str_value))
} else {
let lower = str_value.to_lowercase();
for variant in &e.values {
if variant.to_lowercase() == lower {
return Ok(BamlValue::String(variant.clone()));
}
}
anyhow::bail!("'{}' is not a valid variant of enum '{}'", str_value, enum_name)
}
}
fn coerce_class(&self, value: &JsonValue, class_name: &str) -> Result<BamlValue> {
let class = self.ir.find_class(class_name)
.ok_or_else(|| anyhow::anyhow!("Class '{}' not found", class_name))?;
let obj = value.as_object()
.ok_or_else(|| anyhow::anyhow!("Expected object for class '{}'", class_name))?;
let mut result = HashMap::new();
for field in &class.fields {
if let Some(field_value) = obj.get(&field.name) {
let coerced = self.coerce(field_value, &field.field_type)?;
result.insert(field.name.clone(), coerced);
} else if !field.optional {
anyhow::bail!("Missing required field '{}' in class '{}'", field.name, class_name);
}
}
Ok(BamlValue::Map(result))
}
fn coerce_list(&self, value: &JsonValue, inner_type: &FieldType) -> Result<BamlValue> {
let arr = value.as_array()
.ok_or_else(|| anyhow::anyhow!("Expected array"))?;
let coerced: Result<Vec<BamlValue>> = arr.iter()
.map(|item| self.coerce(item, inner_type))
.collect();
Ok(BamlValue::List(coerced?))
}
fn coerce_map(&self, value: &JsonValue, _key_type: &FieldType, value_type: &FieldType) -> Result<BamlValue> {
let obj = value.as_object()
.ok_or_else(|| anyhow::anyhow!("Expected object for map"))?;
let coerced: Result<HashMap<String, BamlValue>> = obj.iter()
.map(|(k, v)| {
self.coerce(v, value_type)
.map(|coerced_v| (k.clone(), coerced_v))
})
.collect();
Ok(BamlValue::Map(coerced?))
}
fn coerce_union(&self, value: &JsonValue, types: &[FieldType]) -> Result<BamlValue> {
for t in types {
if let Ok(coerced) = self.coerce(value, t) {
return Ok(coerced);
}
}
anyhow::bail!("Cannot coerce {:?} to any of the union types", value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::*;
#[test]
fn test_extract_json_from_markdown() {
let ir = IR::new();
let parser = Parser::new(&ir);
let response = r#"
Here's the result:
```json
{"name": "John", "age": 30}
```
"#;
let json = parser.extract_json(response).unwrap();
assert_eq!(json.trim(), r#"{"name": "John", "age": 30}"#);
}
#[test]
fn test_coerce_int_from_string() {
let ir = IR::new();
let parser = Parser::new(&ir);
let value = JsonValue::String("42".to_string());
let result = parser.coerce_int(&value).unwrap();
assert_eq!(result.as_int(), Some(42));
}
#[test]
fn test_parse_class() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
let parser = Parser::new(&ir);
let response = r#"{"name": "John", "age": 30}"#;
let result = parser.parse(response, &FieldType::Class("Person".to_string())).unwrap();
if let BamlValue::Map(map) = result {
assert_eq!(map.get("name").and_then(|v| v.as_string()), Some("John"));
assert_eq!(map.get("age").and_then(|v| v.as_int()), Some(30));
} else {
panic!("Expected Map");
}
}
}