use std::collections::HashMap;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use futures::stream::Stream;
use serde_json::Value;
use super::result::{RunConfig, RunEvent, RunResult, UserInput};
use crate::chat::{ResponseFormat, SharedChatProvider};
use crate::error::Result;
use crate::guardrail::{InputGuardrail, OutputGuardrail};
use crate::tool::{BoxedTool, ToolDefinition, ToolExecutionPolicy};
#[derive(Debug, Clone)]
pub struct OutputSchema {
name: String,
schema: Value,
strict: bool,
}
impl OutputSchema {
#[must_use]
pub fn new(name: impl Into<String>, schema: Value) -> Self {
Self {
name: name.into(),
schema,
strict: true,
}
}
#[must_use]
pub fn with_strict(name: impl Into<String>, schema: Value, strict: bool) -> Self {
Self {
name: name.into(),
schema,
strict,
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub const fn schema(&self) -> &Value {
&self.schema
}
#[must_use]
pub const fn is_strict(&self) -> bool {
self.strict
}
#[must_use]
pub fn to_response_format(&self) -> ResponseFormat {
if self.strict {
ResponseFormat::json_schema(&self.name, self.schema.clone())
} else {
ResponseFormat::JsonSchema {
json_schema: crate::chat::JsonSchemaSpec {
name: self.name.clone(),
schema: self.schema.clone(),
strict: Some(false),
},
}
}
}
#[cfg(feature = "schema")]
#[must_use]
pub fn from_type<T: schemars::JsonSchema>() -> Self {
let (name, schema_value) = crate::chat::generate_json_schema::<T>();
Self {
name,
schema: schema_value,
strict: true,
}
}
}
#[derive(Clone)]
pub enum Instructions {
Static(String),
Dynamic(Arc<dyn Fn(&str) -> String + Send + Sync>),
}
impl Instructions {
#[must_use]
pub fn resolve(&self, agent_name: &str) -> String {
match self {
Self::Static(s) => s.clone(),
Self::Dynamic(f) => f(agent_name),
}
}
}
impl fmt::Debug for Instructions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Static(s) => f.debug_tuple("Static").field(s).finish(),
Self::Dynamic(_) => f.debug_tuple("Dynamic").field(&"<closure>").finish(),
}
}
}
impl<S: Into<String>> From<S> for Instructions {
fn from(s: S) -> Self {
Self::Static(s.into())
}
}
pub struct Agent {
pub(crate) name: String,
pub(crate) instructions: Instructions,
pub(crate) model: String,
pub(crate) provider: Option<SharedChatProvider>,
pub(crate) tools: Vec<BoxedTool>,
pub(crate) managed_agents: Vec<Self>,
pub(crate) max_steps: usize,
pub(crate) description: String,
pub(crate) tool_policies: HashMap<String, ToolExecutionPolicy>,
pub(crate) output_schema: Option<OutputSchema>,
pub(crate) input_guardrails: Vec<InputGuardrail>,
pub(crate) output_guardrails: Vec<OutputGuardrail>,
}
impl fmt::Debug for Agent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Agent")
.field("name", &self.name)
.field("instructions", &self.instructions)
.field("model", &self.model)
.field("provider", &self.provider.is_some())
.field(
"tools",
&self.tools.iter().map(|t| t.name()).collect::<Vec<_>>(),
)
.field(
"managed_agents",
&self
.managed_agents
.iter()
.map(|a| &a.name)
.collect::<Vec<_>>(),
)
.field("max_steps", &self.max_steps)
.field("description", &self.description)
.field("tool_policies", &self.tool_policies)
.field(
"output_schema",
&self.output_schema.as_ref().map(OutputSchema::name),
)
.field("input_guardrails", &self.input_guardrails)
.field("output_guardrails", &self.output_guardrails)
.finish()
}
}
impl Agent {
pub const DEFAULT_MAX_STEPS: usize = 10;
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
let name = name.into();
Self {
description: format!("Agent: {name}"),
name,
instructions: Instructions::Static(String::new()),
model: String::new(),
provider: None,
tools: Vec::new(),
managed_agents: Vec::new(),
max_steps: Self::DEFAULT_MAX_STEPS,
tool_policies: HashMap::new(),
output_schema: None,
input_guardrails: Vec::new(),
output_guardrails: Vec::new(),
}
}
#[must_use]
pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
self.instructions = Instructions::Static(instructions.into());
self
}
#[must_use]
pub fn dynamic_instructions<F>(mut self, f: F) -> Self
where
F: Fn(&str) -> String + Send + Sync + 'static,
{
self.instructions = Instructions::Dynamic(Arc::new(f));
self
}
#[must_use]
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
#[must_use]
pub fn provider(mut self, provider: SharedChatProvider) -> Self {
self.provider = Some(provider);
self
}
#[must_use]
pub fn tool(mut self, tool: BoxedTool) -> Self {
self.tools.push(tool);
self
}
#[must_use]
pub fn tools(mut self, tools: Vec<BoxedTool>) -> Self {
self.tools = tools;
self
}
#[must_use]
pub fn managed_agent(mut self, agent: Self) -> Self {
self.managed_agents.push(agent);
self
}
#[must_use]
pub fn managed_agents(mut self, agents: Vec<Self>) -> Self {
self.managed_agents = agents;
self
}
#[must_use]
pub const fn max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = max_steps;
self
}
#[must_use]
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = description.into();
self
}
#[cfg(feature = "wallet")]
#[must_use]
pub fn wallet(mut self, wallet: crate::wallet::EvmWallet) -> Self {
self.tools.extend(wallet.into_tools());
self
}
#[must_use]
pub fn tool_policy(mut self, name: impl Into<String>, policy: ToolExecutionPolicy) -> Self {
self.tool_policies.insert(name.into(), policy);
self
}
#[must_use]
pub fn output_schema(mut self, schema: OutputSchema) -> Self {
self.output_schema = Some(schema);
self
}
#[must_use]
pub fn input_guardrail(mut self, guardrail: InputGuardrail) -> Self {
self.input_guardrails.push(guardrail);
self
}
#[must_use]
pub fn output_guardrail(mut self, guardrail: OutputGuardrail) -> Self {
self.output_guardrails.push(guardrail);
self
}
#[cfg(feature = "schema")]
#[must_use]
pub fn output_type<T: schemars::JsonSchema>(self) -> Self {
self.output_schema(OutputSchema::from_type::<T>())
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn get_model(&self) -> &str {
&self.model
}
#[must_use]
pub fn get_description(&self) -> &str {
&self.description
}
#[must_use]
pub const fn get_max_steps(&self) -> usize {
self.max_steps
}
#[must_use]
pub fn has_provider(&self) -> bool {
self.provider.is_some()
}
#[must_use]
pub fn tool_count(&self) -> usize {
self.tools.len()
}
#[must_use]
pub fn resolve_instructions(&self) -> String {
self.instructions.resolve(&self.name)
}
#[must_use]
pub const fn has_managed_agents(&self) -> bool {
!self.managed_agents.is_empty()
}
#[must_use]
pub fn total_tool_count(&self) -> usize {
self.tools.len() + self.managed_agents.len()
}
pub fn run<'a>(
&'a self,
input: impl Into<UserInput>,
config: RunConfig,
) -> Pin<Box<dyn Future<Output = Result<RunResult>> + Send + 'a>> {
super::Runner::run(self, input, config)
}
pub fn run_streamed<'a>(
&'a self,
input: impl Into<UserInput>,
config: RunConfig,
) -> Pin<Box<dyn Stream<Item = Result<RunEvent>> + Send + 'a>> {
super::Runner::run_streamed(self, input, config)
}
#[must_use]
pub fn tool_definition(&self) -> ToolDefinition {
ToolDefinition::new(
&self.name,
&self.description,
serde_json::json!({
"type": "object",
"properties": {
"task": {
"type": "string",
"description": "The task to delegate to this agent."
}
},
"required": ["task"],
"additionalProperties": false
}),
)
}
}