use anyhow::Result;
use rand::Rng;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use std::time::Instant;
use eenn::{SelfLearningConfig, SelfLearningLightningStrike};
use theory_core::{
BinaryOp, CognitiveConfig, ConstValue, Constraint, ConstraintIR, Domain, Expr, TheoryTag,
Variable, VariableMetadata,
};
const TRAINING_DATA_FILE: &str = "neural_training_state.json";
const STATS_FILE: &str = "learning_progress.txt";
fn main() -> Result<()> {
println!("🧠📈 Persistent Neural Learning Demo");
println!("===================================\n");
println!("This demo generates RANDOM problems each run and saves neural training!");
println!("Run it multiple times to see the neural network get smarter! 🚀\n");
let mut self_learning_engine = load_or_create_engine()?;
let session_start = Instant::now();
let previous_stats = load_previous_stats();
if previous_stats.total_runs > 0 {
println!("📊 CONTINUING LEARNING from previous runs:");
println!(" Previous runs: {}", previous_stats.total_runs);
println!(
" Neural predictions used: {} ({:.1}%)",
previous_stats.neural_predictions,
previous_stats.neural_success_rate * 100.0
);
println!(
" Training examples collected: {}",
previous_stats.training_examples
);
println!(
" Average solving time: {:.2}ms\n",
previous_stats.avg_solve_time
);
} else {
println!("🆕 Starting fresh neural learning journey!\n");
}
let num_problems = 8 + rand::rng().random_range(0..5); println!(
"🎲 Generating {} random constraint problems...\n",
num_problems
);
let mut session_stats = SessionStats::new();
for problem_num in 1..=num_problems {
let (problem_ir, description) = generate_random_constraint_problem(problem_num);
println!("🏁 Problem {}: {}", problem_num, description);
let start_time = Instant::now();
let result = self_learning_engine.solve_and_learn(&problem_ir, false)?;
let duration = start_time.elapsed();
let solve_time_ms = duration.as_secs_f64() * 1000.0;
session_stats.record_solve(solve_time_ms, &result);
println!(
" ✅ Solution: {} in {:.2}ms (confidence: {:.1}%)",
if result.satisfiable {
"SATISFIABLE"
} else {
"UNSATISFIABLE"
},
solve_time_ms,
result.confidence * 100.0
);
if let (Some(branch_id), Some(strategy), Some(backend)) = (
&result.winning_branch_id,
&result.winning_strategy,
&result.winning_backend,
) {
println!(
" 🏆 Winner: '{}' ({} strategy, {} backend)",
branch_id, strategy, backend
);
if strategy == "Neural" {
println!(" 🧠 NEURAL BRANCH WON! AI is getting smarter! 🎉");
session_stats.neural_wins += 1;
}
}
let engine_stats = self_learning_engine.get_performance_stats();
if engine_stats.is_neural_trained {
let neural_usage = if engine_stats.neural_predictions_used > 0 {
"🧠 PREDICTING!"
} else {
"trained (low confidence)"
};
println!(
" 📊 Neural status: {} ({} examples)",
neural_usage, engine_stats.training_examples_collected
);
} else {
println!(
" 📊 Neural status: collecting data ({}/3 examples)",
engine_stats.training_examples_collected
);
}
println!();
std::thread::sleep(std::time::Duration::from_millis(300)); }
let session_duration = session_start.elapsed();
let final_engine_stats = self_learning_engine.get_performance_stats();
println!("🎯 Session Complete!");
println!("==================");
println!(" Problems solved: {}", num_problems);
println!(" Session time: {:.1}s", session_duration.as_secs_f64());
println!(
" Average solve time: {:.2}ms",
session_stats.avg_solve_time()
);
println!(
" Neural wins this session: {} ({:.1}%)",
session_stats.neural_wins,
(session_stats.neural_wins as f64 / num_problems as f64) * 100.0
);
println!("\n 🧠 Neural Network Progress:");
println!(
" - Training examples: {}",
final_engine_stats.training_examples_collected
);
println!(
" - Neural predictions: {} ({:.1}%)",
final_engine_stats.neural_predictions_used,
final_engine_stats.neural_success_rate * 100.0
);
println!(
" - SMT fallbacks: {} ({:.1}%)",
final_engine_stats.smt_fallbacks_used,
(1.0 - final_engine_stats.neural_success_rate) * 100.0
);
save_training_state(&self_learning_engine)?;
save_session_stats(&previous_stats, &session_stats, &final_engine_stats)?;
if final_engine_stats.neural_predictions_used > 0 {
println!("\n 🎉 SUCCESS: Neural network is actively predicting!");
println!(" 🚀 Run again to see continued learning!");
} else if final_engine_stats.is_neural_trained {
println!("\n 📈 Neural network is trained but being cautious");
println!(" 🔄 Run more times to build confidence!");
} else {
println!("\n 🌱 Neural network is still collecting training data");
println!(" 📚 Run again to continue learning!");
}
println!("\n✅ Progress saved for next run! 💾");
Ok(())
}
#[derive(Debug)]
struct SessionStats {
total_problems: usize,
total_solve_time: f64,
neural_wins: usize,
}
impl SessionStats {
fn new() -> Self {
Self {
total_problems: 0,
total_solve_time: 0.0,
neural_wins: 0,
}
}
fn record_solve(&mut self, solve_time_ms: f64, result: &theory_core::SolutionResult) {
self.total_problems += 1;
self.total_solve_time += solve_time_ms;
if let Some(strategy) = &result.winning_strategy {
if strategy == "Neural" {
self.neural_wins += 1;
}
}
}
fn avg_solve_time(&self) -> f64 {
if self.total_problems > 0 {
self.total_solve_time / self.total_problems as f64
} else {
0.0
}
}
}
#[derive(Debug, Default)]
struct PersistentStats {
total_runs: usize,
neural_predictions: usize,
neural_success_rate: f64,
training_examples: usize,
avg_solve_time: f64,
}
fn load_or_create_engine() -> Result<SelfLearningLightningStrike> {
let mut config = SelfLearningConfig::default();
config.cognitive_config = CognitiveConfig::fastest();
config.min_examples_for_training = 3;
config.retrain_frequency = 2;
config.confidence_threshold = 0.3; config.neural_save_path = Some(std::path::PathBuf::from(TRAINING_DATA_FILE));
let engine = SelfLearningLightningStrike::with_config(config)?;
Ok(engine)
}
fn generate_random_constraint_problem(seed: usize) -> (ConstraintIR, String) {
let mut rng = rand::rng();
let mut ir = ConstraintIR::new();
let x_var = Variable {
name: "x".to_string(),
domain: Domain::Real {
min: Some(0.0),
max: Some(20.0),
},
metadata: VariableMetadata::default(),
};
let y_var = Variable {
name: "y".to_string(),
domain: Domain::Real {
min: Some(0.0),
max: Some(20.0),
},
metadata: VariableMetadata::default(),
};
let x_id = ir.add_variable(x_var);
let y_id = ir.add_variable(y_var);
let constraint_types = vec![
"sum_bound",
"individual_bounds",
"equality",
"difference",
"product_bound",
];
let constraint_type = &constraint_types[rng.random_range(0..constraint_types.len())];
let description = match *constraint_type {
"sum_bound" => {
let bound = rng.random_range(5..25) as f64;
let constraint = Constraint::LessEqual {
lhs: Expr::Binary {
op: BinaryOp::Add,
lhs: Box::new(Expr::Var(x_id)),
rhs: Box::new(Expr::Var(y_id)),
},
rhs: Expr::Const(ConstValue::Real(bound)),
};
ir.add_constraint(constraint);
format!("Sum constraint: x + y <= {}", bound)
}
"individual_bounds" => {
let x_bound = rng.random_range(3..15) as f64;
let y_bound = rng.random_range(3..15) as f64;
let constraint1 = Constraint::LessEqual {
lhs: Expr::Var(x_id),
rhs: Expr::Const(ConstValue::Real(x_bound)),
};
let constraint2 = Constraint::LessEqual {
lhs: Expr::Var(y_id),
rhs: Expr::Const(ConstValue::Real(y_bound)),
};
ir.add_constraint(constraint1);
ir.add_constraint(constraint2);
format!("Individual bounds: x <= {}, y <= {}", x_bound, y_bound)
}
"equality" => {
let target = rng.random_range(8..20) as f64;
let constraint = Constraint::Equal {
lhs: Expr::Binary {
op: BinaryOp::Add,
lhs: Box::new(Expr::Var(x_id)),
rhs: Box::new(Expr::Var(y_id)),
},
rhs: Expr::Const(ConstValue::Real(target)),
};
ir.add_constraint(constraint);
format!("Equality: x + y = {}", target)
}
"difference" => {
let max_diff = rng.random_range(2..8) as f64;
let constraint = Constraint::LessEqual {
lhs: Expr::Binary {
op: BinaryOp::Sub,
lhs: Box::new(Expr::Var(x_id)),
rhs: Box::new(Expr::Var(y_id)),
},
rhs: Expr::Const(ConstValue::Real(max_diff)),
};
ir.add_constraint(constraint);
format!("Difference bound: x - y <= {}", max_diff)
}
_ => {
let bound = rng.random_range(10..30) as f64;
let constraint = Constraint::LessEqual {
lhs: Expr::Binary {
op: BinaryOp::Add,
lhs: Box::new(Expr::Var(x_id)),
rhs: Box::new(Expr::Var(y_id)),
},
rhs: Expr::Const(ConstValue::Real(bound)),
};
ir.add_constraint(constraint);
format!("Random sum bound: x + y <= {}", bound)
}
};
ir.add_theory_tag(TheoryTag::LRA);
let final_description = format!("#{}: {}", seed, description);
(ir, final_description)
}
fn save_training_state(engine: &SelfLearningLightningStrike) -> Result<()> {
engine.save_neural_weights()?;
let mut file = File::create("session_info.json")?;
writeln!(
file,
"{{\"saved\": true, \"timestamp\": \"{}\"}}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)?
.as_secs()
)?;
Ok(())
}
fn load_previous_stats() -> PersistentStats {
if let Ok(mut file) = File::open(STATS_FILE) {
let mut contents = String::new();
if file.read_to_string(&mut contents).is_ok() {
let lines: Vec<&str> = contents.lines().collect();
if lines.len() >= 5 {
return PersistentStats {
total_runs: lines[0].parse().unwrap_or(0),
neural_predictions: lines[1].parse().unwrap_or(0),
neural_success_rate: lines[2].parse().unwrap_or(0.0),
training_examples: lines[3].parse().unwrap_or(0),
avg_solve_time: lines[4].parse().unwrap_or(0.0),
};
}
}
}
PersistentStats::default()
}
fn save_session_stats(
previous: &PersistentStats,
session: &SessionStats,
engine_stats: &eenn::PerformanceStats,
) -> Result<()> {
let mut file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(STATS_FILE)?;
let new_total_runs = previous.total_runs + 1;
let new_neural_predictions = previous.neural_predictions + engine_stats.neural_predictions_used;
let new_avg_solve_time = if new_total_runs > 1 {
(previous.avg_solve_time * (new_total_runs - 1) as f64 + session.avg_solve_time())
/ new_total_runs as f64
} else {
session.avg_solve_time()
};
writeln!(file, "{}", new_total_runs)?;
writeln!(file, "{}", new_neural_predictions)?;
writeln!(file, "{}", engine_stats.neural_success_rate)?;
writeln!(file, "{}", engine_stats.training_examples_collected)?;
writeln!(file, "{}", new_avg_solve_time)?;
Ok(())
}