1use crate::baml_parser::*;
10
11pub fn generate(module: &BamlModule) -> String {
13 let mut out = String::new();
14
15 out.push_str("//! Auto-generated from .baml files by sgr-agent codegen.\n");
17 out.push_str("//! Do not edit manually — edit the .baml source and re-run.\n\n");
18 out.push_str("#![allow(dead_code, clippy::derivable_impls)]\n\n");
19 out.push_str("use serde::{Deserialize, Serialize};\n");
20 out.push_str("use schemars::JsonSchema;\n\n");
21
22 let mut enum_map: Vec<(String, Vec<String>)> = Vec::new();
24
25 for class in &module.classes {
27 generate_struct(&mut out, class, module, &mut enum_map);
28 }
29
30 for (name, variants) in &enum_map {
32 generate_string_enum(&mut out, name, variants);
33 }
34
35 generate_tool_registry(&mut out, module);
37
38 generate_prompts(&mut out, module);
40
41 out
42}
43
44fn generate_struct(
45 out: &mut String,
46 class: &BamlClass,
47 module: &BamlModule,
48 enum_map: &mut Vec<(String, Vec<String>)>,
49) {
50 if let Some(desc) = &class.description {
52 out.push_str(&format!("/// {}\n", desc));
53 }
54
55 out.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
56 out.push_str(&format!("pub struct {} {{\n", class.name));
57
58 for field in &class.fields {
59 if let Some(desc) = &field.description {
61 out.push_str(&format!(" /// {}\n", desc));
62 }
63
64 if let Some(fixed) = &field.fixed_value {
66 out.push_str(&format!(" /// Fixed value: \"{}\"\n", fixed));
67 out.push_str(&format!(
68 " #[serde(default = \"default_{}__{}\")]\n",
69 snake_case(&class.name),
70 field.name
71 ));
72 out.push_str(&format!(" pub {}: String,\n", field.name));
73 continue;
74 }
75
76 let rust_type = baml_type_to_rust(&field.ty, &class.name, &field.name, module, enum_map);
77
78 if matches!(&field.ty, BamlType::Optional(_)) {
80 out.push_str(" #[serde(skip_serializing_if = \"Option::is_none\")]\n");
81 }
82
83 out.push_str(&format!(" pub {}: {},\n", field.name, rust_type));
84 }
85
86 out.push_str("}\n\n");
87
88 for field in &class.fields {
90 if let Some(fixed) = &field.fixed_value {
91 out.push_str(&format!(
92 "fn default_{}__{}() -> String {{ \"{}\".to_string() }}\n",
93 snake_case(&class.name),
94 field.name,
95 fixed
96 ));
97 }
98 }
99 if class.fields.iter().any(|f| f.fixed_value.is_some()) {
101 out.push('\n');
102 }
103}
104
105fn generate_string_enum(out: &mut String, name: &str, variants: &[String]) {
106 out.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
107 out.push_str(&format!("pub enum {} {{\n", name));
108 for variant in variants {
109 let rust_variant = pascal_case(variant);
110 out.push_str(&format!(" #[serde(rename = \"{}\")]\n", variant));
111 out.push_str(&format!(" {},\n", rust_variant));
112 }
113 out.push_str("}\n\n");
114}
115
116fn generate_tool_registry(out: &mut String, module: &BamlModule) {
117 let tool_classes: Vec<&BamlClass> = module
119 .classes
120 .iter()
121 .filter(|c| {
122 c.fields
123 .iter()
124 .any(|f| f.name == "task" && f.fixed_value.is_some())
125 })
126 .collect();
127
128 if tool_classes.is_empty() {
129 return;
130 }
131
132 out.push_str("// --- Tool Registry ---\n\n");
133 out.push_str("use crate::tool::ToolDef;\n\n");
134 out.push_str("/// All tools extracted from BAML definitions.\n");
135 out.push_str("pub fn all_tools() -> Vec<ToolDef> {\n");
136 out.push_str(" vec![\n");
137
138 for class in &tool_classes {
139 let task_field = class.fields.iter().find(|f| f.name == "task").unwrap();
140 let tool_name = task_field.fixed_value.as_deref().unwrap();
141 let description = task_field.description.as_deref().unwrap_or(&class.name);
142
143 out.push_str(&format!(
144 " crate::tool::tool::<{}>(\"{}\", \"{}\"),\n",
145 class.name,
146 tool_name,
147 escape_string(description),
148 ));
149 }
150
151 out.push_str(" ]\n");
152 out.push_str("}\n\n");
153
154 out.push_str("/// Union of all tool types (for dispatching tool calls).\n");
156 out.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
157 out.push_str("#[serde(tag = \"task\")]\n");
158 out.push_str("pub enum ActionUnion {\n");
159
160 for class in &tool_classes {
161 let task_field = class.fields.iter().find(|f| f.name == "task").unwrap();
162 let tool_name = task_field.fixed_value.as_deref().unwrap();
163 out.push_str(&format!(" #[serde(rename = \"{}\")]\n", tool_name));
164 out.push_str(&format!(" {}({}),\n", class.name, class.name));
165 }
166
167 out.push_str("}\n\n");
168}
169
170fn generate_prompts(out: &mut String, module: &BamlModule) {
171 if module.functions.is_empty() {
172 return;
173 }
174
175 out.push_str("// --- Prompt Constants ---\n\n");
176
177 for func in &module.functions {
178 let const_name = screaming_snake_case(&func.name);
179 out.push_str(&format!(
181 "pub const {}_PROMPT: &str = r##\"\n{}\"##;\n\n",
182 const_name,
183 func.prompt.trim(),
184 ));
185 }
186}
187
188fn baml_type_to_rust(
191 ty: &BamlType,
192 class_name: &str,
193 field_name: &str,
194 module: &BamlModule,
195 enum_map: &mut Vec<(String, Vec<String>)>,
196) -> String {
197 match ty {
198 BamlType::String => "String".to_string(),
199 BamlType::Int => "i64".to_string(),
200 BamlType::Float => "f64".to_string(),
201 BamlType::Bool => "bool".to_string(),
202 BamlType::Image => "String".to_string(), BamlType::Ref(name) => {
204 if module.find_class(name).is_some() {
205 name.clone()
206 } else {
207 "String".to_string()
209 }
210 }
211 BamlType::Optional(inner) => {
212 let inner_rust = baml_type_to_rust(inner, class_name, field_name, module, enum_map);
213 format!("Option<{}>", inner_rust)
214 }
215 BamlType::Array(inner) => {
216 let inner_rust = baml_type_to_rust(inner, class_name, field_name, module, enum_map);
217 format!("Vec<{}>", inner_rust)
218 }
219 BamlType::StringEnum(variants) => {
220 let enum_name = format!("{}{}", class_name, pascal_case(field_name));
222 if !enum_map.iter().any(|(n, _)| n == &enum_name) {
223 enum_map.push((enum_name.clone(), variants.clone()));
224 }
225 enum_name
226 }
227 BamlType::Union(variants) => {
228 if variants.len() <= 4 {
231 let enum_name = format!("{}{}", class_name, pascal_case(field_name));
233 if !enum_map.iter().any(|(n, _)| n == &enum_name) {
234 }
237 }
238 "serde_json::Value".to_string()
239 }
240 }
241}
242
243fn snake_case(s: &str) -> String {
246 let mut result = String::new();
247 for (i, c) in s.chars().enumerate() {
248 if c.is_uppercase() {
249 if i > 0 {
250 result.push('_');
251 }
252 result.push(c.to_ascii_lowercase());
253 } else {
254 result.push(c);
255 }
256 }
257 result
258}
259
260fn pascal_case(s: &str) -> String {
261 s.split('_')
262 .map(|word| {
263 let mut chars = word.chars();
264 match chars.next() {
265 Some(c) => c.to_uppercase().to_string() + chars.as_str(),
266 None => String::new(),
267 }
268 })
269 .collect()
270}
271
272fn screaming_snake_case(s: &str) -> String {
273 snake_case(s).to_uppercase()
274}
275
276fn escape_string(s: &str) -> String {
277 s.replace('\\', "\\\\").replace('"', "\\\"")
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn generates_from_simple_baml() {
286 let source = r#"
287class CutDecision {
288 action "trim" | "keep" | "highlight" @description("Editing action")
289 reason string @description("Short reasoning")
290}
291"#;
292 let mut module = BamlModule::default();
293 module.parse_source(source);
294
295 let code = generate(&module);
296 assert!(code.contains("pub struct CutDecision"));
297 assert!(code.contains("pub action: CutDecisionAction"));
298 assert!(code.contains("pub reason: String"));
299 assert!(code.contains("pub enum CutDecisionAction"));
300 assert!(code.contains("#[serde(rename = \"trim\")]"));
301 }
302
303 #[test]
304 fn generates_tools_from_baml() {
305 let source = r#"
306class FfmpegTask {
307 task "ffmpeg_operation" @description("FFmpeg operations") @stream.not_null
308 operation "convert" | "trim"
309 input_path string | null
310}
311"#;
312 let mut module = BamlModule::default();
313 module.parse_source(source);
314
315 let code = generate(&module);
316 assert!(code.contains("pub fn all_tools()"));
317 assert!(code.contains("\"ffmpeg_operation\""));
318 assert!(code.contains("pub enum ActionUnion"));
319 assert!(code.contains("FfmpegTask(FfmpegTask)"));
320 }
321
322 #[test]
323 fn generates_from_real_montage_baml() {
324 let mut path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"));
325 path.pop(); path.pop(); path.pop(); path.push("startups");
329 path.push("active");
330 path.push("video-analyzer");
331 path.push("crates");
332 path.push("va-agent");
333 path.push("baml_src");
334 path.push("montage");
335 path.set_extension("baml");
336 if !path.exists() {
337 eprintln!("Skipping: montage.baml not found at {}", path.display());
338 return;
339 }
340 let source = std::fs::read_to_string(&path).unwrap();
341 let mut module = BamlModule::default();
342 module.parse_source(&source);
343
344 let code = generate(&module);
345
346 assert!(code.contains("pub struct MontageAgentNextStep"));
348 assert!(code.contains("pub struct AnalysisTask"));
349 assert!(code.contains("pub struct FfmpegTask"));
350 assert!(code.contains("pub struct ProjectTask"));
351
352 assert!(code.contains("pub fn all_tools()"));
354 assert!(code.contains("\"analysis_operation\""));
355 assert!(code.contains("\"ffmpeg_operation\""));
356 assert!(code.contains("\"project_operation\""));
357
358 assert!(code.contains("DECIDE_MONTAGE_NEXT_STEP_SGR_PROMPT"));
360 assert!(code.contains("ANALYZE_SEGMENT_SGR_PROMPT"));
361 }
362}