anda_engine 0.11.12

Agents engine for Anda -- an AI agent framework built with Rust, powered by ICP and TEEs.
//! Hook system for customizing engine behavior.

use anda_core::{
    AgentOutput, BoxError, CacheExpiry, CacheFeatures, CompletionRequest, Json, Resource,
    StateFeatures, ToolOutput,
};
use async_trait::async_trait;
use std::{sync::Arc, time::Duration};
use structured_logger::unix_ms;

use crate::context::{AgentCtx, BaseCtx};

/// Hook trait for customizing engine behavior.
/// Hooks can be used to intercept and modify agent and tool execution.
#[async_trait]
pub trait Hook: Send + Sync {
    /// Called before an agent is executed.
    async fn on_agent_start(&self, _ctx: &AgentCtx, _agent: &str) -> Result<(), BoxError> {
        Ok(())
    }

    /// Called after an agent is executed.
    async fn on_agent_end(
        &self,
        _ctx: &AgentCtx,
        _agent: &str,
        output: AgentOutput,
    ) -> Result<AgentOutput, BoxError> {
        Ok(output)
    }

    /// Called before a tool is called.
    async fn on_tool_start(&self, _ctx: &BaseCtx, _tool: &str) -> Result<(), BoxError> {
        Ok(())
    }

    /// Called after a tool is called.
    async fn on_tool_end(
        &self,
        _ctx: &BaseCtx,
        _tool: &str,
        output: ToolOutput<Json>,
    ) -> Result<ToolOutput<Json>, BoxError> {
        Ok(output)
    }
}

/// ToolHook trait for customizing tool call behavior.
/// It provides more fine-grained control over tool calls, allowing you to intercept and modify
#[async_trait]
pub trait ToolHook<I, O>: Send + Sync
where
    I: Send + Sync + 'static,
    O: Send + Sync + 'static,
{
    /// This method is called before a tool is called, allowing you to modify the input arguments.
    async fn before_tool_call(&self, _ctx: &BaseCtx, args: I) -> Result<I, BoxError> {
        Ok(args)
    }

    /// This method is called after a tool is called, allowing you to modify the output.
    async fn after_tool_call(
        &self,
        _ctx: &BaseCtx,
        output: ToolOutput<O>,
    ) -> Result<ToolOutput<O>, BoxError> {
        Ok(output)
    }

    /// This method can be called to handle the start of an asynchronous tool execution when the tool is executed in the background.
    async fn on_background_start(&self, _ctx: &BaseCtx, _task_id: &str, _args: &I) {}

    /// This method can be called to handle the final output when the tool is executed asynchronously in the background.
    async fn on_background_end(&self, _ctx: BaseCtx, _task_id: String, _output: ToolOutput<O>) {}
}

#[derive(Clone)]
pub struct DynToolHook<I, O> {
    inner: Arc<dyn ToolHook<I, O>>,
}

impl<I, O> DynToolHook<I, O>
where
    I: Send + Sync + 'static,
    O: Send + Sync + 'static,
{
    pub fn new(inner: Arc<dyn ToolHook<I, O>>) -> Self {
        Self { inner }
    }
}

#[async_trait]
impl<I, O> ToolHook<I, O> for DynToolHook<I, O>
where
    I: Send + Sync + 'static,
    O: Send + Sync + 'static,
{
    async fn before_tool_call(&self, ctx: &BaseCtx, args: I) -> Result<I, BoxError> {
        self.inner.before_tool_call(ctx, args).await
    }

    async fn after_tool_call(
        &self,
        ctx: &BaseCtx,
        output: ToolOutput<O>,
    ) -> Result<ToolOutput<O>, BoxError> {
        self.inner.after_tool_call(ctx, output).await
    }

    async fn on_background_start(&self, ctx: &BaseCtx, task_id: &str, args: &I) {
        self.inner.on_background_start(ctx, task_id, args).await;
    }

    async fn on_background_end(&self, ctx: BaseCtx, task_id: String, output: ToolOutput<O>) {
        self.inner.on_background_end(ctx, task_id, output).await;
    }
}

/// AgentHook trait for customizing agent execution behavior with more fine-grained control.
#[async_trait]
pub trait AgentHook: Send + Sync {
    /// Called before an agent is executed, allowing you to modify the prompt and resources.
    async fn before_agent_run(
        &self,
        _ctx: &AgentCtx,
        prompt: String,
        resources: Vec<Resource>,
    ) -> Result<(String, Vec<Resource>), BoxError> {
        Ok((prompt, resources))
    }

    /// Called after an agent is executed, allowing you to modify the output.
    /// If the agent is executed asynchronously in the background, this will be called immediately before the agent run is triggered, and the final output will be passed to `on_background_end`.
    async fn after_agent_run(
        &self,
        _ctx: &AgentCtx,
        output: AgentOutput,
    ) -> Result<AgentOutput, BoxError> {
        Ok(output)
    }

    /// This method can be called to handle the start of an agent execution when the agent is executed asynchronously in the background.
    async fn on_background_start(&self, _ctx: &AgentCtx, _task_id: &str, _req: &CompletionRequest) {
    }

    /// This method can be called to handle the final output when the agent is executed asynchronously in the background.
    async fn on_background_end(&self, _ctx: AgentCtx, _task_id: String, _output: AgentOutput) {}
}

#[derive(Clone)]
pub struct DynAgentHook {
    inner: Arc<dyn AgentHook>,
}

impl DynAgentHook {
    pub fn new(inner: Arc<dyn AgentHook>) -> Self {
        Self { inner }
    }
}

#[async_trait]
impl AgentHook for DynAgentHook {
    async fn before_agent_run(
        &self,
        ctx: &AgentCtx,
        prompt: String,
        resources: Vec<Resource>,
    ) -> Result<(String, Vec<Resource>), BoxError> {
        self.inner.before_agent_run(ctx, prompt, resources).await
    }

    async fn after_agent_run(
        &self,
        ctx: &AgentCtx,
        output: AgentOutput,
    ) -> Result<AgentOutput, BoxError> {
        self.inner.after_agent_run(ctx, output).await
    }

    async fn on_background_start(&self, ctx: &AgentCtx, task_id: &str, req: &CompletionRequest) {
        self.inner.on_background_start(ctx, task_id, req).await;
    }

    async fn on_background_end(&self, ctx: AgentCtx, task_id: String, output: AgentOutput) {
        self.inner.on_background_end(ctx, task_id, output).await;
    }
}

/// Hooks struct for managing multiple hooks.
pub struct Hooks {
    hooks: Vec<Box<dyn Hook>>,
}

impl Default for Hooks {
    fn default() -> Self {
        Self::new()
    }
}

impl Hooks {
    pub fn new() -> Self {
        Self { hooks: Vec::new() }
    }

    /// Adds a new hook to the list of hooks.
    pub fn add(&mut self, hook: Box<dyn Hook>) {
        self.hooks.push(hook);
    }
}

#[async_trait]
impl Hook for Hooks {
    async fn on_agent_start(&self, ctx: &AgentCtx, agent: &str) -> Result<(), BoxError> {
        for hook in &self.hooks {
            hook.on_agent_start(ctx, agent).await?;
        }
        Ok(())
    }

    async fn on_agent_end(
        &self,
        ctx: &AgentCtx,
        agent: &str,
        mut output: AgentOutput,
    ) -> Result<AgentOutput, BoxError> {
        for hook in &self.hooks {
            output = hook.on_agent_end(ctx, agent, output).await?;
        }
        Ok(output)
    }

    async fn on_tool_start(&self, ctx: &BaseCtx, tool: &str) -> Result<(), BoxError> {
        for hook in &self.hooks {
            hook.on_tool_start(ctx, tool).await?;
        }
        Ok(())
    }

    async fn on_tool_end(
        &self,
        ctx: &BaseCtx,
        tool: &str,
        mut output: ToolOutput<Json>,
    ) -> Result<ToolOutput<Json>, BoxError> {
        for hook in &self.hooks {
            output = hook.on_tool_end(ctx, tool, output).await?;
        }
        Ok(output)
    }
}

pub struct SingleThreadHook {
    ttl: Duration,
}

impl SingleThreadHook {
    pub fn new(ttl: Duration) -> Self {
        Self { ttl }
    }
}

#[async_trait]
impl Hook for SingleThreadHook {
    async fn on_agent_start(&self, ctx: &AgentCtx, _agent: &str) -> Result<(), BoxError> {
        let caller = ctx.caller();
        let now_ms = unix_ms();
        let ok = ctx
            .cache_set_if_not_exists(
                caller.to_string().as_str(),
                (now_ms, Some(CacheExpiry::TTL(self.ttl))),
            )
            .await;
        if !ok {
            return Err("Only one prompt can run at a time.".into());
        }
        Ok(())
    }

    async fn on_agent_end(
        &self,
        ctx: &AgentCtx,
        _agent: &str,
        output: AgentOutput,
    ) -> Result<AgentOutput, BoxError> {
        let caller = ctx.caller();
        ctx.cache_delete(caller.to_string().as_str()).await;
        Ok(output)
    }
}