1use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
29use std::sync::{Arc, Weak};
30use std::time::Duration;
31
32use async_trait::async_trait;
33use oxi_ai::Message;
34use parking_lot::Mutex;
35use tokio::sync::oneshot;
36
37use crate::advisor::types::AdvisorNote;
38
39#[async_trait]
42pub trait AdvisorAgent: Send + Sync + 'static {
43 async fn prompt(&self, input: String) -> Result<(), String>;
46 fn abort(&self, reason: &str);
48 fn reset(&self);
50 async fn rollback_to(&self, count: usize);
53 fn message_count(&self) -> usize;
55}
56
57pub trait AdvisorRuntimeHost: Send + Sync + 'static {
59 fn snapshot_messages(&self) -> Vec<Message>;
62 fn enqueue_advice(&self, note: AdvisorNote);
65 fn maintain_context(&self, _incoming_tokens: usize) -> bool {
69 false
70 }
71 fn begin_advisor_update(&self) {}
75 fn notify_failure(&self, _error: &str) {}
78}
79
80struct PendingDelta {
82 text: String,
83 turns: u64,
85}
86
87#[derive(Default)]
91struct DrainState {
92 pending: Vec<PendingDelta>,
93 draining: bool,
94}
95
96struct CatchupWaiter {
98 threshold: u64,
99 tx: Option<oneshot::Sender<()>>,
100}
101
102pub struct AdvisorRuntime {
106 agent: Arc<dyn AdvisorAgent>,
107 host: Arc<dyn AdvisorRuntimeHost>,
108
109 state: Mutex<DrainState>,
110 epoch: AtomicU64,
114 backlog: AtomicU64,
116 last_count: AtomicU64,
118 latest: Mutex<Option<Vec<Message>>>,
120
121 waiters: Mutex<Vec<CatchupWaiter>>,
122
123 consecutive_failures: AtomicU32,
124 failure_notified: AtomicBool,
125 disposed: AtomicBool,
126 retry_delay: Duration,
127
128 self_ref: Mutex<Option<Weak<AdvisorRuntime>>>,
130}
131
132impl AdvisorRuntime {
133 #[must_use]
136 pub fn new(
137 agent: Arc<dyn AdvisorAgent>,
138 host: Arc<dyn AdvisorRuntimeHost>,
139 retry_delay: Duration,
140 ) -> Self {
141 Self {
142 agent,
143 host,
144 state: Mutex::new(DrainState::default()),
145 epoch: AtomicU64::new(0),
146 backlog: AtomicU64::new(0),
147 last_count: AtomicU64::new(0),
148 latest: Mutex::new(None),
149 waiters: Mutex::new(Vec::new()),
150 consecutive_failures: AtomicU32::new(0),
151 failure_notified: AtomicBool::new(false),
152 disposed: AtomicBool::new(false),
153 retry_delay,
154 self_ref: Mutex::new(None),
155 }
156 }
157
158 pub fn install_self(&self, weak: Weak<AdvisorRuntime>) {
161 *self.self_ref.lock() = Some(weak);
162 }
163
164 #[must_use]
166 pub fn backlog(&self) -> u64 {
167 self.backlog.load(Ordering::SeqCst)
168 }
169
170 #[must_use]
172 pub fn is_disposed(&self) -> bool {
173 self.disposed.load(Ordering::SeqCst)
174 }
175
176 pub fn on_turn_end(&self, messages: Vec<Message>) {
180 if self.disposed.load(Ordering::SeqCst) {
181 return;
182 }
183 *self.latest.lock() = Some(messages.clone());
184 let Some(render) = self.render_delta(&messages) else {
185 return;
186 };
187 let spawn = {
188 let mut s = self.state.lock();
189 s.pending.push(PendingDelta {
190 text: render,
191 turns: 1,
192 });
193 self.backlog.fetch_add(1, Ordering::SeqCst);
194 !s.draining
195 };
196 self.notify_waiters();
197 let drain_handle = self.self_ref.lock().as_ref().and_then(Weak::upgrade);
198 if spawn && let Some(this) = drain_handle {
199 tokio::spawn(async move {
200 this.drain().await;
201 });
202 }
203 }
204
205 pub async fn wait_for_catchup(&self, max: Duration, threshold: u64) {
210 if self.disposed.load(Ordering::SeqCst) || self.backlog.load(Ordering::SeqCst) < threshold {
211 return;
212 }
213 let (tx, rx) = oneshot::channel();
214 {
215 let mut waiters = self.waiters.lock();
216 if self.backlog.load(Ordering::SeqCst) < threshold {
219 return;
220 }
221 waiters.push(CatchupWaiter {
222 threshold,
223 tx: Some(tx),
224 });
225 }
226 let _ = tokio::time::timeout(max, rx).await;
227 }
228
229 pub fn reset(&self) {
233 self.epoch.fetch_add(1, Ordering::SeqCst);
234 self.reset_advisor_context(true);
235 self.wake_all_waiters();
236 }
237
238 pub fn seed_to(&self, count: u64) {
242 self.epoch.fetch_add(1, Ordering::SeqCst);
243 self.last_count.store(count, Ordering::SeqCst);
244 let mut s = self.state.lock();
245 s.pending.clear();
246 self.backlog.store(0, Ordering::SeqCst);
253 self.consecutive_failures.store(0, Ordering::SeqCst);
254 self.failure_notified.store(false, Ordering::SeqCst);
255 drop(s);
256 self.wake_all_waiters();
257 }
258
259 pub fn dispose(&self) {
262 self.disposed.store(true, Ordering::SeqCst);
263 self.epoch.fetch_add(1, Ordering::SeqCst);
264 let mut s = self.state.lock();
265 s.pending.clear();
266 s.draining = false;
267 self.backlog.store(0, Ordering::SeqCst);
268 drop(s);
269 self.wake_all_waiters();
270 self.agent.abort("advisor disposed");
271 }
272
273 fn reset_advisor_context(&self, clear_backlog: bool) {
274 self.last_count.store(0, Ordering::SeqCst);
275 let mut s = self.state.lock();
276 s.pending.clear();
277 if clear_backlog {
278 self.backlog.store(0, Ordering::SeqCst);
279 }
280 self.consecutive_failures.store(0, Ordering::SeqCst);
281 self.failure_notified.store(false, Ordering::SeqCst);
282 drop(s);
283 self.agent.reset();
284 self.agent.abort("advisor reset");
285 }
286
287 fn render_delta(&self, messages: &[Message]) -> Option<String> {
290 let last = self.last_count.load(Ordering::SeqCst) as usize;
291 if messages.len() < last {
292 self.last_count
293 .store(messages.len() as u64, Ordering::SeqCst);
294 return None;
295 }
296 let delta = &messages[last..];
297 self.last_count
298 .store(messages.len() as u64, Ordering::SeqCst);
299 if delta.is_empty() {
300 return None;
301 }
302 let mut parts: Vec<String> = Vec::new();
303 for msg in delta {
304 if let Some(md) = format_message_md(msg) {
305 parts.push(md);
306 }
307 }
308 if parts.is_empty() {
309 return None;
310 }
311 Some(format!("### Session update\n\n{}", parts.join("\n\n")))
312 }
313
314 fn wake_all_waiters(&self) {
315 let mut waiters = self.waiters.lock();
316 for w in waiters.drain(..) {
317 if let Some(tx) = w.tx {
318 let _ = tx.send(());
319 }
320 }
321 }
322
323 fn notify_waiters(&self) {
324 let mut waiters = self.waiters.lock();
325 let backlog = self.backlog.load(Ordering::SeqCst);
326 for w in waiters.iter_mut() {
327 if backlog < w.threshold
328 && let Some(tx) = w.tx.take()
329 {
330 let _ = tx.send(());
331 }
332 }
333 waiters.retain(|w| w.tx.is_some());
334 }
335
336 fn decrement_backlog(&self, by: u64) {
337 let mut prev = self.backlog.load(Ordering::SeqCst);
338 loop {
339 let next = prev.saturating_sub(by);
340 match self
341 .backlog
342 .compare_exchange(prev, next, Ordering::SeqCst, Ordering::SeqCst)
343 {
344 Ok(_) => break,
345 Err(actual) => prev = actual,
346 }
347 }
348 }
349
350 async fn drain(self: Arc<Self>) {
354 {
355 let mut s = self.state.lock();
356 if s.draining || s.pending.is_empty() {
357 return;
358 }
359 s.draining = true;
360 }
361 loop {
362 let (batch_text, turns_covered) = {
364 let mut s = self.state.lock();
365 if s.pending.is_empty() {
366 s.draining = false;
370 return;
371 }
372 let taken: Vec<PendingDelta> = s.pending.drain(..).collect();
373 let turns: u64 = taken.iter().map(|d| d.turns).sum();
374 let joined = taken
375 .into_iter()
376 .map(|d| d.text)
377 .collect::<Vec<_>>()
378 .join("\n\n");
379 (joined, turns)
380 };
381
382 let epoch_start = self.epoch.load(Ordering::SeqCst);
383
384 let should_reprime = self.host.maintain_context(batch_text.len());
387 if self.epoch.load(Ordering::SeqCst) != epoch_start {
388 continue;
389 }
390
391 let (batch, final_turns) = if should_reprime {
392 self.reset_advisor_context(false);
395 let new_turns = self.state.lock().pending.len() as u64;
396 let rendered = self
397 .latest
398 .lock()
399 .as_ref()
400 .and_then(|m| self.render_delta(m));
401 let final_turns = turns_covered.saturating_add(new_turns);
402 match rendered {
403 Some(b) => (b, final_turns),
404 None => {
405 self.decrement_backlog(final_turns);
406 self.notify_waiters();
407 continue;
408 }
409 }
410 } else {
411 (batch_text, turns_covered)
412 };
413
414 if self.disposed.load(Ordering::SeqCst) {
415 self.decrement_backlog(final_turns);
416 self.notify_waiters();
417 continue;
418 }
419
420 let message_snapshot = self.agent.message_count();
421 self.host.begin_advisor_update();
422 let prompt_result = self.agent.prompt(batch.clone()).await;
423
424 if self.epoch.load(Ordering::SeqCst) != epoch_start {
427 continue;
428 }
429
430 let success;
431 match prompt_result {
432 Ok(()) => {
433 self.consecutive_failures.store(0, Ordering::SeqCst);
434 self.failure_notified.store(false, Ordering::SeqCst);
435 success = true;
436 }
437 Err(err) => {
438 self.agent.rollback_to(message_snapshot).await;
439 let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
440 if failures >= 3 {
441 tracing::warn!(
442 failures,
443 "advisor failed consecutively; dropping backlog to prevent stall"
444 );
445 if !self.failure_notified.swap(true, Ordering::SeqCst) {
446 self.host.notify_failure(&err);
447 }
448 self.consecutive_failures.store(0, Ordering::SeqCst);
449 success = true;
450 } else {
451 {
453 let mut s = self.state.lock();
454 s.pending.insert(
455 0,
456 PendingDelta {
457 text: batch,
458 turns: final_turns,
459 },
460 );
461 }
462 tokio::time::sleep(self.retry_delay).await;
463 continue;
464 }
465 }
466 }
467
468 if success {
469 self.decrement_backlog(final_turns);
470 self.notify_waiters();
471 }
472 }
473 }
474}
475
476fn format_message_md(msg: &Message) -> Option<String> {
480 let role = match msg {
481 Message::User(_) => "user",
482 Message::Assistant(_) => "assistant",
483 Message::ToolResult(_) => "tool",
484 };
485 let text = msg.text_content().unwrap_or_default();
486 if text.trim().is_empty() {
487 return None;
488 }
489 Some(format!("**[{role}]**\n{text}"))
490}
491
492#[cfg(test)]
493mod tests {
494 #![allow(clippy::unwrap_used)]
495 use super::*;
496 use std::sync::Mutex as StdMutex;
497 type PromptLog = Arc<StdMutex<Vec<String>>>;
498 type AdviceLog = Arc<StdMutex<Vec<AdvisorNote>>>;
499
500 struct FakeAgent {
502 prompts: PromptLog,
503 fail_first_n: AtomicU32,
504 messages_len: AtomicU64,
505 }
506
507 impl FakeAgent {
508 fn new() -> (Arc<Self>, PromptLog) {
509 let prompts = Arc::new(StdMutex::new(Vec::new()));
510 let a = Arc::new(Self {
511 prompts: Arc::clone(&prompts),
512 fail_first_n: AtomicU32::new(0),
513 messages_len: AtomicU64::new(0),
514 });
515 (a, prompts)
516 }
517 }
518
519 #[async_trait]
520 impl AdvisorAgent for FakeAgent {
521 async fn prompt(&self, input: String) -> Result<(), String> {
522 self.messages_len.fetch_add(4, Ordering::SeqCst);
524 self.prompts.lock().unwrap().push(input);
525 let n = self.fail_first_n.load(Ordering::SeqCst);
529 if n > 0 {
530 self.fail_first_n.fetch_sub(1, Ordering::SeqCst);
531 Err("simulated advisor failure".into())
532 } else {
533 Ok(())
534 }
535 }
536 fn abort(&self, _reason: &str) {}
537 fn reset(&self) {
538 self.messages_len.store(0, Ordering::SeqCst);
539 }
540 async fn rollback_to(&self, count: usize) {
541 self.messages_len.store(count as u64, Ordering::SeqCst);
542 }
543 fn message_count(&self) -> usize {
544 self.messages_len.load(Ordering::SeqCst) as usize
545 }
546 }
547
548 struct FakeHost {
550 advice: AdviceLog,
551 }
552 impl AdvisorRuntimeHost for FakeHost {
553 fn snapshot_messages(&self) -> Vec<Message> {
554 Vec::new()
555 }
556 fn enqueue_advice(&self, note: AdvisorNote) {
557 self.advice.lock().unwrap().push(note);
558 }
559 }
560
561 fn build() -> (Arc<AdvisorRuntime>, PromptLog, AdviceLog) {
562 let (agent, prompts) = FakeAgent::new();
563 let advice = Arc::new(StdMutex::new(Vec::new()));
564 let host: Arc<dyn AdvisorRuntimeHost> = Arc::new(FakeHost {
565 advice: Arc::clone(&advice),
566 });
567 let rt = Arc::new(AdvisorRuntime::new(agent, host, Duration::from_millis(10)));
568 rt.install_self(Arc::downgrade(&rt));
569 (rt, prompts, advice)
570 }
571
572 fn user_msg(s: &str) -> Message {
573 Message::user(s)
574 }
575
576 #[tokio::test]
577 async fn drain_prompts_advisor_with_delta() {
578 let (rt, prompts, _advice) = build();
579 rt.on_turn_end(vec![user_msg("turn 1")]);
580 tokio::time::sleep(Duration::from_millis(50)).await;
582 let p = prompts.lock().unwrap();
583 assert_eq!(p.len(), 1);
584 assert!(p[0].contains("turn 1"));
585 assert!(p[0].starts_with("### Session update"));
586 }
587
588 #[tokio::test]
589 async fn reset_aborts_inflight_and_drops_batch() {
590 let (rt, prompts, _advice) = build();
591 rt.on_turn_end(vec![user_msg("turn 1")]);
592 rt.reset(); tokio::time::sleep(Duration::from_millis(50)).await;
594 assert_eq!(rt.backlog(), 0);
597 let _ = prompts.lock().unwrap().len();
598 }
599
600 #[tokio::test]
601 async fn drain_exit_racing_turn_end_no_lost_wakeup() {
602 let (rt, _prompts, _advice) = build();
605 let rt2 = Arc::clone(&rt);
606 let handles: Vec<_> = (0..20)
607 .map(move |i| {
608 let rt3 = Arc::clone(&rt2);
609 tokio::spawn(async move {
610 rt3.on_turn_end(vec![user_msg(&format!("turn {i}"))]);
611 })
612 })
613 .collect();
614 for h in handles {
615 h.await.unwrap();
616 }
617 tokio::time::sleep(Duration::from_millis(120)).await;
619 assert_eq!(rt.backlog(), 0);
620 let pending = rt.state.lock().pending.len();
622 assert_eq!(pending, 0);
623 }
624
625 #[tokio::test]
626 async fn wait_for_catchup_resolves_below_threshold() {
627 let (rt, _prompts, _advice) = build();
628 rt.on_turn_end(vec![user_msg("turn 1")]);
629 rt.wait_for_catchup(Duration::from_millis(50), 0).await;
631 let _ = tokio::time::timeout(Duration::from_millis(200), async {
633 while rt.backlog() > 0 {
634 tokio::time::sleep(Duration::from_millis(5)).await;
635 }
636 })
637 .await;
638 assert_eq!(rt.backlog(), 0);
639 }
640
641 #[tokio::test]
642 async fn seed_to_skips_history() {
643 let (rt, prompts, _advice) = build();
644 rt.seed_to(5); rt.on_turn_end(vec![user_msg("a"), user_msg("b"), user_msg("c")]);
647 tokio::time::sleep(Duration::from_millis(30)).await;
648 assert!(prompts.lock().unwrap().is_empty());
649 }
650
651 #[tokio::test]
652 async fn reprime_via_maintain_context() {
653 struct ReprimeHost {
656 advice: Arc<StdMutex<Vec<AdvisorNote>>>,
657 }
658 impl AdvisorRuntimeHost for ReprimeHost {
659 fn snapshot_messages(&self) -> Vec<Message> {
660 Vec::new()
661 }
662 fn enqueue_advice(&self, n: AdvisorNote) {
663 self.advice.lock().unwrap().push(n);
664 }
665 fn maintain_context(&self, _t: usize) -> bool {
666 true
667 }
668 }
669 let (agent, prompts) = FakeAgent::new();
670 let advice = Arc::new(StdMutex::new(Vec::new()));
671 let host: Arc<dyn AdvisorRuntimeHost> = Arc::new(ReprimeHost {
672 advice: Arc::clone(&advice),
673 });
674 let rt = Arc::new(AdvisorRuntime::new(agent, host, Duration::from_millis(10)));
675 rt.install_self(Arc::downgrade(&rt));
676 rt.on_turn_end(vec![user_msg("turn 1"), user_msg("turn 2")]);
677 tokio::time::sleep(Duration::from_millis(60)).await;
678 let p = prompts.lock().unwrap();
679 assert!(!p.is_empty());
680 assert!(p[0].contains("turn 1") && p[0].contains("turn 2"));
682 }
683}