use std::collections::{HashMap, HashSet};
use std::fmt;
use crate::ast::{BinaryOp, Equation, Expression, Variable};
use crate::numerical::{NumericalConfig, SmartNumericalSolver};
use crate::resolution_path::{Operation, ResolutionPath};
use crate::solver::{SmartSolver, Solution, Solver, SolverError};
use crate::integration::integrate;
use crate::ode::{solve_linear as solve_linear_ode, solve_separable, FirstOrderODE};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum SystemError {
CircularDependency(Vec<String>),
InsufficientEquations {
needed: usize,
have: usize,
},
OverdeterminedSystem {
equations: usize,
unknowns: usize,
},
UnsolvableEquation {
id: String,
reason: String,
},
InconsistentSystem {
equations: Vec<String>,
},
NumericalFailure {
variable: String,
reason: String,
},
VariableNotFound(String),
EquationNotFound(String),
NoStrategyFound(String),
SolverError(String),
ParseError(String),
}
impl fmt::Display for SystemError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CircularDependency(vars) => {
write!(f, "Circular dependency detected: {}", vars.join(" -> "))
}
Self::InsufficientEquations { needed, have } => {
write!(
f,
"Insufficient equations: need {} but have {}",
needed, have
)
}
Self::OverdeterminedSystem {
equations,
unknowns,
} => {
write!(
f,
"Overdetermined system: {} equations for {} unknowns",
equations, unknowns
)
}
Self::UnsolvableEquation { id, reason } => {
write!(f, "Cannot solve equation '{}': {}", id, reason)
}
Self::InconsistentSystem { equations } => {
write!(
f,
"Inconsistent system: equations {} contradict",
equations.join(", ")
)
}
Self::NumericalFailure { variable, reason } => {
write!(
f,
"Numerical failure solving for '{}': {}",
variable, reason
)
}
Self::VariableNotFound(var) => {
write!(f, "Variable '{}' not found in system", var)
}
Self::EquationNotFound(id) => {
write!(f, "Equation '{}' not found in system", id)
}
Self::NoStrategyFound(reason) => {
write!(f, "No solving strategy found: {}", reason)
}
Self::SolverError(msg) => write!(f, "Solver error: {}", msg),
Self::ParseError(msg) => write!(f, "Parse error: {}", msg),
}
}
}
impl std::error::Error for SystemError {}
impl From<SolverError> for SystemError {
fn from(err: SolverError) -> Self {
Self::SolverError(format!("{:?}", err))
}
}
#[derive(Debug, Clone)]
pub struct ODEInfo {
pub dependent_var: String,
pub independent_var: String,
pub order: usize,
}
#[derive(Debug, Clone)]
pub struct IntegralInfo {
pub integration_var: String,
pub lower_bound: Option<Expression>,
pub upper_bound: Option<Expression>,
}
#[derive(Debug, Clone)]
pub enum EquationType {
Algebraic,
ODE(ODEInfo),
Integral(IntegralInfo),
Differential,
Implicit,
Unknown,
}
impl Default for EquationType {
fn default() -> Self {
Self::Unknown
}
}
#[derive(Debug, Clone)]
pub struct NamedEquation {
pub id: String,
pub equation: Equation,
pub equation_type: EquationType,
pub description: Option<String>,
}
impl NamedEquation {
pub fn new(id: impl Into<String>, equation: Equation) -> Self {
let eq_type = Self::classify(&equation);
Self {
id: id.into(),
equation,
equation_type: eq_type,
description: None,
}
}
pub fn with_type(
id: impl Into<String>,
equation: Equation,
equation_type: EquationType,
) -> Self {
Self {
id: id.into(),
equation,
equation_type,
description: None,
}
}
#[must_use]
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn variables(&self) -> HashSet<String> {
let mut vars = self.equation.left.variables();
vars.extend(self.equation.right.variables());
vars
}
fn classify(equation: &Equation) -> EquationType {
if Self::contains_derivative(&equation.left) || Self::contains_derivative(&equation.right) {
if let Some(info) = Self::extract_ode_info(equation) {
return EquationType::ODE(info);
}
return EquationType::Differential;
}
if Self::contains_integral(&equation.left) || Self::contains_integral(&equation.right) {
if let Some(info) = Self::extract_integral_info(equation) {
return EquationType::Integral(info);
}
}
EquationType::Algebraic
}
fn contains_derivative(expr: &Expression) -> bool {
match expr {
Expression::Function(func, _) => {
matches!(func, crate::ast::Function::Custom(name)
if name.starts_with("d") && name.contains("/d"))
}
Expression::Binary(_, left, right) => {
Self::contains_derivative(left) || Self::contains_derivative(right)
}
Expression::Unary(_, inner) => Self::contains_derivative(inner),
Expression::Power(base, exp) => {
Self::contains_derivative(base) || Self::contains_derivative(exp)
}
_ => false,
}
}
fn contains_integral(expr: &Expression) -> bool {
match expr {
Expression::Function(func, _) => {
matches!(func, crate::ast::Function::Custom(name)
if name == "integral" || name == "int" || name == "integrate")
}
Expression::Binary(_, left, right) => {
Self::contains_integral(left) || Self::contains_integral(right)
}
Expression::Unary(_, inner) => Self::contains_integral(inner),
Expression::Power(base, exp) => {
Self::contains_integral(base) || Self::contains_integral(exp)
}
_ => false,
}
}
fn extract_ode_info(_equation: &Equation) -> Option<ODEInfo> {
None
}
fn extract_integral_info(_equation: &Equation) -> Option<IntegralInfo> {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct EquationSystem {
equations: HashMap<String, NamedEquation>,
pub description: Option<String>,
}
impl EquationSystem {
pub fn new() -> Self {
Self::default()
}
pub fn with_description(description: impl Into<String>) -> Self {
Self {
equations: HashMap::new(),
description: Some(description.into()),
}
}
pub fn add_equation(&mut self, id: impl Into<String>, equation: Equation) -> &mut Self {
let id = id.into();
self.equations
.insert(id.clone(), NamedEquation::new(id, equation));
self
}
pub fn add_named_equation(&mut self, named_eq: NamedEquation) -> &mut Self {
self.equations.insert(named_eq.id.clone(), named_eq);
self
}
#[must_use]
pub fn with_equation(mut self, id: impl Into<String>, equation: Equation) -> Self {
self.add_equation(id, equation);
self
}
pub fn get(&self, id: &str) -> Option<&NamedEquation> {
self.equations.get(id)
}
pub fn equation_ids(&self) -> impl Iterator<Item = &String> {
self.equations.keys()
}
pub fn equations(&self) -> impl Iterator<Item = &NamedEquation> {
self.equations.values()
}
pub fn len(&self) -> usize {
self.equations.len()
}
pub fn is_empty(&self) -> bool {
self.equations.is_empty()
}
pub fn all_variables(&self) -> HashSet<String> {
let mut vars = HashSet::new();
for eq in self.equations.values() {
vars.extend(eq.variables());
}
vars
}
pub fn remove(&mut self, id: &str) -> Option<NamedEquation> {
self.equations.remove(id)
}
}
#[derive(Debug, Clone)]
pub enum Constraint {
GreaterThan(String, f64),
LessThan(String, f64),
InRange(String, f64, f64),
Positive(String),
NonNegative(String),
Integer(String),
Custom(String),
}
#[derive(Debug, Clone, Default)]
pub struct SystemContext {
pub known_values: HashMap<String, f64>,
pub known_expressions: HashMap<String, Expression>,
pub target_variables: Vec<String>,
pub constraints: Vec<Constraint>,
pub verify_solutions: bool,
pub tolerance: f64,
}
impl SystemContext {
pub fn new() -> Self {
Self {
known_values: HashMap::new(),
known_expressions: HashMap::new(),
target_variables: Vec::new(),
constraints: Vec::new(),
verify_solutions: true,
tolerance: 1e-10,
}
}
#[must_use]
pub fn with_known_value(mut self, variable: impl Into<String>, value: f64) -> Self {
self.known_values.insert(variable.into(), value);
self
}
#[must_use]
pub fn with_known_expression(mut self, variable: impl Into<String>, expr: Expression) -> Self {
self.known_expressions.insert(variable.into(), expr);
self
}
#[must_use]
pub fn with_target(mut self, variable: impl Into<String>) -> Self {
self.target_variables.push(variable.into());
self
}
#[must_use]
pub fn with_targets(mut self, variables: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.target_variables
.extend(variables.into_iter().map(Into::into));
self
}
#[must_use]
pub fn with_constraint(mut self, constraint: Constraint) -> Self {
self.constraints.push(constraint);
self
}
#[must_use]
pub fn with_verification(mut self, verify: bool) -> Self {
self.verify_solutions = verify;
self
}
#[must_use]
pub fn with_tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}
pub fn known_variable_names(&self) -> HashSet<String> {
let mut vars: HashSet<String> = self.known_values.keys().cloned().collect();
vars.extend(self.known_expressions.keys().cloned());
vars
}
pub fn is_known(&self, var: &str) -> bool {
self.known_values.contains_key(var) || self.known_expressions.contains_key(var)
}
}
#[derive(Debug, Clone)]
pub struct DependencyGraph {
variable_to_equations: HashMap<String, HashSet<String>>,
equation_to_variables: HashMap<String, HashSet<String>>,
equation_can_solve: HashMap<String, HashSet<String>>,
}
impl DependencyGraph {
pub fn build(system: &EquationSystem) -> Self {
let mut variable_to_equations: HashMap<String, HashSet<String>> = HashMap::new();
let mut equation_to_variables: HashMap<String, HashSet<String>> = HashMap::new();
let mut equation_can_solve: HashMap<String, HashSet<String>> = HashMap::new();
for eq in system.equations() {
let vars = eq.variables();
equation_to_variables.insert(eq.id.clone(), vars.clone());
let solvable = match &eq.equation_type {
EquationType::Algebraic | EquationType::Unknown => vars.clone(),
EquationType::ODE(info) => {
let mut set = HashSet::new();
set.insert(info.dependent_var.clone());
set
}
EquationType::Integral(info) => {
let mut set = vars.clone();
set.remove(&info.integration_var);
set
}
EquationType::Differential | EquationType::Implicit => HashSet::new(),
};
equation_can_solve.insert(eq.id.clone(), solvable);
for var in vars {
variable_to_equations
.entry(var)
.or_default()
.insert(eq.id.clone());
}
}
Self {
variable_to_equations,
equation_to_variables,
equation_can_solve,
}
}
pub fn find_solvable(&self, known: &HashSet<String>) -> Vec<(String, String)> {
let mut solvable = Vec::new();
for (eq_id, vars) in &self.equation_to_variables {
let unknowns: Vec<_> = vars.iter().filter(|v| !known.contains(*v)).collect();
if unknowns.len() == 1 {
let unknown = unknowns[0];
if let Some(can_solve) = self.equation_can_solve.get(eq_id) {
if can_solve.contains(unknown) {
solvable.push((eq_id.clone(), unknown.clone()));
}
}
}
}
solvable
}
pub fn equations_with_variable(&self, var: &str) -> Option<&HashSet<String>> {
self.variable_to_equations.get(var)
}
pub fn variables_in_equation(&self, eq_id: &str) -> Option<&HashSet<String>> {
self.equation_to_variables.get(eq_id)
}
pub fn has_circular_dependency(&self, targets: &[String], known: &HashSet<String>) -> bool {
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for target in targets {
if !known.contains(target)
&& self.has_cycle_dfs(target, known, &mut visited, &mut rec_stack)
{
return true;
}
}
false
}
fn has_cycle_dfs(
&self,
var: &str,
known: &HashSet<String>,
visited: &mut HashSet<String>,
rec_stack: &mut HashSet<String>,
) -> bool {
if rec_stack.contains(var) {
return true;
}
if visited.contains(var) || known.contains(var) {
return false;
}
visited.insert(var.to_string());
rec_stack.insert(var.to_string());
if let Some(eqs) = self.variable_to_equations.get(var) {
for eq_id in eqs {
if let Some(eq_vars) = self.equation_to_variables.get(eq_id) {
for dep_var in eq_vars {
if dep_var != var && self.has_cycle_dfs(dep_var, known, visited, rec_stack)
{
return true;
}
}
}
}
}
rec_stack.remove(var);
false
}
}
#[derive(Debug, Clone)]
pub enum SolveMethod {
Algebraic,
Substitution,
ODE {
method: String,
},
Integration,
Differentiation,
Numerical,
Custom(String),
}
impl fmt::Display for SolveMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Algebraic => write!(f, "algebraic"),
Self::Substitution => write!(f, "substitution"),
Self::ODE { method } => write!(f, "ODE ({})", method),
Self::Integration => write!(f, "integration"),
Self::Differentiation => write!(f, "differentiation"),
Self::Numerical => write!(f, "numerical"),
Self::Custom(name) => write!(f, "custom: {}", name),
}
}
}
#[derive(Debug, Clone)]
pub struct SolveStep {
pub equation_id: String,
pub solve_for: String,
pub method: SolveMethod,
pub dependencies: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct SolutionStrategy {
pub steps: Vec<SolveStep>,
pub parallel_groups: Vec<Vec<usize>>,
}
impl SolutionStrategy {
pub fn new() -> Self {
Self {
steps: Vec::new(),
parallel_groups: Vec::new(),
}
}
pub fn add_step(&mut self, step: SolveStep) {
self.steps.push(step);
}
pub fn is_empty(&self) -> bool {
self.steps.is_empty()
}
pub fn len(&self) -> usize {
self.steps.len()
}
}
impl Default for SolutionStrategy {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum SolutionValue {
Numeric(f64),
Symbolic(Expression),
Multiple(Vec<Expression>),
Parametric {
expr: Expression,
parameter: String,
},
}
impl SolutionValue {
pub fn as_numeric(&self) -> Option<f64> {
match self {
Self::Numeric(v) => Some(*v),
Self::Symbolic(expr) => {
if let Expression::Float(f) = expr {
Some(*f)
} else if let Expression::Integer(i) = expr {
Some(*i as f64)
} else {
None
}
}
_ => None,
}
}
pub fn as_expression(&self) -> Option<&Expression> {
match self {
Self::Numeric(_v) => None,
Self::Symbolic(expr) => Some(expr),
Self::Multiple(exprs) if exprs.len() == 1 => Some(&exprs[0]),
_ => None,
}
}
pub fn to_expression(&self) -> Expression {
match self {
Self::Numeric(v) => Expression::Float(*v),
Self::Symbolic(expr) => expr.clone(),
Self::Multiple(exprs) if !exprs.is_empty() => exprs[0].clone(),
Self::Parametric { expr, .. } => expr.clone(),
_ => Expression::Integer(0), }
}
}
#[derive(Debug, Clone)]
pub enum SystemOperation {
SelectEquation {
reason: String,
},
SolveFor {
variable: String,
method: SolveMethod,
},
SubstituteResult {
variable: String,
into_equations: Vec<String>,
},
VerifySolution {
variable: String,
},
EquationOperation(Operation),
}
impl fmt::Display for SystemOperation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SelectEquation { reason } => write!(f, "Select equation: {}", reason),
Self::SolveFor { variable, method } => {
write!(f, "Solve for {} using {}", variable, method)
}
Self::SubstituteResult {
variable,
into_equations,
} => {
write!(
f,
"Substitute {} into equations: {}",
variable,
into_equations.join(", ")
)
}
Self::VerifySolution { variable } => write!(f, "Verify solution for {}", variable),
Self::EquationOperation(op) => write!(f, "{:?}", op),
}
}
}
#[derive(Debug, Clone)]
pub enum StepResult {
Expression(Expression),
Value(f64),
Intermediate {
known_so_far: HashMap<String, Expression>,
},
}
#[derive(Debug, Clone)]
pub struct SystemStep {
pub step_number: usize,
pub equation_id: String,
pub operation: SystemOperation,
pub explanation: String,
pub result: StepResult,
}
#[derive(Debug, Clone)]
pub struct SystemResolutionPath {
pub initial_context: SystemContext,
pub equation_paths: HashMap<String, ResolutionPath>,
pub steps: Vec<SystemStep>,
pub final_solutions: HashMap<String, Expression>,
}
impl SystemResolutionPath {
pub fn new(context: SystemContext) -> Self {
Self {
initial_context: context,
equation_paths: HashMap::new(),
steps: Vec::new(),
final_solutions: HashMap::new(),
}
}
pub fn add_step(&mut self, step: SystemStep) {
self.steps.push(step);
}
pub fn add_equation_path(&mut self, eq_id: String, path: ResolutionPath) {
self.equation_paths.insert(eq_id, path);
}
pub fn record_solution(&mut self, variable: String, value: Expression) {
self.final_solutions.insert(variable, value);
}
pub fn format_text(&self) -> String {
let mut output = String::new();
output.push_str("=== Multi-Equation System Solution ===\n\n");
output.push_str("Known values:\n");
for (var, val) in &self.initial_context.known_values {
output.push_str(&format!(" {} = {}\n", var, val));
}
output.push_str("\nTarget variables: ");
output.push_str(&self.initial_context.target_variables.join(", "));
output.push_str("\n\n");
output.push_str("Solution steps:\n");
for step in &self.steps {
output.push_str(&format!(
"{}. [{}] {}\n {}\n",
step.step_number, step.equation_id, step.operation, step.explanation
));
}
output.push_str("\nFinal solutions:\n");
for (var, expr) in &self.final_solutions {
output.push_str(&format!(" {} = {}\n", var, expr));
}
output
}
pub fn to_latex(&self) -> String {
let mut output = String::new();
output.push_str("\\section*{Multi-Equation System Solution}\n\n");
output.push_str("\\subsection*{Given}\n");
output.push_str("\\begin{align*}\n");
for (var, val) in &self.initial_context.known_values {
output.push_str(&format!("{} &= {} \\\\\n", var, val));
}
output.push_str("\\end{align*}\n\n");
output.push_str("\\subsection*{Solutions}\n");
output.push_str("\\begin{align*}\n");
for (var, expr) in &self.final_solutions {
output.push_str(&format!("{} &= {} \\\\\n", var, expr.to_latex()));
}
output.push_str("\\end{align*}\n");
output
}
}
#[derive(Debug, Clone)]
pub struct MultiEquationSolution {
pub solutions: HashMap<String, SolutionValue>,
pub resolution_path: SystemResolutionPath,
pub unsolved: Vec<String>,
pub warnings: Vec<String>,
}
impl MultiEquationSolution {
pub fn new(context: SystemContext) -> Self {
Self {
solutions: HashMap::new(),
resolution_path: SystemResolutionPath::new(context),
unsolved: Vec::new(),
warnings: Vec::new(),
}
}
pub fn get(&self, variable: &str) -> Option<&SolutionValue> {
self.solutions.get(variable)
}
pub fn get_numeric(&self, variable: &str) -> Option<f64> {
self.solutions.get(variable).and_then(|v| v.as_numeric())
}
pub fn get_expression(&self, variable: &str) -> Option<&Expression> {
self.solutions.get(variable).and_then(|v| v.as_expression())
}
pub fn is_complete(&self) -> bool {
self.unsolved.is_empty()
}
pub fn add_solution(&mut self, variable: String, value: SolutionValue) {
self.resolution_path
.record_solution(variable.clone(), value.to_expression());
self.solutions.insert(variable, value);
}
pub fn add_warning(&mut self, warning: String) {
self.warnings.push(warning);
}
pub fn mark_unsolved(&mut self, variable: String) {
if !self.unsolved.contains(&variable) {
self.unsolved.push(variable);
}
}
}
#[derive(Debug, Clone)]
pub struct SolverConfig {
pub max_iterations: usize,
pub use_numerical_fallback: bool,
pub numerical_config: NumericalConfig,
pub verify_solutions: bool,
pub tolerance: f64,
}
impl Default for SolverConfig {
fn default() -> Self {
Self {
max_iterations: 100,
use_numerical_fallback: true,
numerical_config: NumericalConfig::default(),
verify_solutions: true,
tolerance: 1e-10,
}
}
}
pub struct MultiEquationSolver {
algebraic_solver: SmartSolver,
numerical_solver: SmartNumericalSolver,
config: SolverConfig,
}
impl MultiEquationSolver {
pub fn new() -> Self {
Self {
algebraic_solver: SmartSolver::new(),
numerical_solver: SmartNumericalSolver::with_default_config(),
config: SolverConfig::default(),
}
}
pub fn with_config(config: SolverConfig) -> Self {
Self {
algebraic_solver: SmartSolver::new(),
numerical_solver: SmartNumericalSolver::new(config.numerical_config.clone()),
config,
}
}
pub fn solve(
&self,
system: &EquationSystem,
context: &SystemContext,
) -> Result<MultiEquationSolution, SystemError> {
if system.is_empty() {
return Err(SystemError::NoStrategyFound(
"No equations in system".to_string(),
));
}
if context.target_variables.is_empty() {
return Err(SystemError::NoStrategyFound(
"No target variables specified".to_string(),
));
}
let graph = self.analyze_dependencies(system);
let _known = context.known_variable_names();
let strategy = self.plan_solution(&graph, system, context)?;
let solution = self.execute_strategy(&strategy, system, context)?;
Ok(solution)
}
fn analyze_dependencies(&self, system: &EquationSystem) -> DependencyGraph {
DependencyGraph::build(system)
}
fn plan_solution(
&self,
graph: &DependencyGraph,
system: &EquationSystem,
context: &SystemContext,
) -> Result<SolutionStrategy, SystemError> {
let mut strategy = SolutionStrategy::new();
let mut known = context.known_variable_names();
let mut remaining_targets: HashSet<_> = context.target_variables.iter().cloned().collect();
let mut iterations = 0;
while !remaining_targets.is_empty() && iterations < self.config.max_iterations {
iterations += 1;
let solvable = graph.find_solvable(&known);
if solvable.is_empty() {
if self.config.use_numerical_fallback && !remaining_targets.is_empty() {
for target in &remaining_targets {
if let Some(eqs) = graph.equations_with_variable(target) {
for eq_id in eqs {
strategy.add_step(SolveStep {
equation_id: eq_id.clone(),
solve_for: target.clone(),
method: SolveMethod::Numerical,
dependencies: known.iter().cloned().collect(),
});
break;
}
}
}
break;
} else {
return Err(SystemError::NoStrategyFound(format!(
"Cannot determine solving order for: {:?}",
remaining_targets
)));
}
}
for (eq_id, var) in solvable {
if remaining_targets.contains(&var) || !known.contains(&var) {
let method = if let Some(eq) = system.get(&eq_id) {
match &eq.equation_type {
EquationType::Algebraic | EquationType::Unknown => {
SolveMethod::Algebraic
}
EquationType::ODE(_info) => SolveMethod::ODE {
method: "auto".to_string(),
},
EquationType::Integral(_) => SolveMethod::Integration,
EquationType::Differential => SolveMethod::Differentiation,
EquationType::Implicit => SolveMethod::Numerical,
}
} else {
SolveMethod::Algebraic
};
strategy.add_step(SolveStep {
equation_id: eq_id,
solve_for: var.clone(),
method,
dependencies: known.iter().cloned().collect(),
});
known.insert(var.clone());
remaining_targets.remove(&var);
}
}
}
if remaining_targets.is_empty() || !strategy.is_empty() {
Ok(strategy)
} else {
Err(SystemError::NoStrategyFound(format!(
"Could not find strategy for: {:?}",
remaining_targets
)))
}
}
fn execute_strategy(
&self,
strategy: &SolutionStrategy,
system: &EquationSystem,
context: &SystemContext,
) -> Result<MultiEquationSolution, SystemError> {
let mut solution = MultiEquationSolution::new(context.clone());
let mut known_exprs: HashMap<String, Expression> = HashMap::new();
let mut known_values: HashMap<String, f64> = context.known_values.clone();
for (var, val) in &context.known_values {
known_exprs.insert(var.clone(), Expression::Float(*val));
}
for (var, expr) in &context.known_expressions {
known_exprs.insert(var.clone(), expr.clone());
}
let mut step_number = 0;
for step in &strategy.steps {
step_number += 1;
let eq = system
.get(&step.equation_id)
.ok_or_else(|| SystemError::EquationNotFound(step.equation_id.clone()))?;
let substituted_eq = self.substitute_known(&eq.equation, &known_exprs);
let result = match &step.method {
SolveMethod::Algebraic => {
self.solve_algebraic(&substituted_eq, &step.solve_for, &known_values)
}
SolveMethod::Numerical => {
self.solve_numerical(&substituted_eq, &step.solve_for, &known_values)
}
SolveMethod::ODE { method } => {
self.solve_ode(&substituted_eq, &step.solve_for, method)
}
SolveMethod::Integration => {
self.solve_integration(&substituted_eq, &step.solve_for)
}
SolveMethod::Substitution => {
self.solve_by_substitution(&substituted_eq, &step.solve_for, &known_values)
}
_ => Err(SystemError::UnsolvableEquation {
id: step.equation_id.clone(),
reason: format!("Method {:?} not implemented", step.method),
}),
};
match result {
Ok((value, eq_path)) => {
let expr = value.to_expression();
solution.resolution_path.add_step(SystemStep {
step_number,
equation_id: step.equation_id.clone(),
operation: SystemOperation::SolveFor {
variable: step.solve_for.clone(),
method: step.method.clone(),
},
explanation: format!(
"From equation '{}', solved {} = {}",
step.equation_id, step.solve_for, expr
),
result: StepResult::Expression(expr.clone()),
});
known_exprs.insert(step.solve_for.clone(), expr.clone());
if let Some(num) = value.as_numeric() {
known_values.insert(step.solve_for.clone(), num);
}
if let Some(path) = eq_path {
solution
.resolution_path
.add_equation_path(step.equation_id.clone(), path);
}
solution.add_solution(step.solve_for.clone(), value);
}
Err(e) => {
solution.add_warning(format!(
"Failed to solve for {} in {}: {}",
step.solve_for, step.equation_id, e
));
solution.mark_unsolved(step.solve_for.clone());
}
}
}
if self.config.verify_solutions {
self.verify_solutions(&solution, system, context);
}
Ok(solution)
}
fn substitute_known(
&self,
equation: &Equation,
known: &HashMap<String, Expression>,
) -> Equation {
Equation {
id: equation.id.clone(),
left: self.substitute_expr(&equation.left, known),
right: self.substitute_expr(&equation.right, known),
}
}
fn substitute_expr(
&self,
expr: &Expression,
known: &HashMap<String, Expression>,
) -> Expression {
match expr {
Expression::Variable(var) => {
if let Some(val) = known.get(&var.name) {
val.clone()
} else {
expr.clone()
}
}
Expression::Binary(op, left, right) => Expression::Binary(
*op,
Box::new(self.substitute_expr(left, known)),
Box::new(self.substitute_expr(right, known)),
),
Expression::Unary(op, inner) => {
Expression::Unary(*op, Box::new(self.substitute_expr(inner, known)))
}
Expression::Power(base, exp) => Expression::Power(
Box::new(self.substitute_expr(base, known)),
Box::new(self.substitute_expr(exp, known)),
),
Expression::Function(func, args) => Expression::Function(
func.clone(),
args.iter()
.map(|a| self.substitute_expr(a, known))
.collect(),
),
_ => expr.clone(),
}
}
fn solve_algebraic(
&self,
equation: &Equation,
variable: &str,
known_values: &HashMap<String, f64>,
) -> Result<(SolutionValue, Option<ResolutionPath>), SystemError> {
let var = Variable::new(variable);
match self.algebraic_solver.solve(equation, &var) {
Ok((sol, path)) => {
let value = match sol {
Solution::Unique(expr) => {
match expr.evaluate(known_values) {
Some(num) => SolutionValue::Numeric(num),
None => SolutionValue::Symbolic(expr),
}
}
Solution::Multiple(exprs) => SolutionValue::Multiple(exprs),
Solution::None => {
return Err(SystemError::UnsolvableEquation {
id: "algebraic".to_string(),
reason: "No solution exists".to_string(),
})
}
Solution::Infinite => {
return Err(SystemError::UnsolvableEquation {
id: "algebraic".to_string(),
reason: "Infinite solutions".to_string(),
})
}
Solution::Parametric {
expression,
constraints: _,
} => SolutionValue::Parametric {
expr: expression,
parameter: "t".to_string(), },
};
Ok((value, Some(path)))
}
Err(e) => Err(SystemError::SolverError(format!("{:?}", e))),
}
}
fn solve_numerical(
&self,
equation: &Equation,
variable: &str,
_known_values: &HashMap<String, f64>,
) -> Result<(SolutionValue, Option<ResolutionPath>), SystemError> {
let var = Variable::new(variable);
match self.numerical_solver.solve(equation, &var) {
Ok((sol, path)) => {
let value = SolutionValue::Numeric(sol.value);
Ok((value, Some(path)))
}
Err(e) => Err(SystemError::NumericalFailure {
variable: variable.to_string(),
reason: format!("{:?}", e),
}),
}
}
fn solve_ode(
&self,
equation: &Equation,
variable: &str,
_method: &str,
) -> Result<(SolutionValue, Option<ResolutionPath>), SystemError> {
let ode = FirstOrderODE {
dependent: variable.to_string(),
independent: "x".to_string(),
rhs: equation.right.clone(),
};
let solution = solve_separable(&ode).or_else(|_| solve_linear_ode(&ode));
match solution {
Ok(sol) => Ok((SolutionValue::Symbolic(sol.general_solution), None)),
Err(e) => Err(SystemError::UnsolvableEquation {
id: "ode".to_string(),
reason: format!("ODE solver failed: {e:?}"),
}),
}
}
fn solve_integration(
&self,
equation: &Equation,
variable: &str,
) -> Result<(SolutionValue, Option<ResolutionPath>), SystemError> {
let integrated = integrate(&equation.left, variable);
match integrated {
Ok(result) => Ok((SolutionValue::Symbolic(result), None)),
Err(e) => Err(SystemError::UnsolvableEquation {
id: "integration".to_string(),
reason: format!("Integration failed: {e:?}"),
}),
}
}
fn solve_by_substitution(
&self,
equation: &Equation,
variable: &str,
known_values: &HashMap<String, f64>,
) -> Result<(SolutionValue, Option<ResolutionPath>), SystemError> {
if let Expression::Variable(var) = &equation.left {
if var.name == variable {
match equation.right.evaluate(known_values) {
Some(val) => return Ok((SolutionValue::Numeric(val), None)),
None => return Ok((SolutionValue::Symbolic(equation.right.clone()), None)),
}
}
}
if let Expression::Variable(var) = &equation.right {
if var.name == variable {
match equation.left.evaluate(known_values) {
Some(val) => return Ok((SolutionValue::Numeric(val), None)),
None => return Ok((SolutionValue::Symbolic(equation.left.clone()), None)),
}
}
}
self.solve_algebraic(equation, variable, known_values)
}
fn verify_solutions(
&self,
solution: &MultiEquationSolution,
system: &EquationSystem,
_context: &SystemContext,
) {
let mut all_values: HashMap<String, f64> = HashMap::new();
for (var, val) in &solution.solutions {
if let Some(num) = val.as_numeric() {
all_values.insert(var.clone(), num);
}
}
for eq in system.equations() {
let lhs_val = eq.equation.left.evaluate(&all_values);
let rhs_val = eq.equation.right.evaluate(&all_values);
match (lhs_val, rhs_val) {
(Some(l), Some(r)) => {
let diff = (l - r).abs();
if diff > self.config.tolerance {
eprintln!(
"Warning: Equation '{}' verification failed: {} != {} (diff: {})",
eq.id, l, r, diff
);
}
}
_ => {
}
}
}
}
}
impl Default for MultiEquationSolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_equation;
#[test]
fn test_equation_system_creation() {
let mut system = EquationSystem::new();
system.add_equation("eq1", parse_equation("x + y = 10").unwrap());
system.add_equation("eq2", parse_equation("x - y = 2").unwrap());
assert_eq!(system.len(), 2);
assert!(system.get("eq1").is_some());
assert!(system.get("eq2").is_some());
}
#[test]
fn test_equation_system_variables() {
let system = EquationSystem::new()
.with_equation("eq1", parse_equation("F = m * a").unwrap())
.with_equation("eq2", parse_equation("v = u + a * t").unwrap());
let vars = system.all_variables();
assert!(vars.contains("F"));
assert!(vars.contains("m"));
assert!(vars.contains("a"));
assert!(vars.contains("v"));
assert!(vars.contains("u"));
assert!(vars.contains("t"));
}
#[test]
fn test_context_builder() {
let context = SystemContext::new()
.with_known_value("F", 100.0)
.with_known_value("m", 20.0)
.with_target("a");
assert_eq!(context.known_values.get("F"), Some(&100.0));
assert_eq!(context.known_values.get("m"), Some(&20.0));
assert!(context.target_variables.contains(&"a".to_string()));
}
#[test]
fn test_dependency_graph() {
let system = EquationSystem::new()
.with_equation("eq1", parse_equation("F = m * a").unwrap())
.with_equation("eq2", parse_equation("v = u + a * t").unwrap());
let graph = DependencyGraph::build(&system);
let a_eqs = graph.equations_with_variable("a").unwrap();
assert!(a_eqs.contains("eq1"));
assert!(a_eqs.contains("eq2"));
let eq1_vars = graph.variables_in_equation("eq1").unwrap();
assert!(eq1_vars.contains("F"));
assert!(eq1_vars.contains("m"));
assert!(eq1_vars.contains("a"));
}
#[test]
fn test_find_solvable() {
let system = EquationSystem::new()
.with_equation("eq1", parse_equation("F = m * a").unwrap())
.with_equation("eq2", parse_equation("v = u + a * t").unwrap());
let graph = DependencyGraph::build(&system);
let mut known = HashSet::new();
known.insert("F".to_string());
known.insert("m".to_string());
let solvable = graph.find_solvable(&known);
assert!(solvable.iter().any(|(eq, var)| eq == "eq1" && var == "a"));
}
#[test]
fn test_simple_linear_system() {
let system =
EquationSystem::new().with_equation("eq1", parse_equation("F = m * a").unwrap());
let context = SystemContext::new()
.with_known_value("F", 100.0)
.with_known_value("m", 20.0)
.with_target("a");
let solver = MultiEquationSolver::new();
let solution = solver.solve(&system, &context).unwrap();
let a = solution.get_numeric("a").unwrap();
assert!((a - 5.0).abs() < 1e-10);
}
#[test]
fn test_chained_equations() {
let system = EquationSystem::new()
.with_equation("eq1", parse_equation("F = m * a").unwrap())
.with_equation("eq2", parse_equation("v = u + a * t").unwrap());
let context = SystemContext::new()
.with_known_value("F", 100.0)
.with_known_value("m", 20.0)
.with_known_value("u", 0.0)
.with_known_value("t", 5.0)
.with_target("a")
.with_target("v");
let solver = MultiEquationSolver::new();
let solution = solver.solve(&system, &context).unwrap();
let a = solution.get_numeric("a").unwrap();
assert!((a - 5.0).abs() < 1e-10);
let v = solution.get_numeric("v").unwrap();
assert!((v - 25.0).abs() < 1e-10);
}
#[test]
fn test_resolution_path() {
let system =
EquationSystem::new().with_equation("eq1", parse_equation("F = m * a").unwrap());
let context = SystemContext::new()
.with_known_value("F", 100.0)
.with_known_value("m", 20.0)
.with_target("a");
let solver = MultiEquationSolver::new();
let solution = solver.solve(&system, &context).unwrap();
assert!(!solution.resolution_path.steps.is_empty());
let text = solution.resolution_path.format_text();
assert!(text.contains("Multi-Equation System Solution"));
}
#[test]
fn test_insufficient_equations() {
let system =
EquationSystem::new().with_equation("eq1", parse_equation("x + y = 10").unwrap());
let context = SystemContext::new().with_target("x").with_target("y");
let solver = MultiEquationSolver::new();
let result = solver.solve(&system, &context);
assert!(result.is_err() || !result.unwrap().is_complete());
}
#[test]
fn test_solution_value_conversion() {
let numeric = SolutionValue::Numeric(42.0);
assert_eq!(numeric.as_numeric(), Some(42.0));
let symbolic = SolutionValue::Symbolic(Expression::Variable(Variable::new("x")));
assert!(symbolic.as_numeric().is_none());
assert!(symbolic.as_expression().is_some());
}
#[test]
fn test_quadratic_in_system() {
let system =
EquationSystem::new().with_equation("eq1", parse_equation("y = x * x").unwrap());
let context = SystemContext::new()
.with_known_value("y", 16.0)
.with_target("x");
let solver = MultiEquationSolver::new();
let solution = solver.solve(&system, &context);
if let Ok(sol) = solution {
if let Some(x) = sol.get_numeric("x") {
assert!((x.abs() - 4.0).abs() < 1e-10);
}
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum NonlinearSystemSolverError {
NoConvergence {
iterations: usize,
final_residual: f64,
},
SingularJacobian {
condition_estimate: Option<f64>,
},
DimensionMismatch {
num_equations: usize,
num_variables: usize,
},
EvaluationFailed {
point: Vec<f64>,
reason: String,
},
InvalidConfig(String),
DifferentiationFailed(String),
}
impl fmt::Display for NonlinearSystemSolverError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NoConvergence {
iterations,
final_residual,
} => {
write!(
f,
"No convergence after {} iterations (residual: {:.2e})",
iterations, final_residual
)
}
Self::SingularJacobian { condition_estimate } => {
if let Some(cond) = condition_estimate {
write!(f, "Singular Jacobian (condition ~{:.2e})", cond)
} else {
write!(f, "Singular Jacobian")
}
}
Self::DimensionMismatch {
num_equations,
num_variables,
} => {
write!(
f,
"Dimension mismatch: {} equations, {} variables",
num_equations, num_variables
)
}
Self::EvaluationFailed { point, reason } => {
write!(f, "Evaluation failed at {:?}: {}", point, reason)
}
Self::InvalidConfig(msg) => write!(f, "Invalid configuration: {}", msg),
Self::DifferentiationFailed(msg) => write!(f, "Differentiation failed: {}", msg),
}
}
}
impl std::error::Error for NonlinearSystemSolverError {}
#[derive(Debug, Clone)]
pub struct NonlinearSystemConfig {
pub max_iterations: usize,
pub tolerance: f64,
pub step_tolerance: f64,
pub damping_factor: f64,
pub use_line_search: bool,
pub finite_diff_epsilon: f64,
pub min_step_size: f64,
pub regularization: f64,
}
impl Default for NonlinearSystemConfig {
fn default() -> Self {
Self {
max_iterations: 100,
tolerance: 1e-10,
step_tolerance: 1e-12,
damping_factor: 1.0,
use_line_search: false,
finite_diff_epsilon: 1e-8,
min_step_size: 1e-10,
regularization: 1e-12,
}
}
}
impl NonlinearSystemConfig {
pub fn damped() -> Self {
Self {
damping_factor: 0.5,
use_line_search: true,
..Default::default()
}
}
pub fn for_broyden() -> Self {
Self {
max_iterations: 200,
tolerance: 1e-8,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct NonlinearSystemSolverResult {
pub solution: Vec<f64>,
pub iterations: usize,
pub final_residual: f64,
pub convergence_history: Vec<f64>,
pub converged: bool,
pub method: String,
pub variable_names: Vec<String>,
}
impl NonlinearSystemSolverResult {
pub fn as_map(&self) -> HashMap<String, f64> {
self.variable_names
.iter()
.zip(self.solution.iter())
.map(|(name, &val)| (name.clone(), val))
.collect()
}
pub fn convergence_rate(&self) -> Option<f64> {
if self.convergence_history.len() < 3 {
return None;
}
let n = self.convergence_history.len();
let r1 = self.convergence_history[n - 2];
let r2 = self.convergence_history[n - 1];
let r0 = self.convergence_history[n - 3];
if r0 > 1e-15 && r1 > 1e-15 {
Some(r2 / r1)
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct ConvergenceDiagnostics {
pub residual_history: Vec<f64>,
pub step_history: Vec<f64>,
pub estimated_rate: Option<f64>,
pub behavior: ConvergenceBehavior,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConvergenceBehavior {
Quadratic,
Linear,
Sublinear,
Oscillating,
Diverging,
Stalled,
}
impl ConvergenceDiagnostics {
pub fn analyze(residuals: &[f64], steps: &[f64]) -> Self {
let behavior = if residuals.len() < 3 {
ConvergenceBehavior::Linear
} else {
let n = residuals.len();
let r1 = residuals[n - 2];
let r2 = residuals[n - 1];
let r0 = residuals[n - 3];
if r2 > r1 * 1.1 {
ConvergenceBehavior::Diverging
} else if r2 > r1 * 0.99 {
ConvergenceBehavior::Stalled
} else if r0 > 1e-15 && r1 > 1e-15 {
let rate1 = r1 / r0;
let rate2 = r2 / r1;
if rate2 < rate1 * rate1 * 2.0 {
ConvergenceBehavior::Quadratic
} else if rate2 < 0.9 {
ConvergenceBehavior::Linear
} else {
ConvergenceBehavior::Sublinear
}
} else {
ConvergenceBehavior::Linear
}
};
let estimated_rate = if residuals.len() >= 2 {
let n = residuals.len();
if residuals[n - 2] > 1e-15 {
Some(residuals[n - 1] / residuals[n - 2])
} else {
None
}
} else {
None
};
Self {
residual_history: residuals.to_vec(),
step_history: steps.to_vec(),
estimated_rate,
behavior,
}
}
}
#[derive(Debug, Clone)]
pub struct NonlinearSystem {
pub equations: Vec<Expression>,
pub variables: Vec<Variable>,
}
impl NonlinearSystem {
pub fn new(equations: Vec<Expression>, variables: Vec<Variable>) -> Self {
Self {
equations,
variables,
}
}
pub fn from_equations(equations: Vec<Equation>, variables: Vec<Variable>) -> Self {
let exprs: Vec<Expression> = equations
.into_iter()
.map(|eq| Expression::Binary(BinaryOp::Sub, Box::new(eq.left), Box::new(eq.right)))
.collect();
Self::new(exprs, variables)
}
pub fn num_equations(&self) -> usize {
self.equations.len()
}
pub fn num_variables(&self) -> usize {
self.variables.len()
}
pub fn is_square(&self) -> bool {
self.num_equations() == self.num_variables()
}
pub fn evaluate(&self, point: &[f64]) -> Result<Vec<f64>, NonlinearSystemSolverError> {
if point.len() != self.variables.len() {
return Err(NonlinearSystemSolverError::DimensionMismatch {
num_equations: self.equations.len(),
num_variables: point.len(),
});
}
let var_map: HashMap<String, f64> = self
.variables
.iter()
.zip(point.iter())
.map(|(v, &val)| (v.name.clone(), val))
.collect();
let mut result = Vec::with_capacity(self.equations.len());
for (i, eq) in self.equations.iter().enumerate() {
match eq.evaluate(&var_map) {
Some(val) => result.push(val),
None => {
return Err(NonlinearSystemSolverError::EvaluationFailed {
point: point.to_vec(),
reason: format!("Could not evaluate equation {}", i),
})
}
}
}
Ok(result)
}
pub fn jacobian(&self) -> Result<Vec<Vec<Expression>>, NonlinearSystemSolverError> {
let mut jacobian = Vec::with_capacity(self.equations.len());
for eq in &self.equations {
let mut row = Vec::with_capacity(self.variables.len());
for var in &self.variables {
let deriv = eq.differentiate(&var.name);
row.push(deriv);
}
jacobian.push(row);
}
Ok(jacobian)
}
pub fn evaluate_jacobian(
&self,
point: &[f64],
) -> Result<Vec<Vec<f64>>, NonlinearSystemSolverError> {
let symbolic_jacobian = self.jacobian()?;
let var_map: HashMap<String, f64> = self
.variables
.iter()
.zip(point.iter())
.map(|(v, &val)| (v.name.clone(), val))
.collect();
let mut result = Vec::with_capacity(symbolic_jacobian.len());
for (i, row) in symbolic_jacobian.iter().enumerate() {
let mut eval_row = Vec::with_capacity(row.len());
for (j, expr) in row.iter().enumerate() {
match expr.evaluate(&var_map) {
Some(val) => eval_row.push(val),
None => {
return Err(NonlinearSystemSolverError::EvaluationFailed {
point: point.to_vec(),
reason: format!("Could not evaluate J[{}][{}]", i, j),
})
}
}
}
result.push(eval_row);
}
Ok(result)
}
pub fn variable_names(&self) -> Vec<String> {
self.variables.iter().map(|v| v.name.clone()).collect()
}
}
pub fn residual_norm(residuals: &[f64]) -> f64 {
residuals.iter().map(|r| r * r).sum::<f64>().sqrt()
}
pub fn solve_linear_system_lu(
matrix: &[Vec<f64>],
rhs: &[f64],
) -> Result<Vec<f64>, NonlinearSystemSolverError> {
let n = matrix.len();
if n == 0 || rhs.len() != n {
return Err(NonlinearSystemSolverError::DimensionMismatch {
num_equations: n,
num_variables: rhs.len(),
});
}
let mut lu: Vec<Vec<f64>> = matrix.to_vec();
let mut p: Vec<usize> = (0..n).collect(); let b = rhs.to_vec();
for k in 0..n {
let mut max_val = lu[k][k].abs();
let mut max_row = k;
for i in (k + 1)..n {
if lu[i][k].abs() > max_val {
max_val = lu[i][k].abs();
max_row = i;
}
}
if max_val < 1e-15 {
return Err(NonlinearSystemSolverError::SingularJacobian {
condition_estimate: None,
});
}
if max_row != k {
lu.swap(k, max_row);
p.swap(k, max_row);
}
for i in (k + 1)..n {
lu[i][k] /= lu[k][k];
for j in (k + 1)..n {
lu[i][j] -= lu[i][k] * lu[k][j];
}
}
}
let mut pb = vec![0.0; n];
for i in 0..n {
pb[i] = b[p[i]];
}
let mut y = pb;
for i in 1..n {
for j in 0..i {
y[i] -= lu[i][j] * y[j];
}
}
let mut x = vec![0.0; n];
for i in (0..n).rev() {
x[i] = y[i];
for j in (i + 1)..n {
x[i] -= lu[i][j] * x[j];
}
x[i] /= lu[i][i];
}
Ok(x)
}
pub fn newton_raphson_system(
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
if !system.is_square() {
return Err(NonlinearSystemSolverError::DimensionMismatch {
num_equations: system.num_equations(),
num_variables: system.num_variables(),
});
}
let n = system.num_variables();
let mut x = initial_guess.to_vec();
let mut convergence_history = Vec::new();
for iter in 0..config.max_iterations {
let f = system.evaluate(&x)?;
let residual = residual_norm(&f);
convergence_history.push(residual);
if residual < config.tolerance {
return Ok(NonlinearSystemSolverResult {
solution: x,
iterations: iter,
final_residual: residual,
convergence_history,
converged: true,
method: "Newton-Raphson".to_string(),
variable_names: system.variable_names(),
});
}
let jacobian = system.evaluate_jacobian(&x)?;
let mut reg_jacobian = jacobian.clone();
for i in 0..n {
reg_jacobian[i][i] += config.regularization;
}
let neg_f: Vec<f64> = f.iter().map(|v| -v).collect();
let delta = solve_linear_system_lu(®_jacobian, &neg_f)?;
let step_norm: f64 = delta.iter().map(|d| d * d).sum::<f64>().sqrt();
if step_norm < config.step_tolerance {
return Ok(NonlinearSystemSolverResult {
solution: x,
iterations: iter,
final_residual: residual,
convergence_history,
converged: true,
method: "Newton-Raphson".to_string(),
variable_names: system.variable_names(),
});
}
let alpha = if config.use_line_search {
let mut alpha = 1.0;
let c = 0.5;
let rho = 0.5;
for _ in 0..20 {
let x_new: Vec<f64> = x
.iter()
.zip(delta.iter())
.map(|(xi, di)| xi + alpha * di)
.collect();
if let Ok(f_new) = system.evaluate(&x_new) {
let new_residual = residual_norm(&f_new);
if new_residual < residual * (1.0 - c * alpha) {
break;
}
}
alpha *= rho;
if alpha < config.min_step_size {
break;
}
}
alpha
} else {
config.damping_factor
};
for i in 0..n {
x[i] += alpha * delta[i];
}
}
let f = system.evaluate(&x)?;
let final_residual = residual_norm(&f);
Err(NonlinearSystemSolverError::NoConvergence {
iterations: config.max_iterations,
final_residual,
})
}
pub fn fixed_point_system(
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
if !system.is_square() {
return Err(NonlinearSystemSolverError::DimensionMismatch {
num_equations: system.num_equations(),
num_variables: system.num_variables(),
});
}
let n = system.num_variables();
let mut x = initial_guess.to_vec();
let mut convergence_history = Vec::new();
for iter in 0..config.max_iterations {
let f = system.evaluate(&x)?;
let residual = residual_norm(&f);
convergence_history.push(residual);
if residual < config.tolerance {
return Ok(NonlinearSystemSolverResult {
solution: x,
iterations: iter,
final_residual: residual,
convergence_history,
converged: true,
method: "Fixed-Point".to_string(),
variable_names: system.variable_names(),
});
}
let mut x_new = vec![0.0; n];
for i in 0..n {
x_new[i] = x[i] - config.damping_factor * f[i];
}
if let Ok(f_new) = system.evaluate(&x_new) {
let new_residual = residual_norm(&f_new);
if new_residual > residual * 10.0 {
return Err(NonlinearSystemSolverError::NoConvergence {
iterations: iter,
final_residual: new_residual,
});
}
}
x = x_new;
}
let f = system.evaluate(&x)?;
let final_residual = residual_norm(&f);
Err(NonlinearSystemSolverError::NoConvergence {
iterations: config.max_iterations,
final_residual,
})
}
pub fn broyden_system(
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
if !system.is_square() {
return Err(NonlinearSystemSolverError::DimensionMismatch {
num_equations: system.num_equations(),
num_variables: system.num_variables(),
});
}
let n = system.num_variables();
let mut x = initial_guess.to_vec();
let mut convergence_history = Vec::new();
let mut b = system.evaluate_jacobian(&x)?;
let mut f = system.evaluate(&x)?;
let mut residual = residual_norm(&f);
convergence_history.push(residual);
for iter in 0..config.max_iterations {
if residual < config.tolerance {
return Ok(NonlinearSystemSolverResult {
solution: x,
iterations: iter,
final_residual: residual,
convergence_history,
converged: true,
method: "Broyden".to_string(),
variable_names: system.variable_names(),
});
}
let mut reg_b = b.clone();
for i in 0..n {
reg_b[i][i] += config.regularization;
}
let neg_f: Vec<f64> = f.iter().map(|v| -v).collect();
let s = solve_linear_system_lu(®_b, &neg_f)?;
let mut x_new = x.clone();
for i in 0..n {
x_new[i] += config.damping_factor * s[i];
}
let f_new = system.evaluate(&x_new)?;
let new_residual = residual_norm(&f_new);
let y: Vec<f64> = f_new.iter().zip(f.iter()).map(|(a, b)| a - b).collect();
let mut bs = vec![0.0; n];
for i in 0..n {
for j in 0..n {
bs[i] += b[i][j] * s[j];
}
}
let diff: Vec<f64> = y.iter().zip(bs.iter()).map(|(a, b)| a - b).collect();
let s_dot_s: f64 = s.iter().map(|si| si * si).sum();
if s_dot_s > 1e-15 {
for i in 0..n {
for j in 0..n {
b[i][j] += diff[i] * s[j] / s_dot_s;
}
}
}
x = x_new;
f = f_new;
residual = new_residual;
convergence_history.push(residual);
}
Err(NonlinearSystemSolverError::NoConvergence {
iterations: config.max_iterations,
final_residual: residual,
})
}
pub trait NonlinearSystemSolver {
fn solve(
&self,
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError>;
fn method_name(&self) -> &str;
}
pub struct NewtonRaphsonSolver;
impl NonlinearSystemSolver for NewtonRaphsonSolver {
fn solve(
&self,
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
newton_raphson_system(system, initial_guess, config)
}
fn method_name(&self) -> &str {
"Newton-Raphson"
}
}
pub struct BroydenSolver;
impl NonlinearSystemSolver for BroydenSolver {
fn solve(
&self,
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
broyden_system(system, initial_guess, config)
}
fn method_name(&self) -> &str {
"Broyden"
}
}
pub struct FixedPointSolver;
impl NonlinearSystemSolver for FixedPointSolver {
fn solve(
&self,
system: &NonlinearSystem,
initial_guess: &[f64],
config: &NonlinearSystemConfig,
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
fixed_point_system(system, initial_guess, config)
}
fn method_name(&self) -> &str {
"Fixed-Point"
}
}
pub struct SmartNonlinearSystemSolver {
config: Option<NonlinearSystemConfig>,
}
impl SmartNonlinearSystemSolver {
pub fn new() -> Self {
Self { config: None }
}
pub fn with_config(config: NonlinearSystemConfig) -> Self {
Self {
config: Some(config),
}
}
pub fn solve(
&self,
system: &NonlinearSystem,
initial_guess: &[f64],
) -> Result<NonlinearSystemSolverResult, NonlinearSystemSolverError> {
let config = self.config.clone().unwrap_or_default();
match newton_raphson_system(system, initial_guess, &config) {
Ok(result) if result.converged => return Ok(result),
_ => {}
}
let broyden_config = NonlinearSystemConfig {
max_iterations: config.max_iterations * 2,
damping_factor: 0.8,
..config.clone()
};
match broyden_system(system, initial_guess, &broyden_config) {
Ok(result) if result.converged => return Ok(result),
_ => {}
}
let damped_config = NonlinearSystemConfig {
damping_factor: 0.5,
use_line_search: true,
max_iterations: config.max_iterations * 2,
..config
};
newton_raphson_system(system, initial_guess, &damped_config)
}
pub fn find_all_solutions(
&self,
system: &NonlinearSystem,
initial_guesses: &[Vec<f64>],
) -> Vec<NonlinearSystemSolverResult> {
let config = self.config.clone().unwrap_or_default();
let mut solutions = Vec::new();
let tolerance = config.tolerance * 100.0;
for guess in initial_guesses {
if let Ok(result) = self.solve(system, guess) {
if result.converged {
let is_new = solutions
.iter()
.all(|existing: &NonlinearSystemSolverResult| {
let diff: f64 = existing
.solution
.iter()
.zip(result.solution.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
diff > tolerance
});
if is_new {
solutions.push(result);
}
}
}
}
solutions
}
}
impl Default for SmartNonlinearSystemSolver {
fn default() -> Self {
Self::new()
}
}
pub fn validate_jacobian(
system: &NonlinearSystem,
point: &[f64],
epsilon: f64,
) -> Result<(Vec<Vec<f64>>, Vec<Vec<f64>>, f64), NonlinearSystemSolverError> {
let analytic = system.evaluate_jacobian(point)?;
let n = system.num_variables();
let m = system.num_equations();
let f0 = system.evaluate(point)?;
let mut numeric = vec![vec![0.0; n]; m];
for j in 0..n {
let mut point_plus = point.to_vec();
point_plus[j] += epsilon;
let f_plus = system.evaluate(&point_plus)?;
for i in 0..m {
numeric[i][j] = (f_plus[i] - f0[i]) / epsilon;
}
}
let mut max_diff = 0.0;
for i in 0..m {
for j in 0..n {
let diff = (analytic[i][j] - numeric[i][j]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
Ok((analytic, numeric, max_diff))
}
#[cfg(test)]
mod nonlinear_tests {
use super::*;
fn make_circle_line_system() -> NonlinearSystem {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Power(
Box::new(Expression::Variable(x.clone())),
Box::new(Expression::Integer(2)),
)),
Box::new(Expression::Power(
Box::new(Expression::Variable(y.clone())),
Box::new(Expression::Integer(2)),
)),
)),
Box::new(Expression::Integer(1)),
);
let eq2 = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Variable(x.clone())),
Box::new(Expression::Variable(y.clone())),
);
NonlinearSystem::new(vec![eq1, eq2], vec![x, y])
}
fn make_hyperbola_line_system() -> NonlinearSystem {
let x = Variable::new("x");
let y = Variable::new("y");
let eq1 = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Binary(
BinaryOp::Mul,
Box::new(Expression::Variable(x.clone())),
Box::new(Expression::Variable(y.clone())),
)),
Box::new(Expression::Integer(1)),
);
let eq2 = Expression::Binary(
BinaryOp::Sub,
Box::new(Expression::Binary(
BinaryOp::Add,
Box::new(Expression::Variable(x.clone())),
Box::new(Expression::Variable(y.clone())),
)),
Box::new(Expression::Integer(3)),
);
NonlinearSystem::new(vec![eq1, eq2], vec![x, y])
}
#[test]
fn test_nonlinear_system_creation() {
let system = make_circle_line_system();
assert_eq!(system.num_equations(), 2);
assert_eq!(system.num_variables(), 2);
assert!(system.is_square());
}
#[test]
fn test_evaluate() {
let system = make_circle_line_system();
let result = system.evaluate(&[1.0, 0.0]).unwrap();
assert!((result[0] - 0.0).abs() < 1e-10);
assert!((result[1] - 1.0).abs() < 1e-10);
let result = system.evaluate(&[0.5, 0.5]).unwrap();
assert!((result[0] - (-0.5)).abs() < 1e-10);
assert!((result[1] - 0.0).abs() < 1e-10);
}
#[test]
fn test_residual_norm() {
let residuals = vec![3.0, 4.0];
assert!((residual_norm(&residuals) - 5.0).abs() < 1e-10);
}
#[test]
fn test_solve_linear_system_lu() {
let matrix = vec![vec![1.0, 1.0], vec![1.0, -1.0]];
let rhs = vec![3.0, 1.0];
let solution = solve_linear_system_lu(&matrix, &rhs).unwrap();
assert!((solution[0] - 2.0).abs() < 1e-10);
assert!((solution[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_newton_raphson_circle_line() {
let system = make_circle_line_system();
let config = NonlinearSystemConfig::default();
let result = newton_raphson_system(&system, &[0.5, 0.5], &config).unwrap();
assert!(result.converged);
let expected = std::f64::consts::FRAC_1_SQRT_2;
assert!((result.solution[0] - expected).abs() < 1e-8);
assert!((result.solution[1] - expected).abs() < 1e-8);
}
#[test]
fn test_newton_raphson_negative_solution() {
let system = make_circle_line_system();
let config = NonlinearSystemConfig::default();
let result = newton_raphson_system(&system, &[-0.5, -0.5], &config).unwrap();
assert!(result.converged);
let expected = -std::f64::consts::FRAC_1_SQRT_2;
assert!((result.solution[0] - expected).abs() < 1e-8);
assert!((result.solution[1] - expected).abs() < 1e-8);
}
#[test]
fn test_hyperbola_line_solution() {
let system = make_hyperbola_line_system();
let config = NonlinearSystemConfig::default();
let result = newton_raphson_system(&system, &[1.5, 1.5], &config).unwrap();
assert!(result.converged);
let x = result.solution[0];
let y = result.solution[1];
assert!((x * y - 1.0).abs() < 1e-8);
assert!((x + y - 3.0).abs() < 1e-8);
}
#[test]
fn test_broyden_circle_line() {
let system = make_circle_line_system();
let config = NonlinearSystemConfig::for_broyden();
let result = broyden_system(&system, &[0.5, 0.5], &config).unwrap();
assert!(result.converged);
let expected = std::f64::consts::FRAC_1_SQRT_2;
assert!((result.solution[0] - expected).abs() < 1e-6);
assert!((result.solution[1] - expected).abs() < 1e-6);
}
#[test]
fn test_jacobian_validation() {
let system = make_circle_line_system();
let point = vec![0.5, 0.5];
let epsilon = 1e-6;
let (analytic, _numeric, max_diff) = validate_jacobian(&system, &point, epsilon).unwrap();
assert!((analytic[0][0] - 1.0).abs() < 1e-8);
assert!((analytic[0][1] - 1.0).abs() < 1e-8);
assert!((analytic[1][0] - 1.0).abs() < 1e-8);
assert!((analytic[1][1] - (-1.0)).abs() < 1e-8);
assert!(max_diff < 1e-4);
}
#[test]
fn test_smart_solver() {
let system = make_circle_line_system();
let solver = SmartNonlinearSystemSolver::new();
let result = solver.solve(&system, &[0.5, 0.5]).unwrap();
assert!(result.converged);
}
#[test]
fn test_find_all_solutions() {
let system = make_circle_line_system();
let solver = SmartNonlinearSystemSolver::new();
let guesses = vec![
vec![0.5, 0.5],
vec![-0.5, -0.5],
vec![1.0, 0.0],
vec![-1.0, 0.0],
];
let solutions = solver.find_all_solutions(&system, &guesses);
assert_eq!(solutions.len(), 2);
let sqrt2_over_2 = std::f64::consts::FRAC_1_SQRT_2;
let has_positive = solutions
.iter()
.any(|s| (s.solution[0] - sqrt2_over_2).abs() < 1e-6);
let has_negative = solutions
.iter()
.any(|s| (s.solution[0] + sqrt2_over_2).abs() < 1e-6);
assert!(has_positive);
assert!(has_negative);
}
#[test]
fn test_convergence_diagnostics() {
let residuals = vec![1.0, 0.1, 0.01, 0.001];
let steps = vec![1.0, 0.5, 0.25, 0.125];
let diagnostics = ConvergenceDiagnostics::analyze(&residuals, &steps);
assert_eq!(diagnostics.behavior, ConvergenceBehavior::Linear);
assert!(diagnostics.estimated_rate.is_some());
assert!((diagnostics.estimated_rate.unwrap() - 0.1).abs() < 0.01);
}
#[test]
fn test_dimension_mismatch_error() {
let system = make_circle_line_system();
let config = NonlinearSystemConfig::default();
let result = newton_raphson_system(&system, &[0.5], &config);
assert!(matches!(
result,
Err(NonlinearSystemSolverError::DimensionMismatch { .. })
));
}
#[test]
fn test_compare_newton_vs_broyden_iterations() {
let system = make_hyperbola_line_system();
let initial_guess = [2.1, 0.9];
let newton_config = NonlinearSystemConfig::default();
let broyden_config = NonlinearSystemConfig {
max_iterations: 200,
tolerance: 1e-8,
..Default::default()
};
let newton_result = newton_raphson_system(&system, &initial_guess, &newton_config).unwrap();
let broyden_result = broyden_system(&system, &initial_guess, &broyden_config).unwrap();
assert!(newton_result.converged);
assert!(broyden_result.converged);
println!(
"Newton iterations: {}, Broyden iterations: {}",
newton_result.iterations, broyden_result.iterations
);
}
#[test]
fn test_solution_as_map() {
let system = make_circle_line_system();
let config = NonlinearSystemConfig::default();
let result = newton_raphson_system(&system, &[0.5, 0.5], &config).unwrap();
let map = result.as_map();
assert!(map.contains_key("x"));
assert!(map.contains_key("y"));
}
#[test]
fn test_nonlinear_system_solver_trait() {
let system = make_circle_line_system();
let config = NonlinearSystemConfig::default();
let solvers: Vec<Box<dyn NonlinearSystemSolver>> =
vec![Box::new(NewtonRaphsonSolver), Box::new(BroydenSolver)];
for solver in solvers {
let result = solver.solve(&system, &[0.5, 0.5], &config).unwrap();
assert!(result.converged);
println!("{}: {} iterations", solver.method_name(), result.iterations);
}
}
}