use std::path::Path;
use faer::sparse::{SparseColMat, Triplet};
use crate::error::SparseError;
pub fn load_mtx(path: &Path) -> Result<SparseColMat<usize, f64>, SparseError> {
let path_str = path.display().to_string();
let io_err = |e: std::io::Error| SparseError::IoError {
source: e.to_string(),
path: path_str.clone(),
};
let parse_err = |reason: String, line: Option<usize>| SparseError::ParseError {
reason,
path: path_str.clone(),
line,
};
let content = std::fs::read_to_string(path).map_err(io_err)?;
let mut lines = content.lines().enumerate();
let (_, header) = lines
.next()
.ok_or_else(|| parse_err("empty file".to_string(), Some(1)))?;
let header_lower = header.to_lowercase();
if !header_lower.starts_with("%%matrixmarket matrix coordinate real symmetric") {
return Err(parse_err(
format!(
"unsupported format: expected '%%MatrixMarket matrix coordinate real symmetric', got '{}'",
header
),
Some(1),
));
}
let mut size_line = None;
for (line_idx, line) in &mut lines {
if !line.starts_with('%') {
size_line = Some((line_idx, line));
break;
}
}
let (size_line_idx, size_line_str) =
size_line.ok_or_else(|| parse_err("missing size line".to_string(), None))?;
let size_parts: Vec<&str> = size_line_str.split_whitespace().collect();
if size_parts.len() != 3 {
return Err(parse_err(
format!(
"size line must have 3 values (nrows ncols nnz), got {}",
size_parts.len()
),
Some(size_line_idx + 1),
));
}
let nrows: usize = size_parts[0].parse().map_err(|_| {
parse_err(
format!("invalid nrows: '{}'", size_parts[0]),
Some(size_line_idx + 1),
)
})?;
let ncols: usize = size_parts[1].parse().map_err(|_| {
parse_err(
format!("invalid ncols: '{}'", size_parts[1]),
Some(size_line_idx + 1),
)
})?;
let declared_nnz: usize = size_parts[2].parse().map_err(|_| {
parse_err(
format!("invalid nnz: '{}'", size_parts[2]),
Some(size_line_idx + 1),
)
})?;
let mut triplets: Vec<Triplet<usize, usize, f64>> = Vec::new();
let mut data_lines = 0usize;
for (line_idx, line) in lines {
let line = line.trim();
if line.is_empty() || line.starts_with('%') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() != 3 {
return Err(parse_err(
format!(
"data line must have 3 values (row col val), got {}",
parts.len()
),
Some(line_idx + 1),
));
}
let row: usize = parts[0].parse().map_err(|_| {
parse_err(
format!("invalid row index: '{}'", parts[0]),
Some(line_idx + 1),
)
})?;
let col: usize = parts[1].parse().map_err(|_| {
parse_err(
format!("invalid col index: '{}'", parts[1]),
Some(line_idx + 1),
)
})?;
let val: f64 = parts[2]
.parse()
.map_err(|_| parse_err(format!("invalid value: '{}'", parts[2]), Some(line_idx + 1)))?;
if row == 0 || col == 0 {
return Err(parse_err(
"Matrix Market indices are 1-based; got 0".to_string(),
Some(line_idx + 1),
));
}
let row = row - 1;
let col = col - 1;
if row >= nrows || col >= ncols {
return Err(parse_err(
format!(
"index ({}, {}) out of bounds for {}x{} matrix",
row, col, nrows, ncols
),
Some(line_idx + 1),
));
}
data_lines += 1;
if val == 0.0 {
continue;
}
triplets.push(Triplet { row, col, val });
if row != col {
triplets.push(Triplet {
row: col,
col: row,
val,
});
}
}
if data_lines < declared_nnz {
eprintln!(
"WARNING: {}: file declares {} entries but only {} data lines found (truncated file?)",
path_str, declared_nnz, data_lines
);
}
SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
.map_err(|e| parse_err(format!("failed to construct sparse matrix: {:?}", e), None))
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn test_data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test-data")
}
#[test]
fn load_arrow_5_pd() {
let path = test_data_dir().join("hand-constructed/arrow-5-pd.mtx");
let mat = load_mtx(&path).expect("failed to load arrow-5-pd.mtx");
assert_eq!(mat.nrows(), 5);
assert_eq!(mat.ncols(), 5);
assert_eq!(mat.compute_nnz(), 13);
}
#[test]
fn malformed_input_returns_parse_error() {
let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("target/test-tmp");
std::fs::create_dir_all(&dir).ok();
let path = dir.join("malformed.mtx");
std::fs::write(&path, "this is not a matrix market file\n").unwrap();
let result = load_mtx(&path);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SparseError::ParseError { .. }
));
}
#[test]
fn unsupported_format_returns_error() {
let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("target/test-tmp");
std::fs::create_dir_all(&dir).ok();
let path = dir.join("unsupported.mtx");
std::fs::write(
&path,
"%%MatrixMarket matrix array real general\n2 2\n1.0\n2.0\n3.0\n4.0\n",
)
.unwrap();
let result = load_mtx(&path);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SparseError::ParseError { .. }
));
}
#[test]
fn nonexistent_file_returns_io_error() {
let result = load_mtx(Path::new("/nonexistent/path/matrix.mtx"));
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), SparseError::IoError { .. }));
}
}