1use crate::cells::LSTMCell;
7use crate::cells::LTCCell;
8use crate::wirings::Wiring;
9use burn::module::Module;
10use burn::tensor::backend::Backend;
11use burn::tensor::Tensor;
12
13#[derive(Module, Debug)]
21pub struct LTC<B: Backend> {
22 cell: LTCCell<B>,
24 #[module(skip)]
26 lstm_cell: Option<LSTMCell<B>>,
27 #[module(skip)]
29 input_size: usize,
30 #[module(skip)]
32 state_size: usize,
33 #[module(skip)]
35 motor_size: usize,
36 #[module(skip)]
38 batch_first: bool,
39 #[module(skip)]
41 return_sequences: bool,
42 #[module(skip)]
44 mixed_memory: bool,
45}
46
47impl<B: Backend> LTC<B> {
48 pub fn new(input_size: usize, wiring: impl Wiring, device: &B::Device) -> Self {
55 let state_size = wiring.units();
56 let motor_size = wiring.output_dim().unwrap_or(state_size);
57
58 let cell = LTCCell::new(&wiring, Some(input_size), device);
59
60 Self {
61 cell,
62 lstm_cell: None,
63 input_size,
64 state_size,
65 motor_size,
66 batch_first: true,
67 return_sequences: true,
68 mixed_memory: false,
69 }
70 }
71
72 pub fn with_batch_first(mut self, batch_first: bool) -> Self {
77 self.batch_first = batch_first;
78 self
79 }
80
81 pub fn with_return_sequences(mut self, return_sequences: bool) -> Self {
86 self.return_sequences = return_sequences;
87 self
88 }
89
90 pub fn with_mixed_memory(mut self, mixed_memory: bool, device: &B::Device) -> Self {
99 self.mixed_memory = mixed_memory;
100 if mixed_memory && self.lstm_cell.is_none() {
101 self.lstm_cell = Some(LSTMCell::new(self.input_size, self.state_size, device));
103 }
104 self
105 }
106
107 pub fn input_size(&self) -> usize {
109 self.input_size
110 }
111
112 pub fn state_size(&self) -> usize {
114 self.state_size
115 }
116
117 pub fn motor_size(&self) -> usize {
119 self.motor_size
120 }
121
122 pub fn forward(
137 &self,
138 input: Tensor<B, 3>,
139 state: Option<Tensor<B, 2>>,
140 timespans: Option<Tensor<B, 2>>,
141 ) -> (Tensor<B, 3>, Tensor<B, 2>) {
142 let device = input.device();
143
144 let (batch_size, seq_len, _) = if self.batch_first {
146 let dims = input.dims();
147 (dims[0], dims[1], dims[2])
148 } else {
149 let dims = input.dims();
150 (dims[1], dims[0], dims[2])
151 };
152
153 let mut current_state =
155 state.unwrap_or_else(|| Tensor::<B, 2>::zeros([batch_size, self.state_size], &device));
156
157 let timespans =
159 timespans.unwrap_or_else(|| Tensor::<B, 2>::ones([batch_size, seq_len], &device));
160
161 let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
163
164 for t in 0..seq_len {
165 let step_input = if self.batch_first {
167 input.clone().narrow(1, t, 1).squeeze(1)
169 } else {
170 input.clone().narrow(0, t, 1).squeeze(0)
172 };
173
174 let step_time = timespans.clone().narrow(1, t, 1).squeeze(1);
176
177 let (output, new_state) = self.cell.forward(step_input, current_state, step_time);
179 current_state = new_state;
180
181 if self.return_sequences {
182 outputs.push(output);
183 } else if t == seq_len - 1 {
184 outputs.push(output);
186 }
187 }
188
189 let output = Tensor::stack(outputs, 1); (output, current_state)
192 }
193
194 pub fn forward_mixed(
201 &self,
202 input: Tensor<B, 3>,
203 state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
204 timespans: Option<Tensor<B, 2>>,
205 ) -> (Tensor<B, 3>, (Tensor<B, 2>, Tensor<B, 2>))
206 where
207 B: Backend,
208 {
209 if !self.mixed_memory {
210 panic!("Mixed memory not enabled. Call with_mixed_memory(true) first.");
211 }
212
213 let device = input.device();
214
215 let (batch_size, seq_len, _) = if self.batch_first {
217 let dims = input.dims();
218 (dims[0], dims[1], dims[2])
219 } else {
220 let dims = input.dims();
221 (dims[1], dims[0], dims[2])
222 };
223
224 let (mut h_state, mut c_state) = state.unwrap_or_else(|| {
226 (
227 Tensor::<B, 2>::zeros([batch_size, self.state_size], &device),
228 Tensor::<B, 2>::zeros([batch_size, self.state_size], &device),
229 )
230 });
231
232 let timespans =
234 timespans.unwrap_or_else(|| Tensor::<B, 2>::ones([batch_size, seq_len], &device));
235
236 let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
238
239 let lstm = self.lstm_cell.as_ref().expect("LSTM cell not initialized");
241
242 for t in 0..seq_len {
243 let step_input = if self.batch_first {
245 input.clone().narrow(1, t, 1).squeeze(1)
246 } else {
247 input.clone().narrow(0, t, 1).squeeze(0)
248 };
249
250 let step_time = timespans.clone().narrow(1, t, 1).squeeze(1);
252
253 let (new_h, new_c) = lstm.forward(step_input.clone(), (h_state, c_state));
255 h_state = new_h.clone();
256 c_state = new_c;
257
258 let (ltc_output, new_ltc_state) =
260 self.cell.forward(step_input, h_state.clone(), step_time);
261 h_state = new_ltc_state;
262
263 if self.return_sequences {
264 outputs.push(ltc_output);
265 } else if t == seq_len - 1 {
266 outputs.push(ltc_output);
267 }
268 }
269
270 let output = Tensor::stack(outputs, 1);
272 (output, (h_state, c_state))
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::wirings::{AutoNCP, FullyConnected};
280 use burn::backend::NdArray;
281 use burn::tensor::backend::Backend as BurnBackend;
282
283 type TestBackend = NdArray<f32>;
284 type TestDevice = <TestBackend as BurnBackend>::Device;
285
286 fn get_test_device() -> TestDevice {
287 Default::default()
288 }
289
290 #[test]
291 fn test_ltc_rnn_creation() {
292 let device = get_test_device();
293 let wiring = FullyConnected::new(50, None, 1234, true);
294
295 let ltc = LTC::<TestBackend>::new(20, wiring, &device);
296
297 assert_eq!(ltc.input_size(), 20);
298 assert_eq!(ltc.state_size(), 50);
299 }
300
301 #[test]
302 fn test_ltc_rnn_forward_batch_first() {
303 let device = get_test_device();
304 let wiring = FullyConnected::new(50, None, 1234, true);
305 let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_batch_first(true);
306
307 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
309
310 let (output, state) = ltc.forward(input, None, None);
311
312 assert_eq!(output.dims(), [4, 10, 50]);
314 assert_eq!(state.dims(), [4, 50]);
315 }
316
317 #[test]
318 fn test_ltc_rnn_forward_seq_first() {
319 let device = get_test_device();
320 let wiring = FullyConnected::new(50, None, 1234, true);
321 let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_batch_first(false);
322
323 let input = Tensor::<TestBackend, 3>::zeros([10, 4, 20], &device);
325
326 let (output, state) = ltc.forward(input, None, None);
327
328 assert_eq!(output.dims(), [4, 10, 50]);
330 }
331
332 #[test]
333 fn test_ltc_rnn_return_last_only() {
334 let device = get_test_device();
335 let wiring = FullyConnected::new(50, None, 1234, true);
336 let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_return_sequences(false);
337
338 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
339
340 let (output, state) = ltc.forward(input, None, None);
341
342 assert_eq!(output.dims(), [4, 1, 50]);
344 assert_eq!(state.dims(), [4, 50]);
345 }
346
347 #[test]
348 fn test_ltc_rnn_with_initial_state() {
349 let device = get_test_device();
350 let wiring = FullyConnected::new(50, None, 1234, true);
351 let ltc = LTC::<TestBackend>::new(20, wiring, &device);
352
353 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
354 let initial_state = Tensor::<TestBackend, 2>::ones([4, 50], &device);
355
356 let (output, state) = ltc.forward(input, Some(initial_state), None);
357
358 assert_eq!(output.dims(), [4, 10, 50]);
359 assert_eq!(state.dims(), [4, 50]);
360 }
361
362 #[test]
363 fn test_ltc_rnn_with_timespans() {
364 let device = get_test_device();
365 let wiring = FullyConnected::new(50, None, 1234, true);
366 let ltc = LTC::<TestBackend>::new(20, wiring, &device);
367
368 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
369 let timespans = Tensor::<TestBackend, 2>::full([4, 10], 0.5, &device);
371
372 let (output, state) = ltc.forward(input, None, Some(timespans));
373
374 assert_eq!(output.dims(), [4, 10, 50]);
375 assert_eq!(state.dims(), [4, 50]);
376 }
377
378 #[test]
379 fn test_ltc_rnn_with_ncp_wiring() {
380 let device = get_test_device();
381 let wiring = AutoNCP::new(64, 8, 0.5, 22222);
382 let ltc = LTC::<TestBackend>::new(20, wiring, &device);
383
384 let input = Tensor::<TestBackend, 3>::zeros([2, 5, 20], &device);
385 let (output, state) = ltc.forward(input, None, None);
386
387 assert_eq!(output.dims(), [2, 5, 8]);
389 assert_eq!(state.dims(), [2, 64]);
390 }
391
392 #[test]
393 fn test_ltc_rnn_sequence_processing() {
394 let device = get_test_device();
395 let wiring = FullyConnected::new(20, None, 1234, true);
396 let ltc = LTC::<TestBackend>::new(10, wiring, &device);
397
398 for seq_len in [1, 5, 20] {
400 let input = Tensor::<TestBackend, 3>::zeros([2, seq_len, 10], &device);
401 let (output, _) = ltc.forward(input, None, None);
402
403 assert_eq!(output.dims(), [2, seq_len, 20]);
404 }
405 }
406}