1use super::GeneratedCode;
7use crate::Language;
8
9#[derive(Debug, Clone, PartialEq)]
11#[allow(missing_docs)]
12pub enum PythonNode {
13 Module(Vec<PythonNode>),
15 Assign {
17 target: String,
19 value: Box<PythonNode>,
21 },
22 BinOp {
24 left: Box<PythonNode>,
26 op: BinaryOp,
28 right: Box<PythonNode>,
30 },
31 UnaryOp {
33 op: UnaryOp,
35 operand: Box<PythonNode>,
37 },
38 IntLit(i64),
40 FloatLit(f64),
42 StrLit(String),
44 BoolLit(bool),
46 NoneLit,
48 Name(String),
50 If {
52 test: Box<PythonNode>,
54 body: Vec<PythonNode>,
56 orelse: Vec<PythonNode>,
58 },
59 While {
61 test: Box<PythonNode>,
63 body: Vec<PythonNode>,
65 },
66 For {
68 target: String,
70 iter: Box<PythonNode>,
72 body: Vec<PythonNode>,
74 },
75 FuncDef {
77 name: String,
79 args: Vec<String>,
81 body: Vec<PythonNode>,
83 },
84 Call {
86 func: String,
88 args: Vec<PythonNode>,
90 },
91 Return(Option<Box<PythonNode>>),
93 Pass,
95 Break,
97 Continue,
99 List(Vec<PythonNode>),
101 Compare {
103 left: Box<PythonNode>,
105 op: CompareOp,
107 right: Box<PythonNode>,
109 },
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum BinaryOp {
115 Add,
117 Sub,
119 Mult,
121 Div,
123 Mod,
125 FloorDiv,
127 Pow,
129 And,
131 Or,
133}
134
135impl BinaryOp {
136 #[must_use]
138 pub fn all() -> &'static [Self] {
139 &[
140 Self::Add,
141 Self::Sub,
142 Self::Mult,
143 Self::Div,
144 Self::Mod,
145 Self::FloorDiv,
146 Self::Pow,
147 ]
148 }
149
150 #[must_use]
152 pub fn to_str(self) -> &'static str {
153 match self {
154 Self::Add => "+",
155 Self::Sub => "-",
156 Self::Mult => "*",
157 Self::Div => "/",
158 Self::Mod => "%",
159 Self::FloorDiv => "//",
160 Self::Pow => "**",
161 Self::And => "and",
162 Self::Or => "or",
163 }
164 }
165}
166
167#[derive(Debug, Clone, Copy, PartialEq, Eq)]
169pub enum UnaryOp {
170 Neg,
172 Not,
174 Pos,
176}
177
178impl UnaryOp {
179 #[must_use]
181 pub fn all() -> &'static [Self] {
182 &[Self::Neg, Self::Not, Self::Pos]
183 }
184
185 #[must_use]
187 pub fn to_str(self) -> &'static str {
188 match self {
189 Self::Neg => "-",
190 Self::Not => "not ",
191 Self::Pos => "+",
192 }
193 }
194}
195
196#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub enum CompareOp {
199 Eq,
201 NotEq,
203 Lt,
205 LtE,
207 Gt,
209 GtE,
211}
212
213impl CompareOp {
214 #[must_use]
216 pub fn all() -> &'static [Self] {
217 &[
218 Self::Eq,
219 Self::NotEq,
220 Self::Lt,
221 Self::LtE,
222 Self::Gt,
223 Self::GtE,
224 ]
225 }
226
227 #[must_use]
229 pub fn to_str(self) -> &'static str {
230 match self {
231 Self::Eq => "==",
232 Self::NotEq => "!=",
233 Self::Lt => "<",
234 Self::LtE => "<=",
235 Self::Gt => ">",
236 Self::GtE => ">=",
237 }
238 }
239}
240
241impl PythonNode {
242 #[allow(clippy::too_many_lines)]
244 pub fn to_code(&self, indent: usize) -> String {
245 let indent_str = " ".repeat(indent);
246 match self {
247 Self::Module(stmts) => stmts
248 .iter()
249 .map(|s| s.to_code(0))
250 .collect::<Vec<_>>()
251 .join("\n"),
252 Self::Assign { target, value } => {
253 let val = value.to_code(0);
254 format!("{indent_str}{target} = {val}")
255 }
256 Self::BinOp { left, op, right } => {
257 let l = left.to_code(0);
258 let r = right.to_code(0);
259 let o = op.to_str();
260 format!("({l} {o} {r})")
261 }
262 Self::UnaryOp { op, operand } => {
263 let o = op.to_str();
264 let e = operand.to_code(0);
265 format!("({o}{e})")
266 }
267 Self::IntLit(n) => n.to_string(),
268 Self::FloatLit(f) => format!("{f:.1}"),
269 Self::StrLit(s) => format!("\"{s}\""),
270 Self::BoolLit(b) => if *b { "True" } else { "False" }.to_string(),
271 Self::NoneLit => "None".to_string(),
272 Self::Name(name) => name.clone(),
273 Self::If { test, body, orelse } => {
274 self.if_to_code(&indent_str, indent, test, body, orelse)
275 }
276 Self::While { test, body } => self.while_to_code(&indent_str, indent, test, body),
277 Self::For { target, iter, body } => {
278 self.for_to_code(&indent_str, indent, target, iter, body)
279 }
280 Self::FuncDef { name, args, body } => {
281 self.funcdef_to_code(&indent_str, indent, name, args, body)
282 }
283 Self::Call { func, args } => {
284 let args_str = args
285 .iter()
286 .map(|a| a.to_code(0))
287 .collect::<Vec<_>>()
288 .join(", ");
289 format!("{func}({args_str})")
290 }
291 Self::Return(Some(value)) => {
292 let val = value.to_code(0);
293 format!("{indent_str}return {val}")
294 }
295 Self::Return(None) => format!("{indent_str}return"),
296 Self::Pass => format!("{indent_str}pass"),
297 Self::Break => format!("{indent_str}break"),
298 Self::Continue => format!("{indent_str}continue"),
299 Self::List(items) => {
300 let items_str = items
301 .iter()
302 .map(|i| i.to_code(0))
303 .collect::<Vec<_>>()
304 .join(", ");
305 format!("[{items_str}]")
306 }
307 Self::Compare { left, op, right } => {
308 let l = left.to_code(0);
309 let r = right.to_code(0);
310 let o = op.to_str();
311 format!("({l} {o} {r})")
312 }
313 }
314 }
315
316 fn if_to_code(
317 &self,
318 indent_str: &str,
319 indent: usize,
320 test: &PythonNode,
321 body: &[PythonNode],
322 orelse: &[PythonNode],
323 ) -> String {
324 let body_code = body
325 .iter()
326 .map(|s| s.to_code(indent + 1))
327 .collect::<Vec<_>>()
328 .join("\n");
329 let test_code = test.to_code(0);
330 if orelse.is_empty() {
331 format!("{indent_str}if {test_code}:\n{body_code}")
332 } else {
333 let else_code = orelse
334 .iter()
335 .map(|s| s.to_code(indent + 1))
336 .collect::<Vec<_>>()
337 .join("\n");
338 format!("{indent_str}if {test_code}:\n{body_code}\n{indent_str}else:\n{else_code}")
339 }
340 }
341
342 fn while_to_code(
343 &self,
344 indent_str: &str,
345 indent: usize,
346 test: &PythonNode,
347 body: &[PythonNode],
348 ) -> String {
349 let body_code = body
350 .iter()
351 .map(|s| s.to_code(indent + 1))
352 .collect::<Vec<_>>()
353 .join("\n");
354 let test_code = test.to_code(0);
355 format!("{indent_str}while {test_code}:\n{body_code}")
356 }
357
358 fn for_to_code(
359 &self,
360 indent_str: &str,
361 indent: usize,
362 target: &str,
363 iter: &PythonNode,
364 body: &[PythonNode],
365 ) -> String {
366 let body_code = body
367 .iter()
368 .map(|s| s.to_code(indent + 1))
369 .collect::<Vec<_>>()
370 .join("\n");
371 let iter_code = iter.to_code(0);
372 format!("{indent_str}for {target} in {iter_code}:\n{body_code}")
373 }
374
375 fn funcdef_to_code(
376 &self,
377 indent_str: &str,
378 indent: usize,
379 name: &str,
380 args: &[String],
381 body: &[PythonNode],
382 ) -> String {
383 let args_str = args.join(", ");
384 let body_code = if body.is_empty() {
385 format!("{indent_str} pass")
386 } else {
387 body.iter()
388 .map(|s| s.to_code(indent + 1))
389 .collect::<Vec<_>>()
390 .join("\n")
391 };
392 format!("{indent_str}def {name}({args_str}):\n{body_code}")
393 }
394
395 pub fn depth(&self) -> usize {
397 match self {
398 Self::Module(stmts) => 1 + stmts.iter().map(Self::depth).max().unwrap_or(0),
399 Self::Assign { value, .. } => 1 + value.depth(),
400 Self::BinOp { left, right, .. } | Self::Compare { left, right, .. } => {
401 1 + left.depth().max(right.depth())
402 }
403 Self::UnaryOp { operand, .. } => 1 + operand.depth(),
404 Self::If { test, body, orelse } => {
405 let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
406 let else_depth = orelse.iter().map(Self::depth).max().unwrap_or(0);
407 1 + test.depth().max(body_depth).max(else_depth)
408 }
409 Self::While { test, body } => {
410 let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
411 1 + test.depth().max(body_depth)
412 }
413 Self::For { iter, body, .. } => {
414 let body_depth = body.iter().map(Self::depth).max().unwrap_or(0);
415 1 + iter.depth().max(body_depth)
416 }
417 Self::FuncDef { body, .. } => 1 + body.iter().map(Self::depth).max().unwrap_or(0),
418 Self::Call { args, .. } => 1 + args.iter().map(Self::depth).max().unwrap_or(0),
419 Self::Return(Some(v)) => 1 + v.depth(),
420 Self::List(items) => 1 + items.iter().map(Self::depth).max().unwrap_or(0),
421 Self::Return(None)
423 | Self::IntLit(_)
424 | Self::FloatLit(_)
425 | Self::StrLit(_)
426 | Self::BoolLit(_)
427 | Self::NoneLit
428 | Self::Name(_)
429 | Self::Pass
430 | Self::Break
431 | Self::Continue => 1,
432 }
433 }
434}
435
436#[derive(Debug)]
438pub struct PythonEnumerator {
439 max_depth: usize,
440 var_names: Vec<String>,
441 int_values: Vec<i64>,
442}
443
444impl Default for PythonEnumerator {
445 fn default() -> Self {
446 Self::new(3)
447 }
448}
449
450impl PythonEnumerator {
451 #[must_use]
453 pub fn new(max_depth: usize) -> Self {
454 Self {
455 max_depth,
456 var_names: vec!["x".to_string(), "y".to_string(), "z".to_string()],
457 int_values: vec![0, 1, -1, 2, 10],
458 }
459 }
460
461 pub fn enumerate_expressions(&self, depth: usize) -> Vec<PythonNode> {
463 if depth == 0 {
464 return vec![];
465 }
466
467 let mut results = Vec::new();
468
469 for &val in &self.int_values {
471 results.push(PythonNode::IntLit(val));
472 }
473 for name in &self.var_names {
474 results.push(PythonNode::Name(name.clone()));
475 }
476 results.push(PythonNode::BoolLit(true));
477 results.push(PythonNode::BoolLit(false));
478 results.push(PythonNode::NoneLit);
479
480 if depth == 1 {
481 return results;
482 }
483
484 let subexprs = self.enumerate_expressions(depth - 1);
486
487 for op in UnaryOp::all() {
489 for subexpr in &subexprs {
490 if subexpr.depth() < depth {
491 results.push(PythonNode::UnaryOp {
492 op: *op,
493 operand: Box::new(subexpr.clone()),
494 });
495 }
496 }
497 }
498
499 let limited_subexprs: Vec<_> = subexprs.iter().take(10).collect();
501 for op in BinaryOp::all() {
502 for left in &limited_subexprs {
503 for right in &limited_subexprs {
504 if left.depth() + right.depth() < depth {
505 results.push(PythonNode::BinOp {
506 left: Box::new((*left).clone()),
507 op: *op,
508 right: Box::new((*right).clone()),
509 });
510 }
511 }
512 }
513 }
514
515 for op in CompareOp::all() {
517 for left in &limited_subexprs {
518 for right in &limited_subexprs {
519 if left.depth() + right.depth() < depth {
520 results.push(PythonNode::Compare {
521 left: Box::new((*left).clone()),
522 op: *op,
523 right: Box::new((*right).clone()),
524 });
525 }
526 }
527 }
528 }
529
530 results
531 }
532
533 pub fn enumerate_statements(&self, depth: usize) -> Vec<PythonNode> {
535 if depth == 0 {
536 return vec![];
537 }
538
539 let mut results = Vec::new();
540
541 results.push(PythonNode::Pass);
543
544 let exprs = self.enumerate_expressions(depth - 1);
545 let limited_exprs: Vec<_> = exprs.iter().take(20).collect();
546
547 for target in &self.var_names {
549 for value in &limited_exprs {
550 results.push(PythonNode::Assign {
551 target: target.clone(),
552 value: Box::new((*value).clone()),
553 });
554 }
555 }
556
557 results.push(PythonNode::Return(None));
559 for expr in limited_exprs.iter().take(10) {
560 results.push(PythonNode::Return(Some(Box::new((*expr).clone()))));
561 }
562
563 if depth >= 2 {
564 let conditions: Vec<_> = exprs
566 .iter()
567 .filter(|e| {
568 matches!(
569 e,
570 PythonNode::Compare { .. } | PythonNode::BoolLit(_) | PythonNode::Name(_)
571 )
572 })
573 .take(5)
574 .collect();
575
576 let body_stmts = self.enumerate_statements(depth - 1);
577 let limited_body: Vec<_> = body_stmts.iter().take(5).collect();
578
579 for cond in &conditions {
580 for body in &limited_body {
581 results.push(PythonNode::If {
582 test: Box::new((*cond).clone()),
583 body: vec![(*body).clone()],
584 orelse: vec![],
585 });
586 }
587 }
588
589 for cond in &conditions {
591 results.push(PythonNode::While {
592 test: Box::new((*cond).clone()),
593 body: vec![PythonNode::Break],
594 });
595 }
596 }
597
598 if depth >= 3 {
599 for name in &["foo", "bar"] {
601 results.push(PythonNode::FuncDef {
602 name: (*name).to_string(),
603 args: vec![],
604 body: vec![PythonNode::Pass],
605 });
606 results.push(PythonNode::FuncDef {
607 name: (*name).to_string(),
608 args: vec!["a".to_string()],
609 body: vec![PythonNode::Return(Some(Box::new(PythonNode::Name(
610 "a".to_string(),
611 ))))],
612 });
613 }
614 }
615
616 results
617 }
618
619 pub fn enumerate_programs(&self) -> Vec<GeneratedCode> {
621 let mut results = Vec::new();
622
623 let stmts = self.enumerate_statements(self.max_depth);
624
625 for stmt in &stmts {
627 let module = PythonNode::Module(vec![stmt.clone()]);
628 let code = module.to_code(0);
629 results.push(GeneratedCode {
630 code,
631 language: Language::Python,
632 ast_depth: stmt.depth(),
633 features: self.extract_features(stmt),
634 });
635 }
636
637 let limited_stmts: Vec<_> = stmts.iter().take(20).collect();
639 for s1 in &limited_stmts {
640 for s2 in limited_stmts.iter().take(10) {
641 let module = PythonNode::Module(vec![(*s1).clone(), (*s2).clone()]);
642 let code = module.to_code(0);
643 let depth = s1.depth().max(s2.depth());
644 results.push(GeneratedCode {
645 code,
646 language: Language::Python,
647 ast_depth: depth,
648 features: self.extract_features(s1),
649 });
650 }
651 }
652
653 results
654 }
655
656 fn extract_features(&self, node: &PythonNode) -> Vec<String> {
658 let mut features = Vec::new();
659
660 match node {
661 PythonNode::Assign { .. } => features.push("assignment".to_string()),
662 PythonNode::BinOp { op, .. } => {
663 features.push("binop".to_string());
664 features.push(format!("op_{}", op.to_str()));
665 }
666 PythonNode::If { orelse, .. } => {
667 features.push("if".to_string());
668 if !orelse.is_empty() {
669 features.push("else".to_string());
670 }
671 }
672 PythonNode::While { .. } => features.push("while".to_string()),
673 PythonNode::For { .. } => features.push("for".to_string()),
674 PythonNode::FuncDef { .. } => features.push("funcdef".to_string()),
675 PythonNode::Return(_) => features.push("return".to_string()),
676 PythonNode::Compare { op, .. } => {
677 features.push("compare".to_string());
678 features.push(format!("cmp_{}", op.to_str()));
679 }
680 _ => {}
681 }
682
683 features
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690
691 #[test]
692 fn test_int_lit_to_code() {
693 let node = PythonNode::IntLit(42);
694 assert_eq!(node.to_code(0), "42");
695 }
696
697 #[test]
698 fn test_assign_to_code() {
699 let node = PythonNode::Assign {
700 target: "x".to_string(),
701 value: Box::new(PythonNode::IntLit(1)),
702 };
703 assert_eq!(node.to_code(0), "x = 1");
704 }
705
706 #[test]
707 fn test_binop_to_code() {
708 let node = PythonNode::BinOp {
709 left: Box::new(PythonNode::IntLit(1)),
710 op: BinaryOp::Add,
711 right: Box::new(PythonNode::IntLit(2)),
712 };
713 assert_eq!(node.to_code(0), "(1 + 2)");
714 }
715
716 #[test]
717 fn test_if_to_code() {
718 let node = PythonNode::If {
719 test: Box::new(PythonNode::BoolLit(true)),
720 body: vec![PythonNode::Pass],
721 orelse: vec![],
722 };
723 assert_eq!(node.to_code(0), "if True:\n pass");
724 }
725
726 #[test]
727 fn test_funcdef_to_code() {
728 let node = PythonNode::FuncDef {
729 name: "foo".to_string(),
730 args: vec!["a".to_string(), "b".to_string()],
731 body: vec![PythonNode::Return(Some(Box::new(PythonNode::Name(
732 "a".to_string(),
733 ))))],
734 };
735 assert_eq!(node.to_code(0), "def foo(a, b):\n return a");
736 }
737
738 #[test]
739 fn test_depth_calculation() {
740 let simple = PythonNode::IntLit(1);
741 assert_eq!(simple.depth(), 1);
742
743 let nested = PythonNode::BinOp {
744 left: Box::new(PythonNode::IntLit(1)),
745 op: BinaryOp::Add,
746 right: Box::new(PythonNode::BinOp {
747 left: Box::new(PythonNode::IntLit(2)),
748 op: BinaryOp::Mult,
749 right: Box::new(PythonNode::IntLit(3)),
750 }),
751 };
752 assert_eq!(nested.depth(), 3);
753 }
754
755 #[test]
756 fn test_enumerator_expressions() {
757 let enum_ = PythonEnumerator::new(2);
758 let exprs = enum_.enumerate_expressions(1);
759 assert!(!exprs.is_empty());
760 assert!(exprs.iter().any(|e| matches!(e, PythonNode::IntLit(_))));
762 assert!(exprs.iter().any(|e| matches!(e, PythonNode::Name(_))));
763 }
764
765 #[test]
766 fn test_enumerator_statements() {
767 let enum_ = PythonEnumerator::new(2);
768 let stmts = enum_.enumerate_statements(2);
769 assert!(!stmts.is_empty());
770 assert!(stmts.iter().any(|s| matches!(s, PythonNode::Pass)));
772 assert!(stmts.iter().any(|s| matches!(s, PythonNode::Assign { .. })));
773 }
774
775 #[test]
776 fn test_enumerator_programs() {
777 let enum_ = PythonEnumerator::new(2);
778 let programs = enum_.enumerate_programs();
779 assert!(!programs.is_empty());
780 for prog in &programs {
782 assert!(!prog.code.is_empty());
783 assert_eq!(prog.language, Language::Python);
784 }
785 }
786
787 #[test]
788 fn test_generated_code_is_valid_python() {
789 let enum_ = PythonEnumerator::new(2);
790 let programs = enum_.enumerate_programs();
791
792 for prog in programs.iter().take(10) {
794 assert!(
796 !prog.code.contains("():")
797 || prog.code.contains("def ")
798 || prog.code.contains("if ")
799 );
800 }
801 }
802
803 #[test]
804 fn test_binary_op_all() {
805 let ops = BinaryOp::all();
806 assert_eq!(ops.len(), 7);
807 }
808
809 #[test]
810 fn test_binary_op_to_str_all() {
811 assert_eq!(BinaryOp::Add.to_str(), "+");
812 assert_eq!(BinaryOp::Sub.to_str(), "-");
813 assert_eq!(BinaryOp::Mult.to_str(), "*");
814 assert_eq!(BinaryOp::Div.to_str(), "/");
815 assert_eq!(BinaryOp::Mod.to_str(), "%");
816 assert_eq!(BinaryOp::FloorDiv.to_str(), "//");
817 assert_eq!(BinaryOp::Pow.to_str(), "**");
818 assert_eq!(BinaryOp::And.to_str(), "and");
819 assert_eq!(BinaryOp::Or.to_str(), "or");
820 }
821
822 #[test]
823 fn test_unary_op_all() {
824 let ops = UnaryOp::all();
825 assert_eq!(ops.len(), 3);
826 }
827
828 #[test]
829 fn test_unary_op_to_str_all() {
830 assert_eq!(UnaryOp::Neg.to_str(), "-");
831 assert_eq!(UnaryOp::Not.to_str(), "not ");
832 assert_eq!(UnaryOp::Pos.to_str(), "+");
833 }
834
835 #[test]
836 fn test_compare_op_all() {
837 let ops = CompareOp::all();
838 assert_eq!(ops.len(), 6);
839 }
840
841 #[test]
842 fn test_compare_op_to_str_all() {
843 assert_eq!(CompareOp::Eq.to_str(), "==");
844 assert_eq!(CompareOp::NotEq.to_str(), "!=");
845 assert_eq!(CompareOp::Lt.to_str(), "<");
846 assert_eq!(CompareOp::LtE.to_str(), "<=");
847 assert_eq!(CompareOp::Gt.to_str(), ">");
848 assert_eq!(CompareOp::GtE.to_str(), ">=");
849 }
850
851 #[test]
852 fn test_float_lit_to_code() {
853 let node = PythonNode::FloatLit(3.14);
854 assert!(node.to_code(0).starts_with("3.1"));
855 }
856
857 #[test]
858 fn test_str_lit_to_code() {
859 let node = PythonNode::StrLit("hello".to_string());
860 assert_eq!(node.to_code(0), "\"hello\"");
861 }
862
863 #[test]
864 fn test_bool_lit_to_code() {
865 assert_eq!(PythonNode::BoolLit(true).to_code(0), "True");
866 assert_eq!(PythonNode::BoolLit(false).to_code(0), "False");
867 }
868
869 #[test]
870 fn test_none_lit_to_code() {
871 assert_eq!(PythonNode::NoneLit.to_code(0), "None");
872 }
873
874 #[test]
875 fn test_name_to_code() {
876 let node = PythonNode::Name("x".to_string());
877 assert_eq!(node.to_code(0), "x");
878 }
879
880 #[test]
881 fn test_unary_op_to_code() {
882 let node = PythonNode::UnaryOp {
883 op: UnaryOp::Neg,
884 operand: Box::new(PythonNode::IntLit(5)),
885 };
886 assert_eq!(node.to_code(0), "(-5)");
887 }
888
889 #[test]
890 fn test_if_with_else_to_code() {
891 let node = PythonNode::If {
892 test: Box::new(PythonNode::BoolLit(true)),
893 body: vec![PythonNode::Pass],
894 orelse: vec![PythonNode::Pass],
895 };
896 let code = node.to_code(0);
897 assert!(code.contains("if True:"));
898 assert!(code.contains("else:"));
899 }
900
901 #[test]
902 fn test_while_to_code() {
903 let node = PythonNode::While {
904 test: Box::new(PythonNode::BoolLit(true)),
905 body: vec![PythonNode::Break],
906 };
907 let code = node.to_code(0);
908 assert!(code.contains("while True:"));
909 assert!(code.contains("break"));
910 }
911
912 #[test]
913 fn test_for_to_code() {
914 let node = PythonNode::For {
915 target: "i".to_string(),
916 iter: Box::new(PythonNode::List(vec![PythonNode::IntLit(1)])),
917 body: vec![PythonNode::Continue],
918 };
919 let code = node.to_code(0);
920 assert!(code.contains("for i in"));
921 assert!(code.contains("continue"));
922 }
923
924 #[test]
925 fn test_call_to_code() {
926 let node = PythonNode::Call {
927 func: "print".to_string(),
928 args: vec![PythonNode::IntLit(1), PythonNode::IntLit(2)],
929 };
930 assert_eq!(node.to_code(0), "print(1, 2)");
931 }
932
933 #[test]
934 fn test_return_none_to_code() {
935 let node = PythonNode::Return(None);
936 assert_eq!(node.to_code(0), "return");
937 }
938
939 #[test]
940 fn test_break_to_code() {
941 let node = PythonNode::Break;
942 assert_eq!(node.to_code(0), "break");
943 }
944
945 #[test]
946 fn test_continue_to_code() {
947 let node = PythonNode::Continue;
948 assert_eq!(node.to_code(0), "continue");
949 }
950
951 #[test]
952 fn test_list_to_code() {
953 let node = PythonNode::List(vec![
954 PythonNode::IntLit(1),
955 PythonNode::IntLit(2),
956 PythonNode::IntLit(3),
957 ]);
958 assert_eq!(node.to_code(0), "[1, 2, 3]");
959 }
960
961 #[test]
962 fn test_empty_list_to_code() {
963 let node = PythonNode::List(vec![]);
964 assert_eq!(node.to_code(0), "[]");
965 }
966
967 #[test]
968 fn test_compare_to_code() {
969 let node = PythonNode::Compare {
970 left: Box::new(PythonNode::IntLit(1)),
971 op: CompareOp::Lt,
972 right: Box::new(PythonNode::IntLit(2)),
973 };
974 assert_eq!(node.to_code(0), "(1 < 2)");
975 }
976
977 #[test]
978 fn test_module_to_code() {
979 let node = PythonNode::Module(vec![
980 PythonNode::Assign {
981 target: "x".to_string(),
982 value: Box::new(PythonNode::IntLit(1)),
983 },
984 PythonNode::Pass,
985 ]);
986 let code = node.to_code(0);
987 assert!(code.contains("x = 1"));
988 assert!(code.contains("pass"));
989 }
990
991 #[test]
992 fn test_python_node_debug() {
993 let node = PythonNode::IntLit(42);
994 let debug = format!("{:?}", node);
995 assert!(debug.contains("IntLit"));
996 }
997
998 #[test]
999 fn test_python_node_clone() {
1000 let node = PythonNode::IntLit(42);
1001 let cloned = node.clone();
1002 assert_eq!(cloned, node);
1003 }
1004
1005 #[test]
1006 fn test_binary_op_debug() {
1007 let op = BinaryOp::Add;
1008 let debug = format!("{:?}", op);
1009 assert!(debug.contains("Add"));
1010 }
1011
1012 #[test]
1013 fn test_binary_op_clone() {
1014 let op = BinaryOp::Add;
1015 let cloned = op.clone();
1016 assert_eq!(cloned, op);
1017 }
1018
1019 #[test]
1020 fn test_unary_op_debug() {
1021 let op = UnaryOp::Neg;
1022 let debug = format!("{:?}", op);
1023 assert!(debug.contains("Neg"));
1024 }
1025
1026 #[test]
1027 fn test_compare_op_debug() {
1028 let op = CompareOp::Lt;
1029 let debug = format!("{:?}", op);
1030 assert!(debug.contains("Lt"));
1031 }
1032
1033 #[test]
1034 fn test_extract_features_binop() {
1035 let enum_ = PythonEnumerator::new(2);
1036 let node = PythonNode::BinOp {
1037 left: Box::new(PythonNode::IntLit(1)),
1038 op: BinaryOp::Add,
1039 right: Box::new(PythonNode::IntLit(2)),
1040 };
1041 let features = enum_.extract_features(&node);
1042 assert!(features.contains(&"binop".to_string()));
1043 }
1044
1045 #[test]
1046 fn test_extract_features_if_with_else() {
1047 let enum_ = PythonEnumerator::new(2);
1048 let node = PythonNode::If {
1049 test: Box::new(PythonNode::BoolLit(true)),
1050 body: vec![PythonNode::Pass],
1051 orelse: vec![PythonNode::Pass],
1052 };
1053 let features = enum_.extract_features(&node);
1054 assert!(features.contains(&"if".to_string()));
1055 assert!(features.contains(&"else".to_string()));
1056 }
1057
1058 #[test]
1059 fn test_extract_features_while() {
1060 let enum_ = PythonEnumerator::new(2);
1061 let node = PythonNode::While {
1062 test: Box::new(PythonNode::BoolLit(true)),
1063 body: vec![PythonNode::Pass],
1064 };
1065 let features = enum_.extract_features(&node);
1066 assert!(features.contains(&"while".to_string()));
1067 }
1068
1069 #[test]
1070 fn test_extract_features_for() {
1071 let enum_ = PythonEnumerator::new(2);
1072 let node = PythonNode::For {
1073 target: "i".to_string(),
1074 iter: Box::new(PythonNode::List(vec![])),
1075 body: vec![PythonNode::Pass],
1076 };
1077 let features = enum_.extract_features(&node);
1078 assert!(features.contains(&"for".to_string()));
1079 }
1080
1081 #[test]
1082 fn test_extract_features_compare() {
1083 let enum_ = PythonEnumerator::new(2);
1084 let node = PythonNode::Compare {
1085 left: Box::new(PythonNode::IntLit(1)),
1086 op: CompareOp::Lt,
1087 right: Box::new(PythonNode::IntLit(2)),
1088 };
1089 let features = enum_.extract_features(&node);
1090 assert!(features.contains(&"compare".to_string()));
1091 }
1092
1093 #[test]
1094 fn test_depth_if() {
1095 let node = PythonNode::If {
1096 test: Box::new(PythonNode::BoolLit(true)),
1097 body: vec![PythonNode::Pass],
1098 orelse: vec![],
1099 };
1100 assert!(node.depth() >= 2);
1101 }
1102
1103 #[test]
1104 fn test_depth_while() {
1105 let node = PythonNode::While {
1106 test: Box::new(PythonNode::BoolLit(true)),
1107 body: vec![PythonNode::Pass],
1108 };
1109 assert!(node.depth() >= 2);
1110 }
1111
1112 #[test]
1113 fn test_depth_for() {
1114 let node = PythonNode::For {
1115 target: "i".to_string(),
1116 iter: Box::new(PythonNode::List(vec![])),
1117 body: vec![PythonNode::Pass],
1118 };
1119 assert!(node.depth() >= 2);
1120 }
1121
1122 #[test]
1123 fn test_depth_funcdef() {
1124 let node = PythonNode::FuncDef {
1125 name: "f".to_string(),
1126 args: vec![],
1127 body: vec![PythonNode::Pass],
1128 };
1129 assert!(node.depth() >= 2);
1130 }
1131}