use super::ultra_fast::{UltraFastCSR, UltraFastCG};
use std::collections::{BinaryHeap, HashMap};
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct BMSSPConfig {
pub max_iterations: usize,
pub tolerance: f64,
pub bound: f64,
pub use_neural: bool,
pub enable_simd: bool,
}
impl Default for BMSSPConfig {
fn default() -> Self {
Self {
max_iterations: 1000,
tolerance: 1e-10,
bound: f64::INFINITY,
use_neural: false,
enable_simd: true,
}
}
}
#[derive(Copy, Clone)]
struct BMSSPNode {
cost: f64,
index: usize,
source_id: usize,
}
impl Eq for BMSSPNode {}
impl PartialEq for BMSSPNode {
fn eq(&self, other: &Self) -> bool {
self.cost == other.cost
}
}
impl Ord for BMSSPNode {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.partial_cmp(&self.cost).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for BMSSPNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
other.cost.partial_cmp(&self.cost)
}
}
pub struct BMSSPSolver {
config: BMSSPConfig,
neural_cache: Option<HashMap<usize, Vec<f64>>>,
}
impl BMSSPSolver {
pub fn new(config: BMSSPConfig) -> Self {
Self {
config,
neural_cache: if config.use_neural { Some(HashMap::new()) } else { None },
}
}
pub fn solve(&mut self, matrix: &UltraFastCSR, b: &[f64]) -> Vec<f64> {
let n = matrix.rows();
if n < 100 || matrix.nnz() > n * n / 10 {
let cg = UltraFastCG::new(self.config.max_iterations, self.config.tolerance);
return cg.solve(matrix, b);
}
self.solve_bmssp(matrix, b)
}
fn solve_bmssp(&mut self, matrix: &UltraFastCSR, b: &[f64]) -> Vec<f64> {
let n = matrix.rows();
let mut solution = vec![0.0; n];
let sources: Vec<usize> = b.iter()
.enumerate()
.filter(|(_, &val)| val.abs() > 1e-10)
.map(|(i, _)| i)
.collect();
if sources.is_empty() {
return solution;
}
let mut distances = vec![f64::INFINITY; n];
let mut heap = BinaryHeap::new();
for &source in &sources {
distances[source] = 0.0;
heap.push(BMSSPNode {
cost: 0.0,
index: source,
source_id: source,
});
}
let mut visited = 0;
while let Some(node) = heap.pop() {
if node.cost > self.config.bound {
break; }
if node.cost > distances[node.index] {
continue;
}
visited += 1;
if visited > n / 2 {
let cg = UltraFastCG::new(self.config.max_iterations, self.config.tolerance);
return cg.solve(matrix, b);
}
solution[node.index] = b[node.source_id] / (1.0 + node.cost);
let (row_start, row_end) = matrix.get_row_range(node.index);
for idx in row_start..row_end {
let (col, val) = matrix.get_entry(idx);
let new_cost = node.cost + 1.0 / val.abs().max(1e-10);
if new_cost < distances[col] {
distances[col] = new_cost;
heap.push(BMSSPNode {
cost: new_cost,
index: col,
source_id: node.source_id,
});
}
}
}
if self.config.use_neural {
self.neural_refine(&mut solution, matrix, b);
}
solution
}
fn neural_refine(&mut self, solution: &mut [f64], matrix: &UltraFastCSR, b: &[f64]) {
if let Some(cache) = &self.neural_cache {
let pattern_key = (matrix.rows() / 100) * 100;
if let Some(pattern) = cache.get(&pattern_key) {
for i in 0..solution.len().min(pattern.len()) {
solution[i] *= 1.0 + pattern[i] * 0.1;
}
}
}
let mut residual = vec![0.0; matrix.rows()];
matrix.multiply_vector_ultra_fast(solution, &mut residual);
let mut error = 0.0;
for i in 0..residual.len() {
let diff = residual[i] - b[i];
error += diff * diff;
solution[i] -= diff * 0.1;
}
if error < self.config.tolerance && self.neural_cache.is_some() {
let pattern_key = (matrix.rows() / 100) * 100;
let pattern: Vec<f64> = solution.iter().map(|&x| x / (x.abs() + 1.0)).collect();
if let Some(cache) = &mut self.neural_cache {
cache.insert(pattern_key, pattern);
}
}
}
pub fn analyze_matrix(matrix: &UltraFastCSR) -> String {
let n = matrix.rows();
let nnz = matrix.nnz();
let sparsity = nnz as f64 / (n * n) as f64;
if sparsity < 0.001 {
"ultra-sparse: BMSSP optimal".to_string()
} else if sparsity < 0.01 {
"sparse: BMSSP recommended".to_string()
} else if sparsity < 0.1 {
"moderate: Hybrid approach".to_string()
} else {
"dense: Direct CG recommended".to_string()
}
}
}
impl UltraFastCSR {
pub fn get_row_range(&self, row: usize) -> (usize, usize) {
let start = self.row_ptr[row] as usize;
let end = self.row_ptr[row + 1] as usize;
(start, end)
}
pub fn get_entry(&self, idx: usize) -> (usize, f64) {
(self.col_indices[idx] as usize, self.values[idx])
}
}
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
pub struct WasmBMSSP {
solver: BMSSPSolver,
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl WasmBMSSP {
#[wasm_bindgen(constructor)]
pub fn new(max_iterations: usize, tolerance: f64, use_neural: bool) -> Self {
let config = BMSSPConfig {
max_iterations,
tolerance,
use_neural,
..Default::default()
};
Self {
solver: BMSSPSolver::new(config),
}
}
#[wasm_bindgen]
pub fn solve(&mut self, matrix_data: &[f64], rows: usize, cols: usize, b: &[f64]) -> Vec<f64> {
let mut triplets = Vec::new();
for i in 0..rows {
for j in 0..cols {
let val = matrix_data[i * cols + j];
if val.abs() > 1e-10 {
triplets.push((i, j, val));
}
}
}
let matrix = UltraFastCSR::from_triplets(triplets, rows, cols);
self.solver.solve(&matrix, b)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ultra_fast::generate_test_matrix;
#[test]
fn test_bmssp_solver() {
let (matrix, b) = generate_test_matrix(1000, 0.001);
let config = BMSSPConfig::default();
let mut solver = BMSSPSolver::new(config);
let solution = solver.solve(&matrix, &b);
let mut result = vec![0.0; 1000];
matrix.multiply_vector_ultra_fast(&solution, &mut result);
let mut error = 0.0;
for i in 0..1000 {
error += (result[i] - b[i]).powi(2);
}
error = error.sqrt();
assert!(error < 1e-6, "BMSSP solution error too large: {}", error);
}
#[test]
fn test_method_selection() {
let (small_matrix, _) = generate_test_matrix(50, 0.1);
let (sparse_matrix, _) = generate_test_matrix(5000, 0.0001);
assert_eq!(BMSSPSolver::analyze_matrix(&small_matrix), "dense: Direct CG recommended");
assert_eq!(BMSSPSolver::analyze_matrix(&sparse_matrix), "ultra-sparse: BMSSP optimal");
}
}