use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Weak};
use std::time::Duration;
use async_trait::async_trait;
use oxi_ai::Message;
use parking_lot::Mutex;
use tokio::sync::oneshot;
use crate::advisor::types::AdvisorNote;
#[async_trait]
pub trait AdvisorAgent: Send + Sync + 'static {
async fn prompt(&self, input: String) -> Result<(), String>;
fn abort(&self, reason: &str);
fn reset(&self);
async fn rollback_to(&self, count: usize);
fn message_count(&self) -> usize;
}
pub trait AdvisorRuntimeHost: Send + Sync + 'static {
fn snapshot_messages(&self) -> Vec<Message>;
fn enqueue_advice(&self, note: AdvisorNote);
fn maintain_context(&self, _incoming_tokens: usize) -> bool {
false
}
fn begin_advisor_update(&self) {}
fn notify_failure(&self, _error: &str) {}
}
struct PendingDelta {
text: String,
turns: u64,
}
#[derive(Default)]
struct DrainState {
pending: Vec<PendingDelta>,
draining: bool,
}
struct CatchupWaiter {
threshold: u64,
tx: Option<oneshot::Sender<()>>,
}
pub struct AdvisorRuntime {
agent: Arc<dyn AdvisorAgent>,
host: Arc<dyn AdvisorRuntimeHost>,
state: Mutex<DrainState>,
epoch: AtomicU64,
backlog: AtomicU64,
last_count: AtomicU64,
latest: Mutex<Option<Vec<Message>>>,
waiters: Mutex<Vec<CatchupWaiter>>,
consecutive_failures: AtomicU32,
failure_notified: AtomicBool,
disposed: AtomicBool,
retry_delay: Duration,
self_ref: Mutex<Option<Weak<AdvisorRuntime>>>,
}
impl AdvisorRuntime {
#[must_use]
pub fn new(
agent: Arc<dyn AdvisorAgent>,
host: Arc<dyn AdvisorRuntimeHost>,
retry_delay: Duration,
) -> Self {
Self {
agent,
host,
state: Mutex::new(DrainState::default()),
epoch: AtomicU64::new(0),
backlog: AtomicU64::new(0),
last_count: AtomicU64::new(0),
latest: Mutex::new(None),
waiters: Mutex::new(Vec::new()),
consecutive_failures: AtomicU32::new(0),
failure_notified: AtomicBool::new(false),
disposed: AtomicBool::new(false),
retry_delay,
self_ref: Mutex::new(None),
}
}
pub fn install_self(&self, weak: Weak<AdvisorRuntime>) {
*self.self_ref.lock() = Some(weak);
}
#[must_use]
pub fn backlog(&self) -> u64 {
self.backlog.load(Ordering::SeqCst)
}
#[must_use]
pub fn is_disposed(&self) -> bool {
self.disposed.load(Ordering::SeqCst)
}
pub fn on_turn_end(&self, messages: Vec<Message>) {
if self.disposed.load(Ordering::SeqCst) {
return;
}
*self.latest.lock() = Some(messages.clone());
let Some(render) = self.render_delta(&messages) else {
return;
};
let spawn = {
let mut s = self.state.lock();
s.pending.push(PendingDelta {
text: render,
turns: 1,
});
self.backlog.fetch_add(1, Ordering::SeqCst);
!s.draining
};
self.notify_waiters();
let drain_handle = self.self_ref.lock().as_ref().and_then(Weak::upgrade);
if spawn && let Some(this) = drain_handle {
tokio::spawn(async move {
this.drain().await;
});
}
}
pub async fn wait_for_catchup(&self, max: Duration, threshold: u64) {
if self.disposed.load(Ordering::SeqCst) || self.backlog.load(Ordering::SeqCst) < threshold {
return;
}
let (tx, rx) = oneshot::channel();
{
let mut waiters = self.waiters.lock();
if self.backlog.load(Ordering::SeqCst) < threshold {
return;
}
waiters.push(CatchupWaiter {
threshold,
tx: Some(tx),
});
}
let _ = tokio::time::timeout(max, rx).await;
}
pub fn reset(&self) {
self.epoch.fetch_add(1, Ordering::SeqCst);
self.reset_advisor_context(true);
self.wake_all_waiters();
}
pub fn seed_to(&self, count: u64) {
self.epoch.fetch_add(1, Ordering::SeqCst);
self.last_count.store(count, Ordering::SeqCst);
let mut s = self.state.lock();
s.pending.clear();
self.backlog.store(0, Ordering::SeqCst);
self.consecutive_failures.store(0, Ordering::SeqCst);
self.failure_notified.store(false, Ordering::SeqCst);
drop(s);
self.wake_all_waiters();
}
pub fn dispose(&self) {
self.disposed.store(true, Ordering::SeqCst);
self.epoch.fetch_add(1, Ordering::SeqCst);
let mut s = self.state.lock();
s.pending.clear();
s.draining = false;
self.backlog.store(0, Ordering::SeqCst);
drop(s);
self.wake_all_waiters();
self.agent.abort("advisor disposed");
}
fn reset_advisor_context(&self, clear_backlog: bool) {
self.last_count.store(0, Ordering::SeqCst);
let mut s = self.state.lock();
s.pending.clear();
if clear_backlog {
self.backlog.store(0, Ordering::SeqCst);
}
self.consecutive_failures.store(0, Ordering::SeqCst);
self.failure_notified.store(false, Ordering::SeqCst);
drop(s);
self.agent.reset();
self.agent.abort("advisor reset");
}
fn render_delta(&self, messages: &[Message]) -> Option<String> {
let last = self.last_count.load(Ordering::SeqCst) as usize;
if messages.len() < last {
self.last_count
.store(messages.len() as u64, Ordering::SeqCst);
return None;
}
let delta = &messages[last..];
self.last_count
.store(messages.len() as u64, Ordering::SeqCst);
if delta.is_empty() {
return None;
}
let mut parts: Vec<String> = Vec::new();
for msg in delta {
if let Some(md) = format_message_md(msg) {
parts.push(md);
}
}
if parts.is_empty() {
return None;
}
Some(format!("### Session update\n\n{}", parts.join("\n\n")))
}
fn wake_all_waiters(&self) {
let mut waiters = self.waiters.lock();
for w in waiters.drain(..) {
if let Some(tx) = w.tx {
let _ = tx.send(());
}
}
}
fn notify_waiters(&self) {
let mut waiters = self.waiters.lock();
let backlog = self.backlog.load(Ordering::SeqCst);
for w in waiters.iter_mut() {
if backlog < w.threshold
&& let Some(tx) = w.tx.take()
{
let _ = tx.send(());
}
}
waiters.retain(|w| w.tx.is_some());
}
fn decrement_backlog(&self, by: u64) {
let mut prev = self.backlog.load(Ordering::SeqCst);
loop {
let next = prev.saturating_sub(by);
match self
.backlog
.compare_exchange(prev, next, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(_) => break,
Err(actual) => prev = actual,
}
}
}
async fn drain(self: Arc<Self>) {
{
let mut s = self.state.lock();
if s.draining || s.pending.is_empty() {
return;
}
s.draining = true;
}
loop {
let (batch_text, turns_covered) = {
let mut s = self.state.lock();
if s.pending.is_empty() {
s.draining = false;
return;
}
let taken: Vec<PendingDelta> = s.pending.drain(..).collect();
let turns: u64 = taken.iter().map(|d| d.turns).sum();
let joined = taken
.into_iter()
.map(|d| d.text)
.collect::<Vec<_>>()
.join("\n\n");
(joined, turns)
};
let epoch_start = self.epoch.load(Ordering::SeqCst);
let should_reprime = self.host.maintain_context(batch_text.len());
if self.epoch.load(Ordering::SeqCst) != epoch_start {
continue;
}
let (batch, final_turns) = if should_reprime {
self.reset_advisor_context(false);
let new_turns = self.state.lock().pending.len() as u64;
let rendered = self
.latest
.lock()
.as_ref()
.and_then(|m| self.render_delta(m));
let final_turns = turns_covered.saturating_add(new_turns);
match rendered {
Some(b) => (b, final_turns),
None => {
self.decrement_backlog(final_turns);
self.notify_waiters();
continue;
}
}
} else {
(batch_text, turns_covered)
};
if self.disposed.load(Ordering::SeqCst) {
self.decrement_backlog(final_turns);
self.notify_waiters();
continue;
}
let message_snapshot = self.agent.message_count();
self.host.begin_advisor_update();
let prompt_result = self.agent.prompt(batch.clone()).await;
if self.epoch.load(Ordering::SeqCst) != epoch_start {
continue;
}
let success;
match prompt_result {
Ok(()) => {
self.consecutive_failures.store(0, Ordering::SeqCst);
self.failure_notified.store(false, Ordering::SeqCst);
success = true;
}
Err(err) => {
self.agent.rollback_to(message_snapshot).await;
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
if failures >= 3 {
tracing::warn!(
failures,
"advisor failed consecutively; dropping backlog to prevent stall"
);
if !self.failure_notified.swap(true, Ordering::SeqCst) {
self.host.notify_failure(&err);
}
self.consecutive_failures.store(0, Ordering::SeqCst);
success = true;
} else {
{
let mut s = self.state.lock();
s.pending.insert(
0,
PendingDelta {
text: batch,
turns: final_turns,
},
);
}
tokio::time::sleep(self.retry_delay).await;
continue;
}
}
}
if success {
self.decrement_backlog(final_turns);
self.notify_waiters();
}
}
}
}
fn format_message_md(msg: &Message) -> Option<String> {
let role = match msg {
Message::User(_) => "user",
Message::Assistant(_) => "assistant",
Message::ToolResult(_) => "tool",
};
let text = msg.text_content().unwrap_or_default();
if text.trim().is_empty() {
return None;
}
Some(format!("**[{role}]**\n{text}"))
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use std::sync::Mutex as StdMutex;
type PromptLog = Arc<StdMutex<Vec<String>>>;
type AdviceLog = Arc<StdMutex<Vec<AdvisorNote>>>;
struct FakeAgent {
prompts: PromptLog,
fail_first_n: AtomicU32,
messages_len: AtomicU64,
}
impl FakeAgent {
fn new() -> (Arc<Self>, PromptLog) {
let prompts = Arc::new(StdMutex::new(Vec::new()));
let a = Arc::new(Self {
prompts: Arc::clone(&prompts),
fail_first_n: AtomicU32::new(0),
messages_len: AtomicU64::new(0),
});
(a, prompts)
}
}
#[async_trait]
impl AdvisorAgent for FakeAgent {
async fn prompt(&self, input: String) -> Result<(), String> {
self.messages_len.fetch_add(4, Ordering::SeqCst);
self.prompts.lock().unwrap().push(input);
let n = self.fail_first_n.load(Ordering::SeqCst);
if n > 0 {
self.fail_first_n.fetch_sub(1, Ordering::SeqCst);
Err("simulated advisor failure".into())
} else {
Ok(())
}
}
fn abort(&self, _reason: &str) {}
fn reset(&self) {
self.messages_len.store(0, Ordering::SeqCst);
}
async fn rollback_to(&self, count: usize) {
self.messages_len.store(count as u64, Ordering::SeqCst);
}
fn message_count(&self) -> usize {
self.messages_len.load(Ordering::SeqCst) as usize
}
}
struct FakeHost {
advice: AdviceLog,
}
impl AdvisorRuntimeHost for FakeHost {
fn snapshot_messages(&self) -> Vec<Message> {
Vec::new()
}
fn enqueue_advice(&self, note: AdvisorNote) {
self.advice.lock().unwrap().push(note);
}
}
fn build() -> (Arc<AdvisorRuntime>, PromptLog, AdviceLog) {
let (agent, prompts) = FakeAgent::new();
let advice = Arc::new(StdMutex::new(Vec::new()));
let host: Arc<dyn AdvisorRuntimeHost> = Arc::new(FakeHost {
advice: Arc::clone(&advice),
});
let rt = Arc::new(AdvisorRuntime::new(agent, host, Duration::from_millis(10)));
rt.install_self(Arc::downgrade(&rt));
(rt, prompts, advice)
}
fn user_msg(s: &str) -> Message {
Message::user(s)
}
#[tokio::test]
async fn drain_prompts_advisor_with_delta() {
let (rt, prompts, _advice) = build();
rt.on_turn_end(vec![user_msg("turn 1")]);
tokio::time::sleep(Duration::from_millis(50)).await;
let p = prompts.lock().unwrap();
assert_eq!(p.len(), 1);
assert!(p[0].contains("turn 1"));
assert!(p[0].starts_with("### Session update"));
}
#[tokio::test]
async fn reset_aborts_inflight_and_drops_batch() {
let (rt, prompts, _advice) = build();
rt.on_turn_end(vec![user_msg("turn 1")]);
rt.reset(); tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(rt.backlog(), 0);
let _ = prompts.lock().unwrap().len();
}
#[tokio::test]
async fn drain_exit_racing_turn_end_no_lost_wakeup() {
let (rt, _prompts, _advice) = build();
let rt2 = Arc::clone(&rt);
let handles: Vec<_> = (0..20)
.map(move |i| {
let rt3 = Arc::clone(&rt2);
tokio::spawn(async move {
rt3.on_turn_end(vec![user_msg(&format!("turn {i}"))]);
})
})
.collect();
for h in handles {
h.await.unwrap();
}
tokio::time::sleep(Duration::from_millis(120)).await;
assert_eq!(rt.backlog(), 0);
let pending = rt.state.lock().pending.len();
assert_eq!(pending, 0);
}
#[tokio::test]
async fn wait_for_catchup_resolves_below_threshold() {
let (rt, _prompts, _advice) = build();
rt.on_turn_end(vec![user_msg("turn 1")]);
rt.wait_for_catchup(Duration::from_millis(50), 0).await;
let _ = tokio::time::timeout(Duration::from_millis(200), async {
while rt.backlog() > 0 {
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
.await;
assert_eq!(rt.backlog(), 0);
}
#[tokio::test]
async fn seed_to_skips_history() {
let (rt, prompts, _advice) = build();
rt.seed_to(5); rt.on_turn_end(vec![user_msg("a"), user_msg("b"), user_msg("c")]);
tokio::time::sleep(Duration::from_millis(30)).await;
assert!(prompts.lock().unwrap().is_empty());
}
#[tokio::test]
async fn reprime_via_maintain_context() {
struct ReprimeHost {
advice: Arc<StdMutex<Vec<AdvisorNote>>>,
}
impl AdvisorRuntimeHost for ReprimeHost {
fn snapshot_messages(&self) -> Vec<Message> {
Vec::new()
}
fn enqueue_advice(&self, n: AdvisorNote) {
self.advice.lock().unwrap().push(n);
}
fn maintain_context(&self, _t: usize) -> bool {
true
}
}
let (agent, prompts) = FakeAgent::new();
let advice = Arc::new(StdMutex::new(Vec::new()));
let host: Arc<dyn AdvisorRuntimeHost> = Arc::new(ReprimeHost {
advice: Arc::clone(&advice),
});
let rt = Arc::new(AdvisorRuntime::new(agent, host, Duration::from_millis(10)));
rt.install_self(Arc::downgrade(&rt));
rt.on_turn_end(vec![user_msg("turn 1"), user_msg("turn 2")]);
tokio::time::sleep(Duration::from_millis(60)).await;
let p = prompts.lock().unwrap();
assert!(!p.is_empty());
assert!(p[0].contains("turn 1") && p[0].contains("turn 2"));
}
}