1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
//! # RNN Layers for Sequence Processing
//!
//! This module provides complete RNN layers that handle sequence processing,
//! batching, and hidden state management. **These are the primary APIs most users should use.**
//!
//! ## Available Layers
//!
//! | Layer | Description | Speed | Biological Accuracy |
//! |-------|-------------|-------|---------------------|
//! | [`CfC`] | Closed-form Continuous-time RNN | ⚡ Fast | Medium |
//! | [`LTC`] | Liquid Time-Constant RNN | 🐢 Slower | High |
//!
//! ## Quick Start
//!
//! ```ignore
//! use ncps::prelude::*;
//! use burn::tensor::Tensor;
//!
//! // Create CfC layer with wiring
//! let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
//! wiring.build(16);
//!
//! let cfc = CfC::<Backend>::with_wiring(16, wiring, &device);
//!
//! // Process sequence: [batch=4, seq_len=10, features=16]
//! let input: Tensor<Backend, 3> = Tensor::zeros([4, 10, 16], &device);
//! let (output, final_state) = cfc.forward(input, None, None);
//!
//! // output: [4, 10, 8] - sequence of outputs
//! // final_state: [4, 32] - final hidden state
//! ```
//!
//! ## Tensor Shapes
//!
//! ### Input Tensor (3D)
//!
//! | Format | Shape | Default |
//! |--------|-------|---------|
//! | Batch-first | `[batch, seq_len, features]` | ✓ Yes |
//! | Sequence-first | `[seq_len, batch, features]` | No |
//!
//! Use `.with_batch_first(false)` to switch to sequence-first format.
//!
//! ### Output Tensor
//!
//! | Setting | Shape | Description |
//! |---------|-------|-------------|
//! | `return_sequences=true` (default) | `[batch, seq_len, output_size]` | All timesteps |
//! | `return_sequences=false` | `[batch, 1, output_size]` | Last timestep only |
//!
//! ### Hidden State Tensor (2D)
//!
//! Shape: `[batch, hidden_size]`
//!
//! - `hidden_size` = `wiring.units()` (total neurons)
//! - Can be passed to preserve state across batches
//!
//! ## Common Patterns
//!
//! ### Sequence Classification (return last output only)
//!
//! ```ignore
//! let cfc = CfC::<Backend>::new(input_size, hidden_size, &device)
//! .with_return_sequences(false);
//!
//! let (output, _) = cfc.forward(input, None, None);
//! // output: [batch, 1, hidden_size] - just the final output
//! ```
//!
//! ### Sequence-to-Sequence (return all outputs)
//!
//! ```ignore
//! let cfc = CfC::<Backend>::new(input_size, hidden_size, &device)
//! .with_return_sequences(true); // default
//!
//! let (output, _) = cfc.forward(input, None, None);
//! // output: [batch, seq_len, hidden_size] - output at every timestep
//! ```
//!
//! ### Stateful Processing (preserve hidden state)
//!
//! ```ignore
//! let cfc = CfC::<Backend>::new(input_size, hidden_size, &device);
//!
//! let (output1, state) = cfc.forward(batch1, None, None);
//! let (output2, state) = cfc.forward(batch2, Some(state), None);
//! let (output3, state) = cfc.forward(batch3, Some(state), None);
//! // State persists across batches
//! ```
//!
//! ### With NCP Wiring (sparse, interpretable)
//!
//! ```ignore
//! let mut wiring = AutoNCP::new(64, 10, 0.5, 42);
//! wiring.build(input_size);
//!
//! let cfc = CfC::<Backend>::with_wiring(input_size, wiring, &device);
//!
//! let (output, _) = cfc.forward(input, None, None);
//! // output: [batch, seq_len, 10] - projected to motor neurons
//! ```
//!
//! ## CfC vs LTC: When to Use Each
//!
//! ### Use CfC (Recommended) When:
//! - Speed is important
//! - Training large models
//! - Production deployment
//! - You don't need exact ODE solutions
//!
//! ### Use LTC When:
//! - Biological accuracy matters
//! - Research applications
//! - Comparing with neuroscience models
//! - You need variable time constants
//!
//! ## Mixed Memory (LSTM Augmentation)
//!
//! [`LTC`] supports "mixed memory" which augments the LTC cell with an LSTM
//! for improved long-term dependency handling:
//!
//! ```ignore
//! let ltc = LTC::<Backend>::new(input_size, wiring, &device)
//! .with_mixed_memory(true, &device);
//!
//! // Use forward_mixed() instead of forward()
//! let (output, ltc_state, lstm_state) = ltc.forward_mixed(input, None, None, None);
//! ```
pub use CfC;
pub use LTC;