candle_mi/transformer/recurrent.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Recurrent feedback types for anacrousis experiments.
4//!
5//! Anacrousis (ἀνάκρουσις — "the upbeat before the first full bar") re-runs
6//! a subset of transformer layers ("commitment layers") with optional feedback
7//! injection, giving the model extra depth to sustain planning signals through
8//! generation.
9//!
10//! Two modes are supported:
11//!
12//! - **Prefill-only** (`sustained: false`): the recurrent multi-pass applies
13//! at every step (since candle-mi recomputes from scratch without KV cache),
14//! but feedback is injected only at the original prompt positions.
15//!
16//! - **Sustained** (`sustained: true`): in addition to the original feedback
17//! positions, feedback is also injected at the current last token at each
18//! autoregressive step — the transformer analog of the DRC's per-tick
19//! recurrence (Taufeeque et al., 2024).
20//!
21//! The `depth` parameter (default 2) controls how many times the recurrent
22//! layer block is executed. Higher depths give the model more iterations to
23//! propagate planning signals, at the cost of proportionally more compute.
24
25use candle_core::Tensor;
26
27use crate::error::{MIError, Result};
28
29// ---------------------------------------------------------------------------
30// RecurrentFeedbackEntry
31// ---------------------------------------------------------------------------
32
33/// A single feedback injection between recurrent passes.
34#[derive(Debug, Clone)]
35pub struct RecurrentFeedbackEntry {
36 /// Token position in the sequence to inject feedback at.
37 pub position: usize,
38 /// Feedback direction vector.
39 ///
40 /// # Shapes
41 /// - `[d_model]`
42 pub vector: Tensor,
43 /// Amplification strength.
44 pub strength: f32,
45}
46
47// ---------------------------------------------------------------------------
48// RecurrentPassSpec
49// ---------------------------------------------------------------------------
50
51/// Specification for a recurrent multi-pass forward through a layer block.
52///
53/// The recurrence re-runs layers `loop_start..=loop_end` a total of `depth`
54/// times, with optional feedback injected into the hidden state between
55/// passes.
56///
57/// # Without feedback
58///
59/// Each subsequent pass receives the previous pass's output (true
60/// recurrence — extra depth).
61///
62/// # With feedback
63///
64/// Each subsequent pass resets to the saved pre-loop hidden state plus
65/// feedback vectors: `hidden[position] += strength * vector`.
66/// This means every recurrent pass sees the same clean input with the
67/// nudge applied — the layers process `H₀ + nudge` rather than
68/// iterating on their own output, which would cause degeneration.
69#[derive(Debug, Clone)]
70pub struct RecurrentPassSpec {
71 /// First layer of the recurrent block (inclusive).
72 pub loop_start: usize,
73 /// Last layer of the recurrent block (inclusive).
74 pub loop_end: usize,
75 /// Feedback vectors to inject between passes.
76 ///
77 /// When present, each recurrent pass resets to the saved pre-loop state
78 /// and injects these vectors before re-running the loop layers.
79 /// If empty, each pass receives the previous pass's output unmodified
80 /// (pure depth increase).
81 pub feedback: Vec<RecurrentFeedbackEntry>,
82 /// If true, also inject feedback at the current last token position
83 /// during each autoregressive generation step (sustained recurrence).
84 ///
85 /// If false, feedback is only injected at the original prompt positions
86 /// (prefill-only recurrence).
87 pub sustained: bool,
88 /// Number of times to run the recurrent layer block.
89 ///
90 /// Must be at least 1 (a single pass, no recurrence). The default is 2
91 /// (one initial pass plus one recurrent pass with feedback injection).
92 pub depth: usize,
93}
94
95impl RecurrentPassSpec {
96 /// Create a spec with no feedback (pure double-pass, depth 2).
97 #[must_use]
98 pub const fn no_feedback(loop_start: usize, loop_end: usize) -> Self {
99 Self {
100 loop_start,
101 loop_end,
102 feedback: Vec::new(),
103 sustained: false,
104 depth: 2,
105 }
106 }
107
108 /// Set the sustained flag (builder pattern).
109 #[must_use]
110 pub const fn with_sustained(mut self, sustained: bool) -> Self {
111 self.sustained = sustained;
112 self
113 }
114
115 /// Set the recurrence depth (builder pattern).
116 ///
117 /// `depth` is the total number of times the recurrent layer block is
118 /// executed. The default is 2 (one initial pass plus one recurrent
119 /// pass). A depth of 1 means no recurrence (single pass).
120 #[must_use]
121 pub const fn with_depth(mut self, depth: usize) -> Self {
122 self.depth = depth;
123 self
124 }
125
126 /// Add a feedback entry.
127 pub fn add_feedback(&mut self, position: usize, vector: Tensor, strength: f32) {
128 self.feedback.push(RecurrentFeedbackEntry {
129 position,
130 vector,
131 strength,
132 });
133 }
134
135 /// Validate the spec against model dimensions.
136 ///
137 /// # Errors
138 ///
139 /// Returns [`MIError::Intervention`] if the layer range is invalid,
140 /// feedback positions exceed sequence length, or feedback vectors
141 /// have the wrong dimension.
142 pub fn validate(&self, n_layers: usize, seq_len: usize, d_model: usize) -> Result<()> {
143 if self.depth == 0 {
144 return Err(MIError::Intervention("depth must be >= 1 (got 0)".into()));
145 }
146 if self.loop_start > self.loop_end {
147 return Err(MIError::Intervention(format!(
148 "loop_start ({}) > loop_end ({})",
149 self.loop_start, self.loop_end
150 )));
151 }
152 if self.loop_end >= n_layers {
153 return Err(MIError::Intervention(format!(
154 "loop_end ({}) >= n_layers ({})",
155 self.loop_end, n_layers
156 )));
157 }
158 for entry in &self.feedback {
159 if entry.position >= seq_len {
160 return Err(MIError::Intervention(format!(
161 "feedback position {} >= seq_len {}",
162 entry.position, seq_len
163 )));
164 }
165 let vec_dim = entry.vector.dim(0).map_err(|e| {
166 MIError::Intervention(format!("feedback vector dimension error: {e}"))
167 })?;
168 if vec_dim != d_model {
169 return Err(MIError::Intervention(format!(
170 "feedback vector dim {vec_dim} != d_model {d_model}"
171 )));
172 }
173 }
174 Ok(())
175 }
176}
177
178// ---------------------------------------------------------------------------
179// Tests
180// ---------------------------------------------------------------------------
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn no_feedback_builder() {
188 let spec = RecurrentPassSpec::no_feedback(14, 15);
189 assert_eq!(spec.loop_start, 14);
190 assert_eq!(spec.loop_end, 15);
191 assert!(spec.feedback.is_empty());
192 assert!(!spec.sustained);
193 }
194
195 #[test]
196 fn with_sustained_builder() {
197 let spec = RecurrentPassSpec::no_feedback(14, 15).with_sustained(true);
198 assert!(spec.sustained);
199 }
200
201 #[test]
202 fn add_feedback_entry() {
203 let mut spec = RecurrentPassSpec::no_feedback(14, 15);
204 let vec = Tensor::zeros(2048, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
205 spec.add_feedback(5, vec, 2.0);
206 assert_eq!(spec.feedback.len(), 1);
207 assert_eq!(spec.feedback[0].position, 5);
208 assert!((spec.feedback[0].strength - 2.0).abs() < f32::EPSILON);
209 }
210
211 #[test]
212 fn validate_good_spec() {
213 let spec = RecurrentPassSpec::no_feedback(14, 15);
214 assert!(spec.validate(16, 10, 2048).is_ok());
215 }
216
217 #[test]
218 fn validate_start_gt_end() {
219 let spec = RecurrentPassSpec::no_feedback(15, 14);
220 let err = spec.validate(16, 10, 2048);
221 assert!(err.is_err());
222 assert!(err.unwrap_err().to_string().contains("loop_start"));
223 }
224
225 #[test]
226 fn validate_end_out_of_range() {
227 let spec = RecurrentPassSpec::no_feedback(14, 16);
228 let err = spec.validate(16, 10, 2048);
229 assert!(err.is_err());
230 assert!(err.unwrap_err().to_string().contains("loop_end"));
231 }
232
233 #[test]
234 fn validate_feedback_position_out_of_range() {
235 let mut spec = RecurrentPassSpec::no_feedback(14, 15);
236 let vec = Tensor::zeros(2048, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
237 spec.add_feedback(20, vec, 1.0);
238 let err = spec.validate(16, 10, 2048);
239 assert!(err.is_err());
240 assert!(err.unwrap_err().to_string().contains("position"));
241 }
242
243 #[test]
244 fn default_depth_is_two() {
245 let spec = RecurrentPassSpec::no_feedback(14, 15);
246 assert_eq!(spec.depth, 2);
247 }
248
249 #[test]
250 fn with_depth_builder() {
251 let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(4);
252 assert_eq!(spec.depth, 4);
253 }
254
255 #[test]
256 fn validate_depth_zero() {
257 let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(0);
258 let err = spec.validate(16, 10, 2048);
259 assert!(err.is_err());
260 assert!(err.unwrap_err().to_string().contains("depth"));
261 }
262
263 #[test]
264 fn validate_depth_one() {
265 let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(1);
266 assert!(spec.validate(16, 10, 2048).is_ok());
267 }
268
269 #[test]
270 fn validate_feedback_wrong_dim() {
271 let mut spec = RecurrentPassSpec::no_feedback(14, 15);
272 let vec = Tensor::zeros(1024, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
273 spec.add_feedback(5, vec, 1.0);
274 let err = spec.validate(16, 10, 2048);
275 assert!(err.is_err());
276 assert!(err.unwrap_err().to_string().contains("d_model"));
277 }
278}