rivrs_sparse/io/
reference.rs1use std::path::Path;
8
9use serde::Deserialize;
10
11use crate::error::SparseError;
12use crate::validate;
13
14pub use crate::symmetric::Inertia;
16
17#[derive(Debug, Clone, Deserialize)]
19pub struct LEntry {
20 pub row: usize,
22 pub col: usize,
24 pub value: f64,
26}
27
28#[derive(Debug, Clone)]
32pub enum DBlock {
33 OneByOne {
35 value: f64,
37 },
38 TwoByTwo {
40 values: [[f64; 2]; 2],
42 },
43}
44
45#[derive(Debug, Clone, Deserialize)]
47pub struct ReferenceFactorization {
48 pub matrix_name: String,
50 pub permutation: Vec<usize>,
52 pub l_entries: Vec<LEntry>,
54 pub d_blocks: Vec<DBlock>,
56 pub inertia: Inertia,
58 #[serde(default)]
60 pub notes: String,
61}
62
63pub fn load_reference(path: &Path) -> Result<ReferenceFactorization, SparseError> {
71 let path_str = path.display().to_string();
72
73 let content = std::fs::read_to_string(path).map_err(|e| SparseError::IoError {
74 source: e.to_string(),
75 path: path_str.clone(),
76 })?;
77
78 let refdata: ReferenceFactorization =
79 serde_json::from_str(&content).map_err(|e| SparseError::ParseError {
80 reason: e.to_string(),
81 path: path_str.clone(),
82 line: None,
83 })?;
84
85 for (i, entry) in refdata.l_entries.iter().enumerate() {
87 if entry.col >= entry.row {
88 return Err(SparseError::ParseError {
89 reason: format!(
90 "l_entry[{}] has col ({}) >= row ({}); must be strict lower triangle",
91 i, entry.col, entry.row
92 ),
93 path: path_str,
94 line: None,
95 });
96 }
97 }
98
99 let n = refdata.permutation.len();
101 validate::validate_permutation(&refdata.permutation, n).map_err(|e| {
102 SparseError::ParseError {
103 reason: format!("invalid permutation: {}", e),
104 path: path_str,
105 line: None,
106 }
107 })?;
108
109 Ok(refdata)
110}
111
112impl<'de> Deserialize<'de> for DBlock {
114 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115 where
116 D: serde::Deserializer<'de>,
117 {
118 use serde::de::Error;
119
120 let raw: serde_json::Value = Deserialize::deserialize(deserializer)?;
121 let obj = raw
122 .as_object()
123 .ok_or_else(|| D::Error::custom("d_block must be an object"))?;
124
125 let size = obj
126 .get("size")
127 .and_then(|v| v.as_u64())
128 .ok_or_else(|| D::Error::custom("d_block must have integer 'size' field"))?;
129
130 let values = obj
131 .get("values")
132 .ok_or_else(|| D::Error::custom("d_block must have 'values' field"))?;
133
134 match size {
135 1 => {
136 let arr = values
138 .as_array()
139 .ok_or_else(|| D::Error::custom("1x1 d_block values must be an array"))?;
140 if arr.len() != 1 {
141 return Err(D::Error::custom(format!(
142 "1x1 d_block values must have exactly 1 element, got {}",
143 arr.len()
144 )));
145 }
146 let value = arr[0]
147 .as_f64()
148 .ok_or_else(|| D::Error::custom("1x1 d_block value must be a number"))?;
149 Ok(DBlock::OneByOne { value })
150 }
151 2 => {
152 let arr = values
154 .as_array()
155 .ok_or_else(|| D::Error::custom("2x2 d_block values must be an array"))?;
156 if arr.len() != 2 {
157 return Err(D::Error::custom(format!(
158 "2x2 d_block values must have exactly 2 rows, got {}",
159 arr.len()
160 )));
161 }
162 let mut vals = [[0.0f64; 2]; 2];
163 for (i, row) in arr.iter().enumerate() {
164 let row_arr = row.as_array().ok_or_else(|| {
165 D::Error::custom(format!("2x2 d_block row {} must be an array", i))
166 })?;
167 if row_arr.len() != 2 {
168 return Err(D::Error::custom(format!(
169 "2x2 d_block row {} must have exactly 2 elements, got {}",
170 i,
171 row_arr.len()
172 )));
173 }
174 for (j, val) in row_arr.iter().enumerate() {
175 vals[i][j] = val.as_f64().ok_or_else(|| {
176 D::Error::custom(format!(
177 "2x2 d_block value at ({}, {}) must be a number",
178 i, j
179 ))
180 })?;
181 }
182 }
183 Ok(DBlock::TwoByTwo { values: vals })
184 }
185 _ => Err(D::Error::custom(format!(
186 "d_block size must be 1 or 2, got {}",
187 size
188 ))),
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::path::PathBuf;
197
198 fn test_data_dir() -> PathBuf {
199 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data")
200 }
201
202 #[test]
203 fn load_arrow_5_pd_reference() {
204 let path = test_data_dir().join("hand-constructed/arrow-5-pd.json");
205 let refdata = load_reference(&path).expect("failed to load arrow-5-pd.json");
206 assert_eq!(
207 refdata.inertia,
208 Inertia {
209 positive: 5,
210 negative: 0,
211 zero: 0
212 }
213 );
214 assert_eq!(refdata.l_entries.len(), 10);
215 assert_eq!(refdata.permutation.len(), 5);
216 assert_eq!(refdata.d_blocks.len(), 5);
217 for block in &refdata.d_blocks {
219 assert!(matches!(block, DBlock::OneByOne { .. }));
220 }
221 }
222
223 #[test]
224 fn load_stress_delayed_pivots_2x2_blocks() {
225 let path = test_data_dir().join("hand-constructed/stress-delayed-pivots.json");
226 let refdata = load_reference(&path).expect("failed to load stress-delayed-pivots.json");
227 assert_eq!(refdata.d_blocks.len(), 5);
228 for block in &refdata.d_blocks {
230 assert!(matches!(block, DBlock::TwoByTwo { .. }));
231 }
232 assert_eq!(
233 refdata.inertia,
234 Inertia {
235 positive: 5,
236 negative: 5,
237 zero: 0
238 }
239 );
240 }
241
242 #[test]
243 fn invalid_json_returns_error() {
244 let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("target/test-tmp");
245 std::fs::create_dir_all(&dir).ok();
246 let path = dir.join("invalid.json");
247 std::fs::write(&path, "{ not valid json }").unwrap();
248 let result = load_reference(&path);
249 assert!(result.is_err());
250 }
251}