use crate::agents::{
error::AgentError,
runner::{AgentRunOutcome, AgentRunner},
session::SessionState,
};
use async_trait::async_trait;
use potato_type::{tools::AsyncTool, TypeError};
use serde_json::Value;
use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
#[derive(Debug, Clone, Default)]
pub struct AgentToolPolicy {
pub disallow_sub_agent_calls: bool,
pub disallowed_agent_ids: HashSet<String>,
}
#[derive(Debug, Clone)]
pub struct AgentTool {
name: String,
description: String,
runner: Arc<dyn AgentRunner>,
pub policy: AgentToolPolicy,
}
impl AgentTool {
pub fn new(
name: impl Into<String>,
description: impl Into<String>,
runner: Arc<dyn AgentRunner>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
runner,
policy: AgentToolPolicy::default(),
}
}
pub fn with_policy(mut self, policy: AgentToolPolicy) -> Self {
self.policy = policy;
self
}
pub fn runner(&self) -> &Arc<dyn AgentRunner> {
&self.runner
}
fn fixed_schema() -> Value {
serde_json::json!({
"type": "object",
"properties": {
"input": { "type": "string" }
},
"required": ["input"]
})
}
pub(crate) async fn dispatch(
&self,
args: Value,
session: &mut SessionState,
) -> Result<Value, AgentError> {
if session.is_ancestor(self.runner.id()) {
return Err(AgentError::CircularAgentCall(self.runner.id().to_string()));
}
if self.policy.disallowed_agent_ids.contains(self.runner.id()) {
return Err(AgentError::DisallowedAgentCall(
self.runner.id().to_string(),
));
}
let input = args
.get("input")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
session.push_ancestor(self.runner.id());
let result = self.runner.run(&input, session).await;
session.pop_ancestor();
match result? {
AgentRunOutcome::Complete(r) => Ok(Value::String(r.final_response.response_text())),
AgentRunOutcome::NeedsInput { question, .. } => {
Err(AgentError::SubAgentNeedsInput(question))
}
}
}
}
#[async_trait]
impl AsyncTool for AgentTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameter_schema(&self) -> Value {
Self::fixed_schema()
}
async fn execute(&self, args: Value) -> Result<Value, TypeError> {
let mut session = SessionState::new();
self.dispatch(args, &mut session)
.await
.map_err(|e| TypeError::Error(e.to_string()))
}
fn as_any(&self) -> Option<&dyn std::any::Any> {
Some(self)
}
}
impl potato_type::tools::Tool for AgentTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn parameter_schema(&self) -> Value {
Self::fixed_schema()
}
fn execute(&self, args: Value) -> Result<Value, TypeError> {
potato_state::block_on(async {
let mut session = SessionState::new();
self.dispatch(args, &mut session)
.await
.map_err(|e| TypeError::Error(e.to_string()))
})
}
}