Skip to main content

rivrs_sparse/io/
reference.rs

1//! Reference factorization data loader (JSON format).
2//!
3//! Loads companion `.json` files for hand-constructed test matrices that contain
4//! analytically known LDL^T factorizations (L factor, D diagonal, permutation,
5//! and inertia).
6
7use std::path::Path;
8
9use serde::Deserialize;
10
11use crate::error::SparseError;
12use crate::validate;
13
14// Inertia relocated to symmetric module; re-export for backward compatibility.
15pub use crate::symmetric::Inertia;
16
17/// Single entry in the strict lower triangle of L.
18#[derive(Debug, Clone, Deserialize)]
19pub struct LEntry {
20    /// Row index (0-indexed).
21    pub row: usize,
22    /// Column index (0-indexed), must satisfy col < row.
23    pub col: usize,
24    /// Entry value.
25    pub value: f64,
26}
27
28/// One block of the block diagonal D.
29///
30/// Either a 1×1 scalar pivot or a 2×2 symmetric pivot block.
31#[derive(Debug, Clone)]
32pub enum DBlock {
33    /// 1×1 scalar pivot.
34    OneByOne {
35        /// The scalar pivot value.
36        value: f64,
37    },
38    /// 2×2 symmetric pivot block (row-major).
39    TwoByTwo {
40        /// The 2x2 block entries in row-major layout.
41        values: [[f64; 2]; 2],
42    },
43}
44
45/// The known-correct LDL^T factorization of a hand-constructed matrix.
46#[derive(Debug, Clone, Deserialize)]
47pub struct ReferenceFactorization {
48    /// Matrix name (must match the MatrixMetadata name).
49    pub matrix_name: String,
50    /// Pivot permutation (0-indexed).
51    pub permutation: Vec<usize>,
52    /// Strict lower-triangular entries of L.
53    pub l_entries: Vec<LEntry>,
54    /// Block diagonal D (1×1 or 2×2 blocks).
55    pub d_blocks: Vec<DBlock>,
56    /// Eigenvalue sign counts.
57    pub inertia: Inertia,
58    /// Human-readable description.
59    #[serde(default)]
60    pub notes: String,
61}
62
63/// Load a reference factorization from a companion JSON file.
64///
65/// # Errors
66///
67/// - File not found or unreadable
68/// - Invalid JSON structure
69/// - Inconsistent data (l_entry indices out of bounds, invalid permutation)
70pub fn load_reference(path: &Path) -> Result<ReferenceFactorization, SparseError> {
71    let path_str = path.display().to_string();
72
73    let content = std::fs::read_to_string(path).map_err(|e| SparseError::IoError {
74        source: e.to_string(),
75        path: path_str.clone(),
76    })?;
77
78    let refdata: ReferenceFactorization =
79        serde_json::from_str(&content).map_err(|e| SparseError::ParseError {
80            reason: e.to_string(),
81            path: path_str.clone(),
82            line: None,
83        })?;
84
85    // Validate l_entries: col < row
86    for (i, entry) in refdata.l_entries.iter().enumerate() {
87        if entry.col >= entry.row {
88            return Err(SparseError::ParseError {
89                reason: format!(
90                    "l_entry[{}] has col ({}) >= row ({}); must be strict lower triangle",
91                    i, entry.col, entry.row
92                ),
93                path: path_str,
94                line: None,
95            });
96        }
97    }
98
99    // Validate permutation using shared utility
100    let n = refdata.permutation.len();
101    validate::validate_permutation(&refdata.permutation, n).map_err(|e| {
102        SparseError::ParseError {
103            reason: format!("invalid permutation: {}", e),
104            path: path_str,
105            line: None,
106        }
107    })?;
108
109    Ok(refdata)
110}
111
112// Custom serde for DBlock: the JSON uses {"size": 1, "values": [...]} or {"size": 2, "values": [[...], [...]]}
113impl<'de> Deserialize<'de> for DBlock {
114    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
115    where
116        D: serde::Deserializer<'de>,
117    {
118        use serde::de::Error;
119
120        let raw: serde_json::Value = Deserialize::deserialize(deserializer)?;
121        let obj = raw
122            .as_object()
123            .ok_or_else(|| D::Error::custom("d_block must be an object"))?;
124
125        let size = obj
126            .get("size")
127            .and_then(|v| v.as_u64())
128            .ok_or_else(|| D::Error::custom("d_block must have integer 'size' field"))?;
129
130        let values = obj
131            .get("values")
132            .ok_or_else(|| D::Error::custom("d_block must have 'values' field"))?;
133
134        match size {
135            1 => {
136                // values is an array with one element: [scalar]
137                let arr = values
138                    .as_array()
139                    .ok_or_else(|| D::Error::custom("1x1 d_block values must be an array"))?;
140                if arr.len() != 1 {
141                    return Err(D::Error::custom(format!(
142                        "1x1 d_block values must have exactly 1 element, got {}",
143                        arr.len()
144                    )));
145                }
146                let value = arr[0]
147                    .as_f64()
148                    .ok_or_else(|| D::Error::custom("1x1 d_block value must be a number"))?;
149                Ok(DBlock::OneByOne { value })
150            }
151            2 => {
152                // values is a 2x2 array: [[a, b], [c, d]]
153                let arr = values
154                    .as_array()
155                    .ok_or_else(|| D::Error::custom("2x2 d_block values must be an array"))?;
156                if arr.len() != 2 {
157                    return Err(D::Error::custom(format!(
158                        "2x2 d_block values must have exactly 2 rows, got {}",
159                        arr.len()
160                    )));
161                }
162                let mut vals = [[0.0f64; 2]; 2];
163                for (i, row) in arr.iter().enumerate() {
164                    let row_arr = row.as_array().ok_or_else(|| {
165                        D::Error::custom(format!("2x2 d_block row {} must be an array", i))
166                    })?;
167                    if row_arr.len() != 2 {
168                        return Err(D::Error::custom(format!(
169                            "2x2 d_block row {} must have exactly 2 elements, got {}",
170                            i,
171                            row_arr.len()
172                        )));
173                    }
174                    for (j, val) in row_arr.iter().enumerate() {
175                        vals[i][j] = val.as_f64().ok_or_else(|| {
176                            D::Error::custom(format!(
177                                "2x2 d_block value at ({}, {}) must be a number",
178                                i, j
179                            ))
180                        })?;
181                    }
182                }
183                Ok(DBlock::TwoByTwo { values: vals })
184            }
185            _ => Err(D::Error::custom(format!(
186                "d_block size must be 1 or 2, got {}",
187                size
188            ))),
189        }
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use std::path::PathBuf;
197
198    fn test_data_dir() -> PathBuf {
199        PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data")
200    }
201
202    #[test]
203    fn load_arrow_5_pd_reference() {
204        let path = test_data_dir().join("hand-constructed/arrow-5-pd.json");
205        let refdata = load_reference(&path).expect("failed to load arrow-5-pd.json");
206        assert_eq!(
207            refdata.inertia,
208            Inertia {
209                positive: 5,
210                negative: 0,
211                zero: 0
212            }
213        );
214        assert_eq!(refdata.l_entries.len(), 10);
215        assert_eq!(refdata.permutation.len(), 5);
216        assert_eq!(refdata.d_blocks.len(), 5);
217        // All d_blocks should be 1x1 for this PD matrix
218        for block in &refdata.d_blocks {
219            assert!(matches!(block, DBlock::OneByOne { .. }));
220        }
221    }
222
223    #[test]
224    fn load_stress_delayed_pivots_2x2_blocks() {
225        let path = test_data_dir().join("hand-constructed/stress-delayed-pivots.json");
226        let refdata = load_reference(&path).expect("failed to load stress-delayed-pivots.json");
227        assert_eq!(refdata.d_blocks.len(), 5);
228        // All d_blocks should be 2x2 for this stress test
229        for block in &refdata.d_blocks {
230            assert!(matches!(block, DBlock::TwoByTwo { .. }));
231        }
232        assert_eq!(
233            refdata.inertia,
234            Inertia {
235                positive: 5,
236                negative: 5,
237                zero: 0
238            }
239        );
240    }
241
242    #[test]
243    fn invalid_json_returns_error() {
244        let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("target/test-tmp");
245        std::fs::create_dir_all(&dir).ok();
246        let path = dir.join("invalid.json");
247        std::fs::write(&path, "{ not valid json }").unwrap();
248        let result = load_reference(&path);
249        assert!(result.is_err());
250    }
251}