use std::collections::VecDeque;
use async_trait::async_trait;
use super::case::EvaluationCase;
use super::trial::TrialResult;
#[derive(Debug, Clone)]
pub struct LoopDetectionSimCase {
name: String,
pub n_steps: usize,
pub looping_tool: String,
pub loop_starts_at: usize,
pub window_size: usize,
pub expect_detection: bool,
}
impl LoopDetectionSimCase {
pub fn should_detect(
n_steps: usize,
looping_tool: impl Into<String>,
loop_starts_at: usize,
window_size: usize,
) -> Self {
Self {
name: format!("loop_detection_window{window_size}_step{loop_starts_at}"),
n_steps,
looping_tool: looping_tool.into(),
loop_starts_at,
window_size,
expect_detection: true,
}
}
pub fn should_not_detect(n_steps: usize, window_size: usize) -> Self {
Self {
name: format!("loop_no_detection_window{window_size}_{n_steps}steps"),
n_steps,
looping_tool: "read_file".into(),
loop_starts_at: usize::MAX,
window_size,
expect_detection: false,
}
}
fn simulate(&self) -> bool {
let tool_names = ["read_file", "write_file", "search_code", "list_dir", "bash"];
let mut window: VecDeque<String> = VecDeque::with_capacity(self.window_size);
for step in 1..=self.n_steps {
let tool = if step >= self.loop_starts_at {
self.looping_tool.clone()
} else {
tool_names[(step - 1) % tool_names.len()].to_string()
};
if window.len() == self.window_size {
window.pop_front();
}
window.push_back(tool);
if window.len() == self.window_size && window.iter().all(|n| n == &window[0]) {
return true;
}
}
false
}
}
#[async_trait]
impl EvaluationCase for LoopDetectionSimCase {
fn name(&self) -> &str {
&self.name
}
fn category(&self) -> &str {
"stability/loop_detection"
}
async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
let start = std::time::Instant::now();
let detected = self.simulate();
let ms = start.elapsed().as_millis() as u64;
if detected == self.expect_detection {
Ok(TrialResult::success(trial_id, ms)
.with_meta("loop_detected", serde_json::json!(detected))
.with_meta("n_steps", serde_json::json!(self.n_steps))
.with_meta("window_size", serde_json::json!(self.window_size)))
} else {
let msg = if self.expect_detection {
format!(
"Expected loop detection after {} steps (window={}) but none fired",
self.n_steps, self.window_size
)
} else {
format!(
"Expected no loop detection but one fired at window={}",
self.window_size
)
};
Ok(TrialResult::failure(trial_id, ms, msg))
}
}
}
#[derive(Debug, Clone)]
pub struct GoalPreservationCase {
name: String,
pub n_iterations: usize,
pub revalidation_interval: usize,
pub goal_text: String,
}
impl GoalPreservationCase {
pub fn new(n_iterations: usize, revalidation_interval: usize) -> Self {
Self {
name: format!("goal_preservation_{n_iterations}iter_every{revalidation_interval}"),
n_iterations,
revalidation_interval,
goal_text: "Complete the long-horizon task reliably".to_string(),
}
}
fn expected_injection_points(&self) -> Vec<usize> {
(2..=self.n_iterations)
.filter(|&i| {
self.revalidation_interval > 0 && (i - 1) % self.revalidation_interval == 0
})
.collect()
}
fn simulate_injections(&self) -> Vec<usize> {
let mut injections = Vec::new();
for iteration in 1..=self.n_iterations {
if self.revalidation_interval > 0
&& iteration > 1
&& (iteration - 1) % self.revalidation_interval == 0
{
injections.push(iteration);
}
}
injections
}
}
#[async_trait]
impl EvaluationCase for GoalPreservationCase {
fn name(&self) -> &str {
&self.name
}
fn category(&self) -> &str {
"stability/goal_preservation"
}
async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
let start = std::time::Instant::now();
let injected = self.simulate_injections();
let expected = self.expected_injection_points();
let ms = start.elapsed().as_millis() as u64;
if self.n_iterations >= 15 && self.revalidation_interval > 0 {
let expected_min = 1usize;
if injected.len() < expected_min {
return Ok(TrialResult::failure(
trial_id,
ms,
format!(
"Expected at least {} goal injection(s) across {} iterations \
(interval={}), got 0",
expected_min, self.n_iterations, self.revalidation_interval
),
));
}
}
if injected != expected {
return Ok(TrialResult::failure(
trial_id,
ms,
format!(
"Goal injection mismatch: expected at iterations {:?}, got {:?}",
expected, injected
),
));
}
Ok(TrialResult::success(trial_id, ms)
.with_meta("n_iterations", serde_json::json!(self.n_iterations))
.with_meta("injections", serde_json::json!(injected.len()))
.with_meta("interval", serde_json::json!(self.revalidation_interval)))
}
}
pub fn long_horizon_stability_suite() -> Vec<std::sync::Arc<dyn EvaluationCase>> {
vec![
std::sync::Arc::new(LoopDetectionSimCase::should_detect(20, "read_file", 3, 5)),
std::sync::Arc::new(LoopDetectionSimCase::should_detect(15, "write_file", 1, 5)),
std::sync::Arc::new(LoopDetectionSimCase::should_detect(25, "bash", 10, 7)),
std::sync::Arc::new(LoopDetectionSimCase::should_detect(
30,
"search_code",
5,
10,
)),
std::sync::Arc::new(LoopDetectionSimCase::should_not_detect(20, 5)),
std::sync::Arc::new(LoopDetectionSimCase::should_not_detect(30, 7)),
std::sync::Arc::new(GoalPreservationCase::new(15, 10)),
std::sync::Arc::new(GoalPreservationCase::new(20, 5)),
std::sync::Arc::new(GoalPreservationCase::new(30, 10)),
std::sync::Arc::new(GoalPreservationCase::new(50, 15)),
]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eval::suite::EvaluationSuite;
#[test]
fn test_loop_sim_fires_at_correct_step() {
let case = LoopDetectionSimCase::should_detect(20, "read_file", 3, 5);
assert!(case.simulate(), "expected loop detection to fire");
}
#[test]
fn test_loop_sim_does_not_fire_diverse() {
let case = LoopDetectionSimCase::should_not_detect(20, 5);
assert!(
!case.simulate(),
"expected no loop detection on diverse sequence"
);
}
#[test]
fn test_loop_sim_fires_immediately() {
let case = LoopDetectionSimCase::should_detect(10, "write_file", 1, 3);
assert!(case.simulate());
}
#[test]
fn test_loop_sim_short_run_no_loop() {
let case = LoopDetectionSimCase::should_detect(2, "read_file", 1, 5);
assert!(!case.simulate());
}
#[test]
fn test_goal_injection_points_15iter_interval10() {
let case = GoalPreservationCase::new(15, 10);
let pts = case.expected_injection_points();
assert_eq!(pts, vec![11]);
}
#[test]
fn test_goal_injection_points_20iter_interval5() {
let case = GoalPreservationCase::new(20, 5);
let pts = case.expected_injection_points();
assert_eq!(pts, vec![6, 11, 16]);
}
#[test]
fn test_goal_injection_simulation_matches_expected() {
let case = GoalPreservationCase::new(30, 10);
assert_eq!(case.simulate_injections(), case.expected_injection_points());
}
#[tokio::test]
async fn test_loop_detection_case_succeeds_when_loop_fires() {
let case = LoopDetectionSimCase::should_detect(20, "read_file", 3, 5);
let result = case.run(0).await.unwrap();
assert!(
result.success,
"case should succeed when detection fires as expected: {:?}",
result.error
);
}
#[tokio::test]
async fn test_loop_detection_case_fails_when_no_loop_fires() {
let case = LoopDetectionSimCase::should_detect(2, "read_file", 1, 5);
let result = case.run(0).await.unwrap();
assert!(
!result.success,
"case should fail when expected detection didn't fire"
);
}
#[tokio::test]
async fn test_goal_preservation_case_succeeds() {
let case = GoalPreservationCase::new(20, 5);
let result = case.run(0).await.unwrap();
assert!(
result.success,
"goal preservation case should pass: {:?}",
result.error
);
}
#[tokio::test]
async fn test_full_stability_suite_runs() {
let suite = EvaluationSuite::new(1);
let cases = long_horizon_stability_suite();
let results = suite.run_suite(&cases).await;
assert!(!results.case_results.is_empty());
}
}