use std::path::{Path, PathBuf};
use faer::sparse::SparseColMat;
use serde::Deserialize;
use crate::error::SparseError;
use crate::io::mtx;
use crate::io::reference::{self, ReferenceFactorization};
#[derive(Debug, Clone, Deserialize)]
pub struct MatrixProperties {
pub symmetric: bool,
#[serde(default)]
pub positive_definite: bool,
#[serde(default)]
pub indefinite: bool,
#[serde(default)]
pub difficulty: String,
#[serde(default)]
pub structure: Option<String>,
#[serde(default)]
pub kind: Option<String>,
#[serde(default)]
pub expected_delayed_pivots: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MatrixMetadata {
pub name: String,
pub source: String,
pub category: String,
pub path: String,
pub size: usize,
pub nnz: usize,
#[serde(default)]
pub in_repo: bool,
#[serde(default)]
pub ci_subset: bool,
pub properties: MatrixProperties,
#[serde(default)]
pub paper_references: Vec<String>,
#[serde(default)]
pub reference_results: serde_json::Value,
#[serde(default)]
pub factorization_path: Option<String>,
}
#[derive(Debug)]
pub struct TestMatrix {
pub metadata: MatrixMetadata,
pub matrix: SparseColMat<usize, f64>,
pub reference: Option<ReferenceFactorization>,
}
#[derive(Debug, Deserialize)]
struct MetadataFile {
#[allow(dead_code)]
schema_version: String,
#[allow(dead_code)]
generated: String,
#[allow(dead_code)]
total_count: usize,
matrices: Vec<MatrixMetadata>,
}
fn test_data_dir() -> PathBuf {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
Path::new(manifest_dir).join("test-data")
}
pub fn load_registry() -> Result<Vec<MatrixMetadata>, SparseError> {
let path = test_data_dir().join("metadata.json");
let path_str = path.display().to_string();
let content = std::fs::read_to_string(&path).map_err(|e| SparseError::IoError {
source: e.to_string(),
path: path_str.clone(),
})?;
let metadata: MetadataFile =
serde_json::from_str(&content).map_err(|e| SparseError::ParseError {
reason: e.to_string(),
path: path_str,
line: None,
})?;
Ok(metadata.matrices)
}
fn resolve_mtx_path(entry: &MatrixMetadata) -> PathBuf {
if entry.ci_subset {
if let Some(ci_path) = ci_subset_path(entry) {
if ci_path.exists() {
return ci_path;
}
}
}
test_data_dir().join(&entry.path)
}
fn ci_subset_path(entry: &MatrixMetadata) -> Option<PathBuf> {
let rest = entry.path.strip_prefix("suitesparse/")?;
let category = rest.split('/').next()?;
let file_name = Path::new(&entry.path).file_name()?;
Some(
test_data_dir()
.join("suitesparse-ci")
.join(category)
.join(file_name),
)
}
pub fn load_test_matrix_from_entry(
entry: &MatrixMetadata,
) -> Result<Option<TestMatrix>, SparseError> {
let mtx_path = resolve_mtx_path(entry);
if !mtx_path.exists() {
return Ok(None);
}
let matrix = mtx::load_mtx(&mtx_path)?;
let reference = if let Some(ref fact_path) = entry.factorization_path {
let json_path = test_data_dir().join(fact_path);
if json_path.exists() {
let refdata = reference::load_reference(&json_path)?;
if refdata.permutation.len() != matrix.nrows() {
return Err(SparseError::ParseError {
reason: format!(
"reference factorization permutation length ({}) != matrix dimension ({})",
refdata.permutation.len(),
matrix.nrows()
),
path: json_path.display().to_string(),
line: None,
});
}
Some(refdata)
} else {
None
}
} else {
None
};
Ok(Some(TestMatrix {
metadata: entry.clone(),
matrix,
reference,
}))
}
pub fn load_test_matrix(name: &str) -> Result<Option<TestMatrix>, SparseError> {
let registry = load_registry()?;
let entry =
registry
.iter()
.find(|m| m.name == name)
.ok_or_else(|| SparseError::MatrixNotFound {
name: name.to_string(),
})?;
load_test_matrix_from_entry(entry)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn load_arrow_5_pd_returns_some() {
let test = load_test_matrix("arrow-5-pd")
.expect("registry error")
.expect("matrix should exist on disk");
assert_eq!(test.matrix.nrows(), 5);
assert_eq!(test.matrix.ncols(), 5);
assert!(test.reference.is_some());
}
#[test]
fn nonexistent_matrix_returns_error() {
let result = load_test_matrix("nonexistent-matrix-name");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SparseError::MatrixNotFound { .. }
));
}
#[test]
fn missing_mtx_file_returns_none() {
let fake_entry = MatrixMetadata {
name: "fake-missing-matrix".to_string(),
source: "test".to_string(),
category: "test".to_string(),
path: "nonexistent/path/fake.mtx".to_string(),
size: 5,
nnz: 10,
in_repo: false,
ci_subset: false,
properties: MatrixProperties {
symmetric: true,
positive_definite: false,
indefinite: false,
difficulty: "trivial".to_string(),
structure: None,
kind: None,
expected_delayed_pivots: None,
},
paper_references: vec![],
reference_results: serde_json::Value::Null,
factorization_path: None,
};
let result =
load_test_matrix_from_entry(&fake_entry).expect("should not error for missing file");
assert!(result.is_none(), "missing .mtx file should return None");
}
#[test]
fn load_via_entry_matches_load_by_name() {
let registry = load_registry().expect("failed to load registry");
let entry = registry.iter().find(|m| m.name == "arrow-5-pd").unwrap();
let by_entry = load_test_matrix_from_entry(entry)
.expect("entry load error")
.expect("should exist");
let by_name = load_test_matrix("arrow-5-pd")
.expect("name load error")
.expect("should exist");
assert_eq!(by_entry.matrix.nrows(), by_name.matrix.nrows());
assert_eq!(by_entry.matrix.ncols(), by_name.matrix.ncols());
assert_eq!(by_entry.matrix.compute_nnz(), by_name.matrix.compute_nnz());
}
}