#[inline]
pub fn rank1_outer_update(
matrix: &mut [f64],
rows: usize,
cols: usize,
alpha: f64,
x: &[f64],
y: &[f64],
) {
debug_assert_eq!(matrix.len(), rows * cols, "matrix len != rows*cols");
debug_assert_eq!(x.len(), rows, "x len != rows");
debug_assert_eq!(y.len(), cols, "y len != cols");
for (i, &xi) in x.iter().enumerate() {
let alpha_xi = alpha * xi;
let row_start = i * cols;
for (j, &yj) in y.iter().enumerate() {
unsafe {
*matrix.get_unchecked_mut(row_start + j) += alpha_xi * yj;
}
}
}
}
#[inline]
pub fn rank1_outer_update_inplace(
matrix: &mut [f64],
rows: usize,
cols: usize,
x: &[f64],
y: &[f64],
) {
debug_assert_eq!(matrix.len(), rows * cols);
debug_assert_eq!(x.len(), rows);
debug_assert_eq!(y.len(), cols);
for (i, &xi) in x.iter().enumerate() {
let row_start = i * cols;
for (j, &yj) in y.iter().enumerate() {
unsafe {
*matrix.get_unchecked_mut(row_start + j) += xi * yj;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn naive_rank1(matrix: &mut [f64], rows: usize, cols: usize, alpha: f64, x: &[f64], y: &[f64]) {
for i in 0..rows {
for j in 0..cols {
matrix[i * cols + j] += alpha * x[i] * y[j];
}
}
}
#[test]
fn rank1_matches_naive_reference() {
let rows = 4;
let cols = 3;
let alpha = 2.5;
let x = [1.0, 2.0, 3.0, 4.0];
let y = [0.5, -1.0, 0.25];
let mut m_fast = [0.0f64; 12];
let mut m_naive = [0.0f64; 12];
rank1_outer_update(&mut m_fast, rows, cols, alpha, &x, &y);
naive_rank1(&mut m_naive, rows, cols, alpha, &x, &y);
for (a, b) in m_fast.iter().zip(m_naive.iter()) {
assert!((a - b).abs() < 1e-12, "mismatch: fast={a}, naive={b}");
}
}
#[test]
fn rank1_inplace_matches_alpha_one() {
let rows = 3;
let cols = 5;
let x = [1.0, -2.0, 0.5];
let y = [0.1, 0.2, 0.3, 0.4, 0.5];
let mut m_alpha = [0.0f64; 15];
let mut m_inplace = [0.0f64; 15];
rank1_outer_update(&mut m_alpha, rows, cols, 1.0, &x, &y);
rank1_outer_update_inplace(&mut m_inplace, rows, cols, &x, &y);
for (a, b) in m_alpha.iter().zip(m_inplace.iter()) {
assert!((a - b).abs() < 1e-15, "alpha=1 and inplace must match");
}
}
#[test]
fn rank1_accumulates_correctly() {
let rows = 2;
let cols = 2;
let mut m = [0.0f64; 4];
rank1_outer_update_inplace(&mut m, rows, cols, &[1.0, 2.0], &[1.0, 2.0]);
rank1_outer_update_inplace(&mut m, rows, cols, &[2.0, 4.0], &[2.0, 4.0]);
assert!(
(m[0] - 5.0).abs() < 1e-12,
"m[0][0] expected 5, got {}",
m[0]
);
assert!(
(m[1] - 10.0).abs() < 1e-12,
"m[0][1] expected 10, got {}",
m[1]
);
assert!(
(m[2] - 10.0).abs() < 1e-12,
"m[1][0] expected 10, got {}",
m[2]
);
assert!(
(m[3] - 20.0).abs() < 1e-12,
"m[1][1] expected 20, got {}",
m[3]
);
}
#[test]
fn rank1_with_zero_alpha_leaves_matrix_unchanged() {
let rows = 2;
let cols = 2;
let initial = [1.0, 2.0, 3.0, 4.0];
let mut m = initial;
rank1_outer_update(&mut m, rows, cols, 0.0, &[100.0, 200.0], &[300.0, 400.0]);
for (a, &b) in m.iter().zip(initial.iter()) {
assert!(
(a - b).abs() < 1e-15,
"zero-alpha update should not change matrix"
);
}
}
#[test]
fn rank1_1x1_matrix() {
let mut m = [0.0f64; 1];
rank1_outer_update(&mut m, 1, 1, 3.0, &[4.0], &[5.0]);
assert!((m[0] - 60.0).abs() < 1e-12, "1x1: 3*4*5 = 60");
}
}