use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::AgentError;
use crate::operations::agent::{Model, PermissionMode};
pub type InvokeFuture<'a> =
Pin<Box<dyn Future<Output = Result<AgentOutput, AgentError>> + Send + 'a>>;
#[derive(Debug, Clone, Copy)]
pub struct NoTools;
#[derive(Debug, Clone, Copy)]
pub struct WithTools;
#[derive(Debug, Clone, Copy)]
pub struct NoSchema;
#[derive(Debug, Clone, Copy)]
pub struct WithSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "", deserialize = ""))]
#[non_exhaustive]
pub struct AgentConfig<Tools = NoTools, Schema = NoSchema> {
pub system_prompt: Option<String>,
pub prompt: String,
#[serde(default = "default_model")]
pub model: String,
#[serde(default)]
pub allowed_tools: Vec<String>,
pub max_turns: Option<u32>,
pub max_budget_usd: Option<f64>,
pub working_dir: Option<String>,
pub mcp_config: Option<String>,
#[serde(default)]
pub permission_mode: PermissionMode,
#[serde(alias = "output_schema")]
pub json_schema: Option<String>,
pub resume_session_id: Option<String>,
#[serde(default)]
pub verbose: bool,
#[serde(skip)]
pub(crate) _marker: PhantomData<(Tools, Schema)>,
}
fn default_model() -> String {
Model::SONNET.to_string()
}
impl AgentConfig {
pub fn new(prompt: &str) -> Self {
Self {
system_prompt: None,
prompt: prompt.to_string(),
model: Model::SONNET.to_string(),
allowed_tools: Vec::new(),
max_turns: None,
max_budget_usd: None,
working_dir: None,
mcp_config: None,
permission_mode: PermissionMode::Default,
json_schema: None,
resume_session_id: None,
verbose: false,
_marker: PhantomData,
}
}
}
impl<Tools, Schema> AgentConfig<Tools, Schema> {
pub fn system_prompt(mut self, prompt: &str) -> Self {
self.system_prompt = Some(prompt.to_string());
self
}
pub fn model(mut self, model: &str) -> Self {
self.model = model.to_string();
self
}
pub fn max_budget_usd(mut self, budget: f64) -> Self {
self.max_budget_usd = Some(budget);
self
}
pub fn max_turns(mut self, turns: u32) -> Self {
self.max_turns = Some(turns);
self
}
pub fn working_dir(mut self, dir: &str) -> Self {
self.working_dir = Some(dir.to_string());
self
}
pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
self.permission_mode = mode;
self
}
pub fn verbose(mut self, enabled: bool) -> Self {
self.verbose = enabled;
self
}
pub fn mcp_config(mut self, config: &str) -> Self {
self.mcp_config = Some(config.to_string());
self
}
pub fn resume(mut self, session_id: &str) -> Self {
self.resume_session_id = Some(session_id.to_string());
self
}
fn change_state<T2, S2>(self) -> AgentConfig<T2, S2> {
AgentConfig {
system_prompt: self.system_prompt,
prompt: self.prompt,
model: self.model,
allowed_tools: self.allowed_tools,
max_turns: self.max_turns,
max_budget_usd: self.max_budget_usd,
working_dir: self.working_dir,
mcp_config: self.mcp_config,
permission_mode: self.permission_mode,
json_schema: self.json_schema,
resume_session_id: self.resume_session_id,
verbose: self.verbose,
_marker: PhantomData,
}
}
}
impl<Tools> AgentConfig<Tools, NoSchema> {
pub fn allow_tool(mut self, tool: &str) -> AgentConfig<WithTools, NoSchema> {
self.allowed_tools.push(tool.to_string());
self.change_state()
}
}
impl<Schema> AgentConfig<NoTools, Schema> {
pub fn output<T: JsonSchema>(mut self) -> AgentConfig<NoTools, WithSchema> {
let schema = schemars::schema_for!(T);
let serialized = serde_json::to_string(&schema).unwrap_or_else(|e| {
panic!(
"failed to serialize JSON schema for {}: {e}",
std::any::type_name::<T>()
)
});
self.json_schema = Some(serialized);
self.change_state()
}
pub fn output_schema_raw(mut self, schema: &str) -> AgentConfig<NoTools, WithSchema> {
self.json_schema = Some(schema.to_string());
self.change_state()
}
}
impl From<AgentConfig<WithTools, NoSchema>> for AgentConfig {
fn from(config: AgentConfig<WithTools, NoSchema>) -> Self {
config.change_state()
}
}
impl From<AgentConfig<NoTools, WithSchema>> for AgentConfig {
fn from(config: AgentConfig<NoTools, WithSchema>) -> Self {
config.change_state()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[non_exhaustive]
pub struct AgentOutput {
pub value: Value,
pub session_id: Option<String>,
pub cost_usd: Option<f64>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
pub model: Option<String>,
pub duration_ms: u64,
pub debug_messages: Option<Vec<DebugMessage>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct DebugMessage {
pub text: Option<String>,
pub tool_calls: Vec<DebugToolCall>,
pub stop_reason: Option<String>,
}
impl fmt::Display for DebugMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref text) = self.text {
writeln!(f, "[assistant] {text}")?;
}
for tc in &self.tool_calls {
write!(f, "{tc}")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct DebugToolCall {
pub name: String,
pub input: Value,
}
impl fmt::Display for DebugToolCall {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, " [tool_use] {} -> {}", self.name, self.input)
}
}
impl AgentOutput {
pub fn new(value: Value) -> Self {
Self {
value,
session_id: None,
cost_usd: None,
input_tokens: None,
output_tokens: None,
model: None,
duration_ms: 0,
debug_messages: None,
}
}
}
pub trait AgentProvider: Send + Sync {
fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a>;
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn full_config() -> AgentConfig {
AgentConfig {
system_prompt: Some("you are helpful".to_string()),
prompt: "do stuff".to_string(),
model: Model::OPUS.to_string(),
allowed_tools: vec!["Read".to_string(), "Write".to_string()],
max_turns: Some(10),
max_budget_usd: Some(2.5),
working_dir: Some("/tmp".to_string()),
mcp_config: Some("{}".to_string()),
permission_mode: PermissionMode::Auto,
json_schema: Some(r#"{"type":"object"}"#.to_string()),
resume_session_id: None,
verbose: false,
_marker: PhantomData,
}
}
#[test]
fn agent_config_serialize_deserialize_roundtrip() {
let config = full_config();
let json = serde_json::to_string(&config).unwrap();
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.system_prompt, Some("you are helpful".to_string()));
assert_eq!(back.prompt, "do stuff");
assert_eq!(back.allowed_tools, vec!["Read", "Write"]);
assert_eq!(back.max_turns, Some(10));
assert_eq!(back.max_budget_usd, Some(2.5));
assert_eq!(back.working_dir, Some("/tmp".to_string()));
assert_eq!(back.mcp_config, Some("{}".to_string()));
assert_eq!(back.json_schema, Some(r#"{"type":"object"}"#.to_string()));
}
#[test]
fn agent_config_with_all_optional_fields_none() {
let config: AgentConfig = AgentConfig {
system_prompt: None,
prompt: "hello".to_string(),
model: Model::HAIKU.to_string(),
allowed_tools: vec![],
max_turns: None,
max_budget_usd: None,
working_dir: None,
mcp_config: None,
permission_mode: PermissionMode::Default,
json_schema: None,
resume_session_id: None,
verbose: false,
_marker: PhantomData,
};
let json = serde_json::to_string(&config).unwrap();
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.system_prompt, None);
assert_eq!(back.prompt, "hello");
assert!(back.allowed_tools.is_empty());
assert_eq!(back.max_turns, None);
assert_eq!(back.max_budget_usd, None);
assert_eq!(back.working_dir, None);
assert_eq!(back.mcp_config, None);
assert_eq!(back.json_schema, None);
}
#[test]
fn agent_output_serialize_deserialize_roundtrip() {
let output = AgentOutput {
value: json!({"key": "value"}),
session_id: Some("sess-abc".to_string()),
cost_usd: Some(0.01),
input_tokens: Some(500),
output_tokens: Some(200),
model: Some("claude-sonnet".to_string()),
duration_ms: 3000,
debug_messages: None,
};
let json = serde_json::to_string(&output).unwrap();
let back: AgentOutput = serde_json::from_str(&json).unwrap();
assert_eq!(back.value, json!({"key": "value"}));
assert_eq!(back.session_id, Some("sess-abc".to_string()));
assert_eq!(back.cost_usd, Some(0.01));
assert_eq!(back.input_tokens, Some(500));
assert_eq!(back.output_tokens, Some(200));
assert_eq!(back.model, Some("claude-sonnet".to_string()));
assert_eq!(back.duration_ms, 3000);
}
#[test]
fn agent_config_new_has_correct_defaults() {
let config = AgentConfig::new("test prompt");
assert_eq!(config.prompt, "test prompt");
assert_eq!(config.system_prompt, None);
assert_eq!(config.model, Model::SONNET);
assert!(config.allowed_tools.is_empty());
assert_eq!(config.max_turns, None);
assert_eq!(config.max_budget_usd, None);
assert_eq!(config.working_dir, None);
assert_eq!(config.mcp_config, None);
assert!(matches!(config.permission_mode, PermissionMode::Default));
assert_eq!(config.json_schema, None);
assert_eq!(config.resume_session_id, None);
assert!(!config.verbose);
}
#[test]
fn agent_output_new_has_correct_defaults() {
let output = AgentOutput::new(json!("test"));
assert_eq!(output.value, json!("test"));
assert_eq!(output.session_id, None);
assert_eq!(output.cost_usd, None);
assert_eq!(output.input_tokens, None);
assert_eq!(output.output_tokens, None);
assert_eq!(output.model, None);
assert_eq!(output.duration_ms, 0);
assert!(output.debug_messages.is_none());
}
#[test]
fn agent_config_resume_session_roundtrip() {
let mut config = AgentConfig::new("test");
config.resume_session_id = Some("sess-xyz".to_string());
let json = serde_json::to_string(&config).unwrap();
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.resume_session_id, Some("sess-xyz".to_string()));
}
#[test]
fn agent_output_debug_does_not_panic() {
let output = AgentOutput {
value: json!(null),
session_id: None,
cost_usd: None,
input_tokens: None,
output_tokens: None,
model: None,
duration_ms: 0,
debug_messages: None,
};
let debug_str = format!("{:?}", output);
assert!(!debug_str.is_empty());
}
#[test]
fn allow_tool_transitions_to_with_tools() {
let config = AgentConfig::new("test").allow_tool("Read");
assert_eq!(config.allowed_tools, vec!["Read"]);
let config = config.allow_tool("Write");
assert_eq!(config.allowed_tools, vec!["Read", "Write"]);
}
#[test]
fn output_schema_raw_transitions_to_with_schema() {
let config = AgentConfig::new("test").output_schema_raw(r#"{"type":"object"}"#);
assert_eq!(config.json_schema.as_deref(), Some(r#"{"type":"object"}"#));
}
#[test]
fn with_tools_converts_to_base_type() {
let typed = AgentConfig::new("test").allow_tool("Read");
let base: AgentConfig = typed.into();
assert_eq!(base.allowed_tools, vec!["Read"]);
}
#[test]
fn with_schema_converts_to_base_type() {
let typed = AgentConfig::new("test").output_schema_raw(r#"{"type":"object"}"#);
let base: AgentConfig = typed.into();
assert_eq!(base.json_schema.as_deref(), Some(r#"{"type":"object"}"#));
}
#[test]
fn serde_roundtrip_ignores_marker() {
let config = AgentConfig::new("test").allow_tool("Read");
let json = serde_json::to_string(&config).unwrap();
assert!(!json.contains("marker"));
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.allowed_tools, vec!["Read"]);
}
}