use candle_core::Tensor;
use crate::error::{MIError, Result};
#[derive(Debug, Clone)]
pub struct RecurrentFeedbackEntry {
pub position: usize,
pub vector: Tensor,
pub strength: f32,
}
#[derive(Debug, Clone)]
pub struct RecurrentPassSpec {
pub loop_start: usize,
pub loop_end: usize,
pub feedback: Vec<RecurrentFeedbackEntry>,
pub sustained: bool,
pub depth: usize,
}
impl RecurrentPassSpec {
#[must_use]
pub const fn no_feedback(loop_start: usize, loop_end: usize) -> Self {
Self {
loop_start,
loop_end,
feedback: Vec::new(),
sustained: false,
depth: 2,
}
}
#[must_use]
pub const fn with_sustained(mut self, sustained: bool) -> Self {
self.sustained = sustained;
self
}
#[must_use]
pub const fn with_depth(mut self, depth: usize) -> Self {
self.depth = depth;
self
}
pub fn add_feedback(&mut self, position: usize, vector: Tensor, strength: f32) {
self.feedback.push(RecurrentFeedbackEntry {
position,
vector,
strength,
});
}
pub fn validate(&self, n_layers: usize, seq_len: usize, d_model: usize) -> Result<()> {
if self.depth == 0 {
return Err(MIError::Intervention("depth must be >= 1 (got 0)".into()));
}
if self.loop_start > self.loop_end {
return Err(MIError::Intervention(format!(
"loop_start ({}) > loop_end ({})",
self.loop_start, self.loop_end
)));
}
if self.loop_end >= n_layers {
return Err(MIError::Intervention(format!(
"loop_end ({}) >= n_layers ({})",
self.loop_end, n_layers
)));
}
for entry in &self.feedback {
if entry.position >= seq_len {
return Err(MIError::Intervention(format!(
"feedback position {} >= seq_len {}",
entry.position, seq_len
)));
}
let vec_dim = entry.vector.dim(0).map_err(|e| {
MIError::Intervention(format!("feedback vector dimension error: {e}"))
})?;
if vec_dim != d_model {
return Err(MIError::Intervention(format!(
"feedback vector dim {vec_dim} != d_model {d_model}"
)));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_feedback_builder() {
let spec = RecurrentPassSpec::no_feedback(14, 15);
assert_eq!(spec.loop_start, 14);
assert_eq!(spec.loop_end, 15);
assert!(spec.feedback.is_empty());
assert!(!spec.sustained);
}
#[test]
fn with_sustained_builder() {
let spec = RecurrentPassSpec::no_feedback(14, 15).with_sustained(true);
assert!(spec.sustained);
}
#[test]
fn add_feedback_entry() {
let mut spec = RecurrentPassSpec::no_feedback(14, 15);
let vec = Tensor::zeros(2048, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
spec.add_feedback(5, vec, 2.0);
assert_eq!(spec.feedback.len(), 1);
assert_eq!(spec.feedback[0].position, 5);
assert!((spec.feedback[0].strength - 2.0).abs() < f32::EPSILON);
}
#[test]
fn validate_good_spec() {
let spec = RecurrentPassSpec::no_feedback(14, 15);
assert!(spec.validate(16, 10, 2048).is_ok());
}
#[test]
fn validate_start_gt_end() {
let spec = RecurrentPassSpec::no_feedback(15, 14);
let err = spec.validate(16, 10, 2048);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("loop_start"));
}
#[test]
fn validate_end_out_of_range() {
let spec = RecurrentPassSpec::no_feedback(14, 16);
let err = spec.validate(16, 10, 2048);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("loop_end"));
}
#[test]
fn validate_feedback_position_out_of_range() {
let mut spec = RecurrentPassSpec::no_feedback(14, 15);
let vec = Tensor::zeros(2048, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
spec.add_feedback(20, vec, 1.0);
let err = spec.validate(16, 10, 2048);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("position"));
}
#[test]
fn default_depth_is_two() {
let spec = RecurrentPassSpec::no_feedback(14, 15);
assert_eq!(spec.depth, 2);
}
#[test]
fn with_depth_builder() {
let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(4);
assert_eq!(spec.depth, 4);
}
#[test]
fn validate_depth_zero() {
let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(0);
let err = spec.validate(16, 10, 2048);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("depth"));
}
#[test]
fn validate_depth_one() {
let spec = RecurrentPassSpec::no_feedback(14, 15).with_depth(1);
assert!(spec.validate(16, 10, 2048).is_ok());
}
#[test]
fn validate_feedback_wrong_dim() {
let mut spec = RecurrentPassSpec::no_feedback(14, 15);
let vec = Tensor::zeros(1024, candle_core::DType::F32, &candle_core::Device::Cpu).unwrap();
spec.add_feedback(5, vec, 1.0);
let err = spec.validate(16, 10, 2048);
assert!(err.is_err());
assert!(err.unwrap_err().to_string().contains("d_model"));
}
}