#![doc = include_str!("../../docs/matrix/update.md")]
use super::{items::Coordinates, iter::OnesCoordinatesColumnCursor, CSVBinaryMatrix};
use thiserror::Error;
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Error)]
pub enum UpdateMatrixError {
#[error(
"The extension's number of columns ({0}) must match the self's number of columns ({1})"
)]
ColumnsMismatch(usize, usize),
#[error("The extension's number of rows ({0}) must match the self's number of rows ({1})")]
RowsMismatch(usize, usize),
}
impl CSVBinaryMatrix {
pub fn extend_with_rows(&mut self, mut extension: Self) -> Result<(), UpdateMatrixError> {
if self.number_of_columns != extension.number_of_columns {
return Err(UpdateMatrixError::ColumnsMismatch(
extension.number_of_columns,
self.number_of_columns,
));
}
let len_self_distances = self.distances.len();
let last_distance = self.distances.pop().unwrap();
self.distances.append(&mut extension.distances);
self.distances[len_self_distances - 1] += last_distance + 1;
self.number_of_rows += extension.number_of_rows;
Ok(())
}
pub fn extend_with_columns(&mut self, mut extension: Self) -> Result<(), UpdateMatrixError> {
if self.number_of_rows != extension.number_of_rows {
return Err(UpdateMatrixError::RowsMismatch(
extension.number_of_rows,
self.number_of_rows,
));
}
let extended_number_of_rows = self.number_of_rows;
let extended_number_of_columns = self.number_of_columns + extension.number_of_columns;
let mut extended_distances =
vec![0; self.number_of_ones() + extension.number_of_ones() + 1];
let mut extended_prev_coord = Coordinates::new(
self.number_of_rows - 1,
self.number_of_columns + extension.number_of_columns - 1,
);
let mut a_ones_cursor = OnesCoordinatesColumnCursor::new(
self.number_of_columns,
Coordinates::new(self.number_of_rows - 1, self.number_of_columns - 1),
);
let mut b_ones_cursor = OnesCoordinatesColumnCursor::new(
extension.number_of_columns,
Coordinates::new(
extension.number_of_rows - 1,
extension.number_of_columns - 1,
),
);
let mut a_current_coord = a_ones_cursor.unchecked_backward(self.distances.pop().unwrap());
let mut b_current_coord =
b_ones_cursor.unchecked_backward(extension.distances.pop().unwrap());
let mut extended_curr_coord: Coordinates;
let mut extended_sub_distance: usize;
let mut extended_distances_index = extended_distances.len() - 1;
while !self.distances.is_empty() || !extension.distances.is_empty() {
match a_current_coord.row().cmp(&b_current_coord.row()) {
std::cmp::Ordering::Less => {
extended_curr_coord = Coordinates::new(
b_current_coord.row(),
b_current_coord.column() + self.number_of_columns,
);
b_current_coord =
b_ones_cursor.unchecked_backward(extension.distances.pop().unwrap());
}
std::cmp::Ordering::Greater => {
extended_curr_coord =
Coordinates::new(a_current_coord.row(), a_current_coord.column());
a_current_coord =
a_ones_cursor.unchecked_backward(self.distances.pop().unwrap());
}
std::cmp::Ordering::Equal => {
if extension.distances.is_empty() {
extended_curr_coord =
Coordinates::new(a_current_coord.row(), a_current_coord.column());
a_current_coord =
a_ones_cursor.unchecked_backward(self.distances.pop().unwrap());
} else {
extended_curr_coord = Coordinates::new(
b_current_coord.row(),
b_current_coord.column() + self.number_of_columns,
);
b_current_coord =
b_ones_cursor.unchecked_backward(extension.distances.pop().unwrap());
}
}
}
extended_sub_distance = (extended_prev_coord.row() - extended_curr_coord.row())
* extended_number_of_columns;
if extended_curr_coord.column() >= extended_prev_coord.column() {
extended_sub_distance -=
extended_curr_coord.column() - extended_prev_coord.column();
} else {
extended_sub_distance +=
extended_prev_coord.column() - extended_curr_coord.column();
}
extended_distances[extended_distances_index] = extended_sub_distance;
extended_prev_coord = extended_curr_coord;
extended_distances_index -= 1;
}
extended_distances[extended_distances_index] =
extended_prev_coord.row() * extended_number_of_columns + extended_prev_coord.column();
self.number_of_rows = extended_number_of_rows;
self.number_of_columns = extended_number_of_columns;
self.distances = extended_distances;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::super::tests::{matrix_a, zeros_matrix};
use super::super::CSVBinaryMatrix;
use super::UpdateMatrixError;
use pretty_assertions::{assert_eq, assert_str_eq};
use rstest::rstest;
#[rstest]
fn extend_with_rows(mut zeros_matrix: CSVBinaryMatrix, mut matrix_a: CSVBinaryMatrix) {
zeros_matrix
.extend_with_rows(CSVBinaryMatrix::try_from(&[[0, 0, 0]]).unwrap())
.unwrap();
assert_eq!(
zeros_matrix,
CSVBinaryMatrix::try_from(&[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]]).unwrap()
);
matrix_a
.extend_with_rows(CSVBinaryMatrix::try_from(&[[0, 1, 0]]).unwrap())
.unwrap();
assert_eq!(
matrix_a,
CSVBinaryMatrix::try_from(&[[1, 1, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0]]).unwrap()
);
}
#[rstest]
fn extend_with_rows_columns_mismatch(mut zeros_matrix: CSVBinaryMatrix) {
match zeros_matrix.extend_with_rows(CSVBinaryMatrix {
number_of_rows: 3,
number_of_columns: 2,
distances: vec![1],
is_reversed: false,
}) {
Ok(()) => panic!("should panic"),
Err(e) => {
assert!(matches!(e, UpdateMatrixError::ColumnsMismatch(2, 3)));
assert_str_eq!(format!("{e:?}"), "ColumnsMismatch(2, 3)");
assert_str_eq!(
format!("{e}"),
"The extension's number of columns (2) must match the self's number of columns (3)"
);
}
}
}
#[rstest]
fn extend_with_columns(mut matrix_a: CSVBinaryMatrix, mut zeros_matrix: CSVBinaryMatrix) {
zeros_matrix
.extend_with_columns(CSVBinaryMatrix::try_from(&[[0], [0], [0]]).unwrap())
.unwrap();
assert_eq!(
zeros_matrix,
CSVBinaryMatrix::try_from(&[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0],]).unwrap()
);
let mut pre_matrix = CSVBinaryMatrix::try_from(&[[1], [0], [1]]).unwrap();
pre_matrix.extend_with_columns(matrix_a.clone()).unwrap();
assert_eq!(
pre_matrix,
CSVBinaryMatrix::try_from(&[[1, 1, 1, 1], [0, 1, 0, 0], [1, 0, 1, 0],]).unwrap()
);
matrix_a
.extend_with_columns(CSVBinaryMatrix::try_from(&[[1], [0], [1]]).unwrap())
.unwrap();
assert_eq!(
matrix_a,
CSVBinaryMatrix::try_from(&[[1, 1, 1, 1], [1, 0, 0, 0], [0, 1, 0, 1],]).unwrap()
);
}
#[rstest]
fn extend_with_columns_rows_mismatch_error(mut matrix_a: CSVBinaryMatrix) {
match matrix_a.extend_with_columns(CSVBinaryMatrix::try_from(&[[0]]).unwrap()) {
Ok(()) => panic!("Expected error"),
Err(e) => {
assert!(matches!(e, UpdateMatrixError::RowsMismatch(1, 3)));
assert_str_eq!(format!("{e:?}"), "RowsMismatch(1, 3)");
assert_str_eq!(
format!("{e}"),
"The extension's number of rows (1) must match the self's number of rows (3)"
);
}
}
}
}