use crate::context::{RunContext, UsageLimits};
use crate::errors::AgentRunError;
use crate::history::HistoryProcessor;
use crate::instructions::{InstructionFn, SystemPromptFn};
use crate::output::{OutputMode, OutputSchema, OutputValidator};
use crate::run::{AgentRun, AgentRunResult, RunOptions};
use crate::stream::AgentStream;
use serdes_ai_core::messages::UserContent;
use serdes_ai_core::ModelSettings;
use serdes_ai_models::Model;
use serdes_ai_tools::ToolDefinition;
use std::marker::PhantomData;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum EndStrategy {
#[default]
Early,
Exhaustive,
}
#[derive(Debug, Clone, Default)]
pub struct InstrumentationSettings {
pub enable_tracing: bool,
pub log_level: Option<String>,
pub span_name: Option<String>,
}
pub struct Agent<Deps = (), Output = String> {
pub(crate) model: Arc<dyn Model>,
pub(crate) name: Option<String>,
pub(crate) model_settings: ModelSettings,
pub(crate) static_system_prompt: Arc<str>,
pub(crate) instruction_fns: Vec<Box<dyn InstructionFn<Deps>>>,
pub(crate) system_prompt_fns: Vec<Box<dyn SystemPromptFn<Deps>>>,
pub(crate) tools: Vec<RegisteredTool<Deps>>,
pub(crate) cached_tool_defs: Arc<Vec<ToolDefinition>>,
pub(crate) output_schema: Box<dyn OutputSchema<Output>>,
pub(crate) output_validators: Vec<Box<dyn OutputValidator<Output, Deps>>>,
pub(crate) end_strategy: EndStrategy,
pub(crate) max_output_retries: u32,
#[allow(dead_code)]
pub(crate) max_tool_retries: u32,
pub(crate) usage_limits: Option<UsageLimits>,
pub(crate) history_processors: Vec<Box<dyn HistoryProcessor<Deps>>>,
#[allow(dead_code)]
pub(crate) instrument: Option<InstrumentationSettings>,
pub(crate) parallel_tool_calls: bool,
pub(crate) max_concurrent_tools: Option<usize>,
pub(crate) _phantom: PhantomData<(Deps, Output)>,
}
pub struct RegisteredTool<Deps> {
pub definition: ToolDefinition,
pub executor: Arc<dyn ToolExecutor<Deps>>,
pub max_retries: u32,
}
impl<Deps> Clone for RegisteredTool<Deps> {
fn clone(&self) -> Self {
Self {
definition: self.definition.clone(),
executor: Arc::clone(&self.executor),
max_retries: self.max_retries,
}
}
}
#[async_trait::async_trait]
pub trait ToolExecutor<Deps>: Send + Sync {
async fn execute(
&self,
args: serde_json::Value,
ctx: &RunContext<Deps>,
) -> Result<serdes_ai_tools::ToolReturn, serdes_ai_tools::ToolError>;
}
impl<Deps, Output> Agent<Deps, Output>
where
Deps: Send + Sync + 'static,
Output: Send + Sync + 'static,
{
pub fn model(&self) -> &dyn Model {
self.model.as_ref()
}
pub fn model_arc(&self) -> Arc<dyn Model> {
Arc::clone(&self.model)
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn model_settings(&self) -> &ModelSettings {
&self.model_settings
}
pub fn tools(&self) -> Vec<&ToolDefinition> {
self.tools.iter().map(|t| &t.definition).collect()
}
pub fn output_mode(&self) -> OutputMode {
self.output_schema.mode()
}
pub fn has_tools(&self) -> bool {
!self.tools.is_empty()
}
pub fn usage_limits(&self) -> Option<&UsageLimits> {
self.usage_limits.as_ref()
}
pub fn parallel_tool_calls(&self) -> bool {
self.parallel_tool_calls
}
pub fn max_concurrent_tools(&self) -> Option<usize> {
self.max_concurrent_tools
}
pub async fn run(
&self,
prompt: impl Into<UserContent>,
deps: Deps,
) -> Result<AgentRunResult<Output>, AgentRunError> {
self.run_with_options(prompt, deps, RunOptions::default())
.await
}
pub async fn run_with_options(
&self,
prompt: impl Into<UserContent>,
deps: Deps,
options: RunOptions,
) -> Result<AgentRunResult<Output>, AgentRunError> {
let run = self.start_run(prompt, deps, options).await?;
run.run_to_completion().await
}
pub fn run_sync(
&self,
prompt: impl Into<UserContent>,
deps: Deps,
) -> Result<AgentRunResult<Output>, AgentRunError> {
tokio::runtime::Handle::current().block_on(self.run(prompt, deps))
}
pub async fn start_run(
&self,
prompt: impl Into<UserContent>,
deps: Deps,
options: RunOptions,
) -> Result<AgentRun<'_, Deps, Output>, AgentRunError> {
AgentRun::new(self, prompt.into(), deps, options).await
}
pub async fn run_stream(
&self,
prompt: impl Into<UserContent>,
deps: Deps,
) -> Result<AgentStream, AgentRunError> {
self.run_stream_with_options(prompt, deps, RunOptions::default())
.await
}
pub async fn run_stream_with_options(
&self,
prompt: impl Into<UserContent>,
deps: Deps,
options: RunOptions,
) -> Result<AgentStream, AgentRunError> {
AgentStream::new(self, prompt.into(), deps, options).await
}
pub(crate) async fn build_system_prompt(&self, ctx: &RunContext<Deps>) -> String {
let has_dynamic = !self.system_prompt_fns.is_empty() || !self.instruction_fns.is_empty();
if !has_dynamic {
return self.static_system_prompt.to_string();
}
let mut parts = Vec::new();
if !self.static_system_prompt.is_empty() {
parts.push(self.static_system_prompt.to_string());
}
for prompt_fn in &self.system_prompt_fns {
if let Some(prompt) = prompt_fn.generate(ctx).await {
if !prompt.is_empty() {
parts.push(prompt);
}
}
}
for instruction_fn in &self.instruction_fns {
if let Some(instruction) = instruction_fn.generate(ctx).await {
if !instruction.is_empty() {
parts.push(instruction);
}
}
}
parts.join("\n\n")
}
pub(crate) fn tool_definitions(&self) -> Arc<Vec<ToolDefinition>> {
Arc::clone(&self.cached_tool_defs)
}
pub(crate) fn find_tool(&self, name: &str) -> Option<&RegisteredTool<Deps>> {
self.tools.iter().find(|t| t.definition.name == name)
}
pub(crate) fn is_output_tool(&self, name: &str) -> bool {
self.output_schema
.tool_name()
.map(|n| n == name)
.unwrap_or(false)
}
#[allow(dead_code)]
pub(crate) fn output_tool_name(&self) -> Option<String> {
self.output_schema.tool_name().map(|s| s.to_string())
}
pub fn static_system_prompt(&self) -> &str {
&self.static_system_prompt
}
}
impl<Deps: Send + Sync + 'static> Default for Agent<Deps, String> {
fn default() -> Self {
panic!("Agent must be created using Agent::builder() or AgentBuilder")
}
}
impl<Deps, Output> std::fmt::Debug for Agent<Deps, Output> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Agent")
.field("name", &self.name)
.field("model", &self.model.name())
.field("tools", &self.tools.len())
.field("end_strategy", &self.end_strategy)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_end_strategy_default() {
assert_eq!(EndStrategy::default(), EndStrategy::Early);
}
#[test]
fn test_instrumentation_settings_default() {
let settings = InstrumentationSettings::default();
assert!(!settings.enable_tracing);
assert!(settings.log_level.is_none());
}
}