use anyhow::{Context, Result, anyhow};
use clap::{Parser, Subcommand};
use console::style;
use dialoguer::{Input, theme::ColorfulTheme};
use directories::ProjectDirs;
use eenn::{SelfLearningConfig, SelfLearningLightningStrike, constraint_parser};
use hf_hub::api::sync::Api;
use lazy_static::lazy_static;
use llama_cpp_2::{
llama_backend::LlamaBackend,
model::{AddBos, LlamaModel},
};
use parking_lot::Mutex;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::Instant;
use theory_core::{ConstraintIR, VarId};
const PHI3_REPO_ID: &str = "microsoft/Phi-3-mini-4k-instruct-gguf";
const PHI3_MINI_MODEL_FILENAME: &str = "Phi-3-mini-4k-instruct-q4.gguf";
lazy_static! {
static ref MODEL_CACHE: Mutex<Option<(LlamaBackend, LlamaModel)>> = Mutex::new(None);
}
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
Solve {
#[arg(short, long)]
problem: Option<String>,
#[arg(short, long)]
verbose: bool,
},
Init {
#[arg(short, long)]
force: bool,
},
Stats,
}
fn main() -> Result<()> {
let cli = Cli::parse();
match cli.command.unwrap_or(Commands::Solve {
problem: None,
verbose: false,
}) {
Commands::Solve { problem, verbose } => {
let problem_text = match problem {
Some(p) => p,
None => {
Input::<String>::with_theme(&ColorfulTheme::default())
.with_prompt("Enter math problem to solve")
.interact_text()?
}
};
solve_problem(&problem_text, verbose)
}
Commands::Init { force } => initialize_solver(force),
Commands::Stats => show_statistics(),
}
}
fn initialize_solver(force: bool) -> Result<()> {
println!(
"{}",
style("🧠 Initializing EENN Math Solver").cyan().bold()
);
let data_dir = get_data_directory()?;
fs::create_dir_all(&data_dir).context("Failed to create data directory")?;
let model_path = data_dir.join(PHI3_MINI_MODEL_FILENAME);
if model_path.exists() && !force {
println!("✅ Model already downloaded at: {}", model_path.display());
return Ok(());
}
println!("📥 Downloading Phi-3 Mini model from HuggingFace...");
println!(" Repository: {}", PHI3_REPO_ID);
println!(" File: {}", PHI3_MINI_MODEL_FILENAME);
let api = Api::new().context("Failed to initialize HuggingFace API")?;
let repo = api.model(PHI3_REPO_ID.to_string());
println!("⏳ Downloading... (this may take a while, ~2.4GB)");
let downloaded_path = repo
.get(PHI3_MINI_MODEL_FILENAME)
.context("Failed to download model from HuggingFace")?;
fs::copy(&downloaded_path, &model_path).context("Failed to copy model to data directory")?;
println!(
"✅ Model downloaded successfully to: {}",
model_path.display()
);
println!("🚀 EENN Math Solver is ready to use!");
Ok(())
}
fn solve_problem(problem_text: &str, verbose: bool) -> Result<()> {
if verbose {
println!("{}", style("🧮 Math Problem:").cyan().bold());
println!("{}\n", problem_text);
println!("{}", style("🤔 Thinking...").yellow());
}
let start_time = Instant::now();
let model_path = get_data_directory()?.join(PHI3_MINI_MODEL_FILENAME);
if !model_path.exists() {
return Err(anyhow!(
"Model not found. Please run 'eenn-solve init' to download the Phi-3 model."
));
}
let (constraints, var_name_map) =
extract_constraints_with_llm(problem_text, &model_path, verbose)?;
if verbose {
println!(
"{}",
style("⚡ Applying Lightning Strike algorithm...").yellow()
);
}
let mut config = SelfLearningConfig::default();
config.neural_save_path = Some(get_data_directory()?.join("neural_weights.json"));
config.verbose = verbose; let mut solver = SelfLearningLightningStrike::with_config(config)?;
let solution_result = match solver.solve_and_learn(&constraints, verbose) {
Ok(result) => result,
Err(e) => {
let err_msg = e.to_string();
if err_msg.contains("Division by zero") {
eprintln!("{}", style("❌ Error: Division by zero detected").red().bold());
eprintln!("💡 Hint: Check your problem for division by zero (e.g., x/0)");
} else if err_msg.contains("Square root of negative") {
eprintln!("{}", style("❌ Error: Square root of negative number").red().bold());
eprintln!("💡 Hint: Square roots of negative numbers are not supported");
} else if err_msg.contains("Exponent too large") {
eprintln!("{}", style("❌ Error: Exponent too large").red().bold());
eprintln!("💡 Hint: Try using smaller exponents (max: 100)");
} else {
eprintln!("{}", style(format!("❌ Error: {}", err_msg)).red().bold());
}
return Err(e);
}
};
let solve_time = start_time.elapsed();
if solution_result.satisfiable {
if verbose {
println!("\n{}", style("🎯 Solution:").green().bold());
println!(
"{} (in {}ms)",
style("SATISFIABLE").green().bold(),
solve_time.as_millis()
);
}
if let Some(ranges) = &solution_result.variable_ranges {
for (var_name, range_str) in ranges {
let display_name = var_name_map
.iter()
.find(|(_, id_str)| id_str.as_str() == var_name)
.map(|(name, _)| name.as_str())
.unwrap_or(var_name);
let display_range = range_str.replace(var_name, display_name);
if verbose {
println!(" {}", style(&display_range).cyan());
} else {
println!("{}", display_range);
}
}
} else {
for (var_name, value) in &solution_result.assignment {
let display_name = var_name_map
.iter()
.find(|(_, id_str)| id_str.as_str() == var_name)
.map(|(name, _)| name.as_str())
.unwrap_or(var_name);
if verbose {
println!(
" {} = {}",
style(display_name).cyan(),
style(value).yellow()
);
} else {
println!("{} = {}", display_name, value);
}
}
}
if verbose {
if solution_result.winning_strategy.as_deref() == Some("neural") {
println!(
"\n{} (confidence: {:.1}%)",
style("✓ Used neural prediction").green(),
solution_result.confidence * 100.0
);
} else {
println!("\n{}", style("✓ Used SMT solving").yellow());
}
}
} else {
if verbose {
println!("\n{}", style("🎯 Solution:").green().bold());
println!(
"{} (in {}ms)",
style("UNSATISFIABLE").red().bold(),
solve_time.as_millis()
);
println!("\n💡 {}", style("Possible reasons:").yellow());
println!(" • The constraints may be contradictory (e.g., x = 5 and x = 10)");
println!(" • The problem may have no solution in the given domain");
println!(" • Try rephrasing or simplifying your constraints");
} else {
println!("No solution exists.");
println!("💡 Hint: Use -v flag for more details on why this is unsatisfiable.");
}
}
Ok(())
}
fn extract_constraints_with_llm(
problem_text: &str,
model_path: &Path,
verbose: bool,
) -> Result<(ConstraintIR, HashMap<String, String>)> {
use llama_cpp_2::{
LogOptions, context::params::LlamaContextParams, llama_batch::LlamaBatch, model::Special,
send_logs_to_tracing,
};
use std::num::NonZeroU32;
send_logs_to_tracing(LogOptions::default().with_logs_enabled(verbose));
if verbose {
println!("🤖 Using Phi-3 Mini for natural language understanding...");
}
let mut cache = MODEL_CACHE.lock();
if cache.is_none() {
if verbose {
println!(" Loading model (first time only)...");
}
let backend = LlamaBackend::init().context("Failed to initialize LLaMA backend")?;
let model = LlamaModel::load_from_file(&backend, model_path, &Default::default())
.context("Failed to load model")?;
*cache = Some((backend, model));
} else if verbose {
println!(" ✓ Using cached model (instant load)");
}
let (backend, model) = cache.as_ref().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(2048));
let mut ctx = model
.new_context(&backend, ctx_params)
.context("Failed to create context")?;
let prompt = format!(
"<|system|>You are a mathematical problem analyzer. Extract variables and constraints from word problems.
Format your response EXACTLY as:
Variables:
- name: result, type: integer, domain: 0..1000
Constraints:
- result = <expression><|end|>
<|user|>{}?<|end|>
<|assistant|>",
problem_text.trim_end_matches('?')
);
let tokens_list = model
.str_to_token(&prompt, llama_cpp_2::model::AddBos::Always)
.context("Failed to tokenize prompt")?;
if verbose {
println!(" Prompt: {} tokens", tokens_list.len());
}
let mut batch = LlamaBatch::new(512, 1);
let last_index = (tokens_list.len() - 1) as i32;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
let is_last = i == last_index;
batch.add(token, i, &[0], is_last)?;
}
ctx.decode(&mut batch).context("Failed to decode prompt")?;
let mut n_cur = batch.n_tokens();
let max_tokens = 256;
let mut response_parts = Vec::new();
for _ in 0..max_tokens {
let candidates: Vec<_> = ctx.candidates().collect();
if candidates.is_empty() {
break;
}
let mut best_token = candidates[0].id();
let mut best_logit = candidates[0].logit();
for candidate in &candidates {
if candidate.logit() > best_logit {
best_logit = candidate.logit();
best_token = candidate.id();
}
}
if model.is_eog_token(best_token) {
break;
}
if let Ok(token_str) = model.token_to_str(best_token, Special::Tokenize) {
response_parts.push(token_str);
}
batch.clear();
batch.add(best_token, n_cur, &[0], true)?;
ctx.decode(&mut batch).context("Failed to decode token")?;
n_cur += 1;
}
let response = response_parts.join("");
if verbose {
println!("✓ LLM response: {} tokens", response_parts.len());
println!("{}", response.trim());
}
constraint_parser::parse_llm_response_with_mapping(&response, verbose)
}
fn show_statistics() -> Result<()> {
let stats_path = get_data_directory()?.join("learning_progress.txt");
if !stats_path.exists() {
println!("No statistics available yet. Solve some problems first!");
return Ok(());
}
let stats_content =
fs::read_to_string(&stats_path).context("Failed to read statistics file")?;
println!(
"{}",
style("📊 Neural Network Training Statistics").cyan().bold()
);
println!("{}", stats_content);
Ok(())
}
fn get_data_directory() -> Result<PathBuf> {
let proj_dirs = ProjectDirs::from("com", "eenn", "math-solver")
.ok_or_else(|| anyhow!("Failed to determine project directories"))?;
let data_dir = proj_dirs.data_dir().to_path_buf();
fs::create_dir_all(&data_dir).context("Failed to create data directory")?;
Ok(data_dir)
}