use scirs2_core::ndarray::{Array1, Array2};
use std::fs::File;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::Path;
use std::str::FromStr;
use crate::error::{IoError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HBMatrixType {
RealUnsymmetric,
RealSymmetric,
RealSymmetricPositiveDefinite,
RealSkewSymmetric,
ComplexUnsymmetric,
ComplexSymmetric,
ComplexHermitian,
ComplexSkewSymmetric,
Pattern,
}
impl FromStr for HBMatrixType {
type Err = IoError;
fn from_str(s: &str) -> Result<Self> {
if s.len() < 3 {
return Err(IoError::FormatError(
"Matrix type string too short".to_string(),
));
}
let chars: Vec<char> = s.chars().collect();
match (chars[0], chars[1], chars[2]) {
('R', 'U', 'A') => Ok(HBMatrixType::RealUnsymmetric),
('R', 'S', 'A') => Ok(HBMatrixType::RealSymmetric),
('R', 'S', 'P') => Ok(HBMatrixType::RealSymmetricPositiveDefinite),
('R', 'S', 'S') => Ok(HBMatrixType::RealSkewSymmetric),
('C', 'U', 'A') => Ok(HBMatrixType::ComplexUnsymmetric),
('C', 'S', 'A') => Ok(HBMatrixType::ComplexSymmetric),
('C', 'H', 'A') => Ok(HBMatrixType::ComplexHermitian),
('C', 'S', 'S') => Ok(HBMatrixType::ComplexSkewSymmetric),
('P', _, _) => Ok(HBMatrixType::Pattern),
_ => Err(IoError::FormatError(format!("Unknown matrix type: {s}"))),
}
}
}
impl std::fmt::Display for HBMatrixType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HBMatrixType::RealUnsymmetric => write!(f, "RUA"),
HBMatrixType::RealSymmetric => write!(f, "RSA"),
HBMatrixType::RealSymmetricPositiveDefinite => write!(f, "RSP"),
HBMatrixType::RealSkewSymmetric => write!(f, "RSS"),
HBMatrixType::ComplexUnsymmetric => write!(f, "CUA"),
HBMatrixType::ComplexSymmetric => write!(f, "CSA"),
HBMatrixType::ComplexHermitian => write!(f, "CHA"),
HBMatrixType::ComplexSkewSymmetric => write!(f, "CSS"),
HBMatrixType::Pattern => write!(f, "PUA"),
}
}
}
#[derive(Debug, Clone)]
pub struct HBHeader {
pub title: String,
pub key: String,
pub totcrd: usize,
pub ptrcrd: usize,
pub indcrd: usize,
pub valcrd: usize,
pub rhscrd: usize,
pub mxtype: HBMatrixType,
pub nrow: usize,
pub ncol: usize,
pub nnzero: usize,
pub neltvl: usize,
pub ptrfmt: String,
pub indfmt: String,
pub valfmt: String,
pub rhsfmt: String,
}
#[derive(Debug, Clone)]
pub struct HBSparseMatrix<T> {
pub header: HBHeader,
pub colptr: Vec<usize>,
pub rowind: Vec<usize>,
pub values: Option<Vec<T>>,
pub rhs: Option<Array2<T>>,
}
impl HBHeader {
pub fn parse_header<R: BufRead>(reader: &mut R) -> Result<Self> {
let mut lines = Vec::new();
for _ in 0..5 {
let mut line = String::new();
reader
.read_line(&mut line)
.map_err(|e| IoError::FileError(e.to_string()))?;
lines.push(line.trim_end().to_string());
}
if lines.len() < 4 {
return Err(IoError::FormatError("Incomplete header".to_string()));
}
let title = if lines[0].len() > 72 {
lines[0][..72].to_string()
} else {
lines[0].clone()
};
let line2_parts: Vec<&str> = lines[1].split_whitespace().collect();
if line2_parts.len() < 5 {
return Err(IoError::FormatError("Invalid line 2 format".to_string()));
}
let key = line2_parts[0].to_string();
let totcrd = line2_parts[1]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid totcrd".to_string()))?;
let ptrcrd = line2_parts[2]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid ptrcrd".to_string()))?;
let indcrd = line2_parts[3]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid indcrd".to_string()))?;
let valcrd = line2_parts[4]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid valcrd".to_string()))?;
let rhscrd = if line2_parts.len() > 5 {
line2_parts[5]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid rhscrd".to_string()))?
} else {
0
};
let line3_parts: Vec<&str> = lines[2].split_whitespace().collect();
if line3_parts.len() < 4 {
return Err(IoError::FormatError("Invalid line 3 format".to_string()));
}
let mxtype = HBMatrixType::from_str(line3_parts[0])?;
let nrow = line3_parts[1]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid nrow".to_string()))?;
let ncol = line3_parts[2]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid ncol".to_string()))?;
let nnzero = line3_parts[3]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid nnzero".to_string()))?;
let neltvl = if line3_parts.len() > 4 {
line3_parts[4]
.parse::<usize>()
.map_err(|_| IoError::FormatError("Invalid neltvl".to_string()))?
} else {
0
};
let line4_parts: Vec<&str> = lines[3].split_whitespace().collect();
if line4_parts.len() < 2 {
return Err(IoError::FormatError("Invalid line 4 format".to_string()));
}
let ptrfmt = line4_parts[0].to_string();
let indfmt = line4_parts[1].to_string();
let valfmt = if line4_parts.len() > 2 {
line4_parts[2].to_string()
} else {
String::new()
};
let rhsfmt = if line4_parts.len() > 3 {
line4_parts[3].to_string()
} else {
String::new()
};
Ok(HBHeader {
title,
key,
totcrd,
ptrcrd,
indcrd,
valcrd,
rhscrd,
mxtype,
nrow,
ncol,
nnzero,
neltvl,
ptrfmt,
indfmt,
valfmt,
rhsfmt,
})
}
pub fn write_header<W: Write>(&self, writer: &mut W) -> Result<()> {
let padded_title = format!("{:<72}", self.title);
writeln!(writer, "{}", &padded_title[..72.min(padded_title.len())])
.map_err(|e| IoError::FileError(e.to_string()))?;
writeln!(
writer,
"{:<8} {:>7} {:>7} {:>7} {:>7} {:>7}",
self.key, self.totcrd, self.ptrcrd, self.indcrd, self.valcrd, self.rhscrd
)
.map_err(|e| IoError::FileError(e.to_string()))?;
writeln!(
writer,
"{:<3} {:>11} {:>11} {:>11} {:>11}",
self.mxtype, self.nrow, self.ncol, self.nnzero, self.neltvl
)
.map_err(|e| IoError::FileError(e.to_string()))?;
writeln!(
writer,
"{:<16} {:<16} {:<20} {:<20}",
self.ptrfmt, self.indfmt, self.valfmt, self.rhsfmt
)
.map_err(|e| IoError::FileError(e.to_string()))?;
Ok(())
}
}
#[allow(dead_code)]
pub fn read_harwell_boeing<P: AsRef<Path>>(path: P) -> Result<HBSparseMatrix<f64>> {
let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut reader = BufReader::new(file);
let header = HBHeader::parse_header(&mut reader)?;
let mut colptr = Vec::with_capacity(header.ncol + 1);
read_integer_data(&mut reader, header.ptrcrd, &mut colptr)?;
if colptr.len() != header.ncol + 1 {
return Err(IoError::FormatError(format!(
"Expected {} column pointers, got {}",
header.ncol + 1,
colptr.len()
)));
}
let mut rowind = Vec::with_capacity(header.nnzero);
read_integer_data(&mut reader, header.indcrd, &mut rowind)?;
if rowind.len() != header.nnzero {
return Err(IoError::FormatError(format!(
"Expected {} row indices, got {}",
header.nnzero,
rowind.len()
)));
}
for ptr in &mut colptr {
*ptr -= 1;
}
for idx in &mut rowind {
*idx -= 1;
}
let values = if header.mxtype == HBMatrixType::Pattern {
None
} else {
let mut vals = Vec::with_capacity(header.nnzero);
read_real_data(&mut reader, header.valcrd, &mut vals)?;
if vals.len() != header.nnzero {
return Err(IoError::FormatError(format!(
"Expected {} values, got {}",
header.nnzero,
vals.len()
)));
}
Some(vals)
};
let rhs = if header.rhscrd > 0 {
if !header.rhsfmt.is_empty() {
let mut rhs_data = Vec::new();
read_real_data(&mut reader, header.rhscrd, &mut rhs_data)?;
let nrhs = 1;
if rhs_data.len() >= header.nrow * nrhs {
let mut rhsmatrix = Array2::zeros((header.nrow, nrhs));
for i in 0..header.nrow {
for j in 0..nrhs {
let idx = j * header.nrow + i; if idx < rhs_data.len() {
rhsmatrix[[i, j]] = rhs_data[idx];
}
}
}
Some(rhsmatrix)
} else {
return Err(IoError::FormatError(format!(
"Insufficient RHS data: expected at least {}, got {}",
header.nrow * nrhs,
rhs_data.len()
)));
}
} else {
None
}
} else {
None
};
Ok(HBSparseMatrix {
header,
colptr,
rowind,
values,
rhs,
})
}
#[allow(dead_code)]
pub fn write_harwell_boeing<P: AsRef<Path>>(path: P, matrix: &HBSparseMatrix<f64>) -> Result<()> {
let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
let mut writer = BufWriter::new(file);
matrix.header.write_header(&mut writer)?;
let colptr_1based: Vec<usize> = matrix.colptr.iter().map(|&x| x + 1).collect();
write_integer_data(&mut writer, &colptr_1based, 8)?;
let rowind_1based: Vec<usize> = matrix.rowind.iter().map(|&x| x + 1).collect();
write_integer_data(&mut writer, &rowind_1based, 8)?;
if let Some(ref values) = matrix.values {
write_real_data(&mut writer, values, 16)?;
}
if let Some(ref rhsmatrix) = matrix.rhs {
if matrix.header.rhscrd > 0 && !matrix.header.rhsfmt.is_empty() {
let mut rhs_data = Vec::new();
for j in 0..rhsmatrix.ncols() {
for i in 0..rhsmatrix.nrows() {
rhs_data.push(rhsmatrix[[i, j]]);
}
}
write_real_data(&mut writer, &rhs_data, 20)?; }
}
writer
.flush()
.map_err(|e| IoError::FileError(e.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn hb_to_ccs(matrix: &HBSparseMatrix<f64>) -> (Array1<usize>, Array1<usize>, Array1<f64>) {
let colptr = Array1::from(matrix.colptr.clone());
let rowind = Array1::from(matrix.rowind.clone());
let values = if let Some(ref vals) = matrix.values {
Array1::from(vals.clone())
} else {
Array1::from(vec![1.0; matrix.rowind.len()]) };
(colptr, rowind, values)
}
#[allow(dead_code)]
pub fn ccs_to_hb(
colptr: &Array1<usize>,
rowind: &Array1<usize>,
values: &Array1<f64>,
shape: (usize, usize),
title: String,
key: String,
mxtype: HBMatrixType,
) -> HBSparseMatrix<f64> {
ccs_to_hb_with_rhs(colptr, rowind, values, shape, title, key, mxtype, None)
}
#[allow(dead_code)]
pub fn ccs_to_hb_with_rhs(
colptr: &Array1<usize>,
rowind: &Array1<usize>,
values: &Array1<f64>,
shape: (usize, usize),
title: String,
key: String,
mxtype: HBMatrixType,
rhs: Option<Array2<f64>>,
) -> HBSparseMatrix<f64> {
let (nrow, ncol) = shape;
let nnzero = rowind.len();
let rhscrd = if let Some(ref rhsmatrix) = rhs {
let total_rhs_elements = rhsmatrix.nrows() * rhsmatrix.ncols();
(total_rhs_elements + 3) / 4 } else {
0
};
let ptrcrd = ((ncol + 1) + 7) / 8; let indcrd = (nnzero + 7) / 8; let valcrd = if mxtype == HBMatrixType::Pattern {
0
} else {
(nnzero + 3) / 4 };
let header_lines = if rhscrd > 0 { 5 } else { 4 }; let totcrd = header_lines + ptrcrd + indcrd + valcrd + rhscrd;
let header = HBHeader {
title,
key,
totcrd,
ptrcrd,
indcrd,
valcrd,
rhscrd,
mxtype,
nrow,
ncol,
nnzero,
neltvl: 0,
ptrfmt: "(8I10)".to_string(),
indfmt: "(8I10)".to_string(),
valfmt: if mxtype == HBMatrixType::Pattern {
String::new()
} else {
"(4E20.12)".to_string()
},
rhsfmt: if rhscrd > 0 {
"(4E20.12)".to_string()
} else {
String::new()
},
};
HBSparseMatrix {
header,
colptr: colptr.to_vec(),
rowind: rowind.to_vec(),
values: if mxtype == HBMatrixType::Pattern {
None
} else {
Some(values.to_vec())
},
rhs,
}
}
#[allow(dead_code)]
fn read_integer_data<R: BufRead>(
reader: &mut R,
num_lines: usize,
data: &mut Vec<usize>,
) -> Result<()> {
for _ in 0..num_lines {
let mut line = String::new();
reader
.read_line(&mut line)
.map_err(|e| IoError::FileError(e.to_string()))?;
for token in line.split_whitespace() {
if let Ok(value) = token.parse::<usize>() {
data.push(value);
}
}
}
Ok(())
}
#[allow(dead_code)]
fn read_real_data<R: BufRead>(reader: &mut R, num_lines: usize, data: &mut Vec<f64>) -> Result<()> {
for _ in 0..num_lines {
let mut line = String::new();
reader
.read_line(&mut line)
.map_err(|e| IoError::FileError(e.to_string()))?;
for token in line.split_whitespace() {
if let Ok(value) = token.parse::<f64>() {
data.push(value);
}
}
}
Ok(())
}
#[allow(dead_code)]
fn write_integer_data<W: Write>(writer: &mut W, data: &[usize], fieldwidth: usize) -> Result<()> {
const INTS_PER_LINE: usize = 8;
for chunk in data.chunks(INTS_PER_LINE) {
for (i, &value) in chunk.iter().enumerate() {
if i > 0 {
write!(writer, " ").map_err(|e| IoError::FileError(e.to_string()))?;
}
write!(writer, "{value:fieldwidth$}").map_err(|e| IoError::FileError(e.to_string()))?;
}
writeln!(writer).map_err(|e| IoError::FileError(e.to_string()))?;
}
Ok(())
}
#[allow(dead_code)]
fn write_real_data<W: Write>(writer: &mut W, data: &[f64], fieldwidth: usize) -> Result<()> {
const REALS_PER_LINE: usize = 4;
for chunk in data.chunks(REALS_PER_LINE) {
for (i, &value) in chunk.iter().enumerate() {
if i > 0 {
write!(writer, " ").map_err(|e| IoError::FileError(e.to_string()))?;
}
write!(writer, "{value:fieldwidth$.6E}")
.map_err(|e| IoError::FileError(e.to_string()))?;
}
writeln!(writer).map_err(|e| IoError::FileError(e.to_string()))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn testmatrix_type_parsing() {
assert_eq!(
HBMatrixType::from_str("RUA").expect("Operation failed"),
HBMatrixType::RealUnsymmetric
);
assert_eq!(
HBMatrixType::from_str("RSA").expect("Operation failed"),
HBMatrixType::RealSymmetric
);
assert_eq!(
HBMatrixType::from_str("CUA").expect("Operation failed"),
HBMatrixType::ComplexUnsymmetric
);
assert_eq!(
HBMatrixType::from_str("PUA").expect("Operation failed"),
HBMatrixType::Pattern
);
}
#[test]
fn testmatrix_type_display() {
assert_eq!(HBMatrixType::RealUnsymmetric.to_string(), "RUA");
assert_eq!(HBMatrixType::ComplexHermitian.to_string(), "CHA");
assert_eq!(HBMatrixType::Pattern.to_string(), "PUA");
}
#[test]
fn test_ccs_conversion() {
let colptr = Array1::from(vec![0, 2, 4]);
let rowind = Array1::from(vec![0, 1, 0, 1]);
let values = Array1::from(vec![1.0, 2.0, 3.0, 4.0]);
let hbmatrix = ccs_to_hb(
&colptr,
&rowind,
&values,
(2, 2),
"Test matrix".to_string(),
"TEST".to_string(),
HBMatrixType::RealUnsymmetric,
);
assert_eq!(hbmatrix.header.nrow, 2);
assert_eq!(hbmatrix.header.ncol, 2);
assert_eq!(hbmatrix.header.nnzero, 4);
assert_eq!(hbmatrix.colptr, vec![0, 2, 4]);
assert_eq!(hbmatrix.rowind, vec![0, 1, 0, 1]);
assert_eq!(
hbmatrix.values.as_ref().expect("Operation failed"),
&vec![1.0, 2.0, 3.0, 4.0]
);
let (new_colptr, new_rowind, new_values) = hb_to_ccs(&hbmatrix);
assert_eq!(new_colptr, colptr);
assert_eq!(new_rowind, rowind);
assert_eq!(new_values, values);
}
#[test]
fn test_patternmatrix() {
let colptr = Array1::from(vec![0, 1, 2]);
let rowind = Array1::from(vec![0, 1]);
let values = Array1::from(vec![1.0, 1.0]);
let hbmatrix = ccs_to_hb(
&colptr,
&rowind,
&values,
(2, 2),
"Pattern matrix".to_string(),
"PATTERN".to_string(),
HBMatrixType::Pattern,
);
assert_eq!(hbmatrix.header.mxtype, HBMatrixType::Pattern);
assert!(hbmatrix.values.is_none());
assert_eq!(hbmatrix.header.valcrd, 0);
}
#[test]
fn test_header_fields() {
let header = HBHeader {
title: "Test matrix for unit testing".to_string(),
key: "TESTKEY".to_string(),
totcrd: 10,
ptrcrd: 1,
indcrd: 1,
valcrd: 1,
rhscrd: 0,
mxtype: HBMatrixType::RealSymmetric,
nrow: 100,
ncol: 100,
nnzero: 500,
neltvl: 0,
ptrfmt: "(8I10)".to_string(),
indfmt: "(8I10)".to_string(),
valfmt: "(4E20.12)".to_string(),
rhsfmt: String::new(),
};
assert_eq!(header.title, "Test matrix for unit testing");
assert_eq!(header.key, "TESTKEY");
assert_eq!(header.mxtype, HBMatrixType::RealSymmetric);
assert_eq!(header.nrow, 100);
assert_eq!(header.ncol, 100);
assert_eq!(header.nnzero, 500);
}
}