Skip to main content

depyler_tooling/
generative_repair.rs

1//! Generative Code Repair Engine
2//!
3//! Integrates entrenar's MCTS (Monte Carlo Tree Search) and GAN capabilities
4//! for generative code synthesis and repair.
5//!
6//! # Overview
7//!
8//! The generative repair engine uses:
9//! - **MCTS Search**: For exploring the space of possible AST transformations
10//! - **GAN Discriminator**: For validating generated Rust code (future)
11//!
12//! # Architecture
13//!
14//! ```text
15//! HIR → CodeState → MCTS Search → Best Action → RustAst
16//!                       ↓
17//!              GAN Discriminator (validation)
18//! ```
19//!
20//! # Feature Flag
21//!
22//! This module requires the `generative` feature to enable MCTS functionality.
23//! Without the feature, a stub implementation is provided.
24
25use depyler_hir::hir::HirModule;
26use anyhow::Result;
27use proc_macro2::TokenStream;
28use std::hash::{Hash, Hasher};
29
30#[cfg(feature = "generative")]
31use entrenar::search::{Action, ActionSpace, MctsConfig, MctsSearch, Reward, State, StateSpace};
32
33/// Configuration for the generative repair engine
34#[derive(Debug, Clone)]
35pub struct GenerativeRepairConfig {
36    /// Maximum MCTS iterations
37    pub max_iterations: usize,
38    /// Exploration constant for UCB1
39    pub exploration_constant: f64,
40    /// Maximum simulation depth
41    pub max_simulation_depth: usize,
42    /// Whether to use GAN discriminator for validation
43    pub use_discriminator: bool,
44    /// Random seed for reproducibility (0 = random)
45    pub seed: u64,
46}
47
48impl Default for GenerativeRepairConfig {
49    fn default() -> Self {
50        Self {
51            max_iterations: 100,
52            exploration_constant: std::f64::consts::SQRT_2,
53            max_simulation_depth: 50,
54            use_discriminator: false,
55            seed: 0,
56        }
57    }
58}
59
60/// Represents the state of code generation (partial AST)
61#[derive(Debug, Clone)]
62#[allow(dead_code)] // is_complete is used only when "generative" feature is enabled
63pub struct CodeState {
64    /// Token representation of the partial AST
65    tokens: Vec<String>,
66    /// Whether this is a terminal (complete) state
67    is_complete: bool,
68}
69
70impl CodeState {
71    /// Create a new code state from tokens
72    pub fn new(tokens: Vec<String>) -> Self {
73        let is_complete = tokens.iter().any(|t| t == "EOF");
74        Self {
75            tokens,
76            is_complete,
77        }
78    }
79
80    /// Create an empty initial state
81    pub fn initial() -> Self {
82        Self {
83            tokens: vec![],
84            is_complete: false,
85        }
86    }
87
88    /// Get the current tokens
89    pub fn tokens(&self) -> &[String] {
90        &self.tokens
91    }
92}
93
94impl PartialEq for CodeState {
95    fn eq(&self, other: &Self) -> bool {
96        self.tokens == other.tokens
97    }
98}
99
100impl Eq for CodeState {}
101
102impl Hash for CodeState {
103    fn hash<H: Hasher>(&self, state: &mut H) {
104        self.tokens.hash(state);
105    }
106}
107
108#[cfg(feature = "generative")]
109impl State for CodeState {
110    fn is_terminal(&self) -> bool {
111        self.is_complete
112    }
113}
114
115/// AST transformation action
116#[derive(Debug, Clone, PartialEq, Eq, Hash)]
117pub struct CodeAction {
118    /// Name of the transformation
119    name: String,
120    /// Token to add/modify
121    token: String,
122}
123
124impl CodeAction {
125    /// Create a new code action
126    pub fn new(name: impl Into<String>, token: impl Into<String>) -> Self {
127        Self {
128            name: name.into(),
129            token: token.into(),
130        }
131    }
132}
133
134#[cfg(feature = "generative")]
135impl Action for CodeAction {
136    fn name(&self) -> &str {
137        &self.name
138    }
139}
140
141/// State space for code generation
142#[cfg(feature = "generative")]
143pub struct CodeStateSpace {
144    /// Target patterns to match (for reward calculation)
145    target_patterns: Vec<String>,
146}
147
148#[cfg(feature = "generative")]
149impl CodeStateSpace {
150    /// Create a new code state space
151    pub fn new(target_patterns: Vec<String>) -> Self {
152        Self { target_patterns }
153    }
154}
155
156#[cfg(feature = "generative")]
157impl StateSpace<CodeState, CodeAction> for CodeStateSpace {
158    fn apply(&self, state: &CodeState, action: &CodeAction) -> CodeState {
159        let mut new_tokens = state.tokens.clone();
160        new_tokens.push(action.token.clone());
161        CodeState::new(new_tokens)
162    }
163
164    fn evaluate(&self, state: &CodeState) -> Reward {
165        // Simple reward: 1.0 if tokens contain all target patterns, 0.0 otherwise
166        let tokens_str = state.tokens.join(" ");
167        let matches = self
168            .target_patterns
169            .iter()
170            .filter(|p| tokens_str.contains(*p))
171            .count();
172
173        if self.target_patterns.is_empty() {
174            0.5 // Neutral if no patterns
175        } else {
176            matches as f64 / self.target_patterns.len() as f64
177        }
178    }
179
180    fn clone_space(&self) -> Box<dyn StateSpace<CodeState, CodeAction> + Send + Sync> {
181        Box::new(Self {
182            target_patterns: self.target_patterns.clone(),
183        })
184    }
185}
186
187/// Action space for code generation
188#[cfg(feature = "generative")]
189pub struct CodeActionSpace {
190    /// Available actions from any state
191    available_actions: Vec<CodeAction>,
192}
193
194#[cfg(feature = "generative")]
195impl CodeActionSpace {
196    /// Create a new code action space with default Rust tokens
197    pub fn new() -> Self {
198        Self {
199            available_actions: vec![
200                CodeAction::new("add_fn", "fn"),
201                CodeAction::new("add_let", "let"),
202                CodeAction::new("add_return", "return"),
203                CodeAction::new("add_if", "if"),
204                CodeAction::new("add_else", "else"),
205                CodeAction::new("add_for", "for"),
206                CodeAction::new("add_while", "while"),
207                CodeAction::new("add_match", "match"),
208                CodeAction::new("add_struct", "struct"),
209                CodeAction::new("add_impl", "impl"),
210                CodeAction::new("add_pub", "pub"),
211                CodeAction::new("add_mut", "mut"),
212                CodeAction::new("add_ref", "&"),
213                CodeAction::new("add_semicolon", ";"),
214                CodeAction::new("add_brace_open", "{"),
215                CodeAction::new("add_brace_close", "}"),
216                CodeAction::new("add_paren_open", "("),
217                CodeAction::new("add_paren_close", ")"),
218                CodeAction::new("add_arrow", "->"),
219                CodeAction::new("add_i32", "i32"),
220                CodeAction::new("add_bool", "bool"),
221                CodeAction::new("add_string", "String"),
222                CodeAction::new("complete", "EOF"),
223            ],
224        }
225    }
226}
227
228#[cfg(feature = "generative")]
229impl Default for CodeActionSpace {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235#[cfg(feature = "generative")]
236impl ActionSpace<CodeState, CodeAction> for CodeActionSpace {
237    fn legal_actions(&self, state: &CodeState) -> Vec<CodeAction> {
238        if state.is_terminal() {
239            vec![]
240        } else {
241            self.available_actions.clone()
242        }
243    }
244}
245
246/// Generative repair engine for synthesizing Rust code from HIR
247pub struct GenerativeRepair {
248    config: GenerativeRepairConfig,
249}
250
251impl GenerativeRepair {
252    /// Create a new generative repair engine with default config
253    pub fn new() -> Self {
254        Self {
255            config: GenerativeRepairConfig::default(),
256        }
257    }
258
259    /// Create a new generative repair engine with custom config
260    pub fn with_config(config: GenerativeRepairConfig) -> Self {
261        Self { config }
262    }
263
264    /// Synthesize Rust code from HIR using MCTS-guided search
265    ///
266    /// # Arguments
267    ///
268    /// * `hir` - The High-level Intermediate Representation to synthesize from
269    ///
270    /// # Returns
271    ///
272    /// Returns the synthesized Rust code as a TokenStream
273    #[cfg(feature = "generative")]
274    pub fn synthesize(&self, hir: &HirModule) -> Result<TokenStream> {
275        // Extract target patterns from HIR
276        let target_patterns = self.extract_target_patterns(hir);
277
278        // Create MCTS components
279        let mcts_config = MctsConfig {
280            max_iterations: self.config.max_iterations,
281            exploration_constant: self.config.exploration_constant,
282            max_simulation_depth: self.config.max_simulation_depth,
283            ..Default::default()
284        };
285
286        let initial_state = CodeState::initial();
287        let action_space = CodeActionSpace::new();
288        let state_space = CodeStateSpace::new(target_patterns);
289
290        // Run MCTS search
291        let mut mcts = if self.config.seed > 0 {
292            MctsSearch::with_seed(initial_state, &action_space, mcts_config, self.config.seed)
293        } else {
294            MctsSearch::new(initial_state, &action_space, mcts_config)
295        };
296
297        let result = mcts.search(&state_space, &action_space, None);
298
299        // Convert resulting state to TokenStream
300        if let Some(state) = result.resulting_state {
301            self.tokens_to_stream(&state)
302        } else {
303            Ok(TokenStream::new())
304        }
305    }
306
307    /// Stub implementation when generative feature is disabled
308    #[cfg(not(feature = "generative"))]
309    pub fn synthesize(&self, _hir: &HirModule) -> Result<TokenStream> {
310        // Stub implementation - requires "generative" feature for MCTS
311        Ok(TokenStream::new())
312    }
313
314    /// Extract target patterns from HIR for guiding MCTS search
315    #[cfg(feature = "generative")]
316    fn extract_target_patterns(&self, hir: &HirModule) -> Vec<String> {
317        let mut patterns = Vec::new();
318
319        // Add function names as targets
320        for func in &hir.functions {
321            patterns.push(format!("fn {}", func.name));
322
323            // Add parameter patterns
324            for param in &func.params {
325                patterns.push(param.name.clone());
326            }
327
328            // Add return type pattern if present
329            if !matches!(
330                func.ret_type,
331                depyler_hir::hir::Type::Unknown | depyler_hir::hir::Type::None
332            ) {
333                patterns.push("->".to_string());
334            }
335        }
336
337        // Add struct names
338        for class in &hir.classes {
339            patterns.push(format!("struct {}", class.name));
340        }
341
342        patterns
343    }
344
345    /// Convert code state tokens to TokenStream
346    #[cfg(feature = "generative")]
347    fn tokens_to_stream(&self, state: &CodeState) -> Result<TokenStream> {
348        let code = state
349            .tokens()
350            .iter()
351            .filter(|t| *t != "EOF")
352            .cloned()
353            .collect::<Vec<_>>()
354            .join(" ");
355
356        // Try to parse as Rust code
357        match code.parse::<TokenStream>() {
358            Ok(ts) => Ok(ts),
359            Err(_) => {
360                // Return empty if parsing fails
361                Ok(TokenStream::new())
362            }
363        }
364    }
365
366    /// Get the current configuration
367    pub fn config(&self) -> &GenerativeRepairConfig {
368        &self.config
369    }
370}
371
372impl Default for GenerativeRepair {
373    fn default() -> Self {
374        Self::new()
375    }
376}
377
378/// Result of a generative synthesis operation
379#[derive(Debug, Clone)]
380pub struct SynthesisResult {
381    /// Whether synthesis was successful
382    pub success: bool,
383    /// Generated code (if successful)
384    pub code: Option<String>,
385    /// Number of MCTS iterations performed
386    pub iterations: usize,
387    /// Expected reward of the best path
388    pub expected_reward: f64,
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    fn create_empty_hir() -> HirModule {
396        HirModule {
397            functions: vec![],
398            imports: vec![],
399            type_aliases: vec![],
400            protocols: vec![],
401            classes: vec![],
402            constants: vec![],
403            top_level_stmts: vec![],
404        }
405    }
406
407    #[test]
408    fn test_generative_synthesis_stub() {
409        // TDD Red: This test validates the basic API exists
410        let repair = GenerativeRepair::new();
411        let hir = create_empty_hir();
412
413        // Attempt to call synthesize
414        let result = repair.synthesize(&hir);
415
416        // For now, just verify it doesn't panic and returns Ok
417        assert!(result.is_ok(), "synthesize should return Ok for empty HIR");
418    }
419
420    #[test]
421    fn test_generative_repair_config_default() {
422        let config = GenerativeRepairConfig::default();
423        assert_eq!(config.max_iterations, 100);
424        assert!(config.exploration_constant > 0.0);
425        assert_eq!(config.max_simulation_depth, 50);
426        assert!(!config.use_discriminator);
427        assert_eq!(config.seed, 0);
428    }
429
430    #[test]
431    fn test_generative_repair_with_config() {
432        let config = GenerativeRepairConfig {
433            max_iterations: 500,
434            exploration_constant: 2.0,
435            max_simulation_depth: 100,
436            use_discriminator: true,
437            seed: 42,
438        };
439
440        let repair = GenerativeRepair::with_config(config);
441        assert_eq!(repair.config().max_iterations, 500);
442        assert!(repair.config().use_discriminator);
443        assert_eq!(repair.config().seed, 42);
444    }
445
446    #[test]
447    fn test_code_state_creation() {
448        let state = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
449        assert_eq!(state.tokens().len(), 2);
450        assert!(!state.is_complete);
451    }
452
453    #[test]
454    fn test_code_state_terminal() {
455        let state = CodeState::new(vec!["fn".to_string(), "EOF".to_string()]);
456        assert!(state.is_complete);
457    }
458
459    #[test]
460    fn test_code_action_creation() {
461        let action = CodeAction::new("add_fn", "fn");
462        assert_eq!(action.name, "add_fn");
463        assert_eq!(action.token, "fn");
464    }
465
466    #[test]
467    fn test_synthesis_result_default() {
468        let result = SynthesisResult {
469            success: true,
470            code: Some("fn test() {}".to_string()),
471            iterations: 100,
472            expected_reward: 0.95,
473        };
474        assert!(result.success);
475        assert!(result.code.is_some());
476        assert_eq!(result.iterations, 100);
477    }
478
479    // DEPYLER-COVERAGE-95: Additional tests for untested components
480
481    #[test]
482    fn test_code_state_initial() {
483        let state = CodeState::initial();
484        assert!(state.tokens().is_empty());
485        assert!(!state.is_complete);
486    }
487
488    #[test]
489    fn test_code_state_tokens_accessor() {
490        let state = CodeState::new(vec!["let".to_string(), "x".to_string(), "=".to_string()]);
491        let tokens = state.tokens();
492        assert_eq!(tokens.len(), 3);
493        assert_eq!(tokens[0], "let");
494        assert_eq!(tokens[1], "x");
495        assert_eq!(tokens[2], "=");
496    }
497
498    #[test]
499    fn test_code_state_partial_eq() {
500        let state1 = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
501        let state2 = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
502        let state3 = CodeState::new(vec!["fn".to_string(), "other".to_string()]);
503
504        assert_eq!(state1, state2);
505        assert_ne!(state1, state3);
506    }
507
508    #[test]
509    fn test_code_state_hash() {
510        use std::collections::HashSet;
511
512        let state1 = CodeState::new(vec!["fn".to_string()]);
513        let state2 = CodeState::new(vec!["fn".to_string()]);
514        let state3 = CodeState::new(vec!["let".to_string()]);
515
516        let mut set = HashSet::new();
517        set.insert(state1.tokens.clone());
518        set.insert(state2.tokens.clone());
519        set.insert(state3.tokens.clone());
520
521        // state1 and state2 should hash to same value
522        assert_eq!(set.len(), 2);
523    }
524
525    #[test]
526    fn test_code_action_partial_eq() {
527        let action1 = CodeAction::new("add_fn", "fn");
528        let action2 = CodeAction::new("add_fn", "fn");
529        let action3 = CodeAction::new("add_let", "let");
530
531        assert_eq!(action1, action2);
532        assert_ne!(action1, action3);
533    }
534
535    #[test]
536    fn test_code_action_hash() {
537        use std::collections::HashSet;
538
539        let action1 = CodeAction::new("add_fn", "fn");
540        let action2 = CodeAction::new("add_fn", "fn");
541        let action3 = CodeAction::new("add_let", "let");
542
543        let mut set = HashSet::new();
544        set.insert(action1);
545        set.insert(action2);
546        set.insert(action3);
547
548        // action1 and action2 should be the same
549        assert_eq!(set.len(), 2);
550    }
551
552    #[test]
553    fn test_code_action_debug() {
554        let action = CodeAction::new("add_struct", "struct");
555        let debug_str = format!("{:?}", action);
556        assert!(debug_str.contains("CodeAction"));
557        assert!(debug_str.contains("add_struct"));
558        assert!(debug_str.contains("struct"));
559    }
560
561    #[test]
562    fn test_code_action_clone() {
563        let action = CodeAction::new("add_impl", "impl");
564        let cloned = action.clone();
565        assert_eq!(action, cloned);
566        assert_eq!(cloned.name, "add_impl");
567        assert_eq!(cloned.token, "impl");
568    }
569
570    #[test]
571    fn test_generative_repair_default() {
572        let repair: GenerativeRepair = Default::default();
573        assert_eq!(repair.config().max_iterations, 100);
574        assert!(!repair.config().use_discriminator);
575    }
576
577    #[test]
578    fn test_generative_repair_config_debug() {
579        let config = GenerativeRepairConfig::default();
580        let debug_str = format!("{:?}", config);
581        assert!(debug_str.contains("GenerativeRepairConfig"));
582        assert!(debug_str.contains("max_iterations"));
583        assert!(debug_str.contains("exploration_constant"));
584    }
585
586    #[test]
587    fn test_generative_repair_config_clone() {
588        let config = GenerativeRepairConfig {
589            max_iterations: 200,
590            exploration_constant: 1.5,
591            max_simulation_depth: 75,
592            use_discriminator: true,
593            seed: 123,
594        };
595        let cloned = config.clone();
596        assert_eq!(cloned.max_iterations, 200);
597        assert_eq!(cloned.exploration_constant, 1.5);
598        assert_eq!(cloned.max_simulation_depth, 75);
599        assert!(cloned.use_discriminator);
600        assert_eq!(cloned.seed, 123);
601    }
602
603    #[test]
604    fn test_synthesis_result_debug() {
605        let result = SynthesisResult {
606            success: false,
607            code: None,
608            iterations: 50,
609            expected_reward: 0.25,
610        };
611        let debug_str = format!("{:?}", result);
612        assert!(debug_str.contains("SynthesisResult"));
613        assert!(debug_str.contains("success"));
614        assert!(debug_str.contains("false"));
615    }
616
617    #[test]
618    fn test_synthesis_result_clone() {
619        let result = SynthesisResult {
620            success: true,
621            code: Some("pub fn foo() -> i32 { 42 }".to_string()),
622            iterations: 150,
623            expected_reward: 0.85,
624        };
625        let cloned = result.clone();
626        assert!(cloned.success);
627        assert_eq!(cloned.code, Some("pub fn foo() -> i32 { 42 }".to_string()));
628        assert_eq!(cloned.iterations, 150);
629        assert_eq!(cloned.expected_reward, 0.85);
630    }
631
632    #[test]
633    fn test_synthesis_result_no_code() {
634        let result = SynthesisResult {
635            success: false,
636            code: None,
637            iterations: 0,
638            expected_reward: 0.0,
639        };
640        assert!(!result.success);
641        assert!(result.code.is_none());
642        assert_eq!(result.iterations, 0);
643        assert_eq!(result.expected_reward, 0.0);
644    }
645
646    #[test]
647    fn test_code_state_complete_with_eof() {
648        let state = CodeState::new(vec![
649            "fn".to_string(),
650            "main".to_string(),
651            "(".to_string(),
652            ")".to_string(),
653            "{".to_string(),
654            "}".to_string(),
655            "EOF".to_string(),
656        ]);
657        assert!(state.is_complete);
658        assert_eq!(state.tokens().len(), 7);
659    }
660
661    #[test]
662    fn test_code_state_not_complete_without_eof() {
663        let state = CodeState::new(vec![
664            "fn".to_string(),
665            "main".to_string(),
666            "(".to_string(),
667            ")".to_string(),
668        ]);
669        assert!(!state.is_complete);
670    }
671
672    #[test]
673    fn test_code_action_with_special_characters() {
674        let action1 = CodeAction::new("add_arrow", "->");
675        assert_eq!(action1.token, "->");
676
677        let action2 = CodeAction::new("add_ref", "&");
678        assert_eq!(action2.token, "&");
679
680        let action3 = CodeAction::new("add_semicolon", ";");
681        assert_eq!(action3.token, ";");
682    }
683
684    #[test]
685    fn test_generative_repair_config_exploration_constant() {
686        let config = GenerativeRepairConfig::default();
687        // SQRT_2 ≈ 1.414...
688        assert!(config.exploration_constant > 1.4);
689        assert!(config.exploration_constant < 1.5);
690    }
691
692    #[test]
693    fn test_generative_repair_config_custom_seed() {
694        let config = GenerativeRepairConfig {
695            seed: 12345,
696            ..Default::default()
697        };
698        let repair = GenerativeRepair::with_config(config);
699        assert_eq!(repair.config().seed, 12345);
700    }
701
702    #[test]
703    fn test_generative_repair_config_method() {
704        let repair = GenerativeRepair::new();
705        let config = repair.config();
706        assert_eq!(config.max_iterations, 100);
707        assert_eq!(config.max_simulation_depth, 50);
708    }
709
710    #[test]
711    fn test_code_state_debug() {
712        let state = CodeState::new(vec!["let".to_string(), "mut".to_string()]);
713        let debug_str = format!("{:?}", state);
714        assert!(debug_str.contains("CodeState"));
715        assert!(debug_str.contains("tokens"));
716    }
717
718    #[test]
719    fn test_code_state_clone() {
720        let state = CodeState::new(vec!["struct".to_string(), "Point".to_string()]);
721        let cloned = state.clone();
722        assert_eq!(state, cloned);
723        assert_eq!(cloned.tokens().len(), 2);
724    }
725
726    #[cfg(feature = "generative")]
727    mod generative_tests {
728        use super::*;
729
730        #[test]
731        fn test_code_action_space_default() {
732            let action_space = CodeActionSpace::new();
733            let state = CodeState::initial();
734            let actions = action_space.legal_actions(&state);
735
736            // Should have available actions
737            assert!(!actions.is_empty());
738
739            // Should include common Rust tokens
740            let action_names: Vec<_> = actions.iter().map(|a| a.name.as_str()).collect();
741            assert!(action_names.contains(&"add_fn"));
742            assert!(action_names.contains(&"add_let"));
743            assert!(action_names.contains(&"complete"));
744        }
745
746        #[test]
747        fn test_code_state_space_evaluate() {
748            let state_space = CodeStateSpace::new(vec!["fn".to_string(), "test".to_string()]);
749
750            // Empty state
751            let empty = CodeState::initial();
752            let reward_empty = state_space.evaluate(&empty);
753            assert_eq!(reward_empty, 0.0);
754
755            // Partial match
756            let partial = CodeState::new(vec!["fn".to_string()]);
757            let reward_partial = state_space.evaluate(&partial);
758            assert!(reward_partial > 0.0);
759            assert!(reward_partial < 1.0);
760
761            // Full match
762            let full = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
763            let reward_full = state_space.evaluate(&full);
764            assert_eq!(reward_full, 1.0);
765        }
766
767        #[test]
768        fn test_mcts_integration() {
769            let config = GenerativeRepairConfig {
770                max_iterations: 10,
771                seed: 42,
772                ..Default::default()
773            };
774
775            let repair = GenerativeRepair::with_config(config);
776            let hir = create_empty_hir();
777
778            let result = repair.synthesize(&hir);
779            assert!(result.is_ok());
780        }
781    }
782}