use anyhow::Result;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::Instant;
use crate::{Adam, Dataset, MSELoss, TrainableNeuron, Trainer, TrainingConfig, TrainingExample};
use theory_core::{
CognitiveConfig, ConstraintIR, ConstraintLearner, Domain, LearningConfig,
LightningStrikeEngine, MetaReasoningConfig, MetaReasoningEngine, OptimizationConfig,
PerformanceOptimizer, SolutionResult, SolutionValue,
};
#[derive(Debug, Clone)]
pub struct SelfLearningConfig {
pub cognitive_config: CognitiveConfig,
pub max_training_examples: usize,
pub min_examples_for_training: usize,
pub retrain_frequency: usize,
pub network_layers: Vec<usize>,
pub training_config: TrainingConfig,
pub confidence_threshold: f32,
pub neural_save_path: Option<std::path::PathBuf>,
pub verbose: bool,
}
impl Default for SelfLearningConfig {
fn default() -> Self {
Self {
cognitive_config: CognitiveConfig::fastest(),
max_training_examples: 1000,
min_examples_for_training: 25,
retrain_frequency: 10,
network_layers: vec![6, 12, 8, 2], training_config: TrainingConfig {
epochs: 30,
batch_size: 16,
learning_rate: 0.001,
validation_split: 0.2,
early_stopping_patience: Some(8),
verbose: false,
},
confidence_threshold: 0.25, neural_save_path: Some(std::path::PathBuf::from("neural_weights.json")),
verbose: false, }
}
}
#[derive(Debug, Clone)]
struct ConstraintTrainingData {
features: Vec<f32>,
satisfiable: bool,
solving_time_ms: f32,
}
pub struct SelfLearningLightningStrike {
lightning_strike: LightningStrikeEngine,
neural_predictor: Arc<Mutex<TrainableNeuron>>,
training_data: Arc<Mutex<VecDeque<ConstraintTrainingData>>>,
config: SelfLearningConfig,
examples_since_training: usize,
is_neural_trained: bool,
neural_predictions_used: usize,
smt_fallbacks_used: usize,
}
impl SelfLearningLightningStrike {
pub fn new() -> Result<Self> {
Self::with_config(SelfLearningConfig::default())
}
pub fn with_config(config: SelfLearningConfig) -> Result<Self> {
let learner = ConstraintLearner::new(LearningConfig::default());
let meta_reasoner = MetaReasoningEngine::new(MetaReasoningConfig::default(), learner);
let optimizer = PerformanceOptimizer::new(OptimizationConfig::default());
let lightning_strike =
LightningStrikeEngine::new(config.cognitive_config.clone(), meta_reasoner, optimizer);
let neural_predictor = if let Some(save_path) = &config.neural_save_path {
TrainableNeuron::new_or_load(config.network_layers.clone(), save_path, config.verbose)
} else {
TrainableNeuron::new(config.network_layers.clone())
};
Ok(Self {
lightning_strike,
neural_predictor: Arc::new(Mutex::new(neural_predictor)),
training_data: Arc::new(Mutex::new(VecDeque::new())),
config,
examples_since_training: 0,
is_neural_trained: false,
neural_predictions_used: 0,
smt_fallbacks_used: 0,
})
}
pub fn solve_and_learn(
&mut self,
constraint_ir: &ConstraintIR,
verbose: bool,
) -> Result<SolutionResult> {
let start_time = Instant::now();
if self.is_neural_trained {
if let Ok(neural_result) = self.try_neural_prediction(constraint_ir) {
if neural_result.confidence >= self.config.confidence_threshold {
self.neural_predictions_used += 1;
return Ok(neural_result);
}
}
}
let smt_result = self.lightning_strike.solve(constraint_ir, verbose)?;
let solving_duration = start_time.elapsed();
self.smt_fallbacks_used += 1;
self.learn_from_smt_result(
constraint_ir,
&smt_result,
solving_duration.as_millis() as f32,
)?;
Ok(smt_result)
}
fn try_neural_prediction(&self, constraint_ir: &ConstraintIR) -> Result<SolutionResult> {
let features = self.extract_constraint_features(constraint_ir);
let primary_input = features[0];
let (satisfiable_pred, time_pred) = {
let mut network = self.neural_predictor.lock().unwrap();
let output = network.forward(primary_input);
let satisfiable = output > 0.5;
let estimated_time = if satisfiable {
(features[1] * features[2] * 20.0).max(5.0).min(500.0)
} else {
2.0 };
(satisfiable, estimated_time)
};
let mut assignment = HashMap::new();
if satisfiable_pred {
for (i, var) in constraint_ir.variables.values().enumerate() {
let var_key = format!("v{}", i);
let value = match &var.domain {
Domain::Real { min, max } => {
let val = match (min, max) {
(Some(min_val), Some(max_val)) => (min_val + max_val) / 2.0,
(Some(min_val), None) => min_val + 1.0,
(None, Some(max_val)) => max_val - 1.0,
(None, None) => 0.0,
};
SolutionValue::Real(val)
}
Domain::Integer { min, max } => {
let val = match (min, max) {
(Some(min_val), Some(max_val)) => (min_val + max_val) / 2,
(Some(min_val), None) => min_val + 1,
(None, Some(max_val)) => max_val - 1,
(None, None) => 0,
};
SolutionValue::Integer(val)
}
Domain::Boolean => SolutionValue::Bool(true),
Domain::BitVector { width } => SolutionValue::BitVector {
value: 0,
width: *width,
},
Domain::String { .. } => SolutionValue::String("neural".to_string()),
Domain::Array { .. } => SolutionValue::String("array".to_string()),
Domain::Enum { values } => {
SolutionValue::String(values.first().cloned().unwrap_or_default())
}
};
assignment.insert(var_key, value);
}
}
let confidence = if self.is_neural_trained {
let training_size = self.training_data.lock().unwrap().len();
let total_experience = self.neural_predictions_used + self.smt_fallbacks_used;
let experience_confidence = (total_experience as f32 / 50.0).min(0.95); let training_confidence = (training_size as f32 / 10.0).min(0.9);
experience_confidence.max(training_confidence)
} else {
0.1
};
Ok(SolutionResult {
satisfiable: satisfiable_pred,
assignment,
confidence,
estimated_time_ms: time_pred as f64,
method: "NeuralPredictor".to_string(),
winning_branch_id: Some("NeuralPredictor".to_string()),
winning_strategy: Some("Neural".to_string()),
winning_backend: Some("CPU".to_string()),
variable_ranges: None,
})
}
fn learn_from_smt_result(
&mut self,
constraint_ir: &ConstraintIR,
smt_result: &SolutionResult,
solving_time_ms: f32,
) -> Result<()> {
let features = self.extract_constraint_features(constraint_ir);
let training_example = ConstraintTrainingData {
features,
satisfiable: smt_result.satisfiable,
solving_time_ms,
};
{
let mut data = self.training_data.lock().unwrap();
data.push_back(training_example);
if data.len() > self.config.max_training_examples {
data.pop_front();
}
}
self.examples_since_training += 1;
if self.examples_since_training >= self.config.retrain_frequency {
let data_size = self.training_data.lock().unwrap().len();
if data_size >= self.config.min_examples_for_training {
self.retrain_neural_network()?;
self.examples_since_training = 0;
}
}
Ok(())
}
fn extract_constraint_features(&self, constraint_ir: &ConstraintIR) -> Vec<f32> {
let num_vars = constraint_ir.variables.len() as f32;
let num_constraints = constraint_ir.constraints.len() as f32;
let num_theories = constraint_ir.theory_tags.len() as f32;
let avg_domain_size: f32 = constraint_ir
.variables
.values()
.map(|var| match &var.domain {
Domain::Real { min, max } => match (min, max) {
(Some(min_val), Some(max_val)) => (max_val - min_val) as f32,
_ => 100.0,
},
Domain::Integer { min, max } => match (min, max) {
(Some(min_val), Some(max_val)) => (max_val - min_val) as f32,
_ => 1000.0,
},
Domain::Boolean => 2.0,
Domain::BitVector { width } => (*width as f32).exp2(),
Domain::String { max_length } => max_length.unwrap_or(255) as f32,
Domain::Array { .. } => 100.0,
Domain::Enum { values } => values.len() as f32,
})
.sum::<f32>()
/ num_vars.max(1.0);
vec![
num_vars.ln(), num_constraints.ln(), (num_constraints / num_vars.max(1.0)).ln(), avg_domain_size.ln(), num_theories, (num_vars * num_constraints).ln(), ]
}
fn retrain_neural_network(&mut self) -> Result<()> {
let training_data = {
let data = self.training_data.lock().unwrap();
data.clone().into_iter().collect::<Vec<_>>()
};
if training_data.is_empty() {
return Ok(());
}
let sat_examples: Vec<TrainingExample> = training_data
.iter()
.map(|example| TrainingExample {
input: example.features[0], target: if example.satisfiable { 1.0 } else { 0.0 },
})
.collect();
if sat_examples.is_empty() {
return Ok(());
}
{
let mut network = self.neural_predictor.lock().unwrap();
let dataset = Dataset::new(sat_examples);
let loss_fn = MSELoss;
let mut optimizer = Adam::new(self.config.training_config.learning_rate);
let trainer = Trainer::new(self.config.training_config.clone());
let _history = trainer.train(&mut *network, &dataset, &loss_fn, &mut optimizer);
self.is_neural_trained = true;
}
if self.config.verbose {
println!(
"🧠Neural surrogate retrained on {} examples",
training_data.len()
);
}
Ok(())
}
pub fn get_performance_stats(&self) -> PerformanceStats {
let training_examples = self.training_data.lock().unwrap().len();
let total_solves = self.neural_predictions_used + self.smt_fallbacks_used;
PerformanceStats {
total_problems_solved: total_solves,
neural_predictions_used: self.neural_predictions_used,
smt_fallbacks_used: self.smt_fallbacks_used,
neural_success_rate: if total_solves > 0 {
self.neural_predictions_used as f32 / total_solves as f32
} else {
0.0
},
training_examples_collected: training_examples,
is_neural_trained: self.is_neural_trained,
}
}
pub fn force_retrain(&mut self) -> Result<()> {
self.retrain_neural_network()
}
}
#[derive(Debug, Clone)]
pub struct PerformanceStats {
pub total_problems_solved: usize,
pub neural_predictions_used: usize,
pub smt_fallbacks_used: usize,
pub neural_success_rate: f32,
pub training_examples_collected: usize,
pub is_neural_trained: bool,
}
impl PerformanceStats {
pub fn summary(&self) -> String {
if self.is_neural_trained {
format!(
"Solved {} problems ({:.1}% neural, {:.1}% SMT) - Trained on {} examples",
self.total_problems_solved,
self.neural_success_rate * 100.0,
(1.0 - self.neural_success_rate) * 100.0,
self.training_examples_collected
)
} else {
format!(
"Collecting training data: {}/{} examples needed",
self.training_examples_collected,
25 )
}
}
}
impl SelfLearningLightningStrike {
pub fn save_neural_weights(&self) -> Result<()> {
if let Some(save_path) = &self.config.neural_save_path {
let network = self.neural_predictor.lock().unwrap();
network
.save_to_file(save_path)
.map_err(|e| anyhow::anyhow!("Failed to save neural weights: {}", e))?;
if self.config.verbose {
println!("💾 Neural network weights saved to {:?}", save_path);
}
}
Ok(())
}
}
impl Drop for SelfLearningLightningStrike {
fn drop(&mut self) {
if let Err(e) = self.save_neural_weights() {
eprintln!("Warning: Failed to auto-save neural weights: {}", e);
}
}
}