ncps/cells/mod.rs
1//! # RNN Cell Implementations
2//!
3//! This module provides single-timestep RNN cells for Neural Circuit Policies.
4//! These cells process one timestep at a time and are wrapped by the higher-level
5//! RNN layers in [`crate::rnn`] for sequence processing.
6//!
7//! ## Cell Types
8//!
9//! | Cell | Description | Use Case |
10//! |------|-------------|----------|
11//! | [`CfCCell`] | Closed-form Continuous-time | Fast, efficient, **recommended** |
12//! | [`LTCCell`] | Liquid Time-Constant | Biologically accurate, slower |
13//! | [`WiredCfCCell`] | CfC with multi-layer wiring | Complex architectures |
14//! | [`LSTMCell`] | Standard LSTM | Mixed memory augmentation |
15//!
16//! ## When to Use Cells Directly
17//!
18//! Most users should use the higher-level [`CfC`](crate::rnn::CfC) or [`LTC`](crate::rnn::LTC)
19//! layers which handle sequence processing automatically. Use cells directly when you need:
20//!
21//! - Custom sequence processing logic
22//! - Integration with other frameworks
23//! - Fine-grained control over state management
24//!
25//! ## CfC Operating Modes
26//!
27//! The [`CfCCell`] supports three operating modes via [`CfcMode`]:
28//!
29//! ### Default Mode (Recommended)
30//! ```text
31//! h = tanh(ff1) × (1 - σ(t)) + tanh(ff2) × σ(t)
32//! ```
33//! Gated interpolation between two feedforward paths. Best balance of
34//! expressiveness and stability.
35//!
36//! ### Pure Mode
37//! ```text
38//! h = a - a × exp(-t × (|w_τ| + |ff1|)) × ff1
39//! ```
40//! Direct ODE solution without gating. More biologically plausible but
41//! can be less stable for some tasks.
42//!
43//! ### NoGate Mode
44//! ```text
45//! h = tanh(ff1) + tanh(ff2) × σ(t)
46//! ```
47//! Simplified mode using addition instead of interpolation. Useful for
48//! tasks where gating adds unnecessary complexity.
49//!
50//! ## Tensor Shapes
51//!
52//! All cells expect 2D tensors for single-timestep processing:
53//!
54//! | Tensor | Shape | Description |
55//! |--------|-------|-------------|
56//! | `input` | `[batch, input_size]` | Input features |
57//! | `hidden_state` | `[batch, hidden_size]` | Previous hidden state |
58//! | `output` | `[batch, hidden_size]` | Cell output |
59//! | `new_state` | `[batch, hidden_size]` | Updated hidden state |
60//!
61//! ## Example: Using CfCCell Directly
62//!
63//! ```ignore
64//! use ncps::cells::{CfCCell, CfcMode};
65//! use burn::tensor::Tensor;
66//!
67//! let device = Default::default();
68//! let cell = CfCCell::<Backend>::new(16, 32, &device)
69//! .with_mode(CfcMode::Default);
70//!
71//! // Process single timestep
72//! let input: Tensor<Backend, 2> = /* [batch, 16] */;
73//! let hidden: Tensor<Backend, 2> = Tensor::zeros([batch, 32], &device);
74//!
75//! let (output, new_hidden) = cell.forward(input, hidden, 1.0);
76//! // output: [batch, 32]
77//! // new_hidden: [batch, 32]
78//! ```
79//!
80//! ## Input/Output Mapping Modes
81//!
82//! [`LTCCell`] supports different input/output mapping strategies via [`MappingMode`]:
83//!
84//! - **Affine**: Linear transformation with bias (most expressive)
85//! - **Linear**: Linear transformation without bias
86//! - **None**: Direct pass-through (fastest)
87
88pub mod cfc_cell;
89pub mod lstm_cell;
90pub mod ltc_cell;
91pub mod wired_cfc_cell;
92
93pub use cfc_cell::{CfCCell, CfcMode};
94pub use lstm_cell::LSTMCell;
95pub use ltc_cell::{LTCCell, MappingMode};
96pub use wired_cfc_cell::WiredCfCCell;