Skip to main content

verificar/generator/
python_enum.rs

1//! Python exhaustive enumeration
2//!
3//! Generates all valid Python programs up to a specified AST depth.
4//! Uses a simplified Python grammar for combinatorial generation.
5
6use super::GeneratedCode;
7use crate::Language;
8
9/// Python AST node types for generation
10#[derive(Debug, Clone, PartialEq)]
11#[allow(missing_docs)]
12pub enum PythonNode {
13    /// Module (root node)
14    Module(Vec<PythonNode>),
15    /// Assignment statement: `name = expr`
16    Assign {
17        /// Variable name being assigned to
18        target: String,
19        /// Expression value
20        value: Box<PythonNode>,
21    },
22    /// Binary operation: `left op right`
23    BinOp {
24        /// Left operand
25        left: Box<PythonNode>,
26        /// Binary operator
27        op: BinaryOp,
28        /// Right operand
29        right: Box<PythonNode>,
30    },
31    /// Unary operation: `op operand`
32    UnaryOp {
33        /// Unary operator
34        op: UnaryOp,
35        /// Operand expression
36        operand: Box<PythonNode>,
37    },
38    /// Integer literal
39    IntLit(i64),
40    /// Float literal
41    FloatLit(f64),
42    /// String literal
43    StrLit(String),
44    /// Boolean literal
45    BoolLit(bool),
46    /// None literal
47    NoneLit,
48    /// Variable reference
49    Name(String),
50    /// If statement
51    If {
52        /// Condition expression
53        test: Box<PythonNode>,
54        /// If body statements
55        body: Vec<PythonNode>,
56        /// Else body statements
57        orelse: Vec<PythonNode>,
58    },
59    /// While loop
60    While {
61        /// Loop condition
62        test: Box<PythonNode>,
63        /// Loop body statements
64        body: Vec<PythonNode>,
65    },
66    /// For loop
67    For {
68        /// Loop variable name
69        target: String,
70        /// Iterable expression
71        iter: Box<PythonNode>,
72        /// Loop body statements
73        body: Vec<PythonNode>,
74    },
75    /// Function definition
76    FuncDef {
77        /// Function name
78        name: String,
79        /// Parameter names
80        args: Vec<String>,
81        /// Function body statements
82        body: Vec<PythonNode>,
83    },
84    /// Function call
85    Call {
86        /// Function name
87        func: String,
88        /// Argument expressions
89        args: Vec<PythonNode>,
90    },
91    /// Return statement
92    Return(Option<Box<PythonNode>>),
93    /// Pass statement
94    Pass,
95    /// Break statement
96    Break,
97    /// Continue statement
98    Continue,
99    /// List literal
100    List(Vec<PythonNode>),
101    /// Comparison: `left op right`
102    Compare {
103        /// Left operand
104        left: Box<PythonNode>,
105        /// Comparison operator
106        op: CompareOp,
107        /// Right operand
108        right: Box<PythonNode>,
109    },
110}
111
112/// Binary operators
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum BinaryOp {
115    /// Addition (`+`)
116    Add,
117    /// Subtraction (`-`)
118    Sub,
119    /// Multiplication (`*`)
120    Mult,
121    /// Division (`/`)
122    Div,
123    /// Modulo (`%`)
124    Mod,
125    /// Floor division (`//`)
126    FloorDiv,
127    /// Power (`**`)
128    Pow,
129    /// Logical and (`and`)
130    And,
131    /// Logical or (`or`)
132    Or,
133}
134
135impl BinaryOp {
136    /// Get all arithmetic binary operators
137    #[must_use]
138    pub fn all() -> &'static [Self] {
139        &[
140            Self::Add,
141            Self::Sub,
142            Self::Mult,
143            Self::Div,
144            Self::Mod,
145            Self::FloorDiv,
146            Self::Pow,
147        ]
148    }
149
150    /// Convert to Python operator string
151    #[must_use]
152    pub fn to_str(self) -> &'static str {
153        match self {
154            Self::Add => "+",
155            Self::Sub => "-",
156            Self::Mult => "*",
157            Self::Div => "/",
158            Self::Mod => "%",
159            Self::FloorDiv => "//",
160            Self::Pow => "**",
161            Self::And => "and",
162            Self::Or => "or",
163        }
164    }
165}
166
167/// Unary operators
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum UnaryOp {
170    /// Negation (`-x`)
171    Neg,
172    /// Logical not (`not x`)
173    Not,
174    /// Positive (`+x`)
175    Pos,
176}
177
178impl UnaryOp {
179    /// Get all unary operators
180    #[must_use]
181    pub fn all() -> &'static [Self] {
182        &[Self::Neg, Self::Not, Self::Pos]
183    }
184
185    /// Convert to Python operator string
186    #[must_use]
187    pub fn to_str(self) -> &'static str {
188        match self {
189            Self::Neg => "-",
190            Self::Not => "not ",
191            Self::Pos => "+",
192        }
193    }
194}
195
196/// Comparison operators
197#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub enum CompareOp {
199    /// Equal (`==`)
200    Eq,
201    /// Not equal (`!=`)
202    NotEq,
203    /// Less than (`<`)
204    Lt,
205    /// Less than or equal (`<=`)
206    LtE,
207    /// Greater than (`>`)
208    Gt,
209    /// Greater than or equal (`>=`)
210    GtE,
211}
212
213impl CompareOp {
214    /// Get all comparison operators
215    #[must_use]
216    pub fn all() -> &'static [Self] {
217        &[
218            Self::Eq,
219            Self::NotEq,
220            Self::Lt,
221            Self::LtE,
222            Self::Gt,
223            Self::GtE,
224        ]
225    }
226
227    /// Convert to Python operator string
228    #[must_use]
229    pub fn to_str(self) -> &'static str {
230        match self {
231            Self::Eq => "==",
232            Self::NotEq => "!=",
233            Self::Lt => "<",
234            Self::LtE => "<=",
235            Self::Gt => ">",
236            Self::GtE => ">=",
237        }
238    }
239}
240
241impl PythonNode {
242    /// Convert AST node to Python source code
243    #[allow(clippy::too_many_lines)]
244    pub fn to_code(&self, indent: usize) -> String {
245        let indent_str = "    ".repeat(indent);
246        match self {
247            Self::Module(stmts) => stmts
248                .iter()
249                .map(|s| s.to_code(0))
250                .collect::<Vec<_>>()
251                .join("\n"),
252            Self::Assign { target, value } => {
253                let val = value.to_code(0);
254                format!("{indent_str}{target} = {val}")
255            }
256            Self::BinOp { left, op, right } => {
257                let l = left.to_code(0);
258                let r = right.to_code(0);
259                let o = op.to_str();
260                format!("({l} {o} {r})")
261            }
262            Self::UnaryOp { op, operand } => {
263                let o = op.to_str();
264                let e = operand.to_code(0);
265                format!("({o}{e})")
266            }
267            Self::IntLit(n) => n.to_string(),
268            Self::FloatLit(f) => format!("{f:.1}"),
269            Self::StrLit(s) => format!("\"{s}\""),
270            Self::BoolLit(b) => if *b { "True" } else { "False" }.to_string(),
271            Self::NoneLit => "None".to_string(),
272            Self::Name(name) => name.clone(),
273            Self::If { test, body, orelse } => {
274                self.if_to_code(&indent_str, indent, test, body, orelse)
275            }
276            Self::While { test, body } => self.while_to_code(&indent_str, indent, test, body),
277            Self::For { target, iter, body } => {
278                self.for_to_code(&indent_str, indent, target, iter, body)
279            }
280            Self::FuncDef { name, args, body } => {
281                self.funcdef_to_code(&indent_str, indent, name, args, body)
282            }
283            Self::Call { func, args } => {
284                let args_str = args
285                    .iter()
286                    .map(|a| a.to_code(0))
287                    .collect::<Vec<_>>()
288                    .join(", ");
289                format!("{func}({args_str})")
290            }
291            Self::Return(Some(value)) => {
292                let val = value.to_code(0);
293                format!("{indent_str}return {val}")
294            }
295            Self::Return(None) => format!("{indent_str}return"),
296            Self::Pass => format!("{indent_str}pass"),
297            Self::Break => format!("{indent_str}break"),
298            Self::Continue => format!("{indent_str}continue"),
299            Self::List(items) => {
300                let items_str = items
301                    .iter()
302                    .map(|i| i.to_code(0))
303                    .collect::<Vec<_>>()
304                    .join(", ");
305                format!("[{items_str}]")
306            }
307            Self::Compare { left, op, right } => {
308                let l = left.to_code(0);
309                let r = right.to_code(0);
310                let o = op.to_str();
311                format!("({l} {o} {r})")
312            }
313        }
314    }
315
316    fn if_to_code(
317        &self,
318        indent_str: &str,
319        indent: usize,
320        test: &PythonNode,
321        body: &[PythonNode],
322        orelse: &[PythonNode],
323    ) -> String {
324        let body_code = body
325            .iter()
326            .map(|s| s.to_code(indent + 1))
327            .collect::<Vec<_>>()
328            .join("\n");
329        let test_code = test.to_code(0);
330        if orelse.is_empty() {
331            format!("{indent_str}if {test_code}:\n{body_code}")
332        } else {
333            let else_code = orelse
334                .iter()
335                .map(|s| s.to_code(indent + 1))
336                .collect::<Vec<_>>()
337                .join("\n");
338            format!("{indent_str}if {test_code}:\n{body_code}\n{indent_str}else:\n{else_code}")
339        }
340    }
341
342    fn while_to_code(
343        &self,
344        indent_str: &str,
345        indent: usize,
346        test: &PythonNode,
347        body: &[PythonNode],
348    ) -> String {
349        let body_code = body
350            .iter()
351            .map(|s| s.to_code(indent + 1))
352            .collect::<Vec<_>>()
353            .join("\n");
354        let test_code = test.to_code(0);
355        format!("{indent_str}while {test_code}:\n{body_code}")
356    }
357
358    fn for_to_code(
359        &self,
360        indent_str: &str,
361        indent: usize,
362        target: &str,
363        iter: &PythonNode,
364        body: &[PythonNode],
365    ) -> String {
366        let body_code = body
367            .iter()
368            .map(|s| s.to_code(indent + 1))
369            .collect::<Vec<_>>()
370            .join("\n");
371        let iter_code = iter.to_code(0);
372        format!("{indent_str}for {target} in {iter_code}:\n{body_code}")
373    }
374
375    fn funcdef_to_code(
376        &self,
377        indent_str: &str,
378        indent: usize,
379        name: &str,
380        args: &[String],
381        body: &[PythonNode],
382    ) -> String {
383        let args_str = args.join(", ");
384        let body_code = if body.is_empty() {
385            format!("{indent_str}    pass")
386        } else {
387            body.iter()
388                .map(|s| s.to_code(indent + 1))
389                .collect::<Vec<_>>()
390                .join("\n")
391        };
392        format!("{indent_str}def {name}({args_str}):\n{body_code}")
393    }
394
395    /// Calculate AST depth
396    pub fn depth(&self) -> usize {
397        match self {
398            Self::Module(stmts) => 1 + stmts.iter().map(Self::depth).max().unwrap_or(0),
399            Self::Assign { value, .. } => 1 + value.depth(),
400            Self::BinOp { left, right, .. } | Self::Compare { left, right, .. } => {
401                1 + left.depth().max(right.depth())
402            }
403            Self::UnaryOp { operand, .. } => 1 + operand.depth(),
404            Self::If { test, body, orelse } => {
405                let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
406                let else_depth = orelse.iter().map(Self::depth).max().unwrap_or(0);
407                1 + test.depth().max(body_depth).max(else_depth)
408            }
409            Self::While { test, body } => {
410                let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
411                1 + test.depth().max(body_depth)
412            }
413            Self::For { iter, body, .. } => {
414                let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
415                1 + iter.depth().max(body_depth)
416            }
417            Self::FuncDef { body, .. } => 1 + body.iter().map(Self::depth).max().unwrap_or(0),
418            Self::Call { args, .. } => 1 + args.iter().map(Self::depth).max().unwrap_or(0),
419            Self::Return(Some(v)) => 1 + v.depth(),
420            Self::List(items) => 1 + items.iter().map(Self::depth).max().unwrap_or(0),
421            // Terminal nodes - depth 1
422            Self::Return(None)
423            | Self::IntLit(_)
424            | Self::FloatLit(_)
425            | Self::StrLit(_)
426            | Self::BoolLit(_)
427            | Self::NoneLit
428            | Self::Name(_)
429            | Self::Pass
430            | Self::Break
431            | Self::Continue => 1,
432        }
433    }
434}
435
436/// Exhaustive Python program enumerator
437#[derive(Debug)]
438pub struct PythonEnumerator {
439    max_depth: usize,
440    var_names: Vec<String>,
441    int_values: Vec<i64>,
442}
443
444impl Default for PythonEnumerator {
445    fn default() -> Self {
446        Self::new(3)
447    }
448}
449
450impl PythonEnumerator {
451    /// Create a new enumerator with specified max depth
452    #[must_use]
453    pub fn new(max_depth: usize) -> Self {
454        Self {
455            max_depth,
456            var_names: vec!["x".to_string(), "y".to_string(), "z".to_string()],
457            int_values: vec![0, 1, -1, 2, 10],
458        }
459    }
460
461    /// Enumerate all expressions up to the given depth
462    pub fn enumerate_expressions(&self, depth: usize) -> Vec<PythonNode> {
463        if depth == 0 {
464            return vec![];
465        }
466
467        let mut results = Vec::new();
468
469        // Depth 1: literals and names
470        for &val in &self.int_values {
471            results.push(PythonNode::IntLit(val));
472        }
473        for name in &self.var_names {
474            results.push(PythonNode::Name(name.clone()));
475        }
476        results.push(PythonNode::BoolLit(true));
477        results.push(PythonNode::BoolLit(false));
478        results.push(PythonNode::NoneLit);
479
480        if depth == 1 {
481            return results;
482        }
483
484        // Depth 2+: compound expressions
485        let subexprs = self.enumerate_expressions(depth - 1);
486
487        // Unary operations
488        for op in UnaryOp::all() {
489            for subexpr in &subexprs {
490                if subexpr.depth() < depth {
491                    results.push(PythonNode::UnaryOp {
492                        op: *op,
493                        operand: Box::new(subexpr.clone()),
494                    });
495                }
496            }
497        }
498
499        // Binary operations (limited to prevent explosion)
500        let limited_subexprs: Vec<_> = subexprs.iter().take(10).collect();
501        for op in BinaryOp::all() {
502            for left in &limited_subexprs {
503                for right in &limited_subexprs {
504                    if left.depth() + right.depth() < depth {
505                        results.push(PythonNode::BinOp {
506                            left: Box::new((*left).clone()),
507                            op: *op,
508                            right: Box::new((*right).clone()),
509                        });
510                    }
511                }
512            }
513        }
514
515        // Comparisons
516        for op in CompareOp::all() {
517            for left in &limited_subexprs {
518                for right in &limited_subexprs {
519                    if left.depth() + right.depth() < depth {
520                        results.push(PythonNode::Compare {
521                            left: Box::new((*left).clone()),
522                            op: *op,
523                            right: Box::new((*right).clone()),
524                        });
525                    }
526                }
527            }
528        }
529
530        results
531    }
532
533    /// Enumerate all statements up to the given depth
534    pub fn enumerate_statements(&self, depth: usize) -> Vec<PythonNode> {
535        if depth == 0 {
536            return vec![];
537        }
538
539        let mut results = Vec::new();
540
541        // Simple statements
542        results.push(PythonNode::Pass);
543
544        let exprs = self.enumerate_expressions(depth - 1);
545        let limited_exprs: Vec<_> = exprs.iter().take(20).collect();
546
547        // Assignments
548        for target in &self.var_names {
549            for value in &limited_exprs {
550                results.push(PythonNode::Assign {
551                    target: target.clone(),
552                    value: Box::new((*value).clone()),
553                });
554            }
555        }
556
557        // Return statements
558        results.push(PythonNode::Return(None));
559        for expr in limited_exprs.iter().take(10) {
560            results.push(PythonNode::Return(Some(Box::new((*expr).clone()))));
561        }
562
563        if depth >= 2 {
564            // If statements
565            let conditions: Vec<_> = exprs
566                .iter()
567                .filter(|e| {
568                    matches!(
569                        e,
570                        PythonNode::Compare { .. } | PythonNode::BoolLit(_) | PythonNode::Name(_)
571                    )
572                })
573                .take(5)
574                .collect();
575
576            let body_stmts = self.enumerate_statements(depth - 1);
577            let limited_body: Vec<_> = body_stmts.iter().take(5).collect();
578
579            for cond in &conditions {
580                for body in &limited_body {
581                    results.push(PythonNode::If {
582                        test: Box::new((*cond).clone()),
583                        body: vec![(*body).clone()],
584                        orelse: vec![],
585                    });
586                }
587            }
588
589            // While loops
590            for cond in &conditions {
591                results.push(PythonNode::While {
592                    test: Box::new((*cond).clone()),
593                    body: vec![PythonNode::Break],
594                });
595            }
596        }
597
598        if depth >= 3 {
599            // Function definitions
600            for name in &["foo", "bar"] {
601                results.push(PythonNode::FuncDef {
602                    name: (*name).to_string(),
603                    args: vec![],
604                    body: vec![PythonNode::Pass],
605                });
606                results.push(PythonNode::FuncDef {
607                    name: (*name).to_string(),
608                    args: vec!["a".to_string()],
609                    body: vec![PythonNode::Return(Some(Box::new(PythonNode::Name(
610                        "a".to_string(),
611                    ))))],
612                });
613            }
614        }
615
616        results
617    }
618
619    /// Enumerate complete programs (modules)
620    pub fn enumerate_programs(&self) -> Vec<GeneratedCode> {
621        let mut results = Vec::new();
622
623        let stmts = self.enumerate_statements(self.max_depth);
624
625        // Single statement programs
626        for stmt in &stmts {
627            let module = PythonNode::Module(vec![stmt.clone()]);
628            let code = module.to_code(0);
629            results.push(GeneratedCode {
630                code,
631                language: Language::Python,
632                ast_depth: stmt.depth(),
633                features: self.extract_features(stmt),
634            });
635        }
636
637        // Two statement programs (limited)
638        let limited_stmts: Vec<_> = stmts.iter().take(20).collect();
639        for s1 in &limited_stmts {
640            for s2 in limited_stmts.iter().take(10) {
641                let module = PythonNode::Module(vec![(*s1).clone(), (*s2).clone()]);
642                let code = module.to_code(0);
643                let depth = s1.depth().max(s2.depth());
644                results.push(GeneratedCode {
645                    code,
646                    language: Language::Python,
647                    ast_depth: depth,
648                    features: self.extract_features(s1),
649                });
650            }
651        }
652
653        results
654    }
655
656    /// Extract feature labels from an AST node
657    fn extract_features(&self, node: &PythonNode) -> Vec<String> {
658        let mut features = Vec::new();
659
660        match node {
661            PythonNode::Assign { .. } => features.push("assignment".to_string()),
662            PythonNode::BinOp { op, .. } => {
663                features.push("binop".to_string());
664                features.push(format!("op_{}", op.to_str()));
665            }
666            PythonNode::If { orelse, .. } => {
667                features.push("if".to_string());
668                if !orelse.is_empty() {
669                    features.push("else".to_string());
670                }
671            }
672            PythonNode::While { .. } => features.push("while".to_string()),
673            PythonNode::For { .. } => features.push("for".to_string()),
674            PythonNode::FuncDef { .. } => features.push("funcdef".to_string()),
675            PythonNode::Return(_) => features.push("return".to_string()),
676            PythonNode::Compare { op, .. } => {
677                features.push("compare".to_string());
678                features.push(format!("cmp_{}", op.to_str()));
679            }
680            _ => {}
681        }
682
683        features
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690
691    #[test]
692    fn test_int_lit_to_code() {
693        let node = PythonNode::IntLit(42);
694        assert_eq!(node.to_code(0), "42");
695    }
696
697    #[test]
698    fn test_assign_to_code() {
699        let node = PythonNode::Assign {
700            target: "x".to_string(),
701            value: Box::new(PythonNode::IntLit(1)),
702        };
703        assert_eq!(node.to_code(0), "x = 1");
704    }
705
706    #[test]
707    fn test_binop_to_code() {
708        let node = PythonNode::BinOp {
709            left: Box::new(PythonNode::IntLit(1)),
710            op: BinaryOp::Add,
711            right: Box::new(PythonNode::IntLit(2)),
712        };
713        assert_eq!(node.to_code(0), "(1 + 2)");
714    }
715
716    #[test]
717    fn test_if_to_code() {
718        let node = PythonNode::If {
719            test: Box::new(PythonNode::BoolLit(true)),
720            body: vec![PythonNode::Pass],
721            orelse: vec![],
722        };
723        assert_eq!(node.to_code(0), "if True:\n    pass");
724    }
725
726    #[test]
727    fn test_funcdef_to_code() {
728        let node = PythonNode::FuncDef {
729            name: "foo".to_string(),
730            args: vec!["a".to_string(), "b".to_string()],
731            body: vec![PythonNode::Return(Some(Box::new(PythonNode::Name(
732                "a".to_string(),
733            ))))],
734        };
735        assert_eq!(node.to_code(0), "def foo(a, b):\n    return a");
736    }
737
738    #[test]
739    fn test_depth_calculation() {
740        let simple = PythonNode::IntLit(1);
741        assert_eq!(simple.depth(), 1);
742
743        let nested = PythonNode::BinOp {
744            left: Box::new(PythonNode::IntLit(1)),
745            op: BinaryOp::Add,
746            right: Box::new(PythonNode::BinOp {
747                left: Box::new(PythonNode::IntLit(2)),
748                op: BinaryOp::Mult,
749                right: Box::new(PythonNode::IntLit(3)),
750            }),
751        };
752        assert_eq!(nested.depth(), 3);
753    }
754
755    #[test]
756    fn test_enumerator_expressions() {
757        let enum_ = PythonEnumerator::new(2);
758        let exprs = enum_.enumerate_expressions(1);
759        assert!(!exprs.is_empty());
760        // Should have integers, names, booleans, None
761        assert!(exprs.iter().any(|e| matches!(e, PythonNode::IntLit(_))));
762        assert!(exprs.iter().any(|e| matches!(e, PythonNode::Name(_))));
763    }
764
765    #[test]
766    fn test_enumerator_statements() {
767        let enum_ = PythonEnumerator::new(2);
768        let stmts = enum_.enumerate_statements(2);
769        assert!(!stmts.is_empty());
770        // Should have pass, assignments, etc.
771        assert!(stmts.iter().any(|s| matches!(s, PythonNode::Pass)));
772        assert!(stmts.iter().any(|s| matches!(s, PythonNode::Assign { .. })));
773    }
774
775    #[test]
776    fn test_enumerator_programs() {
777        let enum_ = PythonEnumerator::new(2);
778        let programs = enum_.enumerate_programs();
779        assert!(!programs.is_empty());
780        // All programs should have valid Python code
781        for prog in &programs {
782            assert!(!prog.code.is_empty());
783            assert_eq!(prog.language, Language::Python);
784        }
785    }
786
787    #[test]
788    fn test_generated_code_is_valid_python() {
789        let enum_ = PythonEnumerator::new(2);
790        let programs = enum_.enumerate_programs();
791
792        // Test a few programs to ensure they look like valid Python
793        for prog in programs.iter().take(10) {
794            // Should not contain syntax errors (basic check)
795            assert!(
796                !prog.code.contains("():")
797                    || prog.code.contains("def ")
798                    || prog.code.contains("if ")
799            );
800        }
801    }
802
803    #[test]
804    fn test_binary_op_all() {
805        let ops = BinaryOp::all();
806        assert_eq!(ops.len(), 7);
807    }
808
809    #[test]
810    fn test_binary_op_to_str_all() {
811        assert_eq!(BinaryOp::Add.to_str(), "+");
812        assert_eq!(BinaryOp::Sub.to_str(), "-");
813        assert_eq!(BinaryOp::Mult.to_str(), "*");
814        assert_eq!(BinaryOp::Div.to_str(), "/");
815        assert_eq!(BinaryOp::Mod.to_str(), "%");
816        assert_eq!(BinaryOp::FloorDiv.to_str(), "//");
817        assert_eq!(BinaryOp::Pow.to_str(), "**");
818        assert_eq!(BinaryOp::And.to_str(), "and");
819        assert_eq!(BinaryOp::Or.to_str(), "or");
820    }
821
822    #[test]
823    fn test_unary_op_all() {
824        let ops = UnaryOp::all();
825        assert_eq!(ops.len(), 3);
826    }
827
828    #[test]
829    fn test_unary_op_to_str_all() {
830        assert_eq!(UnaryOp::Neg.to_str(), "-");
831        assert_eq!(UnaryOp::Not.to_str(), "not ");
832        assert_eq!(UnaryOp::Pos.to_str(), "+");
833    }
834
835    #[test]
836    fn test_compare_op_all() {
837        let ops = CompareOp::all();
838        assert_eq!(ops.len(), 6);
839    }
840
841    #[test]
842    fn test_compare_op_to_str_all() {
843        assert_eq!(CompareOp::Eq.to_str(), "==");
844        assert_eq!(CompareOp::NotEq.to_str(), "!=");
845        assert_eq!(CompareOp::Lt.to_str(), "<");
846        assert_eq!(CompareOp::LtE.to_str(), "<=");
847        assert_eq!(CompareOp::Gt.to_str(), ">");
848        assert_eq!(CompareOp::GtE.to_str(), ">=");
849    }
850
851    #[test]
852    fn test_float_lit_to_code() {
853        let node = PythonNode::FloatLit(3.14);
854        assert!(node.to_code(0).starts_with("3.1"));
855    }
856
857    #[test]
858    fn test_str_lit_to_code() {
859        let node = PythonNode::StrLit("hello".to_string());
860        assert_eq!(node.to_code(0), "\"hello\"");
861    }
862
863    #[test]
864    fn test_bool_lit_to_code() {
865        assert_eq!(PythonNode::BoolLit(true).to_code(0), "True");
866        assert_eq!(PythonNode::BoolLit(false).to_code(0), "False");
867    }
868
869    #[test]
870    fn test_none_lit_to_code() {
871        assert_eq!(PythonNode::NoneLit.to_code(0), "None");
872    }
873
874    #[test]
875    fn test_name_to_code() {
876        let node = PythonNode::Name("x".to_string());
877        assert_eq!(node.to_code(0), "x");
878    }
879
880    #[test]
881    fn test_unary_op_to_code() {
882        let node = PythonNode::UnaryOp {
883            op: UnaryOp::Neg,
884            operand: Box::new(PythonNode::IntLit(5)),
885        };
886        assert_eq!(node.to_code(0), "(-5)");
887    }
888
889    #[test]
890    fn test_if_with_else_to_code() {
891        let node = PythonNode::If {
892            test: Box::new(PythonNode::BoolLit(true)),
893            body: vec![PythonNode::Pass],
894            orelse: vec![PythonNode::Pass],
895        };
896        let code = node.to_code(0);
897        assert!(code.contains("if True:"));
898        assert!(code.contains("else:"));
899    }
900
901    #[test]
902    fn test_while_to_code() {
903        let node = PythonNode::While {
904            test: Box::new(PythonNode::BoolLit(true)),
905            body: vec![PythonNode::Break],
906        };
907        let code = node.to_code(0);
908        assert!(code.contains("while True:"));
909        assert!(code.contains("break"));
910    }
911
912    #[test]
913    fn test_for_to_code() {
914        let node = PythonNode::For {
915            target: "i".to_string(),
916            iter: Box::new(PythonNode::List(vec![PythonNode::IntLit(1)])),
917            body: vec![PythonNode::Continue],
918        };
919        let code = node.to_code(0);
920        assert!(code.contains("for i in"));
921        assert!(code.contains("continue"));
922    }
923
924    #[test]
925    fn test_call_to_code() {
926        let node = PythonNode::Call {
927            func: "print".to_string(),
928            args: vec![PythonNode::IntLit(1), PythonNode::IntLit(2)],
929        };
930        assert_eq!(node.to_code(0), "print(1, 2)");
931    }
932
933    #[test]
934    fn test_return_none_to_code() {
935        let node = PythonNode::Return(None);
936        assert_eq!(node.to_code(0), "return");
937    }
938
939    #[test]
940    fn test_break_to_code() {
941        let node = PythonNode::Break;
942        assert_eq!(node.to_code(0), "break");
943    }
944
945    #[test]
946    fn test_continue_to_code() {
947        let node = PythonNode::Continue;
948        assert_eq!(node.to_code(0), "continue");
949    }
950
951    #[test]
952    fn test_list_to_code() {
953        let node = PythonNode::List(vec![
954            PythonNode::IntLit(1),
955            PythonNode::IntLit(2),
956            PythonNode::IntLit(3),
957        ]);
958        assert_eq!(node.to_code(0), "[1, 2, 3]");
959    }
960
961    #[test]
962    fn test_empty_list_to_code() {
963        let node = PythonNode::List(vec![]);
964        assert_eq!(node.to_code(0), "[]");
965    }
966
967    #[test]
968    fn test_compare_to_code() {
969        let node = PythonNode::Compare {
970            left: Box::new(PythonNode::IntLit(1)),
971            op: CompareOp::Lt,
972            right: Box::new(PythonNode::IntLit(2)),
973        };
974        assert_eq!(node.to_code(0), "(1 < 2)");
975    }
976
977    #[test]
978    fn test_module_to_code() {
979        let node = PythonNode::Module(vec![
980            PythonNode::Assign {
981                target: "x".to_string(),
982                value: Box::new(PythonNode::IntLit(1)),
983            },
984            PythonNode::Pass,
985        ]);
986        let code = node.to_code(0);
987        assert!(code.contains("x = 1"));
988        assert!(code.contains("pass"));
989    }
990
991    #[test]
992    fn test_python_node_debug() {
993        let node = PythonNode::IntLit(42);
994        let debug = format!("{:?}", node);
995        assert!(debug.contains("IntLit"));
996    }
997
998    #[test]
999    fn test_python_node_clone() {
1000        let node = PythonNode::IntLit(42);
1001        let cloned = node.clone();
1002        assert_eq!(cloned, node);
1003    }
1004
1005    #[test]
1006    fn test_binary_op_debug() {
1007        let op = BinaryOp::Add;
1008        let debug = format!("{:?}", op);
1009        assert!(debug.contains("Add"));
1010    }
1011
1012    #[test]
1013    fn test_binary_op_clone() {
1014        let op = BinaryOp::Add;
1015        let cloned = op.clone();
1016        assert_eq!(cloned, op);
1017    }
1018
1019    #[test]
1020    fn test_unary_op_debug() {
1021        let op = UnaryOp::Neg;
1022        let debug = format!("{:?}", op);
1023        assert!(debug.contains("Neg"));
1024    }
1025
1026    #[test]
1027    fn test_compare_op_debug() {
1028        let op = CompareOp::Lt;
1029        let debug = format!("{:?}", op);
1030        assert!(debug.contains("Lt"));
1031    }
1032
1033    #[test]
1034    fn test_extract_features_binop() {
1035        let enum_ = PythonEnumerator::new(2);
1036        let node = PythonNode::BinOp {
1037            left: Box::new(PythonNode::IntLit(1)),
1038            op: BinaryOp::Add,
1039            right: Box::new(PythonNode::IntLit(2)),
1040        };
1041        let features = enum_.extract_features(&node);
1042        assert!(features.contains(&"binop".to_string()));
1043    }
1044
1045    #[test]
1046    fn test_extract_features_if_with_else() {
1047        let enum_ = PythonEnumerator::new(2);
1048        let node = PythonNode::If {
1049            test: Box::new(PythonNode::BoolLit(true)),
1050            body: vec![PythonNode::Pass],
1051            orelse: vec![PythonNode::Pass],
1052        };
1053        let features = enum_.extract_features(&node);
1054        assert!(features.contains(&"if".to_string()));
1055        assert!(features.contains(&"else".to_string()));
1056    }
1057
1058    #[test]
1059    fn test_extract_features_while() {
1060        let enum_ = PythonEnumerator::new(2);
1061        let node = PythonNode::While {
1062            test: Box::new(PythonNode::BoolLit(true)),
1063            body: vec![PythonNode::Pass],
1064        };
1065        let features = enum_.extract_features(&node);
1066        assert!(features.contains(&"while".to_string()));
1067    }
1068
1069    #[test]
1070    fn test_extract_features_for() {
1071        let enum_ = PythonEnumerator::new(2);
1072        let node = PythonNode::For {
1073            target: "i".to_string(),
1074            iter: Box::new(PythonNode::List(vec![])),
1075            body: vec![PythonNode::Pass],
1076        };
1077        let features = enum_.extract_features(&node);
1078        assert!(features.contains(&"for".to_string()));
1079    }
1080
1081    #[test]
1082    fn test_extract_features_compare() {
1083        let enum_ = PythonEnumerator::new(2);
1084        let node = PythonNode::Compare {
1085            left: Box::new(PythonNode::IntLit(1)),
1086            op: CompareOp::Lt,
1087            right: Box::new(PythonNode::IntLit(2)),
1088        };
1089        let features = enum_.extract_features(&node);
1090        assert!(features.contains(&"compare".to_string()));
1091    }
1092
1093    #[test]
1094    fn test_depth_if() {
1095        let node = PythonNode::If {
1096            test: Box::new(PythonNode::BoolLit(true)),
1097            body: vec![PythonNode::Pass],
1098            orelse: vec![],
1099        };
1100        assert!(node.depth() >= 2);
1101    }
1102
1103    #[test]
1104    fn test_depth_while() {
1105        let node = PythonNode::While {
1106            test: Box::new(PythonNode::BoolLit(true)),
1107            body: vec![PythonNode::Pass],
1108        };
1109        assert!(node.depth() >= 2);
1110    }
1111
1112    #[test]
1113    fn test_depth_for() {
1114        let node = PythonNode::For {
1115            target: "i".to_string(),
1116            iter: Box::new(PythonNode::List(vec![])),
1117            body: vec![PythonNode::Pass],
1118        };
1119        assert!(node.depth() >= 2);
1120    }
1121
1122    #[test]
1123    fn test_depth_funcdef() {
1124        let node = PythonNode::FuncDef {
1125            name: "f".to_string(),
1126            args: vec![],
1127            body: vec![PythonNode::Pass],
1128        };
1129        assert!(node.depth() >= 2);
1130    }
1131}