use std::error::Error;
use std::fmt;
use std::fs::File;
use std::io;
use std::io::{Seek, SeekFrom, Write};
use std::path::Path;
use num_traits::cast::NumCast;
use crate::indexing::SpIndex;
use crate::num_kinds::{NumKind, PrimitiveKind};
use crate::sparse::{SparseMat, TriMatI};
#[derive(Debug)]
pub enum IoError {
Io(io::Error),
BadMatrixMarketFile,
UnsupportedMatrixMarketFormat,
}
use self::IoError::*;
impl fmt::Display for IoError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Io(ref err) => err.fmt(f),
Self::BadMatrixMarketFile | Self::UnsupportedMatrixMarketFormat => {
write!(f, "Bad matrix market file.")
}
}
}
}
impl Error for IoError {}
impl From<io::Error> for IoError {
fn from(err: io::Error) -> Self {
Self::Io(err)
}
}
impl PartialEq for IoError {
fn eq(&self, rhs: &Self) -> bool {
match *self {
Self::BadMatrixMarketFile => {
matches!(*rhs, Self::BadMatrixMarketFile)
}
Self::UnsupportedMatrixMarketFormat => {
matches!(*rhs, Self::UnsupportedMatrixMarketFormat)
}
Self::Io(..) => false,
}
}
}
#[derive(Debug, PartialEq)]
enum DataType {
Integer,
Real,
Complex,
}
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum SymmetryMode {
General,
Hermitian,
Symmetric,
SkewSymmetric,
}
fn parse_header(header: &str) -> Result<(SymmetryMode, DataType), IoError> {
if !header.starts_with("%%matrixmarket matrix coordinate") {
return Err(BadMatrixMarketFile);
}
let data_type = if header.contains("real") {
DataType::Real
} else if header.contains("integer") {
DataType::Integer
} else if header.contains("complex") {
DataType::Complex
} else {
return Err(BadMatrixMarketFile);
};
let sym_mode = if header.contains("general") {
SymmetryMode::General
} else if header.contains("symmetric") {
SymmetryMode::Symmetric
} else if header.contains("skew-symmetric") {
SymmetryMode::SkewSymmetric
} else if header.contains("hermitian") {
SymmetryMode::Hermitian
} else {
return Err(BadMatrixMarketFile);
};
Ok((sym_mode, data_type))
}
pub fn read_matrix_market<N, I, P>(mm_file: P) -> Result<TriMatI<N, I>, IoError>
where
I: SpIndex,
N: NumCast + Clone,
P: AsRef<Path>,
{
let mm_file = mm_file.as_ref();
let f = File::open(mm_file)?;
let mut reader = io::BufReader::new(f);
read_matrix_market_from_bufread(&mut reader)
}
pub fn read_matrix_market_from_bufread<N, I, R>(
reader: &mut R,
) -> Result<TriMatI<N, I>, IoError>
where
I: SpIndex,
N: NumCast + Clone,
R: io::BufRead,
{
let mut line = String::with_capacity(1024);
reader.read_line(&mut line)?;
let header = line.to_lowercase();
let (sym_mode, data_type) = parse_header(&header)?;
if data_type == DataType::Complex {
return Err(UnsupportedMatrixMarketFormat);
}
if sym_mode == SymmetryMode::Hermitian {
return Err(UnsupportedMatrixMarketFormat);
}
'header: loop {
line.clear();
let len = reader.read_line(&mut line)?;
if len == 0 || line.starts_with('%') {
continue 'header;
}
break;
}
let (rows, cols, entries) = {
let mut infos = line
.split_whitespace()
.filter_map(|s| s.parse::<usize>().ok());
let rows = infos.next().ok_or(BadMatrixMarketFile)?;
let cols = infos.next().ok_or(BadMatrixMarketFile)?;
let entries = infos.next().ok_or(BadMatrixMarketFile)?;
if infos.next().is_some() {
return Err(BadMatrixMarketFile);
}
(rows, cols, entries)
};
let nnz_max = if sym_mode == SymmetryMode::General {
entries
} else {
2 * entries
};
let mut row_inds = Vec::with_capacity(nnz_max);
let mut col_inds = Vec::with_capacity(nnz_max);
let mut data = Vec::with_capacity(nnz_max);
for _ in 0..entries {
'empty_lines: loop {
line.clear();
let len = reader.read_line(&mut line)?;
if len != 0 && line.split_whitespace().next() == None {
continue 'empty_lines;
}
break;
}
let mut entry = line.split_whitespace();
let row = entry
.next()
.ok_or(BadMatrixMarketFile)
.and_then(|s| s.parse::<usize>().or(Err(BadMatrixMarketFile)))?;
let col = entry
.next()
.ok_or(BadMatrixMarketFile)
.and_then(|s| s.parse::<usize>().or(Err(BadMatrixMarketFile)))?;
let row = row.checked_sub(1).ok_or(BadMatrixMarketFile)?;
let col = col.checked_sub(1).ok_or(BadMatrixMarketFile)?;
let val: N = match data_type {
DataType::Integer => {
let val =
entry.next().ok_or(BadMatrixMarketFile).and_then(|s| {
s.parse::<isize>().or(Err(BadMatrixMarketFile))
})?;
NumCast::from(val).unwrap()
}
DataType::Real => {
let val =
entry.next().ok_or(BadMatrixMarketFile).and_then(|s| {
s.parse::<f64>().or(Err(BadMatrixMarketFile))
})?;
NumCast::from(val).unwrap()
}
DataType::Complex => unreachable!(),
};
row_inds.push(I::from_usize(row));
col_inds.push(I::from_usize(col));
data.push(val.clone());
if sym_mode != SymmetryMode::General && row != col {
if sym_mode == SymmetryMode::Hermitian {
unreachable!();
} else {
row_inds.push(I::from_usize(col));
col_inds.push(I::from_usize(row));
data.push(val);
}
}
if sym_mode == SymmetryMode::SkewSymmetric && row == col {
return Err(BadMatrixMarketFile);
}
if entry.next().is_some() {
return Err(BadMatrixMarketFile);
}
}
Ok(TriMatI::from_triplets(
(rows, cols),
row_inds,
col_inds,
data,
))
}
pub fn write_matrix_market<'a, N, I, M, P>(
path: P,
mat: M,
) -> Result<(), io::Error>
where
I: 'a + SpIndex + fmt::Display,
N: 'a + PrimitiveKind + fmt::Display,
M: IntoIterator<Item = (&'a N, (I, I))> + SparseMat,
P: AsRef<Path>,
{
let (rows, cols, nnz) = (mat.rows(), mat.cols(), mat.nnz());
let f = File::create(path)?;
let mut writer = io::BufWriter::new(f);
let data_type = match N::num_kind() {
NumKind::Integer => "integer",
NumKind::Float => "real",
NumKind::Complex => "complex",
};
writeln!(
writer,
"%%MatrixMarket matrix coordinate {} general",
data_type
)?;
writeln!(writer, "% written by sprs")?;
writeln!(writer, "{} {} {}", rows, cols, nnz)?;
for (val, (row, col)) in mat {
writeln!(writer, "{} {} {}", row.index() + 1, col.index() + 1, val)?;
}
Ok(())
}
pub fn write_matrix_market_sym<'a, N, I, M, P>(
path: P,
mat: M,
sym: SymmetryMode,
) -> Result<(), io::Error>
where
I: 'a + SpIndex + fmt::Display,
N: 'a + PrimitiveKind + fmt::Display,
M: IntoIterator<Item = (&'a N, (I, I))> + SparseMat,
P: AsRef<Path>,
{
let (rows, cols, nnz) = (mat.rows(), mat.cols(), mat.nnz());
let f = File::create(path)?;
let mut writer = io::BufWriter::new(f);
let data_type = match N::num_kind() {
NumKind::Integer => "integer",
NumKind::Float => "real",
NumKind::Complex => "complex",
};
let mode = match sym {
SymmetryMode::General => "general",
SymmetryMode::Symmetric => "symmetric",
SymmetryMode::SkewSymmetric => "skew-symmetric",
SymmetryMode::Hermitian => "hermitian",
};
writeln!(
writer,
"%%MatrixMarket matrix coordinate {} {}",
data_type, mode
)?;
writeln!(writer, "% written by sprs")?;
let dim_header_pos = writer.seek(SeekFrom::Current(0))?;
writeln!(writer, "{} {} {}", rows, cols, nnz)?;
let mut entries = 0;
match sym {
SymmetryMode::General => {
for (val, (row, col)) in mat {
writeln!(
writer,
"{} {} {}",
row.index() + 1,
col.index() + 1,
val
)?;
entries += 1;
}
}
SymmetryMode::SkewSymmetric => {
for (val, (row, col)) in
mat.into_iter().filter(|&(_, (r, c))| r < c)
{
writeln!(
writer,
"{} {} {}",
row.index() + 1,
col.index() + 1,
val
)?;
entries += 1;
}
}
_ => {
for (val, (row, col)) in
mat.into_iter().filter(|&(_, (r, c))| r <= c)
{
writeln!(
writer,
"{} {} {}",
row.index() + 1,
col.index() + 1,
val
)?;
entries += 1;
}
}
};
assert!(entries <= nnz);
writer.seek(SeekFrom::Start(dim_header_pos))?;
write!(writer, "{} {} {}", rows, cols, entries)?;
let dim_header_size = format!("{} {} {}", rows, cols, nnz).len();
let new_size = format!("{} {} {}", rows, cols, entries).len();
if new_size < dim_header_size {
let nb_spaces = dim_header_size - new_size;
for _ in 0..nb_spaces {
writer.write_all(b" ")?;
}
}
Ok(())
}
#[cfg(test)]
mod test {
use super::{
read_matrix_market, read_matrix_market_from_bufread,
write_matrix_market, write_matrix_market_sym, IoError, SymmetryMode,
};
use crate::CsMat;
use tempfile::tempdir;
#[cfg_attr(miri, ignore)]
#[test]
fn simple_matrix_market_read() {
let path = "data/matrix_market/simple.mm";
let mat = read_matrix_market::<f64, usize, _>(path).unwrap();
assert_eq!(mat.rows(), 5);
assert_eq!(mat.cols(), 5);
assert_eq!(mat.nnz(), 8);
assert_eq!(mat.row_inds(), &[0, 1, 2, 0, 3, 3, 3, 4]);
assert_eq!(mat.col_inds(), &[0, 1, 2, 3, 1, 3, 4, 4]);
assert_eq!(
mat.data(),
&[1., 10.5, 1.5e-02, 6., 2.505e2, -2.8e2, 3.332e1, 1.2e+1]
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn simple_matrix_market_read_from_bufread() {
let path = "data/matrix_market/simple.mm";
let f = std::fs::File::open(path).unwrap();
let mut reader = std::io::BufReader::new(f);
let mat = read_matrix_market_from_bufread::<f64, usize, _>(&mut reader)
.unwrap();
assert_eq!(mat.rows(), 5);
assert_eq!(mat.cols(), 5);
assert_eq!(mat.nnz(), 8);
assert_eq!(mat.row_inds(), &[0, 1, 2, 0, 3, 3, 3, 4]);
assert_eq!(mat.col_inds(), &[0, 1, 2, 3, 1, 3, 4, 4]);
assert_eq!(
mat.data(),
&[1., 10.5, 1.5e-02, 6., 2.505e2, -2.8e2, 3.332e1, 1.2e+1]
);
}
#[test]
#[cfg_attr(miri, ignore)]
fn int_matrix_market_read() {
let path = "data/matrix_market/simple_int.mm";
let mat = read_matrix_market::<i32, usize, _>(path).unwrap();
assert_eq!(mat.rows(), 5);
assert_eq!(mat.cols(), 5);
assert_eq!(mat.nnz(), 8);
assert_eq!(mat.row_inds(), &[0, 1, 2, 0, 3, 3, 3, 4]);
assert_eq!(mat.col_inds(), &[0, 1, 2, 3, 1, 3, 4, 4]);
assert_eq!(mat.data(), &[1, 1, 1, 6, 2, -2, 3, 1]);
let mat = read_matrix_market::<f32, i16, _>(path).unwrap();
assert_eq!(mat.rows(), 5);
assert_eq!(mat.cols(), 5);
assert_eq!(mat.nnz(), 8);
assert_eq!(mat.row_inds(), &[0, 1, 2, 0, 3, 3, 3, 4]);
assert_eq!(mat.col_inds(), &[0, 1, 2, 3, 1, 3, 4, 4]);
assert_eq!(mat.data(), &[1., 1., 1., 6., 2., -2., 3., 1.]);
}
#[test]
#[cfg_attr(miri, ignore)]
fn matrix_market_read_fail_too_many_in_entry() {
let path = "data/matrix_market/bad_files/too_many_elems_in_entry.mm";
let res = read_matrix_market::<f64, i32, _>(path);
assert_eq!(res.unwrap_err(), IoError::BadMatrixMarketFile);
}
#[test]
#[cfg_attr(miri, ignore)]
fn matrix_market_read_fail_not_enough_entries() {
let path = "data/matrix_market/bad_files/not_enough_entries.mm";
let res = read_matrix_market::<f64, i32, _>(path);
assert_eq!(res.unwrap_err(), IoError::BadMatrixMarketFile);
}
#[test]
#[cfg_attr(miri, ignore)]
fn read_write_read_matrix_market() {
let path = "data/matrix_market/simple.mm";
let mat = read_matrix_market::<f64, usize, _>(path).unwrap();
let tmp_dir = tempdir().unwrap();
let save_path = tmp_dir.path().join("simple.mm");
write_matrix_market(&save_path, mat.view()).unwrap();
let mat2 = read_matrix_market::<f64, usize, _>(&save_path).unwrap();
assert_eq!(mat, mat2);
write_matrix_market(&save_path, &mat2).unwrap();
let mat3 = read_matrix_market::<f64, usize, _>(&save_path).unwrap();
assert_eq!(mat, mat3);
}
#[test]
#[cfg_attr(miri, ignore)]
fn read_write_read_matrix_market_via_csc() {
let path = "data/matrix_market/simple.mm";
let mat = read_matrix_market::<f64, usize, _>(path).unwrap();
let csc: CsMat<_> = mat.to_csc();
let tmp_dir = tempdir().unwrap();
let save_path = tmp_dir.path().join("simple_csc.mm");
write_matrix_market(&save_path, &csc).unwrap();
let mat2 = read_matrix_market::<f64, usize, _>(&save_path).unwrap();
assert_eq!(csc, mat2.to_csc());
}
#[test]
#[cfg_attr(miri, ignore)]
fn read_symmetric_matrix_market() {
let path = "data/matrix_market/symmetric.mm";
let mat = read_matrix_market::<f64, usize, _>(path).unwrap();
let csc = mat.to_csc();
let expected = CsMat::new_csc(
(5, 5),
vec![0, 1, 3, 4, 6, 8],
vec![0, 1, 3, 2, 1, 4, 3, 4],
vec![1., 10.5, 2.505e2, 1.5e-2, 2.505e2, 3.332e1, 3.332e1, 1.2e1],
);
assert_eq!(csc, expected);
let tmp_dir = tempdir().unwrap();
let save_path = tmp_dir.path().join("symmetric.mm");
write_matrix_market_sym(&save_path, &csc, SymmetryMode::Symmetric)
.unwrap();
let mat2 = read_matrix_market::<f64, usize, _>(&save_path).unwrap();
assert_eq!(csc, mat2.to_csc());
}
#[test]
#[cfg_attr(miri, ignore)]
fn tricky_symmetric_matrix_market() {
let mat = CsMat::new(
(5, 5),
vec![0, 2, 4, 6, 8, 10],
vec![1, 4, 0, 2, 1, 3, 2, 4, 0, 3],
vec![2, 1, 2, 3, 3, 5, 5, 4, 1, 4],
);
let tmp_dir = tempdir().unwrap();
let save_path = tmp_dir.path().join("symmetric.mm");
write_matrix_market_sym(&save_path, &mat, SymmetryMode::Symmetric)
.unwrap();
let mat2 = read_matrix_market::<i32, usize, _>(&save_path).unwrap();
assert_eq!(mat, mat2.to_csr());
}
#[test]
#[cfg_attr(miri, ignore)]
fn skew_symmetric_matrix_market() {
let mat = CsMat::new(
(5, 5),
vec![0, 2, 4, 6, 8, 10],
vec![1, 4, 0, 2, 1, 3, 2, 4, 0, 3],
vec![2, 1, 2, 3, 3, 5, 5, 4, 1, 4],
);
let tmp_dir = tempdir().unwrap();
let save_path = tmp_dir.path().join("skew_symmetric.mm");
write_matrix_market_sym(&save_path, &mat, SymmetryMode::SkewSymmetric)
.unwrap();
let mat2 = read_matrix_market::<i32, usize, _>(&save_path).unwrap();
assert_eq!(mat, mat2.to_csr());
}
#[test]
#[cfg_attr(miri, ignore)]
fn general_matrix_via_symmetric_save() {
let mat = CsMat::new(
(5, 5),
vec![0, 2, 4, 6, 8, 10],
vec![0, 3, 0, 2, 1, 3, 2, 4, 0, 3],
vec![2, -1, 2, 3, 3, 5, 5, 4, 1, 4],
);
let tmp_dir = tempdir().unwrap();
let save_path = tmp_dir.path().join("general.mm");
write_matrix_market_sym(&save_path, &mat, SymmetryMode::General)
.unwrap();
let mat2 = read_matrix_market::<i32, usize, _>(&save_path).unwrap();
assert_eq!(mat, mat2.to_csr());
}
}