use crate::cells::CfCCell;
use crate::wirings::Wiring;
use burn::module::Module;
use burn::nn::{Linear, LinearConfig};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
#[derive(Module, Debug)]
pub struct CfC<B: Backend> {
cell: CfCCell<B>,
proj: Option<Linear<B>>,
#[module(skip)]
input_size: usize,
#[module(skip)]
hidden_size: usize,
#[module(skip)]
batch_first: bool,
#[module(skip)]
return_sequences: bool,
#[module(skip)]
proj_size: Option<usize>,
#[module(skip)]
output_size: usize,
}
impl<B: Backend> CfC<B> {
pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
let cell = CfCCell::new(input_size, hidden_size, device);
Self {
cell,
proj: None,
input_size,
hidden_size,
batch_first: true,
return_sequences: true,
proj_size: None,
output_size: hidden_size,
}
}
pub fn with_wiring(input_size: usize, wiring: impl Wiring, device: &B::Device) -> Self {
let state_size = wiring.units();
let motor_size = wiring.output_dim().unwrap_or(state_size);
let cell = CfCCell::new(input_size, state_size, device);
let output_size = motor_size;
let proj = if motor_size != state_size {
Some(
LinearConfig::new(state_size, motor_size)
.with_bias(true)
.init(device),
)
} else {
None
};
Self {
cell,
proj,
input_size,
hidden_size: state_size,
batch_first: true,
return_sequences: true,
proj_size: if motor_size != state_size {
Some(motor_size)
} else {
None
},
output_size,
}
}
pub fn with_batch_first(mut self, batch_first: bool) -> Self {
self.batch_first = batch_first;
self
}
pub fn with_return_sequences(mut self, return_sequences: bool) -> Self {
self.return_sequences = return_sequences;
self
}
pub fn with_proj_size(mut self, proj_size: usize) -> Self {
let device = self.get_device();
self.proj = Some(
LinearConfig::new(self.hidden_size, proj_size)
.with_bias(true)
.init(&device),
);
self.proj_size = Some(proj_size);
self.output_size = proj_size;
self
}
pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
self
}
fn get_device(&self) -> B::Device {
B::Device::default()
}
pub fn input_size(&self) -> usize {
self.input_size
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn output_size(&self) -> usize {
self.output_size
}
pub fn forward(
&self,
input: Tensor<B, 3>,
state: Option<Tensor<B, 2>>,
_timespans: Option<Tensor<B, 2>>,
) -> (Tensor<B, 3>, Tensor<B, 2>) {
let device = input.device();
let (batch_size, seq_len, _) = if self.batch_first {
let dims = input.dims();
(dims[0], dims[1], dims[2])
} else {
let dims = input.dims();
(dims[1], dims[0], dims[2])
};
let mut current_state =
state.unwrap_or_else(|| Tensor::<B, 2>::zeros([batch_size, self.hidden_size], &device));
let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let step_input = if self.batch_first {
input.clone().narrow(1, t, 1).squeeze(1)
} else {
input.clone().narrow(0, t, 1).squeeze(0)
};
let (mut output, new_state) = self.cell.forward(step_input, current_state, 1.0);
current_state = new_state;
if let Some(ref proj) = self.proj {
output = proj.forward(output);
}
if self.return_sequences {
outputs.push(output);
} else if t == seq_len - 1 {
outputs.push(output);
}
}
let output = Tensor::stack(outputs, 1); (output, current_state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wirings::{AutoNCP, FullyConnected};
use burn::backend::NdArray;
use burn::tensor::backend::Backend as BurnBackend;
type TestBackend = NdArray<f32>;
type TestDevice = <TestBackend as BurnBackend>::Device;
fn get_test_device() -> TestDevice {
Default::default()
}
#[test]
fn test_cfc_rnn_creation() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device);
assert_eq!(cfc.input_size(), 20);
assert_eq!(cfc.hidden_size(), 50);
assert_eq!(cfc.output_size(), 50);
}
#[test]
fn test_cfc_rnn_with_wiring() {
let device = get_test_device();
let wiring = AutoNCP::new(32, 8, 0.5, 22222);
let cfc = CfC::<TestBackend>::with_wiring(20, wiring, &device);
assert_eq!(cfc.output_size(), 8);
}
#[test]
fn test_cfc_rnn_forward() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let (output, state) = cfc.forward(input, None, None);
assert_eq!(output.dims(), [4, 10, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_cfc_rnn_with_projection() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device).with_proj_size(10);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let (output, _) = cfc.forward(input, None, None);
assert_eq!(output.dims(), [4, 10, 10]);
assert_eq!(cfc.output_size(), 10);
}
#[test]
fn test_cfc_rnn_backbone_config() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device).with_backbone(128, 2, 0.1);
let input = Tensor::<TestBackend, 3>::zeros([2, 5, 20], &device);
let (output, _) = cfc.forward(input, None, None);
assert_eq!(output.dims(), [2, 5, 50]);
}
#[test]
fn test_cfc_rnn_return_last_only() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device).with_return_sequences(false);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let (output, state) = cfc.forward(input, None, None);
assert_eq!(output.dims(), [4, 1, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_cfc_rnn_seq_first() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device).with_batch_first(false);
let input = Tensor::<TestBackend, 3>::zeros([10, 4, 20], &device);
let (output, state) = cfc.forward(input, None, None);
assert_eq!(output.dims(), [4, 10, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_cfc_rnn_with_initial_state() {
let device = get_test_device();
let cfc = CfC::<TestBackend>::new(20, 50, &device);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let initial_state = Tensor::<TestBackend, 2>::ones([4, 50], &device);
let (output, state) = cfc.forward(input, Some(initial_state), None);
assert_eq!(output.dims(), [4, 10, 50]);
assert_eq!(state.dims(), [4, 50]);
}
}