use std::sync::Arc;
use async_trait::async_trait;
use entelix_core::context::ExecutionContext;
use entelix_core::error::{Error, Result};
#[async_trait]
pub trait AgentObserver<S>: Send + Sync
where
S: Clone + Send + Sync + 'static,
{
fn name(&self) -> &'static str {
""
}
async fn pre_turn(&self, _state: &S, _ctx: &ExecutionContext) -> Result<()> {
Ok(())
}
async fn on_complete(&self, _state: &S, _ctx: &ExecutionContext) -> Result<()> {
Ok(())
}
async fn on_error(&self, _error: &Error, _ctx: &ExecutionContext) -> Result<()> {
Ok(())
}
}
pub type DynObserver<S> = Arc<dyn AgentObserver<S>>;
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
struct CountingObserver {
name: &'static str,
pre_turn: AtomicUsize,
on_complete: AtomicUsize,
on_error: AtomicUsize,
}
impl CountingObserver {
fn new(name: &'static str) -> Self {
Self {
name,
pre_turn: AtomicUsize::new(0),
on_complete: AtomicUsize::new(0),
on_error: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl AgentObserver<i32> for CountingObserver {
fn name(&self) -> &'static str {
self.name
}
async fn pre_turn(&self, _state: &i32, _ctx: &ExecutionContext) -> Result<()> {
self.pre_turn.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn on_complete(&self, _state: &i32, _ctx: &ExecutionContext) -> Result<()> {
self.on_complete.fetch_add(1, Ordering::SeqCst);
Ok(())
}
async fn on_error(&self, _error: &Error, _ctx: &ExecutionContext) -> Result<()> {
self.on_error.fetch_add(1, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test]
async fn default_methods_are_no_ops() {
struct Bare;
#[async_trait]
impl AgentObserver<i32> for Bare {}
let observer = Bare;
let ctx = ExecutionContext::new();
observer.pre_turn(&0, &ctx).await.unwrap();
observer.on_complete(&0, &ctx).await.unwrap();
observer
.on_error(&Error::config("nope"), &ctx)
.await
.unwrap();
assert_eq!(observer.name(), "");
}
#[tokio::test]
async fn observer_records_each_lifecycle_event() {
let obs = CountingObserver::new("test");
let ctx = ExecutionContext::new();
obs.pre_turn(&0, &ctx).await.unwrap();
obs.on_complete(&100, &ctx).await.unwrap();
obs.on_error(&Error::config("nope"), &ctx).await.unwrap();
assert_eq!(obs.name(), "test");
assert_eq!(obs.pre_turn.load(Ordering::SeqCst), 1);
assert_eq!(obs.on_complete.load(Ordering::SeqCst), 1);
assert_eq!(obs.on_error.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn dyn_observer_handle_works_for_storage() {
let raw = Arc::new(CountingObserver::new("dyn-test"));
let dyn_obs: DynObserver<i32> = raw.clone();
let ctx = ExecutionContext::new();
dyn_obs.pre_turn(&0, &ctx).await.unwrap();
assert_eq!(raw.pre_turn.load(Ordering::SeqCst), 1);
}
}