1use crate::cells::CfCCell;
7use crate::wirings::Wiring;
8use burn::module::Module;
9use burn::nn::{Linear, LinearConfig};
10use burn::tensor::backend::Backend;
11use burn::tensor::Tensor;
12
13#[derive(Module, Debug)]
21pub struct CfC<B: Backend> {
22 cell: CfCCell<B>,
24 proj: Option<Linear<B>>,
26 #[module(skip)]
28 input_size: usize,
29 #[module(skip)]
31 hidden_size: usize,
32 #[module(skip)]
34 batch_first: bool,
35 #[module(skip)]
37 return_sequences: bool,
38 #[module(skip)]
40 proj_size: Option<usize>,
41 #[module(skip)]
43 output_size: usize,
44}
45
46impl<B: Backend> CfC<B> {
47 pub fn new(input_size: usize, hidden_size: usize, device: &B::Device) -> Self {
54 let cell = CfCCell::new(input_size, hidden_size, device);
55
56 Self {
57 cell,
58 proj: None,
59 input_size,
60 hidden_size,
61 batch_first: true,
62 return_sequences: true,
63 proj_size: None,
64 output_size: hidden_size,
65 }
66 }
67
68 pub fn with_wiring(input_size: usize, wiring: impl Wiring, device: &B::Device) -> Self {
75 let state_size = wiring.units();
76 let motor_size = wiring.output_dim().unwrap_or(state_size);
77
78 let cell = CfCCell::new(input_size, state_size, device);
79
80 let output_size = motor_size;
81
82 let proj = if motor_size != state_size {
84 Some(
85 LinearConfig::new(state_size, motor_size)
86 .with_bias(true)
87 .init(device),
88 )
89 } else {
90 None
91 };
92
93 Self {
94 cell,
95 proj,
96 input_size,
97 hidden_size: state_size,
98 batch_first: true,
99 return_sequences: true,
100 proj_size: if motor_size != state_size {
101 Some(motor_size)
102 } else {
103 None
104 },
105 output_size,
106 }
107 }
108
109 pub fn with_batch_first(mut self, batch_first: bool) -> Self {
111 self.batch_first = batch_first;
112 self
113 }
114
115 pub fn with_return_sequences(mut self, return_sequences: bool) -> Self {
117 self.return_sequences = return_sequences;
118 self
119 }
120
121 pub fn with_proj_size(mut self, proj_size: usize) -> Self {
123 let device = self.get_device();
124 self.proj = Some(
125 LinearConfig::new(self.hidden_size, proj_size)
126 .with_bias(true)
127 .init(&device),
128 );
129 self.proj_size = Some(proj_size);
130 self.output_size = proj_size;
131 self
132 }
133
134 pub fn with_backbone(self, _units: usize, _layers: usize, _dropout: f64) -> Self {
136 self
137 }
138
139 fn get_device(&self) -> B::Device {
141 B::Device::default()
142 }
143
144 pub fn input_size(&self) -> usize {
146 self.input_size
147 }
148
149 pub fn hidden_size(&self) -> usize {
151 self.hidden_size
152 }
153
154 pub fn output_size(&self) -> usize {
156 self.output_size
157 }
158
159 pub fn forward(
173 &self,
174 input: Tensor<B, 3>,
175 state: Option<Tensor<B, 2>>,
176 _timespans: Option<Tensor<B, 2>>,
177 ) -> (Tensor<B, 3>, Tensor<B, 2>) {
178 let device = input.device();
179
180 let (batch_size, seq_len, _) = if self.batch_first {
182 let dims = input.dims();
183 (dims[0], dims[1], dims[2])
184 } else {
185 let dims = input.dims();
186 (dims[1], dims[0], dims[2])
187 };
188
189 let mut current_state =
191 state.unwrap_or_else(|| Tensor::<B, 2>::zeros([batch_size, self.hidden_size], &device));
192
193 let mut outputs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
195
196 for t in 0..seq_len {
197 let step_input = if self.batch_first {
199 input.clone().narrow(1, t, 1).squeeze(1)
201 } else {
202 input.clone().narrow(0, t, 1).squeeze(0)
204 };
205
206 let (mut output, new_state) = self.cell.forward(step_input, current_state, 1.0);
208 current_state = new_state;
209
210 if let Some(ref proj) = self.proj {
212 output = proj.forward(output);
213 }
214
215 if self.return_sequences {
216 outputs.push(output);
217 } else if t == seq_len - 1 {
218 outputs.push(output);
220 }
221 }
222
223 let output = Tensor::stack(outputs, 1); (output, current_state)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::wirings::{AutoNCP, FullyConnected};
233 use burn::backend::NdArray;
234 use burn::tensor::backend::Backend as BurnBackend;
235
236 type TestBackend = NdArray<f32>;
237 type TestDevice = <TestBackend as BurnBackend>::Device;
238
239 fn get_test_device() -> TestDevice {
240 Default::default()
241 }
242
243 #[test]
244 fn test_cfc_rnn_creation() {
245 let device = get_test_device();
246 let cfc = CfC::<TestBackend>::new(20, 50, &device);
247
248 assert_eq!(cfc.input_size(), 20);
249 assert_eq!(cfc.hidden_size(), 50);
250 assert_eq!(cfc.output_size(), 50);
251 }
252
253 #[test]
254 fn test_cfc_rnn_with_wiring() {
255 let device = get_test_device();
256 let wiring = AutoNCP::new(32, 8, 0.5, 22222);
257 let cfc = CfC::<TestBackend>::with_wiring(20, wiring, &device);
258
259 assert_eq!(cfc.output_size(), 8);
260 }
261
262 #[test]
263 fn test_cfc_rnn_forward() {
264 let device = get_test_device();
265 let cfc = CfC::<TestBackend>::new(20, 50, &device);
266
267 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
268 let (output, state) = cfc.forward(input, None, None);
269
270 assert_eq!(output.dims(), [4, 10, 50]);
271 assert_eq!(state.dims(), [4, 50]);
272 }
273
274 #[test]
275 fn test_cfc_rnn_with_projection() {
276 let device = get_test_device();
277 let cfc = CfC::<TestBackend>::new(20, 50, &device).with_proj_size(10);
278
279 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
280 let (output, _) = cfc.forward(input, None, None);
281
282 assert_eq!(output.dims(), [4, 10, 10]);
284 assert_eq!(cfc.output_size(), 10);
285 }
286
287 #[test]
288 fn test_cfc_rnn_backbone_config() {
289 let device = get_test_device();
290 let cfc = CfC::<TestBackend>::new(20, 50, &device).with_backbone(128, 2, 0.1);
291
292 let input = Tensor::<TestBackend, 3>::zeros([2, 5, 20], &device);
293 let (output, _) = cfc.forward(input, None, None);
294
295 assert_eq!(output.dims(), [2, 5, 50]);
296 }
297
298 #[test]
299 fn test_cfc_rnn_return_last_only() {
300 let device = get_test_device();
301 let cfc = CfC::<TestBackend>::new(20, 50, &device).with_return_sequences(false);
302
303 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
304 let (output, state) = cfc.forward(input, None, None);
305
306 assert_eq!(output.dims(), [4, 1, 50]);
308 assert_eq!(state.dims(), [4, 50]);
309 }
310
311 #[test]
312 fn test_cfc_rnn_seq_first() {
313 let device = get_test_device();
314 let cfc = CfC::<TestBackend>::new(20, 50, &device).with_batch_first(false);
315
316 let input = Tensor::<TestBackend, 3>::zeros([10, 4, 20], &device);
318 let (output, state) = cfc.forward(input, None, None);
319
320 assert_eq!(output.dims(), [4, 10, 50]);
321 assert_eq!(state.dims(), [4, 50]);
322 }
323
324 #[test]
325 fn test_cfc_rnn_with_initial_state() {
326 let device = get_test_device();
327 let cfc = CfC::<TestBackend>::new(20, 50, &device);
328
329 let input = Tensor::<TestBackend, 3>::zeros([4, 10, 20], &device);
330 let initial_state = Tensor::<TestBackend, 2>::ones([4, 50], &device);
331
332 let (output, state) = cfc.forward(input, Some(initial_state), None);
333
334 assert_eq!(output.dims(), [4, 10, 50]);
335 assert_eq!(state.dims(), [4, 50]);
336 }
337}