1use burn::module::{Module, Param};
13use burn::nn::{Linear, LinearConfig};
14use burn::tensor::activation;
15use burn::tensor::backend::Backend;
16use burn::tensor::Tensor;
17use ndarray::Array2;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum CfcMode {
22 Default = 0,
24 Pure = 1,
26 NoGate = 2,
28}
29
30#[derive(Module, Debug)]
38pub struct CfCCell<B: Backend> {
39 #[module(skip)]
40 input_size: usize,
41 #[module(skip)]
42 hidden_size: usize,
43 #[module(skip)]
45 mode: u8,
46 #[module(skip)]
48 has_sparsity_mask: bool,
49 ff1: Linear<B>,
50 ff2: Option<Linear<B>>,
51 time_a: Option<Linear<B>>,
52 time_b: Option<Linear<B>>,
53 w_tau: Option<Linear<B>>,
54 a: Option<Linear<B>>,
55 sparsity_mask: Option<Param<Tensor<B, 2>>>,
57}
58
59impl<B: Backend> CfCCell<B> {
60 pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
62 let ff1 = LinearConfig::new(input_size + hidden_size, hidden_size)
63 .with_bias(true)
64 .init(device);
65
66 let ff2 = LinearConfig::new(input_size + hidden_size, hidden_size)
67 .with_bias(true)
68 .init(device);
69
70 let time_a = LinearConfig::new(input_size + hidden_size, hidden_size)
71 .with_bias(true)
72 .init(device);
73
74 let time_b = LinearConfig::new(input_size + hidden_size, hidden_size)
75 .with_bias(true)
76 .init(device);
77
78 Self {
79 input_size,
80 hidden_size,
81 mode: 0, has_sparsity_mask: false,
83 ff1,
84 ff2: Some(ff2),
85 time_a: Some(time_a),
86 time_b: Some(time_b),
87 w_tau: None,
88 a: None,
89 sparsity_mask: None,
90 }
91 }
92
93 pub fn with_mode(mut self, mode: CfcMode) -> Self {
95 self.mode = match mode {
96 CfcMode::Default => 0,
97 CfcMode::Pure => 1,
98 CfcMode::NoGate => 2,
99 };
100 self.reconfigure_for_mode();
101 self
102 }
103
104 pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
106 self
109 }
110
111 pub fn with_activation(self, activation: &str) -> Self {
113 let valid_activations = ["relu", "tanh", "gelu", "silu", "lecun_tanh"];
114 if !valid_activations.contains(&activation) {
115 panic!(
116 "Unknown activation: {}. Valid options are {:?}",
117 activation, valid_activations
118 );
119 }
120 self
121 }
122
123 pub fn with_sparsity_mask(mut self, mask: Array2<f32>, device: &B::Device) -> Self {
129 let shape = mask.shape();
130 let transposed = mask.t();
132 let data: Vec<f32> = transposed.iter().map(|&x| x.abs()).collect();
133 let tensor: Tensor<B, 2> =
134 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[1], shape[0]]);
135 self.sparsity_mask = Some(Param::from_tensor(tensor));
136 self.has_sparsity_mask = true;
137 self
138 }
139
140 pub fn from_wiring(
142 input_size: usize,
143 wiring: &dyn crate::wirings::Wiring,
144 device: &B::Device,
145 ) -> Self {
146 let hidden_size = wiring.units();
147 let mut cell = Self::new(input_size, hidden_size, device);
148
149 let adj_matrix = wiring.adjacency_matrix();
151 let shape = adj_matrix.shape();
152 let data: Vec<f32> = adj_matrix.iter().map(|&x| x.abs() as f32).collect();
153 let mask_tensor: Tensor<B, 2> =
154 Tensor::<B, 1>::from_floats(data.as_slice(), device).reshape([shape[0], shape[1]]);
155 cell.sparsity_mask = Some(Param::from_tensor(mask_tensor));
156 cell.has_sparsity_mask = true;
157
158 cell
159 }
160
161 pub fn input_size(&self) -> usize {
163 self.input_size
164 }
165
166 pub fn hidden_size(&self) -> usize {
168 self.hidden_size
169 }
170
171 pub fn mode(&self) -> CfcMode {
173 match self.mode {
174 0 => CfcMode::Default,
175 1 => CfcMode::Pure,
176 2 => CfcMode::NoGate,
177 _ => CfcMode::Default,
178 }
179 }
180
181 fn reconfigure_for_mode(&mut self) {
182 let device = self.ff1.weight.device();
183
184 match self.mode {
185 1 => {
186 self.ff2 = None;
188 self.time_a = None;
189 self.time_b = None;
190
191 self.w_tau = Some(
192 LinearConfig::new(1, self.hidden_size)
193 .with_bias(false)
194 .init(&device),
195 );
196 self.a = Some(
197 LinearConfig::new(1, self.hidden_size)
198 .with_bias(false)
199 .init(&device),
200 );
201 }
202 _ => {
203 if self.ff2.is_none() {
205 self.ff2 = Some(
206 LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
207 .with_bias(true)
208 .init(&device),
209 );
210 }
211 if self.time_a.is_none() {
212 self.time_a = Some(
213 LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
214 .with_bias(true)
215 .init(&device),
216 );
217 }
218 if self.time_b.is_none() {
219 self.time_b = Some(
220 LinearConfig::new(self.input_size + self.hidden_size, self.hidden_size)
221 .with_bias(true)
222 .init(&device),
223 );
224 }
225 self.w_tau = None;
226 self.a = None;
227 }
228 }
229 }
231
232 pub fn has_sparsity_mask(&self) -> bool {
234 self.has_sparsity_mask
235 }
236
237 fn apply_sparsity_mask(&self, tensor: Tensor<B, 2>) -> Tensor<B, 2> {
239 if let Some(ref mask) = self.sparsity_mask {
240 let mask_val = mask.val();
244 let [batch_size, hidden_size] = tensor.dims();
245
246 let row_mask: Tensor<B, 1> = mask_val.clone().sum_dim(1).squeeze(1);
249 let row_mask_normalized = row_mask.div_scalar(hidden_size as f32);
250 let mask_expanded = row_mask_normalized.unsqueeze::<2>().expand([batch_size, hidden_size]);
251
252 tensor.mul(mask_expanded)
253 } else {
254 tensor
255 }
256 }
257
258 pub fn forward(
260 &self,
261 input: Tensor<B, 2>,
262 hx: Tensor<B, 2>,
263 ts: f32,
264 ) -> (Tensor<B, 2>, Tensor<B, 2>) {
265 let batch_size = input.dims()[0];
266 let device = input.device();
267
268 let x = Tensor::cat(vec![input, hx], 1);
270
271 let ff1_out = self.ff1.forward(x.clone());
273 let ff1_out = self.apply_sparsity_mask(ff1_out);
274
275 match self.mode {
276 1 => {
277 let w_tau_layer = self.w_tau.as_ref().unwrap();
279 let a_layer = self.a.as_ref().unwrap();
280
281 let ones_input = Tensor::<B, 2>::ones([batch_size, 1], &device);
282 let w_tau_out = w_tau_layer.forward(ones_input.clone());
283 let a_out = a_layer.forward(ones_input);
284
285 let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
286 let abs_w_tau = w_tau_out.abs();
287 let abs_ff1 = ff1_out.clone().abs();
288
289 let exp_term = (ts_tensor * (abs_w_tau + abs_ff1)).neg().exp();
290 let new_hidden = a_out.clone() - a_out * exp_term * ff1_out;
291
292 (new_hidden.clone(), new_hidden)
293 }
294 _ => {
295 let ff2_out = self.ff2.as_ref().unwrap().forward(x.clone());
297 let ff2_out = self.apply_sparsity_mask(ff2_out);
298
299 let ff1_tanh = ff1_out.tanh();
300 let ff2_tanh = ff2_out.tanh();
301
302 let time_a = self.time_a.as_ref().unwrap().forward(x.clone());
303 let time_b = self.time_b.as_ref().unwrap().forward(x);
304
305 let ts_tensor = Tensor::<B, 2>::full([batch_size, self.hidden_size], ts, &device);
307 let t_interp = activation::sigmoid(time_a * ts_tensor + time_b);
308
309 let new_hidden = if self.mode == 2 {
310 ff1_tanh + t_interp * ff2_tanh
312 } else {
313 ff1_tanh
315 * (Tensor::<B, 2>::ones([batch_size, self.hidden_size], &device)
316 - t_interp.clone())
317 + t_interp * ff2_tanh
318 };
319
320 (new_hidden.clone(), new_hidden)
321 }
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use burn::backend::NdArray;
330 use burn::tensor::backend::Backend as BurnBackend;
331
332 type TestBackend = NdArray<f32>;
333 type TestDevice = <TestBackend as BurnBackend>::Device;
334
335 fn get_test_device() -> TestDevice {
336 Default::default()
337 }
338
339 #[test]
340 fn test_cfc_cell_creation() {
341 let device = get_test_device();
342 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
343
344 assert_eq!(cell.input_size(), 20);
345 assert_eq!(cell.hidden_size(), 50);
346 assert_eq!(cell.mode(), CfcMode::Default);
347 }
348
349 #[test]
350 fn test_cfc_forward_default() {
351 let device = get_test_device();
352 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
353
354 let batch_size = 4;
355 let input = Tensor::<TestBackend, 2>::zeros([batch_size, 20], &device);
356 let hx = Tensor::<TestBackend, 2>::zeros([batch_size, 50], &device);
357
358 let (output, new_hidden) = cell.forward(input, hx, 1.0);
359
360 assert_eq!(output.dims(), [batch_size, 50]);
361 assert_eq!(new_hidden.dims(), [batch_size, 50]);
362 }
363
364 #[test]
365 fn test_cfc_forward_pure() {
366 let device = get_test_device();
367 let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::Pure);
368
369 assert_eq!(cell.mode(), CfcMode::Pure);
370
371 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
372 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
373
374 let (output, _) = cell.forward(input, hx, 1.0);
375
376 assert_eq!(output.dims(), [2, 50]);
377 }
378
379 #[test]
380 fn test_cfc_forward_no_gate() {
381 let device = get_test_device();
382 let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
383
384 assert_eq!(cell.mode(), CfcMode::NoGate);
385
386 let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
387 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
388
389 let (output, new_hidden) = cell.forward(input, hx, 1.0);
390
391 assert_eq!(output.dims(), [2, 50]);
392 assert_eq!(new_hidden.dims(), [2, 50]);
393 }
394
395 #[test]
396 fn test_cfc_state_change() {
397 let device = get_test_device();
398 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
399
400 let input = Tensor::<TestBackend, 2>::ones([2, 20], &device);
401 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
402
403 let (output, new_hidden) = cell.forward(input, hx.clone(), 1.0);
404
405 let diff = (new_hidden.clone() - hx).abs().mean().into_scalar();
407 assert!(diff > 0.0, "State should change after forward pass");
408
409 let output_diff = (output - new_hidden).abs().mean().into_scalar();
411 assert!(output_diff < 1e-6, "Output should equal new_hidden");
412 }
413
414 #[test]
415 fn test_cfc_different_modes_produce_different_results() {
416 let device = get_test_device();
417
418 let cell_default = CfCCell::<TestBackend>::new(20, 50, &device);
419 let cell_no_gate = CfCCell::<TestBackend>::new(20, 50, &device).with_mode(CfcMode::NoGate);
420
421 let input = Tensor::<TestBackend, 2>::random(
422 [2, 20],
423 burn::tensor::Distribution::Uniform(-1.0, 1.0),
424 &device,
425 );
426 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
427
428 let (out1, _) = cell_default.forward(input.clone(), hx.clone(), 1.0);
429 let (out2, _) = cell_no_gate.forward(input, hx, 1.0);
430
431 let diff = (out1 - out2).abs().mean().into_scalar();
432 assert!(
433 diff > 0.01,
434 "Different modes should produce different outputs"
435 );
436 }
437
438 #[test]
439 fn test_cfc_backbone_configurations() {
440 let device = get_test_device();
441
442 let _cell_no_backbone =
444 CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(0, 0, 0.0);
445
446 let _cell_deep_backbone =
447 CfCCell::<TestBackend>::new(20, 50, &device).with_backbone(64, 3, 0.2);
448 }
449
450 #[test]
451 fn test_cfc_activations() {
452 let device = get_test_device();
453
454 for activation in ["relu", "tanh", "gelu", "silu", "lecun_tanh"] {
455 let cell = CfCCell::<TestBackend>::new(20, 50, &device)
456 .with_backbone(64, 1, 0.0)
457 .with_activation(activation);
458
459 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
460 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
461
462 let (output, _) = cell.forward(input, hx, 1.0);
463 assert_eq!(output.dims()[0], 2);
464 }
465 }
466
467 #[test]
468 #[should_panic]
469 fn test_cfc_invalid_activation() {
470 let device = get_test_device();
471 let _cell =
472 CfCCell::<TestBackend>::new(20, 50, &device).with_activation("invalid_activation");
473 }
474
475 #[test]
476 fn test_cfc_batch_processing() {
477 let device = get_test_device();
478 let cell = CfCCell::<TestBackend>::new(20, 50, &device);
479
480 for batch in [1, 8, 32] {
482 let input = Tensor::<TestBackend, 2>::zeros([batch, 20], &device);
483 let hx = Tensor::<TestBackend, 2>::zeros([batch, 50], &device);
484
485 let (output, _) = cell.forward(input, hx, 1.0);
486 assert_eq!(output.dims(), [batch, 50]);
487 }
488 }
489
490 #[test]
491 fn test_cfc_sparsity_mask() {
492 let device = get_test_device();
493 let mask = Array2::from_shape_vec((50, 50), vec![1.0f32; 2500]).unwrap();
494
495 let cell = CfCCell::<TestBackend>::new(20, 50, &device).with_sparsity_mask(mask, &device);
496
497 assert!(cell.has_sparsity_mask());
498
499 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
500 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
501
502 let (output, _) = cell.forward(input, hx, 1.0);
503 assert_eq!(output.dims(), [2, 50]);
504 }
505
506 #[test]
507 fn test_cfc_from_wiring() {
508 let device = get_test_device();
509 let wiring = crate::wirings::FullyConnected::new(50, None, 1234, true);
510
511 let cell = CfCCell::<TestBackend>::from_wiring(20, &wiring, &device);
512
513 assert!(cell.has_sparsity_mask());
514 assert_eq!(cell.input_size(), 20);
515 assert_eq!(cell.hidden_size(), 50);
516
517 let input = Tensor::<TestBackend, 2>::zeros([2, 20], &device);
518 let hx = Tensor::<TestBackend, 2>::zeros([2, 50], &device);
519
520 let (output, _) = cell.forward(input, hx, 1.0);
521 assert_eq!(output.dims(), [2, 50]);
522 }
523}