use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::error::AgentError;
use crate::operations::agent::{Model, PermissionMode};
pub type InvokeFuture<'a> =
Pin<Box<dyn Future<Output = Result<AgentOutput, AgentError>> + Send + 'a>>;
#[derive(Debug, Clone, Copy)]
pub struct NoTools;
#[derive(Debug, Clone, Copy)]
pub struct WithTools;
#[derive(Debug, Clone, Copy)]
pub struct NoSchema;
#[derive(Debug, Clone, Copy)]
pub struct WithSchema;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "", deserialize = ""))]
#[non_exhaustive]
pub struct AgentConfig<Tools = NoTools, Schema = NoSchema> {
pub system_prompt: Option<String>,
pub prompt: String,
#[serde(default = "default_model")]
pub model: String,
#[serde(default)]
pub allowed_tools: Vec<String>,
#[serde(default)]
pub disallowed_tools: Vec<String>,
pub max_turns: Option<u32>,
pub max_budget_usd: Option<f64>,
pub working_dir: Option<String>,
pub mcp_config: Option<String>,
#[serde(default)]
pub strict_mcp_config: bool,
#[serde(default)]
pub bare: bool,
#[serde(default)]
pub permission_mode: PermissionMode,
#[serde(alias = "output_schema")]
pub json_schema: Option<String>,
pub resume_session_id: Option<String>,
#[serde(default)]
pub verbose: bool,
#[serde(skip)]
pub(crate) _marker: PhantomData<(Tools, Schema)>,
}
fn default_model() -> String {
Model::SONNET.to_string()
}
impl AgentConfig {
pub fn new(prompt: &str) -> Self {
Self {
system_prompt: None,
prompt: prompt.to_string(),
model: Model::SONNET.to_string(),
allowed_tools: Vec::new(),
disallowed_tools: Vec::new(),
max_turns: None,
max_budget_usd: None,
working_dir: None,
mcp_config: None,
strict_mcp_config: false,
bare: false,
permission_mode: PermissionMode::Default,
json_schema: None,
resume_session_id: None,
verbose: false,
_marker: PhantomData,
}
}
}
impl<Tools, Schema> AgentConfig<Tools, Schema> {
pub fn system_prompt(mut self, prompt: &str) -> Self {
self.system_prompt = Some(prompt.to_string());
self
}
pub fn model(mut self, model: &str) -> Self {
self.model = model.to_string();
self
}
pub fn max_budget_usd(mut self, budget: f64) -> Self {
self.max_budget_usd = Some(budget);
self
}
pub fn max_turns(mut self, turns: u32) -> Self {
self.max_turns = Some(turns);
self
}
pub fn working_dir(mut self, dir: &str) -> Self {
self.working_dir = Some(dir.to_string());
self
}
pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
self.permission_mode = mode;
self
}
pub fn verbose(mut self, enabled: bool) -> Self {
self.verbose = enabled;
self
}
pub fn mcp_config(mut self, config: &str) -> Self {
self.mcp_config = Some(config.to_string());
self
}
pub fn strict_mcp_config(mut self, strict: bool) -> Self {
self.strict_mcp_config = strict;
self
}
pub fn bare(mut self, enabled: bool) -> Self {
self.bare = enabled;
self
}
pub fn disallowed_tools<I, S>(mut self, tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.disallowed_tools = tools.into_iter().map(Into::into).collect();
self
}
pub fn resume(mut self, session_id: &str) -> Self {
self.resume_session_id = Some(session_id.to_string());
self
}
fn change_state<T2, S2>(self) -> AgentConfig<T2, S2> {
AgentConfig {
system_prompt: self.system_prompt,
prompt: self.prompt,
model: self.model,
allowed_tools: self.allowed_tools,
disallowed_tools: self.disallowed_tools,
max_turns: self.max_turns,
max_budget_usd: self.max_budget_usd,
working_dir: self.working_dir,
mcp_config: self.mcp_config,
strict_mcp_config: self.strict_mcp_config,
bare: self.bare,
permission_mode: self.permission_mode,
json_schema: self.json_schema,
resume_session_id: self.resume_session_id,
verbose: self.verbose,
_marker: PhantomData,
}
}
}
impl<Tools> AgentConfig<Tools, NoSchema> {
pub fn allow_tool(mut self, tool: &str) -> AgentConfig<WithTools, NoSchema> {
self.allowed_tools.push(tool.to_string());
self.change_state()
}
}
impl<Schema> AgentConfig<NoTools, Schema> {
pub fn output<T: JsonSchema>(mut self) -> AgentConfig<NoTools, WithSchema> {
let schema = schemars::schema_for!(T);
let serialized = serde_json::to_string(&schema).unwrap_or_else(|e| {
panic!(
"failed to serialize JSON schema for {}: {e}",
std::any::type_name::<T>()
)
});
self.json_schema = Some(serialized);
self.change_state()
}
pub fn output_schema_raw(mut self, schema: &str) -> AgentConfig<NoTools, WithSchema> {
self.json_schema = Some(schema.to_string());
self.change_state()
}
}
impl From<AgentConfig<WithTools, NoSchema>> for AgentConfig {
fn from(config: AgentConfig<WithTools, NoSchema>) -> Self {
config.change_state()
}
}
impl From<AgentConfig<NoTools, WithSchema>> for AgentConfig {
fn from(config: AgentConfig<NoTools, WithSchema>) -> Self {
config.change_state()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[non_exhaustive]
pub struct AgentOutput {
pub value: Value,
pub session_id: Option<String>,
pub cost_usd: Option<f64>,
pub input_tokens: Option<u64>,
pub output_tokens: Option<u64>,
pub model: Option<String>,
pub duration_ms: u64,
pub debug_messages: Option<Vec<DebugMessage>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct DebugMessage {
pub text: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub thinking: Option<String>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub thinking_redacted: bool,
pub tool_calls: Vec<DebugToolCall>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tool_results: Vec<DebugToolResult>,
pub stop_reason: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub input_tokens: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output_tokens: Option<u64>,
}
impl fmt::Display for DebugMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref thinking) = self.thinking {
writeln!(f, "[thinking] {thinking}")?;
} else if self.thinking_redacted {
writeln!(f, "[thinking redacted]")?;
}
if let Some(ref text) = self.text {
writeln!(f, "[assistant] {text}")?;
}
for tc in &self.tool_calls {
write!(f, "{tc}")?;
}
for tr in &self.tool_results {
write!(f, "{tr}")?;
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct DebugToolCall {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub name: String,
pub input: Value,
}
impl fmt::Display for DebugToolCall {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, " [tool_use] {} -> {}", self.name, self.input)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct DebugToolResult {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_use_id: Option<String>,
pub content: Value,
#[serde(default)]
pub is_error: bool,
}
impl fmt::Display for DebugToolResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let kind = if self.is_error {
"tool_error"
} else {
"tool_result"
};
writeln!(f, " [{kind}] {}", self.content)
}
}
impl AgentOutput {
pub fn new(value: Value) -> Self {
Self {
value,
session_id: None,
cost_usd: None,
input_tokens: None,
output_tokens: None,
model: None,
duration_ms: 0,
debug_messages: None,
}
}
}
pub trait AgentProvider: Send + Sync {
fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a>;
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn full_config() -> AgentConfig {
AgentConfig {
system_prompt: Some("you are helpful".to_string()),
prompt: "do stuff".to_string(),
model: Model::OPUS.to_string(),
allowed_tools: vec!["Read".to_string(), "Write".to_string()],
disallowed_tools: vec!["Bash".to_string()],
max_turns: Some(10),
max_budget_usd: Some(2.5),
working_dir: Some("/tmp".to_string()),
mcp_config: Some("{}".to_string()),
strict_mcp_config: true,
bare: true,
permission_mode: PermissionMode::Auto,
json_schema: Some(r#"{"type":"object"}"#.to_string()),
resume_session_id: None,
verbose: false,
_marker: PhantomData,
}
}
#[test]
fn agent_config_serialize_deserialize_roundtrip() {
let config = full_config();
let json = serde_json::to_string(&config).unwrap();
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.system_prompt, Some("you are helpful".to_string()));
assert_eq!(back.prompt, "do stuff");
assert_eq!(back.allowed_tools, vec!["Read", "Write"]);
assert_eq!(back.max_turns, Some(10));
assert_eq!(back.max_budget_usd, Some(2.5));
assert_eq!(back.working_dir, Some("/tmp".to_string()));
assert_eq!(back.mcp_config, Some("{}".to_string()));
assert_eq!(back.json_schema, Some(r#"{"type":"object"}"#.to_string()));
}
#[test]
fn agent_config_with_all_optional_fields_none() {
let config: AgentConfig = AgentConfig {
system_prompt: None,
prompt: "hello".to_string(),
model: Model::HAIKU.to_string(),
allowed_tools: vec![],
disallowed_tools: vec![],
max_turns: None,
max_budget_usd: None,
working_dir: None,
mcp_config: None,
strict_mcp_config: false,
bare: false,
permission_mode: PermissionMode::Default,
json_schema: None,
resume_session_id: None,
verbose: false,
_marker: PhantomData,
};
let json = serde_json::to_string(&config).unwrap();
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.system_prompt, None);
assert_eq!(back.prompt, "hello");
assert!(back.allowed_tools.is_empty());
assert_eq!(back.max_turns, None);
assert_eq!(back.max_budget_usd, None);
assert_eq!(back.working_dir, None);
assert_eq!(back.mcp_config, None);
assert_eq!(back.json_schema, None);
}
#[test]
fn agent_output_serialize_deserialize_roundtrip() {
let output = AgentOutput {
value: json!({"key": "value"}),
session_id: Some("sess-abc".to_string()),
cost_usd: Some(0.01),
input_tokens: Some(500),
output_tokens: Some(200),
model: Some("claude-sonnet".to_string()),
duration_ms: 3000,
debug_messages: None,
};
let json = serde_json::to_string(&output).unwrap();
let back: AgentOutput = serde_json::from_str(&json).unwrap();
assert_eq!(back.value, json!({"key": "value"}));
assert_eq!(back.session_id, Some("sess-abc".to_string()));
assert_eq!(back.cost_usd, Some(0.01));
assert_eq!(back.input_tokens, Some(500));
assert_eq!(back.output_tokens, Some(200));
assert_eq!(back.model, Some("claude-sonnet".to_string()));
assert_eq!(back.duration_ms, 3000);
}
#[test]
fn agent_config_new_has_correct_defaults() {
let config = AgentConfig::new("test prompt");
assert_eq!(config.prompt, "test prompt");
assert_eq!(config.system_prompt, None);
assert_eq!(config.model, Model::SONNET);
assert!(config.allowed_tools.is_empty());
assert_eq!(config.max_turns, None);
assert_eq!(config.max_budget_usd, None);
assert_eq!(config.working_dir, None);
assert_eq!(config.mcp_config, None);
assert!(matches!(config.permission_mode, PermissionMode::Default));
assert_eq!(config.json_schema, None);
assert_eq!(config.resume_session_id, None);
assert!(!config.verbose);
}
#[test]
fn agent_output_new_has_correct_defaults() {
let output = AgentOutput::new(json!("test"));
assert_eq!(output.value, json!("test"));
assert_eq!(output.session_id, None);
assert_eq!(output.cost_usd, None);
assert_eq!(output.input_tokens, None);
assert_eq!(output.output_tokens, None);
assert_eq!(output.model, None);
assert_eq!(output.duration_ms, 0);
assert!(output.debug_messages.is_none());
}
#[test]
fn agent_config_resume_session_roundtrip() {
let mut config = AgentConfig::new("test");
config.resume_session_id = Some("sess-xyz".to_string());
let json = serde_json::to_string(&config).unwrap();
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.resume_session_id, Some("sess-xyz".to_string()));
}
#[test]
fn agent_output_debug_does_not_panic() {
let output = AgentOutput {
value: json!(null),
session_id: None,
cost_usd: None,
input_tokens: None,
output_tokens: None,
model: None,
duration_ms: 0,
debug_messages: None,
};
let debug_str = format!("{:?}", output);
assert!(!debug_str.is_empty());
}
#[test]
fn allow_tool_transitions_to_with_tools() {
let config = AgentConfig::new("test").allow_tool("Read");
assert_eq!(config.allowed_tools, vec!["Read"]);
let config = config.allow_tool("Write");
assert_eq!(config.allowed_tools, vec!["Read", "Write"]);
}
#[test]
fn output_schema_raw_transitions_to_with_schema() {
let config = AgentConfig::new("test").output_schema_raw(r#"{"type":"object"}"#);
assert_eq!(config.json_schema.as_deref(), Some(r#"{"type":"object"}"#));
}
#[test]
fn with_tools_converts_to_base_type() {
let typed = AgentConfig::new("test").allow_tool("Read");
let base: AgentConfig = typed.into();
assert_eq!(base.allowed_tools, vec!["Read"]);
}
#[test]
fn with_schema_converts_to_base_type() {
let typed = AgentConfig::new("test").output_schema_raw(r#"{"type":"object"}"#);
let base: AgentConfig = typed.into();
assert_eq!(base.json_schema.as_deref(), Some(r#"{"type":"object"}"#));
}
#[test]
fn serde_roundtrip_ignores_marker() {
let config = AgentConfig::new("test").allow_tool("Read");
let json = serde_json::to_string(&config).unwrap();
assert!(!json.contains("marker"));
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.allowed_tools, vec!["Read"]);
}
#[test]
fn bare_defaults_to_false() {
let config = AgentConfig::new("hello");
assert!(!config.bare, "bare must default to false");
}
#[test]
fn bare_builder_sets_flag() {
let config = AgentConfig::new("hello").bare(true);
assert!(config.bare, "bare(true) must enable the flag");
let config = config.bare(false);
assert!(!config.bare, "bare(false) must disable the flag");
}
#[test]
fn bare_serde_default_when_missing() {
let raw = r#"{"prompt":"hello","model":"sonnet"}"#;
let config: AgentConfig = serde_json::from_str(raw).unwrap();
assert!(
!config.bare,
"bare must default to false when absent from serialized payload"
);
}
#[test]
fn bare_serde_roundtrip() {
let mut config = AgentConfig::new("hello");
config.bare = true;
let json = serde_json::to_string(&config).unwrap();
assert!(
json.contains("\"bare\":true"),
"serialized form must contain bare:true, got: {json}"
);
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert!(back.bare, "bare must survive a serde roundtrip");
}
#[test]
fn disallowed_tools_defaults_to_empty() {
let config = AgentConfig::new("hello");
assert!(
config.disallowed_tools.is_empty(),
"disallowed_tools must default to empty"
);
}
#[test]
fn disallowed_tools_builder_replaces_list() {
let config = AgentConfig::new("hello").disallowed_tools(["Write", "Edit"]);
assert_eq!(config.disallowed_tools, vec!["Write", "Edit"]);
let config = config.disallowed_tools(["Bash"]);
assert_eq!(config.disallowed_tools, vec!["Bash"]);
let config = config.disallowed_tools(std::iter::empty::<String>());
assert!(config.disallowed_tools.is_empty());
}
#[test]
fn disallowed_tools_compatible_with_output() {
#[derive(serde::Deserialize, JsonSchema)]
#[allow(dead_code)]
struct Out {
ok: bool,
}
let before: AgentConfig<NoTools, WithSchema> = AgentConfig::new("classify")
.disallowed_tools(["Write", "Edit"])
.output::<Out>();
assert_eq!(before.disallowed_tools, vec!["Write", "Edit"]);
assert!(before.json_schema.is_some());
let after: AgentConfig<NoTools, WithSchema> = AgentConfig::new("classify")
.output::<Out>()
.disallowed_tools(["Write"]);
assert_eq!(after.disallowed_tools, vec!["Write"]);
assert!(after.json_schema.is_some());
}
#[test]
fn disallowed_tools_serde_default_when_missing() {
let raw = r#"{"prompt":"hello","model":"sonnet"}"#;
let config: AgentConfig = serde_json::from_str(raw).unwrap();
assert!(
config.disallowed_tools.is_empty(),
"disallowed_tools must default to empty when absent from serialized payload"
);
}
#[test]
fn disallowed_tools_serde_roundtrip() {
let config = AgentConfig::new("hello").disallowed_tools(["Write", "Edit"]);
let json = serde_json::to_string(&config).unwrap();
assert!(
json.contains("\"disallowed_tools\":[\"Write\",\"Edit\"]"),
"serialized form must contain the disallowed_tools array, got: {json}"
);
let back: AgentConfig = serde_json::from_str(&json).unwrap();
assert_eq!(back.disallowed_tools, vec!["Write", "Edit"]);
}
}