oxiblas_sparse/
mtx.rs

1//! Matrix Market format support.
2//!
3//! This module provides reading and writing of sparse matrices in
4//! Matrix Market (MM) format, the standard exchange format used by
5//! the SuiteSparse Matrix Collection (formerly Florida Matrix Collection).
6//!
7//! # Format Overview
8//!
9//! Matrix Market files consist of:
10//! 1. A header line: `%%MatrixMarket matrix coordinate real general`
11//! 2. Optional comment lines starting with `%`
12//! 3. A size line: `nrows ncols nnz`
13//! 4. Data lines: `row col value` (1-indexed)
14//!
15//! # Supported Types
16//!
17//! - **Coordinate format** (sparse): real, complex, pattern, integer
18//! - **Array format** (dense): not yet implemented
19//!
20//! # Symmetry
21//!
22//! - **general**: No symmetry assumed
23//! - **symmetric**: Only lower triangle stored, A = A^T
24//! - **skew-symmetric**: Only lower triangle stored, A = -A^T
25//! - **hermitian**: Only lower triangle stored, A = A^H (complex only)
26//!
27//! # Example
28//!
29//! ```ignore
30//! use oxiblas_sparse::mtx::{read_matrix_market, write_matrix_market};
31//!
32//! // Read a matrix from file
33//! let csr = read_matrix_market::<f64>("matrix.mtx")?;
34//!
35//! // Write a matrix to file
36//! write_matrix_market(&csr, "output.mtx", "My matrix")?;
37//! ```
38
39use crate::coo::CooMatrix;
40use crate::csr::CsrMatrix;
41use num_traits::ToPrimitive;
42use oxiblas_core::scalar::{Field, Real, Scalar};
43use std::io::{BufRead, BufReader, Write};
44use std::path::Path;
45
46/// Error type for Matrix Market operations.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum MtxError {
49    /// Invalid header format.
50    InvalidHeader(String),
51    /// Invalid data format.
52    InvalidData(String),
53    /// Unsupported matrix type.
54    UnsupportedType(String),
55    /// I/O error.
56    IoError(String),
57    /// Parse error.
58    ParseError(String),
59    /// Missing size line.
60    MissingSizeLine,
61    /// Index out of bounds.
62    IndexOutOfBounds {
63        /// Row index.
64        row: usize,
65        /// Column index.
66        col: usize,
67        /// Matrix rows.
68        nrows: usize,
69        /// Matrix columns.
70        ncols: usize,
71    },
72}
73
74impl core::fmt::Display for MtxError {
75    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
76        match self {
77            Self::InvalidHeader(s) => write!(f, "Invalid Matrix Market header: {s}"),
78            Self::InvalidData(s) => write!(f, "Invalid data: {s}"),
79            Self::UnsupportedType(s) => write!(f, "Unsupported matrix type: {s}"),
80            Self::IoError(s) => write!(f, "I/O error: {s}"),
81            Self::ParseError(s) => write!(f, "Parse error: {s}"),
82            Self::MissingSizeLine => write!(f, "Missing size line"),
83            Self::IndexOutOfBounds {
84                row,
85                col,
86                nrows,
87                ncols,
88            } => {
89                write!(
90                    f,
91                    "Index ({row}, {col}) out of bounds for {nrows}×{ncols} matrix"
92                )
93            }
94        }
95    }
96}
97
98impl std::error::Error for MtxError {}
99
100/// Matrix Market object type.
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum MtxObject {
103    /// Matrix.
104    Matrix,
105    /// Vector.
106    Vector,
107}
108
109/// Matrix Market format type.
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum MtxFormat {
112    /// Coordinate (sparse) format.
113    Coordinate,
114    /// Array (dense) format.
115    Array,
116}
117
118/// Matrix Market field type.
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum MtxField {
121    /// Real values.
122    Real,
123    /// Complex values.
124    Complex,
125    /// Pattern only (structure, no values).
126    Pattern,
127    /// Integer values.
128    Integer,
129}
130
131/// Matrix Market symmetry type.
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum MtxSymmetry {
134    /// General (no symmetry).
135    General,
136    /// Symmetric: A = A^T.
137    Symmetric,
138    /// Skew-symmetric: A = -A^T.
139    SkewSymmetric,
140    /// Hermitian: A = A^H.
141    Hermitian,
142}
143
144/// Matrix Market header information.
145#[derive(Debug, Clone)]
146pub struct MtxHeader {
147    /// Object type (matrix or vector).
148    pub object: MtxObject,
149    /// Format type (coordinate or array).
150    pub format: MtxFormat,
151    /// Field type (real, complex, pattern, integer).
152    pub field: MtxField,
153    /// Symmetry type.
154    pub symmetry: MtxSymmetry,
155    /// Number of rows.
156    pub nrows: usize,
157    /// Number of columns.
158    pub ncols: usize,
159    /// Number of stored entries (for coordinate format).
160    pub nnz: usize,
161    /// Comment lines (without leading %).
162    pub comments: Vec<String>,
163}
164
165/// Parses a Matrix Market header line.
166fn parse_header_line(
167    line: &str,
168) -> Result<(MtxObject, MtxFormat, MtxField, MtxSymmetry), MtxError> {
169    let line = line.to_lowercase();
170
171    if !line.starts_with("%%matrixmarket") {
172        return Err(MtxError::InvalidHeader(
173            "Header must start with %%MatrixMarket".to_string(),
174        ));
175    }
176
177    let parts: Vec<&str> = line.split_whitespace().collect();
178    if parts.len() < 5 {
179        return Err(MtxError::InvalidHeader(
180            "Header must have 5 parts".to_string(),
181        ));
182    }
183
184    let object = match parts[1] {
185        "matrix" => MtxObject::Matrix,
186        "vector" => MtxObject::Vector,
187        other => {
188            return Err(MtxError::UnsupportedType(format!(
189                "Unknown object type: {other}"
190            )));
191        }
192    };
193
194    let format = match parts[2] {
195        "coordinate" => MtxFormat::Coordinate,
196        "array" => MtxFormat::Array,
197        other => {
198            return Err(MtxError::UnsupportedType(format!(
199                "Unknown format: {other}"
200            )));
201        }
202    };
203
204    let field = match parts[3] {
205        "real" => MtxField::Real,
206        "double" => MtxField::Real,
207        "complex" => MtxField::Complex,
208        "pattern" => MtxField::Pattern,
209        "integer" => MtxField::Integer,
210        other => {
211            return Err(MtxError::UnsupportedType(format!(
212                "Unknown field type: {other}"
213            )));
214        }
215    };
216
217    let symmetry = match parts[4] {
218        "general" => MtxSymmetry::General,
219        "symmetric" => MtxSymmetry::Symmetric,
220        "skew-symmetric" => MtxSymmetry::SkewSymmetric,
221        "hermitian" => MtxSymmetry::Hermitian,
222        other => {
223            return Err(MtxError::UnsupportedType(format!(
224                "Unknown symmetry: {other}"
225            )));
226        }
227    };
228
229    Ok((object, format, field, symmetry))
230}
231
232/// Reads a Matrix Market file header.
233///
234/// Returns the header information and the reader positioned after comments.
235pub fn read_header<R: BufRead>(reader: &mut R) -> Result<MtxHeader, MtxError> {
236    let mut line = String::new();
237
238    // Read header line
239    reader
240        .read_line(&mut line)
241        .map_err(|e| MtxError::IoError(e.to_string()))?;
242
243    let (object, format, field, symmetry) = parse_header_line(line.trim())?;
244
245    // Read comments
246    let mut comments = Vec::new();
247    loop {
248        line.clear();
249        reader
250            .read_line(&mut line)
251            .map_err(|e| MtxError::IoError(e.to_string()))?;
252
253        if line.is_empty() {
254            return Err(MtxError::MissingSizeLine);
255        }
256
257        let trimmed = line.trim();
258        if trimmed.starts_with('%') {
259            comments.push(trimmed[1..].trim().to_string());
260        } else {
261            // This is the size line
262            break;
263        }
264    }
265
266    // Parse size line
267    let size_parts: Vec<&str> = line.split_whitespace().collect();
268
269    let (nrows, ncols, nnz) = match format {
270        MtxFormat::Coordinate => {
271            if size_parts.len() < 3 {
272                return Err(MtxError::InvalidData(
273                    "Coordinate size line must have 3 values".to_string(),
274                ));
275            }
276            let nrows = size_parts[0]
277                .parse::<usize>()
278                .map_err(|_| MtxError::ParseError("Invalid nrows".to_string()))?;
279            let ncols = size_parts[1]
280                .parse::<usize>()
281                .map_err(|_| MtxError::ParseError("Invalid ncols".to_string()))?;
282            let nnz = size_parts[2]
283                .parse::<usize>()
284                .map_err(|_| MtxError::ParseError("Invalid nnz".to_string()))?;
285            (nrows, ncols, nnz)
286        }
287        MtxFormat::Array => {
288            if size_parts.len() < 2 {
289                return Err(MtxError::InvalidData(
290                    "Array size line must have 2 values".to_string(),
291                ));
292            }
293            let nrows = size_parts[0]
294                .parse::<usize>()
295                .map_err(|_| MtxError::ParseError("Invalid nrows".to_string()))?;
296            let ncols = size_parts[1]
297                .parse::<usize>()
298                .map_err(|_| MtxError::ParseError("Invalid ncols".to_string()))?;
299            (nrows, ncols, nrows * ncols)
300        }
301    };
302
303    Ok(MtxHeader {
304        object,
305        format,
306        field,
307        symmetry,
308        nrows,
309        ncols,
310        nnz,
311        comments,
312    })
313}
314
315/// Reads a Matrix Market file and returns a CSR matrix.
316///
317/// # Arguments
318///
319/// * `path` - Path to the Matrix Market file
320///
321/// # Errors
322///
323/// Returns an error if:
324/// - The file cannot be read
325/// - The format is not supported
326/// - The data is invalid
327pub fn read_matrix_market<T: Scalar<Real = T> + Clone + Field + Real, P: AsRef<Path>>(
328    path: P,
329) -> Result<CsrMatrix<T>, MtxError> {
330    let file = std::fs::File::open(path).map_err(|e| MtxError::IoError(e.to_string()))?;
331
332    let mut reader = BufReader::new(file);
333    read_matrix_market_from_reader(&mut reader)
334}
335
336/// Reads a Matrix Market file from a reader.
337pub fn read_matrix_market_from_reader<T: Scalar<Real = T> + Clone + Field + Real, R: BufRead>(
338    reader: &mut R,
339) -> Result<CsrMatrix<T>, MtxError> {
340    let header = read_header(reader)?;
341
342    if header.format != MtxFormat::Coordinate {
343        return Err(MtxError::UnsupportedType(
344            "Only coordinate format is supported".to_string(),
345        ));
346    }
347
348    if header.field == MtxField::Complex {
349        return Err(MtxError::UnsupportedType(
350            "Complex matrices not supported for real type".to_string(),
351        ));
352    }
353
354    // Read data entries
355    let mut rows = Vec::with_capacity(header.nnz);
356    let mut cols = Vec::with_capacity(header.nnz);
357    let mut vals = Vec::with_capacity(header.nnz);
358
359    for line_result in reader.lines() {
360        let line = line_result.map_err(|e| MtxError::IoError(e.to_string()))?;
361        let trimmed = line.trim();
362
363        if trimmed.is_empty() || trimmed.starts_with('%') {
364            continue;
365        }
366
367        let parts: Vec<&str> = trimmed.split_whitespace().collect();
368
369        if parts.len() < 2 {
370            return Err(MtxError::InvalidData(format!(
371                "Invalid data line: {trimmed}"
372            )));
373        }
374
375        let row: usize = parts[0]
376            .parse()
377            .map_err(|_| MtxError::ParseError(format!("Invalid row: {}", parts[0])))?;
378        let col: usize = parts[1]
379            .parse()
380            .map_err(|_| MtxError::ParseError(format!("Invalid col: {}", parts[1])))?;
381
382        // Convert from 1-indexed to 0-indexed
383        if row == 0 || col == 0 {
384            return Err(MtxError::IndexOutOfBounds {
385                row,
386                col,
387                nrows: header.nrows,
388                ncols: header.ncols,
389            });
390        }
391        let row = row - 1;
392        let col = col - 1;
393
394        if row >= header.nrows || col >= header.ncols {
395            return Err(MtxError::IndexOutOfBounds {
396                row: row + 1,
397                col: col + 1,
398                nrows: header.nrows,
399                ncols: header.ncols,
400            });
401        }
402
403        let val = if header.field == MtxField::Pattern {
404            T::one()
405        } else {
406            if parts.len() < 3 {
407                return Err(MtxError::InvalidData(format!(
408                    "Missing value on line: {trimmed}"
409                )));
410            }
411            parts[2]
412                .parse::<f64>()
413                .map_err(|_| MtxError::ParseError(format!("Invalid value: {}", parts[2])))
414                .and_then(|v| {
415                    T::from_f64(v)
416                        .ok_or_else(|| MtxError::ParseError(format!("Cannot convert value: {v}")))
417                })?
418        };
419
420        rows.push(row);
421        cols.push(col);
422        vals.push(val.clone());
423
424        // Handle symmetry
425        if row != col {
426            match header.symmetry {
427                MtxSymmetry::Symmetric => {
428                    rows.push(col);
429                    cols.push(row);
430                    vals.push(val);
431                }
432                MtxSymmetry::SkewSymmetric => {
433                    rows.push(col);
434                    cols.push(row);
435                    vals.push(T::zero() - val);
436                }
437                MtxSymmetry::Hermitian => {
438                    // For real matrices, Hermitian = Symmetric
439                    rows.push(col);
440                    cols.push(row);
441                    vals.push(val);
442                }
443                MtxSymmetry::General => {}
444            }
445        }
446    }
447
448    // Build COO and convert to CSR
449    let coo = CooMatrix::new(header.nrows, header.ncols, rows, cols, vals)
450        .map_err(|e| MtxError::InvalidData(format!("Failed to create COO matrix: {e:?}")))?;
451
452    Ok(crate::convert::coo_to_csr(&coo))
453}
454
455/// Reads a Matrix Market file and returns a COO matrix.
456pub fn read_matrix_market_coo<T: Scalar<Real = T> + Clone + Field + Real, P: AsRef<Path>>(
457    path: P,
458) -> Result<CooMatrix<T>, MtxError> {
459    let csr: CsrMatrix<T> = read_matrix_market(path)?;
460    Ok(crate::convert::csr_to_coo(&csr))
461}
462
463/// Writes a CSR matrix to Matrix Market format.
464///
465/// # Arguments
466///
467/// * `csr` - The matrix to write
468/// * `path` - Path to write to
469/// * `comment` - Optional comment to include in the file
470///
471/// # Errors
472///
473/// Returns an error if the file cannot be written.
474pub fn write_matrix_market<T: Scalar + Clone + Field + ToPrimitive, P: AsRef<Path>>(
475    csr: &CsrMatrix<T>,
476    path: P,
477    comment: Option<&str>,
478) -> Result<(), MtxError> {
479    let file = std::fs::File::create(path).map_err(|e| MtxError::IoError(e.to_string()))?;
480
481    let mut writer = std::io::BufWriter::new(file);
482    write_matrix_market_to_writer(csr, &mut writer, comment)
483}
484
485/// Writes a CSR matrix to Matrix Market format using a writer.
486pub fn write_matrix_market_to_writer<T: Scalar + Clone + Field + ToPrimitive, W: Write>(
487    csr: &CsrMatrix<T>,
488    writer: &mut W,
489    comment: Option<&str>,
490) -> Result<(), MtxError> {
491    let eps = <T as Scalar>::epsilon();
492
493    // Count actual non-zeros
494    let mut nnz = 0;
495    for (_, _, val) in csr.iter() {
496        if Scalar::abs(val.clone()) > eps {
497            nnz += 1;
498        }
499    }
500
501    // Write header
502    writeln!(writer, "%%MatrixMarket matrix coordinate real general")
503        .map_err(|e| MtxError::IoError(e.to_string()))?;
504
505    // Write comment if provided
506    if let Some(c) = comment {
507        for line in c.lines() {
508            writeln!(writer, "% {line}").map_err(|e| MtxError::IoError(e.to_string()))?;
509        }
510    }
511
512    // Write size line
513    writeln!(writer, "{} {} {}", csr.nrows(), csr.ncols(), nnz)
514        .map_err(|e| MtxError::IoError(e.to_string()))?;
515
516    // Write data (1-indexed)
517    for (row, col, val) in csr.iter() {
518        if Scalar::abs(val.clone()) > eps {
519            let f = val.to_f64().unwrap_or(0.0);
520            writeln!(writer, "{} {} {}", row + 1, col + 1, f)
521                .map_err(|e| MtxError::IoError(e.to_string()))?;
522        }
523    }
524
525    Ok(())
526}
527
528/// Writes a symmetric CSR matrix to Matrix Market format.
529///
530/// Only writes the lower triangle, with symmetric flag.
531pub fn write_matrix_market_symmetric<T: Scalar + Clone + Field + ToPrimitive, P: AsRef<Path>>(
532    csr: &CsrMatrix<T>,
533    path: P,
534    comment: Option<&str>,
535) -> Result<(), MtxError> {
536    let file = std::fs::File::create(path).map_err(|e| MtxError::IoError(e.to_string()))?;
537
538    let mut writer = std::io::BufWriter::new(file);
539    let eps = <T as Scalar>::epsilon();
540
541    // Count lower triangle non-zeros
542    let mut nnz = 0;
543    for (row, col, val) in csr.iter() {
544        if row >= col && Scalar::abs(val.clone()) > eps {
545            nnz += 1;
546        }
547    }
548
549    // Write header
550    writeln!(writer, "%%MatrixMarket matrix coordinate real symmetric")
551        .map_err(|e| MtxError::IoError(e.to_string()))?;
552
553    if let Some(c) = comment {
554        for line in c.lines() {
555            writeln!(writer, "% {line}").map_err(|e| MtxError::IoError(e.to_string()))?;
556        }
557    }
558
559    writeln!(writer, "{} {} {}", csr.nrows(), csr.ncols(), nnz)
560        .map_err(|e| MtxError::IoError(e.to_string()))?;
561
562    // Write lower triangle only (1-indexed)
563    for (row, col, val) in csr.iter() {
564        if row >= col && Scalar::abs(val.clone()) > eps {
565            let f = val.to_f64().unwrap_or(0.0);
566            writeln!(writer, "{} {} {}", row + 1, col + 1, f)
567                .map_err(|e| MtxError::IoError(e.to_string()))?;
568        }
569    }
570
571    Ok(())
572}
573
574/// Reads Matrix Market format from a string.
575pub fn read_matrix_market_str<T: Scalar<Real = T> + Clone + Field + Real>(
576    s: &str,
577) -> Result<CsrMatrix<T>, MtxError> {
578    let mut reader = BufReader::new(s.as_bytes());
579    read_matrix_market_from_reader(&mut reader)
580}
581
582/// Writes a CSR matrix to Matrix Market format as a string.
583pub fn write_matrix_market_str<T: Scalar + Clone + Field + ToPrimitive>(
584    csr: &CsrMatrix<T>,
585    comment: Option<&str>,
586) -> Result<String, MtxError> {
587    let mut buf = Vec::new();
588    write_matrix_market_to_writer(csr, &mut buf, comment)?;
589    String::from_utf8(buf).map_err(|e| MtxError::IoError(e.to_string()))
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_parse_header() {
598        let (obj, fmt, field, sym) =
599            parse_header_line("%%MatrixMarket matrix coordinate real general").unwrap();
600
601        assert_eq!(obj, MtxObject::Matrix);
602        assert_eq!(fmt, MtxFormat::Coordinate);
603        assert_eq!(field, MtxField::Real);
604        assert_eq!(sym, MtxSymmetry::General);
605    }
606
607    #[test]
608    fn test_parse_header_symmetric() {
609        let (_, _, _, sym) =
610            parse_header_line("%%MatrixMarket matrix coordinate real symmetric").unwrap();
611
612        assert_eq!(sym, MtxSymmetry::Symmetric);
613    }
614
615    #[test]
616    fn test_read_simple_matrix() {
617        let mtx = r#"%%MatrixMarket matrix coordinate real general
618% A simple test matrix
6193 3 5
6201 1 1.0
6211 3 2.0
6222 2 3.0
6233 1 4.0
6243 3 5.0
625"#;
626
627        let csr: CsrMatrix<f64> = read_matrix_market_str(mtx).unwrap();
628
629        assert_eq!(csr.nrows(), 3);
630        assert_eq!(csr.ncols(), 3);
631        assert_eq!(csr.nnz(), 5);
632
633        assert_eq!(csr.get(0, 0), Some(&1.0));
634        assert_eq!(csr.get(0, 2), Some(&2.0));
635        assert_eq!(csr.get(1, 1), Some(&3.0));
636        assert_eq!(csr.get(2, 0), Some(&4.0));
637        assert_eq!(csr.get(2, 2), Some(&5.0));
638    }
639
640    #[test]
641    fn test_read_symmetric_matrix() {
642        let mtx = r#"%%MatrixMarket matrix coordinate real symmetric
6433 3 4
6441 1 1.0
6452 1 2.0
6462 2 3.0
6473 3 4.0
648"#;
649
650        let csr: CsrMatrix<f64> = read_matrix_market_str(mtx).unwrap();
651
652        assert_eq!(csr.nrows(), 3);
653        assert_eq!(csr.ncols(), 3);
654
655        // Symmetric entries
656        assert_eq!(csr.get(0, 0), Some(&1.0));
657        assert_eq!(csr.get(1, 0), Some(&2.0));
658        assert_eq!(csr.get(0, 1), Some(&2.0)); // Symmetric fill-in
659        assert_eq!(csr.get(1, 1), Some(&3.0));
660        assert_eq!(csr.get(2, 2), Some(&4.0));
661    }
662
663    #[test]
664    fn test_read_pattern_matrix() {
665        let mtx = r#"%%MatrixMarket matrix coordinate pattern general
6662 2 2
6671 1
6682 2
669"#;
670
671        let csr: CsrMatrix<f64> = read_matrix_market_str(mtx).unwrap();
672
673        assert_eq!(csr.get(0, 0), Some(&1.0));
674        assert_eq!(csr.get(1, 1), Some(&1.0));
675    }
676
677    #[test]
678    fn test_write_read_roundtrip() {
679        // Create a simple matrix
680        let values = vec![1.0f64, 2.0, 3.0, 4.0, 5.0];
681        let col_indices = vec![0, 2, 1, 0, 2];
682        let row_ptrs = vec![0, 2, 3, 5];
683
684        let csr = CsrMatrix::new(3, 3, row_ptrs, col_indices, values).unwrap();
685
686        // Write to string
687        let mtx_str = write_matrix_market_str(&csr, Some("Test matrix")).unwrap();
688
689        // Read back
690        let csr2: CsrMatrix<f64> = read_matrix_market_str(&mtx_str).unwrap();
691
692        assert_eq!(csr.nrows(), csr2.nrows());
693        assert_eq!(csr.ncols(), csr2.ncols());
694        assert_eq!(csr.nnz(), csr2.nnz());
695
696        for row in 0..3 {
697            for col in 0..3 {
698                let v1 = csr.get(row, col).cloned().unwrap_or(0.0);
699                let v2 = csr2.get(row, col).cloned().unwrap_or(0.0);
700                assert!((v1 - v2).abs() < 1e-10);
701            }
702        }
703    }
704
705    #[test]
706    fn test_header_parsing_error() {
707        let result = parse_header_line("invalid header");
708        assert!(result.is_err());
709    }
710
711    #[test]
712    fn test_index_error() {
713        let mtx = r#"%%MatrixMarket matrix coordinate real general
7142 2 1
7153 1 1.0
716"#;
717
718        let result: Result<CsrMatrix<f64>, _> = read_matrix_market_str(mtx);
719        assert!(result.is_err());
720    }
721}