1use aios_protocol::SteeringMode;
7use serde::{Deserialize, Serialize};
8use std::collections::VecDeque;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12#[derive(Debug, Clone)]
16pub struct QueueConfig {
17 pub max_queue_depth: usize,
19 pub steer_timeout: Duration,
21 pub collect_coalesce_window: Duration,
23}
24
25impl Default for QueueConfig {
26 fn default() -> Self {
27 Self {
28 max_queue_depth: 10,
29 steer_timeout: Duration::from_secs(30),
30 collect_coalesce_window: Duration::from_secs(2),
31 }
32 }
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct QueuedMessage {
40 pub id: String,
42 pub mode: SteeringMode,
44 pub content: String,
46 #[serde(skip)]
48 pub queued_at: Option<Instant>,
49}
50
51#[derive(Debug, Clone)]
55pub enum SteeringAction {
56 Continue,
58 InjectMessage(String),
60 CompleteAndSwitch(QueuedMessage),
62 Abort { reason: String },
64}
65
66pub trait PreemptionCheck: Send + Sync {
73 fn check_preemption(&self) -> Result<SteeringAction, QueueError>;
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct QueueStatus {
83 pub depth: usize,
84 pub pending: Vec<QueuedMessage>,
85 pub has_active_run: bool,
86 pub oldest_message_age_ms: Option<u64>,
87}
88
89#[derive(Debug, Clone, thiserror::Error)]
93pub enum QueueError {
94 #[error("queue is full (depth: {depth}, max: {max})")]
95 QueueFull { depth: usize, max: usize },
96 #[error("message not found: {id}")]
97 NotFound { id: String },
98 #[error("internal lock poisoned: {0}")]
99 LockPoisoned(String),
100}
101
102pub struct MessageQueue {
110 inner: Arc<Mutex<QueueInner>>,
111 config: QueueConfig,
112}
113
114struct QueueInner {
115 pending: VecDeque<QueuedMessage>,
116 has_active_run: bool,
117}
118
119impl MessageQueue {
120 pub fn new(config: QueueConfig) -> Self {
122 Self {
123 inner: Arc::new(Mutex::new(QueueInner {
124 pending: VecDeque::new(),
125 has_active_run: false,
126 })),
127 config,
128 }
129 }
130
131 pub fn enqueue(&self, message: QueuedMessage) -> Result<(), QueueError> {
133 let mut inner = self
134 .inner
135 .lock()
136 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
137 if inner.pending.len() >= self.config.max_queue_depth {
138 return Err(QueueError::QueueFull {
139 depth: inner.pending.len(),
140 max: self.config.max_queue_depth,
141 });
142 }
143 let mut msg = message;
144 msg.queued_at = Some(Instant::now());
145 inner.pending.push_back(msg);
146 Ok(())
147 }
148
149 pub fn remove(&self, id: &str) -> Result<QueuedMessage, QueueError> {
151 let mut inner = self
152 .inner
153 .lock()
154 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
155 let pos = inner
156 .pending
157 .iter()
158 .position(|m| m.id == id)
159 .ok_or_else(|| QueueError::NotFound { id: id.to_owned() })?;
160 Ok(inner.pending.remove(pos).expect("position valid"))
161 }
162
163 pub fn status(&self) -> Result<QueueStatus, QueueError> {
165 let inner = self
166 .inner
167 .lock()
168 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
169 let oldest_age = inner
170 .pending
171 .front()
172 .and_then(|m| m.queued_at.map(|t| t.elapsed().as_millis() as u64));
173 Ok(QueueStatus {
174 depth: inner.pending.len(),
175 pending: inner.pending.iter().cloned().collect(),
176 has_active_run: inner.has_active_run,
177 oldest_message_age_ms: oldest_age,
178 })
179 }
180
181 pub fn set_active_run(&self, active: bool) -> Result<(), QueueError> {
183 let mut inner = self
184 .inner
185 .lock()
186 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
187 inner.has_active_run = active;
188 Ok(())
189 }
190
191 pub fn has_active_run(&self) -> Result<bool, QueueError> {
193 let inner = self
194 .inner
195 .lock()
196 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
197 Ok(inner.has_active_run)
198 }
199
200 pub fn check_preemption(&self) -> Result<SteeringAction, QueueError> {
205 let mut inner = self
206 .inner
207 .lock()
208 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
209
210 if let Some(pos) = inner
212 .pending
213 .iter()
214 .position(|m| m.mode == SteeringMode::Interrupt)
215 {
216 let msg = inner.pending.remove(pos).expect("position valid");
217 return Ok(SteeringAction::Abort {
218 reason: format!("interrupted by queue message: {}", msg.id),
219 });
220 }
221
222 if let Some(pos) = inner
224 .pending
225 .iter()
226 .position(|m| m.mode == SteeringMode::Steer)
227 {
228 let msg = inner.pending.remove(pos).expect("position valid");
229 return Ok(SteeringAction::CompleteAndSwitch(msg));
230 }
231
232 Ok(SteeringAction::Continue)
233 }
234
235 pub fn drain_after_run(&self) -> Result<Vec<QueuedMessage>, QueueError> {
240 let mut inner = self
241 .inner
242 .lock()
243 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
244 inner.has_active_run = false;
245
246 if inner.pending.is_empty() {
247 return Ok(Vec::new());
248 }
249
250 let mut followups = Vec::new();
251 let mut collects = Vec::new();
252 let mut remaining = VecDeque::new();
253
254 for msg in inner.pending.drain(..) {
255 match msg.mode {
256 SteeringMode::Followup => followups.push(msg),
257 SteeringMode::Collect => collects.push(msg),
258 SteeringMode::Interrupt | SteeringMode::Steer => collects.push(msg),
261 }
262 }
263
264 let window = self.config.collect_coalesce_window;
266 if collects.len() > 1 {
267 let now = Instant::now();
268 let (within_window, outside): (Vec<_>, Vec<_>) = collects
269 .into_iter()
270 .partition(|m| m.queued_at.is_some_and(|t| now.duration_since(t) <= window));
271 for msg in outside {
272 remaining.push_back(msg);
273 }
274 collects = within_window;
275 }
276
277 inner.pending = remaining;
278
279 let mut result = followups;
280 result.extend(collects);
281 Ok(result)
282 }
283
284 pub fn health_check(&self) -> Result<Vec<String>, QueueError> {
286 let inner = self
287 .inner
288 .lock()
289 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
290 let mut warnings = Vec::new();
291
292 let depth = inner.pending.len();
293 let threshold = self.config.max_queue_depth / 2;
294 if depth > threshold {
295 warnings.push(format!(
296 "queue depth {depth} exceeds warning threshold {threshold}"
297 ));
298 }
299
300 let stale_timeout = self.config.steer_timeout * 2;
301 if let Some(oldest) = inner.pending.front()
302 && let Some(queued_at) = oldest.queued_at
303 && queued_at.elapsed() > stale_timeout
304 {
305 warnings.push(format!(
306 "oldest message {} is stale ({:.1}s old)",
307 oldest.id,
308 queued_at.elapsed().as_secs_f64()
309 ));
310 }
311
312 Ok(warnings)
313 }
314
315 pub fn depth(&self) -> Result<usize, QueueError> {
317 let inner = self
318 .inner
319 .lock()
320 .map_err(|e| QueueError::LockPoisoned(e.to_string()))?;
321 Ok(inner.pending.len())
322 }
323
324 pub fn config(&self) -> &QueueConfig {
326 &self.config
327 }
328}
329
330impl PreemptionCheck for MessageQueue {
331 fn check_preemption(&self) -> Result<SteeringAction, QueueError> {
332 self.check_preemption()
333 }
334}
335
336impl Clone for MessageQueue {
337 fn clone(&self) -> Self {
338 Self {
339 inner: self.inner.clone(),
340 config: self.config.clone(),
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 fn make_msg(id: &str, mode: SteeringMode, content: &str) -> QueuedMessage {
350 QueuedMessage {
351 id: id.to_string(),
352 mode,
353 content: content.to_string(),
354 queued_at: None,
355 }
356 }
357
358 #[test]
359 fn collect_mode_queued_and_drained() {
360 let queue = MessageQueue::new(QueueConfig::default());
361 queue.set_active_run(true).unwrap();
362
363 queue
364 .enqueue(make_msg("q1", SteeringMode::Collect, "do this later"))
365 .unwrap();
366
367 assert_eq!(queue.depth().unwrap(), 1);
368 assert!(queue.has_active_run().unwrap());
369
370 assert!(matches!(
371 queue.check_preemption().unwrap(),
372 SteeringAction::Continue
373 ));
374
375 let drained = queue.drain_after_run().unwrap();
376 assert_eq!(drained.len(), 1);
377 assert_eq!(drained[0].id, "q1");
378 assert!(!queue.has_active_run().unwrap());
379 }
380
381 #[test]
382 fn steer_mode_preempts_at_tool_boundary() {
383 let queue = MessageQueue::new(QueueConfig::default());
384 queue.set_active_run(true).unwrap();
385
386 queue
387 .enqueue(make_msg("q1", SteeringMode::Steer, "do this instead"))
388 .unwrap();
389
390 match queue.check_preemption().unwrap() {
391 SteeringAction::CompleteAndSwitch(msg) => {
392 assert_eq!(msg.id, "q1");
393 assert_eq!(msg.content, "do this instead");
394 }
395 other => panic!("expected CompleteAndSwitch, got {other:?}"),
396 }
397
398 assert_eq!(queue.depth().unwrap(), 0);
399 }
400
401 #[test]
402 fn followup_inherits_context_order() {
403 let queue = MessageQueue::new(QueueConfig::default());
404 queue.set_active_run(true).unwrap();
405
406 queue
407 .enqueue(make_msg("c1", SteeringMode::Collect, "fresh run"))
408 .unwrap();
409 queue
410 .enqueue(make_msg("f1", SteeringMode::Followup, "same context"))
411 .unwrap();
412
413 let drained = queue.drain_after_run().unwrap();
414 assert_eq!(drained.len(), 2);
415 assert_eq!(drained[0].id, "f1");
416 assert_eq!(drained[1].id, "c1");
417 }
418
419 #[test]
420 fn interrupt_aborts_at_tool_boundary() {
421 let queue = MessageQueue::new(QueueConfig::default());
422 queue.set_active_run(true).unwrap();
423
424 queue
425 .enqueue(make_msg("i1", SteeringMode::Interrupt, "stop now"))
426 .unwrap();
427
428 match queue.check_preemption().unwrap() {
429 SteeringAction::Abort { reason } => {
430 assert!(reason.contains("i1"));
431 }
432 other => panic!("expected Abort, got {other:?}"),
433 }
434 }
435
436 #[test]
437 fn queue_depth_limit_enforced() {
438 let config = QueueConfig {
439 max_queue_depth: 2,
440 ..Default::default()
441 };
442 let queue = MessageQueue::new(config);
443
444 queue
445 .enqueue(make_msg("q1", SteeringMode::Collect, "1"))
446 .unwrap();
447 queue
448 .enqueue(make_msg("q2", SteeringMode::Collect, "2"))
449 .unwrap();
450
451 let result = queue.enqueue(make_msg("q3", SteeringMode::Collect, "3"));
452 assert!(result.is_err());
453 assert!(matches!(
454 result.unwrap_err(),
455 QueueError::QueueFull { depth: 2, max: 2 }
456 ));
457 }
458
459 #[test]
460 fn steer_timeout_falls_back_to_collect() {
461 let queue = MessageQueue::new(QueueConfig::default());
462 queue.set_active_run(true).unwrap();
463
464 queue
465 .enqueue(make_msg("s1", SteeringMode::Steer, "should steer"))
466 .unwrap();
467
468 let drained = queue.drain_after_run().unwrap();
469 assert_eq!(drained.len(), 1);
470 assert_eq!(drained[0].id, "s1");
471 }
472
473 #[test]
474 fn collect_messages_coalesced_within_window() {
475 let config = QueueConfig {
476 collect_coalesce_window: Duration::from_secs(10),
477 ..Default::default()
478 };
479 let queue = MessageQueue::new(config);
480 queue.set_active_run(true).unwrap();
481
482 queue
483 .enqueue(make_msg("c1", SteeringMode::Collect, "a"))
484 .unwrap();
485 queue
486 .enqueue(make_msg("c2", SteeringMode::Collect, "b"))
487 .unwrap();
488 queue
489 .enqueue(make_msg("c3", SteeringMode::Collect, "c"))
490 .unwrap();
491
492 let drained = queue.drain_after_run().unwrap();
493 assert_eq!(drained.len(), 3);
494 }
495
496 #[test]
497 fn drain_order_interrupt_steer_followup_collect() {
498 let queue = MessageQueue::new(QueueConfig::default());
499 queue.set_active_run(true).unwrap();
500
501 queue
502 .enqueue(make_msg("c1", SteeringMode::Collect, "collect"))
503 .unwrap();
504 queue
505 .enqueue(make_msg("f1", SteeringMode::Followup, "followup"))
506 .unwrap();
507 queue
508 .enqueue(make_msg("i1", SteeringMode::Interrupt, "interrupt"))
509 .unwrap();
510 queue
511 .enqueue(make_msg("s1", SteeringMode::Steer, "steer"))
512 .unwrap();
513
514 match queue.check_preemption().unwrap() {
515 SteeringAction::Abort { .. } => {}
516 other => panic!("expected Abort from interrupt, got {other:?}"),
517 }
518
519 match queue.check_preemption().unwrap() {
520 SteeringAction::CompleteAndSwitch(msg) => assert_eq!(msg.id, "s1"),
521 other => panic!("expected CompleteAndSwitch from steer, got {other:?}"),
522 }
523
524 let drained = queue.drain_after_run().unwrap();
525 assert_eq!(drained.len(), 2);
526 assert_eq!(drained[0].id, "f1");
527 assert_eq!(drained[1].id, "c1");
528 }
529
530 #[test]
531 fn preemption_returns_continue_on_empty_queue() {
532 let queue = MessageQueue::new(QueueConfig::default());
533 assert!(matches!(
534 queue.check_preemption().unwrap(),
535 SteeringAction::Continue
536 ));
537 }
538
539 #[test]
540 fn remove_specific_message() {
541 let queue = MessageQueue::new(QueueConfig::default());
542 queue
543 .enqueue(make_msg("q1", SteeringMode::Collect, "a"))
544 .unwrap();
545 queue
546 .enqueue(make_msg("q2", SteeringMode::Collect, "b"))
547 .unwrap();
548
549 let removed = queue.remove("q1").unwrap();
550 assert_eq!(removed.id, "q1");
551 assert_eq!(queue.depth().unwrap(), 1);
552
553 assert!(queue.remove("q99").is_err());
554 }
555
556 #[test]
557 fn status_snapshot() {
558 let queue = MessageQueue::new(QueueConfig::default());
559 queue.set_active_run(true).unwrap();
560 queue
561 .enqueue(make_msg("q1", SteeringMode::Collect, "test"))
562 .unwrap();
563
564 let status = queue.status().unwrap();
565 assert_eq!(status.depth, 1);
566 assert!(status.has_active_run);
567 assert_eq!(status.pending.len(), 1);
568 assert!(status.oldest_message_age_ms.is_some());
569 }
570
571 #[test]
572 fn health_check_warns_on_depth() {
573 let config = QueueConfig {
574 max_queue_depth: 4,
575 ..Default::default()
576 };
577 let queue = MessageQueue::new(config);
578
579 for i in 0..3 {
580 queue
581 .enqueue(make_msg(&format!("q{i}"), SteeringMode::Collect, "x"))
582 .unwrap();
583 }
584
585 let warnings = queue.health_check().unwrap();
586 assert!(
587 warnings
588 .iter()
589 .any(|w| w.contains("exceeds warning threshold"))
590 );
591 }
592
593 #[test]
594 fn clone_shares_state() {
595 let queue = MessageQueue::new(QueueConfig::default());
596 let queue2 = queue.clone();
597
598 queue
599 .enqueue(make_msg("q1", SteeringMode::Collect, "shared"))
600 .unwrap();
601 assert_eq!(queue2.depth().unwrap(), 1);
602 }
603
604 #[test]
605 fn multiple_followups_preserved_in_order() {
606 let queue = MessageQueue::new(QueueConfig::default());
607 queue.set_active_run(true).unwrap();
608
609 queue
610 .enqueue(make_msg("f1", SteeringMode::Followup, "first"))
611 .unwrap();
612 queue
613 .enqueue(make_msg("f2", SteeringMode::Followup, "second"))
614 .unwrap();
615 queue
616 .enqueue(make_msg("f3", SteeringMode::Followup, "third"))
617 .unwrap();
618
619 let drained = queue.drain_after_run().unwrap();
620 assert_eq!(drained.len(), 3);
621 assert_eq!(drained[0].id, "f1");
622 assert_eq!(drained[1].id, "f2");
623 assert_eq!(drained[2].id, "f3");
624 }
625
626 #[test]
627 fn drain_empty_queue_returns_empty() {
628 let queue = MessageQueue::new(QueueConfig::default());
629 queue.set_active_run(true).unwrap();
630 let drained = queue.drain_after_run().unwrap();
631 assert!(drained.is_empty());
632 }
633}