use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{collections::BTreeMap, future::Future, marker::PhantomData, sync::Arc};
use crate::{
BoxError, BoxPinFut, Function,
context::AgentContext,
model::{AgentOutput, FunctionDefinition, Resource},
select_resources, validate_function_name,
};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AgentArgs {
pub prompt: String,
}
pub trait Agent<C>: Send + Sync
where
C: AgentContext + Send + Sync,
{
fn name(&self) -> String;
fn description(&self) -> String;
fn definition(&self) -> FunctionDefinition {
FunctionDefinition {
name: self.name().to_ascii_lowercase(),
description: self.description(),
parameters: json!({
"type": "object",
"description": "Run this agent on a focused task. Provide a self-contained prompt with the goal, relevant context, constraints, and expected output.",
"properties": {
"prompt": {
"type": "string",
"description": "The task for this agent. Include the objective, relevant context, constraints, preferred workflow or deliverable, and any success criteria needed to complete the work.",
"minLength": 1
},
},
"required": ["prompt"],
"additionalProperties": false
}),
strict: Some(true),
}
}
fn supported_resource_tags(&self) -> Vec<String> {
Vec::new()
}
fn select_resources(&self, resources: &mut Vec<Resource>) -> Vec<Resource> {
let supported_tags = self.supported_resource_tags();
select_resources(resources, &supported_tags)
}
fn init(&self, _ctx: C) -> impl Future<Output = Result<(), BoxError>> + Send {
futures::future::ready(Ok(()))
}
fn tool_dependencies(&self) -> Vec<String> {
Vec::new()
}
fn run(
&self,
ctx: C,
prompt: String,
resources: Vec<Resource>,
) -> impl Future<Output = Result<AgentOutput, BoxError>> + Send;
}
pub trait DynAgent<C>: Send + Sync
where
C: AgentContext + Send + Sync,
{
fn label(&self) -> &str;
fn name(&self) -> String;
fn definition(&self) -> FunctionDefinition;
fn tool_dependencies(&self) -> Vec<String>;
fn supported_resource_tags(&self) -> Vec<String>;
fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>>;
fn run(
&self,
ctx: C,
prompt: String,
resources: Vec<Resource>,
) -> BoxPinFut<Result<AgentOutput, BoxError>>;
}
struct AgentWrapper<T, C>
where
T: Agent<C> + 'static,
C: AgentContext + Send + Sync + 'static,
{
inner: Arc<T>,
label: String,
_phantom: PhantomData<C>,
}
impl<T, C> DynAgent<C> for AgentWrapper<T, C>
where
T: Agent<C> + 'static,
C: AgentContext + Send + Sync + 'static,
{
fn label(&self) -> &str {
&self.label
}
fn name(&self) -> String {
self.inner.name()
}
fn definition(&self) -> FunctionDefinition {
self.inner.definition()
}
fn tool_dependencies(&self) -> Vec<String> {
self.inner.tool_dependencies()
}
fn supported_resource_tags(&self) -> Vec<String> {
self.inner.supported_resource_tags()
}
fn init(&self, ctx: C) -> BoxPinFut<Result<(), BoxError>> {
let agent = self.inner.clone();
Box::pin(async move { agent.init(ctx).await })
}
fn run(
&self,
ctx: C,
prompt: String,
resources: Vec<Resource>,
) -> BoxPinFut<Result<AgentOutput, BoxError>> {
let agent = self.inner.clone();
Box::pin(async move { agent.run(ctx, prompt, resources).await })
}
}
#[derive(Default)]
pub struct AgentSet<C: AgentContext> {
pub set: BTreeMap<String, Arc<dyn DynAgent<C>>>,
}
impl<C> AgentSet<C>
where
C: AgentContext + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
set: BTreeMap::new(),
}
}
pub fn contains(&self, name: &str) -> bool {
self.set.contains_key(&name.to_ascii_lowercase())
}
pub fn contains_lowercase(&self, lowercase_name: &str) -> bool {
self.set.contains_key(lowercase_name)
}
pub fn names(&self) -> Vec<String> {
self.set.keys().cloned().collect()
}
pub fn definition(&self, name: &str) -> Option<FunctionDefinition> {
self.set
.get(&name.to_ascii_lowercase())
.map(|agent| agent.definition())
}
pub fn definitions(&self, names: Option<&[String]>) -> Vec<FunctionDefinition> {
let names: Option<Vec<String>> =
names.map(|names| names.iter().map(|n| n.to_ascii_lowercase()).collect());
self.set
.iter()
.filter_map(|(name, agent)| match &names {
Some(names) => {
if names.contains(name) {
Some(agent.definition())
} else {
None
}
}
None => Some(agent.definition()),
})
.collect()
}
pub fn functions(&self, names: Option<&[String]>) -> Vec<Function> {
let names: Option<Vec<String>> =
names.map(|names| names.iter().map(|n| n.to_ascii_lowercase()).collect());
self.set
.iter()
.filter_map(|(name, agent)| match &names {
Some(names) => {
if names.contains(name) {
Some(Function {
definition: agent.definition(),
supported_resource_tags: agent.supported_resource_tags(),
})
} else {
None
}
}
None => Some(Function {
definition: agent.definition(),
supported_resource_tags: agent.supported_resource_tags(),
}),
})
.collect()
}
pub fn select_resources(&self, name: &str, resources: &mut Vec<Resource>) -> Vec<Resource> {
if resources.is_empty() {
return Vec::new();
}
self.set
.get(&name.to_ascii_lowercase())
.map(|agent| {
let supported_tags = agent.supported_resource_tags();
select_resources(resources, &supported_tags)
})
.unwrap_or_default()
}
pub fn add<T>(&mut self, agent: Arc<T>, label: Option<String>) -> Result<(), BoxError>
where
T: Agent<C> + Send + Sync + 'static,
{
let name = agent.name().to_ascii_lowercase();
if self.set.contains_key(&name) {
return Err(format!("agent {} already exists", name).into());
}
validate_function_name(&name)?;
let agent_dyn = AgentWrapper {
inner: agent,
label: label.unwrap_or_else(|| name.clone()),
_phantom: PhantomData,
};
self.set.insert(name, Arc::new(agent_dyn));
Ok(())
}
pub fn get(&self, name: &str) -> Option<Arc<dyn DynAgent<C>>> {
self.set.get(&name.to_ascii_lowercase()).cloned()
}
pub fn get_lowercase(&self, lowercase_name: &str) -> Option<Arc<dyn DynAgent<C>>> {
self.set.get(lowercase_name).cloned()
}
}