Skip to main content

ncps_rust/cells/
cfc_cell.rs

1//! Closed-form Continuous-time (CfC) Cell Implementation
2//!
3//! The CfC cell is a fast approximation of the LTC (Liquid Time-Constant) cell.
4//! It provides closed-form solutions to continuous-time neural dynamics without
5//! requiring iterative ODE solvers.
6//!
7//! Three modes are supported:
8//! - **Default**: Gated interpolation between two feedforward paths
9//! - **Pure**: Direct ODE solution without gating
10//! - **NoGate**: Simplified gating with addition instead of interpolation
11
12use burn::module::{Module, Param};
13use burn::nn::{Linear, LinearConfig};
14use burn::tensor::activation;
15use burn::tensor::backend::Backend;
16use burn::tensor::Tensor;
17use ndarray::Array2;
18
19/// CfC cell operating modes
20#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum CfcMode {
22    /// Default gated mode: h = tanh(ff1) * (1 - σ) + tanh(ff2) * σ
23    Default = 0,
24    /// Pure ODE solution without gating
25    Pure = 1,
26    /// No-gate mode: h = ff1 + tanh(ff2) * σ
27    NoGate = 2,
28}
29
30/// A Closed-form Continuous-time cell
31///
32/// This is an RNNCell that processes single time-steps. To get a full RNN
33/// that can process sequences, see the full RNN layer implementation.
34///
35/// # Type Parameters
36/// * `B` - The backend type
37#[derive(Module, Debug)]
38pub struct CfCCell<B: Backend> {
39    #[module(skip)]
40    input_size: usize,
41    #[module(skip)]
42    hidden_size: usize,
43    /// Mode: 0=Default, 1=Pure, 2=NoGate
44    #[module(skip)]
45    mode: u8,
46    /// Whether sparsity mask is enabled
47    #[module(skip)]
48    has_sparsity_mask: bool,
49    ff1: Linear<B>,
50    ff2: Option<Linear<B>>,
51    time_a: Option<Linear<B>>,
52    time_b: Option<Linear<B>>,
53    w_tau: Option<Linear<B>>,
54    a: Option<Linear<B>>,
55    /// Sparsity mask for output (transposed from input mask)
56    sparsity_mask: Option<Param<Tensor<B, 2>>>,
57}
58
59impl<B: Backend> CfCCell<B> {
60    /// Create a new CfC cell
61    pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
62        let ff1 = LinearConfig::new(input_size + hidden_size, hidden_size)
63            .with_bias(true)
64            .init(device);
65
66        let ff2 = LinearConfig::new(input_size + hidden_size, hidden_size)
67            .with_bias(true)
68            .init(device);
69
70        let time_a = LinearConfig::new(input_size + hidden_size, hidden_size)
71            .with_bias(true)
72            .init(device);
73
74        let time_b = LinearConfig::new(input_size + hidden_size, hidden_size)
75            .with_bias(true)
76            .init(device);
77
78        Self {
79            input_size,
80            hidden_size,
81            mode: 0, // Default
82            has_sparsity_mask: false,
83            ff1,
84            ff2: Some(ff2),
85            time_a: Some(time_a),
86            time_b: Some(time_b),
87            w_tau: None,
88            a: None,
89            sparsity_mask: None,
90        }
91    }
92
93    /// Set the CfC mode (Default, Pure, or NoGate)
94    pub fn with_mode(mut self, mode: CfcMode) -> Self {
95        self.mode = match mode {
96            CfcMode::Default => 0,
97            CfcMode::Pure => 1,
98            CfcMode::NoGate => 2,
99        };
100        self.reconfigure_for_mode();
101        self
102    }
103
104    /// Configure backbone (currently a no-op, kept for API compatibility)
105    pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
106        // Backbone support would require dynamic layer sizing
107        // For now, we keep the simple version working
108        self
109    }
110
111    /// Set the backbone activation (currently a no-op, kept for API compatibility)
112    pub fn with_activation(self, activation: &str) -> Self {
113        let valid_activations = ["relu", "tanh", "gelu", "silu", "lecun_tanh"];
114        if !valid_activations.contains(&activation) {
115            panic!(
116                "Unknown activation: {}. Valid options are {:?}",
117                activation, valid_activations
118            );
119        }
120        self
121    }
122
123    /// Set a sparsity mask to enforce wiring connectivity
124    ///
125    /// The mask should have shape [hidden_size, hidden_size] and contain
126    /// 0s for blocked connections and 1s for allowed connections.
127    /// Note: The mask is transposed internally to match PyTorch convention.
128    pub fn with_sparsity_mask(mut self, mask: Array2<f32>, device: &B::Device) -> Self {
129        let shape = mask.shape();
130        // Transpose the mask to match PyTorch's convention (sparsity_mask.T)
131        let transposed = mask.t();
132        let data: Vec<f32> = transposed.iter().map(|&x| x.abs()).collect();
133        let tensor: Tensor<B, 2> =
134            Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[1], shape[0]]);
135        self.sparsity_mask = Some(Param::from_tensor(tensor));
136        self.has_sparsity_mask = true;
137        self
138    }
139
140    /// Create a CfC cell from a wiring configuration
141    pub fn from_wiring(
142        input_size: usize,
143        wiring: &dyn crate::wirings::Wiring,
144        device: &B::Device,
145    ) -> Self {
146        let hidden_size = wiring.units();
147        let mut cell = Self::new(input_size, hidden_size, device);
148
149        // Apply sparsity mask from adjacency matrix
150        let adj_matrix = wiring.adjacency_matrix();
151        let shape = adj_matrix.shape();
152        let data: Vec<f32> = adj_matrix.iter().map(|&x| x.abs() as f32).collect();
153        let mask_tensor: Tensor<B, 2> =
154            Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
155        cell.sparsity_mask = Some(Param::from_tensor(mask_tensor));
156        cell.has_sparsity_mask = true;
157
158        cell
159    }
160
161    /// Get input size
162    pub fn input_size(&self) -> usize {
163        self.input_size
164    }
165
166    /// Get hidden size
167    pub fn hidden_size(&self) -> usize {
168        self.hidden_size
169    }
170
171    /// Get current mode
172    pub fn mode(&self) -> CfcMode {
173        match self.mode {
174            0 => CfcMode::Default,
175            1 => CfcMode::Pure,
176            2 => CfcMode::NoGate,
177            _ => CfcMode::Default,
178        }
179    }
180
181    fn reconfigure_for_mode(&mut self) {
182        let device = self.ff1.weight.device();
183
184        match self.mode {
185            1 => {
186                // Pure mode: use w_tau and a, remove ff2/time parameters
187                self.ff2 = None;
188                self.time_a = None;
189                self.time_b = None;
190
191                self.w_tau = Some(
192                    LinearConfig::new(1, self.hidden_size)
193                        .with_bias(false)
194                        .init(&device),
195                );
196                self.a = Some(
197                    LinearConfig::new(1, self.hidden_size)
198                        .with_bias(false)
199                        .init(&device),
200                );
201            }
202            _ => {
203                // Default/NoGate mode: ensure ff2, time_a, time_b exist
204                if self.ff2.is_none() {
205                    self.ff2 = Some(
206                        LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
207                            .with_bias(true)
208                            .init(&device),
209                    );
210                }
211                if self.time_a.is_none() {
212                    self.time_a = Some(
213                        LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
214                            .with_bias(true)
215                            .init(&device),
216                    );
217                }
218                if self.time_b.is_none() {
219                    self.time_b = Some(
220                        LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
221                            .with_bias(true)
222                            .init(&device),
223                    );
224                }
225                self.w_tau = None;
226                self.a = None;
227            }
228        }
229        // Note: sparsity_mask is preserved across mode changes
230    }
231
232    /// Check if this cell has a sparsity mask
233    pub fn has_sparsity_mask(&self) -> bool {
234        self.has_sparsity_mask
235    }
236
237    /// Apply sparsity mask to a tensor if mask exists
238    fn apply_sparsity_mask(&self, tensor: Tensor<B, 2>) -> Tensor<B, 2> {
239        if let Some(ref mask) = self.sparsity_mask {
240            // Mask shape is [hidden_size, hidden_size], we need to broadcast
241            // For output masking, we just multiply element-wise with the diagonal
242            // or apply the full mask if needed
243            let mask_val = mask.val();
244            let [batch_size, hidden_size] = tensor.dims();
245
246            // For simple sparsity, we take the row sums as a per-neuron mask
247            // This approximates the effect of masked weights
248            let row_mask: Tensor<B, 1> = mask_val.clone().sum_dim(1).squeeze(1);
249            let row_mask_normalized = row_mask.div_scalar(hidden_size as f32);
250            let mask_expanded = row_mask_normalized.unsqueeze::<2>().expand([batch_size, hidden_size]);
251
252            tensor.mul(mask_expanded)
253        } else {
254            tensor
255        }
256    }
257
258    /// Perform a forward pass through the CfC cell
259    pub fn forward(
260        &self,
261        input: Tensor<B, 2>,
262        hx: Tensor<B, 2>,
263        ts: f32,
264    ) -> (Tensor<B, 2>, Tensor<B, 2>) {
265        let batch_size = input.dims()[0];
266        let device = input.device();
267
268        // Concatenate input and hidden state
269        let x = Tensor::cat(vec![input, hx], 1);
270
271        // Compute ff1 and apply sparsity mask
272        let ff1_out = self.ff1.forward(x.clone());
273        let ff1_out = self.apply_sparsity_mask(ff1_out);
274
275        match self.mode {
276            1 => {
277                // Pure mode
278                let w_tau_layer = self.w_tau.as_ref().unwrap();
279                let a_layer = self.a.as_ref().unwrap();
280
281                let ones_input = Tensor::<B, 2>::ones([batch_size, 1], &device);
282                let w_tau_out = w_tau_layer.forward(ones_input.clone());
283                let a_out = a_layer.forward(ones_input);
284
285                let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
286                let abs_w_tau = w_tau_out.abs();
287                let abs_ff1 = ff1_out.clone().abs();
288
289                let exp_term = (ts_tensor * (abs_w_tau + abs_ff1)).neg().exp();
290                let new_hidden = a_out.clone() - a_out * exp_term * ff1_out;
291
292                (new_hidden.clone(), new_hidden)
293            }
294            _ => {
295                // Default or NoGate mode
296                let ff2_out = self.ff2.as_ref().unwrap().forward(x.clone());
297                let ff2_out = self.apply_sparsity_mask(ff2_out);
298
299                let ff1_tanh = ff1_out.tanh();
300                let ff2_tanh = ff2_out.tanh();
301
302                let time_a = self.time_a.as_ref().unwrap().forward(x.clone());
303                let time_b = self.time_b.as_ref().unwrap().forward(x);
304
305                // Compute time interpolation
306                let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
307                let t_interp = activation::sigmoid(time_a * ts_tensor + time_b);
308
309                let new_hidden = if self.mode == 2 {
310                    // NoGate: h = ff1 + t_interp * ff2
311                    ff1_tanh + t_interp * ff2_tanh
312                } else {
313                    // Default: h = ff1 * (1 - t_interp) + t_interp * ff2
314                    ff1_tanh
315                        * (Tensor::<B, 2>::ones([batch_size, self.hidden_size], &device)
316                            - t_interp.clone())
317                        + t_interp * ff2_tanh
318                };
319
320                (new_hidden.clone(), new_hidden)
321            }
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329    use burn::backend::NdArray;
330    use burn::tensor::backend::Backend as BurnBackend;
331
332    type TestBackend = NdArray<f32>;
333    type TestDevice = <TestBackend as BurnBackend>::Device;
334
335    fn get_test_device() -> TestDevice {
336        Default::default()
337    }
338
339    #[test]
340    fn test_cfc_cell_creation() {
341        let device = get_test_device();
342        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
343
344        assert_eq!(cell.input_size(), 20);
345        assert_eq!(cell.hidden_size(), 50);
346        assert_eq!(cell.mode(), CfcMode::Default);
347    }
348
349    #[test]
350    fn test_cfc_forward_default() {
351        let device = get_test_device();
352        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
353
354        let batch_size = 4;
355        let input = Tensor::<TestBackend, 2>::zeros([batch_size, 20], &device);
356        let hx = Tensor::<TestBackend, 2>::zeros([batch_size, 50], &device);
357
358        let (output, new_hidden) = cell.forward(input, hx, 1.0);
359
360        assert_eq!(output.dims(), [batch_size, 50]);
361        assert_eq!(new_hidden.dims(), [batch_size, 50]);
362    }
363
364    #[test]
365    fn test_cfc_forward_pure() {
366        let device = get_test_device();
367        let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::Pure);
368
369        assert_eq!(cell.mode(), CfcMode::Pure);
370
371        let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
372        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
373
374        let (output, _) = cell.forward(input, hx, 1.0);
375
376        assert_eq!(output.dims(), [2, 50]);
377    }
378
379    #[test]
380    fn test_cfc_forward_no_gate() {
381        let device = get_test_device();
382        let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
383
384        assert_eq!(cell.mode(), CfcMode::NoGate);
385
386        let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
387        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
388
389        let (output, new_hidden) = cell.forward(input, hx, 1.0);
390
391        assert_eq!(output.dims(), [2, 50]);
392        assert_eq!(new_hidden.dims(), [2, 50]);
393    }
394
395    #[test]
396    fn test_cfc_state_change() {
397        let device = get_test_device();
398        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
399
400        let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
401        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
402
403        let (output, new_hidden) = cell.forward(input, hx.clone(), 1.0);
404
405        // State should have changed
406        let diff = (new_hidden.clone() - hx).abs().mean().into_scalar();
407        assert!(diff > 0.0, "State should change after forward pass");
408
409        // Output should equal new_hidden for CfC
410        let output_diff = (output - new_hidden).abs().mean().into_scalar();
411        assert!(output_diff < 1e-6, "Output should equal new_hidden");
412    }
413
414    #[test]
415    fn test_cfc_different_modes_produce_different_results() {
416        let device = get_test_device();
417
418        let cell_default = CfCCell::<TestBackend>::new(20, 50, &device);
419        let cell_no_gate = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
420
421        let input = Tensor::<TestBackend, 2>::random(
422            [2, 20],
423            burn::tensor::Distribution::Uniform(-1.0, 1.0),
424            &device,
425        );
426        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
427
428        let (out1, _) = cell_default.forward(input.clone(), hx.clone(), 1.0);
429        let (out2, _) = cell_no_gate.forward(input, hx, 1.0);
430
431        let diff = (out1 - out2).abs().mean().into_scalar();
432        assert!(
433            diff > 0.01,
434            "Different modes should produce different outputs"
435        );
436    }
437
438    #[test]
439    fn test_cfc_backbone_configurations() {
440        let device = get_test_device();
441
442        // These should not panic (backbone is currently no-op)
443        let _cell_no_backbone =
444            CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(0, 0, 0.0);
445
446        let _cell_deep_backbone =
447            CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(64, 3, 0.2);
448    }
449
450    #[test]
451    fn test_cfc_activations() {
452        let device = get_test_device();
453
454        for activation in ["relu", "tanh", "gelu", "silu", "lecun_tanh"] {
455            let cell = CfCCell::<TestBackend>::new(20, 50, &device)
456                .with_backbone(64, 1, 0.0)
457                .with_activation(activation);
458
459            let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
460            let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
461
462            let (output, _) = cell.forward(input, hx, 1.0);
463            assert_eq!(output.dims()[0], 2);
464        }
465    }
466
467    #[test]
468    #[should_panic]
469    fn test_cfc_invalid_activation() {
470        let device = get_test_device();
471        let _cell =
472            CfCCell::<TestBackend>::new(20, 50, &device).with_activation("invalid_activation");
473    }
474
475    #[test]
476    fn test_cfc_batch_processing() {
477        let device = get_test_device();
478        let cell = CfCCell::<TestBackend>::new(20, 50, &device);
479
480        // Test with batch sizes 1, 8, 32
481        for batch in [1, 8, 32] {
482            let input = Tensor::<TestBackend, 2>::zeros([batch, 20], &device);
483            let hx = Tensor::<TestBackend, 2>::zeros([batch, 50], &device);
484
485            let (output, _) = cell.forward(input, hx, 1.0);
486            assert_eq!(output.dims(), [batch, 50]);
487        }
488    }
489
490    #[test]
491    fn test_cfc_sparsity_mask() {
492        let device = get_test_device();
493        let mask = Array2::from_shape_vec((50, 50), vec![1.0f32; 2500]).unwrap();
494
495        let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_sparsity_mask(mask, &device);
496
497        assert!(cell.has_sparsity_mask());
498
499        let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
500        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
501
502        let (output, _) = cell.forward(input, hx, 1.0);
503        assert_eq!(output.dims(), [2, 50]);
504    }
505
506    #[test]
507    fn test_cfc_from_wiring() {
508        let device = get_test_device();
509        let wiring = crate::wirings::FullyConnected::new(50, None, 1234, true);
510
511        let cell = CfCCell::<TestBackend>::from_wiring(20, &wiring, &device);
512
513        assert!(cell.has_sparsity_mask());
514        assert_eq!(cell.input_size(), 20);
515        assert_eq!(cell.hidden_size(), 50);
516
517        let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
518        let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
519
520        let (output, _) = cell.forward(input, hx, 1.0);
521        assert_eq!(output.dims(), [2, 50]);
522    }
523}