1use crate::coo::CooMatrix;
40use crate::csr::CsrMatrix;
41use num_traits::ToPrimitive;
42use oxiblas_core::scalar::{Field, Real, Scalar};
43use std::io::{BufRead, BufReader, Write};
44use std::path::Path;
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum MtxError {
49 InvalidHeader(String),
51 InvalidData(String),
53 UnsupportedType(String),
55 IoError(String),
57 ParseError(String),
59 MissingSizeLine,
61 IndexOutOfBounds {
63 row: usize,
65 col: usize,
67 nrows: usize,
69 ncols: usize,
71 },
72}
73
74impl core::fmt::Display for MtxError {
75 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
76 match self {
77 Self::InvalidHeader(s) => write!(f, "Invalid Matrix Market header: {s}"),
78 Self::InvalidData(s) => write!(f, "Invalid data: {s}"),
79 Self::UnsupportedType(s) => write!(f, "Unsupported matrix type: {s}"),
80 Self::IoError(s) => write!(f, "I/O error: {s}"),
81 Self::ParseError(s) => write!(f, "Parse error: {s}"),
82 Self::MissingSizeLine => write!(f, "Missing size line"),
83 Self::IndexOutOfBounds {
84 row,
85 col,
86 nrows,
87 ncols,
88 } => {
89 write!(
90 f,
91 "Index ({row}, {col}) out of bounds for {nrows}×{ncols} matrix"
92 )
93 }
94 }
95 }
96}
97
98impl std::error::Error for MtxError {}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum MtxObject {
103 Matrix,
105 Vector,
107}
108
109#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum MtxFormat {
112 Coordinate,
114 Array,
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum MtxField {
121 Real,
123 Complex,
125 Pattern,
127 Integer,
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum MtxSymmetry {
134 General,
136 Symmetric,
138 SkewSymmetric,
140 Hermitian,
142}
143
144#[derive(Debug, Clone)]
146pub struct MtxHeader {
147 pub object: MtxObject,
149 pub format: MtxFormat,
151 pub field: MtxField,
153 pub symmetry: MtxSymmetry,
155 pub nrows: usize,
157 pub ncols: usize,
159 pub nnz: usize,
161 pub comments: Vec<String>,
163}
164
165fn parse_header_line(
167 line: &str,
168) -> Result<(MtxObject, MtxFormat, MtxField, MtxSymmetry), MtxError> {
169 let line = line.to_lowercase();
170
171 if !line.starts_with("%%matrixmarket") {
172 return Err(MtxError::InvalidHeader(
173 "Header must start with %%MatrixMarket".to_string(),
174 ));
175 }
176
177 let parts: Vec<&str> = line.split_whitespace().collect();
178 if parts.len() < 5 {
179 return Err(MtxError::InvalidHeader(
180 "Header must have 5 parts".to_string(),
181 ));
182 }
183
184 let object = match parts[1] {
185 "matrix" => MtxObject::Matrix,
186 "vector" => MtxObject::Vector,
187 other => {
188 return Err(MtxError::UnsupportedType(format!(
189 "Unknown object type: {other}"
190 )));
191 }
192 };
193
194 let format = match parts[2] {
195 "coordinate" => MtxFormat::Coordinate,
196 "array" => MtxFormat::Array,
197 other => {
198 return Err(MtxError::UnsupportedType(format!(
199 "Unknown format: {other}"
200 )));
201 }
202 };
203
204 let field = match parts[3] {
205 "real" => MtxField::Real,
206 "double" => MtxField::Real,
207 "complex" => MtxField::Complex,
208 "pattern" => MtxField::Pattern,
209 "integer" => MtxField::Integer,
210 other => {
211 return Err(MtxError::UnsupportedType(format!(
212 "Unknown field type: {other}"
213 )));
214 }
215 };
216
217 let symmetry = match parts[4] {
218 "general" => MtxSymmetry::General,
219 "symmetric" => MtxSymmetry::Symmetric,
220 "skew-symmetric" => MtxSymmetry::SkewSymmetric,
221 "hermitian" => MtxSymmetry::Hermitian,
222 other => {
223 return Err(MtxError::UnsupportedType(format!(
224 "Unknown symmetry: {other}"
225 )));
226 }
227 };
228
229 Ok((object, format, field, symmetry))
230}
231
232pub fn read_header<R: BufRead>(reader: &mut R) -> Result<MtxHeader, MtxError> {
236 let mut line = String::new();
237
238 reader
240 .read_line(&mut line)
241 .map_err(|e| MtxError::IoError(e.to_string()))?;
242
243 let (object, format, field, symmetry) = parse_header_line(line.trim())?;
244
245 let mut comments = Vec::new();
247 loop {
248 line.clear();
249 reader
250 .read_line(&mut line)
251 .map_err(|e| MtxError::IoError(e.to_string()))?;
252
253 if line.is_empty() {
254 return Err(MtxError::MissingSizeLine);
255 }
256
257 let trimmed = line.trim();
258 if trimmed.starts_with('%') {
259 comments.push(trimmed[1..].trim().to_string());
260 } else {
261 break;
263 }
264 }
265
266 let size_parts: Vec<&str> = line.split_whitespace().collect();
268
269 let (nrows, ncols, nnz) = match format {
270 MtxFormat::Coordinate => {
271 if size_parts.len() < 3 {
272 return Err(MtxError::InvalidData(
273 "Coordinate size line must have 3 values".to_string(),
274 ));
275 }
276 let nrows = size_parts[0]
277 .parse::<usize>()
278 .map_err(|_| MtxError::ParseError("Invalid nrows".to_string()))?;
279 let ncols = size_parts[1]
280 .parse::<usize>()
281 .map_err(|_| MtxError::ParseError("Invalid ncols".to_string()))?;
282 let nnz = size_parts[2]
283 .parse::<usize>()
284 .map_err(|_| MtxError::ParseError("Invalid nnz".to_string()))?;
285 (nrows, ncols, nnz)
286 }
287 MtxFormat::Array => {
288 if size_parts.len() < 2 {
289 return Err(MtxError::InvalidData(
290 "Array size line must have 2 values".to_string(),
291 ));
292 }
293 let nrows = size_parts[0]
294 .parse::<usize>()
295 .map_err(|_| MtxError::ParseError("Invalid nrows".to_string()))?;
296 let ncols = size_parts[1]
297 .parse::<usize>()
298 .map_err(|_| MtxError::ParseError("Invalid ncols".to_string()))?;
299 (nrows, ncols, nrows * ncols)
300 }
301 };
302
303 Ok(MtxHeader {
304 object,
305 format,
306 field,
307 symmetry,
308 nrows,
309 ncols,
310 nnz,
311 comments,
312 })
313}
314
315pub fn read_matrix_market<T: Scalar<Real = T> + Clone + Field + Real, P: AsRef<Path>>(
328 path: P,
329) -> Result<CsrMatrix<T>, MtxError> {
330 let file = std::fs::File::open(path).map_err(|e| MtxError::IoError(e.to_string()))?;
331
332 let mut reader = BufReader::new(file);
333 read_matrix_market_from_reader(&mut reader)
334}
335
336pub fn read_matrix_market_from_reader<T: Scalar<Real = T> + Clone + Field + Real, R: BufRead>(
338 reader: &mut R,
339) -> Result<CsrMatrix<T>, MtxError> {
340 let header = read_header(reader)?;
341
342 if header.format != MtxFormat::Coordinate {
343 return Err(MtxError::UnsupportedType(
344 "Only coordinate format is supported".to_string(),
345 ));
346 }
347
348 if header.field == MtxField::Complex {
349 return Err(MtxError::UnsupportedType(
350 "Complex matrices not supported for real type".to_string(),
351 ));
352 }
353
354 let mut rows = Vec::with_capacity(header.nnz);
356 let mut cols = Vec::with_capacity(header.nnz);
357 let mut vals = Vec::with_capacity(header.nnz);
358
359 for line_result in reader.lines() {
360 let line = line_result.map_err(|e| MtxError::IoError(e.to_string()))?;
361 let trimmed = line.trim();
362
363 if trimmed.is_empty() || trimmed.starts_with('%') {
364 continue;
365 }
366
367 let parts: Vec<&str> = trimmed.split_whitespace().collect();
368
369 if parts.len() < 2 {
370 return Err(MtxError::InvalidData(format!(
371 "Invalid data line: {trimmed}"
372 )));
373 }
374
375 let row: usize = parts[0]
376 .parse()
377 .map_err(|_| MtxError::ParseError(format!("Invalid row: {}", parts[0])))?;
378 let col: usize = parts[1]
379 .parse()
380 .map_err(|_| MtxError::ParseError(format!("Invalid col: {}", parts[1])))?;
381
382 if row == 0 || col == 0 {
384 return Err(MtxError::IndexOutOfBounds {
385 row,
386 col,
387 nrows: header.nrows,
388 ncols: header.ncols,
389 });
390 }
391 let row = row - 1;
392 let col = col - 1;
393
394 if row >= header.nrows || col >= header.ncols {
395 return Err(MtxError::IndexOutOfBounds {
396 row: row + 1,
397 col: col + 1,
398 nrows: header.nrows,
399 ncols: header.ncols,
400 });
401 }
402
403 let val = if header.field == MtxField::Pattern {
404 T::one()
405 } else {
406 if parts.len() < 3 {
407 return Err(MtxError::InvalidData(format!(
408 "Missing value on line: {trimmed}"
409 )));
410 }
411 parts[2]
412 .parse::<f64>()
413 .map_err(|_| MtxError::ParseError(format!("Invalid value: {}", parts[2])))
414 .and_then(|v| {
415 T::from_f64(v)
416 .ok_or_else(|| MtxError::ParseError(format!("Cannot convert value: {v}")))
417 })?
418 };
419
420 rows.push(row);
421 cols.push(col);
422 vals.push(val.clone());
423
424 if row != col {
426 match header.symmetry {
427 MtxSymmetry::Symmetric => {
428 rows.push(col);
429 cols.push(row);
430 vals.push(val);
431 }
432 MtxSymmetry::SkewSymmetric => {
433 rows.push(col);
434 cols.push(row);
435 vals.push(T::zero() - val);
436 }
437 MtxSymmetry::Hermitian => {
438 rows.push(col);
440 cols.push(row);
441 vals.push(val);
442 }
443 MtxSymmetry::General => {}
444 }
445 }
446 }
447
448 let coo = CooMatrix::new(header.nrows, header.ncols, rows, cols, vals)
450 .map_err(|e| MtxError::InvalidData(format!("Failed to create COO matrix: {e:?}")))?;
451
452 Ok(crate::convert::coo_to_csr(&coo))
453}
454
455pub fn read_matrix_market_coo<T: Scalar<Real = T> + Clone + Field + Real, P: AsRef<Path>>(
457 path: P,
458) -> Result<CooMatrix<T>, MtxError> {
459 let csr: CsrMatrix<T> = read_matrix_market(path)?;
460 Ok(crate::convert::csr_to_coo(&csr))
461}
462
463pub fn write_matrix_market<T: Scalar + Clone + Field + ToPrimitive, P: AsRef<Path>>(
475 csr: &CsrMatrix<T>,
476 path: P,
477 comment: Option<&str>,
478) -> Result<(), MtxError> {
479 let file = std::fs::File::create(path).map_err(|e| MtxError::IoError(e.to_string()))?;
480
481 let mut writer = std::io::BufWriter::new(file);
482 write_matrix_market_to_writer(csr, &mut writer, comment)
483}
484
485pub fn write_matrix_market_to_writer<T: Scalar + Clone + Field + ToPrimitive, W: Write>(
487 csr: &CsrMatrix<T>,
488 writer: &mut W,
489 comment: Option<&str>,
490) -> Result<(), MtxError> {
491 let eps = <T as Scalar>::epsilon();
492
493 let mut nnz = 0;
495 for (_, _, val) in csr.iter() {
496 if Scalar::abs(val.clone()) > eps {
497 nnz += 1;
498 }
499 }
500
501 writeln!(writer, "%%MatrixMarket matrix coordinate real general")
503 .map_err(|e| MtxError::IoError(e.to_string()))?;
504
505 if let Some(c) = comment {
507 for line in c.lines() {
508 writeln!(writer, "% {line}").map_err(|e| MtxError::IoError(e.to_string()))?;
509 }
510 }
511
512 writeln!(writer, "{} {} {}", csr.nrows(), csr.ncols(), nnz)
514 .map_err(|e| MtxError::IoError(e.to_string()))?;
515
516 for (row, col, val) in csr.iter() {
518 if Scalar::abs(val.clone()) > eps {
519 let f = val.to_f64().unwrap_or(0.0);
520 writeln!(writer, "{} {} {}", row + 1, col + 1, f)
521 .map_err(|e| MtxError::IoError(e.to_string()))?;
522 }
523 }
524
525 Ok(())
526}
527
528pub fn write_matrix_market_symmetric<T: Scalar + Clone + Field + ToPrimitive, P: AsRef<Path>>(
532 csr: &CsrMatrix<T>,
533 path: P,
534 comment: Option<&str>,
535) -> Result<(), MtxError> {
536 let file = std::fs::File::create(path).map_err(|e| MtxError::IoError(e.to_string()))?;
537
538 let mut writer = std::io::BufWriter::new(file);
539 let eps = <T as Scalar>::epsilon();
540
541 let mut nnz = 0;
543 for (row, col, val) in csr.iter() {
544 if row >= col && Scalar::abs(val.clone()) > eps {
545 nnz += 1;
546 }
547 }
548
549 writeln!(writer, "%%MatrixMarket matrix coordinate real symmetric")
551 .map_err(|e| MtxError::IoError(e.to_string()))?;
552
553 if let Some(c) = comment {
554 for line in c.lines() {
555 writeln!(writer, "% {line}").map_err(|e| MtxError::IoError(e.to_string()))?;
556 }
557 }
558
559 writeln!(writer, "{} {} {}", csr.nrows(), csr.ncols(), nnz)
560 .map_err(|e| MtxError::IoError(e.to_string()))?;
561
562 for (row, col, val) in csr.iter() {
564 if row >= col && Scalar::abs(val.clone()) > eps {
565 let f = val.to_f64().unwrap_or(0.0);
566 writeln!(writer, "{} {} {}", row + 1, col + 1, f)
567 .map_err(|e| MtxError::IoError(e.to_string()))?;
568 }
569 }
570
571 Ok(())
572}
573
574pub fn read_matrix_market_str<T: Scalar<Real = T> + Clone + Field + Real>(
576 s: &str,
577) -> Result<CsrMatrix<T>, MtxError> {
578 let mut reader = BufReader::new(s.as_bytes());
579 read_matrix_market_from_reader(&mut reader)
580}
581
582pub fn write_matrix_market_str<T: Scalar + Clone + Field + ToPrimitive>(
584 csr: &CsrMatrix<T>,
585 comment: Option<&str>,
586) -> Result<String, MtxError> {
587 let mut buf = Vec::new();
588 write_matrix_market_to_writer(csr, &mut buf, comment)?;
589 String::from_utf8(buf).map_err(|e| MtxError::IoError(e.to_string()))
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_parse_header() {
598 let (obj, fmt, field, sym) =
599 parse_header_line("%%MatrixMarket matrix coordinate real general").unwrap();
600
601 assert_eq!(obj, MtxObject::Matrix);
602 assert_eq!(fmt, MtxFormat::Coordinate);
603 assert_eq!(field, MtxField::Real);
604 assert_eq!(sym, MtxSymmetry::General);
605 }
606
607 #[test]
608 fn test_parse_header_symmetric() {
609 let (_, _, _, sym) =
610 parse_header_line("%%MatrixMarket matrix coordinate real symmetric").unwrap();
611
612 assert_eq!(sym, MtxSymmetry::Symmetric);
613 }
614
615 #[test]
616 fn test_read_simple_matrix() {
617 let mtx = r#"%%MatrixMarket matrix coordinate real general
618% A simple test matrix
6193 3 5
6201 1 1.0
6211 3 2.0
6222 2 3.0
6233 1 4.0
6243 3 5.0
625"#;
626
627 let csr: CsrMatrix<f64> = read_matrix_market_str(mtx).unwrap();
628
629 assert_eq!(csr.nrows(), 3);
630 assert_eq!(csr.ncols(), 3);
631 assert_eq!(csr.nnz(), 5);
632
633 assert_eq!(csr.get(0, 0), Some(&1.0));
634 assert_eq!(csr.get(0, 2), Some(&2.0));
635 assert_eq!(csr.get(1, 1), Some(&3.0));
636 assert_eq!(csr.get(2, 0), Some(&4.0));
637 assert_eq!(csr.get(2, 2), Some(&5.0));
638 }
639
640 #[test]
641 fn test_read_symmetric_matrix() {
642 let mtx = r#"%%MatrixMarket matrix coordinate real symmetric
6433 3 4
6441 1 1.0
6452 1 2.0
6462 2 3.0
6473 3 4.0
648"#;
649
650 let csr: CsrMatrix<f64> = read_matrix_market_str(mtx).unwrap();
651
652 assert_eq!(csr.nrows(), 3);
653 assert_eq!(csr.ncols(), 3);
654
655 assert_eq!(csr.get(0, 0), Some(&1.0));
657 assert_eq!(csr.get(1, 0), Some(&2.0));
658 assert_eq!(csr.get(0, 1), Some(&2.0)); assert_eq!(csr.get(1, 1), Some(&3.0));
660 assert_eq!(csr.get(2, 2), Some(&4.0));
661 }
662
663 #[test]
664 fn test_read_pattern_matrix() {
665 let mtx = r#"%%MatrixMarket matrix coordinate pattern general
6662 2 2
6671 1
6682 2
669"#;
670
671 let csr: CsrMatrix<f64> = read_matrix_market_str(mtx).unwrap();
672
673 assert_eq!(csr.get(0, 0), Some(&1.0));
674 assert_eq!(csr.get(1, 1), Some(&1.0));
675 }
676
677 #[test]
678 fn test_write_read_roundtrip() {
679 let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
681 let col_indices = vec![0, 2, 1, 0, 2];
682 let row_ptrs = vec![0, 2, 3, 5];
683
684 let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
685
686 let mtx_str = write_matrix_market_str(&csr, Some("Test matrix")).unwrap();
688
689 let csr2: CsrMatrix<f64> = read_matrix_market_str(&mtx_str).unwrap();
691
692 assert_eq!(csr.nrows(), csr2.nrows());
693 assert_eq!(csr.ncols(), csr2.ncols());
694 assert_eq!(csr.nnz(), csr2.nnz());
695
696 for row in 0..3 {
697 for col in 0..3 {
698 let v1 = csr.get(row, col).cloned().unwrap_or(0.0);
699 let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
700 assert!((v1 - v2).abs() < 1e-10);
701 }
702 }
703 }
704
705 #[test]
706 fn test_header_parsing_error() {
707 let result = parse_header_line("invalid header");
708 assert!(result.is_err());
709 }
710
711 #[test]
712 fn test_index_error() {
713 let mtx = r#"%%MatrixMarket matrix coordinate real general
7142 2 1
7153 1 1.0
716"#;
717
718 let result: Result<CsrMatrix<f64>, _> = read_matrix_market_str(mtx);
719 assert!(result.is_err());
720 }
721}