use std::path::Path;
#[derive(Debug, Clone)]
pub struct BamlClass {
pub name: String,
pub fields: Vec<BamlField>,
pub description: Option<String>,
}
#[derive(Debug, Clone)]
pub struct BamlField {
pub name: String,
pub ty: BamlType,
pub description: Option<String>,
pub fixed_value: Option<String>,
}
#[derive(Debug, Clone)]
pub enum BamlType {
String,
Int,
Float,
Bool,
StringEnum(Vec<String>),
Ref(String),
Optional(Box<BamlType>),
Array(Box<BamlType>),
Union(Vec<String>),
Image,
}
#[derive(Debug, Clone)]
pub struct BamlFunction {
pub name: String,
pub params: Vec<(String, BamlType)>,
pub return_type: String,
pub client: String,
pub prompt: String,
}
#[derive(Debug, Clone, Default)]
pub struct BamlModule {
pub classes: Vec<BamlClass>,
pub functions: Vec<BamlFunction>,
}
impl BamlModule {
pub fn parse_dir(dir: &Path) -> Result<Self, String> {
let mut module = BamlModule::default();
let entries =
std::fs::read_dir(dir).map_err(|e| format!("Cannot read {}: {}", dir.display(), e))?;
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "baml") {
let source = std::fs::read_to_string(&path)
.map_err(|e| format!("Cannot read {}: {}", path.display(), e))?;
module.parse_source(&source);
}
}
Ok(module)
}
pub fn parse_source(&mut self, source: &str) {
let lines: Vec<&str> = source.lines().collect();
let mut i = 0;
while i < lines.len() {
let line = lines[i].trim();
if line.is_empty() || line.starts_with("//") {
i += 1;
continue;
}
if line.starts_with("class ")
&& let Some((class, consumed)) = parse_class(&lines[i..])
{
self.classes.push(class);
i += consumed;
continue;
}
if line.starts_with("function ")
&& let Some((func, consumed)) = parse_function(&lines[i..])
{
self.functions.push(func);
i += consumed;
continue;
}
i += 1;
}
}
pub fn find_class(&self, name: &str) -> Option<&BamlClass> {
self.classes.iter().find(|c| c.name == name)
}
pub fn find_function(&self, name: &str) -> Option<&BamlFunction> {
self.functions.iter().find(|f| f.name == name)
}
}
fn parse_class(lines: &[&str]) -> Option<(BamlClass, usize)> {
let header = lines[0].trim();
let name = header
.strip_prefix("class ")?
.trim()
.trim_end_matches('{')
.trim()
.to_string();
let mut fields = Vec::new();
let mut i = 1;
while i < lines.len() {
let line = lines[i].trim();
i += 1;
if line == "}" {
break;
}
if line.is_empty() || line.starts_with("//") {
continue;
}
if let Some(field) = parse_field(line) {
fields.push(field);
}
}
Some((
BamlClass {
name,
fields,
description: None,
},
i,
))
}
fn parse_field(line: &str) -> Option<BamlField> {
let line = line.trim();
let description = extract_description(line);
let clean = remove_annotations(line);
let clean = clean.trim();
let mut parts = clean.splitn(2, char::is_whitespace);
let name = parts.next()?.trim().to_string();
let type_str = parts.next()?.trim();
if type_str.starts_with('"') && !type_str.contains('|') {
let value = type_str.trim_matches('"').to_string();
return Some(BamlField {
name,
ty: BamlType::String,
description,
fixed_value: Some(value),
});
}
let ty = parse_type(type_str);
Some(BamlField {
name,
ty,
description,
fixed_value: None,
})
}
fn parse_type(s: &str) -> BamlType {
let s = s.trim();
if s.ends_with("[]") {
let inner = s.trim_end_matches("[]").trim();
if inner.starts_with('(') && inner.ends_with(')') {
let inner_types = &inner[1..inner.len() - 1];
let variants: Vec<String> = inner_types
.split('|')
.map(|v| v.trim().to_string())
.collect();
if variants
.iter()
.all(|v| v.starts_with(|c: char| c.is_uppercase()))
{
return BamlType::Array(Box::new(BamlType::Union(variants)));
}
}
let inner_type = parse_type(inner);
return BamlType::Array(Box::new(inner_type));
}
if s.contains("| null") || s.contains("null |") {
let base = s
.replace("| null", "")
.replace("null |", "")
.trim()
.to_string();
return BamlType::Optional(Box::new(parse_type(&base)));
}
if s.contains('"') && s.contains('|') {
let variants: Vec<String> = s
.split('|')
.map(|v| v.trim().trim_matches('"').to_string())
.filter(|v| !v.is_empty())
.collect();
return BamlType::StringEnum(variants);
}
if s.contains('|') {
let variants: Vec<String> = s.split('|').map(|v| v.trim().to_string()).collect();
if variants
.iter()
.all(|v| v.starts_with(|c: char| c.is_uppercase()))
{
return BamlType::Union(variants);
}
}
match s {
"string" => BamlType::String,
"int" => BamlType::Int,
"float" => BamlType::Float,
"bool" => BamlType::Bool,
"image" => BamlType::Image,
_ => {
if s.starts_with(|c: char| c.is_uppercase()) {
BamlType::Ref(s.to_string())
} else {
BamlType::String }
}
}
}
fn extract_description(line: &str) -> Option<String> {
let marker = "@description(\"";
if let Some(start) = line.find(marker) {
let rest = &line[start + marker.len()..];
if let Some(end) = rest.find("\")") {
return Some(rest[..end].to_string());
}
}
None
}
fn remove_annotations(line: &str) -> String {
let mut result = line.to_string();
while let Some(start) = result.find("@description(\"") {
if let Some(end) = result[start..].find("\")") {
result = format!("{}{}", &result[..start], &result[start + end + 2..]);
} else {
break;
}
}
while let Some(start) = result.find('@') {
let rest = &result[start + 1..];
let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
result = format!("{}{}", &result[..start], &result[start + 1 + end..]);
}
result
}
fn parse_function(lines: &[&str]) -> Option<(BamlFunction, usize)> {
let header = lines[0].trim();
let rest = header.strip_prefix("function ")?;
let paren_start = rest.find('(')?;
let name = rest[..paren_start].trim().to_string();
let paren_end = rest.find(')')?;
let params_str = &rest[paren_start + 1..paren_end];
let params: Vec<(String, BamlType)> = if params_str.trim().is_empty() {
vec![]
} else {
params_str
.split(',')
.filter_map(|p| {
let p = p.trim();
let mut parts = p.splitn(2, ':');
let pname = parts.next()?.trim().to_string();
let ptype = parse_type(parts.next()?.trim());
Some((pname, ptype))
})
.collect()
};
let arrow = rest.find("->")?;
let return_rest = rest[arrow + 2..].trim();
let return_type = return_rest.trim_end_matches('{').trim().to_string();
let mut client = String::new();
let mut prompt_lines = Vec::new();
let mut in_prompt = false;
let mut i = 1;
while i < lines.len() {
let line = lines[i].trim();
i += 1;
if line == "}" && !in_prompt {
break;
}
if line.starts_with("client ") {
client = line
.strip_prefix("client ")
.unwrap_or("")
.trim()
.trim_matches('"')
.to_string();
continue;
}
if line.starts_with("prompt #\"") {
in_prompt = true;
let after = line.strip_prefix("prompt #\"").unwrap_or("");
if !after.is_empty() {
prompt_lines.push(after.to_string());
}
continue;
}
if in_prompt {
if line.contains("\"#") {
let before = line.trim_end_matches("\"#").trim_end();
if !before.is_empty() {
prompt_lines.push(before.to_string());
}
in_prompt = false;
continue;
}
prompt_lines.push(lines[i - 1].to_string());
}
}
Some((
BamlFunction {
name,
params,
return_type,
client,
prompt: prompt_lines.join("\n"),
},
i,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_simple_class() {
let source = r#"
class CutDecision {
action "trim" | "keep" | "highlight" @description("Editing action")
reason string @description("Short reasoning")
}
"#;
let mut module = BamlModule::default();
module.parse_source(source);
assert_eq!(module.classes.len(), 1);
let cls = &module.classes[0];
assert_eq!(cls.name, "CutDecision");
assert_eq!(cls.fields.len(), 2);
let action = &cls.fields[0];
assert_eq!(action.name, "action");
match &action.ty {
BamlType::StringEnum(variants) => {
assert_eq!(variants, &["trim", "keep", "highlight"]);
}
other => panic!("Expected StringEnum, got {:?}", other),
}
assert_eq!(action.description.as_deref(), Some("Editing action"));
}
#[test]
fn parses_class_with_optional_and_array() {
let source = r#"
class FfmpegTask {
task "ffmpeg_operation" @description("FFmpeg ops") @stream.not_null
operation "convert" | "trim" | "concat"
input_path string | null
custom_args string[] | null
overwrite bool | null
}
"#;
let mut module = BamlModule::default();
module.parse_source(source);
let cls = &module.classes[0];
assert_eq!(cls.name, "FfmpegTask");
assert_eq!(
cls.fields[0].fixed_value.as_deref(),
Some("ffmpeg_operation")
);
assert!(matches!(cls.fields[2].ty, BamlType::Optional(_)));
match &cls.fields[3].ty {
BamlType::Optional(inner) => {
assert!(matches!(inner.as_ref(), BamlType::Array(_)));
}
other => panic!("Expected Optional(Array), got {:?}", other),
}
}
#[test]
fn parses_union_array() {
let source = r#"
class MontageAgentNextStep {
intent "display" | "montage"
next_actions (AnalysisTask | FfmpegTask | ProjectTask)[] @description("Tools to execute")
}
"#;
let mut module = BamlModule::default();
module.parse_source(source);
let cls = &module.classes[0];
let actions_field = &cls.fields[1];
match &actions_field.ty {
BamlType::Array(inner) => match inner.as_ref() {
BamlType::Union(variants) => {
assert_eq!(variants, &["AnalysisTask", "FfmpegTask", "ProjectTask"]);
}
other => panic!("Expected Union inside Array, got {:?}", other),
},
other => panic!("Expected Array, got {:?}", other),
}
}
#[test]
fn parses_function() {
let source = r##"
function AnalyzeSegmentSgr(genre: string, scene: string) -> SgrSegmentDecision {
client AgentFallback
prompt #"
You are a video editor.
Genre: {{ genre }}
{{ ctx.output_format }}
"#
}
"##;
let mut module = BamlModule::default();
module.parse_source(source);
assert_eq!(module.functions.len(), 1);
let func = &module.functions[0];
assert_eq!(func.name, "AnalyzeSegmentSgr");
assert_eq!(func.params.len(), 2);
assert_eq!(func.return_type, "SgrSegmentDecision");
assert_eq!(func.client, "AgentFallback");
assert!(func.prompt.contains("video editor"));
}
#[test]
fn parses_real_montage_baml() {
let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.pop(); path.pop(); path.pop(); path.push("startups");
path.push("active");
path.push("video-analyzer");
path.push("crates");
path.push("va-agent");
path.push("baml_src");
path.push("montage");
path.set_extension("baml");
if !path.exists() {
eprintln!("Skipping: montage.baml not found at {}", path.display());
return;
}
let source = std::fs::read_to_string(&path).unwrap();
let mut module = BamlModule::default();
module.parse_source(&source);
assert!(module.find_class("CutDecision").is_some());
assert!(module.find_class("MontageAgentNextStep").is_some());
assert!(module.find_class("AnalysisTask").is_some());
assert!(module.find_class("FfmpegTask").is_some());
assert!(module.find_class("ProjectTask").is_some());
assert!(module.find_class("ReportTaskCompletion").is_some());
assert!(module.find_function("AnalyzeSegmentSgr").is_some());
assert!(module.find_function("DecideMontageNextStepSgr").is_some());
assert!(module.find_function("SummarizeTranscriptSgr").is_some());
let step = module.find_class("MontageAgentNextStep").unwrap();
let actions = step
.fields
.iter()
.find(|f| f.name == "next_actions")
.unwrap();
match &actions.ty {
BamlType::Array(inner) => match inner.as_ref() {
BamlType::Union(variants) => {
assert!(variants.contains(&"AnalysisTask".to_string()));
assert!(variants.contains(&"FfmpegTask".to_string()));
assert!(
variants.len() >= 10,
"Should have 16 tool types, got {}",
variants.len()
);
}
other => panic!("Expected Union, got {:?}", other),
},
other => panic!("Expected Array, got {:?}", other),
}
}
}