use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use scirs2_core::random::{ChaCha8Rng, SeedableRng};
use std::fmt::Debug;
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone)]
pub struct GaussianSketch<F> {
pub matrix: Array2<F>,
pub sketch_dim: usize,
pub input_dim: usize,
}
impl<F> GaussianSketch<F>
where
F: Float + NumAssign + FromPrimitive + Debug + Sum + 'static,
{
pub fn new(sketch_dim: usize, input_dim: usize, seed: Option<u64>) -> LinalgResult<Self> {
if sketch_dim == 0 {
return Err(LinalgError::ValueError(
"sketch_dim must be positive".to_string(),
));
}
if input_dim == 0 {
return Err(LinalgError::ValueError(
"input_dim must be positive".to_string(),
));
}
let matrix = gaussian_random_matrix(sketch_dim, input_dim, seed)?;
Ok(Self {
matrix,
sketch_dim,
input_dim,
})
}
pub fn apply(&self, a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let (n, _d) = a.dim();
if n != self.input_dim {
return Err(LinalgError::DimensionError(format!(
"GaussianSketch: expected {} rows, got {}",
self.input_dim, n
)));
}
matmul_2d(&self.matrix.view(), a)
}
}
#[derive(Debug, Clone)]
pub struct SRHTTransform<F> {
pub signs: Array1<F>,
pub row_indices: Vec<usize>,
pub sketch_dim: usize,
pub input_dim: usize,
}
impl<F> SRHTTransform<F>
where
F: Float + NumAssign + FromPrimitive + Debug + Sum + 'static,
{
pub fn new(sketch_dim: usize, input_dim: usize, seed: Option<u64>) -> LinalgResult<Self> {
if sketch_dim == 0 {
return Err(LinalgError::ValueError(
"sketch_dim must be positive".to_string(),
));
}
if input_dim == 0 {
return Err(LinalgError::ValueError(
"input_dim must be positive".to_string(),
));
}
if !input_dim.is_power_of_two() {
return Err(LinalgError::ValueError(format!(
"SRHTTransform: input_dim must be a power of 2, got {}",
input_dim
)));
}
if sketch_dim > input_dim {
return Err(LinalgError::ValueError(format!(
"sketch_dim ({sketch_dim}) must be ≤ input_dim ({input_dim})"
)));
}
let mut rng: ChaCha8Rng = match seed {
Some(s) => ChaCha8Rng::seed_from_u64(s),
None => {
let t = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(12345);
ChaCha8Rng::seed_from_u64(t)
}
};
let mut signs = Array1::zeros(input_dim);
for i in 0..input_dim {
let bit: u8 = (rng.next_u32() & 1) as u8;
signs[i] = if bit == 0 { F::one() } else { -F::one() };
}
let row_indices = sample_without_replacement(input_dim, sketch_dim, &mut rng);
Ok(Self {
signs,
row_indices,
sketch_dim,
input_dim,
})
}
pub fn apply(&self, a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let (n, d) = a.dim();
if n != self.input_dim {
return Err(LinalgError::DimensionError(format!(
"SRHTTransform: expected {} rows, got {}",
self.input_dim, n
)));
}
let scale = F::from(1.0 / (self.sketch_dim as f64).sqrt()).ok_or_else(|| {
LinalgError::ComputationError("SRHT scale conversion failed".to_string())
})?;
let mut work = Array2::<F>::zeros((n, d));
for i in 0..n {
for j in 0..d {
work[[i, j]] = self.signs[i] * a[[i, j]];
}
}
for j in 0..d {
let mut col: Vec<F> = (0..n).map(|i| work[[i, j]]).collect();
fwht_inplace(&mut col);
for i in 0..n {
work[[i, j]] = col[i];
}
}
let mut out = Array2::<F>::zeros((self.sketch_dim, d));
for (out_row, &src_row) in self.row_indices.iter().enumerate() {
for j in 0..d {
out[[out_row, j]] = work[[src_row, j]] * scale;
}
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct CountSketchMatrix<F> {
pub hash: Vec<usize>,
pub signs: Array1<F>,
pub sketch_dim: usize,
pub input_dim: usize,
}
impl<F> CountSketchMatrix<F>
where
F: Float + NumAssign + FromPrimitive + Debug + Sum + 'static,
{
pub fn new(sketch_dim: usize, input_dim: usize, seed: Option<u64>) -> LinalgResult<Self> {
if sketch_dim == 0 {
return Err(LinalgError::ValueError(
"sketch_dim must be positive".to_string(),
));
}
if input_dim == 0 {
return Err(LinalgError::ValueError(
"input_dim must be positive".to_string(),
));
}
let mut rng: ChaCha8Rng = match seed {
Some(s) => ChaCha8Rng::seed_from_u64(s),
None => {
let t = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(99999);
ChaCha8Rng::seed_from_u64(t)
}
};
let mut hash = Vec::with_capacity(input_dim);
let mut signs = Array1::zeros(input_dim);
for i in 0..input_dim {
hash.push((rng.next_u64() as usize) % sketch_dim);
let bit: u8 = (rng.next_u32() & 1) as u8;
signs[i] = if bit == 0 { F::one() } else { -F::one() };
}
Ok(Self {
hash,
signs,
sketch_dim,
input_dim,
})
}
pub fn apply(&self, a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let (n, d) = a.dim();
if n != self.input_dim {
return Err(LinalgError::DimensionError(format!(
"CountSketchMatrix: expected {} rows, got {}",
self.input_dim, n
)));
}
let mut out = Array2::<F>::zeros((self.sketch_dim, d));
for i in 0..n {
let bucket = self.hash[i];
let sign = self.signs[i];
for j in 0..d {
out[[bucket, j]] += sign * a[[i, j]];
}
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct JLTransform<F> {
pub matrix: Array2<F>,
pub target_dim: usize,
pub source_dim: usize,
pub epsilon: F,
}
impl<F> JLTransform<F>
where
F: Float + NumAssign + FromPrimitive + Debug + Sum + 'static,
{
pub fn new(
source_dim: usize,
n_points: usize,
epsilon: F,
seed: Option<u64>,
) -> LinalgResult<Self> {
if source_dim == 0 {
return Err(LinalgError::ValueError(
"source_dim must be positive".to_string(),
));
}
let eps_f64 = epsilon
.to_f64()
.ok_or_else(|| LinalgError::ValueError("epsilon conversion failed".to_string()))?;
if !(0.0 < eps_f64 && eps_f64 < 1.0) {
return Err(LinalgError::ValueError(format!(
"epsilon must be in (0,1), got {eps_f64}"
)));
}
let n_pts_f = (n_points.max(2) as f64).ln();
let denom = eps_f64 * eps_f64 / 2.0 - eps_f64 * eps_f64 * eps_f64 / 3.0;
let target_dim = ((4.0 * n_pts_f / denom).ceil() as usize).max(1).min(source_dim);
let matrix = gaussian_random_matrix(target_dim, source_dim, seed)?;
Ok(Self {
matrix,
target_dim,
source_dim,
epsilon,
})
}
pub fn with_target_dim(
source_dim: usize,
target_dim: usize,
epsilon: F,
seed: Option<u64>,
) -> LinalgResult<Self> {
if source_dim == 0 || target_dim == 0 {
return Err(LinalgError::ValueError(
"source_dim and target_dim must be positive".to_string(),
));
}
let matrix = gaussian_random_matrix(target_dim, source_dim, seed)?;
Ok(Self {
matrix,
target_dim,
source_dim,
epsilon,
})
}
pub fn embed_point(&self, x: &Array1<F>) -> LinalgResult<Array1<F>> {
if x.len() != self.source_dim {
return Err(LinalgError::DimensionError(format!(
"JLTransform: expected vector length {}, got {}",
self.source_dim,
x.len()
)));
}
let mut out = Array1::<F>::zeros(self.target_dim);
for i in 0..self.target_dim {
let mut acc = F::zero();
for j in 0..self.source_dim {
acc += self.matrix[[i, j]] * x[j];
}
out[i] = acc;
}
Ok(out)
}
pub fn embed_rows(&self, x: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let (rows, cols) = x.dim();
if cols != self.source_dim {
return Err(LinalgError::DimensionError(format!(
"JLTransform: expected {} columns, got {}",
self.source_dim, cols
)));
}
let mut out = Array2::<F>::zeros((rows, self.target_dim));
for i in 0..rows {
for k in 0..self.target_dim {
let mut acc = F::zero();
for j in 0..self.source_dim {
acc += x[[i, j]] * self.matrix[[k, j]];
}
out[[i, k]] = acc;
}
}
Ok(out)
}
}
pub fn apply_sketch<F>(sketch: &ArrayView2<F>, a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + FromPrimitive + Debug + Sum + 'static,
{
let (m, n) = sketch.dim();
let (a_rows, _d) = a.dim();
if n != a_rows {
return Err(LinalgError::DimensionError(format!(
"apply_sketch: sketch has {n} columns but A has {a_rows} rows"
)));
}
matmul_2d(sketch, a)
}
pub fn jl_embed_points<F>(
x: &ArrayView2<F>,
epsilon: F,
seed: Option<u64>,
) -> LinalgResult<(Array2<F>, JLTransform<F>)>
where
F: Float + NumAssign + FromPrimitive + Debug + Sum + 'static,
{
let (n_points, source_dim) = x.dim();
if n_points == 0 {
return Err(LinalgError::ValueError(
"jl_embed_points: empty input matrix".to_string(),
));
}
let transform = JLTransform::new(source_dim, n_points, epsilon, seed)?;
let embedded = transform.embed_rows(x)?;
Ok((embedded, transform))
}
fn gaussian_random_matrix<F>(rows: usize, cols: usize, seed: Option<u64>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + FromPrimitive + Debug + 'static,
{
use scirs2_core::random::{Distribution, Normal};
let mut rng: ChaCha8Rng = match seed {
Some(s) => ChaCha8Rng::seed_from_u64(s),
None => {
let t = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(777);
ChaCha8Rng::seed_from_u64(t)
}
};
let std_dev = 1.0 / (cols as f64).sqrt();
let normal = Normal::new(0.0, std_dev).map_err(|e| {
LinalgError::ComputationError(format!("Normal distribution creation failed: {e}"))
})?;
let mut mat = Array2::<F>::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
let v: f64 = normal.sample(&mut rng);
mat[[i, j]] = F::from(v).ok_or_else(|| {
LinalgError::ComputationError(format!("Failed to convert f64 {v} to F"))
})?;
}
}
Ok(mat)
}
fn fwht_inplace<F>(x: &mut [F])
where
F: Float + NumAssign,
{
let n = x.len();
let mut h = 1usize;
while h < n {
let mut i = 0;
while i < n {
for j in i..(i + h) {
let u = x[j];
let v = x[j + h];
x[j] = u + v;
x[j + h] = u - v;
}
i += 2 * h;
}
h *= 2;
}
}
fn sample_without_replacement<R>(n: usize, k: usize, rng: &mut R) -> Vec<usize>
where
R: scirs2_core::random::Rng,
{
let mut pool: Vec<usize> = (0..n).collect();
for i in 0..k {
let j = i + (rng.next_u64() as usize) % (n - i);
pool.swap(i, j);
}
pool[..k].to_vec()
}
fn matmul_2d<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + 'static,
{
let (m, k) = a.dim();
let (k2, n) = b.dim();
if k != k2 {
return Err(LinalgError::DimensionError(format!(
"matmul_2d: inner dimensions {k} and {k2} do not match"
)));
}
let mut c = Array2::<F>::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut acc = F::zero();
for l in 0..k {
acc += a[[i, l]] * b[[l, j]];
}
c[[i, j]] = acc;
}
}
Ok(c)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn row_norm(a: &Array2<f64>, row: usize) -> f64 {
let mut s = 0.0f64;
for j in 0..a.ncols() {
s += a[[row, j]] * a[[row, j]];
}
s.sqrt()
}
#[test]
fn test_gaussian_sketch_shape() {
let a = Array2::<f64>::ones((10, 4));
let sk = GaussianSketch::new(3, 10, Some(1)).expect("gaussian sketch");
let sa = sk.apply(&a.view()).expect("apply failed");
assert_eq!(sa.shape(), &[3, 4]);
}
#[test]
fn test_gaussian_sketch_dimension_mismatch() {
let a = Array2::<f64>::ones((5, 4));
let sk = GaussianSketch::new(3, 10, Some(1)).expect("gaussian sketch");
assert!(sk.apply(&a.view()).is_err());
}
#[test]
fn test_apply_sketch_helper() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
let sk = GaussianSketch::new(2, 3, Some(42)).expect("sketch");
let sa = apply_sketch(&sk.matrix.view(), &a.view()).expect("apply_sketch");
assert_eq!(sa.shape(), &[2, 2]);
}
#[test]
fn test_srht_shape_power_of_two() {
let n = 8usize;
let m = 4usize;
let a = Array2::<f64>::ones((n, 3));
let srht = SRHTTransform::new(m, n, Some(7)).expect("srht");
let sa = srht.apply(&a.view()).expect("srht apply");
assert_eq!(sa.shape(), &[m, 3]);
}
#[test]
fn test_srht_rejects_non_power_of_two() {
assert!(SRHTTransform::<f64>::new(3, 7, None).is_err());
}
#[test]
fn test_count_sketch_shape() {
let a = Array2::<f64>::ones((20, 5));
let cs = CountSketchMatrix::new(8, 20, Some(3)).expect("count sketch");
let sa = cs.apply(&a.view()).expect("cs apply");
assert_eq!(sa.shape(), &[8, 5]);
}
#[test]
fn test_count_sketch_linearity() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0], [1.0, 1.0]];
let b = array![[0.0_f64, 1.0], [1.0, 0.0], [1.0, -1.0]];
let cs = CountSketchMatrix::new(2, 3, Some(99)).expect("cs");
let sa = cs.apply(&a.view()).expect("sa");
let sb = cs.apply(&b.view()).expect("sb");
let mut ab = a.clone();
for i in 0..3 {
for j in 0..2 {
ab[[i, j]] += b[[i, j]];
}
}
let s_ab = cs.apply(&ab.view()).expect("s_ab");
for i in 0..2 {
for j in 0..2 {
assert!(
(s_ab[[i, j]] - sa[[i, j]] - sb[[i, j]]).abs() < 1e-12,
"linearity violated at [{i},{j}]"
);
}
}
}
#[test]
fn test_jl_embed_points_shape() {
let x = Array2::<f64>::ones((50, 100));
let (emb, t) = jl_embed_points(&x.view(), 0.3, Some(11)).expect("jl embed");
assert_eq!(emb.nrows(), 50);
assert_eq!(emb.ncols(), t.target_dim);
assert!(t.target_dim <= 100);
}
#[test]
fn test_jl_preserves_distances_roughly() {
let x = array![
[1.0_f64, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
];
let transform =
JLTransform::with_target_dim(8, 4, 0.3f64, Some(7)).expect("jl transform");
let emb = transform.embed_rows(&x.view()).expect("embed");
for i in 0..4 {
let n = row_norm(&emb, i);
assert!(n > 0.0, "row {i} has zero norm after JL");
}
}
#[test]
fn test_fwht_known_values() {
let mut x = vec![1.0f64, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0];
fwht_inplace(&mut x);
let sum: f64 = x.iter().sum();
let original_sum = 1.0 + 0.0 + 1.0 + 0.0 + 0.0 + 1.0 + 1.0 + 0.0;
assert!(
(x[0] - original_sum).abs() < 1e-12,
"FWHT x[0] = {}, expected {}",
x[0],
original_sum
);
}
#[test]
fn test_jl_transform_with_target_dim() {
let t = JLTransform::<f64>::with_target_dim(10, 3, 0.3, Some(1)).expect("jl");
assert_eq!(t.target_dim, 3);
assert_eq!(t.source_dim, 10);
}
}