#[cfg(test)]
mod tests {
use burn::backend::NdArray;
use burn::tensor::Tensor;
use ncps::cells::cfc_cell::{CfCCell, CfcMode};
type Backend = NdArray<f32>;
fn create_test_cell(mode: CfcMode) -> CfCCell<Backend> {
let device = Default::default();
CfCCell::new(20, 50, &device)
.with_mode(mode)
.with_backbone(128, 1, 0.0)
.with_activation("lecun_tanh")
}
#[test]
fn test_cfc_cell_creation() {
let device = Default::default();
let cell = CfCCell::<Backend>::new(20, 50, &device);
assert_eq!(cell.input_size(), 20);
assert_eq!(cell.hidden_size(), 50);
assert_eq!(cell.mode(), CfcMode::Default);
}
#[test]
fn test_cfc_forward_default() {
let device = Default::default();
let cell = create_test_cell(CfcMode::Default);
let batch_size = 4;
let input = Tensor::<Backend, 2>::zeros([batch_size, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([batch_size, 50], &device);
let (output, new_hidden) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims(), [batch_size, 50]);
assert_eq!(new_hidden.dims(), [batch_size, 50]);
}
#[test]
fn test_cfc_forward_pure() {
let device = Default::default();
let cell = create_test_cell(CfcMode::Pure);
let input = Tensor::<Backend, 2>::random(
[2, 20],
burn::tensor::Distribution::Uniform(-0.5, 0.5),
&device,
);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let (output, _) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims(), [2, 50]);
assert_eq!(cell.mode(), CfcMode::Pure);
}
#[test]
fn test_cfc_forward_no_gate() {
let device = Default::default();
let cell = create_test_cell(CfcMode::NoGate);
let input = Tensor::<Backend, 2>::ones([2, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let (output, new_hidden) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims(), [2, 50]);
assert_eq!(new_hidden.dims(), [2, 50]);
assert_eq!(cell.mode(), CfcMode::NoGate);
}
#[test]
fn test_cfc_state_change() {
let device = Default::default();
let cell = create_test_cell(CfcMode::Default);
let input = Tensor::<Backend, 2>::ones([2, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let (output, new_hidden) = cell.forward(input, hx.clone(), 1.0);
let diff = (new_hidden.clone() - hx).abs().mean().into_scalar();
assert!(diff > 0.0);
let output_diff = (output - new_hidden).abs().mean().into_scalar();
assert!(output_diff < 1e-6, "Output should equal new_hidden");
}
#[test]
fn test_cfc_different_modes_produce_different_results() {
let device = Default::default();
let cell_default = create_test_cell(CfcMode::Default);
let cell_no_gate = create_test_cell(CfcMode::NoGate);
let input = Tensor::<Backend, 2>::random(
[2, 20],
burn::tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let (out1, _) = cell_default.forward(input.clone(), hx.clone(), 1.0);
let (out2, _) = cell_no_gate.forward(input, hx, 1.0);
let diff = (out1 - out2).abs().mean().into_scalar();
assert!(
diff > 0.01,
"Different modes should produce different outputs"
);
}
#[test]
fn test_cfc_backbone_configurations() {
let device = Default::default();
let _cell_no_backbone = CfCCell::<Backend>::new(20, 50, &device).with_backbone(0, 0, 0.0);
let _cell_deep_backbone =
CfCCell::<Backend>::new(20, 50, &device).with_backbone(64, 3, 0.2);
let input = Tensor::<Backend, 2>::zeros([2, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let _cell = CfCCell::<Backend>::new(20, 50, &device).with_backbone(0, 0, 0.0);
let (out1, _) = _cell.forward(input.clone(), hx.clone(), 1.0);
assert_eq!(out1.dims(), [2, 50]);
let _cell2 = CfCCell::<Backend>::new(20, 50, &device).with_backbone(64, 3, 0.2);
let (out2, _) = _cell2.forward(input, hx, 1.0);
assert_eq!(out2.dims(), [2, 50]);
}
#[test]
fn test_cfc_activations() {
let device = Default::default();
for activation in ["relu", "tanh", "gelu", "silu", "lecun_tanh"] {
let cell = CfCCell::<Backend>::new(20, 50, &device)
.with_backbone(64, 1, 0.0)
.with_activation(activation);
let input = Tensor::<Backend, 2>::zeros([2, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let (output, _) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims()[0], 2);
}
}
#[test]
#[should_panic]
fn test_cfc_invalid_activation() {
let device = Default::default();
CfCCell::<Backend>::new(20, 50, &device).with_activation("invalid_activation");
}
#[test]
fn test_cfc_batch_processing() {
let device = Default::default();
let cell = create_test_cell(CfcMode::Default);
for batch in [1, 8, 32] {
let input = Tensor::<Backend, 2>::zeros([batch, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([batch, 50], &device);
let (output, _) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims(), [batch, 50]);
}
}
#[test]
fn test_cfc_sparsity_mask() {
let device = Default::default();
use ndarray::Array2;
let mask = Array2::from_shape_vec((50, 50), vec![1.0f32; 2500]).unwrap();
let cell = CfCCell::<Backend>::new(20, 50, &device).with_sparsity_mask(mask, &device);
let input = Tensor::<Backend, 2>::zeros([2, 20], &device);
let hx = Tensor::<Backend, 2>::zeros([2, 50], &device);
let (output, _) = cell.forward(input, hx, 1.0);
assert_eq!(output.dims(), [2, 50]);
}
}