Skip to main content

ncps/rnn/
ltc.rs

1//! Liquid Time-Constant (LTC) RNN Layer
2//!
3//! Full RNN layer that handles sequence processing, batching, and state management
4//! for LTC (Liquid Time-Constant) cells.
5
6use crate::cells::LSTMCell;
7use crate::cells::LTCCell;
8use crate::wirings::Wiring;
9use burn::module::Module;
10use burn::tensor::backend::Backend;
11use burn::tensor::Tensor;
12
13/// LTC RNN Layer
14///
15/// A full RNN layer that processes sequences using LTC cells.
16/// Supports batching, state management, mixed memory (LSTM), and variable timespans.
17///
18/// # Type Parameters
19/// * `B` - The backend type
20#[derive(Module, Debug)]
21pub struct LTC<B: Backend> {
22    /// The LTC cell for processing individual timesteps
23    cell: LTCCell<B>,
24    /// Optional LSTM cell for mixed memory mode
25    #[module(skip)]
26    lstm_cell: Option<LSTMCell<B>>,
27    /// Input size (number of features)
28    #[module(skip)]
29    input_size: usize,
30    /// State size (number of neurons)
31    #[module(skip)]
32    state_size: usize,
33    /// Motor/output size (from wiring)
34    #[module(skip)]
35    motor_size: usize,
36    /// Whether input is batch-first (batch, seq, features) vs (seq, batch, features)
37    #[module(skip)]
38    batch_first: bool,
39    /// Whether to return full sequence or just last timestep
40    #[module(skip)]
41    return_sequences: bool,
42    /// Whether to use mixed memory (LSTM augmentation)
43    #[module(skip)]
44    mixed_memory: bool,
45}
46
47impl<B: Backend> LTC<B> {
48    /// Create a new LTC RNN layer with the given wiring
49    ///
50    /// # Arguments
51    /// * `input_size` - Number of input features
52    /// * `wiring` - Wiring configuration defining the network structure
53    /// * `device` - Device to create the module on
54    pub fn new(input_size: usize, wiring: impl Wiring, device: &B::Device) -> Self {
55        let state_size = wiring.units();
56        let motor_size = wiring.output_dim().unwrap_or(state_size);
57
58        let cell = LTCCell::new(&wiring, Some(input_size), device);
59
60        Self {
61            cell,
62            lstm_cell: None,
63            input_size,
64            state_size,
65            motor_size,
66            batch_first: true,
67            return_sequences: true,
68            mixed_memory: false,
69        }
70    }
71
72    /// Set whether input is batch-first (default: true)
73    ///
74    /// When true: input shape is [batch, seq, features]
75    /// When false: input shape is [seq, batch, features]
76    pub fn with_batch_first(mut self, batch_first: bool) -> Self {
77        self.batch_first = batch_first;
78        self
79    }
80
81    /// Set whether to return full sequences (default: true)
82    ///
83    /// When true: returns all timesteps [batch, seq, state_size]
84    /// When false: returns only last timestep [batch, state_size]
85    pub fn with_return_sequences(mut self, return_sequences: bool) -> Self {
86        self.return_sequences = return_sequences;
87        self
88    }
89
90    /// Enable or disable mixed memory mode (LSTM augmentation)
91    ///
92    /// When enabled, an LSTM cell processes the LTC output for better long-term memory.
93    /// The LSTM cell is initialized when this is called with `true`.
94    ///
95    /// # Arguments
96    /// * `mixed_memory` - Whether to enable mixed memory mode
97    /// * `device` - Device to create the LSTM cell on (required when enabling)
98    pub fn with_mixed_memory(mut self, mixed_memory: bool, device: &B::Device) -> Self {
99        self.mixed_memory = mixed_memory;
100        if mixed_memory && self.lstm_cell.is_none() {
101            // Create LSTM cell: input_size -> state_size
102            self.lstm_cell = Some(LSTMCell::new(self.input_size, self.state_size, device));
103        }
104        self
105    }
106
107    /// Get input size
108    pub fn input_size(&self) -> usize {
109        self.input_size
110    }
111
112    /// Get state size (number of neurons)
113    pub fn state_size(&self) -> usize {
114        self.state_size
115    }
116
117    /// Get motor/output size
118    pub fn motor_size(&self) -> usize {
119        self.motor_size
120    }
121
122    /// Forward pass through the LTC RNN layer
123    ///
124    /// # Arguments
125    /// * `input` - Input tensor of shape:
126    ///   - 3D batched: [batch, seq, features] if batch_first=true
127    ///   - 3D batched: [seq, batch, features] if batch_first=false
128    ///   - 2D unbatched: [seq, features]
129    /// * `state` - Optional initial state tensor of shape [batch, state_size]
130    /// * `timespans` - Optional time intervals tensor of shape [batch, seq] or scalar
131    ///
132    /// # Returns
133    /// Tuple of (output, final_state) where:
134    /// - output: [batch, seq, motor_size] or [batch, motor_size] depending on return_sequences
135    /// - final_state: [batch, state_size] or ([batch, state_size], [batch, state_size]) for mixed_memory
136    pub fn forward(
137        &self,
138        input: Tensor<B, 3>,
139        state: Option<Tensor<B, 2>>,
140        timespans: Option<Tensor<B, 2>>,
141    ) -> (Tensor<B, 3>, Tensor<B, 2>) {
142        let device = input.device();
143
144        // Get dimensions
145        let (batch_size, seq_len, _) = if self.batch_first {
146            let dims = input.dims();
147            (dims[0], dims[1], dims[2])
148        } else {
149            let dims = input.dims();
150            (dims[1], dims[0], dims[2])
151        };
152
153        // Initialize state if not provided
154        let mut current_state =
155            state.unwrap_or_else(|| Tensor::<B, 2>::zeros([batch_size, self.state_size], &device));
156
157        // Default timespans (all ones)
158        let timespans =
159            timespans.unwrap_or_else(|| Tensor::<B, 2>::ones([batch_size, seq_len], &device));
160
161        // Collect outputs for each timestep
162        let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
163
164        for t in 0..seq_len {
165            // Extract input for this timestep
166            let step_input = if self.batch_first {
167                // input[batch, t, features] -> [batch, features]
168                input.clone().narrow(1, t, 1).squeeze(1)
169            } else {
170                // input[t, batch, features] -> [batch, features]
171                input.clone().narrow(0, t, 1).squeeze(0)
172            };
173
174            // Extract timespan for this timestep
175            let step_time = timespans.clone().narrow(1, t, 1).squeeze(1);
176
177            // Forward through LTC cell
178            let (output, new_state) = self.cell.forward(step_input, current_state, step_time);
179            current_state = new_state;
180
181            if self.return_sequences {
182                outputs.push(output);
183            } else if t == seq_len - 1 {
184                // Only keep last output if not returning sequences
185                outputs.push(output);
186            }
187        }
188
189        // Stack outputs into final tensor
190        let output = Tensor::stack(outputs, 1); // [batch, seq, motor_size]
191        (output, current_state)
192    }
193
194    /// Forward pass with mixed memory (LSTM augmentation)
195    ///
196    /// This follows the Python implementation order: LSTM first (for memory),
197    /// then LTC (for continuous-time dynamics).
198    ///
199    /// This is only available when mixed_memory is enabled
200    pub fn forward_mixed(
201        &self,
202        input: Tensor<B, 3>,
203        state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
204        timespans: Option<Tensor<B, 2>>,
205    ) -> (Tensor<B, 3>, (Tensor<B, 2>, Tensor<B, 2>))
206    where
207        B: Backend,
208    {
209        if !self.mixed_memory {
210            panic!("Mixed memory not enabled. Call with_mixed_memory(true) first.");
211        }
212
213        let device = input.device();
214
215        // Get dimensions
216        let (batch_size, seq_len, _) = if self.batch_first {
217            let dims = input.dims();
218            (dims[0], dims[1], dims[2])
219        } else {
220            let dims = input.dims();
221            (dims[1], dims[0], dims[2])
222        };
223
224        // Initialize states if not provided
225        let (mut h_state, mut c_state) = state.unwrap_or_else(|| {
226            (
227                Tensor::<B, 2>::zeros([batch_size, self.state_size], &device),
228                Tensor::<B, 2>::zeros([batch_size, self.state_size], &device),
229            )
230        });
231
232        // Default timespans
233        let timespans =
234            timespans.unwrap_or_else(|| Tensor::<B, 2>::ones([batch_size, seq_len], &device));
235
236        // Collect outputs
237        let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
238
239        // Get LSTM cell reference (it should exist if mixed_memory is true)
240        let lstm = self.lstm_cell.as_ref().expect("LSTM cell not initialized");
241
242        for t in 0..seq_len {
243            // Extract input for this timestep
244            let step_input = if self.batch_first {
245                input.clone().narrow(1, t, 1).squeeze(1)
246            } else {
247                input.clone().narrow(0, t, 1).squeeze(0)
248            };
249
250            // Extract timespan
251            let step_time = timespans.clone().narrow(1, t, 1).squeeze(1);
252
253            // FIRST: Forward through LSTM for memory (matches Python implementation)
254            let (new_h, new_c) = lstm.forward(step_input.clone(), (h_state, c_state));
255            h_state = new_h.clone();
256            c_state = new_c;
257
258            // SECOND: Forward through LTC cell with LSTM hidden state
259            let (ltc_output, new_ltc_state) =
260                self.cell.forward(step_input, h_state.clone(), step_time);
261            h_state = new_ltc_state;
262
263            if self.return_sequences {
264                outputs.push(ltc_output);
265            } else if t == seq_len - 1 {
266                outputs.push(ltc_output);
267            }
268        }
269
270        // Stack outputs
271        let output = Tensor::stack(outputs, 1);
272        (output, (h_state, c_state))
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::wirings::{AutoNCP, FullyConnected};
280    use burn::backend::NdArray;
281    use burn::tensor::backend::Backend as BurnBackend;
282
283    type TestBackend = NdArray<f32>;
284    type TestDevice = <TestBackend as BurnBackend>::Device;
285
286    fn get_test_device() -> TestDevice {
287        Default::default()
288    }
289
290    #[test]
291    fn test_ltc_rnn_creation() {
292        let device = get_test_device();
293        let wiring = FullyConnected::new(50, None, 1234, true);
294
295        let ltc = LTC::<TestBackend>::new(20, wiring, &device);
296
297        assert_eq!(ltc.input_size(), 20);
298        assert_eq!(ltc.state_size(), 50);
299    }
300
301    #[test]
302    fn test_ltc_rnn_forward_batch_first() {
303        let device = get_test_device();
304        let wiring = FullyConnected::new(50, None, 1234, true);
305        let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_batch_first(true);
306
307        // [batch, seq, features]
308        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
309
310        let (output, state) = ltc.forward(input, None, None);
311
312        // [batch, seq, state_size]
313        assert_eq!(output.dims(), [4, 10, 50]);
314        assert_eq!(state.dims(), [4, 50]);
315    }
316
317    #[test]
318    fn test_ltc_rnn_forward_seq_first() {
319        let device = get_test_device();
320        let wiring = FullyConnected::new(50, None, 1234, true);
321        let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_batch_first(false);
322
323        // [seq, batch, features]
324        let input = Tensor::<TestBackend, 3>::zeros([10, 4, 20], &device);
325
326        let (output, state) = ltc.forward(input, None, None);
327
328        // Output is always [batch, seq, state_size] for consistency
329        assert_eq!(output.dims(), [4, 10, 50]);
330    }
331
332    #[test]
333    fn test_ltc_rnn_return_last_only() {
334        let device = get_test_device();
335        let wiring = FullyConnected::new(50, None, 1234, true);
336        let ltc = LTC::<TestBackend>::new(20, wiring, &device).with_return_sequences(false);
337
338        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
339
340        let (output, state) = ltc.forward(input, None, None);
341
342        // When return_sequences=false, we still return 3D with seq=1
343        assert_eq!(output.dims(), [4, 1, 50]);
344        assert_eq!(state.dims(), [4, 50]);
345    }
346
347    #[test]
348    fn test_ltc_rnn_with_initial_state() {
349        let device = get_test_device();
350        let wiring = FullyConnected::new(50, None, 1234, true);
351        let ltc = LTC::<TestBackend>::new(20, wiring, &device);
352
353        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
354        let initial_state = Tensor::<TestBackend, 2>::ones([4, 50], &device);
355
356        let (output, state) = ltc.forward(input, Some(initial_state), None);
357
358        assert_eq!(output.dims(), [4, 10, 50]);
359        assert_eq!(state.dims(), [4, 50]);
360    }
361
362    #[test]
363    fn test_ltc_rnn_with_timespans() {
364        let device = get_test_device();
365        let wiring = FullyConnected::new(50, None, 1234, true);
366        let ltc = LTC::<TestBackend>::new(20, wiring, &device);
367
368        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
369        // Variable time intervals
370        let timespans = Tensor::<TestBackend, 2>::full([4, 10], 0.5, &device);
371
372        let (output, state) = ltc.forward(input, None, Some(timespans));
373
374        assert_eq!(output.dims(), [4, 10, 50]);
375        assert_eq!(state.dims(), [4, 50]);
376    }
377
378    #[test]
379    fn test_ltc_rnn_with_ncp_wiring() {
380        let device = get_test_device();
381        let wiring = AutoNCP::new(64, 8, 0.5, 22222);
382        let ltc = LTC::<TestBackend>::new(20, wiring, &device);
383
384        let input = Tensor::<TestBackend, 3>::zeros([2, 5, 20], &device);
385        let (output, state) = ltc.forward(input, None, None);
386
387        // Output should be motor_size (8)
388        assert_eq!(output.dims(), [2, 5, 8]);
389        assert_eq!(state.dims(), [2, 64]);
390    }
391
392    #[test]
393    fn test_ltc_rnn_sequence_processing() {
394        let device = get_test_device();
395        let wiring = FullyConnected::new(20, None, 1234, true);
396        let ltc = LTC::<TestBackend>::new(10, wiring, &device);
397
398        // Test different sequence lengths
399        for seq_len in [1, 5, 20] {
400            let input = Tensor::<TestBackend, 3>::zeros([2, seq_len, 10], &device);
401            let (output, _) = ltc.forward(input, None, None);
402
403            assert_eq!(output.dims(), [2, seq_len, 20]);
404        }
405    }
406}