1use crate::wirings::Wiring;
6use burn::module::{Module, Param};
7use burn::tensor::activation;
8use burn::tensor::backend::Backend;
9use burn::tensor::{Distribution, Tensor};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
13pub enum MappingMode {
14 #[default]
16 Affine,
17 Linear,
19 None,
21}
22
23#[derive(Debug, Module)]
25pub struct LTCCell<B: Backend> {
26 pub gleak: Param<Tensor<B, 1>>,
28 pub vleak: Param<Tensor<B, 1>>,
30 pub cm: Param<Tensor<B, 1>>,
32 pub sigma: Param<Tensor<B, 2>>,
34 pub mu: Param<Tensor<B, 2>>,
36 pub w: Param<Tensor<B, 2>>,
38 pub erev: Param<Tensor<B, 2>>,
40 pub sensory_sigma: Param<Tensor<B, 2>>,
42 pub sensory_mu: Param<Tensor<B, 2>>,
44 pub sensory_w: Param<Tensor<B, 2>>,
46 pub sensory_erev: Param<Tensor<B, 2>>,
48 pub sparsity_mask: Param<Tensor<B, 2>>,
50 pub sensory_sparsity_mask: Param<Tensor<B, 2>>,
52 pub input_w: Option<Param<Tensor<B, 1>>>,
54 pub input_b: Option<Param<Tensor<B, 1>>>,
56 pub output_w: Option<Param<Tensor<B, 1>>>,
58 pub output_b: Option<Param<Tensor<B, 1>>>,
60 #[module(skip)]
62 ode_unfolds: usize,
63 #[module(skip)]
65 epsilon: f64,
66 #[module(skip)]
68 state_size: usize,
69 #[module(skip)]
71 motor_size: usize,
72 #[module(skip)]
74 sensory_size: usize,
75 #[module(skip)]
77 input_mapping: u8,
78 #[module(skip)]
80 output_mapping: u8,
81}
82
83impl<B: Backend> LTCCell<B> {
84 pub fn new(wiring: &dyn Wiring, sensory_size: Option<usize>, device: &B::Device) -> Self {
86 let state_size = wiring.units();
87 let motor_size = wiring.output_dim().unwrap_or(state_size);
88 let actual_sensory_size = sensory_size.or_else(|| wiring.input_dim()).expect(
89 "LTCCell requires sensory_size or wiring with input_dim. Call wiring.build() first.",
90 );
91
92 let gleak = Self::init_param([state_size], 0.001, 1.0, device);
94 let vleak = Self::init_param([state_size], -0.2, 0.2, device);
95 let cm = Self::init_param([state_size], 0.4, 0.6, device);
96
97 let sigma = Self::init_param([state_size, state_size], 3.0, 8.0, device);
99 let mu = Self::init_param([state_size, state_size], 0.3, 0.8, device);
100 let w = Self::init_param([state_size, state_size], 0.001, 1.0, device);
101
102 let erev_matrix = wiring.erev_initializer();
104 let erev = Self::tensor_from_ndarray(&erev_matrix, device);
105
106 let sparsity_mask = Self::sparsity_mask_from_ndarray(&erev_matrix, device);
108
109 let sensory_sigma = Self::init_param([actual_sensory_size, state_size], 3.0, 8.0, device);
110 let sensory_mu = Self::init_param([actual_sensory_size, state_size], 0.3, 0.8, device);
111 let sensory_w = Self::init_param([actual_sensory_size, state_size], 0.001, 1.0, device);
112
113 let (sensory_erev, sensory_sparsity_mask) =
115 if let Some(sensory_matrix) = wiring.sensory_erev_initializer() {
116 (
117 Self::tensor_from_ndarray(&sensory_matrix, device),
118 Self::sparsity_mask_from_ndarray(&sensory_matrix, device),
119 )
120 } else {
121 (
123 Param::from_tensor(Tensor::ones([actual_sensory_size, state_size], device)),
124 Param::from_tensor(Tensor::ones([actual_sensory_size, state_size], device)),
125 )
126 };
127
128 Self {
129 gleak,
130 vleak,
131 cm,
132 sigma,
133 mu,
134 w,
135 erev,
136 sensory_sigma,
137 sensory_mu,
138 sensory_w,
139 sensory_erev,
140 sparsity_mask,
141 sensory_sparsity_mask,
142 input_w: None,
143 input_b: None,
144 output_w: None,
145 output_b: None,
146 ode_unfolds: 6,
147 epsilon: 1e-8,
148 state_size,
149 motor_size,
150 sensory_size: actual_sensory_size,
151 input_mapping: 0, output_mapping: 0, }
154 }
155
156 fn tensor_from_ndarray(
158 arr: &ndarray::Array2<i32>,
159 device: &B::Device,
160 ) -> Param<Tensor<B, 2>> {
161 let shape = arr.shape();
162 let data: Vec<f32> = arr.iter().map(|&x| x as f32).collect();
163 let tensor: Tensor<B, 2> =
164 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
165 Param::from_tensor(tensor)
166 }
167
168 fn sparsity_mask_from_ndarray(
170 arr: &ndarray::Array2<i32>,
171 device: &B::Device,
172 ) -> Param<Tensor<B, 2>> {
173 let shape = arr.shape();
174 let data: Vec<f32> = arr.iter().map(|&x| x.abs() as f32).collect();
175 let tensor: Tensor<B, 2> =
176 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
177 Param::from_tensor(tensor)
178 }
179
180 fn init_param<const D: usize>(
181 shape: [usize; D],
182 min: f64,
183 max: f64,
184 device: &B::Device,
185 ) -> Param<Tensor<B, D>> {
186 let tensor = Tensor::random(shape, Distribution::Uniform(min, max), device);
187 Param::from_tensor(tensor)
188 }
189
190 pub fn with_ode_unfolds(mut self, unfolds: usize) -> Self {
191 self.ode_unfolds = unfolds;
192 self
193 }
194
195 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
196 self.epsilon = epsilon;
197 self
198 }
199
200 pub fn with_input_mapping(mut self, mode: MappingMode, device: &B::Device) -> Self {
202 self.input_mapping = match mode {
203 MappingMode::None => 0,
204 MappingMode::Linear => 1,
205 MappingMode::Affine => 2,
206 };
207 match mode {
208 MappingMode::Affine => {
209 self.input_w =
210 Some(Param::from_tensor(Tensor::ones([self.sensory_size], device)));
211 self.input_b =
212 Some(Param::from_tensor(Tensor::zeros([self.sensory_size], device)));
213 }
214 MappingMode::Linear => {
215 self.input_w =
216 Some(Param::from_tensor(Tensor::ones([self.sensory_size], device)));
217 self.input_b = None;
218 }
219 MappingMode::None => {
220 self.input_w = None;
221 self.input_b = None;
222 }
223 }
224 self
225 }
226
227 pub fn with_output_mapping(mut self, mode: MappingMode, device: &B::Device) -> Self {
229 self.output_mapping = match mode {
230 MappingMode::None => 0,
231 MappingMode::Linear => 1,
232 MappingMode::Affine => 2,
233 };
234 match mode {
235 MappingMode::Affine => {
236 self.output_w = Some(Param::from_tensor(Tensor::ones([self.motor_size], device)));
237 self.output_b = Some(Param::from_tensor(Tensor::zeros([self.motor_size], device)));
238 }
239 MappingMode::Linear => {
240 self.output_w = Some(Param::from_tensor(Tensor::ones([self.motor_size], device)));
241 self.output_b = None;
242 }
243 MappingMode::None => {
244 self.output_w = None;
245 self.output_b = None;
246 }
247 }
248 self
249 }
250
251 pub fn state_size(&self) -> usize {
252 self.state_size
253 }
254
255 pub fn motor_size(&self) -> usize {
256 self.motor_size
257 }
258
259 pub fn sensory_size(&self) -> usize {
260 self.sensory_size
261 }
262
263 pub fn synapse_count(&self) -> usize {
264 self.state_size * self.state_size
265 }
266
267 pub fn sensory_synapse_count(&self) -> usize {
268 self.sensory_size * self.state_size
269 }
270
271 fn map_inputs(&self, inputs: Tensor<B, 2>) -> Tensor<B, 2> {
273 let mut result = inputs;
274 if let Some(ref w) = self.input_w {
275 result = result.mul(w.val().unsqueeze::<2>());
276 }
277 if let Some(ref b) = self.input_b {
278 result = result.add(b.val().unsqueeze::<2>());
279 }
280 result
281 }
282
283 fn map_outputs(&self, state: Tensor<B, 2>) -> Tensor<B, 2> {
285 let mut output = state.narrow(1, 0, self.motor_size);
287
288 if let Some(ref w) = self.output_w {
289 output = output.mul(w.val().unsqueeze::<2>());
290 }
291 if let Some(ref b) = self.output_b {
292 output = output.add(b.val().unsqueeze::<2>());
293 }
294 output
295 }
296
297 pub fn apply_weight_constraints(&mut self) {
299 self.w = Param::from_tensor(self.w.val().clamp_min(0.0));
302 self.sensory_w = Param::from_tensor(self.sensory_w.val().clamp_min(0.0));
303 self.cm = Param::from_tensor(self.cm.val().clamp_min(0.0));
304 self.gleak = Param::from_tensor(self.gleak.val().clamp_min(0.0));
305 }
306}
307
308impl<B: Backend> LTCCell<B> {
309 fn softplus_1d(&self, x: Tensor<B, 1>) -> Tensor<B, 1> {
310 x.exp().add_scalar(1.0).log()
311 }
312
313 fn softplus_2d(&self, x: &Tensor<B, 2>) -> Tensor<B, 2> {
314 x.clone().exp().add_scalar(1.0).log()
315 }
316
317 fn _ode_solver(
318 &self,
319 inputs: Tensor<B, 2>,
320 state: Tensor<B, 2>,
321 elapsed_time: Tensor<B, 1>,
322 ) -> Tensor<B, 2> {
323 let [batch, state_size] = state.dims();
324 let sensory_size = self.sensory_size;
325 let mut v_pre = state;
326
327 let cm = self.softplus_1d(self.cm.val()); let cm_expanded = cm
333 .unsqueeze::<2>() .expand([batch, state_size]); let dt = elapsed_time.div_scalar(self.ode_unfolds as f64); let dt_expanded = dt
339 .unsqueeze_dim::<2>(1) .expand([batch, state_size]); let cm_t = cm_expanded.div(dt_expanded);
343
344 let sensory_sigmoid = self.compute_sensory_sigmoid(&inputs);
347
348 let sensory_w_pos = self.softplus_2d(&self.sensory_w.val());
350 let sensory_w_expanded = sensory_w_pos.unsqueeze::<3>();
351 let sensory_w_activation = sensory_w_expanded.mul(sensory_sigmoid);
352
353 let sensory_mask_expanded = self
355 .sensory_sparsity_mask
356 .val()
357 .reshape([1, sensory_size, state_size]);
358 let sensory_w_activation = sensory_w_activation.mul(sensory_mask_expanded);
359
360 let sensory_erev_expanded = self.sensory_erev.val().unsqueeze::<3>();
362 let sensory_rev_activation = sensory_w_activation.clone().mul(sensory_erev_expanded);
363
364 let w_numerator_sensory: Tensor<B, 2> = sensory_rev_activation.sum_dim(1).squeeze(1);
366 let w_denominator_sensory: Tensor<B, 2> = sensory_w_activation.sum_dim(1).squeeze(1);
367
368 let w_pos = self.softplus_2d(&self.w.val());
369
370 let sparsity_mask_expanded = self
372 .sparsity_mask
373 .val()
374 .reshape([1, state_size, state_size]);
375
376 for _ in 0..self.ode_unfolds {
378 let sigmoid_val = self.compute_sigmoid_2d(&v_pre, &self.mu.val(), &self.sigma.val());
380
381 let w_expanded = w_pos.clone().unsqueeze::<3>();
383 let w_activation = w_expanded.mul(sigmoid_val);
384
385 let w_activation = w_activation.mul(sparsity_mask_expanded.clone());
387
388 let erev_expanded = self.erev.val().unsqueeze::<3>();
390 let rev_activation = w_activation.clone().mul(erev_expanded);
391
392 let w_numerator: Tensor<B, 2> = rev_activation
394 .sum_dim(1)
395 .squeeze(1)
396 .add(w_numerator_sensory.clone());
397 let w_denominator: Tensor<B, 2> = w_activation
398 .sum_dim(1)
399 .squeeze(1)
400 .add(w_denominator_sensory.clone());
401
402 let gleak_pos = self
404 .softplus_1d(self.gleak.val())
405 .unsqueeze::<2>()
406 .expand([batch, state_size]);
407 let vleak_expanded = self
408 .vleak
409 .val()
410 .unsqueeze::<2>()
411 .expand([batch, state_size]);
412
413 let numerator = cm_t
414 .clone()
415 .mul(v_pre.clone())
416 .add(gleak_pos.clone().mul(vleak_expanded))
417 .add(w_numerator);
418 let denominator = cm_t
419 .clone()
420 .add(gleak_pos)
421 .add(w_denominator)
422 .add_scalar(self.epsilon);
423
424 v_pre = numerator.div(denominator);
425 }
426
427 v_pre
428 }
429
430 fn compute_sigmoid_2d(
431 &self,
432 v_pre: &Tensor<B, 2>,
433 mu: &Tensor<B, 2>,
434 sigma: &Tensor<B, 2>,
435 ) -> Tensor<B, 3> {
436 let [batch, state_size] = v_pre.dims();
437
438 let v_expanded = v_pre.clone().reshape([batch, state_size, 1]);
441 let mu_expanded = mu.clone().reshape([1, state_size, state_size]);
442 let sigma_expanded = sigma.clone().reshape([1, state_size, state_size]);
443
444 let diff = v_expanded.sub(mu_expanded);
445 let scaled = sigma_expanded.mul(diff);
446
447 activation::sigmoid(scaled.reshape([batch * state_size, state_size]))
448 .reshape([batch, state_size, state_size])
449 }
450
451 fn compute_sensory_sigmoid(&self, inputs: &Tensor<B, 2>) -> Tensor<B, 3> {
452 let [batch, sensory_size] = inputs.dims();
453 let state_size = self.state_size;
454
455 let inputs_expanded = inputs.clone().reshape([batch, sensory_size, 1]);
457 let mu_expanded = self.sensory_mu.val().reshape([1, sensory_size, state_size]);
458 let sigma_expanded = self
459 .sensory_sigma
460 .val()
461 .reshape([1, sensory_size, state_size]);
462
463 let diff = inputs_expanded.sub(mu_expanded);
464 let scaled = sigma_expanded.mul(diff);
465
466 activation::sigmoid(scaled.reshape([batch * sensory_size, state_size])).reshape([
467 batch,
468 sensory_size,
469 state_size,
470 ])
471 }
472
473 pub fn forward(
474 &self,
475 inputs: Tensor<B, 2>,
476 states: Tensor<B, 2>,
477 elapsed_time: Tensor<B, 1>,
478 ) -> (Tensor<B, 2>, Tensor<B, 2>) {
479 let mapped_inputs = self.map_inputs(inputs);
481
482 let new_states = self._ode_solver(mapped_inputs, states, elapsed_time);
484
485 let output = self.map_outputs(new_states.clone());
487
488 (output, new_states)
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use burn::backend::NdArray;
496
497 type Backend = NdArray<f32>;
498
499 fn create_test_cell() -> LTCCell<Backend> {
500 let device = Default::default();
501 let wiring = crate::wirings::FullyConnected::new(10, Some(5), 1234, true);
502
503 LTCCell::new(&wiring, Some(8), &device)
504 .with_ode_unfolds(6)
505 .with_epsilon(1e-8)
506 }
507
508 #[test]
509 fn test_ltc_cell_creation() {
510 let device = Default::default();
511 let wiring = crate::wirings::FullyConnected::new(10, Some(5), 1234, true);
512 let cell = LTCCell::<Backend>::new(&wiring, Some(8), &device);
513
514 assert_eq!(cell.state_size(), 10);
515 assert_eq!(cell.motor_size(), 5);
516 assert_eq!(cell.sensory_size(), 8);
517 }
518
519 #[test]
520 fn test_ltc_cell_forward() {
521 let device = Default::default();
522 let cell = create_test_cell();
523
524 let batch_size = 4;
525 let inputs = Tensor::<Backend, 2>::zeros([batch_size, 8], &device);
526 let states = Tensor::<Backend, 2>::zeros([batch_size, 10], &device);
527 let elapsed_time = Tensor::<Backend, 1>::ones([batch_size], &device);
528
529 let (output, new_state) = cell.forward(inputs, states, elapsed_time);
530
531 assert_eq!(output.dims(), [batch_size, 5]);
532 assert_eq!(new_state.dims(), [batch_size, 10]);
533 }
534
535 #[test]
536 fn test_ltc_state_change() {
537 let device = Default::default();
538 let cell = create_test_cell();
539
540 let inputs =
541 Tensor::<Backend, 2>::random([2, 8], Distribution::Uniform(-1.0, 1.0), &device);
542 let states = Tensor::<Backend, 2>::zeros([2, 10], &device);
543 let elapsed_time = Tensor::<Backend, 1>::full([2], 1.0, &device);
544
545 let (output, new_state) =
546 cell.forward(inputs.clone(), states.clone(), elapsed_time.clone());
547
548 let state_diff = new_state.abs().mean().into_scalar();
550 assert!(state_diff > 0.0, "State should change after forward pass");
551 }
552}