use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::random::prelude::*;
use scirs2_core::random::{Distribution, Normal, Uniform};
use std::iter::Sum;
pub trait SketchFloat: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static {}
impl<F> SketchFloat for F where F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static {}
fn matmul_nn<F: SketchFloat>(a: &Array2<F>, b: &Array2<F>) -> LinalgResult<Array2<F>> {
let (m, ka) = (a.nrows(), a.ncols());
let (kb, n) = (b.nrows(), b.ncols());
if ka != kb {
return Err(LinalgError::ShapeError(format!(
"sketch matmul: inner dims {} vs {}",
ka, kb
)));
}
let mut c = Array2::<F>::zeros((m, n));
for i in 0..m {
for l in 0..ka {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] += a_il * b[[l, j]];
}
}
}
Ok(c)
}
fn matmul_transpose_a<F: SketchFloat>(a: &Array2<F>, b: &Array2<F>) -> LinalgResult<Array2<F>> {
let (m, ka) = (a.nrows(), a.ncols());
let (kb, n) = (b.nrows(), b.ncols());
if m != kb {
return Err(LinalgError::ShapeError(format!(
"sketch A^T B: dims mismatch {} vs {}",
m, kb
)));
}
let mut c = Array2::<F>::zeros((ka, n));
for l in 0..m {
for i in 0..ka {
let a_li = a[[l, i]];
if a_li == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] += a_li * b[[l, j]];
}
}
}
Ok(c)
}
pub fn gaussian_sketch<F: SketchFloat>(
a: &ArrayView2<F>,
k: usize,
rng: &mut impl Rng,
) -> LinalgResult<Array2<F>> {
let (m, n) = (a.nrows(), a.ncols());
if k == 0 || k > m {
return Err(LinalgError::InvalidInputError(format!(
"gaussian_sketch: k={} must be in [1, {}]",
k, m
)));
}
let scale = F::from(1.0 / (k as f64).sqrt())
.ok_or_else(|| LinalgError::ComputationError("Cannot convert scale".into()))?;
let normal = Normal::new(0.0_f64, 1.0)
.map_err(|e| LinalgError::ComputationError(format!("Normal dist: {e}")))?;
let mut result = Array2::<F>::zeros((k, n));
for i in 0..k {
for l in 0..m {
let s_il = F::from(normal.sample(rng)).unwrap_or(F::zero()) * scale;
if s_il == F::zero() {
continue;
}
for j in 0..n {
result[[i, j]] += s_il * a[[l, j]];
}
}
}
Ok(result)
}
pub fn subsampled_rht<F: SketchFloat>(
a: &ArrayView2<F>,
k: usize,
rng: &mut impl Rng,
) -> LinalgResult<Array2<F>> {
let (m, n) = (a.nrows(), a.ncols());
if k == 0 || k > m {
return Err(LinalgError::InvalidInputError(format!(
"subsampled_rht: k={} must be in [1, {}]",
k, m
)));
}
let m_padded = next_power_of_two(m);
let mut da = Array2::<F>::zeros((m_padded, n));
let uniform_sign = Uniform::new(0u8, 2)
.map_err(|e| LinalgError::ComputationError(format!("Uniform dist: {e}")))?;
let signs: Vec<F> = (0..m)
.map(|_| {
if uniform_sign.sample(rng) == 0 {
F::one()
} else {
-F::one()
}
})
.collect();
for i in 0..m {
for j in 0..n {
da[[i, j]] = signs[i] * a[[i, j]];
}
}
hadamard_transform_rows(&mut da, m_padded, n);
let scale = F::from(1.0 / (m_padded as f64).sqrt())
.ok_or_else(|| LinalgError::ComputationError("scale convert".into()))?;
for v in da.iter_mut() {
*v *= scale;
}
let row_indices = sample_without_replacement(m_padded, k, rng);
let out_scale = F::from((m_padded as f64 / k as f64).sqrt())
.ok_or_else(|| LinalgError::ComputationError("out_scale convert".into()))?;
let mut result = Array2::<F>::zeros((k, n));
for (i, &row_idx) in row_indices.iter().enumerate() {
for j in 0..n {
result[[i, j]] = out_scale * da[[row_idx, j]];
}
}
Ok(result)
}
fn hadamard_transform_rows<F: SketchFloat>(a: &mut Array2<F>, m: usize, n: usize) {
let mut len = 1usize;
while len < m {
let stride = len * 2;
let mut start = 0;
while start < m {
for j in 0..n {
for i in 0..len {
let u = a[[start + i, j]];
let v = a[[start + i + len, j]];
a[[start + i, j]] = u + v;
a[[start + i + len, j]] = u - v;
}
}
start += stride;
}
len *= 2;
}
}
fn next_power_of_two(n: usize) -> usize {
if n == 0 {
return 1;
}
let mut p = 1usize;
while p < n {
p <<= 1;
}
p
}
fn sample_without_replacement(m: usize, k: usize, rng: &mut impl Rng) -> Vec<usize> {
let mut indices: Vec<usize> = (0..m).collect();
let k = k.min(m);
for i in 0..k {
let j = i + rng.random_range(0..(m - i));
indices.swap(i, j);
}
indices[..k].to_vec()
}
pub fn sparse_sign_sketch<F: SketchFloat>(
a: &ArrayView2<F>,
k: usize,
s: usize,
rng: &mut impl Rng,
) -> LinalgResult<Array2<F>> {
let (m, n) = (a.nrows(), a.ncols());
if k == 0 {
return Err(LinalgError::InvalidInputError(
"sparse_sign_sketch: k must be >= 1".into(),
));
}
let s = s.min(k).max(1);
let scale = F::from(1.0 / (s as f64).sqrt())
.ok_or_else(|| LinalgError::ComputationError("scale convert".into()))?;
let uniform_sign = Uniform::new(0u8, 2)
.map_err(|e| LinalgError::ComputationError(format!("Uniform sign dist: {e}")))?;
let mut result = Array2::<F>::zeros((k, n));
for l in 0..m {
let row_indices = sample_without_replacement(k, s, rng);
for &row in &row_indices {
let sign_bit = uniform_sign.sample(rng);
let sign = if sign_bit == 0 { scale } else { -scale };
for j in 0..n {
result[[row, j]] += sign * a[[l, j]];
}
}
}
Ok(result)
}
pub fn sketch_multiply<F: SketchFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
k: usize,
rng: &mut impl Rng,
) -> LinalgResult<Array2<F>> {
let (m, p) = (a.nrows(), a.ncols());
let (n, pb) = (b.nrows(), b.ncols());
if p != pb {
return Err(LinalgError::ShapeError(format!(
"sketch_multiply: inner dimensions {} vs {} must match",
p, pb
)));
}
let normal = Normal::new(0.0_f64, 1.0)
.map_err(|e| LinalgError::ComputationError(format!("Normal dist: {e}")))?;
let inv_sqrt_k = F::from(1.0 / (k as f64).sqrt())
.ok_or_else(|| LinalgError::ComputationError("inv_sqrt_k convert".into()))?;
let mut s = Array2::<F>::zeros((k, p));
for i in 0..k {
for j in 0..p {
s[[i, j]] = F::from(normal.sample(rng)).unwrap_or(F::zero()) * inv_sqrt_k;
}
}
let mut as_mat = Array2::<F>::zeros((m, k));
for i in 0..m {
for l in 0..p {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..k {
as_mat[[i, j]] += a_il * s[[j, l]];
}
}
}
let mut bs_mat = Array2::<F>::zeros((n, k));
for i in 0..n {
for l in 0..p {
let b_il = b[[i, l]];
if b_il == F::zero() {
continue;
}
for j in 0..k {
bs_mat[[i, j]] += b_il * s[[j, l]];
}
}
}
matmul_transpose_a(&as_mat.view().t().to_owned(), &bs_mat).or_else(|_| {
let mut result = Array2::<F>::zeros((m, n));
for i in 0..m {
for r in 0..n {
let mut dot = F::zero();
for j in 0..k {
dot += as_mat[[i, j]] * bs_mat[[r, j]];
}
result[[i, r]] = dot;
}
}
Ok(result)
})
}
pub fn sketch_multiply_direct<F: SketchFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
k: usize,
rng: &mut impl Rng,
) -> LinalgResult<Array2<F>> {
let (m, p) = (a.nrows(), a.ncols());
let (pb, n) = (b.nrows(), b.ncols());
if p != pb {
return Err(LinalgError::ShapeError(format!(
"sketch_multiply_direct: p={} vs pb={}",
p, pb
)));
}
let normal = Normal::new(0.0_f64, 1.0)
.map_err(|e| LinalgError::ComputationError(format!("Normal dist: {e}")))?;
let inv_sqrt_k = F::from(1.0 / (k as f64).sqrt())
.ok_or_else(|| LinalgError::ComputationError("inv_sqrt_k convert".into()))?;
let mut s_mat = Array2::<F>::zeros((k, p));
for i in 0..k {
for j in 0..p {
s_mat[[i, j]] = F::from(normal.sample(rng)).unwrap_or(F::zero()) * inv_sqrt_k;
}
}
let mut a_sk = Array2::<F>::zeros((m, k));
for i in 0..m {
for l in 0..p {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..k {
a_sk[[i, j]] += a_il * s_mat[[j, l]];
}
}
}
let mut s_b = Array2::<F>::zeros((k, n));
for i in 0..k {
for l in 0..p {
let s_il = s_mat[[i, l]];
if s_il == F::zero() {
continue;
}
for j in 0..n {
s_b[[i, j]] += s_il * b[[l, j]];
}
}
}
matmul_nn(&a_sk, &s_b)
}
pub fn leverage_scores<F: SketchFloat>(q: &ArrayView2<F>) -> Array1<F> {
let (m, _k) = (q.nrows(), q.ncols());
let mut scores = Array1::<F>::zeros(m);
for i in 0..m {
let mut norm_sq = F::zero();
for j in 0..q.ncols() {
norm_sq += q[[i, j]] * q[[i, j]];
}
scores[i] = norm_sq;
}
scores
}
pub fn sample_by_leverage<F: SketchFloat>(
scores: &Array1<F>,
k: usize,
rng: &mut impl Rng,
) -> Vec<usize> {
let m = scores.len();
if m == 0 || k == 0 {
return Vec::new();
}
let total: F = scores
.iter()
.copied()
.fold(F::zero(), |acc, x| acc + x.abs());
let probs: Vec<f64> = if total > F::zero() {
scores
.iter()
.map(|&s| s.abs().to_f64().unwrap_or(0.0) / total.to_f64().unwrap_or(1.0))
.collect()
} else {
vec![1.0 / m as f64; m]
};
let mut cdf = vec![0.0f64; m + 1];
for i in 0..m {
cdf[i + 1] = cdf[i] + probs[i];
}
let uniform = Uniform::new(0.0_f64, 1.0).unwrap_or_else(|_| {
Uniform::new(0.0, 1.0 - f64::EPSILON).expect("failed to create uniform")
});
(0..k)
.map(|_| {
let u = uniform.sample(rng);
let mut lo = 0;
let mut hi = m;
while lo < hi {
let mid = (lo + hi) / 2;
if cdf[mid + 1] < u {
lo = mid + 1;
} else {
hi = mid;
}
}
lo.min(m - 1)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
use scirs2_core::random::prelude::*;
fn make_rng() -> impl Rng {
scirs2_core::random::seeded_rng(12345)
}
#[test]
fn test_gaussian_sketch_dimensions() {
let a = array![
[1.0_f64, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0]
];
let mut rng = make_rng();
let sketch = gaussian_sketch(&a.view(), 2, &mut rng).expect("gaussian_sketch dims");
assert_eq!(sketch.nrows(), 2);
assert_eq!(sketch.ncols(), 3);
}
#[test]
fn test_gaussian_sketch_preserves_norms_approx() {
let m = 100;
let n = 1;
let k = 50;
let mut a = Array2::<f64>::zeros((m, n));
for i in 0..m {
a[[i, 0]] = (i as f64 + 1.0).sqrt();
}
let mut rng = scirs2_core::random::seeded_rng(99);
let sketch = gaussian_sketch(&a.view(), k, &mut rng).expect("sketch norm test");
let orig_norm_sq: f64 = (0..m).map(|i| a[[i, 0]] * a[[i, 0]]).sum();
let sketch_norm_sq: f64 = (0..k).map(|i| sketch[[i, 0]] * sketch[[i, 0]]).sum();
let ratio = sketch_norm_sq / orig_norm_sq;
assert!(ratio > 0.5 && ratio < 2.5, "ratio={ratio:.3}");
}
#[test]
fn test_srht_dimensions() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
let mut rng = make_rng();
let sketch = subsampled_rht(&a.view(), 2, &mut rng).expect("srht dims");
assert_eq!(sketch.nrows(), 2);
assert_eq!(sketch.ncols(), 2);
}
#[test]
fn test_sparse_sign_sketch_dimensions() {
let a = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let mut rng = make_rng();
let sketch = sparse_sign_sketch(&a.view(), 2, 1, &mut rng).expect("sparse sketch dims");
assert_eq!(sketch.nrows(), 2);
assert_eq!(sketch.ncols(), 3);
}
#[test]
fn test_sketch_multiply_dimensions() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b = array![[5.0_f64, 6.0], [7.0, 8.0]];
let mut rng = make_rng();
let ab = sketch_multiply(&a.view(), &b.view(), 10, &mut rng).expect("sketch_multiply");
assert_eq!(ab.nrows(), 2);
assert_eq!(ab.ncols(), 2);
}
#[test]
fn test_sketch_multiply_direct_dimensions() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let b = array![[1.0_f64, 2.0], [3.0, 4.0]];
let mut rng = make_rng();
let ab = sketch_multiply_direct(&a.view(), &b.view(), 10, &mut rng).expect("direct");
assert_eq!(ab.nrows(), 2);
assert_eq!(ab.ncols(), 2);
}
#[test]
fn test_leverage_scores_orthonormal() {
let q = Array2::<f64>::eye(3);
let scores = leverage_scores(&q.view());
for i in 0..3 {
assert!(
(scores[i] - 1.0).abs() < 1e-10,
"score {} = {}",
i,
scores[i]
);
}
}
#[test]
fn test_sample_by_leverage() {
let scores = Array1::from_vec(vec![0.5_f64, 0.3, 0.2]);
let mut rng = make_rng();
let samples = sample_by_leverage(&scores, 10, &mut rng);
assert_eq!(samples.len(), 10);
for &s in &samples {
assert!(s < 3, "sample {} out of range", s);
}
}
#[test]
fn test_hadamard_transform_involution() {
let mut a = Array2::<f64>::zeros((4, 2));
a[[0, 0]] = 1.0;
a[[1, 0]] = 2.0;
a[[2, 0]] = 3.0;
a[[3, 0]] = 4.0;
a[[0, 1]] = -1.0;
a[[3, 1]] = 1.0;
let original = a.clone();
hadamard_transform_rows(&mut a, 4, 2);
hadamard_transform_rows(&mut a, 4, 2);
for i in 0..4 {
for j in 0..2 {
assert!((a[[i, j]] / 4.0 - original[[i, j]]).abs() < 1e-10);
}
}
}
}