Skip to main content

claude_agent/
lib.rs

1//! # claude-agent
2//!
3//! Rust SDK for building AI agents with Anthropic's Claude.
4//!
5//! This crate provides a production-ready, memory-efficient way to build AI agents
6//! using the Anthropic Messages API directly, without CLI subprocess dependencies.
7//!
8//! ## Quick Start
9//!
10//! ```rust,no_run
11//! use claude_agent::query;
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), claude_agent::Error> {
15//!     let response = query("What is 2 + 2?").await?;
16//!     println!("{}", response);
17//!     Ok(())
18//! }
19//! ```
20//!
21//! ## Full Agent Example
22//!
23//! ```rust,no_run
24//! use claude_agent::{Agent, AgentEvent, ToolAccess};
25//! use futures::StreamExt;
26//! use std::pin::pin;
27//!
28//! #[tokio::main]
29//! async fn main() -> Result<(), claude_agent::Error> {
30//!     let agent = Agent::builder()
31//!         .model("claude-sonnet-4-5")
32//!         .tools(ToolAccess::all())
33//!         .working_dir("./project")
34//!         .build()
35//!         .await?;
36//!
37//!     let stream = agent.execute_stream("Fix the bug").await?;
38//!     let mut stream = pin!(stream);
39//!     while let Some(event) = stream.next().await {
40//!         match event? {
41//!             AgentEvent::Text(text) => print!("{}", text),
42//!             AgentEvent::Complete(result) => {
43//!                 println!("Done: {} tokens", result.total_tokens());
44//!             }
45//!             _ => {}
46//!         }
47//!     }
48//!     Ok(())
49//! }
50//! ```
51
52#![cfg_attr(docsrs, feature(doc_cfg))]
53#![allow(missing_docs)]
54#![deny(rustdoc::broken_intra_doc_links)]
55
56pub mod agent;
57pub mod auth;
58pub mod budget;
59pub mod client;
60pub mod common;
61pub mod config;
62pub mod context;
63pub mod hooks;
64pub mod mcp;
65pub mod models;
66pub mod observability;
67pub mod output_style;
68pub mod permissions;
69#[cfg(feature = "plugins")]
70pub mod plugins;
71pub mod prelude;
72pub mod prompts;
73pub mod security;
74pub mod session;
75pub mod skills;
76pub mod subagents;
77pub mod tokens;
78pub mod tools;
79pub mod types;
80
81// =========================================================================
82// Core API re-exports (user-facing types)
83// =========================================================================
84
85pub use agent::{Agent, AgentBuilder, AgentConfig, AgentEvent, AgentResult};
86pub use auth::{Auth, Credential};
87pub use client::{Client, ClientBuilder};
88pub use permissions::{PermissionMode, PermissionPolicy};
89pub use tools::{ExecutionContext, SchemaTool, Tool, ToolAccess, ToolRegistry};
90pub use types::{ContentBlock, Message, Role, ToolDefinition, ToolError, ToolOutput, ToolResult};
91
92// =========================================================================
93// Commonly used configuration re-exports
94// =========================================================================
95
96pub use agent::{
97    AgentMetrics, AgentModelConfig, AgentState, BudgetConfig, CacheConfig, CacheStrategy,
98    ExecutionConfig, PromptConfig, SecurityConfig, SystemPromptMode, ToolStats,
99};
100pub use auth::{CredentialProvider, OAuthConfig};
101pub use client::{
102    BetaConfig, BetaFeature, CloudProvider, EffortLevel, FallbackConfig, ModelConfig, ModelType,
103    OutputConfig, ProviderConfig,
104};
105pub use common::{ContentSource, Index, IndexRegistry, Named, SourceType, ToolRestricted};
106pub use context::{
107    ContextBuilder, MemoryLoader, MemoryProvider, PromptOrchestrator, RuleIndex, StaticContext,
108};
109pub use hooks::{CommandHook, Hook, HookContext, HookEvent, HookManager, HookOutput};
110pub use output_style::OutputStyle;
111pub use session::{
112    Session, SessionConfig, SessionId, SessionManager, SessionMessage, SessionState, ToolState,
113};
114pub use skills::{SkillExecutor, SkillIndex, SkillResult};
115pub use subagents::{SubagentIndex, builtin_subagents};
116
117#[cfg(feature = "cli-integration")]
118pub use auth::ClaudeCliProvider;
119#[cfg(feature = "aws")]
120pub use client::BedrockAdapter;
121#[cfg(feature = "azure")]
122pub use client::FoundryAdapter;
123#[cfg(feature = "gcp")]
124pub use client::VertexAdapter;
125#[cfg(feature = "cli-integration")]
126pub use output_style::{OutputStyleLoader, SystemPromptGenerator};
127#[cfg(feature = "plugins")]
128pub use plugins::{PluginDescriptor, PluginDiscovery, PluginError, PluginManager, PluginManifest};
129#[cfg(feature = "cli-integration")]
130pub use subagents::{SubagentFrontmatter, SubagentIndexLoader};
131
132/// Error type for claude-agent operations.
133///
134/// All errors include actionable context to help diagnose and resolve issues.
135#[derive(Debug, thiserror::Error)]
136#[non_exhaustive]
137pub enum Error {
138    /// API returned an error response.
139    #[error("API error (HTTP {status}): {message}", status = status.map(|s| s.to_string()).unwrap_or_else(|| "unknown".into()))]
140    Api {
141        message: String,
142        status: Option<u16>,
143        error_type: Option<String>,
144    },
145
146    /// Authentication failed.
147    #[error("Authentication failed: {message}")]
148    Auth { message: String },
149
150    /// Network connectivity or request failed.
151    #[error("Network request failed: {0}")]
152    Network(#[from] reqwest::Error),
153
154    /// JSON serialization or deserialization failed.
155    #[error("JSON parsing failed: {0}")]
156    Json(#[from] serde_json::Error),
157
158    /// Failed to parse response or configuration.
159    #[error("Parse error: {0}")]
160    Parse(String),
161
162    /// Tool execution failed.
163    #[error("Tool execution failed: {0}")]
164    Tool(#[from] types::ToolError),
165
166    /// Invalid or missing configuration.
167    #[error("Configuration error: {0}")]
168    Config(String),
169
170    /// File system operation failed.
171    #[error("IO error: {0}")]
172    Io(#[from] std::io::Error),
173
174    /// API rate limit exceeded.
175    #[error("Rate limit exceeded{}", match retry_after {
176        Some(d) => format!(", retry in {:.0}s", d.as_secs_f64()),
177        None => String::new(),
178    })]
179    RateLimit {
180        retry_after: Option<std::time::Duration>,
181    },
182
183    /// Context window token limit exceeded.
184    #[error("Context limit exceeded: {current}/{max} tokens ({:.0}% used)", (*current as f64 / *max as f64) * 100.0)]
185    ContextOverflow { current: usize, max: usize },
186
187    /// Context window would be exceeded by request.
188    #[error("Context window exceeded: {estimated} tokens > {limit} limit (overage: {overage})")]
189    ContextWindowExceeded {
190        estimated: u64,
191        limit: u64,
192        overage: u64,
193    },
194
195    /// Operation exceeded timeout.
196    #[error("Operation timed out after {:.1}s", .0.as_secs_f64())]
197    Timeout(std::time::Duration),
198
199    /// Token configuration validation failed.
200    #[error("Token validation failed: {0}")]
201    TokenValidation(#[from] client::messages::TokenValidationError),
202
203    /// Request parameters are invalid.
204    #[error("Invalid request: {0}")]
205    InvalidRequest(String),
206
207    /// Streaming response error.
208    #[error("Stream error: {0}")]
209    Stream(String),
210
211    /// Required environment variable missing or invalid.
212    #[error("Environment variable error: {0}")]
213    Env(#[from] std::env::VarError),
214
215    /// Operation not supported by the current provider.
216    #[error("{operation} is not supported by {provider}")]
217    NotSupported {
218        provider: &'static str,
219        operation: &'static str,
220    },
221
222    /// Operation blocked by permission policy.
223    #[error("Permission denied: {0}")]
224    Permission(String),
225
226    /// Budget limit exceeded.
227    #[error("Budget exceeded: ${used} used (limit: ${limit})")]
228    BudgetExceeded {
229        used: rust_decimal::Decimal,
230        limit: rust_decimal::Decimal,
231    },
232
233    /// Model is temporarily overloaded.
234    #[error("Model {model} is overloaded, try again later")]
235    ModelOverloaded { model: String },
236
237    /// Session operation failed.
238    #[error("Session error: {0}")]
239    Session(String),
240
241    /// MCP server communication failed.
242    #[error("MCP error: {0}")]
243    Mcp(mcp::McpError),
244
245    /// System resource limit reached (memory, processes, etc.)
246    #[error("Resource exhausted: {0}")]
247    ResourceExhausted(String),
248
249    /// Hook execution failed (blockable hooks only).
250    #[error("Hook '{hook}' failed: {reason}")]
251    HookFailed { hook: String, reason: String },
252
253    /// Hook timed out (blockable hooks only).
254    #[error("Hook '{hook}' timed out after {duration_secs}s")]
255    HookTimeout { hook: String, duration_secs: u64 },
256
257    /// Circuit breaker is open, requests are being rejected.
258    #[error("Circuit breaker is open")]
259    CircuitOpen,
260
261    /// Plugin system error.
262    #[cfg(feature = "plugins")]
263    #[error("Plugin error: {0}")]
264    Plugin(#[from] plugins::PluginError),
265}
266
267/// Error category for unified error handling.
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub enum ErrorCategory {
270    /// Authentication or authorization failures (401, 403)
271    Authorization,
272    /// Configuration, parsing, or setup errors
273    Configuration,
274    /// Network, rate limit, or transient errors that may succeed on retry
275    Transient,
276    /// Session, MCP, or other stateful operation errors
277    Stateful,
278    /// Internal errors (IO, JSON, unexpected states)
279    Internal,
280    /// Resource limits (budget, context, timeout)
281    ResourceLimit,
282}
283
284impl Error {
285    pub fn auth(message: impl Into<String>) -> Self {
286        Error::Auth {
287            message: message.into(),
288        }
289    }
290
291    pub fn category(&self) -> ErrorCategory {
292        match self {
293            Error::Auth { .. } => ErrorCategory::Authorization,
294            Error::Api {
295                status: Some(401 | 403),
296                ..
297            } => ErrorCategory::Authorization,
298            Error::Permission(_) | Error::HookFailed { .. } | Error::HookTimeout { .. } => {
299                ErrorCategory::Authorization
300            }
301
302            Error::Config(_)
303            | Error::Parse(_)
304            | Error::Env(_)
305            | Error::InvalidRequest(_)
306            | Error::TokenValidation(_) => ErrorCategory::Configuration,
307
308            Error::Network(_)
309            | Error::RateLimit { .. }
310            | Error::ModelOverloaded { .. }
311            | Error::CircuitOpen => ErrorCategory::Transient,
312            Error::Api {
313                status: Some(500..=599),
314                ..
315            } => ErrorCategory::Transient,
316
317            Error::Session(_) | Error::Mcp(_) | Error::Stream(_) => ErrorCategory::Stateful,
318
319            Error::BudgetExceeded { .. }
320            | Error::ContextOverflow { .. }
321            | Error::ContextWindowExceeded { .. }
322            | Error::Timeout(_)
323            | Error::ResourceExhausted(_) => ErrorCategory::ResourceLimit,
324
325            Error::Io(_)
326            | Error::Json(_)
327            | Error::Tool(_)
328            | Error::Api { .. }
329            | Error::NotSupported { .. } => ErrorCategory::Internal,
330
331            #[cfg(feature = "plugins")]
332            Error::Plugin(_) => ErrorCategory::Configuration,
333        }
334    }
335
336    pub fn is_unauthorized(&self) -> bool {
337        matches!(
338            self,
339            Error::Api {
340                status: Some(401),
341                ..
342            } | Error::Auth { .. }
343        )
344    }
345
346    pub fn is_overloaded(&self) -> bool {
347        match self {
348            Error::Api {
349                status: Some(529 | 503),
350                ..
351            } => true,
352            Error::Api {
353                error_type: Some(t),
354                ..
355            } if t.contains("overloaded") => true,
356            Error::Api { message, .. } if message.to_lowercase().contains("overloaded") => true,
357            Error::ModelOverloaded { .. } => true,
358            _ => false,
359        }
360    }
361
362    pub fn status_code(&self) -> Option<u16> {
363        match self {
364            Error::Api { status, .. } => *status,
365            _ => None,
366        }
367    }
368
369    pub fn retry_after(&self) -> Option<std::time::Duration> {
370        match self {
371            Error::RateLimit { retry_after } => *retry_after,
372            _ => None,
373        }
374    }
375}
376
377impl From<config::ConfigError> for Error {
378    fn from(err: config::ConfigError) -> Self {
379        match err {
380            config::ConfigError::NotFound { key } => {
381                Error::Config(format!("Key not found: {}", key))
382            }
383            config::ConfigError::InvalidValue { key, message } => {
384                Error::Config(format!("Invalid value for {}: {}", key, message))
385            }
386            config::ConfigError::Serialization(e) => Error::Json(e),
387            config::ConfigError::Io(e) => Error::Io(e),
388            config::ConfigError::Env(e) => Error::Env(e),
389            config::ConfigError::Provider { message } => Error::Config(message),
390            config::ConfigError::ValidationErrors(errors) => Error::Config(errors.to_string()),
391        }
392    }
393}
394
395impl From<context::ContextError> for Error {
396    fn from(err: context::ContextError) -> Self {
397        match err {
398            context::ContextError::Source { message } => Error::Config(message),
399            context::ContextError::TokenBudgetExceeded { current, limit } => {
400                Error::ContextOverflow {
401                    current: current as usize,
402                    max: limit as usize,
403                }
404            }
405            context::ContextError::SkillNotFound { name } => {
406                Error::Config(format!("Skill not found: {}", name))
407            }
408            context::ContextError::RuleNotFound { name } => {
409                Error::Config(format!("Rule not found: {}", name))
410            }
411            context::ContextError::Parse { message } => Error::Parse(message),
412            context::ContextError::Io(e) => Error::Io(e),
413        }
414    }
415}
416
417impl From<session::SessionError> for Error {
418    fn from(err: session::SessionError) -> Self {
419        match err {
420            session::SessionError::NotFound { id } => {
421                Error::Config(format!("Session not found: {}", id))
422            }
423            session::SessionError::Expired { id } => {
424                Error::Config(format!("Session expired: {}", id))
425            }
426            session::SessionError::Storage { message } => Error::Config(message),
427            session::SessionError::Serialization(e) => Error::Json(e),
428            session::SessionError::Compact { message } => Error::Config(message),
429            session::SessionError::Context(e) => e.into(),
430        }
431    }
432}
433
434impl From<security::SecurityError> for Error {
435    fn from(err: security::SecurityError) -> Self {
436        match err {
437            security::SecurityError::Io(e) => Error::Io(e),
438            security::SecurityError::ResourceLimit(msg) => Error::ResourceExhausted(msg),
439            security::SecurityError::BashBlocked(msg) => Error::Permission(msg),
440            security::SecurityError::DeniedPath(path) => {
441                Error::Permission(format!("Denied path: {}", path.display()))
442            }
443            security::SecurityError::PathEscape(path) => {
444                Error::Permission(format!("Path escapes sandbox: {}", path.display()))
445            }
446            security::SecurityError::NotWithinSandbox(path) => {
447                Error::Permission(format!("Path not within sandbox: {}", path.display()))
448            }
449            security::SecurityError::InvalidPath(msg) => Error::Config(msg),
450            security::SecurityError::AbsoluteSymlink(path) => Error::Permission(format!(
451                "Absolute symlink outside sandbox: {}",
452                path.display()
453            )),
454            security::SecurityError::SymlinkDepthExceeded { path, max } => Error::Permission(
455                format!("Symlink depth exceeded (max {}): {}", max, path.display()),
456            ),
457        }
458    }
459}
460
461impl From<security::sandbox::SandboxError> for Error {
462    fn from(err: security::sandbox::SandboxError) -> Self {
463        match err {
464            security::sandbox::SandboxError::Io(e) => Error::Io(e),
465            security::sandbox::SandboxError::NotSupported => {
466                Error::Config("Sandbox not supported on this platform".into())
467            }
468            security::sandbox::SandboxError::NotAvailable(msg) => {
469                Error::Config(format!("Sandbox not available: {}", msg))
470            }
471            security::sandbox::SandboxError::Creation(msg) => {
472                Error::Config(format!("Sandbox creation failed: {}", msg))
473            }
474            security::sandbox::SandboxError::RuleApplication(msg) => {
475                Error::Config(format!("Sandbox rule application failed: {}", msg))
476            }
477            security::sandbox::SandboxError::PathNotAccessible(path) => {
478                Error::Permission(format!("Sandbox path not accessible: {}", path.display()))
479            }
480            security::sandbox::SandboxError::InvalidConfig(msg) => {
481                Error::Config(format!("Invalid sandbox config: {}", msg))
482            }
483        }
484    }
485}
486
487impl From<mcp::McpError> for Error {
488    fn from(err: mcp::McpError) -> Self {
489        match err {
490            mcp::McpError::Io(e) => Error::Io(e),
491            mcp::McpError::Json(e) => Error::Json(e),
492            other => Error::Mcp(other),
493        }
494    }
495}
496
497pub type Result<T> = std::result::Result<T, Error>;
498
499/// Simple query function for one-shot requests
500pub async fn query(prompt: &str) -> Result<String> {
501    let client = Client::builder().auth(Auth::FromEnv).await?.build().await?;
502    client.query(prompt).await
503}
504
505/// Query with a specific model
506pub async fn query_with_model(model: &str, prompt: &str) -> Result<String> {
507    use client::CreateMessageRequest;
508    let client = Client::builder().auth(Auth::FromEnv).await?.build().await?;
509    let request =
510        CreateMessageRequest::new(model, vec![types::Message::user(prompt)]).max_tokens(8192);
511    let response = client.send(request).await?;
512    Ok(response.text())
513}
514
515/// Stream a response for one-shot requests
516pub async fn stream(
517    prompt: &str,
518) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
519    let client = Client::builder().auth(Auth::FromEnv).await?.build().await?;
520    client.stream(prompt).await
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_error_display() {
529        let err = Error::Api {
530            message: "Invalid API key".to_string(),
531            status: Some(401),
532            error_type: None,
533        };
534        assert!(err.to_string().contains("Invalid API key"));
535    }
536
537    #[test]
538    fn test_error_category() {
539        let rate_limit = Error::RateLimit { retry_after: None };
540        assert_eq!(rate_limit.category(), ErrorCategory::Transient);
541
542        let server_error = Error::Api {
543            message: "Internal error".to_string(),
544            status: Some(500),
545            error_type: None,
546        };
547        assert_eq!(server_error.category(), ErrorCategory::Transient);
548
549        let auth_error = Error::auth("Invalid token");
550        assert_eq!(auth_error.category(), ErrorCategory::Authorization);
551    }
552
553    #[test]
554    fn test_config_error_conversion() {
555        let config_err = config::ConfigError::NotFound {
556            key: "api_key".to_string(),
557        };
558        let err: Error = config_err.into();
559        assert!(matches!(err, Error::Config(_)));
560    }
561}