use crate::error::FeralError;
#[derive(Debug, Clone)]
pub struct GeneralMatrix {
pub m: usize,
pub data: Vec<f64>,
}
impl GeneralMatrix {
pub fn zeros(m: usize) -> Self {
GeneralMatrix {
m,
data: vec![0.0; m * m],
}
}
pub fn from_columns(m: usize, cols: &[Vec<f64>]) -> Result<Self, FeralError> {
if cols.len() != m {
return Err(FeralError::DimensionMismatch {
expected: m,
got: cols.len(),
});
}
let mut data = vec![0.0; m * m];
for (j, col) in cols.iter().enumerate() {
if col.len() != m {
return Err(FeralError::DimensionMismatch {
expected: m,
got: col.len(),
});
}
data[j * m..j * m + m].copy_from_slice(col);
}
let mat = GeneralMatrix { m, data };
mat.validate()?;
Ok(mat)
}
pub fn from_column_major(m: usize, data: Vec<f64>) -> Result<Self, FeralError> {
if data.len() != m * m {
return Err(FeralError::DimensionMismatch {
expected: m * m,
got: data.len(),
});
}
let mat = GeneralMatrix { m, data };
mat.validate()?;
Ok(mat)
}
#[inline]
pub fn get(&self, i: usize, j: usize) -> f64 {
self.data[j * self.m + i]
}
#[inline]
pub fn set(&mut self, i: usize, j: usize, v: f64) {
self.data[j * self.m + i] = v;
}
#[inline]
pub fn col(&self, j: usize) -> &[f64] {
&self.data[j * self.m..j * self.m + self.m]
}
#[inline]
pub fn col_mut(&mut self, j: usize) -> &mut [f64] {
&mut self.data[j * self.m..j * self.m + self.m]
}
pub fn validate(&self) -> Result<(), FeralError> {
if self.data.len() != self.m * self.m {
return Err(FeralError::InvalidInput(format!(
"GeneralMatrix data length {} != m*m = {}",
self.data.len(),
self.m * self.m
)));
}
if self.data.iter().any(|x| !x.is_finite()) {
return Err(FeralError::InvalidInput(
"GeneralMatrix contains non-finite entries".to_string(),
));
}
Ok(())
}
pub fn matvec(&self, x: &[f64], y: &mut [f64]) {
for yi in y.iter_mut() {
*yi = 0.0;
}
for (j, &xj) in x.iter().enumerate().take(self.m) {
if xj == 0.0 {
continue;
}
let col = self.col(j);
for (yi, &cij) in y.iter_mut().zip(col.iter()) {
*yi += cij * xj;
}
}
}
pub fn matvec_transpose(&self, x: &[f64], y: &mut [f64]) {
for (j, yj) in y.iter_mut().enumerate().take(self.m) {
let col = self.col(j);
let mut acc = 0.0;
for (&cij, &xi) in col.iter().zip(x.iter()) {
acc += cij * xi;
}
*yj = acc;
}
}
}