use anyhow::{Context, Result};
use std::collections::HashMap;
use crate::{
client::LLMClient,
ir::{BamlValue, FieldType, IR},
parser::Parser,
partial_parser::try_parse_partial_json,
renderer::PromptRenderer,
streaming_value::StreamingBamlValue,
};
pub fn generate_prompt_from_ir(
ir: &IR,
template: &str,
params: &HashMap<String, BamlValue>,
output_type: &FieldType,
) -> Result<String> {
let renderer = PromptRenderer::new(ir);
renderer.render(template, params, output_type)
.context("Failed to render prompt from IR")
}
pub fn parse_llm_response_with_ir(
ir: &IR,
raw_response: &str,
target_type: &FieldType,
) -> Result<BamlValue> {
let parser = Parser::new(ir);
parser.parse(raw_response, target_type)
.context("Failed to parse LLM response using IR")
}
pub fn try_parse_partial_response(
ir: &IR,
partial_response: &str,
target_type: &FieldType,
) -> Result<Option<BamlValue>> {
match try_parse_partial_json(partial_response)? {
Some(json_value) => {
let json_str = serde_json::to_string(&json_value)?;
match parse_llm_response_with_ir(ir, &json_str, target_type) {
Ok(baml_value) => Ok(Some(baml_value)),
Err(_) => Ok(None), }
}
None => Ok(None), }
}
pub fn update_streaming_response(
streaming_value: &mut StreamingBamlValue,
ir: &IR,
partial_response: &str,
target_type: &FieldType,
is_final: bool,
) -> Result<()> {
if let Some(partial_baml) = try_parse_partial_response(ir, partial_response, target_type)? {
streaming_value.update_from_partial(ir, partial_baml, target_type);
}
if is_final {
streaming_value.mark_complete();
}
Ok(())
}
pub struct BamlRuntime {
ir: IR,
clients: HashMap<String, LLMClient>,
}
impl BamlRuntime {
pub fn new(ir: IR) -> Self {
Self {
ir,
clients: HashMap::new(),
}
}
pub fn register_client(&mut self, name: impl Into<String>, client: LLMClient) {
self.clients.insert(name.into(), client);
}
pub async fn execute(
&self,
function_name: &str,
params: HashMap<String, BamlValue>,
) -> Result<BamlValue> {
let function = self.ir.find_function(function_name)
.ok_or_else(|| anyhow::anyhow!("Function '{}' not found", function_name))?;
let client = self.clients.get(&function.client)
.ok_or_else(|| anyhow::anyhow!("Client '{}' not found", function.client))?;
let prompt = generate_prompt_from_ir(
&self.ir,
&function.prompt_template,
¶ms,
&function.output
)?;
let raw_response = client.call(&prompt)
.await
.context("Failed to call LLM")?;
let result = parse_llm_response_with_ir(
&self.ir,
&raw_response,
&function.output
)?;
Ok(result)
}
pub fn ir(&self) -> &IR {
&self.ir
}
}
pub struct RuntimeBuilder {
ir: IR,
clients: HashMap<String, LLMClient>,
}
impl RuntimeBuilder {
pub fn new() -> Self {
Self {
ir: IR::new(),
clients: HashMap::new(),
}
}
pub fn ir(mut self, ir: IR) -> Self {
self.ir = ir;
self
}
pub fn client(mut self, name: impl Into<String>, client: LLMClient) -> Self {
self.clients.insert(name.into(), client);
self
}
pub fn build(self) -> BamlRuntime {
let mut runtime = BamlRuntime::new(self.ir);
for (name, client) in self.clients {
runtime.register_client(name, client);
}
runtime
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::*;
#[tokio::test]
async fn test_runtime_execution() {
let mut ir = IR::new();
ir.classes.push(Class {
name: "Person".to_string(),
description: None,
fields: vec![
Field {
name: "name".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
},
Field {
name: "age".to_string(),
field_type: FieldType::Int,
optional: false,
description: None,
},
],
});
ir.functions.push(Function {
name: "ExtractPerson".to_string(),
inputs: vec![
Field {
name: "text".to_string(),
field_type: FieldType::String,
optional: false,
description: None,
}
],
output: FieldType::Class("Person".to_string()),
prompt_template: "Extract person info from: {{ text }}".to_string(),
client: "test_client".to_string(),
});
let runtime = BamlRuntime::new(ir);
assert!(runtime.ir().find_function("ExtractPerson").is_some());
}
#[test]
fn test_runtime_builder() {
let ir = IR::new();
let client = LLMClient::openai("test-key".to_string(), "gpt-4".to_string());
let runtime = RuntimeBuilder::new()
.ir(ir)
.client("openai", client)
.build();
assert!(runtime.clients.contains_key("openai"));
}
}