mod helpers;
use ruvector_solver::router::{RouterConfig, SolverOrchestrator, SolverRouter};
use ruvector_solver::types::{Algorithm, ComputeBudget, CsrMatrix, QueryType, SparsityProfile};
use helpers::{random_diag_dominant_csr, random_spd_csr, random_vector};
fn default_budget() -> ComputeBudget {
ComputeBudget {
max_time: std::time::Duration::from_secs(30),
max_iterations: 10_000,
tolerance: 1e-8,
}
}
#[test]
fn test_router_selects_neumann_for_diag_dominant() {
let router = SolverRouter::new(RouterConfig::default());
let profile = SparsityProfile {
rows: 1000,
cols: 1000,
nnz: 3000,
density: 0.003,
is_diag_dominant: true,
estimated_spectral_radius: 0.5,
estimated_condition: 10.0,
is_symmetric_structure: true,
avg_nnz_per_row: 3.0,
max_nnz_per_row: 5,
};
let algo = router.select_algorithm(&profile, &QueryType::LinearSystem);
assert_eq!(
algo,
Algorithm::Neumann,
"diag-dominant, sparse, low spectral radius should route to Neumann"
);
let matrix = random_diag_dominant_csr(20, 0.2, 42);
let real_profile = SolverOrchestrator::analyze_sparsity(&matrix);
assert!(
real_profile.is_diag_dominant,
"random_diag_dominant_csr should produce a diag-dominant matrix"
);
assert!(
real_profile.estimated_spectral_radius < 1.0,
"spectral radius should be < 1 for diag-dominant matrix"
);
}
#[test]
fn test_router_selects_cg_for_general_spd() {
let router = SolverRouter::new(RouterConfig::default());
let profile = SparsityProfile {
rows: 500,
cols: 500,
nnz: 25_000,
density: 0.10,
is_diag_dominant: false,
estimated_spectral_radius: 0.8,
estimated_condition: 50.0,
is_symmetric_structure: true,
avg_nnz_per_row: 50.0,
max_nnz_per_row: 80,
};
let algo = router.select_algorithm(&profile, &QueryType::LinearSystem);
assert_eq!(
algo,
Algorithm::CG,
"well-conditioned, non-diag-dominant should route to CG"
);
let ill_conditioned = SparsityProfile {
estimated_condition: 500.0,
..profile.clone()
};
let algo_ill = router.select_algorithm(&ill_conditioned, &QueryType::LinearSystem);
assert_eq!(
algo_ill,
Algorithm::BMSSP,
"ill-conditioned should route to BMSSP"
);
}
#[test]
fn test_router_selects_push_for_pagerank() {
let router = SolverRouter::new(RouterConfig::default());
let profile = SparsityProfile {
rows: 5000,
cols: 5000,
nnz: 20_000,
density: 0.0008,
is_diag_dominant: false,
estimated_spectral_radius: 0.85,
estimated_condition: 100.0,
is_symmetric_structure: false,
avg_nnz_per_row: 4.0,
max_nnz_per_row: 50,
};
let algo_single = router.select_algorithm(&profile, &QueryType::PageRankSingle { source: 0 });
assert_eq!(
algo_single,
Algorithm::ForwardPush,
"single-source PageRank should route to ForwardPush"
);
let algo_pairwise_large = router.select_algorithm(
&profile,
&QueryType::PageRankPairwise {
source: 0,
target: 100,
},
);
assert_eq!(
algo_pairwise_large,
Algorithm::HybridRandomWalk,
"pairwise PageRank on large graph should route to HybridRandomWalk"
);
let small_profile = SparsityProfile {
rows: 500,
cols: 500,
nnz: 2000,
density: 0.008,
..profile.clone()
};
let algo_pairwise_small = router.select_algorithm(
&small_profile,
&QueryType::PageRankPairwise {
source: 0,
target: 10,
},
);
assert_eq!(
algo_pairwise_small,
Algorithm::ForwardPush,
"pairwise PageRank on small graph should route to ForwardPush"
);
}
#[test]
fn test_router_fallback_chain() {
let orchestrator = SolverOrchestrator::new(RouterConfig::default());
let matrix = CsrMatrix::<f64>::from_coo(
4,
4,
vec![
(0, 0, 4.0),
(0, 1, -1.0),
(1, 0, -1.0),
(1, 1, 4.0),
(1, 2, -1.0),
(2, 1, -1.0),
(2, 2, 4.0),
(2, 3, -1.0),
(3, 2, -1.0),
(3, 3, 4.0),
],
);
let rhs = vec![1.0, 0.0, 0.0, 1.0];
let budget = default_budget();
let result = orchestrator
.solve_with_fallback(&matrix, &rhs, QueryType::LinearSystem, &budget)
.unwrap();
assert!(
result.residual_norm < 1e-4,
"fallback chain should produce a good solution, residual={}",
result.residual_norm
);
let profile = SolverOrchestrator::analyze_sparsity(&matrix);
let selected = orchestrator
.router()
.select_algorithm(&profile, &QueryType::LinearSystem);
assert!(result.solution.len() == 4, "solution should have 4 entries");
let spd = random_spd_csr(10, 0.3, 42);
let rhs2 = random_vector(10, 43);
let result2 = orchestrator
.solve_with_fallback(&spd, &rhs2, QueryType::LinearSystem, &budget)
.unwrap();
assert!(
result2.residual_norm < 1e-3,
"fallback on SPD should converge, residual={}",
result2.residual_norm
);
}