use std::path::Path;
use serde::Deserialize;
use crate::error::SparseError;
use crate::validate;
pub use crate::symmetric::Inertia;
#[derive(Debug, Clone, Deserialize)]
pub struct LEntry {
pub row: usize,
pub col: usize,
pub value: f64,
}
#[derive(Debug, Clone)]
pub enum DBlock {
OneByOne {
value: f64,
},
TwoByTwo {
values: [[f64; 2]; 2],
},
}
#[derive(Debug, Clone, Deserialize)]
pub struct ReferenceFactorization {
pub matrix_name: String,
pub permutation: Vec<usize>,
pub l_entries: Vec<LEntry>,
pub d_blocks: Vec<DBlock>,
pub inertia: Inertia,
#[serde(default)]
pub notes: String,
}
pub fn load_reference(path: &Path) -> Result<ReferenceFactorization, SparseError> {
let path_str = path.display().to_string();
let content = std::fs::read_to_string(path).map_err(|e| SparseError::IoError {
source: e.to_string(),
path: path_str.clone(),
})?;
let refdata: ReferenceFactorization =
serde_json::from_str(&content).map_err(|e| SparseError::ParseError {
reason: e.to_string(),
path: path_str.clone(),
line: None,
})?;
for (i, entry) in refdata.l_entries.iter().enumerate() {
if entry.col >= entry.row {
return Err(SparseError::ParseError {
reason: format!(
"l_entry[{}] has col ({}) >= row ({}); must be strict lower triangle",
i, entry.col, entry.row
),
path: path_str,
line: None,
});
}
}
let n = refdata.permutation.len();
validate::validate_permutation(&refdata.permutation, n).map_err(|e| {
SparseError::ParseError {
reason: format!("invalid permutation: {}", e),
path: path_str,
line: None,
}
})?;
Ok(refdata)
}
impl<'de> Deserialize<'de> for DBlock {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
let raw: serde_json::Value = Deserialize::deserialize(deserializer)?;
let obj = raw
.as_object()
.ok_or_else(|| D::Error::custom("d_block must be an object"))?;
let size = obj
.get("size")
.and_then(|v| v.as_u64())
.ok_or_else(|| D::Error::custom("d_block must have integer 'size' field"))?;
let values = obj
.get("values")
.ok_or_else(|| D::Error::custom("d_block must have 'values' field"))?;
match size {
1 => {
let arr = values
.as_array()
.ok_or_else(|| D::Error::custom("1x1 d_block values must be an array"))?;
if arr.len() != 1 {
return Err(D::Error::custom(format!(
"1x1 d_block values must have exactly 1 element, got {}",
arr.len()
)));
}
let value = arr[0]
.as_f64()
.ok_or_else(|| D::Error::custom("1x1 d_block value must be a number"))?;
Ok(DBlock::OneByOne { value })
}
2 => {
let arr = values
.as_array()
.ok_or_else(|| D::Error::custom("2x2 d_block values must be an array"))?;
if arr.len() != 2 {
return Err(D::Error::custom(format!(
"2x2 d_block values must have exactly 2 rows, got {}",
arr.len()
)));
}
let mut vals = [[0.0f64; 2]; 2];
for (i, row) in arr.iter().enumerate() {
let row_arr = row.as_array().ok_or_else(|| {
D::Error::custom(format!("2x2 d_block row {} must be an array", i))
})?;
if row_arr.len() != 2 {
return Err(D::Error::custom(format!(
"2x2 d_block row {} must have exactly 2 elements, got {}",
i,
row_arr.len()
)));
}
for (j, val) in row_arr.iter().enumerate() {
vals[i][j] = val.as_f64().ok_or_else(|| {
D::Error::custom(format!(
"2x2 d_block value at ({}, {}) must be a number",
i, j
))
})?;
}
}
Ok(DBlock::TwoByTwo { values: vals })
}
_ => Err(D::Error::custom(format!(
"d_block size must be 1 or 2, got {}",
size
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn test_data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data")
}
#[test]
fn load_arrow_5_pd_reference() {
let path = test_data_dir().join("hand-constructed/arrow-5-pd.json");
let refdata = load_reference(&path).expect("failed to load arrow-5-pd.json");
assert_eq!(
refdata.inertia,
Inertia {
positive: 5,
negative: 0,
zero: 0
}
);
assert_eq!(refdata.l_entries.len(), 10);
assert_eq!(refdata.permutation.len(), 5);
assert_eq!(refdata.d_blocks.len(), 5);
for block in &refdata.d_blocks {
assert!(matches!(block, DBlock::OneByOne { .. }));
}
}
#[test]
fn load_stress_delayed_pivots_2x2_blocks() {
let path = test_data_dir().join("hand-constructed/stress-delayed-pivots.json");
let refdata = load_reference(&path).expect("failed to load stress-delayed-pivots.json");
assert_eq!(refdata.d_blocks.len(), 5);
for block in &refdata.d_blocks {
assert!(matches!(block, DBlock::TwoByTwo { .. }));
}
assert_eq!(
refdata.inertia,
Inertia {
positive: 5,
negative: 5,
zero: 0
}
);
}
#[test]
fn invalid_json_returns_error() {
let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("target/test-tmp");
std::fs::create_dir_all(&dir).ok();
let path = dir.join("invalid.json");
std::fs::write(&path, "{ not valid json }").unwrap();
let result = load_reference(&path);
assert!(result.is_err());
}
}