use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use parking_lot::Mutex;
#[cfg(not(target_arch = "wasm32"))]
use tokio::task::JoinHandle;
use tracing::warn;
use crate::connections::Connection;
use crate::error::{Error, Result};
use crate::runtime::MaybeSendSync;
use crate::types::TriggerDelivery;
#[derive(Clone)]
pub struct TriggerContext {
connection: Arc<dyn Connection>,
}
impl TriggerContext {
pub fn new(connection: Arc<dyn Connection>) -> Self {
Self { connection }
}
pub async fn send(&self, content: impl Into<String>) -> Result<()> {
self.connection.send_trigger(content.into()).await
}
pub async fn send_when_idle(&self, content: impl Into<String>) -> Result<()> {
self.connection.wait_for_idle().await?;
self.send(content).await
}
pub fn is_idle(&self) -> bool {
self.connection.is_idle()
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
pub trait Trigger: MaybeSendSync {
fn name(&self) -> &str;
fn delivery(&self) -> TriggerDelivery {
TriggerDelivery::WaitIdle
}
async fn run(&self, ctx: TriggerContext) -> Result<()>;
}
pub struct TriggerRunner {
triggers: Vec<Arc<dyn Trigger>>,
connection: Arc<dyn Connection>,
#[cfg(not(target_arch = "wasm32"))]
tasks: Mutex<Option<Vec<JoinHandle<()>>>>,
#[cfg(target_arch = "wasm32")]
started: Mutex<bool>,
}
impl TriggerRunner {
pub fn new(triggers: Vec<Arc<dyn Trigger>>, connection: Arc<dyn Connection>) -> Self {
Self {
triggers,
connection,
#[cfg(not(target_arch = "wasm32"))]
tasks: Mutex::new(None),
#[cfg(target_arch = "wasm32")]
started: Mutex::new(false),
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn start(&self) -> Result<()> {
let mut guard = self.tasks.lock();
if guard.is_some() {
return Err(Error::AlreadyStarted);
}
let mut handles = Vec::with_capacity(self.triggers.len());
for trig in &self.triggers {
let ctx = TriggerContext::new(self.connection.clone());
let trig = trig.clone();
handles.push(tokio::spawn(async move {
let name = trig.name().to_string();
if let Err(e) = trig.run(ctx).await {
warn!(%name, error = %e, "trigger exited with error");
}
}));
}
*guard = Some(handles);
Ok(())
}
#[cfg(target_arch = "wasm32")]
pub fn start(&self) -> Result<()> {
let mut guard = self.started.lock();
if *guard {
return Err(Error::AlreadyStarted);
}
for trig in &self.triggers {
let ctx = TriggerContext::new(self.connection.clone());
let trig = trig.clone();
crate::runtime::spawn(async move {
let name = trig.name().to_string();
if let Err(e) = trig.run(ctx).await {
warn!(%name, error = %e, "trigger exited with error");
}
});
}
*guard = true;
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn stop(&self) {
let handles = self.tasks.lock().take();
if let Some(handles) = handles {
for h in &handles {
h.abort();
}
for h in handles {
let _ = h.await;
}
}
}
#[cfg(target_arch = "wasm32")]
pub async fn stop(&self) {
*self.started.lock() = false;
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Drop for TriggerRunner {
fn drop(&mut self) {
if let Some(handles) = self.tasks.lock().take() {
for h in handles {
h.abort();
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn every<F, Fut>(period: Duration, name: impl Into<String>, handler: F) -> Arc<dyn Trigger>
where
F: Fn(TriggerContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<()>> + Send + 'static,
{
Arc::new(PeriodicTrigger {
name: name.into(),
period,
handler: Arc::new(move |c| Box::pin(handler(c))),
})
}
#[cfg(target_arch = "wasm32")]
pub fn every<F, Fut>(period: Duration, name: impl Into<String>, handler: F) -> Arc<dyn Trigger>
where
F: Fn(TriggerContext) -> Fut + 'static,
Fut: std::future::Future<Output = Result<()>> + 'static,
{
Arc::new(PeriodicTrigger {
name: name.into(),
period,
handler: Arc::new(move |c| Box::pin(handler(c))),
})
}
#[cfg(not(target_arch = "wasm32"))]
struct PeriodicTrigger {
name: String,
period: Duration,
handler: Arc<
dyn Fn(TriggerContext) -> futures_util::future::BoxFuture<'static, Result<()>>
+ Send
+ Sync,
>,
}
#[cfg(target_arch = "wasm32")]
struct PeriodicTrigger {
name: String,
period: Duration,
handler: Arc<
dyn Fn(TriggerContext) -> futures_util::future::LocalBoxFuture<'static, Result<()>>,
>,
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Trigger for PeriodicTrigger {
fn name(&self) -> &str {
&self.name
}
#[cfg(not(target_arch = "wasm32"))]
async fn run(&self, ctx: TriggerContext) -> Result<()> {
let mut ticker = tokio::time::interval(self.period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
ticker.tick().await;
loop {
ticker.tick().await;
if let Err(e) = (self.handler)(ctx.clone()).await {
warn!(name = %self.name, error = %e, "periodic trigger handler errored");
}
}
}
#[cfg(target_arch = "wasm32")]
async fn run(&self, ctx: TriggerContext) -> Result<()> {
let ms = u32::try_from(self.period.as_millis()).unwrap_or(u32::MAX);
loop {
crate::runtime::sleep_ms(ms).await;
if let Err(e) = (self.handler)(ctx.clone()).await {
warn!(name = %self.name, error = %e, "periodic trigger handler errored");
}
}
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use crate::content::Content;
use crate::types::ToolResult;
use tokio::sync::Notify;
struct MockConn {
trigger_count: AtomicU32,
idle: AtomicBool,
idle_notify: Notify,
}
impl MockConn {
fn new() -> Arc<Self> {
Arc::new(Self {
trigger_count: AtomicU32::new(0),
idle: AtomicBool::new(false),
idle_notify: Notify::new(),
})
}
fn go_idle(&self) {
self.idle.store(true, Ordering::Release);
self.idle_notify.notify_waiters();
}
fn triggers_sent(&self) -> u32 {
self.trigger_count.load(Ordering::Acquire)
}
}
#[async_trait]
impl Connection for MockConn {
fn is_idle(&self) -> bool {
self.idle.load(Ordering::Acquire)
}
fn conversation_id(&self) -> &str {
"mock"
}
async fn send(&self, _content: Content) -> Result<()> {
Ok(())
}
async fn send_trigger(&self, _content: String) -> Result<()> {
self.trigger_count.fetch_add(1, Ordering::AcqRel);
Ok(())
}
async fn send_tool_results(&self, _results: Vec<ToolResult>) -> Result<()> {
Ok(())
}
fn subscribe_steps(&self) -> crate::connections::StepStream {
Box::pin(futures_util::stream::empty())
}
async fn wait_for_idle(&self) -> Result<()> {
loop {
if self.is_idle() {
return Ok(());
}
self.idle_notify.notified().await;
}
}
async fn shutdown(&self) -> Result<()> {
Ok(())
}
}
struct OneShot {
name: String,
ran: Arc<AtomicBool>,
delivery: TriggerDelivery,
}
#[async_trait]
impl Trigger for OneShot {
fn name(&self) -> &str {
&self.name
}
fn delivery(&self) -> TriggerDelivery {
self.delivery
}
async fn run(&self, ctx: TriggerContext) -> Result<()> {
match self.delivery {
TriggerDelivery::WaitIdle => ctx.send_when_idle("ping").await?,
TriggerDelivery::SendImmediately => ctx.send("ping").await?,
}
self.ran.store(true, Ordering::Release);
Ok(())
}
}
struct AlwaysErr {
name: String,
}
#[async_trait]
impl Trigger for AlwaysErr {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, _ctx: TriggerContext) -> Result<()> {
Err(Error::other("trigger boom"))
}
}
struct Panicker {
name: String,
}
#[async_trait]
impl Trigger for Panicker {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, _ctx: TriggerContext) -> Result<()> {
panic!("trigger panic");
}
}
#[tokio::test]
async fn double_start_is_rejected() {
let conn = MockConn::new();
let runner = TriggerRunner::new(vec![], conn.clone());
runner.start().expect("first start ok");
let err = runner.start().expect_err("second start must fail");
assert!(matches!(err, Error::AlreadyStarted));
runner.stop().await;
}
#[tokio::test]
async fn start_after_stop_succeeds() {
let conn = MockConn::new();
let runner = TriggerRunner::new(vec![], conn.clone());
runner.start().expect("start");
runner.stop().await;
runner.start().expect("restart after stop");
runner.stop().await;
}
#[tokio::test]
async fn immediate_trigger_delivers_and_stop_joins() {
let conn = MockConn::new();
let ran = Arc::new(AtomicBool::new(false));
let trig = Arc::new(OneShot {
name: "imm".into(),
ran: ran.clone(),
delivery: TriggerDelivery::SendImmediately,
});
let runner = TriggerRunner::new(vec![trig], conn.clone());
runner.start().expect("start");
tokio::time::timeout(Duration::from_secs(5), async {
while conn.triggers_sent() == 0 {
tokio::task::yield_now().await;
}
})
.await
.expect("immediate trigger must deliver");
runner.stop().await;
assert_eq!(conn.triggers_sent(), 1, "exactly one trigger send");
assert!(ran.load(Ordering::Acquire), "trigger body ran to completion");
}
#[tokio::test]
async fn send_when_idle_waits_for_idle() {
let conn = MockConn::new(); let ran = Arc::new(AtomicBool::new(false));
let trig = Arc::new(OneShot {
name: "idle".into(),
ran: ran.clone(),
delivery: TriggerDelivery::WaitIdle,
});
let runner = TriggerRunner::new(vec![trig], conn.clone());
runner.start().expect("start");
for _ in 0..8 {
tokio::task::yield_now().await;
}
assert_eq!(conn.triggers_sent(), 0, "must not fire while turn is active");
assert!(!ran.load(Ordering::Acquire));
conn.go_idle();
tokio::time::timeout(Duration::from_secs(5), async {
while conn.triggers_sent() == 0 {
tokio::task::yield_now().await;
}
})
.await
.expect("idle trigger must fire once idle");
runner.stop().await;
assert_eq!(conn.triggers_sent(), 1);
}
#[tokio::test]
async fn erroring_trigger_does_not_kill_siblings() {
let conn = MockConn::new();
let ran = Arc::new(AtomicBool::new(false));
let good = Arc::new(OneShot {
name: "good".into(),
ran: ran.clone(),
delivery: TriggerDelivery::SendImmediately,
});
let bad = Arc::new(AlwaysErr { name: "bad".into() });
let runner = TriggerRunner::new(vec![bad, good], conn.clone());
runner.start().expect("start");
tokio::time::timeout(Duration::from_secs(5), async {
while !ran.load(Ordering::Acquire) {
tokio::task::yield_now().await;
}
})
.await
.expect("the good trigger must still run despite the erroring sibling");
runner.stop().await;
assert_eq!(conn.triggers_sent(), 1, "only the good trigger sent");
}
#[tokio::test]
async fn panicking_trigger_is_isolated_and_stop_survives() {
let conn = MockConn::new();
let ran = Arc::new(AtomicBool::new(false));
let good = Arc::new(OneShot {
name: "good".into(),
ran: ran.clone(),
delivery: TriggerDelivery::SendImmediately,
});
let boom = Arc::new(Panicker { name: "boom".into() });
let runner = TriggerRunner::new(vec![boom, good], conn.clone());
runner.start().expect("start");
tokio::time::timeout(Duration::from_secs(5), async {
while !ran.load(Ordering::Acquire) {
tokio::task::yield_now().await;
}
})
.await
.expect("sibling runs despite a panicking trigger");
runner.stop().await;
assert_eq!(conn.triggers_sent(), 1);
}
#[tokio::test(start_paused = true)]
async fn drop_aborts_running_triggers() {
let conn = MockConn::new();
conn.go_idle(); let counter = Arc::new(AtomicU32::new(0));
let c2 = counter.clone();
let trig = every(Duration::from_millis(10), "tick", move |_ctx| {
let c = c2.clone();
async move {
c.fetch_add(1, Ordering::AcqRel);
Ok(())
}
});
let runner = TriggerRunner::new(vec![trig], conn.clone());
runner.start().expect("start");
for _ in 0..8 {
tokio::task::yield_now().await;
}
for _ in 0..4 {
tokio::time::advance(Duration::from_millis(10)).await;
for _ in 0..4 {
tokio::task::yield_now().await;
}
}
let before = counter.load(Ordering::Acquire);
assert!(before >= 1, "periodic trigger fired at least once, got {before}");
drop(runner);
for _ in 0..4 {
tokio::task::yield_now().await;
}
for _ in 0..10 {
tokio::time::advance(Duration::from_millis(10)).await;
tokio::task::yield_now().await;
}
let after = counter.load(Ordering::Acquire);
assert_eq!(
before, after,
"dropped runner must not keep firing ({before} -> {after})"
);
}
}