use crate::hir::HirModule;
use anyhow::Result;
use proc_macro2::TokenStream;
use std::hash::{Hash, Hasher};
#[cfg(feature = "generative")]
use entrenar::search::{Action, ActionSpace, MctsConfig, MctsSearch, Reward, State, StateSpace};
#[derive(Debug, Clone)]
pub struct GenerativeRepairConfig {
pub max_iterations: usize,
pub exploration_constant: f64,
pub max_simulation_depth: usize,
pub use_discriminator: bool,
pub seed: u64,
}
impl Default for GenerativeRepairConfig {
fn default() -> Self {
Self {
max_iterations: 100,
exploration_constant: std::f64::consts::SQRT_2,
max_simulation_depth: 50,
use_discriminator: false,
seed: 0,
}
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] pub struct CodeState {
tokens: Vec<String>,
is_complete: bool,
}
impl CodeState {
pub fn new(tokens: Vec<String>) -> Self {
let is_complete = tokens.iter().any(|t| t == "EOF");
Self {
tokens,
is_complete,
}
}
pub fn initial() -> Self {
Self {
tokens: vec![],
is_complete: false,
}
}
pub fn tokens(&self) -> &[String] {
&self.tokens
}
}
impl PartialEq for CodeState {
fn eq(&self, other: &Self) -> bool {
self.tokens == other.tokens
}
}
impl Eq for CodeState {}
impl Hash for CodeState {
fn hash<H: Hasher>(&self, state: &mut H) {
self.tokens.hash(state);
}
}
#[cfg(feature = "generative")]
impl State for CodeState {
fn is_terminal(&self) -> bool {
self.is_complete
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CodeAction {
name: String,
token: String,
}
impl CodeAction {
pub fn new(name: impl Into<String>, token: impl Into<String>) -> Self {
Self {
name: name.into(),
token: token.into(),
}
}
}
#[cfg(feature = "generative")]
impl Action for CodeAction {
fn name(&self) -> &str {
&self.name
}
}
#[cfg(feature = "generative")]
pub struct CodeStateSpace {
target_patterns: Vec<String>,
}
#[cfg(feature = "generative")]
impl CodeStateSpace {
pub fn new(target_patterns: Vec<String>) -> Self {
Self { target_patterns }
}
}
#[cfg(feature = "generative")]
impl StateSpace<CodeState, CodeAction> for CodeStateSpace {
fn apply(&self, state: &CodeState, action: &CodeAction) -> CodeState {
let mut new_tokens = state.tokens.clone();
new_tokens.push(action.token.clone());
CodeState::new(new_tokens)
}
fn evaluate(&self, state: &CodeState) -> Reward {
let tokens_str = state.tokens.join(" ");
let matches = self
.target_patterns
.iter()
.filter(|p| tokens_str.contains(*p))
.count();
if self.target_patterns.is_empty() {
0.5 } else {
matches as f64 / self.target_patterns.len() as f64
}
}
fn clone_space(&self) -> Box<dyn StateSpace<CodeState, CodeAction> + Send + Sync> {
Box::new(Self {
target_patterns: self.target_patterns.clone(),
})
}
}
#[cfg(feature = "generative")]
pub struct CodeActionSpace {
available_actions: Vec<CodeAction>,
}
#[cfg(feature = "generative")]
impl CodeActionSpace {
pub fn new() -> Self {
Self {
available_actions: vec![
CodeAction::new("add_fn", "fn"),
CodeAction::new("add_let", "let"),
CodeAction::new("add_return", "return"),
CodeAction::new("add_if", "if"),
CodeAction::new("add_else", "else"),
CodeAction::new("add_for", "for"),
CodeAction::new("add_while", "while"),
CodeAction::new("add_match", "match"),
CodeAction::new("add_struct", "struct"),
CodeAction::new("add_impl", "impl"),
CodeAction::new("add_pub", "pub"),
CodeAction::new("add_mut", "mut"),
CodeAction::new("add_ref", "&"),
CodeAction::new("add_semicolon", ";"),
CodeAction::new("add_brace_open", "{"),
CodeAction::new("add_brace_close", "}"),
CodeAction::new("add_paren_open", "("),
CodeAction::new("add_paren_close", ")"),
CodeAction::new("add_arrow", "->"),
CodeAction::new("add_i32", "i32"),
CodeAction::new("add_bool", "bool"),
CodeAction::new("add_string", "String"),
CodeAction::new("complete", "EOF"),
],
}
}
}
#[cfg(feature = "generative")]
impl Default for CodeActionSpace {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "generative")]
impl ActionSpace<CodeState, CodeAction> for CodeActionSpace {
fn legal_actions(&self, state: &CodeState) -> Vec<CodeAction> {
if state.is_terminal() {
vec![]
} else {
self.available_actions.clone()
}
}
}
pub struct GenerativeRepair {
config: GenerativeRepairConfig,
}
impl GenerativeRepair {
pub fn new() -> Self {
Self {
config: GenerativeRepairConfig::default(),
}
}
pub fn with_config(config: GenerativeRepairConfig) -> Self {
Self { config }
}
#[cfg(feature = "generative")]
pub fn synthesize(&self, hir: &HirModule) -> Result<TokenStream> {
let target_patterns = self.extract_target_patterns(hir);
let mcts_config = MctsConfig {
max_iterations: self.config.max_iterations,
exploration_constant: self.config.exploration_constant,
max_simulation_depth: self.config.max_simulation_depth,
..Default::default()
};
let initial_state = CodeState::initial();
let action_space = CodeActionSpace::new();
let state_space = CodeStateSpace::new(target_patterns);
let mut mcts = if self.config.seed > 0 {
MctsSearch::with_seed(initial_state, &action_space, mcts_config, self.config.seed)
} else {
MctsSearch::new(initial_state, &action_space, mcts_config)
};
let result = mcts.search(&state_space, &action_space, None);
if let Some(state) = result.resulting_state {
self.tokens_to_stream(&state)
} else {
Ok(TokenStream::new())
}
}
#[cfg(not(feature = "generative"))]
pub fn synthesize(&self, _hir: &HirModule) -> Result<TokenStream> {
Ok(TokenStream::new())
}
#[cfg(feature = "generative")]
fn extract_target_patterns(&self, hir: &HirModule) -> Vec<String> {
let mut patterns = Vec::new();
for func in &hir.functions {
patterns.push(format!("fn {}", func.name));
for param in &func.params {
patterns.push(param.name.clone());
}
if !matches!(
func.ret_type,
crate::hir::Type::Unknown | crate::hir::Type::None
) {
patterns.push("->".to_string());
}
}
for class in &hir.classes {
patterns.push(format!("struct {}", class.name));
}
patterns
}
#[cfg(feature = "generative")]
fn tokens_to_stream(&self, state: &CodeState) -> Result<TokenStream> {
let code = state
.tokens()
.iter()
.filter(|t| *t != "EOF")
.cloned()
.collect::<Vec<_>>()
.join(" ");
match code.parse::<TokenStream>() {
Ok(ts) => Ok(ts),
Err(_) => {
Ok(TokenStream::new())
}
}
}
pub fn config(&self) -> &GenerativeRepairConfig {
&self.config
}
}
impl Default for GenerativeRepair {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SynthesisResult {
pub success: bool,
pub code: Option<String>,
pub iterations: usize,
pub expected_reward: f64,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_empty_hir() -> HirModule {
HirModule {
functions: vec![],
imports: vec![],
type_aliases: vec![],
protocols: vec![],
classes: vec![],
constants: vec![],
top_level_stmts: vec![],
}
}
#[test]
fn test_generative_synthesis_stub() {
let repair = GenerativeRepair::new();
let hir = create_empty_hir();
let result = repair.synthesize(&hir);
assert!(result.is_ok(), "synthesize should return Ok for empty HIR");
}
#[test]
fn test_generative_repair_config_default() {
let config = GenerativeRepairConfig::default();
assert_eq!(config.max_iterations, 100);
assert!(config.exploration_constant > 0.0);
assert_eq!(config.max_simulation_depth, 50);
assert!(!config.use_discriminator);
assert_eq!(config.seed, 0);
}
#[test]
fn test_generative_repair_with_config() {
let config = GenerativeRepairConfig {
max_iterations: 500,
exploration_constant: 2.0,
max_simulation_depth: 100,
use_discriminator: true,
seed: 42,
};
let repair = GenerativeRepair::with_config(config);
assert_eq!(repair.config().max_iterations, 500);
assert!(repair.config().use_discriminator);
assert_eq!(repair.config().seed, 42);
}
#[test]
fn test_code_state_creation() {
let state = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
assert_eq!(state.tokens().len(), 2);
assert!(!state.is_complete);
}
#[test]
fn test_code_state_terminal() {
let state = CodeState::new(vec!["fn".to_string(), "EOF".to_string()]);
assert!(state.is_complete);
}
#[test]
fn test_code_action_creation() {
let action = CodeAction::new("add_fn", "fn");
assert_eq!(action.name, "add_fn");
assert_eq!(action.token, "fn");
}
#[test]
fn test_synthesis_result_default() {
let result = SynthesisResult {
success: true,
code: Some("fn test() {}".to_string()),
iterations: 100,
expected_reward: 0.95,
};
assert!(result.success);
assert!(result.code.is_some());
assert_eq!(result.iterations, 100);
}
#[test]
fn test_code_state_initial() {
let state = CodeState::initial();
assert!(state.tokens().is_empty());
assert!(!state.is_complete);
}
#[test]
fn test_code_state_tokens_accessor() {
let state = CodeState::new(vec!["let".to_string(), "x".to_string(), "=".to_string()]);
let tokens = state.tokens();
assert_eq!(tokens.len(), 3);
assert_eq!(tokens[0], "let");
assert_eq!(tokens[1], "x");
assert_eq!(tokens[2], "=");
}
#[test]
fn test_code_state_partial_eq() {
let state1 = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
let state2 = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
let state3 = CodeState::new(vec!["fn".to_string(), "other".to_string()]);
assert_eq!(state1, state2);
assert_ne!(state1, state3);
}
#[test]
fn test_code_state_hash() {
use std::collections::HashSet;
let state1 = CodeState::new(vec!["fn".to_string()]);
let state2 = CodeState::new(vec!["fn".to_string()]);
let state3 = CodeState::new(vec!["let".to_string()]);
let mut set = HashSet::new();
set.insert(state1.tokens.clone());
set.insert(state2.tokens.clone());
set.insert(state3.tokens.clone());
assert_eq!(set.len(), 2);
}
#[test]
fn test_code_action_partial_eq() {
let action1 = CodeAction::new("add_fn", "fn");
let action2 = CodeAction::new("add_fn", "fn");
let action3 = CodeAction::new("add_let", "let");
assert_eq!(action1, action2);
assert_ne!(action1, action3);
}
#[test]
fn test_code_action_hash() {
use std::collections::HashSet;
let action1 = CodeAction::new("add_fn", "fn");
let action2 = CodeAction::new("add_fn", "fn");
let action3 = CodeAction::new("add_let", "let");
let mut set = HashSet::new();
set.insert(action1);
set.insert(action2);
set.insert(action3);
assert_eq!(set.len(), 2);
}
#[test]
fn test_code_action_debug() {
let action = CodeAction::new("add_struct", "struct");
let debug_str = format!("{:?}", action);
assert!(debug_str.contains("CodeAction"));
assert!(debug_str.contains("add_struct"));
assert!(debug_str.contains("struct"));
}
#[test]
fn test_code_action_clone() {
let action = CodeAction::new("add_impl", "impl");
let cloned = action.clone();
assert_eq!(action, cloned);
assert_eq!(cloned.name, "add_impl");
assert_eq!(cloned.token, "impl");
}
#[test]
fn test_generative_repair_default() {
let repair: GenerativeRepair = Default::default();
assert_eq!(repair.config().max_iterations, 100);
assert!(!repair.config().use_discriminator);
}
#[test]
fn test_generative_repair_config_debug() {
let config = GenerativeRepairConfig::default();
let debug_str = format!("{:?}", config);
assert!(debug_str.contains("GenerativeRepairConfig"));
assert!(debug_str.contains("max_iterations"));
assert!(debug_str.contains("exploration_constant"));
}
#[test]
fn test_generative_repair_config_clone() {
let config = GenerativeRepairConfig {
max_iterations: 200,
exploration_constant: 1.5,
max_simulation_depth: 75,
use_discriminator: true,
seed: 123,
};
let cloned = config.clone();
assert_eq!(cloned.max_iterations, 200);
assert_eq!(cloned.exploration_constant, 1.5);
assert_eq!(cloned.max_simulation_depth, 75);
assert!(cloned.use_discriminator);
assert_eq!(cloned.seed, 123);
}
#[test]
fn test_synthesis_result_debug() {
let result = SynthesisResult {
success: false,
code: None,
iterations: 50,
expected_reward: 0.25,
};
let debug_str = format!("{:?}", result);
assert!(debug_str.contains("SynthesisResult"));
assert!(debug_str.contains("success"));
assert!(debug_str.contains("false"));
}
#[test]
fn test_synthesis_result_clone() {
let result = SynthesisResult {
success: true,
code: Some("pub fn foo() -> i32 { 42 }".to_string()),
iterations: 150,
expected_reward: 0.85,
};
let cloned = result.clone();
assert!(cloned.success);
assert_eq!(cloned.code, Some("pub fn foo() -> i32 { 42 }".to_string()));
assert_eq!(cloned.iterations, 150);
assert_eq!(cloned.expected_reward, 0.85);
}
#[test]
fn test_synthesis_result_no_code() {
let result = SynthesisResult {
success: false,
code: None,
iterations: 0,
expected_reward: 0.0,
};
assert!(!result.success);
assert!(result.code.is_none());
assert_eq!(result.iterations, 0);
assert_eq!(result.expected_reward, 0.0);
}
#[test]
fn test_code_state_complete_with_eof() {
let state = CodeState::new(vec![
"fn".to_string(),
"main".to_string(),
"(".to_string(),
")".to_string(),
"{".to_string(),
"}".to_string(),
"EOF".to_string(),
]);
assert!(state.is_complete);
assert_eq!(state.tokens().len(), 7);
}
#[test]
fn test_code_state_not_complete_without_eof() {
let state = CodeState::new(vec![
"fn".to_string(),
"main".to_string(),
"(".to_string(),
")".to_string(),
]);
assert!(!state.is_complete);
}
#[test]
fn test_code_action_with_special_characters() {
let action1 = CodeAction::new("add_arrow", "->");
assert_eq!(action1.token, "->");
let action2 = CodeAction::new("add_ref", "&");
assert_eq!(action2.token, "&");
let action3 = CodeAction::new("add_semicolon", ";");
assert_eq!(action3.token, ";");
}
#[test]
fn test_generative_repair_config_exploration_constant() {
let config = GenerativeRepairConfig::default();
assert!(config.exploration_constant > 1.4);
assert!(config.exploration_constant < 1.5);
}
#[test]
fn test_generative_repair_config_custom_seed() {
let config = GenerativeRepairConfig {
seed: 12345,
..Default::default()
};
let repair = GenerativeRepair::with_config(config);
assert_eq!(repair.config().seed, 12345);
}
#[test]
fn test_generative_repair_config_method() {
let repair = GenerativeRepair::new();
let config = repair.config();
assert_eq!(config.max_iterations, 100);
assert_eq!(config.max_simulation_depth, 50);
}
#[test]
fn test_code_state_debug() {
let state = CodeState::new(vec!["let".to_string(), "mut".to_string()]);
let debug_str = format!("{:?}", state);
assert!(debug_str.contains("CodeState"));
assert!(debug_str.contains("tokens"));
}
#[test]
fn test_code_state_clone() {
let state = CodeState::new(vec!["struct".to_string(), "Point".to_string()]);
let cloned = state.clone();
assert_eq!(state, cloned);
assert_eq!(cloned.tokens().len(), 2);
}
#[cfg(feature = "generative")]
mod generative_tests {
use super::*;
#[test]
fn test_code_action_space_default() {
let action_space = CodeActionSpace::new();
let state = CodeState::initial();
let actions = action_space.legal_actions(&state);
assert!(!actions.is_empty());
let action_names: Vec<_> = actions.iter().map(|a| a.name.as_str()).collect();
assert!(action_names.contains(&"add_fn"));
assert!(action_names.contains(&"add_let"));
assert!(action_names.contains(&"complete"));
}
#[test]
fn test_code_state_space_evaluate() {
let state_space = CodeStateSpace::new(vec!["fn".to_string(), "test".to_string()]);
let empty = CodeState::initial();
let reward_empty = state_space.evaluate(&empty);
assert_eq!(reward_empty, 0.0);
let partial = CodeState::new(vec!["fn".to_string()]);
let reward_partial = state_space.evaluate(&partial);
assert!(reward_partial > 0.0);
assert!(reward_partial < 1.0);
let full = CodeState::new(vec!["fn".to_string(), "test".to_string()]);
let reward_full = state_space.evaluate(&full);
assert_eq!(reward_full, 1.0);
}
#[test]
fn test_mcts_integration() {
let config = GenerativeRepairConfig {
max_iterations: 10,
seed: 42,
..Default::default()
};
let repair = GenerativeRepair::with_config(config);
let hir = create_empty_hir();
let result = repair.synthesize(&hir);
assert!(result.is_ok());
}
}
}