use anda_core::{
AgentOutput, BoxError, CacheExpiry, CacheFeatures, CompletionRequest, Json, Resource,
StateFeatures, ToolOutput,
};
use async_trait::async_trait;
use std::{sync::Arc, time::Duration};
use structured_logger::unix_ms;
use crate::context::{AgentCtx, BaseCtx};
#[async_trait]
pub trait Hook: Send + Sync {
async fn on_agent_start(&self, _ctx: &AgentCtx, _agent: &str) -> Result<(), BoxError> {
Ok(())
}
async fn on_agent_end(
&self,
_ctx: &AgentCtx,
_agent: &str,
output: AgentOutput,
) -> Result<AgentOutput, BoxError> {
Ok(output)
}
async fn on_tool_start(&self, _ctx: &BaseCtx, _tool: &str) -> Result<(), BoxError> {
Ok(())
}
async fn on_tool_end(
&self,
_ctx: &BaseCtx,
_tool: &str,
output: ToolOutput<Json>,
) -> Result<ToolOutput<Json>, BoxError> {
Ok(output)
}
}
#[async_trait]
pub trait ToolHook<I, O>: Send + Sync
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
async fn before_tool_call(&self, _ctx: &BaseCtx, args: I) -> Result<I, BoxError> {
Ok(args)
}
async fn after_tool_call(
&self,
_ctx: &BaseCtx,
output: ToolOutput<O>,
) -> Result<ToolOutput<O>, BoxError> {
Ok(output)
}
async fn on_background_start(&self, _ctx: &BaseCtx, _task_id: &str, _args: &I) {}
async fn on_background_end(&self, _ctx: BaseCtx, _task_id: String, _output: ToolOutput<O>) {}
}
#[derive(Clone)]
pub struct DynToolHook<I, O> {
inner: Arc<dyn ToolHook<I, O>>,
}
impl<I, O> DynToolHook<I, O>
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
pub fn new(inner: Arc<dyn ToolHook<I, O>>) -> Self {
Self { inner }
}
}
#[async_trait]
impl<I, O> ToolHook<I, O> for DynToolHook<I, O>
where
I: Send + Sync + 'static,
O: Send + Sync + 'static,
{
async fn before_tool_call(&self, ctx: &BaseCtx, args: I) -> Result<I, BoxError> {
self.inner.before_tool_call(ctx, args).await
}
async fn after_tool_call(
&self,
ctx: &BaseCtx,
output: ToolOutput<O>,
) -> Result<ToolOutput<O>, BoxError> {
self.inner.after_tool_call(ctx, output).await
}
async fn on_background_start(&self, ctx: &BaseCtx, task_id: &str, args: &I) {
self.inner.on_background_start(ctx, task_id, args).await;
}
async fn on_background_end(&self, ctx: BaseCtx, task_id: String, output: ToolOutput<O>) {
self.inner.on_background_end(ctx, task_id, output).await;
}
}
#[async_trait]
pub trait AgentHook: Send + Sync {
async fn before_agent_run(
&self,
_ctx: &AgentCtx,
prompt: String,
resources: Vec<Resource>,
) -> Result<(String, Vec<Resource>), BoxError> {
Ok((prompt, resources))
}
async fn after_agent_run(
&self,
_ctx: &AgentCtx,
output: AgentOutput,
) -> Result<AgentOutput, BoxError> {
Ok(output)
}
async fn on_background_start(&self, _ctx: &AgentCtx, _task_id: &str, _req: &CompletionRequest) {
}
async fn on_background_end(&self, _ctx: AgentCtx, _task_id: String, _output: AgentOutput) {}
}
#[derive(Clone)]
pub struct DynAgentHook {
inner: Arc<dyn AgentHook>,
}
impl DynAgentHook {
pub fn new(inner: Arc<dyn AgentHook>) -> Self {
Self { inner }
}
}
#[async_trait]
impl AgentHook for DynAgentHook {
async fn before_agent_run(
&self,
ctx: &AgentCtx,
prompt: String,
resources: Vec<Resource>,
) -> Result<(String, Vec<Resource>), BoxError> {
self.inner.before_agent_run(ctx, prompt, resources).await
}
async fn after_agent_run(
&self,
ctx: &AgentCtx,
output: AgentOutput,
) -> Result<AgentOutput, BoxError> {
self.inner.after_agent_run(ctx, output).await
}
async fn on_background_start(&self, ctx: &AgentCtx, task_id: &str, req: &CompletionRequest) {
self.inner.on_background_start(ctx, task_id, req).await;
}
async fn on_background_end(&self, ctx: AgentCtx, task_id: String, output: AgentOutput) {
self.inner.on_background_end(ctx, task_id, output).await;
}
}
pub struct Hooks {
hooks: Vec<Box<dyn Hook>>,
}
impl Default for Hooks {
fn default() -> Self {
Self::new()
}
}
impl Hooks {
pub fn new() -> Self {
Self { hooks: Vec::new() }
}
pub fn add(&mut self, hook: Box<dyn Hook>) {
self.hooks.push(hook);
}
}
#[async_trait]
impl Hook for Hooks {
async fn on_agent_start(&self, ctx: &AgentCtx, agent: &str) -> Result<(), BoxError> {
for hook in &self.hooks {
hook.on_agent_start(ctx, agent).await?;
}
Ok(())
}
async fn on_agent_end(
&self,
ctx: &AgentCtx,
agent: &str,
mut output: AgentOutput,
) -> Result<AgentOutput, BoxError> {
for hook in &self.hooks {
output = hook.on_agent_end(ctx, agent, output).await?;
}
Ok(output)
}
async fn on_tool_start(&self, ctx: &BaseCtx, tool: &str) -> Result<(), BoxError> {
for hook in &self.hooks {
hook.on_tool_start(ctx, tool).await?;
}
Ok(())
}
async fn on_tool_end(
&self,
ctx: &BaseCtx,
tool: &str,
mut output: ToolOutput<Json>,
) -> Result<ToolOutput<Json>, BoxError> {
for hook in &self.hooks {
output = hook.on_tool_end(ctx, tool, output).await?;
}
Ok(output)
}
}
pub struct SingleThreadHook {
ttl: Duration,
}
impl SingleThreadHook {
pub fn new(ttl: Duration) -> Self {
Self { ttl }
}
}
#[async_trait]
impl Hook for SingleThreadHook {
async fn on_agent_start(&self, ctx: &AgentCtx, _agent: &str) -> Result<(), BoxError> {
let caller = ctx.caller();
let now_ms = unix_ms();
let ok = ctx
.cache_set_if_not_exists(
caller.to_string().as_str(),
(now_ms, Some(CacheExpiry::TTL(self.ttl))),
)
.await;
if !ok {
return Err("Only one prompt can run at a time.".into());
}
Ok(())
}
async fn on_agent_end(
&self,
ctx: &AgentCtx,
_agent: &str,
output: AgentOutput,
) -> Result<AgentOutput, BoxError> {
let caller = ctx.caller();
ctx.cache_delete(caller.to_string().as_str()).await;
Ok(output)
}
}