1use std::path::Path;
11
12#[derive(Debug, Clone)]
14pub struct BamlClass {
15 pub name: String,
16 pub fields: Vec<BamlField>,
17 pub description: Option<String>,
18}
19
20#[derive(Debug, Clone)]
22pub struct BamlField {
23 pub name: String,
24 pub ty: BamlType,
25 pub description: Option<String>,
26 pub fixed_value: Option<String>,
28}
29
30#[derive(Debug, Clone)]
32pub enum BamlType {
33 String,
34 Int,
35 Float,
36 Bool,
37 StringEnum(Vec<String>),
39 Ref(String),
41 Optional(Box<BamlType>),
43 Array(Box<BamlType>),
45 Union(Vec<String>),
47 Image,
49}
50
51#[derive(Debug, Clone)]
53pub struct BamlFunction {
54 pub name: String,
55 pub params: Vec<(String, BamlType)>,
56 pub return_type: String,
57 pub client: String,
58 pub prompt: String,
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct BamlModule {
64 pub classes: Vec<BamlClass>,
65 pub functions: Vec<BamlFunction>,
66}
67
68impl BamlModule {
69 pub fn parse_dir(dir: &Path) -> Result<Self, String> {
71 let mut module = BamlModule::default();
72
73 let entries =
74 std::fs::read_dir(dir).map_err(|e| format!("Cannot read {}: {}", dir.display(), e))?;
75
76 for entry in entries.flatten() {
77 let path = entry.path();
78 if path.extension().is_some_and(|ext| ext == "baml") {
79 let source = std::fs::read_to_string(&path)
80 .map_err(|e| format!("Cannot read {}: {}", path.display(), e))?;
81 module.parse_source(&source);
82 }
83 }
84
85 Ok(module)
86 }
87
88 pub fn parse_source(&mut self, source: &str) {
90 let lines: Vec<&str> = source.lines().collect();
91 let mut i = 0;
92
93 while i < lines.len() {
94 let line = lines[i].trim();
95
96 if line.is_empty() || line.starts_with("//") {
98 i += 1;
99 continue;
100 }
101
102 if line.starts_with("class ") {
104 if let Some((class, consumed)) = parse_class(&lines[i..]) {
105 self.classes.push(class);
106 i += consumed;
107 continue;
108 }
109 }
110
111 if line.starts_with("function ") {
113 if let Some((func, consumed)) = parse_function(&lines[i..]) {
114 self.functions.push(func);
115 i += consumed;
116 continue;
117 }
118 }
119
120 i += 1;
121 }
122 }
123
124 pub fn find_class(&self, name: &str) -> Option<&BamlClass> {
126 self.classes.iter().find(|c| c.name == name)
127 }
128
129 pub fn find_function(&self, name: &str) -> Option<&BamlFunction> {
131 self.functions.iter().find(|f| f.name == name)
132 }
133}
134
135fn parse_class(lines: &[&str]) -> Option<(BamlClass, usize)> {
138 let header = lines[0].trim();
139 let name = header
140 .strip_prefix("class ")?
141 .trim()
142 .trim_end_matches('{')
143 .trim()
144 .to_string();
145
146 let mut fields = Vec::new();
147 let mut i = 1;
148
149 while i < lines.len() {
150 let line = lines[i].trim();
151 i += 1;
152
153 if line == "}" {
154 break;
155 }
156 if line.is_empty() || line.starts_with("//") {
157 continue;
158 }
159
160 if let Some(field) = parse_field(line) {
161 fields.push(field);
162 }
163 }
164
165 Some((
166 BamlClass {
167 name,
168 fields,
169 description: None,
170 },
171 i,
172 ))
173}
174
175fn parse_field(line: &str) -> Option<BamlField> {
176 let line = line.trim();
184
185 let description = extract_description(line);
187
188 let clean = remove_annotations(line);
190 let clean = clean.trim();
191
192 let mut parts = clean.splitn(2, char::is_whitespace);
194 let name = parts.next()?.trim().to_string();
195 let type_str = parts.next()?.trim();
196
197 if type_str.starts_with('"') && !type_str.contains('|') {
199 let value = type_str.trim_matches('"').to_string();
200 return Some(BamlField {
201 name,
202 ty: BamlType::String,
203 description,
204 fixed_value: Some(value),
205 });
206 }
207
208 let ty = parse_type(type_str);
209
210 Some(BamlField {
211 name,
212 ty,
213 description,
214 fixed_value: None,
215 })
216}
217
218fn parse_type(s: &str) -> BamlType {
219 let s = s.trim();
220
221 if s.ends_with("[]") {
223 let inner = s.trim_end_matches("[]").trim();
224 if inner.starts_with('(') && inner.ends_with(')') {
226 let inner_types = &inner[1..inner.len() - 1];
227 let variants: Vec<String> = inner_types
228 .split('|')
229 .map(|v| v.trim().to_string())
230 .collect();
231 if variants
233 .iter()
234 .all(|v| v.starts_with(|c: char| c.is_uppercase()))
235 {
236 return BamlType::Array(Box::new(BamlType::Union(variants)));
237 }
238 }
239 let inner_type = parse_type(inner);
240 return BamlType::Array(Box::new(inner_type));
241 }
242
243 if s.contains("| null") || s.contains("null |") {
245 let base = s
246 .replace("| null", "")
247 .replace("null |", "")
248 .trim()
249 .to_string();
250 return BamlType::Optional(Box::new(parse_type(&base)));
251 }
252
253 if s.contains('"') && s.contains('|') {
255 let variants: Vec<String> = s
256 .split('|')
257 .map(|v| v.trim().trim_matches('"').to_string())
258 .filter(|v| !v.is_empty())
259 .collect();
260 return BamlType::StringEnum(variants);
261 }
262
263 if s.contains('|') {
265 let variants: Vec<String> = s.split('|').map(|v| v.trim().to_string()).collect();
266 if variants
267 .iter()
268 .all(|v| v.starts_with(|c: char| c.is_uppercase()))
269 {
270 return BamlType::Union(variants);
271 }
272 }
273
274 match s {
276 "string" => BamlType::String,
277 "int" => BamlType::Int,
278 "float" => BamlType::Float,
279 "bool" => BamlType::Bool,
280 "image" => BamlType::Image,
281 _ => {
282 if s.starts_with(|c: char| c.is_uppercase()) {
284 BamlType::Ref(s.to_string())
285 } else {
286 BamlType::String }
288 }
289 }
290}
291
292fn extract_description(line: &str) -> Option<String> {
293 let marker = "@description(\"";
294 if let Some(start) = line.find(marker) {
295 let rest = &line[start + marker.len()..];
296 if let Some(end) = rest.find("\")") {
297 return Some(rest[..end].to_string());
298 }
299 }
300 None
301}
302
303fn remove_annotations(line: &str) -> String {
304 let mut result = line.to_string();
305 while let Some(start) = result.find("@description(\"") {
307 if let Some(end) = result[start..].find("\")") {
308 result = format!("{}{}", &result[..start], &result[start + end + 2..]);
309 } else {
310 break;
311 }
312 }
313 while let Some(start) = result.find('@') {
315 let rest = &result[start + 1..];
316 let end = rest.find(|c: char| c.is_whitespace()).unwrap_or(rest.len());
317 result = format!("{}{}", &result[..start], &result[start + 1 + end..]);
318 }
319 result
320}
321
322fn parse_function(lines: &[&str]) -> Option<(BamlFunction, usize)> {
323 let header = lines[0].trim();
324
325 let rest = header.strip_prefix("function ")?;
327
328 let paren_start = rest.find('(')?;
330 let name = rest[..paren_start].trim().to_string();
331
332 let paren_end = rest.find(')')?;
334 let params_str = &rest[paren_start + 1..paren_end];
335 let params: Vec<(String, BamlType)> = if params_str.trim().is_empty() {
336 vec![]
337 } else {
338 params_str
339 .split(',')
340 .filter_map(|p| {
341 let p = p.trim();
342 let mut parts = p.splitn(2, ':');
343 let pname = parts.next()?.trim().to_string();
344 let ptype = parse_type(parts.next()?.trim());
345 Some((pname, ptype))
346 })
347 .collect()
348 };
349
350 let arrow = rest.find("->")?;
352 let return_rest = rest[arrow + 2..].trim();
353 let return_type = return_rest.trim_end_matches('{').trim().to_string();
354
355 let mut client = String::new();
357 let mut prompt_lines = Vec::new();
358 let mut in_prompt = false;
359 let mut i = 1;
360
361 while i < lines.len() {
362 let line = lines[i].trim();
363 i += 1;
364
365 if line == "}" && !in_prompt {
366 break;
367 }
368
369 if line.starts_with("client ") {
370 client = line
371 .strip_prefix("client ")
372 .unwrap_or("")
373 .trim()
374 .trim_matches('"')
375 .to_string();
376 continue;
377 }
378
379 if line.starts_with("prompt #\"") {
380 in_prompt = true;
381 let after = line.strip_prefix("prompt #\"").unwrap_or("");
383 if !after.is_empty() {
384 prompt_lines.push(after.to_string());
385 }
386 continue;
387 }
388
389 if in_prompt {
390 if line.contains("\"#") {
391 let before = line.trim_end_matches("\"#").trim_end();
392 if !before.is_empty() {
393 prompt_lines.push(before.to_string());
394 }
395 in_prompt = false;
396 continue;
397 }
398 prompt_lines.push(lines[i - 1].to_string());
399 }
400 }
401
402 Some((
403 BamlFunction {
404 name,
405 params,
406 return_type,
407 client,
408 prompt: prompt_lines.join("\n"),
409 },
410 i,
411 ))
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn parses_simple_class() {
420 let source = r#"
421class CutDecision {
422 action "trim" | "keep" | "highlight" @description("Editing action")
423 reason string @description("Short reasoning")
424}
425"#;
426 let mut module = BamlModule::default();
427 module.parse_source(source);
428
429 assert_eq!(module.classes.len(), 1);
430 let cls = &module.classes[0];
431 assert_eq!(cls.name, "CutDecision");
432 assert_eq!(cls.fields.len(), 2);
433
434 let action = &cls.fields[0];
435 assert_eq!(action.name, "action");
436 match &action.ty {
437 BamlType::StringEnum(variants) => {
438 assert_eq!(variants, &["trim", "keep", "highlight"]);
439 }
440 other => panic!("Expected StringEnum, got {:?}", other),
441 }
442 assert_eq!(action.description.as_deref(), Some("Editing action"));
443 }
444
445 #[test]
446 fn parses_class_with_optional_and_array() {
447 let source = r#"
448class FfmpegTask {
449 task "ffmpeg_operation" @description("FFmpeg ops") @stream.not_null
450 operation "convert" | "trim" | "concat"
451 input_path string | null
452 custom_args string[] | null
453 overwrite bool | null
454}
455"#;
456 let mut module = BamlModule::default();
457 module.parse_source(source);
458
459 let cls = &module.classes[0];
460 assert_eq!(cls.name, "FfmpegTask");
461
462 assert_eq!(
464 cls.fields[0].fixed_value.as_deref(),
465 Some("ffmpeg_operation")
466 );
467
468 assert!(matches!(cls.fields[2].ty, BamlType::Optional(_)));
470
471 match &cls.fields[3].ty {
473 BamlType::Optional(inner) => {
474 assert!(matches!(inner.as_ref(), BamlType::Array(_)));
475 }
476 other => panic!("Expected Optional(Array), got {:?}", other),
477 }
478 }
479
480 #[test]
481 fn parses_union_array() {
482 let source = r#"
483class MontageAgentNextStep {
484 intent "display" | "montage"
485 next_actions (AnalysisTask | FfmpegTask | ProjectTask)[] @description("Tools to execute")
486}
487"#;
488 let mut module = BamlModule::default();
489 module.parse_source(source);
490
491 let cls = &module.classes[0];
492 let actions_field = &cls.fields[1];
493 match &actions_field.ty {
494 BamlType::Array(inner) => match inner.as_ref() {
495 BamlType::Union(variants) => {
496 assert_eq!(variants, &["AnalysisTask", "FfmpegTask", "ProjectTask"]);
497 }
498 other => panic!("Expected Union inside Array, got {:?}", other),
499 },
500 other => panic!("Expected Array, got {:?}", other),
501 }
502 }
503
504 #[test]
505 fn parses_function() {
506 let source = r##"
507function AnalyzeSegmentSgr(genre: string, scene: string) -> SgrSegmentDecision {
508 client AgentFallback
509 prompt #"
510 You are a video editor.
511 Genre: {{ genre }}
512 {{ ctx.output_format }}
513 "#
514}
515"##;
516 let mut module = BamlModule::default();
517 module.parse_source(source);
518
519 assert_eq!(module.functions.len(), 1);
520 let func = &module.functions[0];
521 assert_eq!(func.name, "AnalyzeSegmentSgr");
522 assert_eq!(func.params.len(), 2);
523 assert_eq!(func.return_type, "SgrSegmentDecision");
524 assert_eq!(func.client, "AgentFallback");
525 assert!(func.prompt.contains("video editor"));
526 }
527
528 #[test]
529 fn parses_real_montage_baml() {
530 let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
531 path.pop(); path.pop(); path.pop(); path.push("startups");
535 path.push("active");
536 path.push("video-analyzer");
537 path.push("crates");
538 path.push("va-agent");
539 path.push("baml_src");
540 path.push("montage");
541 path.set_extension("baml");
542 if !path.exists() {
543 eprintln!("Skipping: montage.baml not found at {}", path.display());
544 return;
545 }
546 let source = std::fs::read_to_string(&path).unwrap();
547 let mut module = BamlModule::default();
548 module.parse_source(&source);
549
550 assert!(module.find_class("CutDecision").is_some());
552 assert!(module.find_class("MontageAgentNextStep").is_some());
553 assert!(module.find_class("AnalysisTask").is_some());
554 assert!(module.find_class("FfmpegTask").is_some());
555 assert!(module.find_class("ProjectTask").is_some());
556 assert!(module.find_class("ReportTaskCompletion").is_some());
557
558 assert!(module.find_function("AnalyzeSegmentSgr").is_some());
560 assert!(module.find_function("DecideMontageNextStepSgr").is_some());
561 assert!(module.find_function("SummarizeTranscriptSgr").is_some());
562
563 let step = module.find_class("MontageAgentNextStep").unwrap();
565 let actions = step
566 .fields
567 .iter()
568 .find(|f| f.name == "next_actions")
569 .unwrap();
570 match &actions.ty {
571 BamlType::Array(inner) => match inner.as_ref() {
572 BamlType::Union(variants) => {
573 assert!(variants.contains(&"AnalysisTask".to_string()));
574 assert!(variants.contains(&"FfmpegTask".to_string()));
575 assert!(
576 variants.len() >= 10,
577 "Should have 16 tool types, got {}",
578 variants.len()
579 );
580 }
581 other => panic!("Expected Union, got {:?}", other),
582 },
583 other => panic!("Expected Array, got {:?}", other),
584 }
585 }
586}