use anyhow::{Context, Result};
use std::collections::HashMap;
use thiserror::Error;
use crate::{
client::{DynLLMClient, LLMClientTrait},
ir::{BamlValue, FieldType, IR},
parser::Parser,
partial_parser::try_parse_partial_json,
renderer::PromptRenderer,
streaming_value::StreamingBamlValue,
};
use std::sync::Arc;
#[derive(Debug, Error)]
pub enum BamlError {
#[error("Network error: {0}")]
Network(String),
#[error("Parse error: {0}")]
Parse(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("{0}")]
Other(String),
}
impl BamlError {
pub fn is_retryable(&self) -> bool {
matches!(self, BamlError::Parse(_) | BamlError::Validation(_))
}
}
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub retry_empty_arrays: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
retry_empty_arrays: true,
}
}
}
impl RetryConfig {
pub fn new(max_retries: usize) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn without_empty_array_retry(mut self) -> Self {
self.retry_empty_arrays = false;
self
}
}
fn validate_result(ir: &IR, value: &BamlValue, expected_type: &FieldType) -> Result<()> {
match expected_type {
FieldType::Class(class_name) => {
if let Some(class) = ir.find_class(class_name) {
if let BamlValue::Map(map) = value {
for field in &class.fields {
if let Some(field_value) = map.get(&field.name) {
validate_result(ir, field_value, &field.field_type)?;
}
}
}
}
}
FieldType::List(_inner) => {
if let BamlValue::List(items) = value {
if items.is_empty() {
eprintln!("Warning: LLM returned empty array. This might indicate incomplete output.");
}
}
}
_ => {}
}
Ok(())
}
pub fn generate_prompt_from_ir(
ir: &IR,
template: &str,
params: &HashMap<String, BamlValue>,
output_type: &FieldType,
) -> Result<String> {
let renderer = PromptRenderer::new(ir);
renderer.render(template, params, output_type)
.context("Failed to render prompt from IR")
}
pub fn parse_llm_response_with_ir(
ir: &IR,
raw_response: &str,
target_type: &FieldType,
) -> Result<BamlValue> {
let parser = Parser::new(ir);
parser.parse(raw_response, target_type)
.context("Failed to parse LLM response using IR")
}
pub fn try_parse_partial_response(
ir: &IR,
partial_response: &str,
target_type: &FieldType,
) -> Result<Option<BamlValue>> {
match try_parse_partial_json(partial_response)? {
Some(json_value) => {
let json_str = serde_json::to_string(&json_value)?;
match parse_llm_response_with_ir(ir, &json_str, target_type) {
Ok(baml_value) => Ok(Some(baml_value)),
Err(_) => Ok(None), }
}
None => Ok(None), }
}
pub fn update_streaming_response(
streaming_value: &mut StreamingBamlValue,
ir: &IR,
partial_response: &str,
target_type: &FieldType,
is_final: bool,
) -> Result<()> {
if is_final {
let final_value = parse_llm_response_with_ir(ir, partial_response, target_type)
.context("Final streaming chunk failed to parse")?;
streaming_value.update_from_partial(ir, final_value, target_type);
streaming_value.mark_complete();
} else {
if let Some(partial_baml) = try_parse_partial_response(ir, partial_response, target_type)? {
streaming_value.update_from_partial(ir, partial_baml, target_type);
}
}
Ok(())
}
pub struct BamlRuntime {
ir: IR,
clients: HashMap<String, DynLLMClient>,
}
impl BamlRuntime {
pub fn new(ir: IR) -> Self {
Self {
ir,
clients: HashMap::new(),
}
}
pub fn register_client<C: LLMClientTrait + 'static>(&mut self, name: impl Into<String>, client: C) {
self.clients.insert(name.into(), Arc::new(client));
}
pub fn register_dyn_client(&mut self, name: impl Into<String>, client: DynLLMClient) {
self.clients.insert(name.into(), client);
}
pub async fn execute(
&self,
function_name: &str,
params: HashMap<String, BamlValue>,
) -> Result<BamlValue> {
let function = self.ir.find_function(function_name)
.ok_or_else(|| anyhow::anyhow!("Function '{}' not found", function_name))?;
let client = self.clients.get(&function.client)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", function.client))?;
let prompt = generate_prompt_from_ir(
&self.ir,
&function.prompt_template,
¶ms,
&function.output
)?;
let raw_response = client.call(&prompt)
.await
.context("Failed to call LLM")?;
let result = parse_llm_response_with_ir(
&self.ir,
&raw_response,
&function.output
)?;
validate_result(&self.ir, &result, &function.output)?;
Ok(result)
}
pub async fn execute_with_retry(
&self,
function_name: &str,
params: HashMap<String, BamlValue>,
max_retries: usize,
) -> Result<BamlValue> {
self.execute_with_retry_config(function_name, params, RetryConfig {
max_retries,
retry_empty_arrays: true,
}).await
}
pub async fn execute_with_retry_config(
&self,
function_name: &str,
params: HashMap<String, BamlValue>,
config: RetryConfig,
) -> Result<BamlValue> {
let mut attempts = 0;
loop {
match self.try_execute(function_name, params.clone()).await {
Ok(result) => {
if config.retry_empty_arrays && has_empty_arrays(&result) && attempts < config.max_retries {
eprintln!("Attempt {}: LLM returned empty arrays, retrying...", attempts + 1);
attempts += 1;
continue;
}
return Ok(result);
}
Err(baml_error) => {
if !baml_error.is_retryable() || attempts >= config.max_retries {
return Err(anyhow::anyhow!(baml_error));
}
attempts += 1;
eprintln!("Attempt {} failed ({}), retrying...", attempts, baml_error);
}
}
}
}
async fn try_execute(
&self,
function_name: &str,
params: HashMap<String, BamlValue>,
) -> std::result::Result<BamlValue, BamlError> {
let function = self.ir.find_function(function_name)
.ok_or_else(|| BamlError::Other(format!("Function '{}' not found", function_name)))?;
let client = self.clients.get(&function.client)
.ok_or_else(|| BamlError::Other(format!("Client '{}' not found", function.client)))?;
let prompt = generate_prompt_from_ir(
&self.ir,
&function.prompt_template,
¶ms,
&function.output
).map_err(|e| BamlError::Other(e.to_string()))?;
let raw_response = client.call(&prompt)
.await
.map_err(|e| BamlError::Network(e.to_string()))?;
let result = parse_llm_response_with_ir(
&self.ir,
&raw_response,
&function.output
).map_err(|e| BamlError::Parse(e.to_string()))?;
validate_result(&self.ir, &result, &function.output)
.map_err(|e| BamlError::Validation(e.to_string()))?;
Ok(result)
}
pub fn ir(&self) -> &IR {
&self.ir
}
}
fn has_empty_arrays(value: &BamlValue) -> bool {
match value {
BamlValue::List(items) => items.is_empty(),
BamlValue::Map(map) => {
for val in map.values() {
if has_empty_arrays(val) {
return true;
}
}
false
}
_ => false,
}
}
pub struct RuntimeBuilder {
ir: IR,
clients: HashMap<String, DynLLMClient>,
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self {
ir: IR::new(),
clients: HashMap::new(),
}
}
pub fn ir(mut self, ir: IR) -> Self {
self.ir = ir;
self
}
pub fn client<C: LLMClientTrait + 'static>(mut self, name: impl Into<String>, client: C) -> Self {
self.clients.insert(name.into(), Arc::new(client));
self
}
pub fn dyn_client(mut self, name: impl Into<String>, client: DynLLMClient) -> Self {
self.clients.insert(name.into(), client);
self
}
pub fn build(self) -> BamlRuntime {
let mut runtime = BamlRuntime::new(self.ir);
for (name, client) in self.clients {
runtime.register_dyn_client(name, client);
}
runtime
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::LLMClient;
use crate::ir::*;
#[tokio::test]
async fn test_runtime_execution() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
ir.functions.push(Function {
name: "ExtractPerson".to_string(),
inputs: vec![
Field {
name: "text".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
}
],
output: FieldType::Class("Person".to_string()),
prompt_template: "Extract person info from: {{ text }}".to_string(),
client: "test_client".to_string(),
});
let runtime = BamlRuntime::new(ir);
assert!(runtime.ir().find_function("ExtractPerson").is_some());
}
#[test]
fn test_runtime_builder() {
let ir = IR::new();
let client = LLMClient::openai("test-key".to_string(), "gpt-4".to_string());
let runtime = RuntimeBuilder::new()
.ir(ir)
.client("openai", client)
.build();
assert!(runtime.clients.contains_key("openai"));
}
#[test]
fn test_final_chunk_with_invalid_json_returns_error() {
use crate::streaming_value::StreamingBamlValue;
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
let target_type = FieldType::Class("Person".to_string());
let mut streaming = StreamingBamlValue::from_ir_skeleton(&ir, &target_type);
let invalid_json = r#"{"name": "John", "age":"#;
let result = update_streaming_response(
&mut streaming,
&ir,
invalid_json,
&target_type,
true,
);
assert!(result.is_err(), "Final chunk with invalid JSON should return error");
assert_ne!(
streaming.completion_state,
crate::streaming_value::CompletionState::Complete,
"Should not mark as complete on parse failure"
);
}
#[test]
fn test_final_chunk_with_valid_json_succeeds() {
use crate::streaming_value::StreamingBamlValue;
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
let target_type = FieldType::Class("Person".to_string());
let mut streaming = StreamingBamlValue::from_ir_skeleton(&ir, &target_type);
let valid_json = r#"{"name": "John", "age": 30}"#;
let result = update_streaming_response(
&mut streaming,
&ir,
valid_json,
&target_type,
true,
);
assert!(result.is_ok(), "Final chunk with valid JSON should succeed");
assert_eq!(
streaming.completion_state,
crate::streaming_value::CompletionState::Complete,
"Should mark as complete on success"
);
}
#[test]
fn test_baml_error_retryable() {
assert!(BamlError::Parse("invalid json".to_string()).is_retryable());
assert!(BamlError::Validation("missing field".to_string()).is_retryable());
assert!(!BamlError::Network("connection refused".to_string()).is_retryable());
assert!(!BamlError::Other("unknown error".to_string()).is_retryable());
}
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert!(config.retry_empty_arrays);
}
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new(5).without_empty_array_retry();
assert_eq!(config.max_retries, 5);
assert!(!config.retry_empty_arrays);
}
#[tokio::test]
async fn test_mock_client_with_runtime() {
use crate::client::MockLLMClient;
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
ir.functions.push(Function {
name: "ExtractPerson".to_string(),
inputs: vec![Field {
name: "text".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
}],
output: FieldType::Class("Person".to_string()),
prompt_template: "Extract person info from: {{ text }}".to_string(),
client: "mock".to_string(),
});
let mut mock_client = MockLLMClient::new();
mock_client.add_response("Extract person", r#"{"name": "Alice", "age": 25}"#);
let runtime = RuntimeBuilder::new()
.ir(ir)
.client("mock", mock_client)
.build();
let mut params = HashMap::new();
params.insert("text".to_string(), BamlValue::String("Alice is 25".to_string()));
let result = runtime.execute("ExtractPerson", params).await.unwrap();
let map = result.as_map().unwrap();
assert_eq!(map.get("name").unwrap().as_string(), Some("Alice"));
assert_eq!(map.get("age").unwrap().as_int(), Some(25));
}
}