use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
#[derive(Debug, Clone)]
pub struct PrecisionMatrix {
pub matrix: Array2<f64>,
pub num_edges: usize,
pub lambda: f64,
}
impl PrecisionMatrix {
pub fn new(matrix: Array2<f64>, lambda: f64) -> Self {
let p = matrix.nrows();
let mut num_edges = 0;
for i in 0..p {
for j in (i + 1)..p {
if matrix[[i, j]].abs() > 1e-10 {
num_edges += 1;
}
}
}
PrecisionMatrix {
matrix,
num_edges,
lambda,
}
}
pub fn adjacency(&self, threshold: f64) -> Array2<bool> {
let p = self.matrix.nrows();
let mut adj = Array2::from_elem((p, p), false);
for i in 0..p {
for j in 0..p {
if i != j && self.matrix[[i, j]].abs() > threshold {
adj[[i, j]] = true;
}
}
}
adj
}
pub fn density(&self) -> f64 {
let p = self.matrix.nrows();
let max_edges = p * (p - 1) / 2;
if max_edges == 0 {
return 0.0;
}
self.num_edges as f64 / max_edges as f64
}
}
#[derive(Debug, Clone)]
pub struct GraphicalLassoConfig {
pub lambda: f64,
pub max_iter: usize,
pub tolerance: f64,
pub warm_start: Option<Array2<f64>>,
}
impl Default for GraphicalLassoConfig {
fn default() -> Self {
GraphicalLassoConfig {
lambda: 0.1,
max_iter: 100,
tolerance: 1e-4,
warm_start: None,
}
}
}
impl GraphicalLassoConfig {
pub fn new(lambda: f64) -> StatsResult<Self> {
if lambda <= 0.0 {
return Err(StatsError::InvalidArgument(
"lambda must be positive".to_string(),
));
}
Ok(GraphicalLassoConfig {
lambda,
..Default::default()
})
}
pub fn with_max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn with_tolerance(mut self, tolerance: f64) -> Self {
self.tolerance = tolerance;
self
}
pub fn with_warm_start(mut self, initial: Array2<f64>) -> Self {
self.warm_start = Some(initial);
self
}
}
#[derive(Debug, Clone)]
pub struct GraphicalLassoResult {
pub precision: PrecisionMatrix,
pub covariance: Array2<f64>,
pub n_iter: usize,
pub converged: bool,
pub objective: f64,
}
#[derive(Debug, Clone)]
pub struct GraphicalLasso {
config: GraphicalLassoConfig,
}
impl GraphicalLasso {
pub fn new(config: GraphicalLassoConfig) -> Self {
GraphicalLasso { config }
}
pub fn with_lambda(lambda: f64) -> StatsResult<Self> {
Ok(GraphicalLasso {
config: GraphicalLassoConfig::new(lambda)?,
})
}
pub fn fit(&self, s: &ArrayView2<f64>) -> StatsResult<GraphicalLassoResult> {
let p = s.nrows();
if s.ncols() != p {
return Err(StatsError::DimensionMismatch(
"Sample covariance matrix must be square".to_string(),
));
}
if p == 0 {
return Err(StatsError::InvalidArgument(
"Sample covariance matrix must be non-empty".to_string(),
));
}
for i in 0..p {
for j in (i + 1)..p {
if (s[[i, j]] - s[[j, i]]).abs() > 1e-10 {
return Err(StatsError::InvalidArgument(
"Sample covariance matrix must be symmetric".to_string(),
));
}
}
}
let lambda = self.config.lambda;
if p == 1 {
let s_val = s[[0, 0]];
if s_val <= 0.0 {
return Err(StatsError::InvalidArgument(
"Diagonal of covariance must be positive".to_string(),
));
}
let theta_val = 1.0 / s_val;
let precision_mat = Array2::from_elem((1, 1), theta_val);
let cov_mat = Array2::from_elem((1, 1), s_val);
let obj = -theta_val.ln() + 1.0; return Ok(GraphicalLassoResult {
precision: PrecisionMatrix::new(precision_mat, lambda),
covariance: cov_mat,
n_iter: 0,
converged: true,
objective: obj,
});
}
let mut w = if let Some(ref warm) = self.config.warm_start {
if warm.nrows() != p || warm.ncols() != p {
return Err(StatsError::DimensionMismatch(
"Warm start matrix dimension mismatch".to_string(),
));
}
invert_symmetric(warm)?
} else {
let mut w_init = s.to_owned();
for i in 0..p {
w_init[[i, i]] += lambda;
}
w_init
};
let mut converged = false;
let mut n_iter = 0;
for iter in 0..self.config.max_iter {
let w_old = w.clone();
for j in 0..p {
let (w11, s12) = partition_matrix(&w, s, j);
let beta = solve_lasso_subproblem(&w11, &s12, lambda, 100, 1e-6)?;
let w12 = w11.dot(&beta);
let mut idx = 0;
for i in 0..p {
if i != j {
w[[i, j]] = w12[idx];
w[[j, i]] = w12[idx];
idx += 1;
}
}
}
let mut max_change: f64 = 0.0;
for i in 0..p {
for k in 0..p {
let diff = (w[[i, k]] - w_old[[i, k]]).abs();
if diff > max_change {
max_change = diff;
}
}
}
n_iter = iter + 1;
if max_change < self.config.tolerance {
converged = true;
break;
}
}
let theta = invert_symmetric(&w)?;
let objective = compute_objective(&theta, s, lambda)?;
Ok(GraphicalLassoResult {
precision: PrecisionMatrix::new(theta, lambda),
covariance: w,
n_iter,
converged,
objective,
})
}
}
pub(crate) fn compute_objective(
theta: &Array2<f64>,
s: &ArrayView2<f64>,
lambda: f64,
) -> StatsResult<f64> {
let p = theta.nrows();
let log_det = log_determinant(theta)?;
let neg_log_det = -log_det;
let mut trace_st = 0.0;
for i in 0..p {
for j in 0..p {
trace_st += s[[i, j]] * theta[[i, j]];
}
}
let mut l1_off_diag = 0.0;
for i in 0..p {
for j in 0..p {
if i != j {
l1_off_diag += theta[[i, j]].abs();
}
}
}
Ok(neg_log_det + trace_st + lambda * l1_off_diag)
}
fn log_determinant(a: &Array2<f64>) -> StatsResult<f64> {
let p = a.nrows();
let l = cholesky_decomp(a)?;
let mut log_det = 0.0;
for i in 0..p {
let diag = l[[i, i]];
if diag <= 0.0 {
return Err(StatsError::ComputationError(
"Matrix is not positive definite (Cholesky failed)".to_string(),
));
}
log_det += diag.ln();
}
Ok(2.0 * log_det)
}
fn cholesky_decomp(a: &Array2<f64>) -> StatsResult<Array2<f64>> {
let p = a.nrows();
let mut l = Array2::<f64>::zeros((p, p));
for i in 0..p {
for j in 0..=i {
let mut sum = 0.0;
for k in 0..j {
sum += l[[i, k]] * l[[j, k]];
}
if i == j {
let val = a[[i, i]] - sum;
if val <= 0.0 {
return Err(StatsError::ComputationError(format!(
"Matrix is not positive definite at index {}",
i
)));
}
l[[i, j]] = val.sqrt();
} else {
if l[[j, j]].abs() < 1e-15 {
return Err(StatsError::ComputationError(
"Near-zero diagonal in Cholesky decomposition".to_string(),
));
}
l[[i, j]] = (a[[i, j]] - sum) / l[[j, j]];
}
}
}
Ok(l)
}
fn invert_symmetric(a: &Array2<f64>) -> StatsResult<Array2<f64>> {
let p = a.nrows();
let l = cholesky_decomp(a)?;
let mut z = Array2::<f64>::zeros((p, p));
for col in 0..p {
for i in 0..p {
let mut sum = if i == col { 1.0 } else { 0.0 };
for k in 0..i {
sum -= l[[i, k]] * z[[k, col]];
}
z[[i, col]] = sum / l[[i, i]];
}
}
let mut inv = Array2::<f64>::zeros((p, p));
for col in 0..p {
for i in (0..p).rev() {
let mut sum = z[[i, col]];
for k in (i + 1)..p {
sum -= l[[k, i]] * inv[[k, col]];
}
inv[[i, col]] = sum / l[[i, i]];
}
}
Ok(inv)
}
fn partition_matrix(w: &Array2<f64>, s: &ArrayView2<f64>, j: usize) -> (Array2<f64>, Array1<f64>) {
let p = w.nrows();
let pm1 = p - 1;
let mut w11 = Array2::<f64>::zeros((pm1, pm1));
let mut s12 = Array1::<f64>::zeros(pm1);
let mut ri = 0;
for i in 0..p {
if i == j {
continue;
}
s12[ri] = s[[i, j]];
let mut ci = 0;
for k in 0..p {
if k == j {
continue;
}
w11[[ri, ci]] = w[[i, k]];
ci += 1;
}
ri += 1;
}
(w11, s12)
}
fn solve_lasso_subproblem(
w11: &Array2<f64>,
s12: &Array1<f64>,
lambda: f64,
max_iter: usize,
tol: f64,
) -> StatsResult<Array1<f64>> {
let pm1 = s12.len();
let mut beta = Array1::<f64>::zeros(pm1);
for _iter in 0..max_iter {
let mut max_change: f64 = 0.0;
for k in 0..pm1 {
let mut residual = s12[k];
for m in 0..pm1 {
if m != k {
residual -= w11[[k, m]] * beta[m];
}
}
let w_kk = w11[[k, k]];
if w_kk.abs() < 1e-15 {
continue;
}
let new_val = soft_threshold(residual, lambda) / w_kk;
let change = (new_val - beta[k]).abs();
if change > max_change {
max_change = change;
}
beta[k] = new_val;
}
if max_change < tol {
break;
}
}
Ok(beta)
}
#[inline]
fn soft_threshold(x: f64, lambda: f64) -> f64 {
if x > lambda {
x - lambda
} else if x < -lambda {
x + lambda
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_glasso_identity_covariance() {
let s = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let config = GraphicalLassoConfig::new(0.01).expect("config creation failed");
let glasso = GraphicalLasso::new(config);
let result = glasso.fit(&s.view()).expect("GLASSO fit failed");
assert!(result.converged, "GLASSO should converge");
let theta = &result.precision.matrix;
for i in 0..3 {
assert!(
(theta[[i, i]] - 1.0).abs() < 0.1,
"Diagonal should be near 1.0, got {}",
theta[[i, i]]
);
for j in 0..3 {
if i != j {
assert!(
theta[[i, j]].abs() < 0.1,
"Off-diagonal should be near 0.0, got {}",
theta[[i, j]]
);
}
}
}
}
#[test]
fn test_glasso_sparse_structure_recovery() {
let true_theta = array![[1.0, -0.5, 0.0], [-0.5, 1.25, -0.5], [0.0, -0.5, 1.0]];
let sigma = invert_symmetric(&true_theta).expect("inversion failed");
let config = GraphicalLassoConfig::new(0.05).expect("config creation failed");
let glasso = GraphicalLasso::new(config);
let result = glasso.fit(&sigma.view()).expect("GLASSO fit failed");
let theta = &result.precision.matrix;
assert!(
theta[[0, 2]].abs() < 0.2,
"Expected near-zero for (0,2), got {}",
theta[[0, 2]]
);
assert!(
theta[[2, 0]].abs() < 0.2,
"Expected near-zero for (2,0), got {}",
theta[[2, 0]]
);
assert!(
theta[[0, 1]].abs() > 0.1,
"Expected non-zero for (0,1), got {}",
theta[[0, 1]]
);
}
#[test]
fn test_glasso_lambda_sparsity() {
let s = array![[1.0, 0.5, 0.3], [0.5, 1.0, 0.4], [0.3, 0.4, 1.0]];
let config_small = GraphicalLassoConfig::new(0.01).expect("config creation failed");
let result_small = GraphicalLasso::new(config_small)
.fit(&s.view())
.expect("GLASSO fit failed");
let config_large = GraphicalLassoConfig::new(0.5).expect("config creation failed");
let result_large = GraphicalLasso::new(config_large)
.fit(&s.view())
.expect("GLASSO fit failed");
assert!(
result_large.precision.num_edges <= result_small.precision.num_edges,
"Larger lambda should give sparser result: {} edges (large) vs {} edges (small)",
result_large.precision.num_edges,
result_small.precision.num_edges
);
}
#[test]
fn test_glasso_convergence() {
let s = array![
[2.0, 0.8, 0.3, 0.1],
[0.8, 2.0, 0.5, 0.2],
[0.3, 0.5, 2.0, 0.6],
[0.1, 0.2, 0.6, 2.0]
];
let config = GraphicalLassoConfig::new(0.1)
.expect("config creation failed")
.with_max_iter(200)
.with_tolerance(1e-6);
let result = GraphicalLasso::new(config)
.fit(&s.view())
.expect("GLASSO fit failed");
assert!(
result.converged,
"Should converge with sufficient iterations"
);
assert!(result.n_iter > 0, "Should take at least one iteration");
}
#[test]
fn test_glasso_warm_start() {
let s = array![[1.0, 0.4, 0.2], [0.4, 1.0, 0.3], [0.2, 0.3, 1.0]];
let config1 = GraphicalLassoConfig::new(0.1).expect("config creation failed");
let result1 = GraphicalLasso::new(config1)
.fit(&s.view())
.expect("GLASSO fit failed");
let config2 = GraphicalLassoConfig::new(0.1)
.expect("config creation failed")
.with_warm_start(result1.precision.matrix.clone());
let result2 = GraphicalLasso::new(config2)
.fit(&s.view())
.expect("GLASSO fit failed");
let p = result1.precision.matrix.nrows();
for i in 0..p {
for j in 0..p {
assert!(
(result1.precision.matrix[[i, j]] - result2.precision.matrix[[i, j]]).abs()
< 0.05,
"Warm start result should be close to cold start"
);
}
}
assert!(
result2.n_iter <= result1.n_iter + 1,
"Warm start should not take significantly more iterations"
);
}
#[test]
fn test_glasso_single_variable() {
let s = array![[4.0]];
let config = GraphicalLassoConfig::new(0.1).expect("config creation failed");
let result = GraphicalLasso::new(config)
.fit(&s.view())
.expect("GLASSO fit failed");
assert!(
(result.precision.matrix[[0, 0]] - 0.25).abs() < 1e-10,
"Precision should be 1/variance"
);
assert!(result.converged);
}
#[test]
fn test_glasso_non_square_error() {
let s = array![[1.0, 0.5], [0.6, 1.0]]; let config = GraphicalLassoConfig::new(0.1).expect("config creation failed");
let result = GraphicalLasso::new(config).fit(&s.view());
assert!(result.is_err(), "Should error on asymmetric matrix");
}
#[test]
fn test_glasso_invalid_lambda() {
let result = GraphicalLassoConfig::new(-0.1);
assert!(result.is_err(), "Should error on negative lambda");
let result = GraphicalLassoConfig::new(0.0);
assert!(result.is_err(), "Should error on zero lambda");
}
#[test]
fn test_precision_matrix_helpers() {
let mat = array![[1.0, -0.3, 0.0], [-0.3, 1.0, -0.2], [0.0, -0.2, 1.0]];
let pm = PrecisionMatrix::new(mat, 0.1);
assert_eq!(pm.num_edges, 2, "Should have 2 edges (0-1 and 1-2)");
let adj = pm.adjacency(0.1);
assert!(adj[[0, 1]]);
assert!(adj[[1, 0]]);
assert!(adj[[1, 2]]);
assert!(adj[[2, 1]]);
assert!(!adj[[0, 2]]);
assert!(!adj[[2, 0]]);
let density = pm.density();
assert!((density - 2.0 / 3.0).abs() < 1e-10);
}
}