use crate::error::InterpolateError;
use crate::random_features::feature_map::{FourierFeatureMap, Lcg64, RffKernel};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
fn gram_schmidt(a: &Array2<f64>) -> Array2<f64> {
let m = a.nrows();
let n = a.ncols();
let mut q = Array2::<f64>::zeros((m, n));
for j in 0..n {
let mut v: Vec<f64> = a.column(j).to_vec();
for k in 0..j {
let qk: Vec<f64> = q.column(k).to_vec();
let dot: f64 = v.iter().zip(qk.iter()).map(|(vi, qi)| vi * qi).sum();
for i in 0..m {
v[i] -= dot * qk[i];
}
}
let norm: f64 = v.iter().map(|vi| vi * vi).sum::<f64>().sqrt();
if norm > 1e-12 {
for i in 0..m {
q[(i, j)] = v[i] / norm;
}
}
}
q
}
#[derive(Debug, Clone)]
pub struct OrthogonalFourierFeatureMap {
base: FourierFeatureMap,
}
impl OrthogonalFourierFeatureMap {
pub fn new(kernel: RffKernel, d_in: usize, d_out: usize, seed: u64) -> Self {
assert!(d_in > 0, "d_in must be > 0");
assert!(d_out > 0, "d_out must be > 0");
let mut rng = Lcg64::new(seed);
let ls = kernel.length_scale().max(1e-300);
let n_blocks = (d_out + d_in - 1) / d_in; let d_out_padded = n_blocks * d_in;
let mut omega_rows: Vec<Vec<f64>> = Vec::with_capacity(d_out_padded);
for _ in 0..n_blocks {
let g = Array2::from_shape_fn((d_in, d_in), |_| rng.next_normal());
let q = gram_schmidt(&g);
for qi_idx in 0..d_in {
let q_row: Vec<f64> = (0..d_in).map(|j| q[(j, qi_idx)]).collect();
let scale = sample_row_scale(&kernel, &mut rng, d_in, ls);
let omega_row: Vec<f64> = q_row.iter().map(|v| v * scale).collect();
omega_rows.push(omega_row);
}
}
omega_rows.truncate(d_out);
let bias_data: Vec<f64> = (0..d_out)
.map(|_| rng.next_f64() * 2.0 * std::f64::consts::PI)
.collect();
let omega_flat: Vec<f64> = omega_rows.iter().flat_map(|r| r.iter().copied()).collect();
let omega = Array2::from_shape_vec((d_out, d_in), omega_flat).expect("shape consistent");
let bias = Array1::from_vec(bias_data);
let scale = (2.0 / d_out as f64).sqrt();
let base = FourierFeatureMap::from_parts(kernel, d_in, d_out, omega, bias, scale);
Self { base }
}
pub fn transform(&self, x: &ArrayView2<f64>) -> Result<Array2<f64>, InterpolateError> {
self.base.transform(x)
}
pub fn kernel_approx(&self, x1: &[f64], x2: &[f64]) -> Result<f64, InterpolateError> {
self.base.kernel_approx(x1, x2)
}
pub fn d_in(&self) -> usize {
self.base.d_in
}
pub fn d_out(&self) -> usize {
self.base.d_out
}
}
fn sample_row_scale(kernel: &RffKernel, rng: &mut Lcg64, d_in: usize, ls: f64) -> f64 {
match kernel {
RffKernel::Gaussian { .. } => {
let chi2: f64 = (0..d_in).map(|_| rng.next_normal().powi(2)).sum();
chi2.sqrt() / ls
}
RffKernel::Laplacian { .. } => {
let norm2: f64 = (0..d_in).map(|_| rng.next_cauchy().powi(2)).sum();
norm2.sqrt() / ls
}
RffKernel::Matern32 { .. } => {
let norm2: f64 = (0..d_in).map(|_| rng.next_student_t(3).powi(2)).sum();
norm2.sqrt() / ls
}
RffKernel::Matern52 { .. } => {
let norm2: f64 = (0..d_in).map(|_| rng.next_student_t(5).powi(2)).sum();
norm2.sqrt() / ls
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_orf_output_shape() {
let map =
OrthogonalFourierFeatureMap::new(RffKernel::Gaussian { length_scale: 1.0 }, 3, 64, 99);
let x = Array2::<f64>::zeros((5, 3));
let z = map.transform(&x.view()).expect("transform");
assert_eq!(z.shape(), &[5, 64]);
}
#[test]
fn test_orf_kernel_approx_reasonable() {
let map =
OrthogonalFourierFeatureMap::new(RffKernel::Gaussian { length_scale: 1.0 }, 2, 512, 7);
let x1 = [1.0f64, 0.0];
let x2 = [0.0f64, 1.0];
let true_k = (-1.0f64).exp();
let approx_k = map.kernel_approx(&x1, &x2).expect("approx");
let err = (approx_k - true_k).abs();
assert!(err < 0.15, "ORF error={err:.4}, expected < 0.15 for D=512");
}
#[test]
fn test_orf_all_features_finite() {
for kernel in [
RffKernel::Gaussian { length_scale: 1.0 },
RffKernel::Laplacian { length_scale: 1.0 },
RffKernel::Matern32 { length_scale: 1.0 },
RffKernel::Matern52 { length_scale: 1.0 },
] {
let x = Array2::from_shape_fn((4, 2), |(i, j)| (i + j) as f64 * 0.2);
let map = OrthogonalFourierFeatureMap::new(kernel, 2, 32, 0);
let z = map.transform(&x.view()).expect("transform");
assert!(
z.iter().all(|v| v.is_finite()),
"ORF: all features must be finite"
);
}
}
#[test]
fn test_orf_d_out_not_multiple_of_d_in() {
let map =
OrthogonalFourierFeatureMap::new(RffKernel::Gaussian { length_scale: 1.0 }, 3, 7, 5);
let x = Array2::<f64>::zeros((2, 3));
let z = map.transform(&x.view()).expect("transform");
assert_eq!(z.shape(), &[2, 7]);
}
}