use std::any;
use schemars::{JsonSchema, schema_for};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::{Value, from_value, to_string};
use tokio::time;
use tracing::{info, warn};
use crate::error::OperationError;
#[cfg(feature = "prometheus")]
use crate::metric_names;
use crate::provider::{AgentConfig, AgentOutput, AgentProvider, DebugMessage};
use crate::retry::RetryPolicy;
pub struct Model;
impl Model {
pub const SONNET: &str = "sonnet";
pub const OPUS: &str = "opus";
pub const HAIKU: &str = "haiku";
pub const HAIKU_45: &str = "claude-haiku-4-5-20251001";
pub const SONNET_46: &str = "claude-sonnet-4-6";
pub const OPUS_46: &str = "claude-opus-4-6";
pub const SONNET_46_1M: &str = "claude-sonnet-4-6[1m]";
pub const OPUS_46_1M: &str = "claude-opus-4-6[1m]";
pub const OPUS_47: &str = "claude-opus-4-7";
pub const OPUS_47_1M: &str = "claude-opus-4-7[1m]";
}
#[derive(Debug, Default, Clone, Copy, Serialize)]
pub enum PermissionMode {
#[default]
Default,
Auto,
DontAsk,
BypassPermissions,
}
impl<'de> Deserialize<'de> for PermissionMode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
Ok(match s.to_lowercase().replace('_', "").as_str() {
"auto" => Self::Auto,
"dontask" => Self::DontAsk,
"bypass" | "bypasspermissions" => Self::BypassPermissions,
_ => Self::Default,
})
}
}
#[must_use = "an Agent does nothing until .run() is awaited"]
pub struct Agent {
config: AgentConfig,
dry_run: Option<bool>,
retry_policy: Option<RetryPolicy>,
}
impl Agent {
pub fn new() -> Self {
Self {
config: AgentConfig::new(""),
dry_run: None,
retry_policy: None,
}
}
pub fn from_config(config: impl Into<AgentConfig>) -> Self {
Self {
config: config.into(),
dry_run: None,
retry_policy: None,
}
}
pub fn system_prompt(mut self, prompt: &str) -> Self {
self.config.system_prompt = Some(prompt.to_string());
self
}
pub fn prompt(mut self, prompt: &str) -> Self {
self.config.prompt = prompt.to_string();
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.model = model.into();
self
}
pub fn allowed_tools(mut self, tools: &[&str]) -> Self {
self.config.allowed_tools = tools.iter().map(|s| s.to_string()).collect();
self
}
pub fn max_turns(mut self, turns: u32) -> Self {
assert!(turns > 0, "max_turns must be greater than 0");
self.config.max_turns = Some(turns);
self
}
pub fn max_budget_usd(mut self, budget: f64) -> Self {
assert!(
budget.is_finite() && budget > 0.0,
"budget must be a positive finite number, got {budget}"
);
self.config.max_budget_usd = Some(budget);
self
}
pub fn working_dir(mut self, dir: &str) -> Self {
self.config.working_dir = Some(dir.to_string());
self
}
pub fn mcp_config(mut self, config: &str) -> Self {
self.config.mcp_config = Some(config.to_string());
self
}
pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
self.config.permission_mode = mode;
self
}
pub fn output<T: JsonSchema>(mut self) -> Self {
let schema = schema_for!(T);
self.config.json_schema = match to_string(&schema) {
Ok(s) => Some(s),
Err(e) => {
warn!(error = %e, type_name = any::type_name::<T>(), "failed to serialize JSON schema, structured output disabled");
None
}
};
self
}
pub fn output_schema_raw(mut self, schema: &str) -> Self {
self.config.json_schema = Some(schema.to_string());
self
}
pub fn retry(mut self, max_retries: u32) -> Self {
self.retry_policy = Some(RetryPolicy::new(max_retries));
self
}
pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = Some(policy);
self
}
pub fn dry_run(mut self, enabled: bool) -> Self {
self.dry_run = Some(enabled);
self
}
pub fn verbose(mut self) -> Self {
self.config.verbose = true;
self
}
pub fn resume(mut self, session_id: &str) -> Self {
assert!(!session_id.is_empty(), "session_id must not be empty");
assert!(
session_id
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
"session_id must only contain alphanumeric characters, hyphens, or underscores, got: {session_id}"
);
self.config.resume_session_id = Some(session_id.to_string());
self
}
#[tracing::instrument(name = "agent", skip_all, fields(model = %self.config.model, prompt_len = self.config.prompt.len()))]
pub async fn run(self, provider: &dyn AgentProvider) -> Result<AgentResult, OperationError> {
assert!(
!self.config.prompt.trim().is_empty(),
"prompt must not be empty - call .prompt(\"...\") before .run()"
);
if crate::dry_run::effective_dry_run(self.dry_run) {
info!(
prompt_len = self.config.prompt.len(),
"[dry-run] agent call skipped"
);
let mut output =
AgentOutput::new(Value::String("[dry-run] agent call skipped".to_string()));
output.cost_usd = Some(0.0);
output.input_tokens = Some(0);
output.output_tokens = Some(0);
return Ok(AgentResult { output });
}
let result = self.invoke_once(provider).await;
let policy = match &self.retry_policy {
Some(p) => p,
None => return result,
};
if let Err(ref err) = result {
if !crate::retry::is_retryable(err) {
return result;
}
} else {
return result;
}
let mut last_result = result;
for attempt in 0..policy.max_retries {
let delay = policy.delay_for_attempt(attempt);
warn!(
attempt = attempt + 1,
max_retries = policy.max_retries,
delay_ms = delay.as_millis() as u64,
"retrying agent invocation"
);
time::sleep(delay).await;
last_result = self.invoke_once(provider).await;
match &last_result {
Ok(_) => return last_result,
Err(err) if !crate::retry::is_retryable(err) => return last_result,
_ => {}
}
}
last_result
}
async fn invoke_once(
&self,
provider: &dyn AgentProvider,
) -> Result<AgentResult, OperationError> {
#[cfg(feature = "prometheus")]
let model_label = self.config.model.to_string();
let output = match provider.invoke(&self.config).await {
Ok(output) => output,
Err(e) => {
#[cfg(feature = "prometheus")]
{
metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_ERROR).increment(1);
}
return Err(OperationError::Agent(e));
}
};
info!(
duration_ms = output.duration_ms,
cost_usd = output.cost_usd,
input_tokens = output.input_tokens,
output_tokens = output.output_tokens,
model = output.model,
"agent completed"
);
#[cfg(feature = "prometheus")]
{
metrics::counter!(metric_names::AGENT_TOTAL, "model" => model_label.clone(), "status" => metric_names::STATUS_SUCCESS).increment(1);
metrics::histogram!(metric_names::AGENT_DURATION_SECONDS, "model" => model_label.clone())
.record(output.duration_ms as f64 / 1000.0);
if let Some(cost) = output.cost_usd {
metrics::gauge!(metric_names::AGENT_COST_USD_TOTAL, "model" => model_label.clone())
.increment(cost);
}
if let Some(tokens) = output.input_tokens {
metrics::counter!(metric_names::AGENT_TOKENS_INPUT_TOTAL, "model" => model_label.clone()).increment(tokens);
}
if let Some(tokens) = output.output_tokens {
metrics::counter!(metric_names::AGENT_TOKENS_OUTPUT_TOTAL, "model" => model_label)
.increment(tokens);
}
}
Ok(AgentResult { output })
}
}
impl Default for Agent {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct AgentResult {
output: AgentOutput,
}
impl AgentResult {
pub fn text(&self) -> &str {
match self.output.value.as_str() {
Some(s) => s,
None => {
warn!(
value_type = self.output.value.to_string(),
"agent output is not a string, returning empty"
);
""
}
}
}
pub fn value(&self) -> &Value {
&self.output.value
}
pub fn json<T: DeserializeOwned>(&self) -> Result<T, OperationError> {
from_value(self.output.value.clone()).map_err(OperationError::deserialize::<T>)
}
pub fn into_json<T: DeserializeOwned>(self) -> Result<T, OperationError> {
from_value(self.output.value).map_err(OperationError::deserialize::<T>)
}
#[cfg(test)]
pub(crate) fn from_output(output: AgentOutput) -> Self {
Self { output }
}
pub fn session_id(&self) -> Option<&str> {
self.output.session_id.as_deref()
}
pub fn cost_usd(&self) -> Option<f64> {
self.output.cost_usd
}
pub fn input_tokens(&self) -> Option<u64> {
self.output.input_tokens
}
pub fn output_tokens(&self) -> Option<u64> {
self.output.output_tokens
}
pub fn duration_ms(&self) -> u64 {
self.output.duration_ms
}
pub fn model(&self) -> Option<&str> {
self.output.model.as_deref()
}
pub fn debug_messages(&self) -> Option<&[DebugMessage]> {
self.output.debug_messages.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::AgentError;
use crate::provider::InvokeFuture;
use serde_json::json;
struct TestProvider {
output: AgentOutput,
}
impl AgentProvider for TestProvider {
fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
Box::pin(async move {
Ok(AgentOutput {
value: self.output.value.clone(),
session_id: self.output.session_id.clone(),
cost_usd: self.output.cost_usd,
input_tokens: self.output.input_tokens,
output_tokens: self.output.output_tokens,
model: self.output.model.clone(),
duration_ms: self.output.duration_ms,
debug_messages: None,
})
})
}
}
struct ConfigCapture {
output: AgentOutput,
}
impl AgentProvider for ConfigCapture {
fn invoke<'a>(&'a self, config: &'a AgentConfig) -> InvokeFuture<'a> {
let config_json = serde_json::to_value(config).unwrap();
Box::pin(async move {
Ok(AgentOutput {
value: config_json,
session_id: self.output.session_id.clone(),
cost_usd: self.output.cost_usd,
input_tokens: self.output.input_tokens,
output_tokens: self.output.output_tokens,
model: self.output.model.clone(),
duration_ms: self.output.duration_ms,
debug_messages: None,
})
})
}
}
fn default_output() -> AgentOutput {
AgentOutput {
value: json!("test output"),
session_id: Some("sess-123".to_string()),
cost_usd: Some(0.05),
input_tokens: Some(100),
output_tokens: Some(50),
model: Some("sonnet".to_string()),
duration_ms: 1500,
debug_messages: None,
}
}
#[test]
fn model_constants_have_expected_values() {
assert_eq!(Model::SONNET, "sonnet");
assert_eq!(Model::OPUS, "opus");
assert_eq!(Model::HAIKU, "haiku");
assert_eq!(Model::HAIKU_45, "claude-haiku-4-5-20251001");
assert_eq!(Model::SONNET_46, "claude-sonnet-4-6");
assert_eq!(Model::OPUS_46, "claude-opus-4-6");
assert_eq!(Model::SONNET_46_1M, "claude-sonnet-4-6[1m]");
assert_eq!(Model::OPUS_46_1M, "claude-opus-4-6[1m]");
assert_eq!(Model::OPUS_47, "claude-opus-4-7");
assert_eq!(Model::OPUS_47_1M, "claude-opus-4-7[1m]");
}
#[tokio::test]
async fn agent_new_default_values() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
let config = result.value();
assert_eq!(config["system_prompt"], json!(null));
assert_eq!(config["prompt"], json!("hi"));
assert_eq!(config["model"], json!("sonnet"));
assert_eq!(config["allowed_tools"], json!([]));
assert_eq!(config["max_turns"], json!(null));
assert_eq!(config["max_budget_usd"], json!(null));
assert_eq!(config["working_dir"], json!(null));
assert_eq!(config["mcp_config"], json!(null));
assert_eq!(config["permission_mode"], json!("Default"));
assert_eq!(config["json_schema"], json!(null));
}
#[tokio::test]
async fn agent_default_matches_new() {
let provider = ConfigCapture {
output: default_output(),
};
let result_new = Agent::new().prompt("x").run(&provider).await.unwrap();
let result_default = Agent::default().prompt("x").run(&provider).await.unwrap();
assert_eq!(result_new.value(), result_default.value());
}
#[tokio::test]
async fn builder_methods_store_values_correctly() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new()
.system_prompt("you are a bot")
.prompt("do something")
.model(Model::OPUS)
.allowed_tools(&["Read", "Write"])
.max_turns(5)
.max_budget_usd(1.5)
.working_dir("/tmp")
.mcp_config("{}")
.permission_mode(PermissionMode::Auto)
.run(&provider)
.await
.unwrap();
let config = result.value();
assert_eq!(config["system_prompt"], json!("you are a bot"));
assert_eq!(config["prompt"], json!("do something"));
assert_eq!(config["model"], json!("opus"));
assert_eq!(config["allowed_tools"], json!(["Read", "Write"]));
assert_eq!(config["max_turns"], json!(5));
assert_eq!(config["max_budget_usd"], json!(1.5));
assert_eq!(config["working_dir"], json!("/tmp"));
assert_eq!(config["mcp_config"], json!("{}"));
assert_eq!(config["permission_mode"], json!("Auto"));
}
#[test]
#[should_panic(expected = "max_turns must be greater than 0")]
fn max_turns_zero_panics() {
let _ = Agent::new().max_turns(0);
}
#[test]
#[should_panic(expected = "budget must be a positive finite number")]
fn max_budget_negative_panics() {
let _ = Agent::new().max_budget_usd(-1.0);
}
#[test]
#[should_panic(expected = "budget must be a positive finite number")]
fn max_budget_nan_panics() {
let _ = Agent::new().max_budget_usd(f64::NAN);
}
#[test]
#[should_panic(expected = "budget must be a positive finite number")]
fn max_budget_infinity_panics() {
let _ = Agent::new().max_budget_usd(f64::INFINITY);
}
#[tokio::test]
async fn agent_result_text_with_string_value() {
let provider = TestProvider {
output: AgentOutput {
value: json!("hello world"),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
assert_eq!(result.text(), "hello world");
}
#[tokio::test]
async fn agent_result_text_with_non_string_value() {
let provider = TestProvider {
output: AgentOutput {
value: json!(42),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
assert_eq!(result.text(), "");
}
#[tokio::test]
async fn agent_result_text_with_null_value() {
let provider = TestProvider {
output: AgentOutput {
value: json!(null),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
assert_eq!(result.text(), "");
}
#[tokio::test]
async fn agent_result_json_successful_deserialize() {
#[derive(Deserialize, PartialEq, Debug)]
struct MyOutput {
name: String,
count: u32,
}
let provider = TestProvider {
output: AgentOutput {
value: json!({"name": "test", "count": 7}),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
let parsed: MyOutput = result.json().unwrap();
assert_eq!(parsed.name, "test");
assert_eq!(parsed.count, 7);
}
#[tokio::test]
async fn agent_result_json_failed_deserialize() {
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct MyOutput {
name: String,
}
let provider = TestProvider {
output: AgentOutput {
value: json!(42),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
let err = result.json::<MyOutput>().unwrap_err();
assert!(matches!(err, OperationError::Deserialize { .. }));
}
#[tokio::test]
async fn agent_result_accessors() {
let provider = TestProvider {
output: AgentOutput {
value: json!("v"),
session_id: Some("s-1".to_string()),
cost_usd: Some(0.123),
input_tokens: Some(999),
output_tokens: Some(456),
model: Some("opus".to_string()),
duration_ms: 2000,
debug_messages: None,
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
assert_eq!(result.session_id(), Some("s-1"));
assert_eq!(result.cost_usd(), Some(0.123));
assert_eq!(result.input_tokens(), Some(999));
assert_eq!(result.output_tokens(), Some(456));
assert_eq!(result.duration_ms(), 2000);
assert_eq!(result.model(), Some("opus"));
}
#[tokio::test]
async fn resume_passes_session_id_in_config() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new()
.prompt("followup")
.resume("sess-abc")
.run(&provider)
.await
.unwrap();
let config = result.value();
assert_eq!(config["resume_session_id"], json!("sess-abc"));
}
#[tokio::test]
async fn no_resume_has_null_session_id() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new()
.prompt("first call")
.run(&provider)
.await
.unwrap();
let config = result.value();
assert_eq!(config["resume_session_id"], json!(null));
}
#[test]
#[should_panic(expected = "session_id must not be empty")]
fn resume_empty_session_id_panics() {
let _ = Agent::new().resume("");
}
#[test]
#[should_panic(expected = "session_id must only contain")]
fn resume_invalid_chars_panics() {
let _ = Agent::new().resume("sess;rm -rf /");
}
#[test]
fn resume_valid_formats_accepted() {
let _ = Agent::new().resume("sess-abc123");
let _ = Agent::new().resume("a1b2c3d4_session");
let _ = Agent::new().resume("abc-DEF-123_456");
}
#[tokio::test]
#[should_panic(expected = "prompt must not be empty")]
async fn run_without_prompt_panics() {
let provider = TestProvider {
output: default_output(),
};
let _ = Agent::new().run(&provider).await;
}
#[tokio::test]
#[should_panic(expected = "prompt must not be empty")]
async fn run_with_whitespace_only_prompt_panics() {
let provider = TestProvider {
output: default_output(),
};
let _ = Agent::new().prompt(" ").run(&provider).await;
}
#[tokio::test]
async fn model_accepts_custom_string() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new()
.prompt("hi")
.model("mistral-large-latest")
.run(&provider)
.await
.unwrap();
assert_eq!(result.value()["model"], json!("mistral-large-latest"));
}
#[tokio::test]
async fn verbose_sets_config_flag() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new()
.prompt("hi")
.verbose()
.run(&provider)
.await
.unwrap();
assert_eq!(result.value()["verbose"], json!(true));
}
#[tokio::test]
async fn verbose_not_set_by_default() {
let provider = ConfigCapture {
output: default_output(),
};
let result = Agent::new().prompt("hi").run(&provider).await.unwrap();
assert_eq!(result.value()["verbose"], json!(false));
}
#[tokio::test]
async fn debug_messages_none_without_verbose() {
let provider = TestProvider {
output: default_output(),
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
assert!(result.debug_messages().is_none());
}
#[tokio::test]
async fn model_accepts_owned_string() {
let provider = ConfigCapture {
output: default_output(),
};
let model_name = String::from("gpt-4o");
let result = Agent::new()
.prompt("hi")
.model(model_name)
.run(&provider)
.await
.unwrap();
assert_eq!(result.value()["model"], json!("gpt-4o"));
}
#[tokio::test]
async fn into_json_success() {
#[derive(Deserialize, PartialEq, Debug)]
struct Out {
name: String,
}
let provider = TestProvider {
output: AgentOutput {
value: json!({"name": "test"}),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
let parsed: Out = result.into_json().unwrap();
assert_eq!(parsed.name, "test");
}
#[tokio::test]
async fn into_json_failure() {
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Out {
name: String,
}
let provider = TestProvider {
output: AgentOutput {
value: json!(42),
..default_output()
},
};
let result = Agent::new().prompt("test").run(&provider).await.unwrap();
let err = result.into_json::<Out>().unwrap_err();
assert!(matches!(err, OperationError::Deserialize { .. }));
}
#[test]
fn from_output_creates_result() {
let output = AgentOutput {
value: json!("hello"),
..default_output()
};
let result = AgentResult::from_output(output);
assert_eq!(result.text(), "hello");
assert_eq!(result.cost_usd(), Some(0.05));
}
#[test]
#[should_panic(expected = "budget must be a positive finite number")]
fn max_budget_zero_panics() {
let _ = Agent::new().max_budget_usd(0.0);
}
#[test]
fn model_constant_equality() {
assert_eq!(Model::SONNET, "sonnet");
assert_ne!(Model::SONNET, Model::OPUS);
}
#[test]
fn permission_mode_serialize_deserialize_roundtrip() {
for mode in [
PermissionMode::Default,
PermissionMode::Auto,
PermissionMode::DontAsk,
PermissionMode::BypassPermissions,
] {
let json = to_string(&mode).unwrap();
let back: PermissionMode = serde_json::from_str(&json).unwrap();
assert_eq!(format!("{:?}", mode), format!("{:?}", back));
}
}
#[test]
fn retry_builder_stores_policy() {
let agent = Agent::new().retry(3);
assert!(agent.retry_policy.is_some());
assert_eq!(agent.retry_policy.unwrap().max_retries(), 3);
}
#[test]
fn retry_policy_builder_stores_custom_policy() {
use crate::retry::RetryPolicy;
let policy = RetryPolicy::new(5).backoff(Duration::from_secs(1));
let agent = Agent::new().retry_policy(policy);
let p = agent.retry_policy.unwrap();
assert_eq!(p.max_retries(), 5);
}
#[test]
fn no_retry_by_default() {
let agent = Agent::new();
assert!(agent.retry_policy.is_none());
}
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
struct FailNTimesProvider {
fail_count: AtomicU32,
failures_before_success: u32,
output: AgentOutput,
}
impl AgentProvider for FailNTimesProvider {
fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
Box::pin(async move {
let current = self.fail_count.fetch_add(1, Ordering::SeqCst);
if current < self.failures_before_success {
Err(AgentError::ProcessFailed {
exit_code: 1,
stderr: format!("transient failure #{}", current + 1),
})
} else {
Ok(AgentOutput {
value: self.output.value.clone(),
session_id: self.output.session_id.clone(),
cost_usd: self.output.cost_usd,
input_tokens: self.output.input_tokens,
output_tokens: self.output.output_tokens,
model: self.output.model.clone(),
duration_ms: self.output.duration_ms,
debug_messages: None,
})
}
})
}
}
#[tokio::test]
async fn retry_succeeds_after_transient_failures() {
let provider = FailNTimesProvider {
fail_count: AtomicU32::new(0),
failures_before_success: 2,
output: default_output(),
};
let result = Agent::new()
.prompt("test")
.retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
.run(&provider)
.await;
assert!(result.is_ok());
assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn retry_exhausted_returns_last_error() {
let provider = FailNTimesProvider {
fail_count: AtomicU32::new(0),
failures_before_success: 10, output: default_output(),
};
let result = Agent::new()
.prompt("test")
.retry_policy(crate::retry::RetryPolicy::new(2).backoff(Duration::from_millis(1)))
.run(&provider)
.await;
assert!(result.is_err());
assert_eq!(provider.fail_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn retry_does_not_retry_non_retryable_errors() {
let call_count = Arc::new(AtomicU32::new(0));
let count = call_count.clone();
struct CountingNonRetryable {
count: Arc<AtomicU32>,
}
impl AgentProvider for CountingNonRetryable {
fn invoke<'a>(&'a self, _config: &'a AgentConfig) -> InvokeFuture<'a> {
self.count.fetch_add(1, Ordering::SeqCst);
Box::pin(async move {
Err(AgentError::SchemaValidation {
expected: "object".to_string(),
got: "string".to_string(),
debug_messages: Vec::new(),
partial_usage: Box::default(),
})
})
}
}
let provider = CountingNonRetryable { count };
let result = Agent::new()
.prompt("test")
.retry_policy(crate::retry::RetryPolicy::new(3).backoff(Duration::from_millis(1)))
.run(&provider)
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn no_retry_without_policy() {
let provider = FailNTimesProvider {
fail_count: AtomicU32::new(0),
failures_before_success: 1,
output: default_output(),
};
let result = Agent::new().prompt("test").run(&provider).await;
assert!(result.is_err());
assert_eq!(provider.fail_count.load(Ordering::SeqCst), 1);
}
}