Skip to main content

ncps/cells/
ltc_cell.rs

1//! Liquid Time-Constant (LTC) Cell Implementation
2//!
3//! Reference: Hasani et al., "Liquid time-constant networks", AAAI 2021
4
5use crate::wirings::Wiring;
6use burn::module::{Module, Param};
7use burn::tensor::activation;
8use burn::tensor::backend::Backend;
9use burn::tensor::{Distribution, Tensor};
10
11/// Input/output mapping modes
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum MappingMode {
14    /// Affine mapping: y = w * x + b
15    #[default]
16    Affine,
17    /// Linear mapping: y = w * x
18    Linear,
19    /// No mapping (pass-through)
20    None,
21}
22
23/// Liquid Time-Constant (LTC) Cell
24#[derive(Debug, Module)]
25pub struct LTCCell<B: Backend> {
26    /// Leak conductance (must be positive)
27    pub gleak: Param<Tensor<B, 1>>,
28    /// Leak reversal potential
29    pub vleak: Param<Tensor<B, 1>>,
30    /// Membrane capacitance (must be positive)
31    pub cm: Param<Tensor<B, 1>>,
32    /// Sigmoid center parameter for internal synapses
33    pub sigma: Param<Tensor<B, 2>>,
34    /// Sigmoid steepness parameter for internal synapses
35    pub mu: Param<Tensor<B, 2>>,
36    /// Synaptic weights for internal synapses (must be positive)
37    pub w: Param<Tensor<B, 2>>,
38    /// Reversal potentials for internal synapses (from wiring)
39    pub erev: Param<Tensor<B, 2>>,
40    /// Sigmoid center parameter for sensory synapses
41    pub sensory_sigma: Param<Tensor<B, 2>>,
42    /// Sigmoid steepness parameter for sensory synapses
43    pub sensory_mu: Param<Tensor<B, 2>>,
44    /// Synaptic weights for sensory synapses (must be positive)
45    pub sensory_w: Param<Tensor<B, 2>>,
46    /// Reversal potentials for sensory synapses (from wiring)
47    pub sensory_erev: Param<Tensor<B, 2>>,
48    /// Sparsity mask for internal synapses (non-trainable)
49    pub sparsity_mask: Param<Tensor<B, 2>>,
50    /// Sparsity mask for sensory synapses (non-trainable)
51    pub sensory_sparsity_mask: Param<Tensor<B, 2>>,
52    /// Input weight for mapping
53    pub input_w: Option<Param<Tensor<B, 1>>>,
54    /// Input bias for mapping
55    pub input_b: Option<Param<Tensor<B, 1>>>,
56    /// Output weight for mapping
57    pub output_w: Option<Param<Tensor<B, 1>>>,
58    /// Output bias for mapping
59    pub output_b: Option<Param<Tensor<B, 1>>>,
60    /// Number of ODE solver steps per forward pass
61    #[module(skip)]
62    ode_unfolds: usize,
63    /// Epsilon for numerical stability
64    #[module(skip)]
65    epsilon: f64,
66    /// State size (number of neurons)
67    #[module(skip)]
68    state_size: usize,
69    /// Motor size (output neurons)
70    #[module(skip)]
71    motor_size: usize,
72    /// Sensory size (input neurons)
73    #[module(skip)]
74    sensory_size: usize,
75    /// Input mapping mode (0=None, 1=Linear, 2=Affine)
76    #[module(skip)]
77    input_mapping: u8,
78    /// Output mapping mode (0=None, 1=Linear, 2=Affine)
79    #[module(skip)]
80    output_mapping: u8,
81}
82
83impl<B: Backend> LTCCell<B> {
84    /// Creates a new LTC Cell with the given wiring configuration
85    pub fn new(wiring: &dyn Wiring, sensory_size: Option<usize>, device: &B::Device) -> Self {
86        let state_size = wiring.units();
87        let motor_size = wiring.output_dim().unwrap_or(state_size);
88        let actual_sensory_size = sensory_size.or_else(|| wiring.input_dim()).expect(
89            "LTCCell requires sensory_size or wiring with input_dim. Call wiring.build() first.",
90        );
91
92        // Initialize parameters with specified ranges
93        let gleak = Self::init_param([state_size], 0.001, 1.0, device);
94        let vleak = Self::init_param([state_size], -0.2, 0.2, device);
95        let cm = Self::init_param([state_size], 0.4, 0.6, device);
96
97        // 2D parameters
98        let sigma = Self::init_param([state_size, state_size], 3.0, 8.0, device);
99        let mu = Self::init_param([state_size, state_size], 0.3, 0.8, device);
100        let w = Self::init_param([state_size, state_size], 0.001, 1.0, device);
101
102        // Get erev from wiring adjacency matrix (this encodes excitatory/inhibitory polarity)
103        let erev_matrix = wiring.erev_initializer();
104        let erev = Self::tensor_from_ndarray(&erev_matrix, device);
105
106        // Get sparsity mask from adjacency matrix (absolute values)
107        let sparsity_mask = Self::sparsity_mask_from_ndarray(&erev_matrix, device);
108
109        let sensory_sigma = Self::init_param([actual_sensory_size, state_size], 3.0, 8.0, device);
110        let sensory_mu = Self::init_param([actual_sensory_size, state_size], 0.3, 0.8, device);
111        let sensory_w = Self::init_param([actual_sensory_size, state_size], 0.001, 1.0, device);
112
113        // Get sensory erev and sparsity mask from wiring
114        let (sensory_erev, sensory_sparsity_mask) =
115            if let Some(sensory_matrix) = wiring.sensory_erev_initializer() {
116                (
117                    Self::tensor_from_ndarray(&sensory_matrix, device),
118                    Self::sparsity_mask_from_ndarray(&sensory_matrix, device),
119                )
120            } else {
121                // If no sensory adjacency, create fully connected
122                (
123                    Param::from_tensor(Tensor::ones([actual_sensory_size, state_size], device)),
124                    Param::from_tensor(Tensor::ones([actual_sensory_size, state_size], device)),
125                )
126            };
127
128        Self {
129            gleak,
130            vleak,
131            cm,
132            sigma,
133            mu,
134            w,
135            erev,
136            sensory_sigma,
137            sensory_mu,
138            sensory_w,
139            sensory_erev,
140            sparsity_mask,
141            sensory_sparsity_mask,
142            input_w: None,
143            input_b: None,
144            output_w: None,
145            output_b: None,
146            ode_unfolds: 6,
147            epsilon: 1e-8,
148            state_size,
149            motor_size,
150            sensory_size: actual_sensory_size,
151            input_mapping: 0,  // MappingMode::None
152            output_mapping: 0, // MappingMode::None
153        }
154    }
155
156    /// Convert ndarray to Burn tensor parameter
157    fn tensor_from_ndarray(
158        arr: &ndarray::Array2<i32>,
159        device: &B::Device,
160    ) -> Param<Tensor<B, 2>> {
161        let shape = arr.shape();
162        let data: Vec<f32> = arr.iter().map(|&x| x as f32).collect();
163        let tensor: Tensor<B, 2> =
164            Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
165        Param::from_tensor(tensor)
166    }
167
168    /// Create sparsity mask from adjacency matrix (|adjacency|)
169    fn sparsity_mask_from_ndarray(
170        arr: &ndarray::Array2<i32>,
171        device: &B::Device,
172    ) -> Param<Tensor<B, 2>> {
173        let shape = arr.shape();
174        let data: Vec<f32> = arr.iter().map(|&x| x.abs() as f32).collect();
175        let tensor: Tensor<B, 2> =
176            Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
177        Param::from_tensor(tensor)
178    }
179
180    fn init_param<const D: usize>(
181        shape: [usize; D],
182        min: f64,
183        max: f64,
184        device: &B::Device,
185    ) -> Param<Tensor<B, D>> {
186        let tensor = Tensor::random(shape, Distribution::Uniform(min, max), device);
187        Param::from_tensor(tensor)
188    }
189
190    pub fn with_ode_unfolds(mut self, unfolds: usize) -> Self {
191        self.ode_unfolds = unfolds;
192        self
193    }
194
195    pub fn with_epsilon(mut self, epsilon: f64) -> Self {
196        self.epsilon = epsilon;
197        self
198    }
199
200    /// Set input mapping mode (affine, linear, or none)
201    pub fn with_input_mapping(mut self, mode: MappingMode, device: &B::Device) -> Self {
202        self.input_mapping = match mode {
203            MappingMode::None => 0,
204            MappingMode::Linear => 1,
205            MappingMode::Affine => 2,
206        };
207        match mode {
208            MappingMode::Affine => {
209                self.input_w =
210                    Some(Param::from_tensor(Tensor::ones([self.sensory_size], device)));
211                self.input_b =
212                    Some(Param::from_tensor(Tensor::zeros([self.sensory_size], device)));
213            }
214            MappingMode::Linear => {
215                self.input_w =
216                    Some(Param::from_tensor(Tensor::ones([self.sensory_size], device)));
217                self.input_b = None;
218            }
219            MappingMode::None => {
220                self.input_w = None;
221                self.input_b = None;
222            }
223        }
224        self
225    }
226
227    /// Set output mapping mode (affine, linear, or none)
228    pub fn with_output_mapping(mut self, mode: MappingMode, device: &B::Device) -> Self {
229        self.output_mapping = match mode {
230            MappingMode::None => 0,
231            MappingMode::Linear => 1,
232            MappingMode::Affine => 2,
233        };
234        match mode {
235            MappingMode::Affine => {
236                self.output_w = Some(Param::from_tensor(Tensor::ones([self.motor_size], device)));
237                self.output_b = Some(Param::from_tensor(Tensor::zeros([self.motor_size], device)));
238            }
239            MappingMode::Linear => {
240                self.output_w = Some(Param::from_tensor(Tensor::ones([self.motor_size], device)));
241                self.output_b = None;
242            }
243            MappingMode::None => {
244                self.output_w = None;
245                self.output_b = None;
246            }
247        }
248        self
249    }
250
251    pub fn state_size(&self) -> usize {
252        self.state_size
253    }
254
255    pub fn motor_size(&self) -> usize {
256        self.motor_size
257    }
258
259    pub fn sensory_size(&self) -> usize {
260        self.sensory_size
261    }
262
263    pub fn synapse_count(&self) -> usize {
264        self.state_size * self.state_size
265    }
266
267    pub fn sensory_synapse_count(&self) -> usize {
268        self.sensory_size * self.state_size
269    }
270
271    /// Apply input mapping
272    fn map_inputs(&self, inputs: Tensor<B, 2>) -> Tensor<B, 2> {
273        let mut result = inputs;
274        if let Some(ref w) = self.input_w {
275            result = result.mul(w.val().unsqueeze::<2>());
276        }
277        if let Some(ref b) = self.input_b {
278            result = result.add(b.val().unsqueeze::<2>());
279        }
280        result
281    }
282
283    /// Apply output mapping
284    fn map_outputs(&self, state: Tensor<B, 2>) -> Tensor<B, 2> {
285        // First slice to motor size
286        let mut output = state.narrow(1, 0, self.motor_size);
287
288        if let Some(ref w) = self.output_w {
289            output = output.mul(w.val().unsqueeze::<2>());
290        }
291        if let Some(ref b) = self.output_b {
292            output = output.add(b.val().unsqueeze::<2>());
293        }
294        output
295    }
296
297    /// Apply weight constraints (clamp positive parameters to be >= 0)
298    pub fn apply_weight_constraints(&mut self) {
299        // In implicit mode (default), constraints are applied via softplus
300        // This method is for explicit mode where we clamp negative values
301        self.w = Param::from_tensor(self.w.val().clamp_min(0.0));
302        self.sensory_w = Param::from_tensor(self.sensory_w.val().clamp_min(0.0));
303        self.cm = Param::from_tensor(self.cm.val().clamp_min(0.0));
304        self.gleak = Param::from_tensor(self.gleak.val().clamp_min(0.0));
305    }
306}
307
308impl<B: Backend> LTCCell<B> {
309    fn softplus_1d(&self, x: Tensor<B, 1>) -> Tensor<B, 1> {
310        x.exp().add_scalar(1.0).log()
311    }
312
313    fn softplus_2d(&self, x: &Tensor<B, 2>) -> Tensor<B, 2> {
314        x.clone().exp().add_scalar(1.0).log()
315    }
316
317    fn _ode_solver(
318        &self,
319        inputs: Tensor<B, 2>,
320        state: Tensor<B, 2>,
321        elapsed_time: Tensor<B, 1>,
322    ) -> Tensor<B, 2> {
323        let [batch, state_size] = state.dims();
324        let sensory_size = self.sensory_size;
325        let mut v_pre = state;
326
327        // Compute cm_t: cm is [state_size], time is [batch]
328        // Formula: cm_t = softplus(cm) / (elapsed_time / ode_unfolds)
329        let cm = self.softplus_1d(self.cm.val()); // [state_size]
330
331        // Expand cm: [state_size] -> unsqueeze to [1, state_size] -> expand to [batch, state_size]
332        let cm_expanded = cm
333            .unsqueeze::<2>() // [1, state_size]
334            .expand([batch, state_size]); // [batch, state_size]
335
336        // Compute dt per unfold: [batch] -> unsqueeze_dim(1) -> [batch, 1] -> expand to [batch, state_size]
337        let dt = elapsed_time.div_scalar(self.ode_unfolds as f64); // [batch]
338        let dt_expanded = dt
339            .unsqueeze_dim::<2>(1) // [batch, 1]
340            .expand([batch, state_size]); // [batch, state_size]
341
342        let cm_t = cm_expanded.div(dt_expanded);
343
344        // Compute sensory activations
345        // sensory_sigmoid: [batch, sensory_size, state_size]
346        let sensory_sigmoid = self.compute_sensory_sigmoid(&inputs);
347
348        // w * sigmoid(inputs): [batch, sensory_size, state_size]
349        let sensory_w_pos = self.softplus_2d(&self.sensory_w.val());
350        let sensory_w_expanded = sensory_w_pos.unsqueeze::<3>();
351        let sensory_w_activation = sensory_w_expanded.mul(sensory_sigmoid);
352
353        // Apply sensory sparsity mask: [sensory_size, state_size] -> [1, sensory_size, state_size]
354        let sensory_mask_expanded = self
355            .sensory_sparsity_mask
356            .val()
357            .reshape([1, sensory_size, state_size]);
358        let sensory_w_activation = sensory_w_activation.mul(sensory_mask_expanded);
359
360        // erev * w_activation
361        let sensory_erev_expanded = self.sensory_erev.val().unsqueeze::<3>();
362        let sensory_rev_activation = sensory_w_activation.clone().mul(sensory_erev_expanded);
363
364        // Sum over sensory dimension
365        let w_numerator_sensory: Tensor<B, 2> = sensory_rev_activation.sum_dim(1).squeeze(1);
366        let w_denominator_sensory: Tensor<B, 2> = sensory_w_activation.sum_dim(1).squeeze(1);
367
368        let w_pos = self.softplus_2d(&self.w.val());
369
370        // Get sparsity mask for internal synapses: [state_size, state_size] -> [1, state_size, state_size]
371        let sparsity_mask_expanded = self
372            .sparsity_mask
373            .val()
374            .reshape([1, state_size, state_size]);
375
376        // ODE iterations
377        for _ in 0..self.ode_unfolds {
378            // Compute internal synapse activations
379            let sigmoid_val = self.compute_sigmoid_2d(&v_pre, &self.mu.val(), &self.sigma.val());
380
381            // w_activation = w_pos * sigmoid_val
382            let w_expanded = w_pos.clone().unsqueeze::<3>();
383            let w_activation = w_expanded.mul(sigmoid_val);
384
385            // Apply sparsity mask to enforce wiring connectivity
386            let w_activation = w_activation.mul(sparsity_mask_expanded.clone());
387
388            // rev_activation = w_activation * erev
389            let erev_expanded = self.erev.val().unsqueeze::<3>();
390            let rev_activation = w_activation.clone().mul(erev_expanded);
391
392            // Sum over source dimension
393            let w_numerator: Tensor<B, 2> = rev_activation
394                .sum_dim(1)
395                .squeeze(1)
396                .add(w_numerator_sensory.clone());
397            let w_denominator: Tensor<B, 2> = w_activation
398                .sum_dim(1)
399                .squeeze(1)
400                .add(w_denominator_sensory.clone());
401
402            // Update voltage
403            let gleak_pos = self
404                .softplus_1d(self.gleak.val())
405                .unsqueeze::<2>()
406                .expand([batch, state_size]);
407            let vleak_expanded = self
408                .vleak
409                .val()
410                .unsqueeze::<2>()
411                .expand([batch, state_size]);
412
413            let numerator = cm_t
414                .clone()
415                .mul(v_pre.clone())
416                .add(gleak_pos.clone().mul(vleak_expanded))
417                .add(w_numerator);
418            let denominator = cm_t
419                .clone()
420                .add(gleak_pos)
421                .add(w_denominator)
422                .add_scalar(self.epsilon);
423
424            v_pre = numerator.div(denominator);
425        }
426
427        v_pre
428    }
429
430    fn compute_sigmoid_2d(
431        &self,
432        v_pre: &Tensor<B, 2>,
433        mu: &Tensor<B, 2>,
434        sigma: &Tensor<B, 2>,
435    ) -> Tensor<B, 3> {
436        let [batch, state_size] = v_pre.dims();
437
438        // v_pre: [batch, state_size] -> [batch, state_size, 1]
439        // mu, sigma: [state_size, state_size]
440        let v_expanded = v_pre.clone().reshape([batch, state_size, 1]);
441        let mu_expanded = mu.clone().reshape([1, state_size, state_size]);
442        let sigma_expanded = sigma.clone().reshape([1, state_size, state_size]);
443
444        let diff = v_expanded.sub(mu_expanded);
445        let scaled = sigma_expanded.mul(diff);
446
447        activation::sigmoid(scaled.reshape([batch * state_size, state_size]))
448            .reshape([batch, state_size, state_size])
449    }
450
451    fn compute_sensory_sigmoid(&self, inputs: &Tensor<B, 2>) -> Tensor<B, 3> {
452        let [batch, sensory_size] = inputs.dims();
453        let state_size = self.state_size;
454
455        // inputs: [batch, sensory_size] -> [batch, sensory_size, 1]
456        let inputs_expanded = inputs.clone().reshape([batch, sensory_size, 1]);
457        let mu_expanded = self.sensory_mu.val().reshape([1, sensory_size, state_size]);
458        let sigma_expanded = self
459            .sensory_sigma
460            .val()
461            .reshape([1, sensory_size, state_size]);
462
463        let diff = inputs_expanded.sub(mu_expanded);
464        let scaled = sigma_expanded.mul(diff);
465
466        activation::sigmoid(scaled.reshape([batch * sensory_size, state_size])).reshape([
467            batch,
468            sensory_size,
469            state_size,
470        ])
471    }
472
473    pub fn forward(
474        &self,
475        inputs: Tensor<B, 2>,
476        states: Tensor<B, 2>,
477        elapsed_time: Tensor<B, 1>,
478    ) -> (Tensor<B, 2>, Tensor<B, 2>) {
479        // Apply input mapping
480        let mapped_inputs = self.map_inputs(inputs);
481
482        // Run ODE solver
483        let new_states = self._ode_solver(mapped_inputs, states, elapsed_time);
484
485        // Apply output mapping
486        let output = self.map_outputs(new_states.clone());
487
488        (output, new_states)
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use burn::backend::NdArray;
496
497    type Backend = NdArray<f32>;
498
499    fn create_test_cell() -> LTCCell<Backend> {
500        let device = Default::default();
501        let wiring = crate::wirings::FullyConnected::new(10, Some(5), 1234, true);
502
503        LTCCell::new(&wiring, Some(8), &device)
504            .with_ode_unfolds(6)
505            .with_epsilon(1e-8)
506    }
507
508    #[test]
509    fn test_ltc_cell_creation() {
510        let device = Default::default();
511        let wiring = crate::wirings::FullyConnected::new(10, Some(5), 1234, true);
512        let cell = LTCCell::<Backend>::new(&wiring, Some(8), &device);
513
514        assert_eq!(cell.state_size(), 10);
515        assert_eq!(cell.motor_size(), 5);
516        assert_eq!(cell.sensory_size(), 8);
517    }
518
519    #[test]
520    fn test_ltc_cell_forward() {
521        let device = Default::default();
522        let cell = create_test_cell();
523
524        let batch_size = 4;
525        let inputs = Tensor::<Backend, 2>::zeros([batch_size, 8], &device);
526        let states = Tensor::<Backend, 2>::zeros([batch_size, 10], &device);
527        let elapsed_time = Tensor::<Backend, 1>::ones([batch_size], &device);
528
529        let (output, new_state) = cell.forward(inputs, states, elapsed_time);
530
531        assert_eq!(output.dims(), [batch_size, 5]);
532        assert_eq!(new_state.dims(), [batch_size, 10]);
533    }
534
535    #[test]
536    fn test_ltc_state_change() {
537        let device = Default::default();
538        let cell = create_test_cell();
539
540        let inputs =
541            Tensor::<Backend, 2>::random([2, 8], Distribution::Uniform(-1.0, 1.0), &device);
542        let states = Tensor::<Backend, 2>::zeros([2, 10], &device);
543        let elapsed_time = Tensor::<Backend, 1>::full([2], 1.0, &device);
544
545        let (output, new_state) =
546            cell.forward(inputs.clone(), states.clone(), elapsed_time.clone());
547
548        // State should have changed
549        let state_diff = new_state.abs().mean().into_scalar();
550        assert!(state_diff > 0.0, "State should change after forward pass");
551    }
552}