1use depyler_hir::hir::HirModule;
26use anyhow::Result;
27use proc_macro2::TokenStream;
28use std::hash::{Hash, Hasher};
29
30#[cfg(feature = "generative")]
31use entrenar::search::{Action, ActionSpace, MctsConfig, MctsSearch, Reward, State, StateSpace};
32
33#[derive(Debug, Clone)]
35pub struct GenerativeRepairConfig {
36 pub max_iterations: usize,
38 pub exploration_constant: f64,
40 pub max_simulation_depth: usize,
42 pub use_discriminator: bool,
44 pub seed: u64,
46}
47
48impl Default for GenerativeRepairConfig {
49 fn default() -> Self {
50 Self {
51 max_iterations: 100,
52 exploration_constant: std::f64::consts::SQRT_2,
53 max_simulation_depth: 50,
54 use_discriminator: false,
55 seed: 0,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62#[allow(dead_code)] pub struct CodeState {
64 tokens: Vec<String>,
66 is_complete: bool,
68}
69
70impl CodeState {
71 pub fn new(tokens: Vec<String>) -> Self {
73 let is_complete = tokens.iter().any(|t| t == "EOF");
74 Self {
75 tokens,
76 is_complete,
77 }
78 }
79
80 pub fn initial() -> Self {
82 Self {
83 tokens: vec![],
84 is_complete: false,
85 }
86 }
87
88 pub fn tokens(&self) -> &[String] {
90 &self.tokens
91 }
92}
93
94impl PartialEq for CodeState {
95 fn eq(&self, other: &Self) -> bool {
96 self.tokens == other.tokens
97 }
98}
99
100impl Eq for CodeState {}
101
102impl Hash for CodeState {
103 fn hash<H: Hasher>(&self, state: &mut H) {
104 self.tokens.hash(state);
105 }
106}
107
108#[cfg(feature = "generative")]
109impl State for CodeState {
110 fn is_terminal(&self) -> bool {
111 self.is_complete
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Hash)]
117pub struct CodeAction {
118 name: String,
120 token: String,
122}
123
124impl CodeAction {
125 pub fn new(name: impl Into<String>, token: impl Into<String>) -> Self {
127 Self {
128 name: name.into(),
129 token: token.into(),
130 }
131 }
132}
133
134#[cfg(feature = "generative")]
135impl Action for CodeAction {
136 fn name(&self) -> &str {
137 &self.name
138 }
139}
140
141#[cfg(feature = "generative")]
143pub struct CodeStateSpace {
144 target_patterns: Vec<String>,
146}
147
148#[cfg(feature = "generative")]
149impl CodeStateSpace {
150 pub fn new(target_patterns: Vec<String>) -> Self {
152 Self { target_patterns }
153 }
154}
155
156#[cfg(feature = "generative")]
157impl StateSpace<CodeState, CodeAction> for CodeStateSpace {
158 fn apply(&self, state: &CodeState, action: &CodeAction) -> CodeState {
159 let mut new_tokens = state.tokens.clone();
160 new_tokens.push(action.token.clone());
161 CodeState::new(new_tokens)
162 }
163
164 fn evaluate(&self, state: &CodeState) -> Reward {
165 let tokens_str = state.tokens.join(" ");
167 let matches = self
168 .target_patterns
169 .iter()
170 .filter(|p| tokens_str.contains(*p))
171 .count();
172
173 if self.target_patterns.is_empty() {
174 0.5 } else {
176 matches as f64 / self.target_patterns.len() as f64
177 }
178 }
179
180 fn clone_space(&self) -> Box<dyn StateSpace<CodeState, CodeAction> + Send + Sync> {
181 Box::new(Self {
182 target_patterns: self.target_patterns.clone(),
183 })
184 }
185}
186
187#[cfg(feature = "generative")]
189pub struct CodeActionSpace {
190 available_actions: Vec<CodeAction>,
192}
193
194#[cfg(feature = "generative")]
195impl CodeActionSpace {
196 pub fn new() -> Self {
198 Self {
199 available_actions: vec![
200 CodeAction::new("add_fn", "fn"),
201 CodeAction::new("add_let", "let"),
202 CodeAction::new("add_return", "return"),
203 CodeAction::new("add_if", "if"),
204 CodeAction::new("add_else", "else"),
205 CodeAction::new("add_for", "for"),
206 CodeAction::new("add_while", "while"),
207 CodeAction::new("add_match", "match"),
208 CodeAction::new("add_struct", "struct"),
209 CodeAction::new("add_impl", "impl"),
210 CodeAction::new("add_pub", "pub"),
211 CodeAction::new("add_mut", "mut"),
212 CodeAction::new("add_ref", "&"),
213 CodeAction::new("add_semicolon", ";"),
214 CodeAction::new("add_brace_open", "{"),
215 CodeAction::new("add_brace_close", "}"),
216 CodeAction::new("add_paren_open", "("),
217 CodeAction::new("add_paren_close", ")"),
218 CodeAction::new("add_arrow", "->"),
219 CodeAction::new("add_i32", "i32"),
220 CodeAction::new("add_bool", "bool"),
221 CodeAction::new("add_string", "String"),
222 CodeAction::new("complete", "EOF"),
223 ],
224 }
225 }
226}
227
228#[cfg(feature = "generative")]
229impl Default for CodeActionSpace {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235#[cfg(feature = "generative")]
236impl ActionSpace<CodeState, CodeAction> for CodeActionSpace {
237 fn legal_actions(&self, state: &CodeState) -> Vec<CodeAction> {
238 if state.is_terminal() {
239 vec![]
240 } else {
241 self.available_actions.clone()
242 }
243 }
244}
245
246pub struct GenerativeRepair {
248 config: GenerativeRepairConfig,
249}
250
251impl GenerativeRepair {
252 pub fn new() -> Self {
254 Self {
255 config: GenerativeRepairConfig::default(),
256 }
257 }
258
259 pub fn with_config(config: GenerativeRepairConfig) -> Self {
261 Self { config }
262 }
263
264 #[cfg(feature = "generative")]
274 pub fn synthesize(&self, hir: &HirModule) -> Result<TokenStream> {
275 let target_patterns = self.extract_target_patterns(hir);
277
278 let mcts_config = MctsConfig {
280 max_iterations: self.config.max_iterations,
281 exploration_constant: self.config.exploration_constant,
282 max_simulation_depth: self.config.max_simulation_depth,
283 ..Default::default()
284 };
285
286 let initial_state = CodeState::initial();
287 let action_space = CodeActionSpace::new();
288 let state_space = CodeStateSpace::new(target_patterns);
289
290 let mut mcts = if self.config.seed > 0 {
292 MctsSearch::with_seed(initial_state, &action_space, mcts_config, self.config.seed)
293 } else {
294 MctsSearch::new(initial_state, &action_space, mcts_config)
295 };
296
297 let result = mcts.search(&state_space, &action_space, None);
298
299 if let Some(state) = result.resulting_state {
301 self.tokens_to_stream(&state)
302 } else {
303 Ok(TokenStream::new())
304 }
305 }
306
307 #[cfg(not(feature = "generative"))]
309 pub fn synthesize(&self, _hir: &HirModule) -> Result<TokenStream> {
310 Ok(TokenStream::new())
312 }
313
314 #[cfg(feature = "generative")]
316 fn extract_target_patterns(&self, hir: &HirModule) -> Vec<String> {
317 let mut patterns = Vec::new();
318
319 for func in &hir.functions {
321 patterns.push(format!("fn {}", func.name));
322
323 for param in &func.params {
325 patterns.push(param.name.clone());
326 }
327
328 if !matches!(
330 func.ret_type,
331 depyler_hir::hir::Type::Unknown | depyler_hir::hir::Type::None
332 ) {
333 patterns.push("->".to_string());
334 }
335 }
336
337 for class in &hir.classes {
339 patterns.push(format!("struct {}", class.name));
340 }
341
342 patterns
343 }
344
345 #[cfg(feature = "generative")]
347 fn tokens_to_stream(&self, state: &CodeState) -> Result<TokenStream> {
348 let code = state
349 .tokens()
350 .iter()
351 .filter(|t| *t != "EOF")
352 .cloned()
353 .collect::<Vec<_>>()
354 .join(" ");
355
356 match code.parse::<TokenStream>() {
358 Ok(ts) => Ok(ts),
359 Err(_) => {
360 Ok(TokenStream::new())
362 }
363 }
364 }
365
366 pub fn config(&self) -> &GenerativeRepairConfig {
368 &self.config
369 }
370}
371
372impl Default for GenerativeRepair {
373 fn default() -> Self {
374 Self::new()
375 }
376}
377
378#[derive(Debug, Clone)]
380pub struct SynthesisResult {
381 pub success: bool,
383 pub code: Option<String>,
385 pub iterations: usize,
387 pub expected_reward: f64,
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 fn create_empty_hir() -> HirModule {
396 HirModule {
397 functions: vec![],
398 imports: vec![],
399 type_aliases: vec![],
400 protocols: vec![],
401 classes: vec![],
402 constants: vec![],
403 top_level_stmts: vec![],
404 }
405 }
406
407 #[test]
408 fn test_generative_synthesis_stub() {
409 let repair = GenerativeRepair::new();
411 let hir = create_empty_hir();
412
413 let result = repair.synthesize(&hir);
415
416 assert!(result.is_ok(), "synthesize should return Ok for empty HIR");
418 }
419
420 #[test]
421 fn test_generative_repair_config_default() {
422 let config = GenerativeRepairConfig::default();
423 assert_eq!(config.max_iterations, 100);
424 assert!(config.exploration_constant > 0.0);
425 assert_eq!(config.max_simulation_depth, 50);
426 assert!(!config.use_discriminator);
427 assert_eq!(config.seed, 0);
428 }
429
430 #[test]
431 fn test_generative_repair_with_config() {
432 let config = GenerativeRepairConfig {
433 max_iterations: 500,
434 exploration_constant: 2.0,
435 max_simulation_depth: 100,
436 use_discriminator: true,
437 seed: 42,
438 };
439
440 let repair = GenerativeRepair::with_config(config);
441 assert_eq!(repair.config().max_iterations, 500);
442 assert!(repair.config().use_discriminator);
443 assert_eq!(repair.config().seed, 42);
444 }
445
446 #[test]
447 fn test_code_state_creation() {
448 let state = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
449 assert_eq!(state.tokens().len(), 2);
450 assert!(!state.is_complete);
451 }
452
453 #[test]
454 fn test_code_state_terminal() {
455 let state = CodeState::new(vec!["fn".to_string(), "EOF".to_string()]);
456 assert!(state.is_complete);
457 }
458
459 #[test]
460 fn test_code_action_creation() {
461 let action = CodeAction::new("add_fn", "fn");
462 assert_eq!(action.name, "add_fn");
463 assert_eq!(action.token, "fn");
464 }
465
466 #[test]
467 fn test_synthesis_result_default() {
468 let result = SynthesisResult {
469 success: true,
470 code: Some("fn test() {}".to_string()),
471 iterations: 100,
472 expected_reward: 0.95,
473 };
474 assert!(result.success);
475 assert!(result.code.is_some());
476 assert_eq!(result.iterations, 100);
477 }
478
479 #[test]
482 fn test_code_state_initial() {
483 let state = CodeState::initial();
484 assert!(state.tokens().is_empty());
485 assert!(!state.is_complete);
486 }
487
488 #[test]
489 fn test_code_state_tokens_accessor() {
490 let state = CodeState::new(vec!["let".to_string(), "x".to_string(), "=".to_string()]);
491 let tokens = state.tokens();
492 assert_eq!(tokens.len(), 3);
493 assert_eq!(tokens[0], "let");
494 assert_eq!(tokens[1], "x");
495 assert_eq!(tokens[2], "=");
496 }
497
498 #[test]
499 fn test_code_state_partial_eq() {
500 let state1 = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
501 let state2 = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
502 let state3 = CodeState::new(vec!["fn".to_string(), "other".to_string()]);
503
504 assert_eq!(state1, state2);
505 assert_ne!(state1, state3);
506 }
507
508 #[test]
509 fn test_code_state_hash() {
510 use std::collections::HashSet;
511
512 let state1 = CodeState::new(vec!["fn".to_string()]);
513 let state2 = CodeState::new(vec!["fn".to_string()]);
514 let state3 = CodeState::new(vec!["let".to_string()]);
515
516 let mut set = HashSet::new();
517 set.insert(state1.tokens.clone());
518 set.insert(state2.tokens.clone());
519 set.insert(state3.tokens.clone());
520
521 assert_eq!(set.len(), 2);
523 }
524
525 #[test]
526 fn test_code_action_partial_eq() {
527 let action1 = CodeAction::new("add_fn", "fn");
528 let action2 = CodeAction::new("add_fn", "fn");
529 let action3 = CodeAction::new("add_let", "let");
530
531 assert_eq!(action1, action2);
532 assert_ne!(action1, action3);
533 }
534
535 #[test]
536 fn test_code_action_hash() {
537 use std::collections::HashSet;
538
539 let action1 = CodeAction::new("add_fn", "fn");
540 let action2 = CodeAction::new("add_fn", "fn");
541 let action3 = CodeAction::new("add_let", "let");
542
543 let mut set = HashSet::new();
544 set.insert(action1);
545 set.insert(action2);
546 set.insert(action3);
547
548 assert_eq!(set.len(), 2);
550 }
551
552 #[test]
553 fn test_code_action_debug() {
554 let action = CodeAction::new("add_struct", "struct");
555 let debug_str = format!("{:?}", action);
556 assert!(debug_str.contains("CodeAction"));
557 assert!(debug_str.contains("add_struct"));
558 assert!(debug_str.contains("struct"));
559 }
560
561 #[test]
562 fn test_code_action_clone() {
563 let action = CodeAction::new("add_impl", "impl");
564 let cloned = action.clone();
565 assert_eq!(action, cloned);
566 assert_eq!(cloned.name, "add_impl");
567 assert_eq!(cloned.token, "impl");
568 }
569
570 #[test]
571 fn test_generative_repair_default() {
572 let repair: GenerativeRepair = Default::default();
573 assert_eq!(repair.config().max_iterations, 100);
574 assert!(!repair.config().use_discriminator);
575 }
576
577 #[test]
578 fn test_generative_repair_config_debug() {
579 let config = GenerativeRepairConfig::default();
580 let debug_str = format!("{:?}", config);
581 assert!(debug_str.contains("GenerativeRepairConfig"));
582 assert!(debug_str.contains("max_iterations"));
583 assert!(debug_str.contains("exploration_constant"));
584 }
585
586 #[test]
587 fn test_generative_repair_config_clone() {
588 let config = GenerativeRepairConfig {
589 max_iterations: 200,
590 exploration_constant: 1.5,
591 max_simulation_depth: 75,
592 use_discriminator: true,
593 seed: 123,
594 };
595 let cloned = config.clone();
596 assert_eq!(cloned.max_iterations, 200);
597 assert_eq!(cloned.exploration_constant, 1.5);
598 assert_eq!(cloned.max_simulation_depth, 75);
599 assert!(cloned.use_discriminator);
600 assert_eq!(cloned.seed, 123);
601 }
602
603 #[test]
604 fn test_synthesis_result_debug() {
605 let result = SynthesisResult {
606 success: false,
607 code: None,
608 iterations: 50,
609 expected_reward: 0.25,
610 };
611 let debug_str = format!("{:?}", result);
612 assert!(debug_str.contains("SynthesisResult"));
613 assert!(debug_str.contains("success"));
614 assert!(debug_str.contains("false"));
615 }
616
617 #[test]
618 fn test_synthesis_result_clone() {
619 let result = SynthesisResult {
620 success: true,
621 code: Some("pub fn foo() -> i32 { 42 }".to_string()),
622 iterations: 150,
623 expected_reward: 0.85,
624 };
625 let cloned = result.clone();
626 assert!(cloned.success);
627 assert_eq!(cloned.code, Some("pub fn foo() -> i32 { 42 }".to_string()));
628 assert_eq!(cloned.iterations, 150);
629 assert_eq!(cloned.expected_reward, 0.85);
630 }
631
632 #[test]
633 fn test_synthesis_result_no_code() {
634 let result = SynthesisResult {
635 success: false,
636 code: None,
637 iterations: 0,
638 expected_reward: 0.0,
639 };
640 assert!(!result.success);
641 assert!(result.code.is_none());
642 assert_eq!(result.iterations, 0);
643 assert_eq!(result.expected_reward, 0.0);
644 }
645
646 #[test]
647 fn test_code_state_complete_with_eof() {
648 let state = CodeState::new(vec![
649 "fn".to_string(),
650 "main".to_string(),
651 "(".to_string(),
652 ")".to_string(),
653 "{".to_string(),
654 "}".to_string(),
655 "EOF".to_string(),
656 ]);
657 assert!(state.is_complete);
658 assert_eq!(state.tokens().len(), 7);
659 }
660
661 #[test]
662 fn test_code_state_not_complete_without_eof() {
663 let state = CodeState::new(vec![
664 "fn".to_string(),
665 "main".to_string(),
666 "(".to_string(),
667 ")".to_string(),
668 ]);
669 assert!(!state.is_complete);
670 }
671
672 #[test]
673 fn test_code_action_with_special_characters() {
674 let action1 = CodeAction::new("add_arrow", "->");
675 assert_eq!(action1.token, "->");
676
677 let action2 = CodeAction::new("add_ref", "&");
678 assert_eq!(action2.token, "&");
679
680 let action3 = CodeAction::new("add_semicolon", ";");
681 assert_eq!(action3.token, ";");
682 }
683
684 #[test]
685 fn test_generative_repair_config_exploration_constant() {
686 let config = GenerativeRepairConfig::default();
687 assert!(config.exploration_constant > 1.4);
689 assert!(config.exploration_constant < 1.5);
690 }
691
692 #[test]
693 fn test_generative_repair_config_custom_seed() {
694 let config = GenerativeRepairConfig {
695 seed: 12345,
696 ..Default::default()
697 };
698 let repair = GenerativeRepair::with_config(config);
699 assert_eq!(repair.config().seed, 12345);
700 }
701
702 #[test]
703 fn test_generative_repair_config_method() {
704 let repair = GenerativeRepair::new();
705 let config = repair.config();
706 assert_eq!(config.max_iterations, 100);
707 assert_eq!(config.max_simulation_depth, 50);
708 }
709
710 #[test]
711 fn test_code_state_debug() {
712 let state = CodeState::new(vec!["let".to_string(), "mut".to_string()]);
713 let debug_str = format!("{:?}", state);
714 assert!(debug_str.contains("CodeState"));
715 assert!(debug_str.contains("tokens"));
716 }
717
718 #[test]
719 fn test_code_state_clone() {
720 let state = CodeState::new(vec!["struct".to_string(), "Point".to_string()]);
721 let cloned = state.clone();
722 assert_eq!(state, cloned);
723 assert_eq!(cloned.tokens().len(), 2);
724 }
725
726 #[cfg(feature = "generative")]
727 mod generative_tests {
728 use super::*;
729
730 #[test]
731 fn test_code_action_space_default() {
732 let action_space = CodeActionSpace::new();
733 let state = CodeState::initial();
734 let actions = action_space.legal_actions(&state);
735
736 assert!(!actions.is_empty());
738
739 let action_names: Vec<_> = actions.iter().map(|a| a.name.as_str()).collect();
741 assert!(action_names.contains(&"add_fn"));
742 assert!(action_names.contains(&"add_let"));
743 assert!(action_names.contains(&"complete"));
744 }
745
746 #[test]
747 fn test_code_state_space_evaluate() {
748 let state_space = CodeStateSpace::new(vec!["fn".to_string(), "test".to_string()]);
749
750 let empty = CodeState::initial();
752 let reward_empty = state_space.evaluate(&empty);
753 assert_eq!(reward_empty, 0.0);
754
755 let partial = CodeState::new(vec!["fn".to_string()]);
757 let reward_partial = state_space.evaluate(&partial);
758 assert!(reward_partial > 0.0);
759 assert!(reward_partial < 1.0);
760
761 let full = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
763 let reward_full = state_space.evaluate(&full);
764 assert_eq!(reward_full, 1.0);
765 }
766
767 #[test]
768 fn test_mcts_integration() {
769 let config = GenerativeRepairConfig {
770 max_iterations: 10,
771 seed: 42,
772 ..Default::default()
773 };
774
775 let repair = GenerativeRepair::with_config(config);
776 let hir = create_empty_hir();
777
778 let result = repair.synthesize(&hir);
779 assert!(result.is_ok());
780 }
781 }
782}