use scirs2_core::ndarray::{Array, Array2, ArrayBase, Data, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct ManifoldRegularization<A: Float> {
lambda: A,
similarity_matrix: Option<Array2<A>>,
degree_matrix: Option<Array2<A>>,
laplacian: Option<Array2<A>>,
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> ManifoldRegularization<A> {
pub fn new(lambda: A) -> Self {
Self {
lambda,
similarity_matrix: None,
degree_matrix: None,
laplacian: None,
}
}
pub fn set_similarity_matrix(&mut self, similarity: Array2<A>) -> Result<()> {
let (rows, cols) = similarity.dim();
if rows != cols {
return Err(OptimError::InvalidConfig(
"Similarity matrix must be square".to_string(),
));
}
let mut degree = Array2::zeros((rows, rows));
for i in 0..rows {
let row_sum = similarity.row(i).sum();
degree[[i, i]] = row_sum;
}
let laplacian = °ree - &similarity;
self.similarity_matrix = Some(similarity);
self.degree_matrix = Some(degree);
self.laplacian = Some(laplacian);
Ok(())
}
pub fn compute_penalty<S>(&self, params: &ArrayBase<S, scirs2_core::ndarray::Ix2>) -> Result<A>
where
S: Data<Elem = A>,
{
let laplacian = self
.laplacian
.as_ref()
.ok_or_else(|| OptimError::InvalidConfig("Similarity matrix not set".to_string()))?;
let lf = laplacian.dot(params);
let penalty = params
.iter()
.zip(lf.iter())
.map(|(p, lf)| *p * *lf)
.fold(A::zero(), |acc, val| acc + val);
Ok(self.lambda * penalty)
}
fn compute_gradient<S>(
&self,
params: &ArrayBase<S, scirs2_core::ndarray::Ix2>,
) -> Result<Array2<A>>
where
S: Data<Elem = A>,
{
let laplacian = self
.laplacian
.as_ref()
.ok_or_else(|| OptimError::InvalidConfig("Similarity matrix not set".to_string()))?;
let gradient =
laplacian.dot(params) * (A::from_f64(2.0).expect("unwrap failed") * self.lambda);
Ok(gradient)
}
}
impl<
A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dimension + Send + Sync,
> Regularizer<A, D> for ManifoldRegularization<A>
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
if params.ndim() != 2 {
return Ok(A::zero());
}
let params_2d = params
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
let gradient_update = self.compute_gradient(¶ms_2d)?;
let mut gradients_2d = gradients
.view_mut()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
gradients_2d.zip_mut_with(&gradient_update, |g, &u| *g = *g + u);
self.compute_penalty(¶ms_2d)
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
if params.ndim() != 2 {
return Ok(A::zero());
}
let params_2d = params
.view()
.into_dimensionality::<scirs2_core::ndarray::Ix2>()
.map_err(|_| OptimError::InvalidConfig("Expected 2D array".to_string()))?;
self.compute_penalty(¶ms_2d)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_manifold_creation() {
let manifold = ManifoldRegularization::<f64>::new(0.01);
assert_eq!(manifold.lambda, 0.01);
assert!(manifold.similarity_matrix.is_none());
}
#[test]
fn test_set_similarity_matrix() {
let mut manifold = ManifoldRegularization::new(0.01);
let similarity = array![[1.0, 0.5], [0.5, 1.0]];
assert!(manifold.set_similarity_matrix(similarity).is_ok());
assert!(manifold.laplacian.is_some());
let laplacian = manifold.laplacian.as_ref().expect("unwrap failed");
assert_relative_eq!(laplacian[[0, 0]], 0.5, epsilon = 1e-10);
assert_relative_eq!(laplacian[[0, 1]], -0.5, epsilon = 1e-10);
}
#[test]
fn test_invalid_similarity_matrix() {
let mut manifold = ManifoldRegularization::<f64>::new(0.01);
let similarity = array![[1.0, 0.5, 0.3], [0.5, 1.0, 0.4]];
assert!(manifold.set_similarity_matrix(similarity).is_err());
}
#[test]
fn test_penalty_without_similarity() {
let manifold = ManifoldRegularization::<f64>::new(0.01);
let params = array![[1.0, 2.0], [3.0, 4.0]];
assert!(manifold.compute_penalty(¶ms).is_err());
}
#[test]
fn test_penalty_computation() {
let mut manifold = ManifoldRegularization::new(0.1);
let similarity = array![[1.0, 0.8], [0.8, 1.0]];
manifold
.set_similarity_matrix(similarity)
.expect("unwrap failed");
let params = array![[1.0, 0.0], [0.0, 1.0]];
let penalty = manifold.compute_penalty(¶ms).expect("unwrap failed");
assert!(penalty > 0.0);
}
#[test]
fn test_gradient_computation() {
let mut manifold = ManifoldRegularization::new(0.1);
let similarity = array![[1.0, 0.8], [0.8, 1.0]];
manifold
.set_similarity_matrix(similarity)
.expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let gradient = manifold.compute_gradient(¶ms).expect("unwrap failed");
assert!(gradient.abs().sum() > 0.0);
}
#[test]
fn test_regularizer_trait() {
let mut manifold = ManifoldRegularization::new(0.01);
let similarity = array![[1.0, 0.6], [0.6, 1.0]];
manifold
.set_similarity_matrix(similarity)
.expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let mut gradient = array![[0.1, 0.2], [0.3, 0.4]];
let original_gradient = gradient.clone();
let penalty = manifold
.apply(¶ms, &mut gradient)
.expect("unwrap failed");
assert!(penalty > 0.0);
assert_ne!(gradient, original_gradient);
}
#[test]
fn test_identity_similarity() {
let mut manifold = ManifoldRegularization::new(0.1);
let similarity = array![[1.0, 0.0], [0.0, 1.0]];
manifold
.set_similarity_matrix(similarity)
.expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let penalty = manifold.compute_penalty(¶ms).expect("unwrap failed");
assert!(penalty >= 0.0);
}
}