Skip to main content

sage_codegen/
generator.rs

1//! Main code generator.
2
3use crate::emit::Emitter;
4use sage_loader::ModuleTree;
5use sage_parser::{
6    AgentDecl, BinOp, Block, ConstDecl, EnumDecl, EventKind, Expr, FnDecl, Literal, Program,
7    RecordDecl, Stmt, StringPart, UnaryOp,
8};
9use sage_types::TypeExpr;
10
11/// Generated Rust project files.
12pub struct GeneratedProject {
13    /// The main.rs content.
14    pub main_rs: String,
15    /// The Cargo.toml content.
16    pub cargo_toml: String,
17}
18
19/// Generate Rust code from a Sage program (single file).
20pub fn generate(program: &Program, project_name: &str) -> GeneratedProject {
21    let mut gen = Generator::new();
22    let main_rs = gen.generate_program(program);
23    let cargo_toml = gen.generate_cargo_toml(project_name);
24    GeneratedProject {
25        main_rs,
26        cargo_toml,
27    }
28}
29
30/// Generate Rust code from a module tree (multi-file project).
31///
32/// This flattens all modules into a single Rust file, generating all agents
33/// and functions with appropriate visibility modifiers.
34pub fn generate_module_tree(tree: &ModuleTree, project_name: &str) -> GeneratedProject {
35    let mut gen = Generator::new();
36    let main_rs = gen.generate_module_tree(tree);
37    let cargo_toml = gen.generate_cargo_toml(project_name);
38    GeneratedProject {
39        main_rs,
40        cargo_toml,
41    }
42}
43
44struct Generator {
45    emit: Emitter,
46}
47
48impl Generator {
49    fn new() -> Self {
50        Self {
51            emit: Emitter::new(),
52        }
53    }
54
55    fn generate_program(&mut self, program: &Program) -> String {
56        // Prelude
57        self.emit
58            .writeln("//! Generated by Sage compiler. Do not edit.");
59        self.emit.blank_line();
60        self.emit.writeln("use sage_runtime::prelude::*;");
61        self.emit.blank_line();
62
63        // Constants
64        for const_decl in &program.consts {
65            self.generate_const(const_decl);
66            self.emit.blank_line();
67        }
68
69        // Enums
70        for enum_decl in &program.enums {
71            self.generate_enum(enum_decl);
72            self.emit.blank_line();
73        }
74
75        // Records
76        for record in &program.records {
77            self.generate_record(record);
78            self.emit.blank_line();
79        }
80
81        // Functions
82        for func in &program.functions {
83            self.generate_function(func);
84            self.emit.blank_line();
85        }
86
87        // Agents
88        for agent in &program.agents {
89            self.generate_agent(agent);
90            self.emit.blank_line();
91        }
92
93        // Entry point (required for executables)
94        if let Some(run_agent) = &program.run_agent {
95            // RFC-0007: Check if entry agent has an error handler
96            let has_error_handler = program
97                .agents
98                .iter()
99                .find(|a| a.name.name == run_agent.name)
100                .map_or(false, |a| {
101                    a.handlers
102                        .iter()
103                        .any(|h| matches!(h.event, EventKind::Error { .. }))
104                });
105            self.generate_main(&run_agent.name, has_error_handler);
106        }
107
108        std::mem::take(&mut self.emit).finish()
109    }
110
111    fn generate_module_tree(&mut self, tree: &ModuleTree) -> String {
112        // Prelude
113        self.emit
114            .writeln("//! Generated by Sage compiler. Do not edit.");
115        self.emit.blank_line();
116        self.emit.writeln("use sage_runtime::prelude::*;");
117        self.emit.blank_line();
118
119        // Generate all modules, starting with the root
120        // We flatten everything into one file for simplicity
121        // (A more advanced implementation would generate mod.rs files)
122
123        // First, generate non-root modules
124        for (path, module) in &tree.modules {
125            if path != &tree.root {
126                self.emit.write("// Module: ");
127                if path.is_empty() {
128                    self.emit.writeln("(root)");
129                } else {
130                    self.emit.writeln(&path.join("::"));
131                }
132
133                for const_decl in &module.program.consts {
134                    self.generate_const(const_decl);
135                    self.emit.blank_line();
136                }
137
138                for enum_decl in &module.program.enums {
139                    self.generate_enum(enum_decl);
140                    self.emit.blank_line();
141                }
142
143                for record in &module.program.records {
144                    self.generate_record(record);
145                    self.emit.blank_line();
146                }
147
148                for func in &module.program.functions {
149                    self.generate_function(func);
150                    self.emit.blank_line();
151                }
152
153                for agent in &module.program.agents {
154                    self.generate_agent(agent);
155                    self.emit.blank_line();
156                }
157            }
158        }
159
160        // Then, generate the root module
161        if let Some(root_module) = tree.modules.get(&tree.root) {
162            self.emit.writeln("// Root module");
163
164            for const_decl in &root_module.program.consts {
165                self.generate_const(const_decl);
166                self.emit.blank_line();
167            }
168
169            for enum_decl in &root_module.program.enums {
170                self.generate_enum(enum_decl);
171                self.emit.blank_line();
172            }
173
174            for record in &root_module.program.records {
175                self.generate_record(record);
176                self.emit.blank_line();
177            }
178
179            for func in &root_module.program.functions {
180                self.generate_function(func);
181                self.emit.blank_line();
182            }
183
184            for agent in &root_module.program.agents {
185                self.generate_agent(agent);
186                self.emit.blank_line();
187            }
188
189            // Entry point (only in root module)
190            if let Some(run_agent) = &root_module.program.run_agent {
191                // RFC-0007: Check if entry agent has an error handler
192                let has_error_handler = root_module
193                    .program
194                    .agents
195                    .iter()
196                    .find(|a| a.name.name == run_agent.name)
197                    .map_or(false, |a| {
198                        a.handlers
199                            .iter()
200                            .any(|h| matches!(h.event, EventKind::Error { .. }))
201                    });
202                self.generate_main(&run_agent.name, has_error_handler);
203            }
204        }
205
206        std::mem::take(&mut self.emit).finish()
207    }
208
209    fn generate_cargo_toml(&self, name: &str) -> String {
210        // Use a relative path that works from target/sage/<project>/
211        // This assumes the standard project layout
212        format!(
213            r#"[package]
214name = "{name}"
215version = "0.1.0"
216edition = "2021"
217
218[dependencies]
219sage-runtime = {{ path = "../../../crates/sage-runtime" }}
220tokio = {{ version = "1", features = ["full"] }}
221serde = {{ version = "1", features = ["derive"] }}
222serde_json = "1"
223
224# Standalone project, not part of parent workspace
225[workspace]
226"#
227        )
228    }
229
230    fn generate_const(&mut self, const_decl: &ConstDecl) {
231        if const_decl.is_pub {
232            self.emit.write("pub ");
233        }
234        self.emit.write("const ");
235        self.emit.write(&const_decl.name.name);
236        self.emit.write(": ");
237        self.emit_type(&const_decl.ty);
238        self.emit.write(" = ");
239        self.generate_expr(&const_decl.value);
240        self.emit.writeln(";");
241    }
242
243    fn generate_enum(&mut self, enum_decl: &EnumDecl) {
244        if enum_decl.is_pub {
245            self.emit.write("pub ");
246        }
247        self.emit
248            .writeln("#[derive(Debug, Clone, Copy, PartialEq, Eq)]");
249        self.emit.write("enum ");
250        self.emit.write(&enum_decl.name.name);
251        self.emit.writeln(" {");
252        self.emit.indent();
253        for variant in &enum_decl.variants {
254            self.emit.write(&variant.name.name);
255            if let Some(payload_ty) = &variant.payload {
256                self.emit.write("(");
257                self.emit_type(payload_ty);
258                self.emit.write(")");
259            }
260            self.emit.writeln(",");
261        }
262        self.emit.dedent();
263        self.emit.writeln("}");
264    }
265
266    fn generate_record(&mut self, record: &RecordDecl) {
267        if record.is_pub {
268            self.emit.write("pub ");
269        }
270        self.emit.writeln("#[derive(Debug, Clone)]");
271        self.emit.write("struct ");
272        self.emit.write(&record.name.name);
273        self.emit.writeln(" {");
274        self.emit.indent();
275        for field in &record.fields {
276            self.emit.write(&field.name.name);
277            self.emit.write(": ");
278            self.emit_type(&field.ty);
279            self.emit.writeln(",");
280        }
281        self.emit.dedent();
282        self.emit.writeln("}");
283    }
284
285    fn generate_function(&mut self, func: &FnDecl) {
286        // Function signature with visibility
287        if func.is_pub {
288            self.emit.write("pub ");
289        }
290        self.emit.write("fn ");
291        self.emit.write(&func.name.name);
292        self.emit.write("(");
293
294        for (i, param) in func.params.iter().enumerate() {
295            if i > 0 {
296                self.emit.write(", ");
297            }
298            self.emit.write(&param.name.name);
299            self.emit.write(": ");
300            self.emit_type(&param.ty);
301        }
302
303        self.emit.write(") -> ");
304
305        // RFC-0007: Wrap return type in SageResult if fallible
306        if func.is_fallible {
307            self.emit.write("SageResult<");
308            self.emit_type(&func.return_ty);
309            self.emit.write(">");
310        } else {
311            self.emit_type(&func.return_ty);
312        }
313
314        self.emit.write(" ");
315        self.generate_block(&func.body);
316    }
317
318    fn generate_agent(&mut self, agent: &AgentDecl) {
319        let name = &agent.name.name;
320
321        // Struct definition with visibility
322        if agent.is_pub {
323            self.emit.write("pub ");
324        }
325        self.emit.write("struct ");
326        self.emit.write(name);
327        if agent.beliefs.is_empty() {
328            self.emit.writeln(";");
329        } else {
330            self.emit.writeln(" {");
331            self.emit.indent();
332            for belief in &agent.beliefs {
333                self.emit.write(&belief.name.name);
334                self.emit.write(": ");
335                self.emit_type(&belief.ty);
336                self.emit.writeln(",");
337            }
338            self.emit.dedent();
339            self.emit.writeln("}");
340        }
341        self.emit.blank_line();
342
343        // Find the output type from the start handler
344        let output_type = self.infer_agent_output_type(agent);
345
346        // Impl block
347        self.emit.write("impl ");
348        self.emit.write(name);
349        self.emit.writeln(" {");
350        self.emit.indent();
351
352        // Generate handlers
353        for handler in &agent.handlers {
354            match &handler.event {
355                EventKind::Start => {
356                    self.emit
357                        .write("async fn on_start(self, ctx: AgentContext<");
358                    self.emit.write(&output_type);
359                    self.emit.write(">) -> SageResult<");
360                    self.emit.write(&output_type);
361                    self.emit.writeln("> {");
362                    self.emit.indent();
363                    self.generate_block_contents(&handler.body);
364                    self.emit.dedent();
365                    self.emit.writeln("}");
366                }
367
368                // RFC-0007: Generate on_error handler
369                EventKind::Error { param_name } => {
370                    self.emit.write("async fn on_error(self, ");
371                    self.emit.write(&param_name.name);
372                    self.emit.write(": SageError, ctx: AgentContext<");
373                    self.emit.write(&output_type);
374                    self.emit.write(">) -> SageResult<");
375                    self.emit.write(&output_type);
376                    self.emit.writeln("> {");
377                    self.emit.indent();
378                    self.generate_block_contents(&handler.body);
379                    self.emit.dedent();
380                    self.emit.writeln("}");
381                }
382
383                // Other handlers (message, stop) - future work
384                _ => {}
385            }
386        }
387
388        self.emit.dedent();
389        self.emit.writeln("}");
390    }
391
392    fn generate_main(&mut self, entry_agent: &str, has_error_handler: bool) {
393        self.emit.writeln("#[tokio::main]");
394        self.emit
395            .writeln("async fn main() -> Result<(), Box<dyn std::error::Error>> {");
396        self.emit.indent();
397
398        if has_error_handler {
399            // RFC-0007: Generate error dispatch code
400            self.emit
401                .writeln("let handle = sage_runtime::spawn(|ctx| async move {");
402            self.emit.indent();
403            self.emit.write("match ");
404            self.emit.write(entry_agent);
405            self.emit.writeln(".on_start(ctx.clone()).await {");
406            self.emit.indent();
407            self.emit.writeln("Ok(result) => Ok(result),");
408            self.emit.write("Err(e) => ");
409            self.emit.write(entry_agent);
410            self.emit.writeln(".on_error(e, ctx).await,");
411            self.emit.dedent();
412            self.emit.writeln("}");
413            self.emit.dedent();
414            self.emit.writeln("});");
415        } else {
416            self.emit.write("let handle = sage_runtime::spawn(|ctx| ");
417            self.emit.write(entry_agent);
418            self.emit.writeln(".on_start(ctx));");
419        }
420
421        self.emit.writeln("let result = handle.result().await?;");
422        self.emit.writeln("println!(\"{:?}\", result);");
423        self.emit.writeln("Ok(())");
424
425        self.emit.dedent();
426        self.emit.writeln("}");
427    }
428
429    fn generate_block(&mut self, block: &Block) {
430        self.emit.open_brace();
431        self.generate_block_contents(block);
432        self.emit.close_brace();
433    }
434
435    fn generate_block_inline(&mut self, block: &Block) {
436        self.emit.open_brace();
437        self.generate_block_contents(block);
438        self.emit.close_brace_inline();
439    }
440
441    fn generate_block_contents(&mut self, block: &Block) {
442        for stmt in &block.stmts {
443            self.generate_stmt(stmt);
444        }
445    }
446
447    fn generate_stmt(&mut self, stmt: &Stmt) {
448        match stmt {
449            Stmt::Let {
450                name, ty, value, ..
451            } => {
452                self.emit.write("let ");
453                if ty.is_some() {
454                    self.emit.write(&name.name);
455                    self.emit.write(": ");
456                    self.emit_type(ty.as_ref().unwrap());
457                } else {
458                    self.emit.write(&name.name);
459                }
460                self.emit.write(" = ");
461                self.generate_expr(value);
462                self.emit.writeln(";");
463            }
464
465            Stmt::Assign { name, value, .. } => {
466                self.emit.write(&name.name);
467                self.emit.write(" = ");
468                self.generate_expr(value);
469                self.emit.writeln(";");
470            }
471
472            Stmt::Return { value, .. } => {
473                self.emit.write("return ");
474                if let Some(expr) = value {
475                    self.generate_expr(expr);
476                }
477                self.emit.writeln(";");
478            }
479
480            Stmt::If {
481                condition,
482                then_block,
483                else_block,
484                ..
485            } => {
486                self.emit.write("if ");
487                self.generate_expr(condition);
488                self.emit.write(" ");
489                if else_block.is_some() {
490                    self.generate_block_inline(then_block);
491                    self.emit.write(" else ");
492                    match else_block.as_ref().unwrap() {
493                        sage_parser::ElseBranch::Block(block) => {
494                            self.generate_block(block);
495                        }
496                        sage_parser::ElseBranch::ElseIf(stmt) => {
497                            self.generate_stmt(stmt);
498                        }
499                    }
500                } else {
501                    self.generate_block(then_block);
502                }
503            }
504
505            Stmt::For {
506                pattern,
507                iter,
508                body,
509                ..
510            } => {
511                self.emit.write("for ");
512                self.emit_pattern(pattern);
513                self.emit.write(" in ");
514                self.generate_expr(iter);
515                self.emit.write(" ");
516                self.generate_block(body);
517            }
518
519            Stmt::While {
520                condition, body, ..
521            } => {
522                self.emit.write("while ");
523                self.generate_expr(condition);
524                self.emit.write(" ");
525                self.generate_block(body);
526            }
527
528            Stmt::Loop { body, .. } => {
529                self.emit.write("loop ");
530                self.generate_block(body);
531            }
532
533            Stmt::Break { .. } => {
534                self.emit.writeln("break;");
535            }
536
537            Stmt::Expr { expr, .. } => {
538                // Handle emit specially
539                if let Expr::Emit { value, .. } = expr {
540                    self.emit.write("return ctx.emit(");
541                    self.generate_expr(value);
542                    self.emit.writeln(");");
543                } else {
544                    self.generate_expr(expr);
545                    self.emit.writeln(";");
546                }
547            }
548
549            Stmt::LetTuple { names, value, .. } => {
550                self.emit.write("let (");
551                for (i, name) in names.iter().enumerate() {
552                    if i > 0 {
553                        self.emit.write(", ");
554                    }
555                    self.emit.write(&name.name);
556                }
557                self.emit.write(") = ");
558                self.generate_expr(value);
559                self.emit.writeln(";");
560            }
561        }
562    }
563
564    fn generate_expr(&mut self, expr: &Expr) {
565        match expr {
566            Expr::Literal { value, .. } => {
567                self.emit_literal(value);
568            }
569
570            Expr::Var { name, .. } => {
571                self.emit.write(&name.name);
572            }
573
574            Expr::Binary {
575                op, left, right, ..
576            } => {
577                // Handle string concatenation specially
578                if matches!(op, BinOp::Concat) {
579                    self.emit.write("format!(\"{}{}\", ");
580                    self.generate_expr(left);
581                    self.emit.write(", ");
582                    self.generate_expr(right);
583                    self.emit.write(")");
584                } else {
585                    self.emit.write("(");
586                    self.generate_expr(left);
587                    self.emit.write(" ");
588                    self.emit_binop(op);
589                    self.emit.write(" ");
590                    self.generate_expr(right);
591                    self.emit.write(")");
592                }
593            }
594
595            Expr::Unary { op, operand, .. } => {
596                self.emit_unaryop(op);
597                self.generate_expr(operand);
598            }
599
600            Expr::Call { name, args, .. } => {
601                let fn_name = &name.name;
602
603                // Handle builtins
604                match fn_name.as_str() {
605                    "print" => {
606                        self.emit.write("println!(\"{}\", ");
607                        self.generate_expr(&args[0]);
608                        self.emit.write(")");
609                    }
610                    "str" => {
611                        self.generate_expr(&args[0]);
612                        self.emit.write(".to_string()");
613                    }
614                    "len" => {
615                        self.generate_expr(&args[0]);
616                        self.emit.write(".len() as i64");
617                    }
618                    _ => {
619                        self.emit.write(fn_name);
620                        self.emit.write("(");
621                        for (i, arg) in args.iter().enumerate() {
622                            if i > 0 {
623                                self.emit.write(", ");
624                            }
625                            self.generate_expr(arg);
626                        }
627                        self.emit.write(")");
628                    }
629                }
630            }
631
632            Expr::SelfField { field, .. } => {
633                self.emit.write("self.");
634                self.emit.write(&field.name);
635            }
636
637            Expr::SelfMethodCall { method, args, .. } => {
638                self.emit.write("self.");
639                self.emit.write(&method.name);
640                self.emit.write("(");
641                for (i, arg) in args.iter().enumerate() {
642                    if i > 0 {
643                        self.emit.write(", ");
644                    }
645                    self.generate_expr(arg);
646                }
647                self.emit.write(")");
648            }
649
650            Expr::List { elements, .. } => {
651                self.emit.write("vec![");
652                for (i, elem) in elements.iter().enumerate() {
653                    if i > 0 {
654                        self.emit.write(", ");
655                    }
656                    self.generate_expr(elem);
657                }
658                self.emit.write("]");
659            }
660
661            Expr::Paren { inner, .. } => {
662                self.emit.write("(");
663                self.generate_expr(inner);
664                self.emit.write(")");
665            }
666
667            Expr::Infer { template, .. } => {
668                self.emit.write("ctx.infer_string(&");
669                self.emit_string_template(template);
670                self.emit.write(").await?");
671            }
672
673            Expr::Spawn { agent, fields, .. } => {
674                self.emit.write("sage_runtime::spawn(|ctx| ");
675                self.emit.write(&agent.name);
676                if fields.is_empty() {
677                    self.emit.write(".on_start(ctx))");
678                } else {
679                    self.emit.write(" { ");
680                    for (i, field) in fields.iter().enumerate() {
681                        if i > 0 {
682                            self.emit.write(", ");
683                        }
684                        self.emit.write(&field.name.name);
685                        self.emit.write(": ");
686                        self.generate_expr(&field.value);
687                    }
688                    self.emit.write(" }.on_start(ctx))");
689                }
690            }
691
692            Expr::Await { handle, .. } => {
693                self.generate_expr(handle);
694                self.emit.write(".result().await?");
695            }
696
697            Expr::Send {
698                handle, message, ..
699            } => {
700                self.generate_expr(handle);
701                self.emit.write(".send(sage_runtime::Message::new(");
702                self.generate_expr(message);
703                self.emit.write(")?).await?");
704            }
705
706            Expr::Emit { value, .. } => {
707                self.emit.write("ctx.emit(");
708                self.generate_expr(value);
709                self.emit.write(")");
710            }
711
712            Expr::StringInterp { template, .. } => {
713                self.emit_string_template(template);
714            }
715
716            // TODO: Implement in RFC-0005
717            Expr::Match {
718                scrutinee, arms, ..
719            } => {
720                self.emit.write("match ");
721                self.generate_expr(scrutinee);
722                self.emit.writeln(" {");
723                self.emit.indent();
724                for arm in arms {
725                    self.emit_pattern(&arm.pattern);
726                    self.emit.write(" => ");
727                    self.generate_expr(&arm.body);
728                    self.emit.writeln(",");
729                }
730                self.emit.dedent();
731                self.emit.write("}");
732            }
733
734            // TODO: Implement in RFC-0005
735            Expr::RecordConstruct { name, fields, .. } => {
736                self.emit.write(&name.name);
737                self.emit.write(" { ");
738                for (i, field) in fields.iter().enumerate() {
739                    if i > 0 {
740                        self.emit.write(", ");
741                    }
742                    self.emit.write(&field.name.name);
743                    self.emit.write(": ");
744                    self.generate_expr(&field.value);
745                }
746                self.emit.write(" }");
747            }
748
749            // TODO: Implement in RFC-0005
750            Expr::FieldAccess { object, field, .. } => {
751                self.generate_expr(object);
752                self.emit.write(".");
753                self.emit.write(&field.name);
754            }
755
756            Expr::Receive { .. } => {
757                self.emit.write("ctx.receive().await?");
758            }
759
760            // RFC-0007: Error handling
761            Expr::Try { expr, .. } => {
762                // Generate the inner expression with ? for error propagation
763                self.generate_expr(expr);
764                self.emit.write("?");
765            }
766
767            Expr::Catch {
768                expr,
769                error_bind,
770                recovery,
771                ..
772            } => {
773                // Generate a match expression to handle the Result
774                self.emit.write("match ");
775                self.generate_expr(expr);
776                self.emit.writeln(" {");
777                self.emit.indent();
778
779                // Ok arm - unwrap the value
780                self.emit.writeln("Ok(__val) => __val,");
781
782                // Err arm - run recovery
783                if let Some(err_name) = error_bind {
784                    self.emit.write("Err(");
785                    self.emit.write(&err_name.name);
786                    self.emit.write(") => ");
787                } else {
788                    self.emit.write("Err(_) => ");
789                }
790                self.generate_expr(recovery);
791                self.emit.writeln(",");
792
793                self.emit.dedent();
794                self.emit.write("}");
795            }
796
797            // RFC-0009: Closures
798            Expr::Closure { params, body, .. } => {
799                // Generate: Box::new(move |param1: Type1, param2: Type2| { body })
800                self.emit.write("Box::new(move |");
801                for (i, param) in params.iter().enumerate() {
802                    if i > 0 {
803                        self.emit.write(", ");
804                    }
805                    self.emit.write(&param.name.name);
806                    if let Some(ty) = &param.ty {
807                        self.emit.write(": ");
808                        self.emit_type(ty);
809                    }
810                }
811                self.emit.write("| ");
812                self.generate_expr(body);
813                self.emit.write(")");
814            }
815
816            // RFC-0010: Tuples and Maps
817            Expr::Tuple { elements, .. } => {
818                self.emit.write("(");
819                for (i, elem) in elements.iter().enumerate() {
820                    if i > 0 {
821                        self.emit.write(", ");
822                    }
823                    self.generate_expr(elem);
824                }
825                self.emit.write(")");
826            }
827
828            Expr::TupleIndex { tuple, index, .. } => {
829                self.generate_expr(tuple);
830                self.emit.write(&format!(".{index}"));
831            }
832
833            Expr::Map { entries, .. } => {
834                if entries.is_empty() {
835                    self.emit.write("std::collections::HashMap::new()");
836                } else {
837                    self.emit.write("std::collections::HashMap::from([");
838                    for (i, entry) in entries.iter().enumerate() {
839                        if i > 0 {
840                            self.emit.write(", ");
841                        }
842                        self.emit.write("(");
843                        self.generate_expr(&entry.key);
844                        self.emit.write(", ");
845                        self.generate_expr(&entry.value);
846                        self.emit.write(")");
847                    }
848                    self.emit.write("])");
849                }
850            }
851
852            Expr::VariantConstruct {
853                enum_name,
854                variant,
855                payload,
856                ..
857            } => {
858                self.emit.write(&enum_name.name);
859                self.emit.write("::");
860                self.emit.write(&variant.name);
861                if let Some(payload_expr) = payload {
862                    self.emit.write("(");
863                    self.generate_expr(payload_expr);
864                    self.emit.write(")");
865                }
866            }
867        }
868    }
869
870    fn emit_pattern(&mut self, pattern: &sage_parser::Pattern) {
871        use sage_parser::Pattern;
872        match pattern {
873            Pattern::Wildcard { .. } => {
874                self.emit.write("_");
875            }
876            Pattern::Variant {
877                enum_name,
878                variant,
879                payload,
880                ..
881            } => {
882                if let Some(enum_name) = enum_name {
883                    self.emit.write(&enum_name.name);
884                    self.emit.write("::");
885                }
886                self.emit.write(&variant.name);
887                if let Some(inner_pattern) = payload {
888                    self.emit.write("(");
889                    self.emit_pattern(inner_pattern);
890                    self.emit.write(")");
891                }
892            }
893            Pattern::Literal { value, .. } => {
894                self.emit_literal(value);
895            }
896            Pattern::Binding { name, .. } => {
897                self.emit.write(&name.name);
898            }
899            Pattern::Tuple { elements, .. } => {
900                self.emit.write("(");
901                for (i, elem) in elements.iter().enumerate() {
902                    if i > 0 {
903                        self.emit.write(", ");
904                    }
905                    self.emit_pattern(elem);
906                }
907                self.emit.write(")");
908            }
909        }
910    }
911
912    fn emit_literal(&mut self, lit: &Literal) {
913        match lit {
914            Literal::Int(n) => {
915                self.emit.write(&format!("{n}_i64"));
916            }
917            Literal::Float(f) => {
918                self.emit.write(&format!("{f}_f64"));
919            }
920            Literal::Bool(b) => {
921                self.emit.write(if *b { "true" } else { "false" });
922            }
923            Literal::String(s) => {
924                // Escape the string for Rust
925                self.emit.write("\"");
926                for c in s.chars() {
927                    match c {
928                        '"' => self.emit.write_raw("\\\""),
929                        '\\' => self.emit.write_raw("\\\\"),
930                        '\n' => self.emit.write_raw("\\n"),
931                        '\r' => self.emit.write_raw("\\r"),
932                        '\t' => self.emit.write_raw("\\t"),
933                        _ => self.emit.write_raw(&c.to_string()),
934                    }
935                }
936                self.emit.write("\".to_string()");
937            }
938        }
939    }
940
941    fn emit_string_template(&mut self, template: &sage_parser::StringTemplate) {
942        if !template.has_interpolations() {
943            // Simple string literal
944            if let Some(StringPart::Literal(s)) = template.parts.first() {
945                self.emit.write("\"");
946                self.emit.write_raw(s);
947                self.emit.write("\".to_string()");
948            }
949            return;
950        }
951
952        // Build format string and args
953        self.emit.write("format!(\"");
954        for part in &template.parts {
955            match part {
956                StringPart::Literal(s) => {
957                    // Escape braces for format string
958                    let escaped = s.replace('{', "{{").replace('}', "}}");
959                    self.emit.write_raw(&escaped);
960                }
961                StringPart::Interpolation(_) => {
962                    self.emit.write_raw("{}");
963                }
964            }
965        }
966        self.emit.write("\"");
967
968        // Add the interpolation args
969        for part in &template.parts {
970            if let StringPart::Interpolation(ident) = part {
971                self.emit.write(", ");
972                self.emit.write(&ident.name);
973            }
974        }
975        self.emit.write(")");
976    }
977
978    fn emit_type(&mut self, ty: &TypeExpr) {
979        match ty {
980            TypeExpr::Int => self.emit.write("i64"),
981            TypeExpr::Float => self.emit.write("f64"),
982            TypeExpr::Bool => self.emit.write("bool"),
983            TypeExpr::String => self.emit.write("String"),
984            TypeExpr::Unit => self.emit.write("()"),
985            TypeExpr::List(inner) => {
986                self.emit.write("Vec<");
987                self.emit_type(inner);
988                self.emit.write(">");
989            }
990            TypeExpr::Option(inner) => {
991                self.emit.write("Option<");
992                self.emit_type(inner);
993                self.emit.write(">");
994            }
995            TypeExpr::Inferred(inner) => {
996                // Inferred<T> just becomes T at runtime
997                self.emit_type(inner);
998            }
999            TypeExpr::Agent(agent_name) => {
1000                // Agent handles use the agent's output type, but we don't know it here
1001                // For now, just use a generic output type
1002                self.emit.write("AgentHandle<");
1003                self.emit.write(&agent_name.name);
1004                self.emit.write("Output>");
1005            }
1006            TypeExpr::Named(name) => {
1007                self.emit.write(&name.name);
1008            }
1009
1010            // RFC-0007: Error handling
1011            TypeExpr::Error => {
1012                self.emit.write("sage_runtime::SageError");
1013            }
1014
1015            // RFC-0009: Function types
1016            TypeExpr::Fn(params, ret) => {
1017                self.emit.write("Box<dyn Fn(");
1018                for (i, param) in params.iter().enumerate() {
1019                    if i > 0 {
1020                        self.emit.write(", ");
1021                    }
1022                    self.emit_type(param);
1023                }
1024                self.emit.write(") -> ");
1025                self.emit_type(ret);
1026                self.emit.write(" + Send + 'static>");
1027            }
1028
1029            // RFC-0010: Maps, tuples, Result
1030            TypeExpr::Map(key, value) => {
1031                self.emit.write("std::collections::HashMap<");
1032                self.emit_type(key);
1033                self.emit.write(", ");
1034                self.emit_type(value);
1035                self.emit.write(">");
1036            }
1037            TypeExpr::Tuple(elems) => {
1038                self.emit.write("(");
1039                for (i, elem) in elems.iter().enumerate() {
1040                    if i > 0 {
1041                        self.emit.write(", ");
1042                    }
1043                    self.emit_type(elem);
1044                }
1045                self.emit.write(")");
1046            }
1047            TypeExpr::Result(ok, err) => {
1048                self.emit.write("Result<");
1049                self.emit_type(ok);
1050                self.emit.write(", ");
1051                self.emit_type(err);
1052                self.emit.write(">");
1053            }
1054        }
1055    }
1056
1057    fn emit_binop(&mut self, op: &BinOp) {
1058        let s = match op {
1059            BinOp::Add => "+",
1060            BinOp::Sub => "-",
1061            BinOp::Mul => "*",
1062            BinOp::Div => "/",
1063            BinOp::Eq => "==",
1064            BinOp::Ne => "!=",
1065            BinOp::Lt => "<",
1066            BinOp::Gt => ">",
1067            BinOp::Le => "<=",
1068            BinOp::Ge => ">=",
1069            BinOp::And => "&&",
1070            BinOp::Or => "||",
1071            BinOp::Concat => "++", // Handled specially above
1072        };
1073        self.emit.write(s);
1074    }
1075
1076    fn emit_unaryop(&mut self, op: &UnaryOp) {
1077        let s = match op {
1078            UnaryOp::Neg => "-",
1079            UnaryOp::Not => "!",
1080        };
1081        self.emit.write(s);
1082    }
1083
1084    fn infer_agent_output_type(&self, agent: &AgentDecl) -> String {
1085        // Look for emit expression in start handler to infer return type
1086        // For now, default to i64
1087        for handler in &agent.handlers {
1088            if let EventKind::Start = &handler.event {
1089                if let Some(ty) = self.find_emit_type(&handler.body) {
1090                    return ty;
1091                }
1092            }
1093        }
1094        "i64".to_string()
1095    }
1096
1097    fn find_emit_type(&self, block: &Block) -> Option<String> {
1098        for stmt in &block.stmts {
1099            if let Stmt::Expr { expr, .. } = stmt {
1100                if let Expr::Emit { value, .. } = expr {
1101                    return Some(self.infer_expr_type(value));
1102                }
1103            }
1104            // Check nested blocks
1105            if let Stmt::If {
1106                then_block,
1107                else_block,
1108                ..
1109            } = stmt
1110            {
1111                if let Some(ty) = self.find_emit_type(then_block) {
1112                    return Some(ty);
1113                }
1114                if let Some(else_branch) = else_block {
1115                    if let sage_parser::ElseBranch::Block(block) = else_branch {
1116                        if let Some(ty) = self.find_emit_type(block) {
1117                            return Some(ty);
1118                        }
1119                    }
1120                }
1121            }
1122        }
1123        None
1124    }
1125
1126    fn infer_expr_type(&self, expr: &Expr) -> String {
1127        match expr {
1128            Expr::Literal { value, .. } => match value {
1129                Literal::Int(_) => "i64".to_string(),
1130                Literal::Float(_) => "f64".to_string(),
1131                Literal::Bool(_) => "bool".to_string(),
1132                Literal::String(_) => "String".to_string(),
1133            },
1134            Expr::Var { .. } => "i64".to_string(), // Conservative default
1135            Expr::Binary { op, .. } => {
1136                if matches!(
1137                    op,
1138                    BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Gt | BinOp::Le | BinOp::Ge
1139                ) {
1140                    "bool".to_string()
1141                } else if matches!(op, BinOp::Concat) {
1142                    "String".to_string()
1143                } else {
1144                    "i64".to_string()
1145                }
1146            }
1147            Expr::Infer { .. } | Expr::StringInterp { .. } => "String".to_string(),
1148            Expr::Call { name, .. } if name.name == "str" => "String".to_string(),
1149            Expr::Call { name, .. } if name.name == "len" => "i64".to_string(),
1150            _ => "i64".to_string(),
1151        }
1152    }
1153}
1154
1155#[cfg(test)]
1156mod tests {
1157    use super::*;
1158    use sage_lexer::lex;
1159    use sage_parser::parse;
1160    use std::sync::Arc;
1161
1162    fn generate_source(source: &str) -> String {
1163        let lex_result = lex(source).expect("lexing failed");
1164        let source_arc: Arc<str> = Arc::from(source);
1165        let (program, errors) = parse(lex_result.tokens(), source_arc);
1166        assert!(errors.is_empty(), "parse errors: {errors:?}");
1167        let program = program.expect("should parse");
1168        generate(&program, "test").main_rs
1169    }
1170
1171    #[test]
1172    fn generate_minimal_program() {
1173        let source = r#"
1174            agent Main {
1175                on start {
1176                    emit(42);
1177                }
1178            }
1179            run Main;
1180        "#;
1181
1182        let output = generate_source(source);
1183        assert!(output.contains("struct Main;"));
1184        assert!(output.contains("async fn on_start"));
1185        assert!(output.contains("ctx.emit(42_i64)"));
1186        assert!(output.contains("#[tokio::main]"));
1187    }
1188
1189    #[test]
1190    fn generate_function() {
1191        let source = r#"
1192            fn add(a: Int, b: Int) -> Int {
1193                return a + b;
1194            }
1195            agent Main {
1196                on start {
1197                    emit(add(1, 2));
1198                }
1199            }
1200            run Main;
1201        "#;
1202
1203        let output = generate_source(source);
1204        assert!(output.contains("fn add(a: i64, b: i64) -> i64"));
1205        assert!(output.contains("return (a + b);"));
1206    }
1207
1208    #[test]
1209    fn generate_agent_with_beliefs() {
1210        let source = r#"
1211            agent Worker {
1212                value: Int
1213
1214                on start {
1215                    emit(self.value * 2);
1216                }
1217            }
1218            agent Main {
1219                on start {
1220                    emit(0);
1221                }
1222            }
1223            run Main;
1224        "#;
1225
1226        let output = generate_source(source);
1227        assert!(output.contains("struct Worker {"));
1228        assert!(output.contains("value: i64,"));
1229        assert!(output.contains("self.value"));
1230    }
1231
1232    #[test]
1233    fn generate_string_interpolation() {
1234        let source = r#"
1235            agent Main {
1236                on start {
1237                    let name = "World";
1238                    let msg = "Hello, {name}!";
1239                    print(msg);
1240                    emit(0);
1241                }
1242            }
1243            run Main;
1244        "#;
1245
1246        let output = generate_source(source);
1247        assert!(output.contains("format!(\"Hello, {}!\", name)"));
1248    }
1249
1250    #[test]
1251    fn generate_control_flow() {
1252        let source = r#"
1253            agent Main {
1254                on start {
1255                    let x = 10;
1256                    if x > 5 {
1257                        emit(1);
1258                    } else {
1259                        emit(0);
1260                    }
1261                }
1262            }
1263            run Main;
1264        "#;
1265
1266        let output = generate_source(source);
1267        assert!(output.contains("if (x > 5_i64)"), "output:\n{output}");
1268        // else is on the same line after close brace
1269        assert!(output.contains("else"), "output:\n{output}");
1270    }
1271
1272    #[test]
1273    fn generate_loops() {
1274        let source = r#"
1275            agent Main {
1276                on start {
1277                    for x in [1, 2, 3] {
1278                        print(str(x));
1279                    }
1280                    let n = 0;
1281                    while n < 5 {
1282                        n = n + 1;
1283                    }
1284                    emit(n);
1285                }
1286            }
1287            run Main;
1288        "#;
1289
1290        let output = generate_source(source);
1291        assert!(output.contains("for x in vec![1_i64, 2_i64, 3_i64]"));
1292        assert!(output.contains("while (n < 5_i64)"));
1293    }
1294
1295    #[test]
1296    fn generate_pub_function() {
1297        let source = r#"
1298            pub fn helper(x: Int) -> Int {
1299                return x * 2;
1300            }
1301            agent Main {
1302                on start {
1303                    emit(helper(21));
1304                }
1305            }
1306            run Main;
1307        "#;
1308
1309        let output = generate_source(source);
1310        assert!(output.contains("pub fn helper(x: i64) -> i64"));
1311    }
1312
1313    #[test]
1314    fn generate_pub_agent() {
1315        let source = r#"
1316            pub agent Worker {
1317                on start {
1318                    emit(42);
1319                }
1320            }
1321            agent Main {
1322                on start {
1323                    emit(0);
1324                }
1325            }
1326            run Main;
1327        "#;
1328
1329        let output = generate_source(source);
1330        assert!(output.contains("pub struct Worker;"));
1331    }
1332
1333    #[test]
1334    fn generate_module_tree_simple() {
1335        use sage_loader::load_single_file;
1336        use std::fs;
1337        use tempfile::TempDir;
1338
1339        let dir = TempDir::new().unwrap();
1340        let file = dir.path().join("test.sg");
1341        fs::write(
1342            &file,
1343            r#"
1344agent Main {
1345    on start {
1346        emit(42);
1347    }
1348}
1349run Main;
1350"#,
1351        )
1352        .unwrap();
1353
1354        let tree = load_single_file(&file).unwrap();
1355        let project = generate_module_tree(&tree, "test");
1356
1357        assert!(project.main_rs.contains("struct Main;"));
1358        assert!(project.main_rs.contains("async fn on_start"));
1359        assert!(project.main_rs.contains("#[tokio::main]"));
1360    }
1361
1362    #[test]
1363    fn generate_record_declaration() {
1364        let source = r#"
1365            record Point {
1366                x: Int,
1367                y: Int,
1368            }
1369            agent Main {
1370                on start {
1371                    let p = Point { x: 10, y: 20 };
1372                    emit(p.x);
1373                }
1374            }
1375            run Main;
1376        "#;
1377
1378        let output = generate_source(source);
1379        assert!(output.contains("#[derive(Debug, Clone)]"));
1380        assert!(output.contains("struct Point {"));
1381        assert!(output.contains("x: i64,"));
1382        assert!(output.contains("y: i64,"));
1383        assert!(output.contains("Point { x: 10_i64, y: 20_i64 }"));
1384        assert!(output.contains("p.x"));
1385    }
1386
1387    #[test]
1388    fn generate_enum_declaration() {
1389        let source = r#"
1390            enum Status {
1391                Active,
1392                Inactive,
1393                Pending,
1394            }
1395            agent Main {
1396                on start {
1397                    emit(0);
1398                }
1399            }
1400            run Main;
1401        "#;
1402
1403        let output = generate_source(source);
1404        assert!(output.contains("#[derive(Debug, Clone, Copy, PartialEq, Eq)]"));
1405        assert!(output.contains("enum Status {"));
1406        assert!(output.contains("Active,"));
1407        assert!(output.contains("Inactive,"));
1408        assert!(output.contains("Pending,"));
1409    }
1410
1411    #[test]
1412    fn generate_const_declaration() {
1413        let source = r#"
1414            const MAX_SIZE: Int = 100;
1415            const GREETING: String = "Hello";
1416            agent Main {
1417                on start {
1418                    emit(MAX_SIZE);
1419                }
1420            }
1421            run Main;
1422        "#;
1423
1424        let output = generate_source(source);
1425        assert!(output.contains("const MAX_SIZE: i64 = 100_i64;"));
1426        assert!(output.contains("const GREETING: String = \"Hello\".to_string();"));
1427    }
1428
1429    #[test]
1430    fn generate_match_expression() {
1431        let source = r#"
1432            enum Status {
1433                Active,
1434                Inactive,
1435            }
1436            fn check_status(s: Status) -> Int {
1437                return match s {
1438                    Active => 1,
1439                    Inactive => 0,
1440                };
1441            }
1442            agent Main {
1443                on start {
1444                    emit(0);
1445                }
1446            }
1447            run Main;
1448        "#;
1449
1450        let output = generate_source(source);
1451        assert!(output.contains("match s {"));
1452        assert!(output.contains("Active => 1_i64,"));
1453        assert!(output.contains("Inactive => 0_i64,"));
1454    }
1455
1456    // =========================================================================
1457    // RFC-0007: Error handling codegen tests
1458    // =========================================================================
1459
1460    #[test]
1461    fn generate_fallible_function() {
1462        let source = r#"
1463            fn get_data(url: String) -> String fails {
1464                return url;
1465            }
1466            agent Main {
1467                on start { emit(0); }
1468            }
1469            run Main;
1470        "#;
1471
1472        let output = generate_source(source);
1473        // Fallible function should return SageResult<T>
1474        assert!(output.contains("fn get_data(url: String) -> SageResult<String>"));
1475    }
1476
1477    #[test]
1478    fn generate_try_expression() {
1479        let source = r#"
1480            fn fallible() -> Int fails { return 42; }
1481            fn caller() -> Int fails {
1482                let x = try fallible();
1483                return x;
1484            }
1485            agent Main {
1486                on start { emit(0); }
1487            }
1488            run Main;
1489        "#;
1490
1491        let output = generate_source(source);
1492        // try should generate ? operator
1493        assert!(output.contains("fallible()?"));
1494    }
1495
1496    #[test]
1497    fn generate_catch_expression() {
1498        let source = r#"
1499            fn fallible() -> Int fails { return 42; }
1500            agent Main {
1501                on start {
1502                    let x = fallible() catch { 0 };
1503                    emit(x);
1504                }
1505            }
1506            run Main;
1507        "#;
1508
1509        let output = generate_source(source);
1510        // catch should generate match expression
1511        assert!(output.contains("match fallible()"));
1512        assert!(output.contains("Ok(__val) => __val"));
1513        assert!(output.contains("Err(_) => 0_i64"));
1514    }
1515
1516    #[test]
1517    fn generate_catch_with_binding() {
1518        let source = r#"
1519            fn fallible() -> Int fails { return 42; }
1520            agent Main {
1521                on start {
1522                    let x = fallible() catch(e) { 0 };
1523                    emit(x);
1524                }
1525            }
1526            run Main;
1527        "#;
1528
1529        let output = generate_source(source);
1530        // catch with binding should capture the error
1531        assert!(output.contains("Err(e) => 0_i64"));
1532    }
1533
1534    #[test]
1535    fn generate_on_error_handler() {
1536        let source = r#"
1537            agent Main {
1538                on start {
1539                    emit(0);
1540                }
1541                on error(e) {
1542                    emit(1);
1543                }
1544            }
1545            run Main;
1546        "#;
1547
1548        let output = generate_source(source);
1549        // Should generate on_error method
1550        assert!(output.contains("async fn on_error(self, e: SageError"));
1551        // Main should dispatch to on_error on failure
1552        assert!(output.contains(".on_error(e, ctx)"));
1553    }
1554}