use crate::cells::LSTMCell;
use crate::cells::LTCCell;
use crate::wirings::Wiring;
use burn::module::Module;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
#[derive(Module, Debug)]
pub struct LTC<B: Backend> {
cell: LTCCell<B>,
#[module(skip)]
lstm_cell: Option<LSTMCell<B>>,
#[module(skip)]
input_size: usize,
#[module(skip)]
state_size: usize,
#[module(skip)]
motor_size: usize,
#[module(skip)]
batch_first: bool,
#[module(skip)]
return_sequences: bool,
#[module(skip)]
mixed_memory: bool,
}
impl<B: Backend> LTC<B> {
pub fn new(input_size: usize, wiring: impl Wiring, device: &B::Device) -> Self {
let state_size = wiring.units();
let motor_size = wiring.output_dim().unwrap_or(state_size);
let cell = LTCCell::new(&wiring, Some(input_size), device);
Self {
cell,
lstm_cell: None,
input_size,
state_size,
motor_size,
batch_first: true,
return_sequences: true,
mixed_memory: false,
}
}
pub fn with_batch_first(mut self, batch_first: bool) -> Self {
self.batch_first = batch_first;
self
}
pub fn with_return_sequences(mut self, return_sequences: bool) -> Self {
self.return_sequences = return_sequences;
self
}
pub fn with_mixed_memory(mut self, mixed_memory: bool, device: &B::Device) -> Self {
self.mixed_memory = mixed_memory;
if mixed_memory && self.lstm_cell.is_none() {
self.lstm_cell = Some(LSTMCell::new(self.input_size, self.state_size, device));
}
self
}
pub fn input_size(&self) -> usize {
self.input_size
}
pub fn state_size(&self) -> usize {
self.state_size
}
pub fn motor_size(&self) -> usize {
self.motor_size
}
pub fn forward(
&self,
input: Tensor<B, 3>,
state: Option<Tensor<B, 2>>,
timespans: Option<Tensor<B, 2>>,
) -> (Tensor<B, 3>, Tensor<B, 2>) {
let device = input.device();
let (batch_size, seq_len, _) = if self.batch_first {
let dims = input.dims();
(dims[0], dims[1], dims[2])
} else {
let dims = input.dims();
(dims[1], dims[0], dims[2])
};
let mut current_state =
state.unwrap_or_else(|| Tensor::<B, 2>::zeros([batch_size, self.state_size], &device));
let timespans =
timespans.unwrap_or_else(|| Tensor::<B, 2>::ones([batch_size, seq_len], &device));
let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
for t in 0..seq_len {
let step_input = if self.batch_first {
input.clone().narrow(1, t, 1).squeeze(1)
} else {
input.clone().narrow(0, t, 1).squeeze(0)
};
let step_time = timespans.clone().narrow(1, t, 1).squeeze(1);
let (output, new_state) = self.cell.forward(step_input, current_state, step_time);
current_state = new_state;
if self.return_sequences {
outputs.push(output);
} else if t == seq_len - 1 {
outputs.push(output);
}
}
let output = Tensor::stack(outputs, 1); (output, current_state)
}
pub fn forward_mixed(
&self,
input: Tensor<B, 3>,
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
timespans: Option<Tensor<B, 2>>,
) -> (Tensor<B, 3>, (Tensor<B, 2>, Tensor<B, 2>))
where
B: Backend,
{
if !self.mixed_memory {
panic!("Mixed memory not enabled. Call with_mixed_memory(true) first.");
}
let device = input.device();
let (batch_size, seq_len, _) = if self.batch_first {
let dims = input.dims();
(dims[0], dims[1], dims[2])
} else {
let dims = input.dims();
(dims[1], dims[0], dims[2])
};
let (mut h_state, mut c_state) = state.unwrap_or_else(|| {
(
Tensor::<B, 2>::zeros([batch_size, self.state_size], &device),
Tensor::<B, 2>::zeros([batch_size, self.state_size], &device),
)
});
let timespans =
timespans.unwrap_or_else(|| Tensor::<B, 2>::ones([batch_size, seq_len], &device));
let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
let lstm = self.lstm_cell.as_ref().expect("LSTM cell not initialized");
for t in 0..seq_len {
let step_input = if self.batch_first {
input.clone().narrow(1, t, 1).squeeze(1)
} else {
input.clone().narrow(0, t, 1).squeeze(0)
};
let step_time = timespans.clone().narrow(1, t, 1).squeeze(1);
let (new_h, new_c) = lstm.forward(step_input.clone(), (h_state, c_state));
h_state = new_h.clone();
c_state = new_c;
let (ltc_output, new_ltc_state) =
self.cell.forward(step_input, h_state.clone(), step_time);
h_state = new_ltc_state;
if self.return_sequences {
outputs.push(ltc_output);
} else if t == seq_len - 1 {
outputs.push(ltc_output);
}
}
let output = Tensor::stack(outputs, 1);
(output, (h_state, c_state))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wirings::{AutoNCP, FullyConnected};
use burn::backend::NdArray;
use burn::tensor::backend::Backend as BurnBackend;
type TestBackend = NdArray<f32>;
type TestDevice = <TestBackend as BurnBackend>::Device;
fn get_test_device() -> TestDevice {
Default::default()
}
#[test]
fn test_ltc_rnn_creation() {
let device = get_test_device();
let wiring = FullyConnected::new(50, None, 1234, true);
let ltc = LTC::<TestBackend>::new(20, wiring, &device);
assert_eq!(ltc.input_size(), 20);
assert_eq!(ltc.state_size(), 50);
}
#[test]
fn test_ltc_rnn_forward_batch_first() {
let device = get_test_device();
let wiring = FullyConnected::new(50, None, 1234, true);
let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_batch_first(true);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let (output, state) = ltc.forward(input, None, None);
assert_eq!(output.dims(), [4, 10, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_ltc_rnn_forward_seq_first() {
let device = get_test_device();
let wiring = FullyConnected::new(50, None, 1234, true);
let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_batch_first(false);
let input = Tensor::<TestBackend, 3>::zeros([10, 4, 20], &device);
let (output, state) = ltc.forward(input, None, None);
assert_eq!(output.dims(), [4, 10, 50]);
}
#[test]
fn test_ltc_rnn_return_last_only() {
let device = get_test_device();
let wiring = FullyConnected::new(50, None, 1234, true);
let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_return_sequences(false);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let (output, state) = ltc.forward(input, None, None);
assert_eq!(output.dims(), [4, 1, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_ltc_rnn_with_initial_state() {
let device = get_test_device();
let wiring = FullyConnected::new(50, None, 1234, true);
let ltc = LTC::<TestBackend>::new(20, wiring, &device);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let initial_state = Tensor::<TestBackend, 2>::ones([4, 50], &device);
let (output, state) = ltc.forward(input, Some(initial_state), None);
assert_eq!(output.dims(), [4, 10, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_ltc_rnn_with_timespans() {
let device = get_test_device();
let wiring = FullyConnected::new(50, None, 1234, true);
let ltc = LTC::<TestBackend>::new(20, wiring, &device);
let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
let timespans = Tensor::<TestBackend, 2>::full([4, 10], 0.5, &device);
let (output, state) = ltc.forward(input, None, Some(timespans));
assert_eq!(output.dims(), [4, 10, 50]);
assert_eq!(state.dims(), [4, 50]);
}
#[test]
fn test_ltc_rnn_with_ncp_wiring() {
let device = get_test_device();
let wiring = AutoNCP::new(64, 8, 0.5, 22222);
let ltc = LTC::<TestBackend>::new(20, wiring, &device);
let input = Tensor::<TestBackend, 3>::zeros([2, 5, 20], &device);
let (output, state) = ltc.forward(input, None, None);
assert_eq!(output.dims(), [2, 5, 8]);
assert_eq!(state.dims(), [2, 64]);
}
#[test]
fn test_ltc_rnn_sequence_processing() {
let device = get_test_device();
let wiring = FullyConnected::new(20, None, 1234, true);
let ltc = LTC::<TestBackend>::new(10, wiring, &device);
for seq_len in [1, 5, 20] {
let input = Tensor::<TestBackend, 3>::zeros([2, seq_len, 10], &device);
let (output, _) = ltc.forward(input, None, None);
assert_eq!(output.dims(), [2, seq_len, 20]);
}
}
}