Skip to main content

ncps/rnn/
cfc.rs

1//! Closed-form Continuous-time (CfC) RNN Layer
2//!
3//! Full RNN layer that handles sequence processing, batching, and state management
4//! for CfC (Closed-form Continuous-time) cells.
5
6use crate::cells::CfCCell;
7use crate::wirings::Wiring;
8use burn::module::Module;
9use burn::nn::{Linear, LinearConfig};
10use burn::tensor::backend::Backend;
11use burn::tensor::Tensor;
12
13/// CfC RNN Layer
14///
15/// A full RNN layer that processes sequences using CfC cells.
16/// Supports batching, state management, different CfC modes, and optional projections.
17///
18/// # Type Parameters
19/// * `B` - The backend type
20#[derive(Module, Debug)]
21pub struct CfC<B: Backend> {
22    /// The CfC cell for processing individual timesteps
23    cell: CfCCell<B>,
24    /// Optional projection layer
25    proj: Option<Linear<B>>,
26    /// Input size (number of features)
27    #[module(skip)]
28    input_size: usize,
29    /// Hidden/output size
30    #[module(skip)]
31    hidden_size: usize,
32    /// Whether input is batch-first
33    #[module(skip)]
34    batch_first: bool,
35    /// Whether to return full sequence or just last timestep
36    #[module(skip)]
37    return_sequences: bool,
38    /// Projection size (if using NCP wiring)
39    #[module(skip)]
40    proj_size: Option<usize>,
41    /// Output size (hidden_size or proj_size)
42    #[module(skip)]
43    output_size: usize,
44}
45
46impl<B: Backend> CfC<B> {
47    /// Create a new CfC RNN layer with simple hidden size
48    ///
49    /// # Arguments
50    /// * `input_size` - Number of input features
51    /// * `hidden_size` - Number of hidden units
52    /// * `device` - Device to create the module on
53    pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
54        let cell = CfCCell::new(input_size, hidden_size, device);
55
56        Self {
57            cell,
58            proj: None,
59            input_size,
60            hidden_size,
61            batch_first: true,
62            return_sequences: true,
63            proj_size: None,
64            output_size: hidden_size,
65        }
66    }
67
68    /// Create a new CfC RNN layer with wiring configuration
69    ///
70    /// # Arguments
71    /// * `input_size` - Number of input features
72    /// * `wiring` - Wiring configuration (e.g., AutoNCP)
73    /// * `device` - Device to create the module on
74    pub fn with_wiring(input_size: usize, wiring: impl Wiring, device: &B::Device) -> Self {
75        let state_size = wiring.units();
76        let motor_size = wiring.output_dim().unwrap_or(state_size);
77
78        let cell = CfCCell::new(input_size, state_size, device);
79
80        let output_size = motor_size;
81
82        // Create projection layer if motor_size differs from state_size
83        let proj = if motor_size != state_size {
84            Some(
85                LinearConfig::new(state_size, motor_size)
86                    .with_bias(true)
87                    .init(device),
88            )
89        } else {
90            None
91        };
92
93        Self {
94            cell,
95            proj,
96            input_size,
97            hidden_size: state_size,
98            batch_first: true,
99            return_sequences: true,
100            proj_size: if motor_size != state_size {
101                Some(motor_size)
102            } else {
103                None
104            },
105            output_size,
106        }
107    }
108
109    /// Set whether input is batch-first (default: true)
110    pub fn with_batch_first(mut self, batch_first: bool) -> Self {
111        self.batch_first = batch_first;
112        self
113    }
114
115    /// Set whether to return full sequences (default: true)
116    pub fn with_return_sequences(mut self, return_sequences: bool) -> Self {
117        self.return_sequences = return_sequences;
118        self
119    }
120
121    /// Set projection size for motor outputs and create the projection layer
122    pub fn with_proj_size(mut self, proj_size: usize) -> Self {
123        let device = self.get_device();
124        self.proj = Some(
125            LinearConfig::new(self.hidden_size, proj_size)
126                .with_bias(true)
127                .init(&device),
128        );
129        self.proj_size = Some(proj_size);
130        self.output_size = proj_size;
131        self
132    }
133
134    /// Configure backbone - currently a no-op for API compatibility
135    pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
136        self
137    }
138
139    /// Helper method to get the device from the cell (defaults to CPU)
140    fn get_device(&self) -> B::Device {
141        B::Device::default()
142    }
143
144    /// Get input size
145    pub fn input_size(&self) -> usize {
146        self.input_size
147    }
148
149    /// Get hidden size
150    pub fn hidden_size(&self) -> usize {
151        self.hidden_size
152    }
153
154    /// Get output size (considering projection)
155    pub fn output_size(&self) -> usize {
156        self.output_size
157    }
158
159    /// Forward pass through the CfC RNN layer
160    ///
161    /// # Arguments
162    /// * `input` - Input tensor of shape:
163    ///   - 3D: [batch, seq, features] if batch_first=true
164    ///   - 3D: [seq, batch, features] if batch_first=false
165    /// * `state` - Optional initial hidden state tensor of shape [batch, hidden_size]
166    /// * `timespans` - Optional time intervals (scalar used for all timesteps if None)
167    ///
168    /// # Returns
169    /// Tuple of (output, final_state) where:
170    /// - output: [batch, seq, output_size] or [batch, output_size] depending on return_sequences
171    /// - final_state: [batch, hidden_size]
172    pub fn forward(
173        &self,
174        input: Tensor<B, 3>,
175        state: Option<Tensor<B, 2>>,
176        _timespans: Option<Tensor<B, 2>>,
177    ) -> (Tensor<B, 3>, Tensor<B, 2>) {
178        let device = input.device();
179
180        // Get dimensions
181        let (batch_size, seq_len, _) = if self.batch_first {
182            let dims = input.dims();
183            (dims[0], dims[1], dims[2])
184        } else {
185            let dims = input.dims();
186            (dims[1], dims[0], dims[2])
187        };
188
189        // Initialize state if not provided
190        let mut current_state =
191            state.unwrap_or_else(|| Tensor::<B, 2>::zeros([batch_size, self.hidden_size], &device));
192
193        // Collect outputs for each timestep
194        let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
195
196        for t in 0..seq_len {
197            // Extract input for this timestep
198            let step_input = if self.batch_first {
199                // input[batch, t, features] -> [batch, features]
200                input.clone().narrow(1, t, 1).squeeze(1)
201            } else {
202                // input[t, batch, features] -> [batch, features]
203                input.clone().narrow(0, t, 1).squeeze(0)
204            };
205
206            // Forward through CfC cell (ts defaults to 1.0)
207            let (mut output, new_state) = self.cell.forward(step_input, current_state, 1.0);
208            current_state = new_state;
209
210            // Apply projection if configured
211            if let Some(ref proj) = self.proj {
212                output = proj.forward(output);
213            }
214
215            if self.return_sequences {
216                outputs.push(output);
217            } else if t == seq_len - 1 {
218                // Only keep last output if not returning sequences
219                outputs.push(output);
220            }
221        }
222
223        // Stack outputs into final tensor
224        let output = Tensor::stack(outputs, 1); // [batch, seq, output_size]
225        (output, current_state)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::wirings::{AutoNCP, FullyConnected};
233    use burn::backend::NdArray;
234    use burn::tensor::backend::Backend as BurnBackend;
235
236    type TestBackend = NdArray<f32>;
237    type TestDevice = <TestBackend as BurnBackend>::Device;
238
239    fn get_test_device() -> TestDevice {
240        Default::default()
241    }
242
243    #[test]
244    fn test_cfc_rnn_creation() {
245        let device = get_test_device();
246        let cfc = CfC::<TestBackend>::new(20, 50, &device);
247
248        assert_eq!(cfc.input_size(), 20);
249        assert_eq!(cfc.hidden_size(), 50);
250        assert_eq!(cfc.output_size(), 50);
251    }
252
253    #[test]
254    fn test_cfc_rnn_with_wiring() {
255        let device = get_test_device();
256        let wiring = AutoNCP::new(32, 8, 0.5, 22222);
257        let cfc = CfC::<TestBackend>::with_wiring(20, wiring, &device);
258
259        assert_eq!(cfc.output_size(), 8);
260    }
261
262    #[test]
263    fn test_cfc_rnn_forward() {
264        let device = get_test_device();
265        let cfc = CfC::<TestBackend>::new(20, 50, &device);
266
267        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
268        let (output, state) = cfc.forward(input, None, None);
269
270        assert_eq!(output.dims(), [4, 10, 50]);
271        assert_eq!(state.dims(), [4, 50]);
272    }
273
274    #[test]
275    fn test_cfc_rnn_with_projection() {
276        let device = get_test_device();
277        let cfc = CfC::<TestBackend>::new(20, 50, &device).with_proj_size(10);
278
279        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
280        let (output, _) = cfc.forward(input, None, None);
281
282        // Output should be projected to 10
283        assert_eq!(output.dims(), [4, 10, 10]);
284        assert_eq!(cfc.output_size(), 10);
285    }
286
287    #[test]
288    fn test_cfc_rnn_backbone_config() {
289        let device = get_test_device();
290        let cfc = CfC::<TestBackend>::new(20, 50, &device).with_backbone(128, 2, 0.1);
291
292        let input = Tensor::<TestBackend, 3>::zeros([2, 5, 20], &device);
293        let (output, _) = cfc.forward(input, None, None);
294
295        assert_eq!(output.dims(), [2, 5, 50]);
296    }
297
298    #[test]
299    fn test_cfc_rnn_return_last_only() {
300        let device = get_test_device();
301        let cfc = CfC::<TestBackend>::new(20, 50, &device).with_return_sequences(false);
302
303        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
304        let (output, state) = cfc.forward(input, None, None);
305
306        // Should return [batch, 1, hidden_size]
307        assert_eq!(output.dims(), [4, 1, 50]);
308        assert_eq!(state.dims(), [4, 50]);
309    }
310
311    #[test]
312    fn test_cfc_rnn_seq_first() {
313        let device = get_test_device();
314        let cfc = CfC::<TestBackend>::new(20, 50, &device).with_batch_first(false);
315
316        // [seq, batch, features]
317        let input = Tensor::<TestBackend, 3>::zeros([10, 4, 20], &device);
318        let (output, state) = cfc.forward(input, None, None);
319
320        assert_eq!(output.dims(), [4, 10, 50]);
321        assert_eq!(state.dims(), [4, 50]);
322    }
323
324    #[test]
325    fn test_cfc_rnn_with_initial_state() {
326        let device = get_test_device();
327        let cfc = CfC::<TestBackend>::new(20, 50, &device);
328
329        let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
330        let initial_state = Tensor::<TestBackend, 2>::ones([4, 50], &device);
331
332        let (output, state) = cfc.forward(input, Some(initial_state), None);
333
334        assert_eq!(output.dims(), [4, 10, 50]);
335        assert_eq!(state.dims(), [4, 50]);
336    }
337}