claude_agent/client/
mod.rs

1//! Anthropic API client with multi-provider support.
2
3pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16    AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17    DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18    ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21    BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27    ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28    CountTokensRequest, CountTokensResponse, CreateMessageRequest, EffortLevel, KeepConfig,
29    KeepThinkingConfig, OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, ToolChoice,
30};
31pub use network::{ClientCertConfig, NetworkConfig, PoolConfig, ProxyConfig};
32pub use recovery::StreamRecoveryState;
33pub use resilience::{
34    CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
35    RetryConfig,
36};
37pub use schema::{strict_schema, transform_for_strict};
38pub use streaming::{RecoverableStream, StreamItem, StreamParser};
39
40#[cfg(feature = "aws")]
41pub use adapter::BedrockAdapter;
42#[cfg(feature = "azure")]
43pub use adapter::FoundryAdapter;
44#[cfg(feature = "gcp")]
45pub use adapter::VertexAdapter;
46
47use std::sync::Arc;
48use std::time::Duration;
49
50use crate::auth::{Auth, Credential, OAuthConfig};
51use crate::{Error, Result};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
54
55#[derive(Clone)]
56pub struct Client {
57    adapter: Arc<dyn ProviderAdapter>,
58    http: reqwest::Client,
59    fallback_config: Option<FallbackConfig>,
60    resilience: Option<Arc<Resilience>>,
61}
62
63impl Client {
64    pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
65        let timeout = DEFAULT_TIMEOUT;
66        let http = reqwest::Client::builder()
67            .timeout(timeout)
68            .build()
69            .map_err(Error::Network)?;
70
71        Ok(Self {
72            adapter: Arc::new(adapter),
73            http,
74            fallback_config: None,
75            resilience: None,
76        })
77    }
78
79    pub fn with_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
80        Self {
81            adapter: Arc::new(adapter),
82            http,
83            fallback_config: None,
84            resilience: None,
85        }
86    }
87
88    pub fn with_fallback(mut self, config: FallbackConfig) -> Self {
89        self.fallback_config = Some(config);
90        self
91    }
92
93    pub fn with_resilience(mut self, config: ResilienceConfig) -> Self {
94        self.resilience = Some(Arc::new(Resilience::new(config)));
95        self
96    }
97
98    pub fn resilience(&self) -> Option<&Arc<Resilience>> {
99        self.resilience.as_ref()
100    }
101
102    pub fn builder() -> ClientBuilder {
103        ClientBuilder::default()
104    }
105
106    pub async fn query(&self, prompt: &str) -> Result<String> {
107        self.query_with_model(prompt, ModelType::Primary).await
108    }
109
110    pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
111        let model = self.adapter.model(model_type).to_string();
112        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
113            .with_max_tokens(self.adapter.config().max_tokens);
114
115        let response = self.adapter.send(&self.http, request).await?;
116        Ok(response.text())
117    }
118
119    pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
120        let fallback = match &self.fallback_config {
121            Some(f) => f,
122            None => return self.adapter.send(&self.http, request).await,
123        };
124
125        let mut current_request = request;
126        let mut attempt = 0;
127        let mut using_fallback = false;
128
129        loop {
130            match self.adapter.send(&self.http, current_request.clone()).await {
131                Ok(response) => return Ok(response),
132                Err(e) => {
133                    if !fallback.should_fallback(&e) {
134                        return Err(e);
135                    }
136
137                    attempt += 1;
138                    if attempt > fallback.max_retries {
139                        return Err(e);
140                    }
141
142                    if !using_fallback {
143                        tracing::warn!(
144                            error = %e,
145                            fallback_model = %fallback.fallback_model,
146                            attempt,
147                            max_retries = fallback.max_retries,
148                            "Primary model failed, falling back"
149                        );
150                        current_request = current_request.with_model(&fallback.fallback_model);
151                        using_fallback = true;
152                    } else {
153                        tracing::warn!(
154                            error = %e,
155                            attempt,
156                            max_retries = fallback.max_retries,
157                            "Fallback model failed, retrying"
158                        );
159                    }
160                }
161            }
162        }
163    }
164
165    pub async fn send_no_fallback(
166        &self,
167        request: CreateMessageRequest,
168    ) -> Result<crate::types::ApiResponse> {
169        self.adapter.send(&self.http, request).await
170    }
171
172    pub fn fallback_config(&self) -> Option<&FallbackConfig> {
173        self.fallback_config.as_ref()
174    }
175
176    pub async fn stream(
177        &self,
178        prompt: &str,
179    ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
180        let model = self.adapter.model(ModelType::Primary).to_string();
181        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
182            .with_max_tokens(self.adapter.config().max_tokens);
183
184        let response = self.adapter.send_stream(&self.http, request).await?;
185        let stream = StreamParser::new(response.bytes_stream());
186
187        Ok(futures::StreamExt::filter_map(stream, |item| async move {
188            match item {
189                Ok(StreamItem::Text(text)) => Some(Ok(text)),
190                Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
191                Ok(
192                    StreamItem::Event(_) | StreamItem::Citation(_) | StreamItem::ToolUseComplete(_),
193                ) => None,
194                Err(e) => Some(Err(e)),
195            }
196        }))
197    }
198
199    pub async fn stream_request(
200        &self,
201        request: CreateMessageRequest,
202    ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
203        let response = self.adapter.send_stream(&self.http, request).await?;
204        Ok(StreamParser::new(response.bytes_stream()))
205    }
206
207    pub async fn stream_recoverable(
208        &self,
209        request: CreateMessageRequest,
210    ) -> Result<
211        RecoverableStream<
212            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
213            + Send
214            + 'static
215            + use<>,
216        >,
217    > {
218        let response = self.adapter.send_stream(&self.http, request).await?;
219        Ok(RecoverableStream::new(response.bytes_stream()))
220    }
221
222    pub async fn stream_with_recovery(
223        &self,
224        request: CreateMessageRequest,
225        recovery_state: Option<StreamRecoveryState>,
226    ) -> Result<
227        RecoverableStream<
228            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
229            + Send
230            + 'static
231            + use<>,
232        >,
233    > {
234        let request = match recovery_state {
235            Some(state) if state.is_recoverable() => {
236                let mut req = request;
237                req.messages = state.build_continuation_messages(&req.messages);
238                req
239            }
240            _ => request,
241        };
242        self.stream_recoverable(request).await
243    }
244
245    pub fn batch(&self) -> BatchClient<'_> {
246        BatchClient::new(self)
247    }
248
249    pub fn files(&self) -> FilesClient<'_> {
250        FilesClient::new(self)
251    }
252
253    pub fn adapter(&self) -> &dyn ProviderAdapter {
254        self.adapter.as_ref()
255    }
256
257    pub fn config(&self) -> &ProviderConfig {
258        self.adapter.config()
259    }
260
261    pub(crate) fn http(&self) -> &reqwest::Client {
262        &self.http
263    }
264
265    pub async fn refresh_credentials(&self) -> Result<()> {
266        self.adapter.refresh_credentials().await
267    }
268
269    /// Send a request with automatic auth retry on 401 errors.
270    ///
271    /// Attempts to refresh credentials and retry once if authentication fails.
272    pub async fn send_with_auth_retry(
273        &self,
274        request: CreateMessageRequest,
275    ) -> Result<crate::types::ApiResponse> {
276        match self.send(request.clone()).await {
277            Ok(resp) => Ok(resp),
278            Err(e) if e.is_unauthorized() => {
279                tracing::debug!("Received 401, attempting credential refresh");
280                self.refresh_credentials().await?;
281                self.send(request).await
282            }
283            Err(e) => Err(e),
284        }
285    }
286
287    /// Send a streaming request with automatic auth retry on 401 errors.
288    ///
289    /// Attempts to refresh credentials and retry once if authentication fails.
290    pub async fn send_stream_with_auth_retry(
291        &self,
292        request: CreateMessageRequest,
293    ) -> Result<reqwest::Response> {
294        match self.adapter.send_stream(&self.http, request.clone()).await {
295            Ok(resp) => Ok(resp),
296            Err(e) if e.is_unauthorized() => {
297                tracing::debug!("Received 401, attempting credential refresh for stream");
298                self.refresh_credentials().await?;
299                self.adapter.send_stream(&self.http, request).await
300            }
301            Err(e) => Err(e),
302        }
303    }
304
305    pub async fn count_tokens(
306        &self,
307        request: messages::CountTokensRequest,
308    ) -> Result<messages::CountTokensResponse> {
309        self.adapter.count_tokens(&self.http, request).await
310    }
311
312    pub async fn count_tokens_for_request(
313        &self,
314        request: &CreateMessageRequest,
315    ) -> Result<messages::CountTokensResponse> {
316        let count_request = messages::CountTokensRequest::from_message_request(request);
317        self.count_tokens(count_request).await
318    }
319}
320
321impl std::fmt::Debug for Client {
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        f.debug_struct("Client")
324            .field("provider", &self.adapter.name())
325            .finish()
326    }
327}
328
329#[derive(Default)]
330pub struct ClientBuilder {
331    provider: Option<CloudProvider>,
332    credential: Option<Credential>,
333    oauth_config: Option<OAuthConfig>,
334    config: Option<ProviderConfig>,
335    models: Option<ModelConfig>,
336    network: Option<NetworkConfig>,
337    gateway: Option<GatewayConfig>,
338    timeout: Option<Duration>,
339    fallback_config: Option<FallbackConfig>,
340    resilience_config: Option<ResilienceConfig>,
341
342    #[cfg(feature = "aws")]
343    aws_region: Option<String>,
344    #[cfg(feature = "gcp")]
345    gcp_project: Option<String>,
346    #[cfg(feature = "gcp")]
347    gcp_region: Option<String>,
348    #[cfg(feature = "azure")]
349    azure_resource: Option<String>,
350}
351
352impl ClientBuilder {
353    /// Configure authentication for the client.
354    ///
355    /// Accepts `Auth` enum or any type that converts to it (e.g., API key string).
356    pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
357        let auth = auth.into();
358
359        #[allow(unreachable_patterns)]
360        match &auth {
361            #[cfg(feature = "aws")]
362            Auth::Bedrock { region } => {
363                self.provider = Some(CloudProvider::Bedrock);
364                self.aws_region = Some(region.clone());
365            }
366            #[cfg(feature = "gcp")]
367            Auth::Vertex { project, region } => {
368                self.provider = Some(CloudProvider::Vertex);
369                self.gcp_project = Some(project.clone());
370                self.gcp_region = Some(region.clone());
371            }
372            #[cfg(feature = "azure")]
373            Auth::Foundry { resource } => {
374                self.provider = Some(CloudProvider::Foundry);
375                self.azure_resource = Some(resource.clone());
376            }
377            _ => {
378                self.provider = Some(CloudProvider::Anthropic);
379            }
380        }
381
382        let credential = auth.resolve().await?;
383        if !credential.is_default() {
384            self.credential = Some(credential);
385        }
386
387        Ok(self)
388    }
389
390    pub fn anthropic(mut self) -> Self {
391        self.provider = Some(CloudProvider::Anthropic);
392        self
393    }
394
395    #[cfg(feature = "aws")]
396    pub(crate) fn with_aws_region(mut self, region: String) -> Self {
397        self.provider = Some(CloudProvider::Bedrock);
398        self.aws_region = Some(region);
399        self
400    }
401
402    #[cfg(feature = "gcp")]
403    pub(crate) fn with_gcp(mut self, project: String, region: String) -> Self {
404        self.provider = Some(CloudProvider::Vertex);
405        self.gcp_project = Some(project);
406        self.gcp_region = Some(region);
407        self
408    }
409
410    #[cfg(feature = "azure")]
411    pub(crate) fn with_azure_resource(mut self, resource: String) -> Self {
412        self.provider = Some(CloudProvider::Foundry);
413        self.azure_resource = Some(resource);
414        self
415    }
416
417    pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
418        self.oauth_config = Some(config);
419        self
420    }
421
422    pub fn models(mut self, models: ModelConfig) -> Self {
423        self.models = Some(models);
424        self
425    }
426
427    pub fn config(mut self, config: ProviderConfig) -> Self {
428        self.config = Some(config);
429        self
430    }
431
432    pub fn network(mut self, network: NetworkConfig) -> Self {
433        self.network = Some(network);
434        self
435    }
436
437    pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
438        self.gateway = Some(gateway);
439        self
440    }
441
442    pub fn timeout(mut self, timeout: Duration) -> Self {
443        self.timeout = Some(timeout);
444        self
445    }
446
447    pub fn fallback(mut self, config: FallbackConfig) -> Self {
448        self.fallback_config = Some(config);
449        self
450    }
451
452    pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
453        self.fallback_config = Some(FallbackConfig::new(model));
454        self
455    }
456
457    pub fn resilience(mut self, config: ResilienceConfig) -> Self {
458        self.resilience_config = Some(config);
459        self
460    }
461
462    pub fn with_default_resilience(mut self) -> Self {
463        self.resilience_config = Some(ResilienceConfig::default());
464        self
465    }
466
467    pub async fn build(self) -> Result<Client> {
468        let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
469
470        let models = self.models.unwrap_or_else(|| provider.default_models());
471
472        let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
473
474        let adapter: Box<dyn ProviderAdapter> = match provider {
475            CloudProvider::Anthropic => {
476                let adapter = if let Some(cred) = self.credential {
477                    let mut a = AnthropicAdapter::from_credential(config, cred, self.oauth_config);
478                    if let Some(ref gw) = self.gateway
479                        && let Some(ref url) = gw.base_url
480                    {
481                        a = a.with_base_url(url);
482                    }
483                    a
484                } else {
485                    let mut a = AnthropicAdapter::new(config);
486                    if let Some(ref gw) = self.gateway {
487                        if let Some(ref url) = gw.base_url {
488                            a = a.with_base_url(url);
489                        }
490                        if let Some(ref token) = gw.auth_token {
491                            a = a.with_api_key(token);
492                        }
493                    }
494                    a
495                };
496                Box::new(adapter)
497            }
498            #[cfg(feature = "aws")]
499            CloudProvider::Bedrock => {
500                let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
501                if let Some(region) = self.aws_region {
502                    adapter = adapter.with_region(region);
503                }
504                Box::new(adapter)
505            }
506            #[cfg(feature = "gcp")]
507            CloudProvider::Vertex => {
508                let mut adapter = adapter::VertexAdapter::from_env(config).await?;
509                if let Some(project) = self.gcp_project {
510                    adapter = adapter.with_project(project);
511                }
512                if let Some(region) = self.gcp_region {
513                    adapter = adapter.with_region(region);
514                }
515                Box::new(adapter)
516            }
517            #[cfg(feature = "azure")]
518            CloudProvider::Foundry => {
519                let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
520                if let Some(resource) = self.azure_resource {
521                    adapter = adapter.with_resource(resource);
522                }
523                Box::new(adapter)
524            }
525        };
526
527        let mut http_builder =
528            reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
529
530        if let Some(ref network) = self.network {
531            http_builder = network
532                .apply_to_builder(http_builder)
533                .map_err(|e| Error::Config(e.to_string()))?;
534        }
535
536        let http = http_builder.build().map_err(Error::Network)?;
537
538        let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
539
540        Ok(Client {
541            adapter: Arc::from(adapter),
542            http,
543            fallback_config: self.fallback_config,
544            resilience,
545        })
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_client_builder() {
555        let _builder = Client::builder().anthropic();
556    }
557
558    #[test]
559    fn test_cloud_provider_from_env() {
560        let provider = CloudProvider::from_env();
561        assert_eq!(provider, CloudProvider::Anthropic);
562    }
563
564    #[tokio::test]
565    async fn test_builder_with_auth_credential() {
566        let _builder = Client::builder()
567            .anthropic()
568            .auth(Credential::api_key("test-key"))
569            .await
570            .unwrap();
571    }
572}