use crate::ModelError;
pub trait DesignMatrix {
fn nrows(&self) -> usize;
fn ncols(&self) -> usize;
fn dot_row(&self, row: usize, beta: &[f64]) -> f64;
fn add_t_mul_vec(&self, weights: &[f64], out: &mut [f64]);
}
#[derive(Debug, Clone, PartialEq)]
pub struct DenseDesign {
nrows: usize,
ncols: usize,
values: Vec<f64>,
}
impl DenseDesign {
pub fn from_row_major(
nrows: usize,
ncols: usize,
values: Vec<f64>,
) -> Result<Self, ModelError> {
let expected_values = nrows * ncols;
let actual_values = values.len();
if actual_values != expected_values {
return Err(ModelError::DesignSize {
expected_values,
actual_values,
});
}
Ok(Self {
nrows,
ncols,
values,
})
}
pub fn from_rows<const C: usize>(rows: &[[f64; C]]) -> Self {
let values = rows.iter().flat_map(|row| row.iter().copied()).collect();
Self {
nrows: rows.len(),
ncols: C,
values,
}
}
pub fn intercept(nrows: usize) -> Self {
Self {
nrows,
ncols: 1,
values: vec![1.0; nrows],
}
}
pub fn column(values: &[f64]) -> Self {
Self {
nrows: values.len(),
ncols: 1,
values: values.to_vec(),
}
}
pub fn from_columns(
nrows: usize,
include_intercept: bool,
columns: &[&[f64]],
) -> Result<Self, ModelError> {
for column in columns {
if column.len() != nrows {
return Err(ModelError::DesignRowMismatch {
parameter: "column",
expected_rows: nrows,
actual_rows: column.len(),
});
}
}
let ncols = columns.len() + usize::from(include_intercept);
let mut values = Vec::with_capacity(nrows * ncols);
for row in 0..nrows {
if include_intercept {
values.push(1.0);
}
for column in columns {
values.push(column[row]);
}
}
Self::from_row_major(nrows, ncols, values)
}
pub fn values(&self) -> &[f64] {
&self.values
}
}
impl DesignMatrix for DenseDesign {
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn dot_row(&self, row: usize, beta: &[f64]) -> f64 {
debug_assert!(row < self.nrows);
debug_assert_eq!(beta.len(), self.ncols);
let offset = row * self.ncols;
self.values[offset..offset + self.ncols]
.iter()
.zip(beta)
.map(|(x, b)| x * b)
.sum()
}
fn add_t_mul_vec(&self, weights: &[f64], out: &mut [f64]) {
debug_assert_eq!(weights.len(), self.nrows);
debug_assert_eq!(out.len(), self.ncols);
for (row, weight) in weights.iter().copied().enumerate() {
let offset = row * self.ncols;
for (col, out_value) in out.iter_mut().enumerate() {
*out_value += self.values[offset + col] * weight;
}
}
}
}
#[cfg(test)]
mod tests {
use super::{DenseDesign, DesignMatrix};
use approx::assert_relative_eq;
#[test]
fn dense_design_multiplies_rows_and_transpose() {
let design = DenseDesign::from_rows(&[[1.0, 2.0], [3.0, 4.0]]);
assert_relative_eq!(design.dot_row(1, &[10.0, 1.0]), 34.0);
let mut out = vec![0.0, 0.0];
design.add_t_mul_vec(&[0.5, 2.0], &mut out);
assert_relative_eq!(out[0], 6.5);
assert_relative_eq!(out[1], 9.0);
}
}