use crate::baml_parser::*;
pub fn generate(module: &BamlModule) -> String {
let mut out = String::new();
out.push_str("//! Auto-generated from .baml files by sgr-agent codegen.\n");
out.push_str("//! Do not edit manually — edit the .baml source and re-run.\n\n");
out.push_str("#![allow(dead_code, clippy::derivable_impls)]\n\n");
out.push_str("use serde::{Deserialize, Serialize};\n");
out.push_str("use schemars::JsonSchema;\n\n");
let mut enum_map: Vec<(String, Vec<String>)> = Vec::new();
for class in &module.classes {
generate_struct(&mut out, class, module, &mut enum_map);
}
for (name, variants) in &enum_map {
generate_string_enum(&mut out, name, variants);
}
generate_tool_registry(&mut out, module);
generate_prompts(&mut out, module);
out
}
fn generate_struct(
out: &mut String,
class: &BamlClass,
module: &BamlModule,
enum_map: &mut Vec<(String, Vec<String>)>,
) {
if let Some(desc) = &class.description {
out.push_str(&format!("/// {}\n", desc));
}
out.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
out.push_str(&format!("pub struct {} {{\n", class.name));
for field in &class.fields {
if let Some(desc) = &field.description {
out.push_str(&format!(" /// {}\n", desc));
}
if let Some(fixed) = &field.fixed_value {
out.push_str(&format!(" /// Fixed value: \"{}\"\n", fixed));
out.push_str(&format!(
" #[serde(default = \"default_{}__{}\")]\n",
snake_case(&class.name),
field.name
));
out.push_str(&format!(" pub {}: String,\n", field.name));
continue;
}
let rust_type = baml_type_to_rust(&field.ty, &class.name, &field.name, module, enum_map);
if matches!(&field.ty, BamlType::Optional(_)) {
out.push_str(" #[serde(skip_serializing_if = \"Option::is_none\")]\n");
}
out.push_str(&format!(" pub {}: {},\n", field.name, rust_type));
}
out.push_str("}\n\n");
for field in &class.fields {
if let Some(fixed) = &field.fixed_value {
out.push_str(&format!(
"fn default_{}__{}() -> String {{ \"{}\".to_string() }}\n",
snake_case(&class.name),
field.name,
fixed
));
}
}
if class.fields.iter().any(|f| f.fixed_value.is_some()) {
out.push('\n');
}
}
fn generate_string_enum(out: &mut String, name: &str, variants: &[String]) {
out.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
out.push_str(&format!("pub enum {} {{\n", name));
for variant in variants {
let rust_variant = pascal_case(variant);
out.push_str(&format!(" #[serde(rename = \"{}\")]\n", variant));
out.push_str(&format!(" {},\n", rust_variant));
}
out.push_str("}\n\n");
}
fn generate_tool_registry(out: &mut String, module: &BamlModule) {
let tool_classes: Vec<&BamlClass> = module
.classes
.iter()
.filter(|c| {
c.fields
.iter()
.any(|f| f.name == "task" && f.fixed_value.is_some())
})
.collect();
if tool_classes.is_empty() {
return;
}
out.push_str("// --- Tool Registry ---\n\n");
out.push_str("use crate::tool::ToolDef;\n\n");
out.push_str("/// All tools extracted from BAML definitions.\n");
out.push_str("pub fn all_tools() -> Vec<ToolDef> {\n");
out.push_str(" vec![\n");
for class in &tool_classes {
let task_field = class.fields.iter().find(|f| f.name == "task").unwrap();
let tool_name = task_field.fixed_value.as_deref().unwrap();
let description = task_field.description.as_deref().unwrap_or(&class.name);
out.push_str(&format!(
" crate::tool::tool::<{}>(\"{}\", \"{}\"),\n",
class.name,
tool_name,
escape_string(description),
));
}
out.push_str(" ]\n");
out.push_str("}\n\n");
out.push_str("/// Union of all tool types (for dispatching tool calls).\n");
out.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
out.push_str("#[serde(tag = \"task\")]\n");
out.push_str("pub enum ActionUnion {\n");
for class in &tool_classes {
let task_field = class.fields.iter().find(|f| f.name == "task").unwrap();
let tool_name = task_field.fixed_value.as_deref().unwrap();
out.push_str(&format!(" #[serde(rename = \"{}\")]\n", tool_name));
out.push_str(&format!(" {}({}),\n", class.name, class.name));
}
out.push_str("}\n\n");
}
fn generate_prompts(out: &mut String, module: &BamlModule) {
if module.functions.is_empty() {
return;
}
out.push_str("// --- Prompt Constants ---\n\n");
for func in &module.functions {
let const_name = screaming_snake_case(&func.name);
out.push_str(&format!(
"pub const {}_PROMPT: &str = r##\"\n{}\"##;\n\n",
const_name,
func.prompt.trim(),
));
}
}
fn baml_type_to_rust(
ty: &BamlType,
class_name: &str,
field_name: &str,
module: &BamlModule,
enum_map: &mut Vec<(String, Vec<String>)>,
) -> String {
match ty {
BamlType::String => "String".to_string(),
BamlType::Int => "i64".to_string(),
BamlType::Float => "f64".to_string(),
BamlType::Bool => "bool".to_string(),
BamlType::Image => "String".to_string(), BamlType::Ref(name) => {
if module.find_class(name).is_some() {
name.clone()
} else {
"String".to_string()
}
}
BamlType::Optional(inner) => {
let inner_rust = baml_type_to_rust(inner, class_name, field_name, module, enum_map);
format!("Option<{}>", inner_rust)
}
BamlType::Array(inner) => {
let inner_rust = baml_type_to_rust(inner, class_name, field_name, module, enum_map);
format!("Vec<{}>", inner_rust)
}
BamlType::StringEnum(variants) => {
let enum_name = format!("{}{}", class_name, pascal_case(field_name));
if !enum_map.iter().any(|(n, _)| n == &enum_name) {
enum_map.push((enum_name.clone(), variants.clone()));
}
enum_name
}
BamlType::Union(variants) => {
if variants.len() <= 4 {
let enum_name = format!("{}{}", class_name, pascal_case(field_name));
if !enum_map.iter().any(|(n, _)| n == &enum_name) {
}
}
"serde_json::Value".to_string()
}
}
}
fn snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(c.to_ascii_lowercase());
} else {
result.push(c);
}
}
result
}
fn pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
Some(c) => c.to_uppercase().to_string() + chars.as_str(),
None => String::new(),
}
})
.collect()
}
fn screaming_snake_case(s: &str) -> String {
snake_case(s).to_uppercase()
}
fn escape_string(s: &str) -> String {
s.replace('\\', "\\\\").replace('"', "\\\"")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generates_from_simple_baml() {
let source = r#"
class CutDecision {
action "trim" | "keep" | "highlight" @description("Editing action")
reason string @description("Short reasoning")
}
"#;
let mut module = BamlModule::default();
module.parse_source(source);
let code = generate(&module);
assert!(code.contains("pub struct CutDecision"));
assert!(code.contains("pub action: CutDecisionAction"));
assert!(code.contains("pub reason: String"));
assert!(code.contains("pub enum CutDecisionAction"));
assert!(code.contains("#[serde(rename = \"trim\")]"));
}
#[test]
fn generates_tools_from_baml() {
let source = r#"
class FfmpegTask {
task "ffmpeg_operation" @description("FFmpeg operations") @stream.not_null
operation "convert" | "trim"
input_path string | null
}
"#;
let mut module = BamlModule::default();
module.parse_source(source);
let code = generate(&module);
assert!(code.contains("pub fn all_tools()"));
assert!(code.contains("\"ffmpeg_operation\""));
assert!(code.contains("pub enum ActionUnion"));
assert!(code.contains("FfmpegTask(FfmpegTask)"));
}
#[test]
fn generates_from_real_montage_baml() {
let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.pop(); path.pop(); path.pop(); path.push("startups");
path.push("active");
path.push("video-analyzer");
path.push("crates");
path.push("va-agent");
path.push("baml_src");
path.push("montage");
path.set_extension("baml");
if !path.exists() {
eprintln!("Skipping: montage.baml not found at {}", path.display());
return;
}
let source = std::fs::read_to_string(&path).unwrap();
let mut module = BamlModule::default();
module.parse_source(&source);
let code = generate(&module);
assert!(code.contains("pub struct MontageAgentNextStep"));
assert!(code.contains("pub struct AnalysisTask"));
assert!(code.contains("pub struct FfmpegTask"));
assert!(code.contains("pub struct ProjectTask"));
assert!(code.contains("pub fn all_tools()"));
assert!(code.contains("\"analysis_operation\""));
assert!(code.contains("\"ffmpeg_operation\""));
assert!(code.contains("\"project_operation\""));
assert!(code.contains("DECIDE_MONTAGE_NEXT_STEP_SGR_PROMPT"));
assert!(code.contains("ANALYZE_SEGMENT_SGR_PROMPT"));
}
}