Skip to main content

verificar/generator/
c_enum.rs

1//! C exhaustive enumeration
2//!
3//! Generates all valid C programs up to a specified AST depth.
4//! Uses a simplified C grammar for combinatorial generation.
5
6use super::GeneratedCode;
7use crate::Language;
8
9/// C AST node types for generation
10#[derive(Debug, Clone, PartialEq)]
11#[allow(missing_docs)]
12pub enum CNode {
13    /// Translation unit (root node)
14    TranslationUnit(Vec<CNode>),
15    /// Function definition
16    FuncDef {
17        /// Return type
18        return_type: CType,
19        /// Function name
20        name: String,
21        /// Parameters
22        params: Vec<(CType, String)>,
23        /// Function body
24        body: Vec<CNode>,
25    },
26    /// Variable declaration
27    VarDecl {
28        /// Variable type
29        var_type: CType,
30        /// Variable name
31        name: String,
32        /// Optional initializer
33        init: Option<Box<CNode>>,
34    },
35    /// Assignment: `lhs = rhs`
36    Assign {
37        /// Left-hand side
38        lhs: Box<CNode>,
39        /// Right-hand side
40        rhs: Box<CNode>,
41    },
42    /// Binary operation
43    BinOp {
44        /// Left operand
45        left: Box<CNode>,
46        /// Operator
47        op: CBinaryOp,
48        /// Right operand
49        right: Box<CNode>,
50    },
51    /// Unary operation
52    UnaryOp {
53        /// Operator
54        op: CUnaryOp,
55        /// Operand
56        operand: Box<CNode>,
57    },
58    /// Integer literal
59    IntLit(i64),
60    /// Float literal
61    FloatLit(f64),
62    /// Character literal
63    CharLit(char),
64    /// String literal
65    StrLit(String),
66    /// Variable reference
67    Ident(String),
68    /// If statement
69    If {
70        /// Condition
71        cond: Box<CNode>,
72        /// Then body
73        then_body: Vec<CNode>,
74        /// Else body
75        else_body: Vec<CNode>,
76    },
77    /// While loop
78    While {
79        /// Condition
80        cond: Box<CNode>,
81        /// Loop body
82        body: Vec<CNode>,
83    },
84    /// For loop
85    For {
86        /// Initialization
87        init: Option<Box<CNode>>,
88        /// Condition
89        cond: Option<Box<CNode>>,
90        /// Increment
91        incr: Option<Box<CNode>>,
92        /// Loop body
93        body: Vec<CNode>,
94    },
95    /// Return statement
96    Return(Option<Box<CNode>>),
97    /// Break statement
98    Break,
99    /// Continue statement
100    Continue,
101    /// Function call
102    Call {
103        /// Function name
104        func: String,
105        /// Arguments
106        args: Vec<CNode>,
107    },
108    /// Array access
109    ArrayAccess {
110        /// Array expression
111        array: Box<CNode>,
112        /// Index expression
113        index: Box<CNode>,
114    },
115    /// Comparison operation
116    Compare {
117        /// Left operand
118        left: Box<CNode>,
119        /// Operator
120        op: CCompareOp,
121        /// Right operand
122        right: Box<CNode>,
123    },
124    /// Expression statement
125    ExprStmt(Box<CNode>),
126    /// Compound statement (block)
127    Block(Vec<CNode>),
128    /// Ternary operator: `cond ? then : else`
129    Ternary {
130        /// Condition
131        cond: Box<CNode>,
132        /// Then expression
133        then_expr: Box<CNode>,
134        /// Else expression
135        else_expr: Box<CNode>,
136    },
137    /// Sizeof expression
138    Sizeof(CType),
139    /// Cast expression
140    Cast {
141        /// Target type
142        target_type: CType,
143        /// Expression
144        expr: Box<CNode>,
145    },
146    /// Pointer dereference
147    Deref(Box<CNode>),
148    /// Address-of
149    AddrOf(Box<CNode>),
150    /// Struct access: `expr.field`
151    StructAccess {
152        /// Expression
153        expr: Box<CNode>,
154        /// Field name
155        field: String,
156    },
157    /// Pointer struct access: `expr->field`
158    PtrAccess {
159        /// Expression
160        expr: Box<CNode>,
161        /// Field name
162        field: String,
163    },
164    /// Increment: `++x` or `x++`
165    Increment {
166        /// Operand
167        operand: Box<CNode>,
168        /// Pre-increment (++x) or post-increment (x++)
169        pre: bool,
170    },
171    /// Decrement: `--x` or `x--`
172    Decrement {
173        /// Operand
174        operand: Box<CNode>,
175        /// Pre-decrement (--x) or post-decrement (x--)
176        pre: bool,
177    },
178}
179
180/// C types
181#[derive(Debug, Clone, Copy, PartialEq, Eq)]
182pub enum CType {
183    /// void
184    Void,
185    /// int
186    Int,
187    /// char
188    Char,
189    /// float
190    Float,
191    /// double
192    Double,
193    /// long
194    Long,
195    /// unsigned int
196    UInt,
197    /// int pointer
198    IntPtr,
199    /// char pointer
200    CharPtr,
201    /// void pointer
202    VoidPtr,
203}
204
205impl CType {
206    /// Get all basic types
207    #[must_use]
208    pub fn all_basic() -> &'static [Self] {
209        &[Self::Int, Self::Char, Self::Float, Self::Double, Self::Long]
210    }
211
212    /// Convert to C type string
213    #[must_use]
214    pub fn to_str(self) -> &'static str {
215        match self {
216            Self::Void => "void",
217            Self::Int => "int",
218            Self::Char => "char",
219            Self::Float => "float",
220            Self::Double => "double",
221            Self::Long => "long",
222            Self::UInt => "unsigned int",
223            Self::IntPtr => "int*",
224            Self::CharPtr => "char*",
225            Self::VoidPtr => "void*",
226        }
227    }
228}
229
230/// C binary operators
231#[derive(Debug, Clone, Copy, PartialEq, Eq)]
232pub enum CBinaryOp {
233    /// Addition
234    Add,
235    /// Subtraction
236    Sub,
237    /// Multiplication
238    Mul,
239    /// Division
240    Div,
241    /// Modulo
242    Mod,
243    /// Bitwise AND
244    BitAnd,
245    /// Bitwise OR
246    BitOr,
247    /// Bitwise XOR
248    BitXor,
249    /// Left shift
250    Shl,
251    /// Right shift
252    Shr,
253    /// Logical AND
254    LogAnd,
255    /// Logical OR
256    LogOr,
257}
258
259impl CBinaryOp {
260    /// Get all arithmetic operators
261    #[must_use]
262    pub fn arithmetic() -> &'static [Self] {
263        &[Self::Add, Self::Sub, Self::Mul, Self::Div, Self::Mod]
264    }
265
266    /// Get all bitwise operators
267    #[must_use]
268    pub fn bitwise() -> &'static [Self] {
269        &[
270            Self::BitAnd,
271            Self::BitOr,
272            Self::BitXor,
273            Self::Shl,
274            Self::Shr,
275        ]
276    }
277
278    /// Convert to C operator string
279    #[must_use]
280    pub fn to_str(self) -> &'static str {
281        match self {
282            Self::Add => "+",
283            Self::Sub => "-",
284            Self::Mul => "*",
285            Self::Div => "/",
286            Self::Mod => "%",
287            Self::BitAnd => "&",
288            Self::BitOr => "|",
289            Self::BitXor => "^",
290            Self::Shl => "<<",
291            Self::Shr => ">>",
292            Self::LogAnd => "&&",
293            Self::LogOr => "||",
294        }
295    }
296}
297
298/// C unary operators
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub enum CUnaryOp {
301    /// Negation
302    Neg,
303    /// Logical NOT
304    Not,
305    /// Bitwise NOT
306    BitNot,
307}
308
309impl CUnaryOp {
310    /// Get all unary operators
311    #[must_use]
312    pub fn all() -> &'static [Self] {
313        &[Self::Neg, Self::Not, Self::BitNot]
314    }
315
316    /// Convert to C operator string
317    #[must_use]
318    pub fn to_str(self) -> &'static str {
319        match self {
320            Self::Neg => "-",
321            Self::Not => "!",
322            Self::BitNot => "~",
323        }
324    }
325}
326
327/// C comparison operators
328#[derive(Debug, Clone, Copy, PartialEq, Eq)]
329pub enum CCompareOp {
330    /// Equal
331    Eq,
332    /// Not equal
333    Ne,
334    /// Less than
335    Lt,
336    /// Greater than
337    Gt,
338    /// Less than or equal
339    Le,
340    /// Greater than or equal
341    Ge,
342}
343
344impl CCompareOp {
345    /// Get all comparison operators
346    #[must_use]
347    pub fn all() -> &'static [Self] {
348        &[Self::Eq, Self::Ne, Self::Lt, Self::Gt, Self::Le, Self::Ge]
349    }
350
351    /// Convert to C operator string
352    #[must_use]
353    pub fn to_str(self) -> &'static str {
354        match self {
355            Self::Eq => "==",
356            Self::Ne => "!=",
357            Self::Lt => "<",
358            Self::Gt => ">",
359            Self::Le => "<=",
360            Self::Ge => ">=",
361        }
362    }
363}
364
365impl CNode {
366    /// Convert AST node to C code string
367    #[must_use]
368    #[allow(clippy::too_many_lines)]
369    pub fn to_code(&self, indent: usize) -> String {
370        let indent_str = "    ".repeat(indent);
371        match self {
372            Self::TranslationUnit(items) => items
373                .iter()
374                .map(|item| item.to_code(0))
375                .collect::<Vec<_>>()
376                .join("\n\n"),
377            Self::FuncDef {
378                return_type,
379                name,
380                params,
381                body,
382            } => {
383                let params_str = if params.is_empty() {
384                    "void".to_string()
385                } else {
386                    params
387                        .iter()
388                        .map(|(t, n)| format!("{} {}", t.to_str(), n))
389                        .collect::<Vec<_>>()
390                        .join(", ")
391                };
392                let body_str = body
393                    .iter()
394                    .map(|s| s.to_code(indent + 1))
395                    .collect::<Vec<_>>()
396                    .join("\n");
397                format!(
398                    "{}{} {}({}) {{\n{}\n{}}}",
399                    indent_str,
400                    return_type.to_str(),
401                    name,
402                    params_str,
403                    body_str,
404                    indent_str
405                )
406            }
407            Self::VarDecl {
408                var_type,
409                name,
410                init,
411            } => {
412                if let Some(init_expr) = init {
413                    format!(
414                        "{}{} {} = {};",
415                        indent_str,
416                        var_type.to_str(),
417                        name,
418                        init_expr.to_code(0)
419                    )
420                } else {
421                    format!("{}{} {};", indent_str, var_type.to_str(), name)
422                }
423            }
424            Self::Assign { lhs, rhs } => {
425                format!("{}{} = {};", indent_str, lhs.to_code(0), rhs.to_code(0))
426            }
427            Self::BinOp { left, op, right } => {
428                format!("({} {} {})", left.to_code(0), op.to_str(), right.to_code(0))
429            }
430            Self::UnaryOp { op, operand } => {
431                format!("({}{})", op.to_str(), operand.to_code(0))
432            }
433            Self::IntLit(n) => n.to_string(),
434            Self::FloatLit(f) => format!("{f:.1}"),
435            Self::CharLit(c) => format!("'{c}'"),
436            Self::StrLit(s) => format!("\"{s}\""),
437            Self::Ident(name) => name.clone(),
438            Self::If {
439                cond,
440                then_body,
441                else_body,
442            } => {
443                let then_str = then_body
444                    .iter()
445                    .map(|s| s.to_code(indent + 1))
446                    .collect::<Vec<_>>()
447                    .join("\n");
448                if else_body.is_empty() {
449                    format!(
450                        "{}if ({}) {{\n{}\n{}}}",
451                        indent_str,
452                        cond.to_code(0),
453                        then_str,
454                        indent_str
455                    )
456                } else {
457                    let else_str = else_body
458                        .iter()
459                        .map(|s| s.to_code(indent + 1))
460                        .collect::<Vec<_>>()
461                        .join("\n");
462                    format!(
463                        "{}if ({}) {{\n{}\n{}}} else {{\n{}\n{}}}",
464                        indent_str,
465                        cond.to_code(0),
466                        then_str,
467                        indent_str,
468                        else_str,
469                        indent_str
470                    )
471                }
472            }
473            Self::While { cond, body } => {
474                let body_str = body
475                    .iter()
476                    .map(|s| s.to_code(indent + 1))
477                    .collect::<Vec<_>>()
478                    .join("\n");
479                format!(
480                    "{}while ({}) {{\n{}\n{}}}",
481                    indent_str,
482                    cond.to_code(0),
483                    body_str,
484                    indent_str
485                )
486            }
487            Self::For {
488                init,
489                cond,
490                incr,
491                body,
492            } => {
493                let init_str = init.as_ref().map_or(String::new(), |i| i.to_code(0));
494                let cond_str = cond.as_ref().map_or(String::new(), |c| c.to_code(0));
495                let incr_str = incr.as_ref().map_or(String::new(), |i| i.to_code(0));
496                let body_str = body
497                    .iter()
498                    .map(|s| s.to_code(indent + 1))
499                    .collect::<Vec<_>>()
500                    .join("\n");
501                format!(
502                    "{indent_str}for ({init_str}; {cond_str}; {incr_str}) {{\n{body_str}\n{indent_str}}}"
503                )
504            }
505            Self::Return(expr) => {
506                if let Some(e) = expr {
507                    format!("{}return {};", indent_str, e.to_code(0))
508                } else {
509                    format!("{indent_str}return;")
510                }
511            }
512            Self::Break => format!("{indent_str}break;"),
513            Self::Continue => format!("{indent_str}continue;"),
514            Self::Call { func, args } => {
515                let args_str = args
516                    .iter()
517                    .map(|a| a.to_code(0))
518                    .collect::<Vec<_>>()
519                    .join(", ");
520                format!("{func}({args_str})")
521            }
522            Self::ArrayAccess { array, index } => {
523                format!("{}[{}]", array.to_code(0), index.to_code(0))
524            }
525            Self::Compare { left, op, right } => {
526                format!("({} {} {})", left.to_code(0), op.to_str(), right.to_code(0))
527            }
528            Self::ExprStmt(expr) => format!("{}{};", indent_str, expr.to_code(0)),
529            Self::Block(stmts) => {
530                let stmts_str = stmts
531                    .iter()
532                    .map(|s| s.to_code(indent + 1))
533                    .collect::<Vec<_>>()
534                    .join("\n");
535                format!("{indent_str}{{\n{stmts_str}\n{indent_str}}}")
536            }
537            Self::Ternary {
538                cond,
539                then_expr,
540                else_expr,
541            } => {
542                format!(
543                    "({} ? {} : {})",
544                    cond.to_code(0),
545                    then_expr.to_code(0),
546                    else_expr.to_code(0)
547                )
548            }
549            Self::Sizeof(t) => format!("sizeof({})", t.to_str()),
550            Self::Cast { target_type, expr } => {
551                format!("(({}){})", target_type.to_str(), expr.to_code(0))
552            }
553            Self::Deref(expr) => format!("(*{})", expr.to_code(0)),
554            Self::AddrOf(expr) => format!("(&{})", expr.to_code(0)),
555            Self::StructAccess { expr, field } => format!("{}.{}", expr.to_code(0), field),
556            Self::PtrAccess { expr, field } => format!("{}->{}", expr.to_code(0), field),
557            Self::Increment { operand, pre } => {
558                if *pre {
559                    format!("++{}", operand.to_code(0))
560                } else {
561                    format!("{}++", operand.to_code(0))
562                }
563            }
564            Self::Decrement { operand, pre } => {
565                if *pre {
566                    format!("--{}", operand.to_code(0))
567                } else {
568                    format!("{}--", operand.to_code(0))
569                }
570            }
571        }
572    }
573
574    /// Calculate AST depth
575    #[must_use]
576    pub fn depth(&self) -> usize {
577        match self {
578            Self::TranslationUnit(items) => 1 + items.iter().map(Self::depth).max().unwrap_or(0),
579            Self::FuncDef { body, .. } => 1 + body.iter().map(Self::depth).max().unwrap_or(0),
580            Self::VarDecl { init, .. } => 1 + init.as_ref().map_or(0, |i| i.depth()),
581            Self::Assign { lhs, rhs } => 1 + lhs.depth().max(rhs.depth()),
582            Self::BinOp { left, right, .. } | Self::Compare { left, right, .. } => {
583                1 + left.depth().max(right.depth())
584            }
585            Self::UnaryOp { operand, .. } => 1 + operand.depth(),
586            Self::If {
587                cond,
588                then_body,
589                else_body,
590            } => {
591                let then_depth = then_body.iter().map(Self::depth).max().unwrap_or(0);
592                let else_depth = else_body.iter().map(Self::depth).max().unwrap_or(0);
593                1 + cond.depth().max(then_depth).max(else_depth)
594            }
595            Self::While { cond, body }
596            | Self::For {
597                cond: Some(cond),
598                body,
599                ..
600            } => {
601                let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
602                1 + cond.depth().max(body_depth)
603            }
604            Self::For { body, .. } => 1 + body.iter().map(Self::depth).max().unwrap_or(0),
605            Self::Return(Some(e)) => 1 + e.depth(),
606            Self::Call { args, .. } => 1 + args.iter().map(Self::depth).max().unwrap_or(0),
607            Self::ArrayAccess { array, index } => 1 + array.depth().max(index.depth()),
608            Self::ExprStmt(e) => 1 + e.depth(),
609            Self::Block(stmts) => 1 + stmts.iter().map(Self::depth).max().unwrap_or(0),
610            Self::Ternary {
611                cond,
612                then_expr,
613                else_expr,
614            } => 1 + cond.depth().max(then_expr.depth()).max(else_expr.depth()),
615            Self::Cast { expr, .. } | Self::Deref(expr) | Self::AddrOf(expr) => 1 + expr.depth(),
616            Self::StructAccess { expr, .. } | Self::PtrAccess { expr, .. } => 1 + expr.depth(),
617            Self::Increment { operand, .. } | Self::Decrement { operand, .. } => {
618                1 + operand.depth()
619            }
620            // Terminal nodes
621            Self::Return(None)
622            | Self::Break
623            | Self::Continue
624            | Self::IntLit(_)
625            | Self::FloatLit(_)
626            | Self::CharLit(_)
627            | Self::StrLit(_)
628            | Self::Ident(_)
629            | Self::Sizeof(_) => 1,
630        }
631    }
632}
633
634/// Exhaustive C program enumerator
635#[derive(Debug)]
636pub struct CEnumerator {
637    max_depth: usize,
638    var_names: Vec<String>,
639    int_values: Vec<i64>,
640}
641
642impl Default for CEnumerator {
643    fn default() -> Self {
644        Self::new(3)
645    }
646}
647
648impl CEnumerator {
649    /// Create a new enumerator with specified max depth
650    #[must_use]
651    pub fn new(max_depth: usize) -> Self {
652        Self {
653            max_depth,
654            var_names: vec!["x".to_string(), "y".to_string(), "n".to_string()],
655            int_values: vec![0, 1, -1, 2, 10, 42],
656        }
657    }
658
659    /// Enumerate all expressions up to the given depth
660    pub fn enumerate_expressions(&self, depth: usize) -> Vec<CNode> {
661        if depth == 0 {
662            return vec![];
663        }
664
665        let mut results = Vec::new();
666
667        // Depth 1: literals and names
668        for val in &self.int_values {
669            results.push(CNode::IntLit(*val));
670        }
671        for name in &self.var_names {
672            results.push(CNode::Ident(name.clone()));
673        }
674
675        if depth >= 2 {
676            // Binary operations
677            let sub_exprs = self.enumerate_expressions(depth - 1);
678            let limited: Vec<_> = sub_exprs.iter().take(15).collect();
679
680            for left in &limited {
681                for right in limited.iter().take(10) {
682                    for op in CBinaryOp::arithmetic() {
683                        results.push(CNode::BinOp {
684                            left: Box::new((*left).clone()),
685                            op: *op,
686                            right: Box::new((*right).clone()),
687                        });
688                    }
689                }
690            }
691
692            // Comparisons
693            for left in &limited {
694                for right in limited.iter().take(10) {
695                    for op in CCompareOp::all() {
696                        results.push(CNode::Compare {
697                            left: Box::new((*left).clone()),
698                            op: *op,
699                            right: Box::new((*right).clone()),
700                        });
701                    }
702                }
703            }
704
705            // Unary operations
706            for operand in limited.iter().take(8) {
707                for op in CUnaryOp::all() {
708                    results.push(CNode::UnaryOp {
709                        op: *op,
710                        operand: Box::new((*operand).clone()),
711                    });
712                }
713            }
714
715            // Increment/Decrement
716            for name in &self.var_names {
717                results.push(CNode::Increment {
718                    operand: Box::new(CNode::Ident(name.clone())),
719                    pre: true,
720                });
721                results.push(CNode::Increment {
722                    operand: Box::new(CNode::Ident(name.clone())),
723                    pre: false,
724                });
725            }
726        }
727
728        results
729    }
730
731    /// Enumerate all statements up to the given depth
732    pub fn enumerate_statements(&self, depth: usize) -> Vec<CNode> {
733        if depth == 0 {
734            return vec![];
735        }
736
737        let mut results = Vec::new();
738
739        let exprs = self.enumerate_expressions(depth - 1);
740        let limited_exprs: Vec<_> = exprs.iter().take(20).collect();
741
742        // Variable declarations
743        for var_type in CType::all_basic() {
744            for name in &self.var_names {
745                results.push(CNode::VarDecl {
746                    var_type: *var_type,
747                    name: name.clone(),
748                    init: None,
749                });
750                // With initialization
751                for val in self.int_values.iter().take(3) {
752                    results.push(CNode::VarDecl {
753                        var_type: *var_type,
754                        name: name.clone(),
755                        init: Some(Box::new(CNode::IntLit(*val))),
756                    });
757                }
758            }
759        }
760
761        // Assignments
762        for name in &self.var_names {
763            for value in &limited_exprs {
764                results.push(CNode::Assign {
765                    lhs: Box::new(CNode::Ident(name.clone())),
766                    rhs: Box::new((*value).clone()),
767                });
768            }
769        }
770
771        // Return statements
772        results.push(CNode::Return(None));
773        for expr in limited_exprs.iter().take(10) {
774            results.push(CNode::Return(Some(Box::new((*expr).clone()))));
775        }
776
777        // Break and continue
778        results.push(CNode::Break);
779        results.push(CNode::Continue);
780
781        if depth >= 2 {
782            // If statements
783            let conditions: Vec<_> = exprs
784                .iter()
785                .filter(|e| matches!(e, CNode::Compare { .. } | CNode::Ident(_)))
786                .take(5)
787                .collect();
788
789            let body_stmts = self.enumerate_statements(depth - 1);
790            let limited_body: Vec<_> = body_stmts.iter().take(5).collect();
791
792            for cond in &conditions {
793                for body in &limited_body {
794                    results.push(CNode::If {
795                        cond: Box::new((*cond).clone()),
796                        then_body: vec![(*body).clone()],
797                        else_body: vec![],
798                    });
799                }
800            }
801
802            // While loops
803            for cond in &conditions {
804                results.push(CNode::While {
805                    cond: Box::new((*cond).clone()),
806                    body: vec![CNode::Break],
807                });
808            }
809
810            // For loops
811            for name in self.var_names.iter().take(2) {
812                results.push(CNode::For {
813                    init: Some(Box::new(CNode::VarDecl {
814                        var_type: CType::Int,
815                        name: name.clone(),
816                        init: Some(Box::new(CNode::IntLit(0))),
817                    })),
818                    cond: Some(Box::new(CNode::Compare {
819                        left: Box::new(CNode::Ident(name.clone())),
820                        op: CCompareOp::Lt,
821                        right: Box::new(CNode::IntLit(10)),
822                    })),
823                    incr: Some(Box::new(CNode::Increment {
824                        operand: Box::new(CNode::Ident(name.clone())),
825                        pre: false,
826                    })),
827                    body: vec![CNode::Break],
828                });
829            }
830        }
831
832        results
833    }
834
835    /// Enumerate complete programs
836    #[must_use]
837    pub fn enumerate_programs(&self) -> Vec<GeneratedCode> {
838        let mut results = Vec::new();
839
840        let stmts = self.enumerate_statements(self.max_depth);
841
842        // Simple main functions with single statement
843        for stmt in stmts.iter().take(50) {
844            let func = CNode::FuncDef {
845                return_type: CType::Int,
846                name: "main".to_string(),
847                params: vec![],
848                body: vec![
849                    stmt.clone(),
850                    CNode::Return(Some(Box::new(CNode::IntLit(0)))),
851                ],
852            };
853            let unit = CNode::TranslationUnit(vec![func]);
854            let code = unit.to_code(0);
855            results.push(GeneratedCode {
856                code,
857                language: Language::C,
858                ast_depth: stmt.depth() + 2,
859                features: self.extract_features(stmt),
860            });
861        }
862
863        // Functions with parameters
864        for stmt in stmts.iter().take(20) {
865            let func = CNode::FuncDef {
866                return_type: CType::Int,
867                name: "compute".to_string(),
868                params: vec![(CType::Int, "a".to_string()), (CType::Int, "b".to_string())],
869                body: vec![stmt.clone()],
870            };
871            let unit = CNode::TranslationUnit(vec![func]);
872            let code = unit.to_code(0);
873            results.push(GeneratedCode {
874                code,
875                language: Language::C,
876                ast_depth: stmt.depth() + 2,
877                features: self.extract_features(stmt),
878            });
879        }
880
881        results
882    }
883
884    /// Extract feature labels from an AST node
885    fn extract_features(&self, node: &CNode) -> Vec<String> {
886        let mut features = Vec::new();
887
888        match node {
889            CNode::VarDecl { .. } => features.push("var_decl".to_string()),
890            CNode::Assign { .. } => features.push("assignment".to_string()),
891            CNode::BinOp { op, .. } => {
892                features.push("binop".to_string());
893                features.push(format!("op_{}", op.to_str()));
894            }
895            CNode::If { else_body, .. } => {
896                features.push("if".to_string());
897                if !else_body.is_empty() {
898                    features.push("else".to_string());
899                }
900            }
901            CNode::While { .. } => features.push("while".to_string()),
902            CNode::For { .. } => features.push("for".to_string()),
903            CNode::Return(_) => features.push("return".to_string()),
904            CNode::Compare { op, .. } => {
905                features.push("compare".to_string());
906                features.push(format!("cmp_{}", op.to_str()));
907            }
908            CNode::Increment { .. } => features.push("increment".to_string()),
909            CNode::Decrement { .. } => features.push("decrement".to_string()),
910            _ => {}
911        }
912
913        features
914    }
915}
916
917#[cfg(test)]
918mod tests {
919    use super::*;
920
921    #[test]
922    fn test_int_lit_to_code() {
923        let node = CNode::IntLit(42);
924        assert_eq!(node.to_code(0), "42");
925    }
926
927    #[test]
928    fn test_var_decl_to_code() {
929        let node = CNode::VarDecl {
930            var_type: CType::Int,
931            name: "x".to_string(),
932            init: Some(Box::new(CNode::IntLit(0))),
933        };
934        assert_eq!(node.to_code(0), "int x = 0;");
935    }
936
937    #[test]
938    fn test_binop_to_code() {
939        let node = CNode::BinOp {
940            left: Box::new(CNode::Ident("x".to_string())),
941            op: CBinaryOp::Add,
942            right: Box::new(CNode::IntLit(1)),
943        };
944        assert_eq!(node.to_code(0), "(x + 1)");
945    }
946
947    #[test]
948    fn test_func_def_to_code() {
949        let node = CNode::FuncDef {
950            return_type: CType::Int,
951            name: "main".to_string(),
952            params: vec![],
953            body: vec![CNode::Return(Some(Box::new(CNode::IntLit(0))))],
954        };
955        let code = node.to_code(0);
956        assert!(code.contains("int main(void)"));
957        assert!(code.contains("return 0;"));
958    }
959
960    #[test]
961    fn test_if_to_code() {
962        let node = CNode::If {
963            cond: Box::new(CNode::Ident("x".to_string())),
964            then_body: vec![CNode::Return(Some(Box::new(CNode::IntLit(1))))],
965            else_body: vec![],
966        };
967        let code = node.to_code(0);
968        assert!(code.contains("if (x)"));
969        assert!(code.contains("return 1;"));
970    }
971
972    #[test]
973    fn test_for_to_code() {
974        let node = CNode::For {
975            init: Some(Box::new(CNode::VarDecl {
976                var_type: CType::Int,
977                name: "i".to_string(),
978                init: Some(Box::new(CNode::IntLit(0))),
979            })),
980            cond: Some(Box::new(CNode::Compare {
981                left: Box::new(CNode::Ident("i".to_string())),
982                op: CCompareOp::Lt,
983                right: Box::new(CNode::IntLit(10)),
984            })),
985            incr: Some(Box::new(CNode::Increment {
986                operand: Box::new(CNode::Ident("i".to_string())),
987                pre: false,
988            })),
989            body: vec![CNode::Break],
990        };
991        let code = node.to_code(0);
992        assert!(code.contains("for (int i = 0;"));
993        assert!(code.contains("(i < 10)"));
994        assert!(code.contains("i++"));
995    }
996
997    #[test]
998    fn test_enumerator_creates_programs() {
999        let enumerator = CEnumerator::new(2);
1000        let programs = enumerator.enumerate_programs();
1001        assert!(!programs.is_empty(), "Should generate programs");
1002    }
1003
1004    #[test]
1005    fn test_programs_are_c() {
1006        let enumerator = CEnumerator::new(2);
1007        let programs = enumerator.enumerate_programs();
1008        for prog in &programs {
1009            assert_eq!(prog.language, Language::C);
1010        }
1011    }
1012
1013    #[test]
1014    fn test_depth_calculation() {
1015        let node = CNode::BinOp {
1016            left: Box::new(CNode::IntLit(1)),
1017            op: CBinaryOp::Add,
1018            right: Box::new(CNode::BinOp {
1019                left: Box::new(CNode::IntLit(2)),
1020                op: CBinaryOp::Mul,
1021                right: Box::new(CNode::IntLit(3)),
1022            }),
1023        };
1024        assert_eq!(node.depth(), 3);
1025    }
1026
1027    #[test]
1028    fn test_compare_to_code() {
1029        let node = CNode::Compare {
1030            left: Box::new(CNode::Ident("x".to_string())),
1031            op: CCompareOp::Lt,
1032            right: Box::new(CNode::IntLit(10)),
1033        };
1034        assert_eq!(node.to_code(0), "(x < 10)");
1035    }
1036
1037    #[test]
1038    fn test_increment_to_code() {
1039        let pre = CNode::Increment {
1040            operand: Box::new(CNode::Ident("x".to_string())),
1041            pre: true,
1042        };
1043        let post = CNode::Increment {
1044            operand: Box::new(CNode::Ident("x".to_string())),
1045            pre: false,
1046        };
1047        assert_eq!(pre.to_code(0), "++x");
1048        assert_eq!(post.to_code(0), "x++");
1049    }
1050
1051    #[test]
1052    fn test_c_type_to_str() {
1053        assert_eq!(CType::Int.to_str(), "int");
1054        assert_eq!(CType::Void.to_str(), "void");
1055        assert_eq!(CType::IntPtr.to_str(), "int*");
1056    }
1057
1058    #[test]
1059    fn test_extract_features() {
1060        let enumerator = CEnumerator::new(2);
1061        let node = CNode::If {
1062            cond: Box::new(CNode::Ident("x".to_string())),
1063            then_body: vec![CNode::Break],
1064            else_body: vec![CNode::Continue],
1065        };
1066        let features = enumerator.extract_features(&node);
1067        assert!(features.contains(&"if".to_string()));
1068        assert!(features.contains(&"else".to_string()));
1069    }
1070}