Skip to main content

ncps/cells/
cfc_cell.rs

1//! Closed-form Continuous-time (CfC) Cell Implementation
2//!
3//! The CfC cell is a fast approximation of the LTC (Liquid Time-Constant) cell.
4//! It provides closed-form solutions to continuous-time neural dynamics without
5//! requiring iterative ODE solvers.
6//!
7//! Three modes are supported:
8//! - **Default**: Gated interpolation between two feedforward paths
9//! - **Pure**: Direct ODE solution without gating
10//! - **NoGate**: Simplified gating with addition instead of interpolation
11
12use burn::module::{Module, Param};
13use burn::nn::{Linear, LinearConfig};
14use burn::tensor::activation;
15use burn::tensor::backend::Backend;
16use burn::tensor::Tensor;
17use ndarray::Array2;
18
19/// Operating modes for CfC cells.
20///
21/// Each mode implements a different approach to computing the hidden state update.
22/// The choice affects both computational cost and model behavior.
23///
24/// # Mode Comparison
25///
26/// | Mode | Formula | Best For |
27/// |------|---------|----------|
28/// | [`Default`](CfcMode::Default) | `h = tanh(ff1) × (1-σ) + tanh(ff2) × σ` | Most tasks (recommended) |
29/// | [`Pure`](CfcMode::Pure) | `h = a - a × exp(-t × \|w_τ\| + \|ff1\|) × ff1` | Biological accuracy |
30/// | [`NoGate`](CfcMode::NoGate) | `h = tanh(ff1) + tanh(ff2) × σ` | Simple dynamics |
31///
32/// # Choosing a Mode
33///
34/// ## Default (Recommended)
35///
36/// The **gated interpolation** mode provides the best balance of:
37/// - Expressive power (two separate paths)
38/// - Training stability (sigmoid gating)
39/// - Flexibility (learns when to use each path)
40///
41/// Use this unless you have a specific reason not to.
42///
43/// ## Pure
44///
45/// The **ODE-based** mode is closer to the original LTC formulation:
46/// - More biologically plausible
47/// - Direct continuous-time dynamics
48/// - Can be less stable during training
49/// - Fewer parameters (no ff2, time_a, time_b)
50///
51/// Use when biological accuracy matters or for comparison studies.
52///
53/// ## NoGate
54///
55/// The **additive** mode simplifies the gating mechanism:
56/// - Adds instead of interpolates
57/// - Can be easier to train on simple tasks
58/// - May underfit on complex tasks
59///
60/// Use for simple sequential patterns or ablation studies.
61#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
62pub enum CfcMode {
63    /// **Recommended.** Gated interpolation between two feedforward paths.
64    ///
65    /// Formula: `h = tanh(ff1) × (1 - σ(t)) + tanh(ff2) × σ(t)`
66    ///
67    /// The time-dependent sigmoid `σ(t) = sigmoid(time_a × t + time_b)` controls
68    /// how much each path contributes to the output.
69    Default = 0,
70
71    /// Pure ODE solution without gating. More biologically plausible.
72    ///
73    /// Formula: `h = a - a × exp(-t × (|w_τ| + |ff1|)) × ff1`
74    ///
75    /// Uses learned time constants `w_τ` and amplitude `a`.
76    /// Has fewer parameters than Default mode.
77    Pure = 1,
78
79    /// Simplified additive mode (no interpolation).
80    ///
81    /// Formula: `h = tanh(ff1) + tanh(ff2) × σ(t)`
82    ///
83    /// Adds the gated component instead of interpolating.
84    /// Can work well for simpler tasks.
85    NoGate = 2,
86}
87
88/// A Closed-form Continuous-time cell
89///
90/// This is an RNNCell that processes single time-steps. To get a full RNN
91/// that can process sequences, see the full RNN layer implementation.
92///
93/// # Type Parameters
94/// * `B` - The backend type
95#[derive(Module, Debug)]
96pub struct CfCCell<B: Backend> {
97    #[module(skip)]
98    input_size: usize,
99    #[module(skip)]
100    hidden_size: usize,
101    /// Mode: 0=Default, 1=Pure, 2=NoGate
102    #[module(skip)]
103    mode: u8,
104    /// Whether sparsity mask is enabled
105    #[module(skip)]
106    has_sparsity_mask: bool,
107    ff1: Linear<B>,
108    ff2: Option<Linear<B>>,
109    time_a: Option<Linear<B>>,
110    time_b: Option<Linear<B>>,
111    w_tau: Option<Linear<B>>,
112    a: Option<Linear<B>>,
113    /// Sparsity mask for output (transposed from input mask)
114    sparsity_mask: Option<Param<Tensor<B, 2>>>,
115}
116
117impl<B: Backend> CfCCell<B> {
118    /// Create a new CfC cell
119    pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
120        let ff1 = LinearConfig::new(input_size + hidden_size, hidden_size)
121            .with_bias(true)
122            .init(device);
123
124        let ff2 = LinearConfig::new(input_size + hidden_size, hidden_size)
125            .with_bias(true)
126            .init(device);
127
128        let time_a = LinearConfig::new(input_size + hidden_size, hidden_size)
129            .with_bias(true)
130            .init(device);
131
132        let time_b = LinearConfig::new(input_size + hidden_size, hidden_size)
133            .with_bias(true)
134            .init(device);
135
136        Self {
137            input_size,
138            hidden_size,
139            mode: 0, // Default
140            has_sparsity_mask: false,
141            ff1,
142            ff2: Some(ff2),
143            time_a: Some(time_a),
144            time_b: Some(time_b),
145            w_tau: None,
146            a: None,
147            sparsity_mask: None,
148        }
149    }
150
151    /// Set the CfC mode (Default, Pure, or NoGate)
152    pub fn with_mode(mut self, mode: CfcMode) -> Self {
153        self.mode = match mode {
154            CfcMode::Default => 0,
155            CfcMode::Pure => 1,
156            CfcMode::NoGate => 2,
157        };
158        self.reconfigure_for_mode();
159        self
160    }
161
162    /// Configure backbone (currently a no-op, kept for API compatibility)
163    pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
164        // Backbone support would require dynamic layer sizing
165        // For now, we keep the simple version working
166        self
167    }
168
169    /// Set the backbone activation (currently a no-op, kept for API compatibility)
170    pub fn with_activation(self, activation: &str) -> Self {
171        let valid_activations = ["relu", "tanh", "gelu", "silu", "lecun_tanh"];
172        if !valid_activations.contains(&activation) {
173            panic!(
174                "Unknown activation: {}. Valid options are {:?}",
175                activation, valid_activations
176            );
177        }
178        self
179    }
180
181    /// Set a sparsity mask to enforce wiring connectivity
182    ///
183    /// The mask should have shape [hidden_size, hidden_size] and contain
184    /// 0s for blocked connections and 1s for allowed connections.
185    /// Note: The mask is transposed internally to match PyTorch convention.
186    pub fn with_sparsity_mask(mut self, mask: Array2<f32>, device: &B::Device) -> Self {
187        let shape = mask.shape();
188        // Transpose the mask to match PyTorch's convention (sparsity_mask.T)
189        let transposed = mask.t();
190        let data: Vec<f32> = transposed.iter().map(|&x| x.abs()).collect();
191        let tensor: Tensor<B, 2> =
192            Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[1], shape[0]]);
193        self.sparsity_mask = Some(Param::from_tensor(tensor));
194        self.has_sparsity_mask = true;
195        self
196    }
197
198    /// Create a CfC cell from a wiring configuration
199    pub fn from_wiring(
200        input_size: usize,
201        wiring: &dyn crate::wirings::Wiring,
202        device: &B::Device,
203    ) -> Self {
204        let hidden_size = wiring.units();
205        let mut cell = Self::new(input_size, hidden_size, device);
206
207        // Apply sparsity mask from adjacency matrix
208        let adj_matrix = wiring.adjacency_matrix();
209        let shape = adj_matrix.shape();
210        let data: Vec<f32> = adj_matrix.iter().map(|&x| x.abs() as f32).collect();
211        let mask_tensor: Tensor<B, 2> =
212            Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
213        cell.sparsity_mask = Some(Param::from_tensor(mask_tensor));
214        cell.has_sparsity_mask = true;
215
216        cell
217    }
218
219    /// Get input size
220    pub fn input_size(&self) -> usize {
221        self.input_size
222    }
223
224    /// Get hidden size
225    pub fn hidden_size(&self) -> usize {
226        self.hidden_size
227    }
228
229    /// Get current mode
230    pub fn mode(&self) -> CfcMode {
231        match self.mode {
232            0 => CfcMode::Default,
233            1 => CfcMode::Pure,
234            2 => CfcMode::NoGate,
235            _ => CfcMode::Default,
236        }
237    }
238
239    fn reconfigure_for_mode(&mut self) {
240        let device = self.ff1.weight.device();
241
242        match self.mode {
243            1 => {
244                // Pure mode: use w_tau and a, remove ff2/time parameters
245                self.ff2 = None;
246                self.time_a = None;
247                self.time_b = None;
248
249                self.w_tau = Some(
250                    LinearConfig::new(1, self.hidden_size)
251                        .with_bias(false)
252                        .init(&device),
253                );
254                self.a = Some(
255                    LinearConfig::new(1, self.hidden_size)
256                        .with_bias(false)
257                        .init(&device),
258                );
259            }
260            _ => {
261                // Default/NoGate mode: ensure ff2, time_a, time_b exist
262                if self.ff2.is_none() {
263                    self.ff2 = Some(
264                        LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
265                            .with_bias(true)
266                            .init(&device),
267                    );
268                }
269                if self.time_a.is_none() {
270                    self.time_a = Some(
271                        LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
272                            .with_bias(true)
273                            .init(&device),
274                    );
275                }
276                if self.time_b.is_none() {
277                    self.time_b = Some(
278                        LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
279                            .with_bias(true)
280                            .init(&device),
281                    );
282                }
283                self.w_tau = None;
284                self.a = None;
285            }
286        }
287        // Note: sparsity_mask is preserved across mode changes
288    }
289
290    /// Check if this cell has a sparsity mask
291    pub fn has_sparsity_mask(&self) -> bool {
292        self.has_sparsity_mask
293    }
294
295    /// Apply sparsity mask to a tensor if mask exists
296    fn apply_sparsity_mask(&self, tensor: Tensor<B, 2>) -> Tensor<B, 2> {
297        if let Some(ref mask) = self.sparsity_mask {
298            // Mask shape is [hidden_size, hidden_size], we need to broadcast
299            // For output masking, we just multiply element-wise with the diagonal
300            // or apply the full mask if needed
301            let mask_val = mask.val();
302            let [batch_size, hidden_size] = tensor.dims();
303
304            // For simple sparsity, we take the row sums as a per-neuron mask
305            // This approximates the effect of masked weights
306            let row_mask: Tensor<B, 1> = mask_val.clone().sum_dim(1).squeeze(1);
307            let row_mask_normalized = row_mask.div_scalar(hidden_size as f32);
308            let mask_expanded = row_mask_normalized.unsqueeze::<2>().expand([batch_size, hidden_size]);
309
310            tensor.mul(mask_expanded)
311        } else {
312            tensor
313        }
314    }
315
316    /// Perform a forward pass through the CfC cell
317    pub fn forward(
318        &self,
319        input: Tensor<B, 2>,
320        hx: Tensor<B, 2>,
321        ts: f32,
322    ) -> (Tensor<B, 2>, Tensor<B, 2>) {
323        let batch_size = input.dims()[0];
324        let device = input.device();
325
326        // Concatenate input and hidden state
327        let x = Tensor::cat(vec![input, hx], 1);
328
329        // Compute ff1 and apply sparsity mask
330        let ff1_out = self.ff1.forward(x.clone());
331        let ff1_out = self.apply_sparsity_mask(ff1_out);
332
333        match self.mode {
334            1 => {
335                // Pure mode
336                let w_tau_layer = self.w_tau.as_ref().unwrap();
337                let a_layer = self.a.as_ref().unwrap();
338
339                let ones_input = Tensor::<B, 2>::ones([batch_size, 1], &device);
340                let w_tau_out = w_tau_layer.forward(ones_input.clone());
341                let a_out = a_layer.forward(ones_input);
342
343                let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
344                let abs_w_tau = w_tau_out.abs();
345                let abs_ff1 = ff1_out.clone().abs();
346
347                let exp_term = (ts_tensor * (abs_w_tau + abs_ff1)).neg().exp();
348                let new_hidden = a_out.clone() - a_out * exp_term * ff1_out;
349
350                (new_hidden.clone(), new_hidden)
351            }
352            _ => {
353                // Default or NoGate mode
354                let ff2_out = self.ff2.as_ref().unwrap().forward(x.clone());
355                let ff2_out = self.apply_sparsity_mask(ff2_out);
356
357                let ff1_tanh = ff1_out.tanh();
358                let ff2_tanh = ff2_out.tanh();
359
360                let time_a = self.time_a.as_ref().unwrap().forward(x.clone());
361                let time_b = self.time_b.as_ref().unwrap().forward(x);
362
363                // Compute time interpolation
364                let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
365                let t_interp = activation::sigmoid(time_a * ts_tensor + time_b);
366
367                let new_hidden = if self.mode == 2 {
368                    // NoGate: h = ff1 + t_interp * ff2
369                    ff1_tanh + t_interp * ff2_tanh
370                } else {
371                    // Default: h = ff1 * (1 - t_interp) + t_interp * ff2
372                    ff1_tanh
373                        * (Tensor::<B, 2>::ones([batch_size, self.hidden_size], &device)
374                            - t_interp.clone())
375                        + t_interp * ff2_tanh
376                };
377
378                (new_hidden.clone(), new_hidden)
379            }
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use burn::backend::NdArray;
388    use burn::tensor::backend::Backend as BurnBackend;
389
390    type TestBackend = NdArray<f32>;
391    type TestDevice = <TestBackend as BurnBackend>::Device;
392
393    fn get_test_device() -> TestDevice {
394        Default::default()
395    }
396
397    #[test]
398    fn test_cfc_cell_creation() {
399        let device = get_test_device();
400        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
401
402        assert_eq!(cell.input_size(), 20);
403        assert_eq!(cell.hidden_size(), 50);
404        assert_eq!(cell.mode(), CfcMode::Default);
405    }
406
407    #[test]
408    fn test_cfc_forward_default() {
409        let device = get_test_device();
410        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
411
412        let batch_size = 4;
413        let input = Tensor::<TestBackend, 2>::zeros([batch_size, 20], &device);
414        let hx = Tensor::<TestBackend, 2>::zeros([batch_size, 50], &device);
415
416        let (output, new_hidden) = cell.forward(input, hx, 1.0);
417
418        assert_eq!(output.dims(), [batch_size, 50]);
419        assert_eq!(new_hidden.dims(), [batch_size, 50]);
420    }
421
422    #[test]
423    fn test_cfc_forward_pure() {
424        let device = get_test_device();
425        let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::Pure);
426
427        assert_eq!(cell.mode(), CfcMode::Pure);
428
429        let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
430        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
431
432        let (output, _) = cell.forward(input, hx, 1.0);
433
434        assert_eq!(output.dims(), [2, 50]);
435    }
436
437    #[test]
438    fn test_cfc_forward_no_gate() {
439        let device = get_test_device();
440        let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
441
442        assert_eq!(cell.mode(), CfcMode::NoGate);
443
444        let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
445        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
446
447        let (output, new_hidden) = cell.forward(input, hx, 1.0);
448
449        assert_eq!(output.dims(), [2, 50]);
450        assert_eq!(new_hidden.dims(), [2, 50]);
451    }
452
453    #[test]
454    fn test_cfc_state_change() {
455        let device = get_test_device();
456        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
457
458        let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
459        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
460
461        let (output, new_hidden) = cell.forward(input, hx.clone(), 1.0);
462
463        // State should have changed
464        let diff = (new_hidden.clone() - hx).abs().mean().into_scalar();
465        assert!(diff > 0.0, "State should change after forward pass");
466
467        // Output should equal new_hidden for CfC
468        let output_diff = (output - new_hidden).abs().mean().into_scalar();
469        assert!(output_diff < 1e-6, "Output should equal new_hidden");
470    }
471
472    #[test]
473    fn test_cfc_different_modes_produce_different_results() {
474        let device = get_test_device();
475
476        let cell_default = CfCCell::<TestBackend>::new(20, 50, &device);
477        let cell_no_gate = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
478
479        let input = Tensor::<TestBackend, 2>::random(
480            [2, 20],
481            burn::tensor::Distribution::Uniform(-1.0, 1.0),
482            &device,
483        );
484        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
485
486        let (out1, _) = cell_default.forward(input.clone(), hx.clone(), 1.0);
487        let (out2, _) = cell_no_gate.forward(input, hx, 1.0);
488
489        let diff = (out1 - out2).abs().mean().into_scalar();
490        assert!(
491            diff > 0.01,
492            "Different modes should produce different outputs"
493        );
494    }
495
496    #[test]
497    fn test_cfc_backbone_configurations() {
498        let device = get_test_device();
499
500        // These should not panic (backbone is currently no-op)
501        let _cell_no_backbone =
502            CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(0, 0, 0.0);
503
504        let _cell_deep_backbone =
505            CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(64, 3, 0.2);
506    }
507
508    #[test]
509    fn test_cfc_activations() {
510        let device = get_test_device();
511
512        for activation in ["relu", "tanh", "gelu", "silu", "lecun_tanh"] {
513            let cell = CfCCell::<TestBackend>::new(20, 50, &device)
514                .with_backbone(64, 1, 0.0)
515                .with_activation(activation);
516
517            let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
518            let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
519
520            let (output, _) = cell.forward(input, hx, 1.0);
521            assert_eq!(output.dims()[0], 2);
522        }
523    }
524
525    #[test]
526    #[should_panic]
527    fn test_cfc_invalid_activation() {
528        let device = get_test_device();
529        let _cell =
530            CfCCell::<TestBackend>::new(20, 50, &device).with_activation("invalid_activation");
531    }
532
533    #[test]
534    fn test_cfc_batch_processing() {
535        let device = get_test_device();
536        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
537
538        // Test with batch sizes 1, 8, 32
539        for batch in [1, 8, 32] {
540            let input = Tensor::<TestBackend, 2>::zeros([batch, 20], &device);
541            let hx = Tensor::<TestBackend, 2>::zeros([batch, 50], &device);
542
543            let (output, _) = cell.forward(input, hx, 1.0);
544            assert_eq!(output.dims(), [batch, 50]);
545        }
546    }
547
548    #[test]
549    fn test_cfc_sparsity_mask() {
550        let device = get_test_device();
551        let mask = Array2::from_shape_vec((50, 50), vec![1.0f32; 2500]).unwrap();
552
553        let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_sparsity_mask(mask, &device);
554
555        assert!(cell.has_sparsity_mask());
556
557        let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
558        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
559
560        let (output, _) = cell.forward(input, hx, 1.0);
561        assert_eq!(output.dims(), [2, 50]);
562    }
563
564    #[test]
565    fn test_cfc_from_wiring() {
566        let device = get_test_device();
567        let wiring = crate::wirings::FullyConnected::new(50, None, 1234, true);
568
569        let cell = CfCCell::<TestBackend>::from_wiring(20, &wiring, &device);
570
571        assert!(cell.has_sparsity_mask());
572        assert_eq!(cell.input_size(), 20);
573        assert_eq!(cell.hidden_size(), 50);
574
575        let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
576        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
577
578        let (output, _) = cell.forward(input, hx, 1.0);
579        assert_eq!(output.dims(), [2, 50]);
580    }
581}