use std::alloc::{self, Layout};
use std::convert::TryFrom;
use std::io::{Read, Write};
use crate::error::{FastTextError, Result};
use crate::model::MinstdRng;
use crate::utils;
use crate::vector::Vector;
const ALIGNMENT: usize = 64;
pub trait Matrix {
fn rows(&self) -> i64;
fn cols(&self) -> i64;
fn dot_row(&self, vec: &Vector, i: i64) -> f32;
fn add_vector_to_row(&mut self, vec: &Vector, i: i64, scale: f32);
fn add_row_to_vector(&self, x: &mut Vector, i: i32, scale: f32);
fn average_rows_to_vector(&self, x: &mut Vector, rows: &[i32]);
fn save<W: Write>(&self, writer: &mut W) -> Result<()>;
fn load<R: Read>(reader: &mut R) -> Result<Self>
where
Self: Sized;
}
#[derive(Debug)]
pub struct DenseMatrix {
ptr: *mut f32,
m: i64,
n: i64,
size: usize,
}
unsafe impl Send for DenseMatrix {}
unsafe impl Sync for DenseMatrix {}
#[inline]
fn checked_dim_size(m: i64, n: i64) -> usize {
let m_u = usize::try_from(m).expect("DenseMatrix row count (m) must be non-negative");
let n_u = usize::try_from(n).expect("DenseMatrix column count (n) must be non-negative");
m_u.checked_mul(n_u)
.expect("DenseMatrix dimensions m*n overflow usize")
}
impl DenseMatrix {
pub fn new(m: i64, n: i64) -> Self {
let size = checked_dim_size(m, n);
if size == 0 {
return DenseMatrix {
ptr: std::ptr::null_mut(),
m,
n,
size: 0,
};
}
let layout = Layout::array::<f32>(size)
.and_then(|l| l.align_to(ALIGNMENT))
.expect("Invalid layout");
let ptr = unsafe { alloc::alloc_zeroed(layout) as *mut f32 };
if ptr.is_null() {
alloc::handle_alloc_error(layout);
}
DenseMatrix { ptr, m, n, size }
}
pub fn from_data(m: i64, n: i64, data: &[f32]) -> Self {
let mut dm = DenseMatrix::new(m, n);
let size = (m as usize) * (n as usize);
if size > 0 {
dm.data_mut().copy_from_slice(&data[..size]);
}
dm
}
#[inline]
pub fn data(&self) -> &[f32] {
if self.size == 0 {
return &[];
}
unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
}
#[inline]
pub fn data_mut(&mut self) -> &mut [f32] {
if self.size == 0 {
return &mut [];
}
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
}
#[inline]
pub fn at(&self, i: i64, j: i64) -> f32 {
let idx = (i as usize) * (self.n as usize) + (j as usize);
self.data()[idx]
}
#[inline]
pub fn at_mut(&mut self, i: i64, j: i64) -> &mut f32 {
let idx = (i as usize) * (self.n as usize) + (j as usize);
&mut self.data_mut()[idx]
}
#[inline]
pub fn row(&self, i: i64) -> &[f32] {
let start = (i as usize) * (self.n as usize);
let end = start + (self.n as usize);
&self.data()[start..end]
}
#[inline]
pub fn row_mut(&mut self, i: i64) -> &mut [f32] {
let start = (i as usize) * (self.n as usize);
let end = start + (self.n as usize);
&mut self.data_mut()[start..end]
}
pub unsafe fn add_vector_to_row_unsync(&self, vec: &Vector, i: i64, scale: f32) {
debug_assert!(i >= 0 && i < self.m, "Row index out of bounds");
debug_assert_eq!(vec.len(), self.n as usize);
let n = self.n as usize;
let start = (i as usize) * n;
let src = vec.data();
for (j, &s) in src[..n].iter().enumerate() {
let p = self.ptr.add(start + j);
p.write(p.read() + scale * s);
}
}
pub fn zero(&mut self) {
let data = self.data_mut();
data.fill(0.0);
}
pub fn uniform(&mut self, a: f32, seed: i32) {
let total = self.data().len();
if total == 0 {
return;
}
let block_size = total / 10;
let data = self.data_mut();
for block in 0..10 {
let start = block_size * block;
let end = if block == 9 {
total
} else {
(block_size * (block + 1)).min(total)
};
let mut rng = MinstdRng::new((block as u64).wrapping_add(seed as u32 as u64));
for item in data.iter_mut().take(end).skip(start) {
let u = rng.uniform_real();
*item = (u * 2.0 * a as f64 - a as f64) as f32;
}
}
}
pub fn l2_norm_row(&self, i: i64) -> Result<f32> {
assert!(i >= 0 && i < self.m, "Row index out of bounds");
let row = self.row(i);
let mut norm = 0.0f64; for &val in row {
norm += (val as f64) * (val as f64);
}
if norm.is_nan() {
return Err(FastTextError::EncounteredNaN);
}
Ok((norm.sqrt()) as f32)
}
pub fn multiply_row(&mut self, nums: &[f32], ib: i64, ie: Option<i64>) {
let ie = ie.unwrap_or(self.m);
for i in ib..ie {
let n = nums[(i - ib) as usize];
if n != 0.0 {
let row = self.row_mut(i);
for v in row.iter_mut() {
*v *= n;
}
}
}
}
pub fn divide_row(&mut self, denoms: &[f32], ib: i64, ie: Option<i64>) {
let ie = ie.unwrap_or(self.m);
for i in ib..ie {
let n = denoms[(i - ib) as usize];
if n != 0.0 {
let row = self.row_mut(i);
for v in row.iter_mut() {
*v /= n;
}
}
}
}
}
impl Clone for DenseMatrix {
fn clone(&self) -> Self {
let mut m = DenseMatrix::new(self.m, self.n);
if !self.data().is_empty() {
m.data_mut().copy_from_slice(self.data());
}
m
}
}
impl Drop for DenseMatrix {
fn drop(&mut self) {
if !self.ptr.is_null() && self.size > 0 {
let layout = Layout::array::<f32>(self.size)
.and_then(|l| l.align_to(ALIGNMENT))
.expect("Invalid layout in Drop");
unsafe {
alloc::dealloc(self.ptr as *mut u8, layout);
}
}
}
}
impl Matrix for DenseMatrix {
#[inline]
fn rows(&self) -> i64 {
self.m
}
#[inline]
fn cols(&self) -> i64 {
self.n
}
fn dot_row(&self, vec: &Vector, i: i64) -> f32 {
assert!(i >= 0 && i < self.m, "Row index out of bounds");
assert_eq!(
vec.len(),
self.n as usize,
"Vector size {} does not match matrix columns {}",
vec.len(),
self.n
);
let row = self.row(i);
crate::simd::dot_impl(row, vec.data())
}
fn add_vector_to_row(&mut self, vec: &Vector, i: i64, scale: f32) {
assert!(i >= 0 && i < self.m, "Row index out of bounds");
assert_eq!(
vec.len(),
self.n as usize,
"Vector size {} does not match matrix columns {}",
vec.len(),
self.n
);
let row = self.row_mut(i);
crate::simd::add_vector_impl(row, vec.data(), scale);
}
fn add_row_to_vector(&self, x: &mut Vector, i: i32, scale: f32) {
assert!(i >= 0 && (i as i64) < self.m, "Row index out of bounds");
assert_eq!(
x.len(),
self.n as usize,
"Vector size {} does not match matrix columns {}",
x.len(),
self.n
);
let row = self.row(i as i64);
crate::simd::add_vector_impl(x.data_mut(), row, scale);
}
fn average_rows_to_vector(&self, x: &mut Vector, rows: &[i32]) {
assert_eq!(
x.len(),
self.n as usize,
"Vector size {} does not match matrix columns {}",
x.len(),
self.n
);
crate::simd::average_rows_impl(x, rows, self);
}
fn save<W: Write>(&self, writer: &mut W) -> Result<()> {
utils::write_i64(writer, self.m)?;
utils::write_i64(writer, self.n)?;
let data = self.data();
for &val in data {
utils::write_f32(writer, val)?;
}
Ok(())
}
fn load<R: Read>(reader: &mut R) -> Result<Self> {
let m = utils::read_i64(reader)?;
let n = utils::read_i64(reader)?;
if m < 0 || n < 0 {
return Err(FastTextError::InvalidModel(format!(
"Invalid matrix dimensions: {}x{}",
m, n
)));
}
let m_u = usize::try_from(m).map_err(|_| {
FastTextError::InvalidModel(format!("Matrix row count {} is too large", m))
})?;
let n_u = usize::try_from(n).map_err(|_| {
FastTextError::InvalidModel(format!("Matrix column count {} is too large", n))
})?;
m_u.checked_mul(n_u).ok_or_else(|| {
FastTextError::InvalidModel(format!(
"Matrix dimensions {}x{} would overflow usize",
m, n
))
})?;
let mut mat = DenseMatrix::new(m, n);
let data = mat.data_mut();
let byte_slice =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut u8, data.len() * 4) };
reader
.read_exact(byte_slice)
.map_err(FastTextError::IoError)?;
Ok(mat)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simd::{add_vector_scalar, dot_scalar};
use std::io::Cursor;
#[test]
fn test_dense_matrix_alloc_safety_zero_size() {
let m1 = DenseMatrix::new(0, 0);
assert_eq!(m1.rows(), 0);
assert_eq!(m1.cols(), 0);
assert!(m1.data().is_empty());
let m2 = DenseMatrix::new(0, 100);
assert_eq!(m2.rows(), 0);
assert_eq!(m2.cols(), 100);
assert!(m2.data().is_empty());
let m3 = DenseMatrix::new(100, 0);
assert_eq!(m3.rows(), 100);
assert_eq!(m3.cols(), 0);
assert!(m3.data().is_empty());
let m4 = m1.clone();
assert_eq!(m4.rows(), 0);
assert_eq!(m4.cols(), 0);
}
#[test]
fn test_dense_matrix_layout_overflow_check() {
assert!(Layout::array::<f32>(usize::MAX).is_err());
let large = (isize::MAX as usize / std::mem::size_of::<f32>()) + 1;
assert!(Layout::array::<f32>(large).is_err());
}
#[test]
#[should_panic(expected = "must be non-negative")]
fn test_dense_matrix_new_negative_m_panics() {
let _ = DenseMatrix::new(-1, 4);
}
#[test]
#[should_panic(expected = "must be non-negative")]
fn test_dense_matrix_new_negative_n_panics() {
let _ = DenseMatrix::new(4, -1);
}
#[test]
fn test_dense_matrix_load_overflow_error() {
use std::io::Cursor;
let m_val: i64 = i64::MAX;
let n_val: i64 = i64::MAX;
let mut buf = Vec::new();
buf.extend_from_slice(&m_val.to_le_bytes());
buf.extend_from_slice(&n_val.to_le_bytes());
let mut cursor = Cursor::new(buf);
let result = DenseMatrix::load(&mut cursor);
assert!(
result.is_err(),
"Expected error for overflow dimensions, got Ok"
);
match result {
Err(FastTextError::InvalidModel(_)) => {} other => panic!("Expected InvalidModel error, got {:?}", other),
}
}
#[test]
fn test_dense_matrix_load_negative_dims_error() {
use std::io::Cursor;
let m_val: i64 = -5i64;
let n_val: i64 = 10i64;
let mut buf = Vec::new();
buf.extend_from_slice(&m_val.to_le_bytes());
buf.extend_from_slice(&n_val.to_le_bytes());
let mut cursor = Cursor::new(buf);
let result = DenseMatrix::load(&mut cursor);
match result {
Err(FastTextError::InvalidModel(_)) => {} other => panic!("Expected InvalidModel error, got {:?}", other),
}
}
#[test]
fn test_dense_matrix_new() {
let m = DenseMatrix::new(3, 4);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 4);
for i in 0..3 {
for j in 0..4 {
assert_eq!(m.at(i, j), 0.0);
}
}
}
#[test]
fn test_dense_matrix_new_zero_rows() {
let m = DenseMatrix::new(0, 4);
assert_eq!(m.rows(), 0);
assert_eq!(m.cols(), 4);
assert!(m.data().is_empty());
}
#[test]
fn test_dense_matrix_new_zero_cols() {
let m = DenseMatrix::new(3, 0);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 0);
assert!(m.data().is_empty());
}
#[test]
fn test_dense_matrix_zero() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
m.zero();
for i in 0..2 {
for j in 0..3 {
assert_eq!(m.at(i, j), 0.0, "Element ({},{}) should be zero", i, j);
}
}
}
#[test]
fn test_dense_matrix_row_major_layout() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
let data = m.data();
assert_eq!(data[0], 1.0);
assert_eq!(data[1], 2.0);
assert_eq!(data[2], 3.0);
assert_eq!(data[3], 4.0);
assert_eq!(data[4], 5.0);
assert_eq!(data[5], 6.0);
}
#[test]
fn test_dense_matrix_dot_row() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
let mut v = Vector::new(3);
v[0] = 1.0;
v[1] = 1.0;
v[2] = 1.0;
assert!((m.dot_row(&v, 0) - 6.0).abs() < 1e-6);
assert!((m.dot_row(&v, 1) - 15.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_dot_row_known_values() {
let mut m = DenseMatrix::new(1, 4);
*m.at_mut(0, 0) = 2.0;
*m.at_mut(0, 1) = 3.0;
*m.at_mut(0, 2) = 4.0;
*m.at_mut(0, 3) = 5.0;
let mut v = Vector::new(4);
v[0] = 1.0;
v[1] = 2.0;
v[2] = 3.0;
v[3] = 4.0;
assert!((m.dot_row(&v, 0) - 40.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_dot_row_zero_vec() {
let mut m = DenseMatrix::new(1, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
let v = Vector::new(3); assert_eq!(m.dot_row(&v, 0), 0.0);
}
#[test]
fn test_dense_matrix_dot_row_nan() {
let mut m = DenseMatrix::new(1, 3);
*m.at_mut(0, 0) = f32::NAN;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
let mut v = Vector::new(3);
v[0] = 1.0;
v[1] = 1.0;
v[2] = 1.0;
assert!(m.dot_row(&v, 0).is_nan());
}
#[test]
fn test_dense_matrix_add_vector_to_row() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
let mut v = Vector::new(3);
v[0] = 10.0;
v[1] = 20.0;
v[2] = 30.0;
m.add_vector_to_row(&v, 0, 1.0);
assert!((m.at(0, 0) - 11.0).abs() < 1e-6);
assert!((m.at(0, 1) - 22.0).abs() < 1e-6);
assert!((m.at(0, 2) - 33.0).abs() < 1e-6);
assert_eq!(m.at(1, 0), 0.0);
assert_eq!(m.at(1, 1), 0.0);
assert_eq!(m.at(1, 2), 0.0);
}
#[test]
fn test_dense_matrix_add_vector_to_row_scaled() {
let mut m = DenseMatrix::new(1, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
let mut v = Vector::new(3);
v[0] = 10.0;
v[1] = 20.0;
v[2] = 30.0;
m.add_vector_to_row(&v, 0, 0.5);
assert!((m.at(0, 0) - 6.0).abs() < 1e-6);
assert!((m.at(0, 1) - 12.0).abs() < 1e-6);
assert!((m.at(0, 2) - 18.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_add_row_to_vector() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
let mut v = Vector::new(3);
v[0] = 10.0;
v[1] = 20.0;
v[2] = 30.0;
m.add_row_to_vector(&mut v, 0, 1.0);
assert!((v[0] - 11.0).abs() < 1e-6);
assert!((v[1] - 22.0).abs() < 1e-6);
assert!((v[2] - 33.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_add_row_to_vector_scaled() {
let mut m = DenseMatrix::new(1, 3);
*m.at_mut(0, 0) = 2.0;
*m.at_mut(0, 1) = 4.0;
*m.at_mut(0, 2) = 6.0;
let mut v = Vector::new(3);
v[0] = 1.0;
v[1] = 1.0;
v[2] = 1.0;
m.add_row_to_vector(&mut v, 0, 0.5);
assert!((v[0] - 2.0).abs() < 1e-6);
assert!((v[1] - 3.0).abs() < 1e-6);
assert!((v[2] - 4.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_average_rows_single() {
let mut m = DenseMatrix::new(3, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
let mut v = Vector::new(3);
m.average_rows_to_vector(&mut v, &[0]);
assert!((v[0] - 1.0).abs() < 1e-6);
assert!((v[1] - 2.0).abs() < 1e-6);
assert!((v[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_average_rows_multiple() {
let mut m = DenseMatrix::new(3, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
*m.at_mut(2, 0) = 7.0;
*m.at_mut(2, 1) = 8.0;
*m.at_mut(2, 2) = 9.0;
let mut v = Vector::new(3);
m.average_rows_to_vector(&mut v, &[0, 1, 2]);
assert!((v[0] - 4.0).abs() < 1e-5);
assert!((v[1] - 5.0).abs() < 1e-5);
assert!((v[2] - 6.0).abs() < 1e-5);
}
#[test]
fn test_dense_matrix_average_rows_subset() {
let mut m = DenseMatrix::new(4, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(1, 1) = 1.0;
*m.at_mut(2, 2) = 1.0;
*m.at_mut(3, 0) = 1.0;
*m.at_mut(3, 1) = 1.0;
*m.at_mut(3, 2) = 1.0;
let mut v = Vector::new(3);
m.average_rows_to_vector(&mut v, &[0, 3]);
assert!((v[0] - 1.0).abs() < 1e-6);
assert!((v[1] - 0.5).abs() < 1e-6);
assert!((v[2] - 0.5).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_average_rows_empty() {
let m = DenseMatrix::new(3, 3);
let mut v = Vector::new(3);
v[0] = 999.0; m.average_rows_to_vector(&mut v, &[]);
assert_eq!(v[0], 0.0);
assert_eq!(v[1], 0.0);
assert_eq!(v[2], 0.0);
}
#[test]
fn test_dense_matrix_l2_norm_row() {
let mut m = DenseMatrix::new(2, 2);
*m.at_mut(0, 0) = 3.0;
*m.at_mut(0, 1) = 4.0;
*m.at_mut(1, 0) = 0.0;
*m.at_mut(1, 1) = 0.0;
assert!((m.l2_norm_row(0).unwrap() - 5.0).abs() < 1e-6);
assert_eq!(m.l2_norm_row(1).unwrap(), 0.0);
}
#[test]
fn test_dense_matrix_l2_norm_row_nan_detection() {
let mut m = DenseMatrix::new(1, 3);
*m.at_mut(0, 0) = f32::NAN;
*m.at_mut(0, 1) = 1.0;
*m.at_mut(0, 2) = 2.0;
match m.l2_norm_row(0) {
Err(FastTextError::EncounteredNaN) => {} other => panic!("Expected EncounteredNaN, got {:?}", other),
}
}
#[test]
fn test_dense_matrix_multiply_row() {
let mut m = DenseMatrix::new(3, 2);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(1, 0) = 3.0;
*m.at_mut(1, 1) = 4.0;
*m.at_mut(2, 0) = 5.0;
*m.at_mut(2, 1) = 6.0;
let nums = [2.0, 3.0, 0.5];
m.multiply_row(&nums, 0, Some(3));
assert!((m.at(0, 0) - 2.0).abs() < 1e-6);
assert!((m.at(0, 1) - 4.0).abs() < 1e-6);
assert!((m.at(1, 0) - 9.0).abs() < 1e-6);
assert!((m.at(1, 1) - 12.0).abs() < 1e-6);
assert!((m.at(2, 0) - 2.5).abs() < 1e-6);
assert!((m.at(2, 1) - 3.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_multiply_row_zero_skip() {
let mut m = DenseMatrix::new(2, 2);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(1, 0) = 3.0;
*m.at_mut(1, 1) = 4.0;
let nums = [0.0, 2.0];
m.multiply_row(&nums, 0, Some(2));
assert_eq!(m.at(0, 0), 1.0); assert_eq!(m.at(0, 1), 2.0); assert!((m.at(1, 0) - 6.0).abs() < 1e-6);
assert!((m.at(1, 1) - 8.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_multiply_row_default_ie() {
let mut m = DenseMatrix::new(2, 2);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(1, 0) = 3.0;
*m.at_mut(1, 1) = 4.0;
let nums = [2.0, 3.0];
m.multiply_row(&nums, 0, None);
assert!((m.at(0, 0) - 2.0).abs() < 1e-6);
assert!((m.at(0, 1) - 4.0).abs() < 1e-6);
assert!((m.at(1, 0) - 9.0).abs() < 1e-6);
assert!((m.at(1, 1) - 12.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_divide_row() {
let mut m = DenseMatrix::new(2, 2);
*m.at_mut(0, 0) = 4.0;
*m.at_mut(0, 1) = 6.0;
*m.at_mut(1, 0) = 8.0;
*m.at_mut(1, 1) = 10.0;
let denoms = [2.0, 5.0];
m.divide_row(&denoms, 0, Some(2));
assert!((m.at(0, 0) - 2.0).abs() < 1e-6);
assert!((m.at(0, 1) - 3.0).abs() < 1e-6);
assert!((m.at(1, 0) - 1.6).abs() < 1e-6);
assert!((m.at(1, 1) - 2.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_divide_row_zero_denom() {
let mut m = DenseMatrix::new(2, 2);
*m.at_mut(0, 0) = 4.0;
*m.at_mut(0, 1) = 6.0;
*m.at_mut(1, 0) = 8.0;
*m.at_mut(1, 1) = 10.0;
let denoms = [0.0, 2.0];
m.divide_row(&denoms, 0, Some(2));
assert_eq!(m.at(0, 0), 4.0); assert_eq!(m.at(0, 1), 6.0); assert!((m.at(1, 0) - 4.0).abs() < 1e-6);
assert!((m.at(1, 1) - 5.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_divide_row_default_ie() {
let mut m = DenseMatrix::new(2, 2);
*m.at_mut(0, 0) = 4.0;
*m.at_mut(0, 1) = 6.0;
*m.at_mut(1, 0) = 8.0;
*m.at_mut(1, 1) = 10.0;
let denoms = [2.0, 4.0];
m.divide_row(&denoms, 0, None);
assert!((m.at(0, 0) - 2.0).abs() < 1e-6);
assert!((m.at(0, 1) - 3.0).abs() < 1e-6);
assert!((m.at(1, 0) - 2.0).abs() < 1e-6);
assert!((m.at(1, 1) - 2.5).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_simd_consistency() {
for &dim in &[16, 32, 64, 256, 512] {
let mut m = DenseMatrix::new(3, dim);
let n = dim as usize;
for i in 0..3 {
for j in 0..n {
*m.at_mut(i as i64, j as i64) = ((i * n + j) as f32) * 0.01;
}
}
let mut v = Vector::new(n);
for j in 0..n {
v[j] = ((n - j) as f32) * 0.01;
}
let simd_dot = m.dot_row(&v, 1);
let row1 = m.row(1);
let scalar_dot = dot_scalar(row1, v.data());
let magnitude = simd_dot.abs().max(scalar_dot.abs()).max(1.0);
let tolerance = magnitude * f32::EPSILON * n as f32;
assert!(
(simd_dot - scalar_dot).abs() < tolerance,
"dot_row SIMD vs scalar mismatch for dim={}: SIMD={}, scalar={}",
dim,
simd_dot,
scalar_dot,
);
let mut m_simd = m.clone();
let mut m_scalar = m.clone();
m_simd.add_vector_to_row(&v, 0, 0.5);
let row0_scalar = m_scalar.row_mut(0);
add_vector_scalar(row0_scalar, v.data(), 0.5);
for j in 0..n {
let s = m_simd.at(0, j as i64);
let sc = m_scalar.at(0, j as i64);
let mag = s.abs().max(sc.abs()).max(1.0);
let tol = mag * f32::EPSILON * 4.0;
assert!(
(s - sc).abs() < tol,
"add_vector_to_row mismatch at j={} for dim={}: SIMD={}, scalar={}",
j,
dim,
s,
sc,
);
}
let mut v_simd = Vector::new(n);
let mut v_scalar = Vector::new(n);
for j in 0..n {
v_simd[j] = j as f32 * 0.1;
v_scalar[j] = j as f32 * 0.1;
}
m.add_row_to_vector(&mut v_simd, 2, 0.7);
add_vector_scalar(v_scalar.data_mut(), m.row(2), 0.7);
for j in 0..n {
let s = v_simd[j];
let sc = v_scalar[j];
let mag = s.abs().max(sc.abs()).max(1.0);
let tol = mag * f32::EPSILON * 4.0;
assert!(
(s - sc).abs() < tol,
"add_row_to_vector mismatch at j={} for dim={}: SIMD={}, scalar={}",
j,
dim,
s,
sc,
);
}
let mut v_avg_simd = Vector::new(n);
let mut v_avg_scalar = Vector::new(n);
m.average_rows_to_vector(&mut v_avg_simd, &[0, 1, 2]);
v_avg_scalar.zero();
for i in 0..3 {
add_vector_scalar(v_avg_scalar.data_mut(), m.row(i), 1.0);
}
v_avg_scalar.mul(1.0 / 3.0);
for j in 0..n {
let s = v_avg_simd[j];
let sc = v_avg_scalar[j];
let mag = s.abs().max(sc.abs()).max(1.0);
let tol = mag * f32::EPSILON * n as f32;
assert!(
(s - sc).abs() < tol,
"average_rows_to_vector mismatch at j={} for dim={}: SIMD={}, scalar={}",
j,
dim,
s,
sc,
);
}
}
}
#[test]
fn test_dense_matrix_alignment() {
for &(rows, cols) in &[(1, 1), (1, 16), (4, 64), (10, 100), (100, 256), (8, 512)] {
let m = DenseMatrix::new(rows, cols);
let ptr_addr = m.data().as_ptr() as usize;
assert_eq!(
ptr_addr % ALIGNMENT,
0,
"DenseMatrix {}x{} is not 64-byte aligned (addr: 0x{:x})",
rows,
cols,
ptr_addr,
);
}
}
#[test]
#[should_panic(expected = "Row index out of bounds")]
fn test_dense_matrix_dot_row_out_of_bounds() {
let m = DenseMatrix::new(2, 3);
let v = Vector::new(3);
let _ = m.dot_row(&v, 2);
}
#[test]
#[should_panic(expected = "Vector size")]
fn test_dense_matrix_dot_row_size_mismatch() {
let m = DenseMatrix::new(2, 3);
let v = Vector::new(4);
let _ = m.dot_row(&v, 0);
}
#[test]
#[should_panic(expected = "Row index out of bounds")]
fn test_dense_matrix_add_vector_to_row_out_of_bounds() {
let mut m = DenseMatrix::new(2, 3);
let v = Vector::new(3);
m.add_vector_to_row(&v, 2, 1.0);
}
#[test]
#[should_panic(expected = "Row index out of bounds")]
fn test_dense_matrix_add_row_to_vector_out_of_bounds() {
let m = DenseMatrix::new(2, 3);
let mut v = Vector::new(3);
m.add_row_to_vector(&mut v, 2, 1.0);
}
#[test]
#[should_panic(expected = "Row index out of bounds")]
fn test_dense_matrix_l2_norm_row_out_of_bounds() {
let m = DenseMatrix::new(2, 3);
let _ = m.l2_norm_row(2);
}
#[test]
fn test_dense_matrix_save_load_roundtrip() {
let mut m = DenseMatrix::new(3, 4);
for i in 0..3 {
for j in 0..4 {
*m.at_mut(i, j) = (i * 4 + j) as f32 * 1.5;
}
}
let mut buf = Vec::new();
m.save(&mut buf).unwrap();
assert_eq!(buf.len(), 8 + 8 + 3 * 4 * 4);
let mut cursor = Cursor::new(&buf);
let loaded = DenseMatrix::load(&mut cursor).unwrap();
assert_eq!(loaded.rows(), 3);
assert_eq!(loaded.cols(), 4);
for i in 0..3 {
for j in 0..4 {
assert_eq!(loaded.at(i, j), m.at(i, j), "Mismatch at ({}, {})", i, j,);
}
}
}
#[test]
fn test_dense_matrix_save_load_empty() {
let m = DenseMatrix::new(0, 0);
let mut buf = Vec::new();
m.save(&mut buf).unwrap();
let mut cursor = Cursor::new(&buf);
let loaded = DenseMatrix::load(&mut cursor).unwrap();
assert_eq!(loaded.rows(), 0);
assert_eq!(loaded.cols(), 0);
}
#[test]
fn test_dense_matrix_save_load_header_format() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
let mut buf = Vec::new();
m.save(&mut buf).unwrap();
let mut cursor = Cursor::new(&buf);
let rows = utils::read_i64(&mut cursor).unwrap();
let cols = utils::read_i64(&mut cursor).unwrap();
assert_eq!(rows, 2);
assert_eq!(cols, 3);
for expected in &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] {
let val = utils::read_f32(&mut cursor).unwrap();
assert_eq!(val, *expected);
}
}
#[test]
fn test_dense_matrix_load_truncated() {
let buf = vec![0u8; 4]; let mut cursor = Cursor::new(&buf);
assert!(DenseMatrix::load(&mut cursor).is_err());
}
#[test]
fn test_dense_matrix_uniform() {
let mut m = DenseMatrix::new(10, 20);
m.uniform(0.5, 0);
let data = m.data();
for &val in data {
assert!(
(-0.5..=0.5).contains(&val),
"Uniform value {} out of range [-0.5, 0.5]",
val,
);
}
let non_zero_count = data.iter().filter(|&&v| v != 0.0).count();
assert!(non_zero_count > 0, "All values are zero after uniform init");
}
#[test]
fn test_dense_matrix_uniform_deterministic() {
let mut m1 = DenseMatrix::new(10, 20);
let mut m2 = DenseMatrix::new(10, 20);
m1.uniform(0.5, 42);
m2.uniform(0.5, 42);
assert_eq!(m1.data(), m2.data(), "Same seed should produce same values");
}
#[test]
fn test_dense_matrix_uniform_different_seeds() {
let mut m1 = DenseMatrix::new(10, 20);
let mut m2 = DenseMatrix::new(10, 20);
m1.uniform(0.5, 42);
m2.uniform(0.5, 43);
assert_ne!(
m1.data(),
m2.data(),
"Different seeds should produce different values"
);
}
#[test]
fn test_dense_matrix_clone() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(1, 2) = 5.0;
let m2 = m.clone();
assert_eq!(m2.rows(), 2);
assert_eq!(m2.cols(), 3);
assert_eq!(m2.at(0, 0), 1.0);
assert_eq!(m2.at(1, 2), 5.0);
*m.at_mut(0, 0) = 99.0;
assert_eq!(m2.at(0, 0), 1.0);
}
#[test]
fn test_dense_matrix_row_access() {
let mut m = DenseMatrix::new(2, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
*m.at_mut(1, 0) = 4.0;
*m.at_mut(1, 1) = 5.0;
*m.at_mut(1, 2) = 6.0;
let row0 = m.row(0);
assert_eq!(row0, &[1.0, 2.0, 3.0]);
let row1 = m.row(1);
assert_eq!(row1, &[4.0, 5.0, 6.0]);
}
#[test]
fn test_dense_matrix_multiply_row_with_offset() {
let mut m = DenseMatrix::new(4, 2);
for i in 0..4 {
*m.at_mut(i, 0) = (i + 1) as f32;
*m.at_mut(i, 1) = (i + 1) as f32 * 10.0;
}
let nums = [2.0, 3.0];
m.multiply_row(&nums, 1, Some(3));
assert_eq!(m.at(0, 0), 1.0);
assert_eq!(m.at(0, 1), 10.0);
assert!((m.at(1, 0) - 4.0).abs() < 1e-6);
assert!((m.at(1, 1) - 40.0).abs() < 1e-6);
assert!((m.at(2, 0) - 9.0).abs() < 1e-6);
assert!((m.at(2, 1) - 90.0).abs() < 1e-6);
assert_eq!(m.at(3, 0), 4.0);
assert_eq!(m.at(3, 1), 40.0);
}
#[test]
fn test_dense_matrix_divide_row_with_offset() {
let mut m = DenseMatrix::new(4, 2);
for i in 0..4 {
*m.at_mut(i, 0) = (i + 1) as f32 * 10.0;
*m.at_mut(i, 1) = (i + 1) as f32 * 20.0;
}
let denoms = [5.0, 10.0];
m.divide_row(&denoms, 2, Some(4));
assert_eq!(m.at(0, 0), 10.0);
assert_eq!(m.at(0, 1), 20.0);
assert_eq!(m.at(1, 0), 20.0);
assert_eq!(m.at(1, 1), 40.0);
assert!((m.at(2, 0) - 6.0).abs() < 1e-6);
assert!((m.at(2, 1) - 12.0).abs() < 1e-6);
assert!((m.at(3, 0) - 4.0).abs() < 1e-6);
assert!((m.at(3, 1) - 8.0).abs() < 1e-6);
}
#[test]
fn test_dense_matrix_dot_row_nan_in_vector() {
let mut m = DenseMatrix::new(1, 3);
*m.at_mut(0, 0) = 1.0;
*m.at_mut(0, 1) = 2.0;
*m.at_mut(0, 2) = 3.0;
let mut v = Vector::new(3);
v[0] = 1.0;
v[1] = f32::NAN;
v[2] = 1.0;
assert!(m.dot_row(&v, 0).is_nan());
}
#[test]
fn test_dense_matrix_large_dot_row() {
let dim = 100;
let mut m = DenseMatrix::new(10, dim);
for i in 0..10 {
for j in 0..dim as usize {
*m.at_mut(i, j as i64) = 1.0;
}
}
let mut v = Vector::new(dim as usize);
for j in 0..dim as usize {
v[j] = 1.0;
}
assert!((m.dot_row(&v, 0) - 100.0).abs() < 0.01);
}
#[test]
fn test_dense_matrix_simd_dispatch_consistency() {
for &dim in &[1_i64, 3, 7, 15, 16, 32, 64, 100, 256, 512] {
let mut m = DenseMatrix::new(3, dim);
let n = dim as usize;
for i in 0..3 {
for j in 0..n {
*m.at_mut(i as i64, j as i64) = ((i * n + j) as f32) * 0.01;
}
}
let mut v = Vector::new(n);
for j in 0..n {
v[j] = ((n - j) as f32) * 0.01;
}
let row1 = m.row(1);
let simd_dot = crate::simd::dot_impl(row1, v.data());
let scalar_dot = dot_scalar(row1, v.data());
let tolerance = scalar_dot.abs().max(simd_dot.abs()).max(1.0) * f32::EPSILON * n as f32;
assert!(
(simd_dot - scalar_dot).abs() < tolerance,
"dot mismatch for dim={}: simd={}, scalar={}",
dim,
simd_dot,
scalar_dot,
);
let mut dest_simd: Vec<f32> = (0..n).map(|j| j as f32 * 0.5).collect();
let mut dest_scalar: Vec<f32> = (0..n).map(|j| j as f32 * 0.5).collect();
let src: Vec<f32> = (0..n).map(|j| j as f32 * 0.1).collect();
crate::simd::add_vector_impl(&mut dest_simd, &src, 2.0);
add_vector_scalar(&mut dest_scalar, &src, 2.0);
for j in 0..n {
let mag = dest_simd[j].abs().max(dest_scalar[j].abs()).max(1.0);
let tol = mag * f32::EPSILON * 4.0;
assert!(
(dest_simd[j] - dest_scalar[j]).abs() < tol,
"add_vector mismatch at j={} for dim={}: simd={}, scalar={}",
j,
dim,
dest_simd[j],
dest_scalar[j],
);
}
let mut v_avg_simd = Vector::new(n);
let mut v_avg_scalar = Vector::new(n);
crate::simd::average_rows_impl(&mut v_avg_simd, &[0, 1, 2], &m);
crate::simd::average_rows_scalar(&mut v_avg_scalar, &[0, 1, 2], &m);
for j in 0..n {
let mag = v_avg_simd[j].abs().max(v_avg_scalar[j].abs()).max(1.0);
let tol = mag * f32::EPSILON * n as f32;
assert!(
(v_avg_simd[j] - v_avg_scalar[j]).abs() < tol,
"average_rows mismatch at j={} for dim={}: simd={}, scalar={}",
j,
dim,
v_avg_simd[j],
v_avg_scalar[j],
);
}
}
}
}