use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub struct ActivityRegularization<F: Float + Debug + Send + Sync + NumAssign> {
l1_factor: Option<F>,
l2_factor: Option<F>,
name: Option<String>,
input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
activity_loss: Arc<RwLock<F>>,
_phantom: PhantomData<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign>
ActivityRegularization<F>
{
pub fn new(l1_factor: Option<f64>, l2_factor: Option<f64>, name: Option<&str>) -> Result<Self> {
if l1_factor.is_none() && l2_factor.is_none() {
return Err(NeuralError::InvalidArchitecture(
"At least one of L1 or L2 regularization factor must be provided".to_string(),
));
}
if let Some(l1) = l1_factor {
if l1 < 0.0 {
return Err(NeuralError::InvalidArchitecture(
"L1 regularization factor must be non-negative".to_string(),
));
}
}
if let Some(l2) = l2_factor {
if l2 < 0.0 {
return Err(NeuralError::InvalidArchitecture(
"L2 regularization factor must be non-negative".to_string(),
));
}
}
Ok(Self {
l1_factor: l1_factor.map(|x| F::from(x).expect("Failed to convert to float")),
l2_factor: l2_factor.map(|x| F::from(x).expect("Failed to convert to float")),
name: name.map(String::from),
input_cache: Arc::new(RwLock::new(None)),
activity_loss: Arc::new(RwLock::new(F::zero())),
_phantom: PhantomData,
})
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn get_activity_loss(&self) -> Result<F> {
match self.activity_loss.read() {
Ok(loss) => Ok(*loss),
Err(_) => Err(NeuralError::InferenceError(
"Failed to acquire read lock on activity loss".to_string(),
)),
}
}
fn calculate_activity_loss(&self, input: &Array<F, IxDyn>) -> F {
let mut total_loss = F::zero();
if let Some(l1_factor) = self.l1_factor {
let l1_loss = input.mapv(|x| x.abs()).sum();
total_loss += l1_factor * l1_loss;
}
if let Some(l2_factor) = self.l2_factor {
let l2_loss = input.mapv(|x| x * x).sum();
total_loss += l2_factor * l2_loss;
}
total_loss
}
fn calculate_activity_gradients(&self, input: &Array<F, IxDyn>) -> Array<F, IxDyn> {
let mut grad = Array::<F, IxDyn>::zeros(input.raw_dim());
if let Some(l1_factor) = self.l1_factor {
let l1_grad = input.mapv(|x| {
if x > F::zero() {
l1_factor
} else if x < F::zero() {
-l1_factor
} else {
F::zero()
}
});
grad = grad + l1_grad;
}
if let Some(l2_factor) = self.l2_factor {
let two = F::from(2.0).expect("Failed to convert constant to float");
let l2_grad = input.mapv(|x| two * l2_factor * x);
grad = grad + l2_grad;
}
grad
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> Layer<F>
for ActivityRegularization<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if let Ok(mut cache) = self.input_cache.write() {
*cache = Some(input.clone());
} else {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on input cache".to_string(),
));
}
let loss = self.calculate_activity_loss(input);
if let Ok(mut loss_cache) = self.activity_loss.write() {
*loss_cache = loss;
} else {
return Err(NeuralError::InferenceError(
"Failed to acquire write lock on activity loss cache".to_string(),
));
}
Ok(input.clone())
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let input_ref = match self.input_cache.read() {
Ok(guard) => guard,
Err(_) => {
return Err(NeuralError::InferenceError(
"Failed to acquire read lock on input cache".to_string(),
));
}
};
if input_ref.is_none() {
return Err(NeuralError::InferenceError(
"No cached input for backward pass. Call forward() first.".to_string(),
));
}
let cached_input = input_ref.as_ref().expect("Operation failed");
if cached_input.shape() != grad_output.shape() {
return Err(NeuralError::InferenceError(
"Input and gradient output shapes must match".to_string(),
));
}
let activity_grad = self.calculate_activity_gradients(cached_input);
Ok(grad_output + &activity_grad)
}
fn update(&mut self, _learning_rate: F) -> Result<()> {
Ok(())
}
fn layer_type(&self) -> &str {
"ActivityRegularization"
}
fn parameter_count(&self) -> usize {
0
}
fn layer_description(&self) -> String {
let l1_str = match self.l1_factor {
Some(l1) => format!("{l1:?}"),
None => "None".to_string(),
};
let l2_str = match self.l2_factor {
Some(l2) => format!("{l2:?}"),
None => "None".to_string(),
};
format!(
"type:ActivityRegularization, l1:{l1_str}, l2:{l2_str}, name:{}",
self.name.as_ref().map_or("None", |s| s)
)
}
}
#[derive(Debug, Clone)]
pub struct L1ActivityRegularization<F: Float + Debug + Send + Sync + NumAssign> {
inner: ActivityRegularization<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign>
L1ActivityRegularization<F>
{
pub fn new(factor: f64, name: Option<&str>) -> Result<Self> {
Ok(Self {
inner: ActivityRegularization::new(Some(factor), None, name)?,
})
}
pub fn name(&self) -> Option<&str> {
self.inner.name()
}
pub fn get_activity_loss(&self) -> Result<F> {
self.inner.get_activity_loss()
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> Layer<F>
for L1ActivityRegularization<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
self.inner.forward(input)
}
fn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
self.inner.backward(input, grad_output)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.inner.update(learning_rate)
}
fn layer_type(&self) -> &str {
"L1ActivityRegularization"
}
fn parameter_count(&self) -> usize {
self.inner.parameter_count()
}
fn layer_description(&self) -> String {
self.inner
.layer_description()
.replace("ActivityRegularization", "L1ActivityRegularization")
}
}
#[derive(Debug, Clone)]
pub struct L2ActivityRegularization<F: Float + Debug + Send + Sync + NumAssign> {
inner: ActivityRegularization<F>,
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign>
L2ActivityRegularization<F>
{
pub fn new(factor: f64, name: Option<&str>) -> Result<Self> {
Ok(Self {
inner: ActivityRegularization::new(None, Some(factor), name)?,
})
}
pub fn name(&self) -> Option<&str> {
self.inner.name()
}
pub fn get_activity_loss(&self) -> Result<F> {
self.inner.get_activity_loss()
}
}
impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static + NumAssign> Layer<F>
for L2ActivityRegularization<F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
self.inner.forward(input)
}
fn backward(
&self,
input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
self.inner.backward(input, grad_output)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.inner.update(learning_rate)
}
fn layer_type(&self) -> &str {
"L2ActivityRegularization"
}
fn parameter_count(&self) -> usize {
self.inner.parameter_count()
}
fn layer_description(&self) -> String {
self.inner
.layer_description()
.replace("ActivityRegularization", "L2ActivityRegularization")
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_activity_regularization_creation() {
let l1_reg = ActivityRegularization::<f64>::new(Some(0.01), None, Some("l1"))
.expect("Operation failed");
assert!(l1_reg.l1_factor.is_some());
assert!(l1_reg.l2_factor.is_none());
let l2_reg = ActivityRegularization::<f64>::new(None, Some(0.02), Some("l2"))
.expect("Operation failed");
assert!(l2_reg.l1_factor.is_none());
assert!(l2_reg.l2_factor.is_some());
let both_reg = ActivityRegularization::<f64>::new(Some(0.01), Some(0.02), Some("both"))
.expect("Operation failed");
assert!(both_reg.l1_factor.is_some());
assert!(both_reg.l2_factor.is_some());
assert!(ActivityRegularization::<f64>::new(None, None, Some("none")).is_err());
}
#[test]
fn test_activity_regularization_forward() {
let reg = ActivityRegularization::<f64>::new(Some(0.01), Some(0.02), Some("test"))
.expect("Operation failed");
let input = Array2::<f64>::from_elem((2, 3), 1.0);
let input_dyn = input.clone().into_dyn();
let output = reg.forward(&input_dyn).expect("Operation failed");
assert_eq!(input.into_dyn().shape(), output.shape());
for (a, b) in input_dyn.iter().zip(output.iter()) {
assert!((a - b).abs() < 1e-10);
}
}
#[test]
fn test_activity_regularization_backward() {
let reg = ActivityRegularization::<f64>::new(Some(0.1), Some(0.1), Some("test"))
.expect("Operation failed");
let input = array![[1.0, -2.0, 0.5], [0.0, 3.0, -1.0]].into_dyn();
let grad_output = array![[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]].into_dyn();
let _output = reg.forward(&input).expect("Operation failed");
let grad_input = reg
.backward(&input, &grad_output)
.expect("Operation failed");
assert_eq!(grad_input.shape(), input.shape());
}
#[test]
fn test_l1_activity_regularization() {
let reg =
L1ActivityRegularization::<f64>::new(0.01, Some("l1_test")).expect("Operation failed");
let input = Array2::<f64>::from_elem((2, 3), 2.0).into_dyn();
let _output = reg.forward(&input).expect("Operation failed");
let loss = reg.get_activity_loss().expect("Operation failed");
assert!(loss > 0.0); }
#[test]
fn test_l2_activity_regularization() {
let reg =
L2ActivityRegularization::<f64>::new(0.01, Some("l2_test")).expect("Operation failed");
let input = Array2::<f64>::from_elem((2, 3), 2.0).into_dyn();
let _output = reg.forward(&input).expect("Operation failed");
let loss = reg.get_activity_loss().expect("Operation failed");
assert!(loss > 0.0); }
#[test]
fn test_activity_loss_calculation() {
let reg = ActivityRegularization::<f64>::new(Some(0.1), Some(0.1), Some("test"))
.expect("Operation failed");
let input = array![[1.0, -1.0], [2.0, 0.0]].into_dyn();
let _output = reg.forward(&input).expect("Operation failed");
let loss = reg.get_activity_loss().expect("Operation failed");
assert!((loss - 1.0).abs() < 1e-10);
}
}