1use std::collections::VecDeque;
11
12use async_trait::async_trait;
13
14use super::case::EvaluationCase;
15use super::trial::TrialResult;
16
17#[derive(Debug, Clone)]
26pub struct LoopDetectionSimCase {
27 name: String,
28 pub n_steps: usize,
30 pub looping_tool: String,
32 pub loop_starts_at: usize,
34 pub window_size: usize,
36 pub expect_detection: bool,
38}
39
40impl LoopDetectionSimCase {
41 pub fn should_detect(
46 n_steps: usize,
47 looping_tool: impl Into<String>,
48 loop_starts_at: usize,
49 window_size: usize,
50 ) -> Self {
51 Self {
52 name: format!("loop_detection_window{window_size}_step{loop_starts_at}"),
53 n_steps,
54 looping_tool: looping_tool.into(),
55 loop_starts_at,
56 window_size,
57 expect_detection: true,
58 }
59 }
60
61 pub fn should_not_detect(n_steps: usize, window_size: usize) -> Self {
64 Self {
65 name: format!("loop_no_detection_window{window_size}_{n_steps}steps"),
66 n_steps,
67 looping_tool: "read_file".into(),
68 loop_starts_at: usize::MAX,
69 window_size,
70 expect_detection: false,
71 }
72 }
73
74 fn simulate(&self) -> bool {
76 let tool_names = ["read_file", "write_file", "search_code", "list_dir", "bash"];
77 let mut window: VecDeque<String> = VecDeque::with_capacity(self.window_size);
78
79 for step in 1..=self.n_steps {
80 let tool = if step >= self.loop_starts_at {
82 self.looping_tool.clone()
83 } else {
84 tool_names[(step - 1) % tool_names.len()].to_string()
85 };
86
87 if window.len() == self.window_size {
88 window.pop_front();
89 }
90 window.push_back(tool);
91
92 if window.len() == self.window_size && window.iter().all(|n| n == &window[0]) {
94 return true;
95 }
96 }
97 false
98 }
99}
100
101#[async_trait]
102impl EvaluationCase for LoopDetectionSimCase {
103 fn name(&self) -> &str {
104 &self.name
105 }
106
107 fn category(&self) -> &str {
108 "stability/loop_detection"
109 }
110
111 async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
112 let start = std::time::Instant::now();
113 let detected = self.simulate();
114 let ms = start.elapsed().as_millis() as u64;
115
116 if detected == self.expect_detection {
117 Ok(TrialResult::success(trial_id, ms)
118 .with_meta("loop_detected", serde_json::json!(detected))
119 .with_meta("n_steps", serde_json::json!(self.n_steps))
120 .with_meta("window_size", serde_json::json!(self.window_size)))
121 } else {
122 let msg = if self.expect_detection {
123 format!(
124 "Expected loop detection after {} steps (window={}) but none fired",
125 self.n_steps, self.window_size
126 )
127 } else {
128 format!(
129 "Expected no loop detection but one fired at window={}",
130 self.window_size
131 )
132 };
133 Ok(TrialResult::failure(trial_id, ms, msg))
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
144pub struct GoalPreservationCase {
145 name: String,
146 pub n_iterations: usize,
148 pub revalidation_interval: usize,
150 pub goal_text: String,
152}
153
154impl GoalPreservationCase {
155 pub fn new(n_iterations: usize, revalidation_interval: usize) -> Self {
158 Self {
159 name: format!("goal_preservation_{n_iterations}iter_every{revalidation_interval}"),
160 n_iterations,
161 revalidation_interval,
162 goal_text: "Complete the long-horizon task reliably".to_string(),
163 }
164 }
165
166 fn expected_injection_points(&self) -> Vec<usize> {
168 (2..=self.n_iterations)
169 .filter(|&i| {
170 self.revalidation_interval > 0 && (i - 1) % self.revalidation_interval == 0
171 })
172 .collect()
173 }
174
175 fn simulate_injections(&self) -> Vec<usize> {
178 let mut injections = Vec::new();
179 for iteration in 1..=self.n_iterations {
180 if self.revalidation_interval > 0
183 && iteration > 1
184 && (iteration - 1) % self.revalidation_interval == 0
185 {
186 injections.push(iteration);
187 }
188 }
189 injections
190 }
191}
192
193#[async_trait]
194impl EvaluationCase for GoalPreservationCase {
195 fn name(&self) -> &str {
196 &self.name
197 }
198
199 fn category(&self) -> &str {
200 "stability/goal_preservation"
201 }
202
203 async fn run(&self, trial_id: usize) -> anyhow::Result<TrialResult> {
204 let start = std::time::Instant::now();
205 let injected = self.simulate_injections();
206 let expected = self.expected_injection_points();
207 let ms = start.elapsed().as_millis() as u64;
208
209 if self.n_iterations >= 15 && self.revalidation_interval > 0 {
212 let expected_min = 1usize;
213 if injected.len() < expected_min {
214 return Ok(TrialResult::failure(
215 trial_id,
216 ms,
217 format!(
218 "Expected at least {} goal injection(s) across {} iterations \
219 (interval={}), got 0",
220 expected_min, self.n_iterations, self.revalidation_interval
221 ),
222 ));
223 }
224 }
225
226 if injected != expected {
228 return Ok(TrialResult::failure(
229 trial_id,
230 ms,
231 format!(
232 "Goal injection mismatch: expected at iterations {:?}, got {:?}",
233 expected, injected
234 ),
235 ));
236 }
237
238 Ok(TrialResult::success(trial_id, ms)
239 .with_meta("n_iterations", serde_json::json!(self.n_iterations))
240 .with_meta("injections", serde_json::json!(injected.len()))
241 .with_meta("interval", serde_json::json!(self.revalidation_interval)))
242 }
243}
244
245pub fn long_horizon_stability_suite() -> Vec<std::sync::Arc<dyn EvaluationCase>> {
254 vec![
255 std::sync::Arc::new(LoopDetectionSimCase::should_detect(20, "read_file", 3, 5)),
258 std::sync::Arc::new(LoopDetectionSimCase::should_detect(15, "write_file", 1, 5)),
260 std::sync::Arc::new(LoopDetectionSimCase::should_detect(25, "bash", 10, 7)),
262 std::sync::Arc::new(LoopDetectionSimCase::should_detect(
264 30,
265 "search_code",
266 5,
267 10,
268 )),
269 std::sync::Arc::new(LoopDetectionSimCase::should_not_detect(20, 5)),
272 std::sync::Arc::new(LoopDetectionSimCase::should_not_detect(30, 7)),
273 std::sync::Arc::new(GoalPreservationCase::new(15, 10)),
276 std::sync::Arc::new(GoalPreservationCase::new(20, 5)),
278 std::sync::Arc::new(GoalPreservationCase::new(30, 10)),
280 std::sync::Arc::new(GoalPreservationCase::new(50, 15)),
282 ]
283}
284
285#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::suite::EvaluationSuite;
291
292 #[test]
293 fn test_loop_sim_fires_at_correct_step() {
294 let case = LoopDetectionSimCase::should_detect(20, "read_file", 3, 5);
296 assert!(case.simulate(), "expected loop detection to fire");
297 }
298
299 #[test]
300 fn test_loop_sim_does_not_fire_diverse() {
301 let case = LoopDetectionSimCase::should_not_detect(20, 5);
303 assert!(
304 !case.simulate(),
305 "expected no loop detection on diverse sequence"
306 );
307 }
308
309 #[test]
310 fn test_loop_sim_fires_immediately() {
311 let case = LoopDetectionSimCase::should_detect(10, "write_file", 1, 3);
313 assert!(case.simulate());
314 }
315
316 #[test]
317 fn test_loop_sim_short_run_no_loop() {
318 let case = LoopDetectionSimCase::should_detect(2, "read_file", 1, 5);
320 assert!(!case.simulate());
324 }
325
326 #[test]
327 fn test_goal_injection_points_15iter_interval10() {
328 let case = GoalPreservationCase::new(15, 10);
329 let pts = case.expected_injection_points();
330 assert_eq!(pts, vec![11]);
332 }
333
334 #[test]
335 fn test_goal_injection_points_20iter_interval5() {
336 let case = GoalPreservationCase::new(20, 5);
337 let pts = case.expected_injection_points();
338 assert_eq!(pts, vec![6, 11, 16]);
340 }
341
342 #[test]
343 fn test_goal_injection_simulation_matches_expected() {
344 let case = GoalPreservationCase::new(30, 10);
345 assert_eq!(case.simulate_injections(), case.expected_injection_points());
346 }
347
348 #[tokio::test]
349 async fn test_loop_detection_case_succeeds_when_loop_fires() {
350 let case = LoopDetectionSimCase::should_detect(20, "read_file", 3, 5);
351 let result = case.run(0).await.unwrap();
352 assert!(
353 result.success,
354 "case should succeed when detection fires as expected: {:?}",
355 result.error
356 );
357 }
358
359 #[tokio::test]
360 async fn test_loop_detection_case_fails_when_no_loop_fires() {
361 let case = LoopDetectionSimCase::should_detect(2, "read_file", 1, 5);
363 let result = case.run(0).await.unwrap();
364 assert!(
365 !result.success,
366 "case should fail when expected detection didn't fire"
367 );
368 }
369
370 #[tokio::test]
371 async fn test_goal_preservation_case_succeeds() {
372 let case = GoalPreservationCase::new(20, 5);
373 let result = case.run(0).await.unwrap();
374 assert!(
375 result.success,
376 "goal preservation case should pass: {:?}",
377 result.error
378 );
379 }
380
381 #[tokio::test]
382 async fn test_full_stability_suite_runs() {
383 let suite = EvaluationSuite::new(1);
384 let cases = long_horizon_stability_suite();
385 let results = suite.run_suite(&cases).await;
386 assert!(!results.case_results.is_empty());
390 }
391}