use super::dim_u32;
use crate::dimension::{Dim2, DimensionKind};
use crate::error::SolverError;
use crate::scalar::BaseScalar;
pub trait MatrixAccess {
type Scalar: BaseScalar;
fn dims(&self) -> Dim2;
fn dimension_kind(&self) -> DimensionKind;
fn get(&self, row: usize, col: usize) -> Result<Self::Scalar, SolverError>;
}
pub trait MatrixAccessMut: MatrixAccess {
fn set(&mut self, row: usize, col: usize, value: Self::Scalar) -> Result<(), SolverError>;
}
pub trait ContiguousMatrixAccess: MatrixAccess {
fn as_row_major(&self) -> Option<&[Self::Scalar]>;
}
#[inline]
fn validate_row_major(len: usize, rows: usize, cols: usize) -> Result<(), SolverError> {
let required = rows
.checked_mul(cols)
.ok_or(SolverError::InvalidDimension)?;
if len != required {
return Err(SolverError::DimensionMismatch {
lhs: dim_u32(len)?,
rhs: dim_u32(required)?,
});
}
Ok(())
}
#[inline]
fn checked_offset(row: usize, col: usize, rows: usize, cols: usize) -> Result<usize, SolverError> {
if row >= rows {
return Err(SolverError::DimensionMismatch {
lhs: dim_u32(row)?,
rhs: dim_u32(rows)?,
});
}
if col >= cols {
return Err(SolverError::DimensionMismatch {
lhs: dim_u32(col)?,
rhs: dim_u32(cols)?,
});
}
row.checked_mul(cols)
.and_then(|base| base.checked_add(col))
.ok_or(SolverError::InternalInvariantViolation)
}
#[derive(Copy, Clone, Debug)]
pub struct MatrixView<'a, S: BaseScalar> {
data: &'a [S],
rows: usize,
cols: usize,
}
impl<'a, S: BaseScalar> MatrixView<'a, S> {
#[inline]
pub fn from_row_major(data: &'a [S], rows: usize, cols: usize) -> Result<Self, SolverError> {
validate_row_major(data.len(), rows, cols)?;
Ok(Self { data, rows, cols })
}
}
impl<S: BaseScalar> MatrixAccess for MatrixView<'_, S> {
type Scalar = S;
#[inline]
fn dims(&self) -> Dim2 {
Dim2::new(self.rows, self.cols)
}
#[inline]
fn dimension_kind(&self) -> DimensionKind {
DimensionKind::Dynamic
}
#[inline]
fn get(&self, row: usize, col: usize) -> Result<S, SolverError> {
let offset = checked_offset(row, col, self.rows, self.cols)?;
match self.data.get(offset) {
Some(&value) => Ok(value),
None => Err(SolverError::InternalInvariantViolation),
}
}
}
impl<S: BaseScalar> ContiguousMatrixAccess for MatrixView<'_, S> {
#[inline]
fn as_row_major(&self) -> Option<&[S]> {
Some(self.data)
}
}
#[derive(Debug)]
pub struct MatrixViewMut<'a, S: BaseScalar> {
data: &'a mut [S],
rows: usize,
cols: usize,
}
impl<'a, S: BaseScalar> MatrixViewMut<'a, S> {
#[inline]
pub fn from_row_major_mut(
data: &'a mut [S],
rows: usize,
cols: usize,
) -> Result<Self, SolverError> {
validate_row_major(data.len(), rows, cols)?;
Ok(Self { data, rows, cols })
}
}
impl<S: BaseScalar> MatrixAccess for MatrixViewMut<'_, S> {
type Scalar = S;
#[inline]
fn dims(&self) -> Dim2 {
Dim2::new(self.rows, self.cols)
}
#[inline]
fn dimension_kind(&self) -> DimensionKind {
DimensionKind::Dynamic
}
#[inline]
fn get(&self, row: usize, col: usize) -> Result<S, SolverError> {
let offset = checked_offset(row, col, self.rows, self.cols)?;
match self.data.get(offset) {
Some(&value) => Ok(value),
None => Err(SolverError::InternalInvariantViolation),
}
}
}
impl<S: BaseScalar> MatrixAccessMut for MatrixViewMut<'_, S> {
#[inline]
fn set(&mut self, row: usize, col: usize, value: S) -> Result<(), SolverError> {
let offset = checked_offset(row, col, self.rows, self.cols)?;
match self.data.get_mut(offset) {
Some(slot) => {
*slot = value;
Ok(())
}
None => Err(SolverError::InternalInvariantViolation),
}
}
}
impl<S: BaseScalar> ContiguousMatrixAccess for MatrixViewMut<'_, S> {
#[inline]
fn as_row_major(&self) -> Option<&[S]> {
Some(&*self.data)
}
}