rivrs_sparse/io/
registry.rs1use std::path::{Path, PathBuf};
7
8use faer::sparse::SparseColMat;
9use serde::Deserialize;
10
11use crate::error::SparseError;
12use crate::io::mtx;
13use crate::io::reference::{self, ReferenceFactorization};
14
15#[derive(Debug, Clone, Deserialize)]
17pub struct MatrixProperties {
18 pub symmetric: bool,
20 #[serde(default)]
22 pub positive_definite: bool,
23 #[serde(default)]
25 pub indefinite: bool,
26 #[serde(default)]
28 pub difficulty: String,
29 #[serde(default)]
31 pub structure: Option<String>,
32 #[serde(default)]
34 pub kind: Option<String>,
35 #[serde(default)]
37 pub expected_delayed_pivots: Option<String>,
38}
39
40#[derive(Debug, Clone, Deserialize)]
42pub struct MatrixMetadata {
43 pub name: String,
45 pub source: String,
47 pub category: String,
49 pub path: String,
51 pub size: usize,
53 pub nnz: usize,
55 #[serde(default)]
57 pub in_repo: bool,
58 #[serde(default)]
60 pub ci_subset: bool,
61 pub properties: MatrixProperties,
63 #[serde(default)]
65 pub paper_references: Vec<String>,
66 #[serde(default)]
68 pub reference_results: serde_json::Value,
69 #[serde(default)]
71 pub factorization_path: Option<String>,
72}
73
74#[derive(Debug)]
76pub struct TestMatrix {
77 pub metadata: MatrixMetadata,
79 pub matrix: SparseColMat<usize, f64>,
81 pub reference: Option<ReferenceFactorization>,
83}
84
85#[derive(Debug, Deserialize)]
90struct MetadataFile {
91 #[allow(dead_code)]
92 schema_version: String,
93 #[allow(dead_code)]
94 generated: String,
95 #[allow(dead_code)]
96 total_count: usize,
97 matrices: Vec<MatrixMetadata>,
98}
99
100fn test_data_dir() -> PathBuf {
102 let manifest_dir = env!("CARGO_MANIFEST_DIR");
103 Path::new(manifest_dir).join("test-data")
104}
105
106pub fn load_registry() -> Result<Vec<MatrixMetadata>, SparseError> {
113 let path = test_data_dir().join("metadata.json");
114 let path_str = path.display().to_string();
115 let content = std::fs::read_to_string(&path).map_err(|e| SparseError::IoError {
116 source: e.to_string(),
117 path: path_str.clone(),
118 })?;
119 let metadata: MetadataFile =
120 serde_json::from_str(&content).map_err(|e| SparseError::ParseError {
121 reason: e.to_string(),
122 path: path_str,
123 line: None,
124 })?;
125 Ok(metadata.matrices)
126}
127
128fn resolve_mtx_path(entry: &MatrixMetadata) -> PathBuf {
133 if entry.ci_subset {
134 if let Some(ci_path) = ci_subset_path(entry) {
135 if ci_path.exists() {
136 return ci_path;
137 }
138 }
139 }
140 test_data_dir().join(&entry.path)
141}
142
143fn ci_subset_path(entry: &MatrixMetadata) -> Option<PathBuf> {
148 let rest = entry.path.strip_prefix("suitesparse/")?;
149 let category = rest.split('/').next()?;
150 let file_name = Path::new(&entry.path).file_name()?;
151 Some(
152 test_data_dir()
153 .join("suitesparse-ci")
154 .join(category)
155 .join(file_name),
156 )
157}
158
159pub fn load_test_matrix_from_entry(
171 entry: &MatrixMetadata,
172) -> Result<Option<TestMatrix>, SparseError> {
173 let mtx_path = resolve_mtx_path(entry);
174
175 if !mtx_path.exists() {
176 return Ok(None);
177 }
178
179 let matrix = mtx::load_mtx(&mtx_path)?;
180
181 let reference = if let Some(ref fact_path) = entry.factorization_path {
183 let json_path = test_data_dir().join(fact_path);
184 if json_path.exists() {
185 let refdata = reference::load_reference(&json_path)?;
186 if refdata.permutation.len() != matrix.nrows() {
188 return Err(SparseError::ParseError {
189 reason: format!(
190 "reference factorization permutation length ({}) != matrix dimension ({})",
191 refdata.permutation.len(),
192 matrix.nrows()
193 ),
194 path: json_path.display().to_string(),
195 line: None,
196 });
197 }
198 Some(refdata)
199 } else {
200 None
201 }
202 } else {
203 None
204 };
205
206 Ok(Some(TestMatrix {
207 metadata: entry.clone(),
208 matrix,
209 reference,
210 }))
211}
212
213pub fn load_test_matrix(name: &str) -> Result<Option<TestMatrix>, SparseError> {
225 let registry = load_registry()?;
226 let entry =
227 registry
228 .iter()
229 .find(|m| m.name == name)
230 .ok_or_else(|| SparseError::MatrixNotFound {
231 name: name.to_string(),
232 })?;
233
234 load_test_matrix_from_entry(entry)
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn load_arrow_5_pd_returns_some() {
243 let test = load_test_matrix("arrow-5-pd")
244 .expect("registry error")
245 .expect("matrix should exist on disk");
246 assert_eq!(test.matrix.nrows(), 5);
247 assert_eq!(test.matrix.ncols(), 5);
248 assert!(test.reference.is_some());
249 }
250
251 #[test]
252 fn nonexistent_matrix_returns_error() {
253 let result = load_test_matrix("nonexistent-matrix-name");
254 assert!(result.is_err());
255 assert!(matches!(
256 result.unwrap_err(),
257 SparseError::MatrixNotFound { .. }
258 ));
259 }
260
261 #[test]
262 fn missing_mtx_file_returns_none() {
263 let fake_entry = MatrixMetadata {
265 name: "fake-missing-matrix".to_string(),
266 source: "test".to_string(),
267 category: "test".to_string(),
268 path: "nonexistent/path/fake.mtx".to_string(),
269 size: 5,
270 nnz: 10,
271 in_repo: false,
272 ci_subset: false,
273 properties: MatrixProperties {
274 symmetric: true,
275 positive_definite: false,
276 indefinite: false,
277 difficulty: "trivial".to_string(),
278 structure: None,
279 kind: None,
280 expected_delayed_pivots: None,
281 },
282 paper_references: vec![],
283 reference_results: serde_json::Value::Null,
284 factorization_path: None,
285 };
286 let result =
287 load_test_matrix_from_entry(&fake_entry).expect("should not error for missing file");
288 assert!(result.is_none(), "missing .mtx file should return None");
289 }
290
291 #[test]
292 fn load_via_entry_matches_load_by_name() {
293 let registry = load_registry().expect("failed to load registry");
294 let entry = registry.iter().find(|m| m.name == "arrow-5-pd").unwrap();
295
296 let by_entry = load_test_matrix_from_entry(entry)
297 .expect("entry load error")
298 .expect("should exist");
299 let by_name = load_test_matrix("arrow-5-pd")
300 .expect("name load error")
301 .expect("should exist");
302
303 assert_eq!(by_entry.matrix.nrows(), by_name.matrix.nrows());
304 assert_eq!(by_entry.matrix.ncols(), by_name.matrix.ncols());
305 assert_eq!(by_entry.matrix.compute_nnz(), by_name.matrix.compute_nnz());
306 }
307}