1use burn::module::{Module, Param};
13use burn::nn::{Linear, LinearConfig};
14use burn::tensor::activation;
15use burn::tensor::backend::Backend;
16use burn::tensor::Tensor;
17use ndarray::Array2;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
62pub enum CfcMode {
63 Default = 0,
70
71 Pure = 1,
78
79 NoGate = 2,
86}
87
88#[derive(Module, Debug)]
96pub struct CfCCell<B: Backend> {
97 #[module(skip)]
98 input_size: usize,
99 #[module(skip)]
100 hidden_size: usize,
101 #[module(skip)]
103 mode: u8,
104 #[module(skip)]
106 has_sparsity_mask: bool,
107 ff1: Linear<B>,
108 ff2: Option<Linear<B>>,
109 time_a: Option<Linear<B>>,
110 time_b: Option<Linear<B>>,
111 w_tau: Option<Linear<B>>,
112 a: Option<Linear<B>>,
113 sparsity_mask: Option<Param<Tensor<B, 2>>>,
115}
116
117impl<B: Backend> CfCCell<B> {
118 pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
120 let ff1 = LinearConfig::new(input_size + hidden_size, hidden_size)
121 .with_bias(true)
122 .init(device);
123
124 let ff2 = LinearConfig::new(input_size + hidden_size, hidden_size)
125 .with_bias(true)
126 .init(device);
127
128 let time_a = LinearConfig::new(input_size + hidden_size, hidden_size)
129 .with_bias(true)
130 .init(device);
131
132 let time_b = LinearConfig::new(input_size + hidden_size, hidden_size)
133 .with_bias(true)
134 .init(device);
135
136 Self {
137 input_size,
138 hidden_size,
139 mode: 0, has_sparsity_mask: false,
141 ff1,
142 ff2: Some(ff2),
143 time_a: Some(time_a),
144 time_b: Some(time_b),
145 w_tau: None,
146 a: None,
147 sparsity_mask: None,
148 }
149 }
150
151 pub fn with_mode(mut self, mode: CfcMode) -> Self {
153 self.mode = match mode {
154 CfcMode::Default => 0,
155 CfcMode::Pure => 1,
156 CfcMode::NoGate => 2,
157 };
158 self.reconfigure_for_mode();
159 self
160 }
161
162 pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
164 self
167 }
168
169 pub fn with_activation(self, activation: &str) -> Self {
171 let valid_activations = ["relu", "tanh", "gelu", "silu", "lecun_tanh"];
172 if !valid_activations.contains(&activation) {
173 panic!(
174 "Unknown activation: {}. Valid options are {:?}",
175 activation, valid_activations
176 );
177 }
178 self
179 }
180
181 pub fn with_sparsity_mask(mut self, mask: Array2<f32>, device: &B::Device) -> Self {
187 let shape = mask.shape();
188 let transposed = mask.t();
190 let data: Vec<f32> = transposed.iter().map(|&x| x.abs()).collect();
191 let tensor: Tensor<B, 2> =
192 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[1], shape[0]]);
193 self.sparsity_mask = Some(Param::from_tensor(tensor));
194 self.has_sparsity_mask = true;
195 self
196 }
197
198 pub fn from_wiring(
200 input_size: usize,
201 wiring: &dyn crate::wirings::Wiring,
202 device: &B::Device,
203 ) -> Self {
204 let hidden_size = wiring.units();
205 let mut cell = Self::new(input_size, hidden_size, device);
206
207 let adj_matrix = wiring.adjacency_matrix();
209 let shape = adj_matrix.shape();
210 let data: Vec<f32> = adj_matrix.iter().map(|&x| x.abs() as f32).collect();
211 let mask_tensor: Tensor<B, 2> =
212 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
213 cell.sparsity_mask = Some(Param::from_tensor(mask_tensor));
214 cell.has_sparsity_mask = true;
215
216 cell
217 }
218
219 pub fn input_size(&self) -> usize {
221 self.input_size
222 }
223
224 pub fn hidden_size(&self) -> usize {
226 self.hidden_size
227 }
228
229 pub fn mode(&self) -> CfcMode {
231 match self.mode {
232 0 => CfcMode::Default,
233 1 => CfcMode::Pure,
234 2 => CfcMode::NoGate,
235 _ => CfcMode::Default,
236 }
237 }
238
239 fn reconfigure_for_mode(&mut self) {
240 let device = self.ff1.weight.device();
241
242 match self.mode {
243 1 => {
244 self.ff2 = None;
246 self.time_a = None;
247 self.time_b = None;
248
249 self.w_tau = Some(
250 LinearConfig::new(1, self.hidden_size)
251 .with_bias(false)
252 .init(&device),
253 );
254 self.a = Some(
255 LinearConfig::new(1, self.hidden_size)
256 .with_bias(false)
257 .init(&device),
258 );
259 }
260 _ => {
261 if self.ff2.is_none() {
263 self.ff2 = Some(
264 LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
265 .with_bias(true)
266 .init(&device),
267 );
268 }
269 if self.time_a.is_none() {
270 self.time_a = Some(
271 LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
272 .with_bias(true)
273 .init(&device),
274 );
275 }
276 if self.time_b.is_none() {
277 self.time_b = Some(
278 LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
279 .with_bias(true)
280 .init(&device),
281 );
282 }
283 self.w_tau = None;
284 self.a = None;
285 }
286 }
287 }
289
290 pub fn has_sparsity_mask(&self) -> bool {
292 self.has_sparsity_mask
293 }
294
295 fn apply_sparsity_mask(&self, tensor: Tensor<B, 2>) -> Tensor<B, 2> {
297 if let Some(ref mask) = self.sparsity_mask {
298 let mask_val = mask.val();
302 let [batch_size, hidden_size] = tensor.dims();
303
304 let row_mask: Tensor<B, 1> = mask_val.clone().sum_dim(1).squeeze(1);
307 let row_mask_normalized = row_mask.div_scalar(hidden_size as f32);
308 let mask_expanded = row_mask_normalized.unsqueeze::<2>().expand([batch_size, hidden_size]);
309
310 tensor.mul(mask_expanded)
311 } else {
312 tensor
313 }
314 }
315
316 pub fn forward(
318 &self,
319 input: Tensor<B, 2>,
320 hx: Tensor<B, 2>,
321 ts: f32,
322 ) -> (Tensor<B, 2>, Tensor<B, 2>) {
323 let batch_size = input.dims()[0];
324 let device = input.device();
325
326 let x = Tensor::cat(vec![input, hx], 1);
328
329 let ff1_out = self.ff1.forward(x.clone());
331 let ff1_out = self.apply_sparsity_mask(ff1_out);
332
333 match self.mode {
334 1 => {
335 let w_tau_layer = self.w_tau.as_ref().unwrap();
337 let a_layer = self.a.as_ref().unwrap();
338
339 let ones_input = Tensor::<B, 2>::ones([batch_size, 1], &device);
340 let w_tau_out = w_tau_layer.forward(ones_input.clone());
341 let a_out = a_layer.forward(ones_input);
342
343 let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
344 let abs_w_tau = w_tau_out.abs();
345 let abs_ff1 = ff1_out.clone().abs();
346
347 let exp_term = (ts_tensor * (abs_w_tau + abs_ff1)).neg().exp();
348 let new_hidden = a_out.clone() - a_out * exp_term * ff1_out;
349
350 (new_hidden.clone(), new_hidden)
351 }
352 _ => {
353 let ff2_out = self.ff2.as_ref().unwrap().forward(x.clone());
355 let ff2_out = self.apply_sparsity_mask(ff2_out);
356
357 let ff1_tanh = ff1_out.tanh();
358 let ff2_tanh = ff2_out.tanh();
359
360 let time_a = self.time_a.as_ref().unwrap().forward(x.clone());
361 let time_b = self.time_b.as_ref().unwrap().forward(x);
362
363 let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
365 let t_interp = activation::sigmoid(time_a * ts_tensor + time_b);
366
367 let new_hidden = if self.mode == 2 {
368 ff1_tanh + t_interp * ff2_tanh
370 } else {
371 ff1_tanh
373 * (Tensor::<B, 2>::ones([batch_size, self.hidden_size], &device)
374 - t_interp.clone())
375 + t_interp * ff2_tanh
376 };
377
378 (new_hidden.clone(), new_hidden)
379 }
380 }
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use burn::backend::NdArray;
388 use burn::tensor::backend::Backend as BurnBackend;
389
390 type TestBackend = NdArray<f32>;
391 type TestDevice = <TestBackend as BurnBackend>::Device;
392
393 fn get_test_device() -> TestDevice {
394 Default::default()
395 }
396
397 #[test]
398 fn test_cfc_cell_creation() {
399 let device = get_test_device();
400 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
401
402 assert_eq!(cell.input_size(), 20);
403 assert_eq!(cell.hidden_size(), 50);
404 assert_eq!(cell.mode(), CfcMode::Default);
405 }
406
407 #[test]
408 fn test_cfc_forward_default() {
409 let device = get_test_device();
410 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
411
412 let batch_size = 4;
413 let input = Tensor::<TestBackend, 2>::zeros([batch_size, 20], &device);
414 let hx = Tensor::<TestBackend, 2>::zeros([batch_size, 50], &device);
415
416 let (output, new_hidden) = cell.forward(input, hx, 1.0);
417
418 assert_eq!(output.dims(), [batch_size, 50]);
419 assert_eq!(new_hidden.dims(), [batch_size, 50]);
420 }
421
422 #[test]
423 fn test_cfc_forward_pure() {
424 let device = get_test_device();
425 let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::Pure);
426
427 assert_eq!(cell.mode(), CfcMode::Pure);
428
429 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
430 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
431
432 let (output, _) = cell.forward(input, hx, 1.0);
433
434 assert_eq!(output.dims(), [2, 50]);
435 }
436
437 #[test]
438 fn test_cfc_forward_no_gate() {
439 let device = get_test_device();
440 let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
441
442 assert_eq!(cell.mode(), CfcMode::NoGate);
443
444 let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
445 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
446
447 let (output, new_hidden) = cell.forward(input, hx, 1.0);
448
449 assert_eq!(output.dims(), [2, 50]);
450 assert_eq!(new_hidden.dims(), [2, 50]);
451 }
452
453 #[test]
454 fn test_cfc_state_change() {
455 let device = get_test_device();
456 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
457
458 let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
459 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
460
461 let (output, new_hidden) = cell.forward(input, hx.clone(), 1.0);
462
463 let diff = (new_hidden.clone() - hx).abs().mean().into_scalar();
465 assert!(diff > 0.0, "State should change after forward pass");
466
467 let output_diff = (output - new_hidden).abs().mean().into_scalar();
469 assert!(output_diff < 1e-6, "Output should equal new_hidden");
470 }
471
472 #[test]
473 fn test_cfc_different_modes_produce_different_results() {
474 let device = get_test_device();
475
476 let cell_default = CfCCell::<TestBackend>::new(20, 50, &device);
477 let cell_no_gate = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
478
479 let input = Tensor::<TestBackend, 2>::random(
480 [2, 20],
481 burn::tensor::Distribution::Uniform(-1.0, 1.0),
482 &device,
483 );
484 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
485
486 let (out1, _) = cell_default.forward(input.clone(), hx.clone(), 1.0);
487 let (out2, _) = cell_no_gate.forward(input, hx, 1.0);
488
489 let diff = (out1 - out2).abs().mean().into_scalar();
490 assert!(
491 diff > 0.01,
492 "Different modes should produce different outputs"
493 );
494 }
495
496 #[test]
497 fn test_cfc_backbone_configurations() {
498 let device = get_test_device();
499
500 let _cell_no_backbone =
502 CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(0, 0, 0.0);
503
504 let _cell_deep_backbone =
505 CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(64, 3, 0.2);
506 }
507
508 #[test]
509 fn test_cfc_activations() {
510 let device = get_test_device();
511
512 for activation in ["relu", "tanh", "gelu", "silu", "lecun_tanh"] {
513 let cell = CfCCell::<TestBackend>::new(20, 50, &device)
514 .with_backbone(64, 1, 0.0)
515 .with_activation(activation);
516
517 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
518 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
519
520 let (output, _) = cell.forward(input, hx, 1.0);
521 assert_eq!(output.dims()[0], 2);
522 }
523 }
524
525 #[test]
526 #[should_panic]
527 fn test_cfc_invalid_activation() {
528 let device = get_test_device();
529 let _cell =
530 CfCCell::<TestBackend>::new(20, 50, &device).with_activation("invalid_activation");
531 }
532
533 #[test]
534 fn test_cfc_batch_processing() {
535 let device = get_test_device();
536 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
537
538 for batch in [1, 8, 32] {
540 let input = Tensor::<TestBackend, 2>::zeros([batch, 20], &device);
541 let hx = Tensor::<TestBackend, 2>::zeros([batch, 50], &device);
542
543 let (output, _) = cell.forward(input, hx, 1.0);
544 assert_eq!(output.dims(), [batch, 50]);
545 }
546 }
547
548 #[test]
549 fn test_cfc_sparsity_mask() {
550 let device = get_test_device();
551 let mask = Array2::from_shape_vec((50, 50), vec![1.0f32; 2500]).unwrap();
552
553 let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_sparsity_mask(mask, &device);
554
555 assert!(cell.has_sparsity_mask());
556
557 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
558 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
559
560 let (output, _) = cell.forward(input, hx, 1.0);
561 assert_eq!(output.dims(), [2, 50]);
562 }
563
564 #[test]
565 fn test_cfc_from_wiring() {
566 let device = get_test_device();
567 let wiring = crate::wirings::FullyConnected::new(50, None, 1234, true);
568
569 let cell = CfCCell::<TestBackend>::from_wiring(20, &wiring, &device);
570
571 assert!(cell.has_sparsity_mask());
572 assert_eq!(cell.input_size(), 20);
573 assert_eq!(cell.hidden_size(), 50);
574
575 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
576 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
577
578 let (output, _) = cell.forward(input, hx, 1.0);
579 assert_eq!(output.dims(), [2, 50]);
580 }
581}