csvbinmatrix 0.8.0

Binary matrix Compressed Sparse Vector
Documentation
#![doc = include_str!("../../docs/matrix/update.md")]

use super::{items::Coordinates, iter::OnesCoordinatesColumnCursor, CSVBinaryMatrix};

use thiserror::Error;

/// Error when updating the matrix
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Error)]
pub enum UpdateMatrixError {
    /// Error when the extension's number of columns does not match the self's number of columns.
    #[error(
        "The extension's number of columns ({0}) must match the self's number of columns ({1})"
    )]
    ColumnsMismatch(usize, usize),
    /// Error when the extension's number of rows does not match the self's number of rows.
    #[error("The extension's number of rows ({0}) must match the self's number of rows ({1})")]
    RowsMismatch(usize, usize),
}

impl CSVBinaryMatrix {
    /// Extends the rows of the current matrix with the given matrix.
    ///
    /// Move the values of extension into self.
    ///
    /// Time complexity is O(1).
    ///
    /// # Example
    ///
    /// ```rust
    /// # use csvbinmatrix::prelude::CSVBinaryMatrix;
    /// #
    /// # let mut matrix = CSVBinaryMatrix::try_from(&[
    /// #     [0, 0, 0],
    /// #     [0, 0, 1],
    /// #     [0, 1, 1],
    /// #     [1, 1, 1],
    /// # ]).unwrap();
    /// let extension = CSVBinaryMatrix::try_from(&[[0, 1, 0]]).unwrap();
    ///
    /// match matrix.extend_with_rows(extension) {
    ///     Ok(()) => (),
    ///     Err(err) => panic!("[ERROR] {err}"),
    /// }
    /// assert_eq!(
    ///     matrix,
    ///     CSVBinaryMatrix::try_from(&[
    ///         [0, 0, 0],
    ///         [0, 0, 1],
    ///         [0, 1, 1],
    ///         [1, 1, 1],
    ///         [0, 1, 0]
    ///     ]).unwrap()
    /// );
    /// ```
    ///
    /// # Errors
    ///
    /// * [`UpdateMatrixError`] - The extension's number of columns must match the self's number of columns.
    ///
    /// # Panics
    ///
    /// Panics if the extension has not been built correctly
    ///
    pub fn extend_with_rows(&mut self, mut extension: Self) -> Result<(), UpdateMatrixError> {
        if self.number_of_columns != extension.number_of_columns {
            return Err(UpdateMatrixError::ColumnsMismatch(
                extension.number_of_columns,
                self.number_of_columns,
            ));
        }
        let len_self_distances = self.distances.len();
        let last_distance = self.distances.pop().unwrap();
        self.distances.append(&mut extension.distances);
        self.distances[len_self_distances - 1] += last_distance + 1;
        self.number_of_rows += extension.number_of_rows;
        Ok(())
    }

    /// Extends the columns of the current matrix with the given matrix.
    ///
    /// The memory and the time complexities are linear according to the number ones of the two matrices.
    ///
    /// # Example
    ///
    /// Let A be the following matrix:
    /// ```text
    /// 0 0 0
    /// 0 0 1
    /// 0 1 1
    /// 1 1 1
    /// ```
    ///
    /// Let B be the extension matrix:
    /// ```text
    /// 1
    /// 1
    /// 1
    /// 1
    /// ```
    ///
    /// The resulting extended matrix will be
    /// ```text
    /// 0 0 0 1
    /// 0 0 1 1
    /// 0 1 1 1
    /// 1 1 1 1
    /// ```
    ///
    /// ```rust
    /// # use csvbinmatrix::prelude::CSVBinaryMatrix;
    /// #
    /// # let mut matrix = CSVBinaryMatrix::try_from(&[
    /// #     [0, 0, 0],
    /// #     [0, 0, 1],
    /// #     [0, 1, 1],
    /// #     [1, 1, 1],
    /// # ]).unwrap();
    /// let extension = CSVBinaryMatrix::try_from(&[[1], [1], [1], [1]]).unwrap();
    ///
    /// match matrix.extend_with_columns(extension) {
    ///     Ok(()) => { },
    ///     Err(err) => panic!("[ERROR] {err}"),
    /// }
    /// assert_eq!(matrix, CSVBinaryMatrix::try_from(&[
    ///     [0, 0, 0, 1],
    ///     [0, 0, 1, 1],
    ///     [0, 1, 1, 1],
    ///     [1, 1, 1, 1],
    /// ]).unwrap());
    /// ```
    ///
    /// # Errors
    ///
    /// * [`UpdateMatrixError`] - The extension's number of rows must match the self's number of rows.
    ///
    /// # Panics
    ///
    /// Unreachable: if self or the extension have not been built correctly
    ///
    pub fn extend_with_columns(&mut self, mut extension: Self) -> Result<(), UpdateMatrixError> {
        if self.number_of_rows != extension.number_of_rows {
            return Err(UpdateMatrixError::RowsMismatch(
                extension.number_of_rows,
                self.number_of_rows,
            ));
        }
        let extended_number_of_rows = self.number_of_rows;
        let extended_number_of_columns = self.number_of_columns + extension.number_of_columns;
        let mut extended_distances =
            vec![0; self.number_of_ones() + extension.number_of_ones() + 1];

        let mut extended_prev_coord = Coordinates::new(
            self.number_of_rows - 1,
            self.number_of_columns + extension.number_of_columns - 1,
        );

        let mut a_ones_cursor = OnesCoordinatesColumnCursor::new(
            self.number_of_columns,
            Coordinates::new(self.number_of_rows - 1, self.number_of_columns - 1),
        );
        let mut b_ones_cursor = OnesCoordinatesColumnCursor::new(
            extension.number_of_columns,
            Coordinates::new(
                extension.number_of_rows - 1,
                extension.number_of_columns - 1,
            ),
        );

        let mut a_current_coord = a_ones_cursor.unchecked_backward(self.distances.pop().unwrap());
        let mut b_current_coord =
            b_ones_cursor.unchecked_backward(extension.distances.pop().unwrap());

        let mut extended_curr_coord: Coordinates;
        let mut extended_sub_distance: usize;

        let mut extended_distances_index = extended_distances.len() - 1;
        while !self.distances.is_empty() || !extension.distances.is_empty() {
            match a_current_coord.row().cmp(&b_current_coord.row()) {
                std::cmp::Ordering::Less => {
                    extended_curr_coord = Coordinates::new(
                        b_current_coord.row(),
                        b_current_coord.column() + self.number_of_columns,
                    );
                    b_current_coord =
                        b_ones_cursor.unchecked_backward(extension.distances.pop().unwrap());
                }
                std::cmp::Ordering::Greater => {
                    extended_curr_coord =
                        Coordinates::new(a_current_coord.row(), a_current_coord.column());
                    a_current_coord =
                        a_ones_cursor.unchecked_backward(self.distances.pop().unwrap());
                }
                std::cmp::Ordering::Equal => {
                    if extension.distances.is_empty() {
                        extended_curr_coord =
                            Coordinates::new(a_current_coord.row(), a_current_coord.column());
                        a_current_coord =
                            a_ones_cursor.unchecked_backward(self.distances.pop().unwrap());
                    } else {
                        extended_curr_coord = Coordinates::new(
                            b_current_coord.row(),
                            b_current_coord.column() + self.number_of_columns,
                        );
                        b_current_coord =
                            b_ones_cursor.unchecked_backward(extension.distances.pop().unwrap());
                    }
                }
            }
            extended_sub_distance = (extended_prev_coord.row() - extended_curr_coord.row())
                * extended_number_of_columns;
            if extended_curr_coord.column() >= extended_prev_coord.column() {
                extended_sub_distance -=
                    extended_curr_coord.column() - extended_prev_coord.column();
            } else {
                extended_sub_distance +=
                    extended_prev_coord.column() - extended_curr_coord.column();
            }
            extended_distances[extended_distances_index] = extended_sub_distance;
            extended_prev_coord = extended_curr_coord;
            extended_distances_index -= 1;
        }

        extended_distances[extended_distances_index] =
            extended_prev_coord.row() * extended_number_of_columns + extended_prev_coord.column();
        self.number_of_rows = extended_number_of_rows;
        self.number_of_columns = extended_number_of_columns;
        self.distances = extended_distances;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::super::tests::{matrix_a, zeros_matrix};
    use super::super::CSVBinaryMatrix;
    use super::UpdateMatrixError;
    use pretty_assertions::{assert_eq, assert_str_eq};
    use rstest::rstest;

    #[rstest]
    fn extend_with_rows(mut zeros_matrix: CSVBinaryMatrix, mut matrix_a: CSVBinaryMatrix) {
        zeros_matrix
            .extend_with_rows(CSVBinaryMatrix::try_from(&[[0, 0, 0]]).unwrap())
            .unwrap();
        assert_eq!(
            zeros_matrix,
            CSVBinaryMatrix::try_from(&[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]).unwrap()
        );
        matrix_a
            .extend_with_rows(CSVBinaryMatrix::try_from(&[[0, 1, 0]]).unwrap())
            .unwrap();
        assert_eq!(
            matrix_a,
            CSVBinaryMatrix::try_from(&[[1, 1, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0]]).unwrap()
        );
    }

    #[rstest]
    fn extend_with_rows_columns_mismatch(mut zeros_matrix: CSVBinaryMatrix) {
        match zeros_matrix.extend_with_rows(CSVBinaryMatrix {
            number_of_rows: 3,
            number_of_columns: 2,
            distances: vec![1],
            is_reversed: false,
        }) {
            Ok(()) => panic!("should panic"),
            Err(e) => {
                assert!(matches!(e, UpdateMatrixError::ColumnsMismatch(2, 3)));
                assert_str_eq!(format!("{e:?}"), "ColumnsMismatch(2, 3)");
                assert_str_eq!(
                    format!("{e}"),
                    "The extension's number of columns (2) must match the self's number of columns (3)"
                );
            }
        }
    }

    #[rstest]
    fn extend_with_columns(mut matrix_a: CSVBinaryMatrix, mut zeros_matrix: CSVBinaryMatrix) {
        zeros_matrix
            .extend_with_columns(CSVBinaryMatrix::try_from(&[[0], [0], [0]]).unwrap())
            .unwrap();
        assert_eq!(
            zeros_matrix,
            CSVBinaryMatrix::try_from(&[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],]).unwrap()
        );

        let mut pre_matrix = CSVBinaryMatrix::try_from(&[[1], [0], [1]]).unwrap();
        pre_matrix.extend_with_columns(matrix_a.clone()).unwrap();
        assert_eq!(
            pre_matrix,
            CSVBinaryMatrix::try_from(&[[1, 1, 1, 1], [0, 1, 0, 0], [1, 0, 1, 0],]).unwrap()
        );

        matrix_a
            .extend_with_columns(CSVBinaryMatrix::try_from(&[[1], [0], [1]]).unwrap())
            .unwrap();
        assert_eq!(
            matrix_a,
            CSVBinaryMatrix::try_from(&[[1, 1, 1, 1], [1, 0, 0, 0], [0, 1, 0, 1],]).unwrap()
        );
    }

    #[rstest]
    fn extend_with_columns_rows_mismatch_error(mut matrix_a: CSVBinaryMatrix) {
        match matrix_a.extend_with_columns(CSVBinaryMatrix::try_from(&[[0]]).unwrap()) {
            Ok(()) => panic!("Expected error"),
            Err(e) => {
                assert!(matches!(e, UpdateMatrixError::RowsMismatch(1, 3)));
                assert_str_eq!(format!("{e:?}"), "RowsMismatch(1, 3)");
                assert_str_eq!(
                    format!("{e}"),
                    "The extension's number of rows (1) must match the self's number of rows (3)"
                );
            }
        }
    }
}