use crate::core::{Matrix, Vector, SparseMatrix, Complexity};
use crate::FTLError;
use std::time::{Duration, Instant};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SolverMethod {
Neumann,
RandomWalk,
ForwardPush,
BackwardPush,
Bidirectional,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SolverConfig {
pub method: SolverMethod,
pub epsilon: f64,
pub max_iterations: usize,
pub parallel: bool,
pub timeout: Duration,
}
impl Default for SolverConfig {
fn default() -> Self {
Self {
method: SolverMethod::Adaptive,
epsilon: 1e-6,
max_iterations: 100,
parallel: true,
timeout: Duration::from_millis(100),
}
}
}
pub struct SublinearSolver {
config: SolverConfig,
}
impl SublinearSolver {
pub fn new() -> Self {
Self {
config: SolverConfig::default(),
}
}
pub fn with_method(method: SolverMethod) -> Self {
let mut config = SolverConfig::default();
config.method = method;
Self { config }
}
pub fn with_config(config: SolverConfig) -> Self {
Self { config }
}
pub fn solve(&self, a: &Matrix, b: &Vector) -> crate::Result<SolverResult> {
let start = Instant::now();
self.validate_inputs(a, b)?;
let method = if self.config.method == SolverMethod::Adaptive {
self.select_best_method(a)
} else {
self.config.method
};
let solution = match method {
SolverMethod::Neumann => self.solve_neumann(a, b)?,
SolverMethod::RandomWalk => self.solve_random_walk(a, b)?,
SolverMethod::ForwardPush => self.solve_forward_push(a, b)?,
SolverMethod::BackwardPush => self.solve_backward_push(a, b)?,
SolverMethod::Bidirectional => self.solve_bidirectional(a, b)?,
SolverMethod::Adaptive => unreachable!(),
};
let elapsed = start.elapsed();
let complexity = self.estimate_complexity(a.shape().0, elapsed);
Ok(SolverResult {
solution,
method,
iterations: self.config.max_iterations,
residual: self.compute_residual(a, &solution, b),
time: elapsed,
complexity,
})
}
fn validate_inputs(&self, a: &Matrix, b: &Vector) -> crate::Result<()> {
let (rows, cols) = a.shape();
if rows != cols {
return Err(FTLError::MatrixError("Matrix must be square".to_string()));
}
if b.len() != rows {
return Err(FTLError::MatrixError(
"Vector dimension mismatch".to_string(),
));
}
if !self.is_diagonally_dominant(a) {
log::warn!("Matrix is not diagonally dominant - convergence not guaranteed");
}
Ok(())
}
fn is_diagonally_dominant(&self, a: &Matrix) -> bool {
let (n, _) = a.shape();
let view = a.view();
for i in 0..n {
let diagonal = view[[i, i]].abs();
let mut off_diagonal_sum = 0.0;
for j in 0..n {
if i != j {
off_diagonal_sum += view[[i, j]].abs();
}
}
if diagonal <= off_diagonal_sum {
return false;
}
}
true
}
fn select_best_method(&self, a: &Matrix) -> SolverMethod {
let sparse = a.to_sparse();
let sparsity = sparse.sparsity();
if sparsity > 0.95 {
SolverMethod::ForwardPush
} else if self.is_diagonally_dominant(a) {
SolverMethod::Neumann
} else {
SolverMethod::Bidirectional
}
}
fn solve_neumann(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
let n = b.len();
let mut x = b.clone();
let identity_minus_a = self.compute_iteration_matrix(a)?;
let iterations = (n as f64).log2().ceil() as usize;
let actual_iterations = iterations.min(self.config.max_iterations);
for _ in 0..actual_iterations {
let mx = identity_minus_a.multiply_vector(&x);
let new_x = b.add(&mx);
let diff = new_x.sub(&x).norm();
if diff < self.config.epsilon {
return Ok(new_x);
}
x = new_x;
}
Ok(x)
}
fn solve_random_walk(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
use rand::Rng;
let mut rng = rand::thread_rng();
let n = b.len();
let mut solution = Vector::zeros(n);
let num_walks = ((n as f64).log2() * 100.0) as usize;
let walk_length = (n as f64).log2().ceil() as usize;
for i in 0..n {
let mut estimate = 0.0;
for _ in 0..num_walks {
let mut current = i;
let mut weight = 1.0;
for _ in 0..walk_length {
let next = rng.gen_range(0..n);
weight *= a.view()[[current, next]];
current = next;
if weight.abs() < 1e-10 {
break;
}
}
estimate += weight * b.view()[current];
}
solution.data[i] = estimate / num_walks as f64;
}
Ok(solution)
}
fn solve_forward_push(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
let n = b.len();
let mut solution = b.clone();
let mut residual = b.clone();
let threshold = self.config.epsilon / (n as f64).sqrt();
let max_pushes = (n as f64).log2().ceil() as usize * 10;
for _ in 0..max_pushes {
let mut max_residual = 0.0;
let mut max_idx = 0;
for i in 0..n {
if residual.data[i].abs() > max_residual {
max_residual = residual.data[i].abs();
max_idx = i;
}
}
if max_residual < threshold {
break;
}
let push_value = residual.data[max_idx];
solution.data[max_idx] += push_value;
for j in 0..n {
residual.data[j] -= push_value * a.view()[[max_idx, j]];
}
residual.data[max_idx] = 0.0;
}
Ok(solution)
}
fn solve_backward_push(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
self.solve_forward_push(a, b) }
fn solve_bidirectional(&self, a: &Matrix, b: &Vector) -> crate::Result<Vector> {
let forward = self.solve_forward_push(a, b)?;
let backward = self.solve_backward_push(a, b)?;
Ok(forward.add(&backward).scale(0.5))
}
fn compute_iteration_matrix(&self, a: &Matrix) -> crate::Result<Matrix> {
let (n, _) = a.shape();
let mut m = Matrix::random(n, n);
for i in 0..n {
for j in 0..n {
if i == j {
m.data[[i, j]] = 0.0;
} else {
let diagonal = a.view()[[i, i]];
if diagonal.abs() > 1e-10 {
m.data[[i, j]] = -a.view()[[i, j]] / diagonal;
}
}
}
}
Ok(m)
}
fn compute_residual(&self, a: &Matrix, x: &Vector, b: &Vector) -> f64 {
let ax = a.multiply_vector(x);
ax.sub(b).norm()
}
fn estimate_complexity(&self, n: usize, elapsed: Duration) -> Complexity {
let nanos = elapsed.as_nanos() as f64;
let log_n = (n as f64).log2();
let ratios = vec![
(Complexity::Constant, 1.0),
(Complexity::Logarithmic, log_n),
(Complexity::Linear, n as f64),
(Complexity::Quadratic, (n * n) as f64),
(Complexity::Cubic, (n * n * n) as f64),
];
let mut best_complexity = Complexity::Cubic;
let mut min_diff = f64::MAX;
for (complexity, theoretical) in ratios {
let diff = (nanos / theoretical - 1.0).abs();
if diff < min_diff {
min_diff = diff;
best_complexity = complexity;
}
}
best_complexity
}
}
#[derive(Debug, Clone)]
pub struct SolverResult {
pub solution: Vector,
pub method: SolverMethod,
pub iterations: usize,
pub residual: f64,
pub time: Duration,
pub complexity: Complexity,
}
impl SolverResult {
pub fn converged(&self, tolerance: f64) -> bool {
self.residual < tolerance
}
pub fn time_microseconds(&self) -> f64 {
self.time.as_secs_f64() * 1_000_000.0
}
pub fn is_sublinear(&self) -> bool {
matches!(
self.complexity,
Complexity::Constant | Complexity::Logarithmic
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_neumann_solver() {
let a = Matrix::diagonally_dominant(10, 2.0);
let b = Vector::ones(10);
let solver = SublinearSolver::with_method(SolverMethod::Neumann);
let result = solver.solve(&a, &b).unwrap();
assert!(result.is_sublinear());
}
#[test]
fn test_forward_push() {
let a = Matrix::diagonally_dominant(100, 3.0);
let b = Vector::random(100);
let solver = SublinearSolver::with_method(SolverMethod::ForwardPush);
let result = solver.solve(&a, &b).unwrap();
assert!(result.time_microseconds() < 1000.0); }
#[test]
fn test_adaptive_selection() {
let sparse = Matrix::diagonally_dominant(50, 5.0);
let b = Vector::ones(50);
let solver = SublinearSolver::new();
let result = solver.solve(&sparse, &b).unwrap();
assert!(result.converged(1e-3));
}
}