use scirs2_core::ndarray::{Array, Array3, ArrayBase, Data, Dimension, Ix2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct OrthogonalRegularization<A: Float> {
lambda: A,
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> OrthogonalRegularization<A> {
pub fn new(lambda: A) -> Self {
Self { lambda }
}
pub fn compute_penalty_2d<S: Data<Elem = A>>(&self, weights: &ArrayBase<S, Ix2>) -> A {
let n = weights.nrows().min(weights.ncols());
let eye = Array::<A, Ix2>::eye(n);
let wtw = weights.t().dot(weights);
let mut penalty = A::zero();
for i in 0..n {
for j in 0..n {
let diff = wtw[[i, j]] - eye[[i, j]];
penalty = penalty + diff * diff;
}
}
if weights.nrows() != weights.ncols() {
let (rows, cols) = wtw.dim();
for i in 0..rows {
for j in 0..cols {
if i >= n || j >= n {
penalty = penalty + wtw[[i, j]] * wtw[[i, j]];
}
}
}
}
self.lambda * penalty
}
fn compute_gradient_2d<S: Data<Elem = A>>(&self, weights: &ArrayBase<S, Ix2>) -> Array<A, Ix2> {
let n = weights.nrows().min(weights.ncols());
let wtw = weights.t().dot(weights);
let mut diff = wtw.clone();
for i in 0..n {
diff[[i, i]] = diff[[i, i]] - A::one();
}
weights.dot(&diff) * (A::from_f64(2.0).expect("unwrap failed") * self.lambda)
}
}
impl<
A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dimension + Send + Sync,
> Regularizer<A, D> for OrthogonalRegularization<A>
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
if params.ndim() != 2 {
return Ok(A::zero());
}
let params_2d = params
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
let gradient_update = self.compute_gradient_2d(¶ms_2d);
let mut gradients_2d = gradients
.view_mut()
.into_dimensionality::<Ix2>()
.map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
gradients_2d.zip_mut_with(&gradient_update, |g, &u| *g = *g + u);
Ok(self.compute_penalty_2d(¶ms_2d))
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
if params.ndim() != 2 {
return Ok(A::zero());
}
let params_2d = params
.view()
.into_dimensionality::<Ix2>()
.map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
Ok(self.compute_penalty_2d(¶ms_2d))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_orthogonal_creation() {
let ortho = OrthogonalRegularization::<f64>::new(0.01);
assert_eq!(ortho.lambda, 0.01);
}
#[test]
fn test_identity_matrix_penalty() {
let ortho = OrthogonalRegularization::new(0.01);
let weights = array![[1.0, 0.0], [0.0, 1.0]];
let penalty = ortho.compute_penalty_2d(&weights);
assert_relative_eq!(penalty, 0.0, epsilon = 1e-10);
}
#[test]
fn test_non_orthogonal_penalty() {
let ortho = OrthogonalRegularization::new(0.01);
let weights = array![[1.0, 0.5], [0.5, 1.0]];
let penalty = ortho.compute_penalty_2d(&weights);
assert!(penalty > 0.0);
}
#[test]
fn test_rectangular_matrix() {
let ortho = OrthogonalRegularization::new(0.01);
let weights = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]];
let penalty = ortho.compute_penalty_2d(&weights);
assert!(penalty >= 0.0);
}
#[test]
fn test_gradient_computation() {
let ortho = OrthogonalRegularization::new(0.1);
let weights = array![[1.0, 0.5], [0.5, 1.0]];
let gradient = ortho.compute_gradient_2d(&weights);
assert!(gradient.abs().sum() > 0.0);
}
#[test]
fn test_regularizer_trait() {
let ortho = OrthogonalRegularization::new(0.01);
let params = array![[1.0, 0.5], [0.5, 1.0]];
let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
let original_gradient = gradient.clone();
let penalty = ortho.apply(¶ms, &mut gradient).expect("unwrap failed");
assert!(penalty > 0.0);
assert_ne!(gradient, original_gradient);
let penalty2 = ortho.penalty(¶ms).expect("unwrap failed");
assert_relative_eq!(penalty, penalty2, epsilon = 1e-10);
}
#[test]
fn test_non_2d_array() {
let ortho = OrthogonalRegularization::new(0.01);
let params = Array3::<f64>::zeros((2, 2, 2));
let mut gradient = Array3::<f64>::zeros((2, 2, 2));
let penalty = ortho.apply(¶ms, &mut gradient).expect("unwrap failed");
assert_eq!(penalty, 0.0);
assert_eq!(gradient, Array3::<f64>::zeros((2, 2, 2)));
}
}