use std::f64::consts::PI;
use scirs2_core::random::{seeded_rng, Distribution, Normal, Rng, RngExt, SeedableRng, Uniform};
use crate::error::{Result, TransformError};
fn omp(x: &[f64], dictionary: &[Vec<f64>], sparsity: usize) -> Vec<f64> {
let d = x.len();
let n_atoms = dictionary.len();
let k = sparsity.min(n_atoms);
let mut residual = x.to_vec();
let mut support: Vec<usize> = Vec::with_capacity(k);
for _ in 0..k {
let best_idx = (0..n_atoms)
.filter(|idx| !support.contains(idx))
.max_by(|&a, &b| {
let ca: f64 = residual.iter().zip(dictionary[a].iter()).map(|(r, di)| r * di).sum::<f64>().abs();
let cb: f64 = residual.iter().zip(dictionary[b].iter()).map(|(r, di)| r * di).sum::<f64>().abs();
ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
});
let best_idx = match best_idx {
Some(i) => i,
None => break,
};
support.push(best_idx);
let s = support.len();
let mut dsd = vec![vec![0.0f64; s]; s]; let mut dsx = vec![0.0f64; s];
for (si, &ai) in support.iter().enumerate() {
let da = &dictionary[ai];
for (sj, &aj) in support.iter().enumerate() {
let db = &dictionary[aj];
let dot: f64 = da.iter().zip(db.iter()).map(|(a, b)| a * b).sum();
dsd[si][sj] = dot;
}
dsx[si] = da.iter().zip(x.iter()).map(|(a, b)| a * b).sum();
}
let alpha_s = solve_small_system(&dsd, &dsx);
residual = x.to_vec();
for (si, &ai) in support.iter().enumerate() {
for (fi, r) in residual.iter_mut().enumerate() {
if fi < d && fi < dictionary[ai].len() {
*r -= alpha_s[si] * dictionary[ai][fi];
}
}
}
let res_norm: f64 = residual.iter().map(|r| r * r).sum::<f64>().sqrt();
if res_norm < 1e-10 {
break;
}
}
let mut alpha = vec![0.0f64; n_atoms];
let s = support.len();
if s > 0 {
let mut dsd = vec![vec![0.0f64; s]; s];
let mut dsx = vec![0.0f64; s];
for (si, &ai) in support.iter().enumerate() {
let da = &dictionary[ai];
for (sj, &aj) in support.iter().enumerate() {
let db = &dictionary[aj];
let dot: f64 = da.iter().zip(db.iter()).map(|(a, b)| a * b).sum();
dsd[si][sj] = dot;
}
dsx[si] = da.iter().zip(x.iter()).map(|(a, b)| a * b).sum();
}
let alpha_s = solve_small_system(&dsd, &dsx);
for (si, &ai) in support.iter().enumerate() {
alpha[ai] = alpha_s[si];
}
}
alpha
}
fn solve_small_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
let n = b.len();
if n == 0 {
return vec![];
}
let mut mat: Vec<Vec<f64>> = (0..n)
.map(|i| {
let mut row = a[i].clone();
row.push(b[i]);
row
})
.collect();
for col in 0..n {
let pivot_row = (col..n).max_by(|&i, &j| {
mat[i][col].abs().partial_cmp(&mat[j][col].abs()).unwrap_or(std::cmp::Ordering::Equal)
});
let pivot_row = match pivot_row {
Some(r) => r,
None => break,
};
if mat[pivot_row][col].abs() < 1e-12 {
continue;
}
mat.swap(col, pivot_row);
let pivot = mat[col][col];
for j in col..=n {
mat[col][j] /= pivot;
}
for i in (col + 1)..n {
let factor = mat[i][col];
for j in col..=n {
let sub = factor * mat[col][j];
mat[i][j] -= sub;
}
}
}
let mut x = vec![0.0f64; n];
for i in (0..n).rev() {
x[i] = mat[i][n];
for j in (i + 1)..n {
x[i] -= mat[i][j] * x[j];
}
}
x
}
#[derive(Debug, Clone)]
pub struct SparseDictTransform {
pub dictionary: Vec<Vec<f64>>,
pub sparsity: usize,
}
impl SparseDictTransform {
pub fn fit(
x: &[Vec<f64>],
n_atoms: usize,
sparsity: usize,
n_iter: usize,
seed: u64,
) -> Result<Self> {
let n = x.len();
if n == 0 {
return Err(TransformError::InvalidInput("Empty dataset".to_string()));
}
let d = x[0].len();
if d == 0 {
return Err(TransformError::InvalidInput("Feature dim must be > 0".to_string()));
}
if n_atoms == 0 {
return Err(TransformError::InvalidInput("n_atoms must be > 0".to_string()));
}
let mut rng = seeded_rng(seed);
let mut dictionary: Vec<Vec<f64>> = (0..n_atoms)
.map(|k| {
let idx = k % n;
let atom = &x[idx];
normalize_atom(atom)
})
.collect();
let noise_dist = Normal::new(0.0_f64, 0.01).map_err(|e| {
TransformError::ComputationError(format!("Normal distribution: {e}"))
})?;
for atom in dictionary.iter_mut() {
for v in atom.iter_mut() {
*v += noise_dist.sample(&mut rng);
}
*atom = normalize_atom(atom);
}
for _iter in 0..n_iter {
let codes: Vec<Vec<f64>> = x.iter().map(|xi| omp(xi, &dictionary, sparsity)).collect();
for k in 0..n_atoms {
let users: Vec<usize> = (0..n)
.filter(|&i| codes[i][k].abs() > 1e-10)
.collect();
if users.is_empty() {
let idx = (rng.next_u64() as usize) % n;
dictionary[k] = normalize_atom(&x[idx]);
continue;
}
let mut e_k: Vec<Vec<f64>> = users
.iter()
.map(|&i| {
let mut res = x[i].clone();
for (m, atom) in dictionary.iter().enumerate() {
if m == k {
continue;
}
let coef = codes[i][m];
if coef.abs() < 1e-12 {
continue;
}
for (fi, r) in res.iter_mut().enumerate() {
if fi < atom.len() {
*r -= coef * atom[fi];
}
}
}
res
})
.collect();
let coefs_k: Vec<f64> = users.iter().map(|&i| codes[i][k]).collect();
let coef_sq: f64 = coefs_k.iter().map(|c| c * c).sum::<f64>();
if coef_sq < 1e-12 {
let idx = (rng.next_u64() as usize) % n;
dictionary[k] = normalize_atom(&x[idx]);
continue;
}
let mut new_atom = vec![0.0f64; d];
for (ui, &coef) in coefs_k.iter().enumerate() {
for (fi, &ev) in e_k[ui].iter().enumerate() {
if fi < d {
new_atom[fi] += coef * ev;
}
}
}
for v in new_atom.iter_mut() {
*v /= coef_sq;
}
dictionary[k] = normalize_atom(&new_atom);
for (ui, &i_orig) in users.iter().enumerate() {
let _ = i_orig; let _ = &mut e_k[ui]; }
}
}
Ok(SparseDictTransform { dictionary, sparsity })
}
pub fn transform(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if self.dictionary.is_empty() {
return Err(TransformError::NotFitted(
"Dictionary is empty".to_string(),
));
}
let d = self.dictionary[0].len();
let n_atoms = self.dictionary.len();
let mut out = Vec::with_capacity(x.len());
for (i, row) in x.iter().enumerate() {
if row.len() != d {
return Err(TransformError::InvalidInput(format!(
"Row {i}: expected {d} features, got {}",
row.len()
)));
}
out.push(omp(row, &self.dictionary, self.sparsity));
}
Ok(out)
}
pub fn reconstruct(&self, codes: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if self.dictionary.is_empty() {
return Err(TransformError::NotFitted("Empty dictionary".to_string()));
}
let d = self.dictionary[0].len();
let n_atoms = self.dictionary.len();
let mut out = Vec::with_capacity(codes.len());
for (i, code) in codes.iter().enumerate() {
if code.len() != n_atoms {
return Err(TransformError::InvalidInput(format!(
"Code {i}: expected {n_atoms} atoms, got {}",
code.len()
)));
}
let mut rec = vec![0.0f64; d];
for (k, &ck) in code.iter().enumerate() {
if ck.abs() < 1e-12 {
continue;
}
for (fi, r) in rec.iter_mut().enumerate() {
if fi < self.dictionary[k].len() {
*r += ck * self.dictionary[k][fi];
}
}
}
out.push(rec);
}
Ok(out)
}
}
fn normalize_atom(v: &[f64]) -> Vec<f64> {
let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-12 {
return vec![0.0f64; v.len()];
}
v.iter().map(|x| x / norm).collect()
}
#[derive(Debug, Clone)]
pub struct RandomFourierFeatures {
pub n_components: usize,
pub gamma: f64,
pub random_weights: Vec<Vec<f64>>,
pub biases: Vec<f64>,
pub n_features: usize,
}
impl RandomFourierFeatures {
pub fn new(n_components: usize, gamma: f64, n_features: usize, seed: u64) -> Result<Self> {
if n_components == 0 {
return Err(TransformError::InvalidInput(
"n_components must be > 0".to_string(),
));
}
if n_features == 0 {
return Err(TransformError::InvalidInput(
"n_features must be > 0".to_string(),
));
}
if gamma <= 0.0 {
return Err(TransformError::InvalidInput(
"gamma must be > 0".to_string(),
));
}
let omega_std = (2.0 * gamma).sqrt();
let omega_dist = Normal::new(0.0_f64, omega_std).map_err(|e| {
TransformError::ComputationError(format!("Normal dist: {e}"))
})?;
let bias_dist = Uniform::new(0.0_f64, 2.0 * PI).map_err(|e| {
TransformError::ComputationError(format!("Uniform dist: {e}"))
})?;
let mut rng = seeded_rng(seed);
let random_weights: Vec<Vec<f64>> = (0..n_components)
.map(|_| (0..n_features).map(|_| omega_dist.sample(&mut rng)).collect())
.collect();
let biases: Vec<f64> = (0..n_components).map(|_| bias_dist.sample(&mut rng)).collect();
Ok(RandomFourierFeatures {
n_components,
gamma,
random_weights,
biases,
n_features,
})
}
pub fn transform(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let scale = (2.0 / self.n_components as f64).sqrt();
let mut out = Vec::with_capacity(x.len());
for (i, row) in x.iter().enumerate() {
if row.len() != self.n_features {
return Err(TransformError::InvalidInput(format!(
"Row {i}: expected {} features, got {}",
self.n_features,
row.len()
)));
}
let features: Vec<f64> = self
.random_weights
.iter()
.zip(self.biases.iter())
.map(|(omega, &bias)| {
let dot: f64 = omega.iter().zip(row.iter()).map(|(o, xi)| o * xi).sum();
scale * (dot + bias).cos()
})
.collect();
out.push(features);
}
Ok(out)
}
pub fn estimate_kernel(&self, x: &[f64], y: &[f64]) -> Result<f64> {
let px = self.transform(&[x.to_vec()])?;
let py = self.transform(&[y.to_vec()])?;
let k: f64 = px[0].iter().zip(py[0].iter()).map(|(a, b)| a * b).sum();
Ok(k)
}
}
#[derive(Debug, Clone)]
pub struct PolynomialRandomFeatures {
pub n_components: usize,
pub degree: usize,
pub gamma: f64,
pub coef0: f64,
h_maps: Vec<Vec<usize>>,
s_maps: Vec<Vec<i8>>,
pub n_features: usize,
}
impl PolynomialRandomFeatures {
pub fn new(
n_components: usize,
degree: usize,
gamma: f64,
coef0: f64,
n_features: usize,
seed: u64,
) -> Result<Self> {
if n_components == 0 || degree == 0 || n_features == 0 {
return Err(TransformError::InvalidInput(
"n_components, degree, and n_features must all be > 0".to_string(),
));
}
let mut rng = seeded_rng(seed);
let h_dist = Uniform::new(0_usize, n_components).map_err(|e| {
TransformError::ComputationError(format!("Uniform h dist: {e}"))
})?;
let s_dist = Uniform::new(0_usize, 2).map_err(|e| {
TransformError::ComputationError(format!("Uniform s dist: {e}"))
})?;
let mut h_maps: Vec<Vec<usize>> = Vec::with_capacity(degree);
let mut s_maps: Vec<Vec<i8>> = Vec::with_capacity(degree);
for _ in 0..degree {
let h: Vec<usize> = (0..n_features).map(|_| h_dist.sample(&mut rng)).collect();
let s: Vec<i8> = (0..n_features)
.map(|_| if s_dist.sample(&mut rng) == 0 { -1i8 } else { 1i8 })
.collect();
h_maps.push(h);
s_maps.push(s);
}
Ok(PolynomialRandomFeatures {
n_components,
degree,
gamma,
coef0,
h_maps,
s_maps,
n_features,
})
}
fn count_sketch(&self, x: &[f64], j: usize) -> Vec<f64> {
let mut sketch = vec![0.0f64; self.n_components];
for (i, &xi) in x.iter().enumerate() {
let bucket = self.h_maps[j][i];
let sign = self.s_maps[j][i] as f64;
sketch[bucket] += sign * xi;
}
sketch
}
fn circular_convolve(a: &[f64], b: &[f64]) -> Vec<f64> {
let n = a.len();
let mut out = vec![0.0f64; n];
for i in 0..n {
for j in 0..n {
out[(i + j) % n] += a[i] * b[j];
}
}
out
}
pub fn transform(&self, x: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let mut out = Vec::with_capacity(x.len());
for (i, row) in x.iter().enumerate() {
if row.len() != self.n_features {
return Err(TransformError::InvalidInput(format!(
"Row {i}: expected {} features, got {}",
self.n_features,
row.len()
)));
}
let scaled: Vec<f64> = row.iter().map(|&v| self.gamma * v).collect();
let sketches: Vec<Vec<f64>> = (0..self.degree)
.map(|j| self.count_sketch(&scaled, j))
.collect();
let mut feature = sketches[0].clone();
for j in 1..self.degree {
feature = Self::circular_convolve(&feature, &sketches[j]);
}
let scale = 1.0 / (self.n_components as f64).sqrt();
let feature: Vec<f64> = feature.iter().map(|&v| v * scale).collect();
out.push(feature);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_data(n: usize, d: usize, seed: u64) -> Vec<Vec<f64>> {
let mut rng = seeded_rng(seed);
let dist = Normal::new(0.0_f64, 1.0).expect("Normal");
(0..n)
.map(|_| (0..d).map(|_| dist.sample(&mut rng)).collect())
.collect()
}
#[test]
fn test_rff_transform_shape() {
let x = make_data(20, 5, 0);
let rff = RandomFourierFeatures::new(50, 1.0, 5, 42).expect("RFF::new");
let phi = rff.transform(&x).expect("transform");
assert_eq!(phi.len(), 20);
assert_eq!(phi[0].len(), 50);
}
#[test]
fn test_rff_kernel_approximation() {
let x = vec![vec![1.0, 0.0, 0.0]];
let y = vec![vec![1.0, 0.0, 0.0]]; let rff = RandomFourierFeatures::new(5000, 1.0, 3, 0).expect("RFF::new");
let k = rff.estimate_kernel(&x[0], &y[0]).expect("kernel");
assert!((k - 1.0).abs() < 0.05, "RBF(x,x) ≈ 1, got {k:.4}");
}
#[test]
fn test_rff_kernel_decreasing_with_distance() {
let rff = RandomFourierFeatures::new(2000, 1.0, 3, 1).expect("RFF::new");
let x = vec![0.0, 0.0, 0.0];
let y_near = vec![0.1, 0.0, 0.0];
let y_far = vec![2.0, 0.0, 0.0];
let k_near = rff.estimate_kernel(&x, &y_near).expect("k_near");
let k_far = rff.estimate_kernel(&x, &y_far).expect("k_far");
assert!(k_near > k_far, "Near kernel {k_near:.4} should exceed far {k_far:.4}");
}
#[test]
fn test_rff_invalid() {
assert!(RandomFourierFeatures::new(0, 1.0, 3, 0).is_err());
assert!(RandomFourierFeatures::new(10, 0.0, 3, 0).is_err()); assert!(RandomFourierFeatures::new(10, -1.0, 3, 0).is_err());
}
#[test]
fn test_sparse_dict_basic() {
let x = make_data(30, 8, 5);
let sdt = SparseDictTransform::fit(&x, 16, 3, 5, 0).expect("fit");
assert_eq!(sdt.dictionary.len(), 16);
assert_eq!(sdt.dictionary[0].len(), 8);
let codes = sdt.transform(&x).expect("transform");
assert_eq!(codes.len(), 30);
assert_eq!(codes[0].len(), 16);
for code in &codes {
let nnz = code.iter().filter(|&&v| v.abs() > 1e-10).count();
assert!(nnz <= sdt.sparsity, "NNZ {nnz} > sparsity {}", sdt.sparsity);
}
}
#[test]
fn test_sparse_dict_reconstruct() {
let x = make_data(20, 6, 10);
let sdt = SparseDictTransform::fit(&x, 12, 4, 10, 1).expect("fit");
let codes = sdt.transform(&x).expect("transform");
let recon = sdt.reconstruct(&codes).expect("reconstruct");
assert_eq!(recon.len(), 20);
assert_eq!(recon[0].len(), 6);
}
#[test]
fn test_poly_rff_shape() {
let x = make_data(15, 4, 0);
let prff = PolynomialRandomFeatures::new(32, 3, 1.0, 0.0, 4, 0).expect("new");
let out = prff.transform(&x).expect("transform");
assert_eq!(out.len(), 15);
assert_eq!(out[0].len(), 32);
}
#[test]
fn test_omp_basic() {
let d0 = vec![1.0, 0.0];
let d1 = vec![0.0, 1.0];
let d2 = vec![1.0 / 2.0_f64.sqrt(), 1.0 / 2.0_f64.sqrt()];
let dictionary = vec![d0, d1, d2];
let x = vec![0.5, 0.3];
let codes = omp(&x, &dictionary, 2);
assert_eq!(codes.len(), 3);
let recon: Vec<f64> = (0..2)
.map(|fi| codes.iter().zip(dictionary.iter()).map(|(c, d)| c * d[fi]).sum::<f64>())
.collect();
let err: f64 = x.iter().zip(recon.iter()).map(|(a, b)| (a - b).powi(2)).sum::<f64>();
assert!(err < 1e-8, "OMP reconstruction error {err:.2e}");
}
}