Skip to main content

ncps/rnn/
mod.rs

1//! # RNN Layers for Sequence Processing
2//!
3//! This module provides complete RNN layers that handle sequence processing,
4//! batching, and hidden state management. **These are the primary APIs most users should use.**
5//!
6//! ## Available Layers
7//!
8//! | Layer | Description | Speed | Biological Accuracy |
9//! |-------|-------------|-------|---------------------|
10//! | [`CfC`] | Closed-form Continuous-time RNN | ⚡ Fast | Medium |
11//! | [`LTC`] | Liquid Time-Constant RNN | 🐢 Slower | High |
12//!
13//! ## Quick Start
14//!
15//! ```ignore
16//! use ncps::prelude::*;
17//! use burn::tensor::Tensor;
18//!
19//! // Create CfC layer with wiring
20//! let mut wiring = AutoNCP::new(32, 8, 0.5, 42);
21//! wiring.build(16);
22//!
23//! let cfc = CfC::<Backend>::with_wiring(16, wiring, &device);
24//!
25//! // Process sequence: [batch=4, seq_len=10, features=16]
26//! let input: Tensor<Backend, 3> = Tensor::zeros([4, 10, 16], &device);
27//! let (output, final_state) = cfc.forward(input, None, None);
28//!
29//! // output: [4, 10, 8] - sequence of outputs
30//! // final_state: [4, 32] - final hidden state
31//! ```
32//!
33//! ## Tensor Shapes
34//!
35//! ### Input Tensor (3D)
36//!
37//! | Format | Shape | Default |
38//! |--------|-------|---------|
39//! | Batch-first | `[batch, seq_len, features]` | ✓ Yes |
40//! | Sequence-first | `[seq_len, batch, features]` | No |
41//!
42//! Use `.with_batch_first(false)` to switch to sequence-first format.
43//!
44//! ### Output Tensor
45//!
46//! | Setting | Shape | Description |
47//! |---------|-------|-------------|
48//! | `return_sequences=true` (default) | `[batch, seq_len, output_size]` | All timesteps |
49//! | `return_sequences=false` | `[batch, 1, output_size]` | Last timestep only |
50//!
51//! ### Hidden State Tensor (2D)
52//!
53//! Shape: `[batch, hidden_size]`
54//!
55//! - `hidden_size` = `wiring.units()` (total neurons)
56//! - Can be passed to preserve state across batches
57//!
58//! ## Common Patterns
59//!
60//! ### Sequence Classification (return last output only)
61//!
62//! ```ignore
63//! let cfc = CfC::<Backend>::new(input_size, hidden_size, &device)
64//!     .with_return_sequences(false);
65//!
66//! let (output, _) = cfc.forward(input, None, None);
67//! // output: [batch, 1, hidden_size] - just the final output
68//! ```
69//!
70//! ### Sequence-to-Sequence (return all outputs)
71//!
72//! ```ignore
73//! let cfc = CfC::<Backend>::new(input_size, hidden_size, &device)
74//!     .with_return_sequences(true);  // default
75//!
76//! let (output, _) = cfc.forward(input, None, None);
77//! // output: [batch, seq_len, hidden_size] - output at every timestep
78//! ```
79//!
80//! ### Stateful Processing (preserve hidden state)
81//!
82//! ```ignore
83//! let cfc = CfC::<Backend>::new(input_size, hidden_size, &device);
84//!
85//! let (output1, state) = cfc.forward(batch1, None, None);
86//! let (output2, state) = cfc.forward(batch2, Some(state), None);
87//! let (output3, state) = cfc.forward(batch3, Some(state), None);
88//! // State persists across batches
89//! ```
90//!
91//! ### With NCP Wiring (sparse, interpretable)
92//!
93//! ```ignore
94//! let mut wiring = AutoNCP::new(64, 10, 0.5, 42);
95//! wiring.build(input_size);
96//!
97//! let cfc = CfC::<Backend>::with_wiring(input_size, wiring, &device);
98//!
99//! let (output, _) = cfc.forward(input, None, None);
100//! // output: [batch, seq_len, 10] - projected to motor neurons
101//! ```
102//!
103//! ## CfC vs LTC: When to Use Each
104//!
105//! ### Use CfC (Recommended) When:
106//! - Speed is important
107//! - Training large models
108//! - Production deployment
109//! - You don't need exact ODE solutions
110//!
111//! ### Use LTC When:
112//! - Biological accuracy matters
113//! - Research applications
114//! - Comparing with neuroscience models
115//! - You need variable time constants
116//!
117//! ## Mixed Memory (LSTM Augmentation)
118//!
119//! [`LTC`] supports "mixed memory" which augments the LTC cell with an LSTM
120//! for improved long-term dependency handling:
121//!
122//! ```ignore
123//! let ltc = LTC::<Backend>::new(input_size, wiring, &device)
124//!     .with_mixed_memory(true, &device);
125//!
126//! // Use forward_mixed() instead of forward()
127//! let (output, ltc_state, lstm_state) = ltc.forward_mixed(input, None, None, None);
128//! ```
129
130pub mod cfc;
131pub mod ltc;
132
133pub use cfc::CfC;
134pub use ltc::LTC;