use loeres::{
ContiguousMatrixAccess, ContiguousVectorAccess, ContiguousVectorAccessMut, Dim2, DimensionKind,
FiniteScalar, MatrixAccess, MatrixAccessMut, SolverError, VectorAccess, VectorAccessMut,
};
use crate::internal::dimension_mismatch;
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
pub struct DenseIngestOptions {
pub max_elements: Option<usize>,
}
#[derive(Clone, Debug)]
pub struct DenseVector<S> {
data: Vec<S>,
}
impl<S: loeres::BaseScalar> DenseVector<S> {
pub fn from_vec(data: Vec<S>) -> Result<Self, SolverError> {
Self::from_vec_with_options(data, DenseIngestOptions::default())
}
pub fn from_vec_with_options(
data: Vec<S>,
options: DenseIngestOptions,
) -> Result<Self, SolverError> {
if data.is_empty() {
return Err(SolverError::InvalidDimension);
}
if let Some(max) = options.max_elements {
if data.len() > max {
return Err(SolverError::InvalidInput);
}
}
Ok(Self { data })
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
impl<S: FiniteScalar> DenseVector<S> {
pub fn validate_finite(&self) -> Result<(), SolverError> {
for value in &self.data {
if !value.is_finite() {
return Err(SolverError::NonFiniteInput);
}
}
Ok(())
}
}
impl<S: loeres::BaseScalar> VectorAccess for DenseVector<S> {
type Scalar = S;
fn len(&self) -> usize {
self.data.len()
}
fn dimension_kind(&self) -> DimensionKind {
DimensionKind::Dynamic
}
fn get(&self, index: usize) -> Result<S, SolverError> {
match self.data.get(index) {
Some(&value) => Ok(value),
None => Err(dimension_mismatch(index, self.data.len())),
}
}
}
impl<S: loeres::BaseScalar> VectorAccessMut for DenseVector<S> {
fn set(&mut self, index: usize, value: S) -> Result<(), SolverError> {
let len = self.data.len();
match self.data.get_mut(index) {
Some(slot) => {
*slot = value;
Ok(())
}
None => Err(dimension_mismatch(index, len)),
}
}
}
impl<S: loeres::BaseScalar> ContiguousVectorAccess for DenseVector<S> {
fn as_contiguous(&self) -> Option<&[S]> {
Some(&self.data)
}
}
impl<S: loeres::BaseScalar> ContiguousVectorAccessMut for DenseVector<S> {
fn as_contiguous_mut(&mut self) -> Option<&mut [S]> {
Some(&mut self.data)
}
}
#[derive(Clone, Debug)]
pub struct DenseMatrix<S> {
rows: usize,
cols: usize,
data: Vec<S>,
}
impl<S: loeres::BaseScalar> DenseMatrix<S> {
pub fn from_row_major_vec(rows: usize, cols: usize, data: Vec<S>) -> Result<Self, SolverError> {
Self::from_row_major_vec_with_options(rows, cols, data, DenseIngestOptions::default())
}
pub fn from_row_major_vec_with_options(
rows: usize,
cols: usize,
data: Vec<S>,
options: DenseIngestOptions,
) -> Result<Self, SolverError> {
if rows == 0 || cols == 0 {
return Err(SolverError::InvalidDimension);
}
let required = rows
.checked_mul(cols)
.ok_or(SolverError::InvalidDimension)?;
if let Some(max) = options.max_elements {
if required > max {
return Err(SolverError::InvalidInput);
}
}
if data.len() != required {
return Err(dimension_mismatch(data.len(), required));
}
Ok(Self { rows, cols, data })
}
pub fn dims(&self) -> Dim2 {
Dim2::new(self.rows, self.cols)
}
}
impl<S: FiniteScalar> DenseMatrix<S> {
pub fn validate_finite(&self) -> Result<(), SolverError> {
for value in &self.data {
if !value.is_finite() {
return Err(SolverError::NonFiniteInput);
}
}
Ok(())
}
}
impl<S: loeres::BaseScalar> MatrixAccess for DenseMatrix<S> {
type Scalar = S;
fn dims(&self) -> Dim2 {
Dim2::new(self.rows, self.cols)
}
fn dimension_kind(&self) -> DimensionKind {
DimensionKind::Dynamic
}
fn get(&self, row: usize, col: usize) -> Result<S, SolverError> {
if row >= self.rows {
return Err(dimension_mismatch(row, self.rows));
}
if col >= self.cols {
return Err(dimension_mismatch(col, self.cols));
}
let offset = row * self.cols + col;
match self.data.get(offset) {
Some(&value) => Ok(value),
None => Err(SolverError::InternalInvariantViolation),
}
}
}
impl<S: loeres::BaseScalar> MatrixAccessMut for DenseMatrix<S> {
fn set(&mut self, row: usize, col: usize, value: S) -> Result<(), SolverError> {
if row >= self.rows {
return Err(dimension_mismatch(row, self.rows));
}
if col >= self.cols {
return Err(dimension_mismatch(col, self.cols));
}
let offset = row * self.cols + col;
match self.data.get_mut(offset) {
Some(slot) => {
*slot = value;
Ok(())
}
None => Err(SolverError::InternalInvariantViolation),
}
}
}
impl<S: loeres::BaseScalar> ContiguousMatrixAccess for DenseMatrix<S> {
fn as_row_major(&self) -> Option<&[S]> {
Some(&self.data)
}
}
#[cfg(test)]
mod tests;