claude_agent/client/
mod.rs

1//! Anthropic API client with multi-provider support.
2
3pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16    AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17    DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18    ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21    BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27    ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28    CountTokensRequest, CountTokensResponse, CreateMessageRequest, DEFAULT_MAX_TOKENS, EffortLevel,
29    KeepConfig, KeepThinkingConfig, MAX_TOKENS_128K, MIN_MAX_TOKENS, MIN_THINKING_BUDGET,
30    OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, TokenValidationError, ToolChoice,
31};
32pub use network::{ClientCertConfig, NetworkConfig, PoolConfig, ProxyConfig};
33pub use recovery::StreamRecoveryState;
34pub use resilience::{
35    CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
36    RetryConfig,
37};
38pub use schema::{strict_schema, transform_for_strict};
39pub use streaming::{RecoverableStream, StreamItem, StreamParser};
40
41#[cfg(feature = "aws")]
42pub use adapter::BedrockAdapter;
43#[cfg(feature = "azure")]
44pub use adapter::FoundryAdapter;
45#[cfg(feature = "gcp")]
46pub use adapter::VertexAdapter;
47
48use std::sync::Arc;
49use std::time::Duration;
50
51use crate::auth::{Auth, Credential, OAuthConfig};
52use crate::{Error, Result};
53
54const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
55
56#[derive(Clone)]
57pub struct Client {
58    adapter: Arc<dyn ProviderAdapter>,
59    http: reqwest::Client,
60    fallback_config: Option<FallbackConfig>,
61    resilience: Option<Arc<Resilience>>,
62}
63
64impl Client {
65    pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
66        let timeout = DEFAULT_TIMEOUT;
67        let http = reqwest::Client::builder()
68            .timeout(timeout)
69            .build()
70            .map_err(Error::Network)?;
71
72        Ok(Self {
73            adapter: Arc::new(adapter),
74            http,
75            fallback_config: None,
76            resilience: None,
77        })
78    }
79
80    pub fn with_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
81        Self {
82            adapter: Arc::new(adapter),
83            http,
84            fallback_config: None,
85            resilience: None,
86        }
87    }
88
89    pub fn with_fallback(mut self, config: FallbackConfig) -> Self {
90        self.fallback_config = Some(config);
91        self
92    }
93
94    pub fn with_resilience(mut self, config: ResilienceConfig) -> Self {
95        self.resilience = Some(Arc::new(Resilience::new(config)));
96        self
97    }
98
99    pub fn resilience(&self) -> Option<&Arc<Resilience>> {
100        self.resilience.as_ref()
101    }
102
103    pub fn builder() -> ClientBuilder {
104        ClientBuilder::default()
105    }
106
107    pub async fn query(&self, prompt: &str) -> Result<String> {
108        self.query_with_model(prompt, ModelType::Primary).await
109    }
110
111    pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
112        let model = self.adapter.model(model_type).to_string();
113        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
114            .with_max_tokens(self.adapter.config().max_tokens);
115        request.validate()?;
116
117        let response = self.adapter.send(&self.http, request).await?;
118        Ok(response.text())
119    }
120
121    pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
122        request.validate()?;
123
124        let fallback = match &self.fallback_config {
125            Some(f) => f,
126            None => return self.adapter.send(&self.http, request).await,
127        };
128
129        let mut current_request = request;
130        let mut attempt = 0;
131        let mut using_fallback = false;
132
133        loop {
134            match self.adapter.send(&self.http, current_request.clone()).await {
135                Ok(response) => return Ok(response),
136                Err(e) => {
137                    if !fallback.should_fallback(&e) {
138                        return Err(e);
139                    }
140
141                    attempt += 1;
142                    if attempt > fallback.max_retries {
143                        return Err(e);
144                    }
145
146                    if !using_fallback {
147                        tracing::warn!(
148                            error = %e,
149                            fallback_model = %fallback.fallback_model,
150                            attempt,
151                            max_retries = fallback.max_retries,
152                            "Primary model failed, falling back"
153                        );
154                        current_request = current_request.with_model(&fallback.fallback_model);
155                        using_fallback = true;
156                    } else {
157                        tracing::warn!(
158                            error = %e,
159                            attempt,
160                            max_retries = fallback.max_retries,
161                            "Fallback model failed, retrying"
162                        );
163                    }
164                }
165            }
166        }
167    }
168
169    pub async fn send_no_fallback(
170        &self,
171        request: CreateMessageRequest,
172    ) -> Result<crate::types::ApiResponse> {
173        request.validate()?;
174        self.adapter.send(&self.http, request).await
175    }
176
177    pub fn fallback_config(&self) -> Option<&FallbackConfig> {
178        self.fallback_config.as_ref()
179    }
180
181    pub async fn stream(
182        &self,
183        prompt: &str,
184    ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
185        let model = self.adapter.model(ModelType::Primary).to_string();
186        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
187            .with_max_tokens(self.adapter.config().max_tokens);
188        request.validate()?;
189
190        let response = self.adapter.send_stream(&self.http, request).await?;
191        let stream = StreamParser::new(response.bytes_stream());
192
193        Ok(futures::StreamExt::filter_map(stream, |item| async move {
194            match item {
195                Ok(StreamItem::Text(text)) => Some(Ok(text)),
196                Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
197                Ok(
198                    StreamItem::Event(_) | StreamItem::Citation(_) | StreamItem::ToolUseComplete(_),
199                ) => None,
200                Err(e) => Some(Err(e)),
201            }
202        }))
203    }
204
205    pub async fn stream_request(
206        &self,
207        request: CreateMessageRequest,
208    ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
209        request.validate()?;
210        let response = self.adapter.send_stream(&self.http, request).await?;
211        Ok(StreamParser::new(response.bytes_stream()))
212    }
213
214    pub async fn stream_recoverable(
215        &self,
216        request: CreateMessageRequest,
217    ) -> Result<
218        RecoverableStream<
219            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
220            + Send
221            + 'static
222            + use<>,
223        >,
224    > {
225        request.validate()?;
226        let response = self.adapter.send_stream(&self.http, request).await?;
227        Ok(RecoverableStream::new(response.bytes_stream()))
228    }
229
230    pub async fn stream_with_recovery(
231        &self,
232        request: CreateMessageRequest,
233        recovery_state: Option<StreamRecoveryState>,
234    ) -> Result<
235        RecoverableStream<
236            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
237            + Send
238            + 'static
239            + use<>,
240        >,
241    > {
242        let request = match recovery_state {
243            Some(state) if state.is_recoverable() => {
244                let mut req = request;
245                req.messages = state.build_continuation_messages(&req.messages);
246                req
247            }
248            _ => request,
249        };
250        self.stream_recoverable(request).await
251    }
252
253    pub fn batch(&self) -> BatchClient<'_> {
254        BatchClient::new(self)
255    }
256
257    pub fn files(&self) -> FilesClient<'_> {
258        FilesClient::new(self)
259    }
260
261    pub fn adapter(&self) -> &dyn ProviderAdapter {
262        self.adapter.as_ref()
263    }
264
265    pub fn config(&self) -> &ProviderConfig {
266        self.adapter.config()
267    }
268
269    pub(crate) fn http(&self) -> &reqwest::Client {
270        &self.http
271    }
272
273    pub async fn refresh_credentials(&self) -> Result<()> {
274        self.adapter.refresh_credentials().await
275    }
276
277    pub async fn send_with_auth_retry(
278        &self,
279        request: CreateMessageRequest,
280    ) -> Result<crate::types::ApiResponse> {
281        self.adapter.ensure_fresh_credentials().await?;
282
283        match self.send(request.clone()).await {
284            Ok(resp) => Ok(resp),
285            Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
286                tracing::debug!("Received 401, refreshing credentials");
287                self.refresh_credentials().await?;
288                self.send(request).await
289            }
290            Err(e) => Err(e),
291        }
292    }
293
294    pub async fn send_stream_with_auth_retry(
295        &self,
296        request: CreateMessageRequest,
297    ) -> Result<reqwest::Response> {
298        request.validate()?;
299        self.adapter.ensure_fresh_credentials().await?;
300
301        match self.adapter.send_stream(&self.http, request.clone()).await {
302            Ok(resp) => Ok(resp),
303            Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
304                tracing::debug!("Received 401, refreshing credentials");
305                self.refresh_credentials().await?;
306                self.adapter.send_stream(&self.http, request).await
307            }
308            Err(e) => Err(e),
309        }
310    }
311
312    pub async fn count_tokens(
313        &self,
314        request: messages::CountTokensRequest,
315    ) -> Result<messages::CountTokensResponse> {
316        self.adapter.ensure_fresh_credentials().await?;
317
318        match self.adapter.count_tokens(&self.http, request.clone()).await {
319            Ok(resp) => Ok(resp),
320            Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
321                self.refresh_credentials().await?;
322                self.adapter.count_tokens(&self.http, request).await
323            }
324            Err(e) => Err(e),
325        }
326    }
327
328    pub async fn count_tokens_for_request(
329        &self,
330        request: &CreateMessageRequest,
331    ) -> Result<messages::CountTokensResponse> {
332        let count_request = messages::CountTokensRequest::from_message_request(request);
333        self.count_tokens(count_request).await
334    }
335}
336
337impl std::fmt::Debug for Client {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        f.debug_struct("Client")
340            .field("provider", &self.adapter.name())
341            .finish()
342    }
343}
344
345#[derive(Default)]
346pub struct ClientBuilder {
347    provider: Option<CloudProvider>,
348    credential: Option<Credential>,
349    credential_provider: Option<Arc<dyn crate::auth::CredentialProvider>>,
350    oauth_config: Option<OAuthConfig>,
351    config: Option<ProviderConfig>,
352    models: Option<ModelConfig>,
353    network: Option<NetworkConfig>,
354    gateway: Option<GatewayConfig>,
355    timeout: Option<Duration>,
356    fallback_config: Option<FallbackConfig>,
357    resilience_config: Option<ResilienceConfig>,
358
359    #[cfg(feature = "aws")]
360    aws_region: Option<String>,
361    #[cfg(feature = "gcp")]
362    gcp_project: Option<String>,
363    #[cfg(feature = "gcp")]
364    gcp_region: Option<String>,
365    #[cfg(feature = "azure")]
366    azure_resource: Option<String>,
367}
368
369impl ClientBuilder {
370    /// Configure authentication for the client.
371    ///
372    /// Accepts `Auth` enum or any type that converts to it (e.g., API key string).
373    /// For `Auth::ClaudeCli`, the credential provider is preserved for automatic token refresh.
374    pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
375        let auth = auth.into();
376
377        #[allow(unreachable_patterns)]
378        match &auth {
379            #[cfg(feature = "aws")]
380            Auth::Bedrock { region } => {
381                self.provider = Some(CloudProvider::Bedrock);
382                self.aws_region = Some(region.clone());
383            }
384            #[cfg(feature = "gcp")]
385            Auth::Vertex { project, region } => {
386                self.provider = Some(CloudProvider::Vertex);
387                self.gcp_project = Some(project.clone());
388                self.gcp_region = Some(region.clone());
389            }
390            #[cfg(feature = "azure")]
391            Auth::Foundry { resource } => {
392                self.provider = Some(CloudProvider::Foundry);
393                self.azure_resource = Some(resource.clone());
394            }
395            _ => {
396                self.provider = Some(CloudProvider::Anthropic);
397            }
398        }
399
400        let (credential, provider) = auth.resolve_with_provider().await?;
401        if !credential.is_default() {
402            self.credential = Some(credential);
403        }
404        self.credential_provider = provider;
405
406        Ok(self)
407    }
408
409    pub fn anthropic(mut self) -> Self {
410        self.provider = Some(CloudProvider::Anthropic);
411        self
412    }
413
414    #[cfg(feature = "aws")]
415    pub(crate) fn with_aws_region(mut self, region: String) -> Self {
416        self.provider = Some(CloudProvider::Bedrock);
417        self.aws_region = Some(region);
418        self
419    }
420
421    #[cfg(feature = "gcp")]
422    pub(crate) fn with_gcp(mut self, project: String, region: String) -> Self {
423        self.provider = Some(CloudProvider::Vertex);
424        self.gcp_project = Some(project);
425        self.gcp_region = Some(region);
426        self
427    }
428
429    #[cfg(feature = "azure")]
430    pub(crate) fn with_azure_resource(mut self, resource: String) -> Self {
431        self.provider = Some(CloudProvider::Foundry);
432        self.azure_resource = Some(resource);
433        self
434    }
435
436    pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
437        self.oauth_config = Some(config);
438        self
439    }
440
441    pub fn models(mut self, models: ModelConfig) -> Self {
442        self.models = Some(models);
443        self
444    }
445
446    pub fn config(mut self, config: ProviderConfig) -> Self {
447        self.config = Some(config);
448        self
449    }
450
451    pub fn network(mut self, network: NetworkConfig) -> Self {
452        self.network = Some(network);
453        self
454    }
455
456    pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
457        self.gateway = Some(gateway);
458        self
459    }
460
461    pub fn timeout(mut self, timeout: Duration) -> Self {
462        self.timeout = Some(timeout);
463        self
464    }
465
466    pub fn fallback(mut self, config: FallbackConfig) -> Self {
467        self.fallback_config = Some(config);
468        self
469    }
470
471    pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
472        self.fallback_config = Some(FallbackConfig::new(model));
473        self
474    }
475
476    pub fn resilience(mut self, config: ResilienceConfig) -> Self {
477        self.resilience_config = Some(config);
478        self
479    }
480
481    pub fn with_default_resilience(mut self) -> Self {
482        self.resilience_config = Some(ResilienceConfig::default());
483        self
484    }
485
486    pub async fn build(self) -> Result<Client> {
487        let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
488
489        let models = self.models.unwrap_or_else(|| provider.default_models());
490
491        let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
492
493        let adapter: Box<dyn ProviderAdapter> = match provider {
494            CloudProvider::Anthropic => {
495                let adapter = if let Some(cred) = self.credential {
496                    let mut a = if let Some(cred_provider) = self.credential_provider {
497                        AnthropicAdapter::from_credential_provider(
498                            config,
499                            cred,
500                            self.oauth_config,
501                            cred_provider,
502                        )
503                    } else {
504                        AnthropicAdapter::from_credential(config, cred, self.oauth_config)
505                    };
506                    if let Some(ref gw) = self.gateway
507                        && let Some(ref url) = gw.base_url
508                    {
509                        a = a.with_base_url(url);
510                    }
511                    a
512                } else {
513                    let mut a = AnthropicAdapter::new(config);
514                    if let Some(ref gw) = self.gateway {
515                        if let Some(ref url) = gw.base_url {
516                            a = a.with_base_url(url);
517                        }
518                        if let Some(ref token) = gw.auth_token {
519                            a = a.with_api_key(token);
520                        }
521                    }
522                    a
523                };
524                Box::new(adapter)
525            }
526            #[cfg(feature = "aws")]
527            CloudProvider::Bedrock => {
528                let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
529                if let Some(region) = self.aws_region {
530                    adapter = adapter.with_region(region);
531                }
532                Box::new(adapter)
533            }
534            #[cfg(feature = "gcp")]
535            CloudProvider::Vertex => {
536                let mut adapter = adapter::VertexAdapter::from_env(config).await?;
537                if let Some(project) = self.gcp_project {
538                    adapter = adapter.with_project(project);
539                }
540                if let Some(region) = self.gcp_region {
541                    adapter = adapter.with_region(region);
542                }
543                Box::new(adapter)
544            }
545            #[cfg(feature = "azure")]
546            CloudProvider::Foundry => {
547                let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
548                if let Some(resource) = self.azure_resource {
549                    adapter = adapter.with_resource(resource);
550                }
551                Box::new(adapter)
552            }
553        };
554
555        let mut http_builder =
556            reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
557
558        if let Some(ref network) = self.network {
559            http_builder = network
560                .apply_to_builder(http_builder)
561                .map_err(|e| Error::Config(e.to_string()))?;
562        }
563
564        let http = http_builder.build().map_err(Error::Network)?;
565
566        let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
567
568        Ok(Client {
569            adapter: Arc::from(adapter),
570            http,
571            fallback_config: self.fallback_config,
572            resilience,
573        })
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    #[test]
582    fn test_client_builder() {
583        let _builder = Client::builder().anthropic();
584    }
585
586    #[test]
587    fn test_cloud_provider_from_env() {
588        let provider = CloudProvider::from_env();
589        assert_eq!(provider, CloudProvider::Anthropic);
590    }
591
592    #[tokio::test]
593    async fn test_builder_with_auth_credential() {
594        let _builder = Client::builder()
595            .anthropic()
596            .auth(Credential::api_key("test-key"))
597            .await
598            .unwrap();
599    }
600}