Skip to main content

candle_mi/transformer/
recurrent.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Recurrent feedback types for anacrousis experiments.
4//!
5//! Anacrousis (ἀνάκρουσις — "the upbeat before the first full bar") re-runs
6//! a subset of transformer layers ("commitment layers") with optional feedback
7//! injection, giving the model extra depth to sustain planning signals through
8//! generation.
9//!
10//! Two modes are supported:
11//!
12//! - **Prefill-only** (`sustained: false`): the recurrent multi-pass applies
13//!   at every step (since candle-mi recomputes from scratch without KV cache),
14//!   but feedback is injected only at the original prompt positions.
15//!
16//! - **Sustained** (`sustained: true`): in addition to the original feedback
17//!   positions, feedback is also injected at the current last token at each
18//!   autoregressive step — the transformer analog of the DRC's per-tick
19//!   recurrence (Taufeeque et al., 2024).
20//!
21//! The `depth` parameter (default 2) controls how many times the recurrent
22//! layer block is executed. Higher depths give the model more iterations to
23//! propagate planning signals, at the cost of proportionally more compute.
24
25use candle_core::Tensor;
26
27use crate::error::{MIError, Result};
28
29// ---------------------------------------------------------------------------
30// RecurrentFeedbackEntry
31// ---------------------------------------------------------------------------
32
33/// A single feedback injection between recurrent passes.
34#[derive(Debug, Clone)]
35pub struct RecurrentFeedbackEntry {
36    /// Token position in the sequence to inject feedback at.
37    pub position: usize,
38    /// Feedback direction vector.
39    ///
40    /// # Shapes
41    /// - `[d_model]`
42    pub vector: Tensor,
43    /// Amplification strength.
44    pub strength: f32,
45}
46
47// ---------------------------------------------------------------------------
48// RecurrentPassSpec
49// ---------------------------------------------------------------------------
50
51/// Specification for a recurrent multi-pass forward through a layer block.
52///
53/// The recurrence re-runs layers `loop_start..=loop_end` a total of `depth`
54/// times, with optional feedback injected into the hidden state between
55/// passes.
56///
57/// # Without feedback
58///
59/// Each subsequent pass receives the previous pass's output (true
60/// recurrence — extra depth).
61///
62/// # With feedback
63///
64/// Each subsequent pass resets to the saved pre-loop hidden state plus
65/// feedback vectors: `hidden[position] += strength * vector`.
66/// This means every recurrent pass sees the same clean input with the
67/// nudge applied — the layers process `H₀ + nudge` rather than
68/// iterating on their own output, which would cause degeneration.
69#[derive(Debug, Clone)]
70pub struct RecurrentPassSpec {
71    /// First layer of the recurrent block (inclusive).
72    pub loop_start: usize,
73    /// Last layer of the recurrent block (inclusive).
74    pub loop_end: usize,
75    /// Feedback vectors to inject between passes.
76    ///
77    /// When present, each recurrent pass resets to the saved pre-loop state
78    /// and injects these vectors before re-running the loop layers.
79    /// If empty, each pass receives the previous pass's output unmodified
80    /// (pure depth increase).
81    pub feedback: Vec<RecurrentFeedbackEntry>,
82    /// If true, also inject feedback at the current last token position
83    /// during each autoregressive generation step (sustained recurrence).
84    ///
85    /// If false, feedback is only injected at the original prompt positions
86    /// (prefill-only recurrence).
87    pub sustained: bool,
88    /// Number of times to run the recurrent layer block.
89    ///
90    /// Must be at least 1 (a single pass, no recurrence). The default is 2
91    /// (one initial pass plus one recurrent pass with feedback injection).
92    pub depth: usize,
93}
94
95impl RecurrentPassSpec {
96    /// Create a spec with no feedback (pure double-pass, depth 2).
97    #[must_use]
98    pub const fn no_feedback(loop_start: usize, loop_end: usize) -> Self {
99        Self {
100            loop_start,
101            loop_end,
102            feedback: Vec::new(),
103            sustained: false,
104            depth: 2,
105        }
106    }
107
108    /// Set the sustained flag (builder pattern).
109    #[must_use]
110    pub const fn with_sustained(mut self, sustained: bool) -> Self {
111        self.sustained = sustained;
112        self
113    }
114
115    /// Set the recurrence depth (builder pattern).
116    ///
117    /// `depth` is the total number of times the recurrent layer block is
118    /// executed. The default is 2 (one initial pass plus one recurrent
119    /// pass). A depth of 1 means no recurrence (single pass).
120    #[must_use]
121    pub const fn with_depth(mut self, depth: usize) -> Self {
122        self.depth = depth;
123        self
124    }
125
126    /// Add a feedback entry.
127    pub fn add_feedback(&mut self, position: usize, vector: Tensor, strength: f32) {
128        self.feedback.push(RecurrentFeedbackEntry {
129            position,
130            vector,
131            strength,
132        });
133    }
134
135    /// Validate the spec against model dimensions.
136    ///
137    /// # Errors
138    ///
139    /// Returns [`MIError::Intervention`] if the layer range is invalid,
140    /// feedback positions exceed sequence length, or feedback vectors
141    /// have the wrong dimension.
142    pub fn validate(&self, n_layers: usize, seq_len: usize, d_model: usize) -> Result<()> {
143        if self.depth == 0 {
144            return Err(MIError::Intervention("depth must be >= 1 (got 0)".into()));
145        }
146        if self.loop_start > self.loop_end {
147            return Err(MIError::Intervention(format!(
148                "loop_start ({}) > loop_end ({})",
149                self.loop_start, self.loop_end
150            )));
151        }
152        if self.loop_end >= n_layers {
153            return Err(MIError::Intervention(format!(
154                "loop_end ({}) >= n_layers ({})",
155                self.loop_end, n_layers
156            )));
157        }
158        for entry in &self.feedback {
159            if entry.position >= seq_len {
160                return Err(MIError::Intervention(format!(
161                    "feedback position {} >= seq_len {}",
162                    entry.position, seq_len
163                )));
164            }
165            let vec_dim = entry.vector.dim(0).map_err(|e| {
166                MIError::Intervention(format!("feedback vector dimension error: {e}"))
167            })?;
168            if vec_dim != d_model {
169                return Err(MIError::Intervention(format!(
170                    "feedback vector dim {vec_dim} != d_model {d_model}"
171                )));
172            }
173        }
174        Ok(())
175    }
176}
177
178// ---------------------------------------------------------------------------
179// Tests
180// ---------------------------------------------------------------------------
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[test]
187    fn no_feedback_builder() {
188        let spec = RecurrentPassSpec::no_feedback(14, 15);
189        assert_eq!(spec.loop_start, 14);
190        assert_eq!(spec.loop_end, 15);
191        assert!(spec.feedback.is_empty());
192        assert!(!spec.sustained);
193    }
194
195    #[test]
196    fn with_sustained_builder() {
197        let spec = RecurrentPassSpec::no_feedback(14, 15).with_sustained(true);
198        assert!(spec.sustained);
199    }
200
201    #[test]
202    fn add_feedback_entry() {
203        let mut spec = RecurrentPassSpec::no_feedback(14, 15);
204        let vec = Tensor::zeros(2048, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
205        spec.add_feedback(5, vec, 2.0);
206        assert_eq!(spec.feedback.len(), 1);
207        assert_eq!(spec.feedback[0].position, 5);
208        assert!((spec.feedback[0].strength - 2.0).abs() < f32::EPSILON);
209    }
210
211    #[test]
212    fn validate_good_spec() {
213        let spec = RecurrentPassSpec::no_feedback(14, 15);
214        assert!(spec.validate(16, 10, 2048).is_ok());
215    }
216
217    #[test]
218    fn validate_start_gt_end() {
219        let spec = RecurrentPassSpec::no_feedback(15, 14);
220        let err = spec.validate(16, 10, 2048);
221        assert!(err.is_err());
222        assert!(err.unwrap_err().to_string().contains("loop_start"));
223    }
224
225    #[test]
226    fn validate_end_out_of_range() {
227        let spec = RecurrentPassSpec::no_feedback(14, 16);
228        let err = spec.validate(16, 10, 2048);
229        assert!(err.is_err());
230        assert!(err.unwrap_err().to_string().contains("loop_end"));
231    }
232
233    #[test]
234    fn validate_feedback_position_out_of_range() {
235        let mut spec = RecurrentPassSpec::no_feedback(14, 15);
236        let vec = Tensor::zeros(2048, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
237        spec.add_feedback(20, vec, 1.0);
238        let err = spec.validate(16, 10, 2048);
239        assert!(err.is_err());
240        assert!(err.unwrap_err().to_string().contains("position"));
241    }
242
243    #[test]
244    fn default_depth_is_two() {
245        let spec = RecurrentPassSpec::no_feedback(14, 15);
246        assert_eq!(spec.depth, 2);
247    }
248
249    #[test]
250    fn with_depth_builder() {
251        let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(4);
252        assert_eq!(spec.depth, 4);
253    }
254
255    #[test]
256    fn validate_depth_zero() {
257        let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(0);
258        let err = spec.validate(16, 10, 2048);
259        assert!(err.is_err());
260        assert!(err.unwrap_err().to_string().contains("depth"));
261    }
262
263    #[test]
264    fn validate_depth_one() {
265        let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(1);
266        assert!(spec.validate(16, 10, 2048).is_ok());
267    }
268
269    #[test]
270    fn validate_feedback_wrong_dim() {
271        let mut spec = RecurrentPassSpec::no_feedback(14, 15);
272        let vec = Tensor::zeros(1024, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
273        spec.add_feedback(5, vec, 1.0);
274        let err = spec.validate(16, 10, 2048);
275        assert!(err.is_err());
276        assert!(err.unwrap_err().to_string().contains("d_model"));
277    }
278}