Skip to main content

claude_agent/client/
mod.rs

1//! Anthropic API client with multi-provider support.
2
3pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16    AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17    DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18    ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21    BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27    ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28    CountTokensRequest, CountTokensResponse, CreateMessageRequest, DEFAULT_MAX_TOKENS, EffortLevel,
29    KeepConfig, KeepThinkingConfig, MAX_TOKENS_128K, MIN_MAX_TOKENS, MIN_THINKING_BUDGET,
30    OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, TokenValidationError, ToolChoice,
31};
32pub use network::{ClientCertConfig, HttpNetworkConfig, PoolConfig, ProxyConfig};
33pub use recovery::StreamRecoveryState;
34pub use resilience::{
35    CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
36    RetryConfig,
37};
38pub use schema::{strict_schema, transform_for_strict};
39pub use streaming::{RecoverableStream, StreamItem, StreamParser};
40
41#[cfg(feature = "aws")]
42pub use adapter::BedrockAdapter;
43#[cfg(feature = "azure")]
44pub use adapter::FoundryAdapter;
45#[cfg(feature = "gcp")]
46pub use adapter::VertexAdapter;
47
48use std::sync::Arc;
49use std::time::Duration;
50
51use crate::auth::{Auth, Credential, OAuthConfig};
52use crate::{Error, Result};
53
54const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
55
56#[derive(Clone)]
57pub struct Client {
58    adapter: Arc<dyn ProviderAdapter>,
59    http: reqwest::Client,
60    fallback_config: Option<FallbackConfig>,
61    resilience: Option<Arc<Resilience>>,
62}
63
64impl Client {
65    pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
66        let timeout = DEFAULT_TIMEOUT;
67        let http = reqwest::Client::builder()
68            .timeout(timeout)
69            .build()
70            .map_err(Error::Network)?;
71
72        Ok(Self {
73            adapter: Arc::new(adapter),
74            http,
75            fallback_config: None,
76            resilience: None,
77        })
78    }
79
80    pub fn from_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
81        Self {
82            adapter: Arc::new(adapter),
83            http,
84            fallback_config: None,
85            resilience: None,
86        }
87    }
88
89    pub fn fallback(mut self, config: FallbackConfig) -> Self {
90        self.fallback_config = Some(config);
91        self
92    }
93
94    pub fn resilience(mut self, config: ResilienceConfig) -> Self {
95        self.resilience = Some(Arc::new(Resilience::new(config)));
96        self
97    }
98
99    pub fn resilience_ref(&self) -> Option<&Arc<Resilience>> {
100        self.resilience.as_ref()
101    }
102
103    pub fn builder() -> ClientBuilder {
104        ClientBuilder::default()
105    }
106
107    pub async fn query(&self, prompt: &str) -> Result<String> {
108        self.query_with_model(prompt, ModelType::Primary).await
109    }
110
111    pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
112        let model = self.adapter.model(model_type).to_string();
113        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
114            .max_tokens(self.adapter.config().max_tokens);
115        request.validate()?;
116
117        let response = self.adapter.send(&self.http, request).await?;
118        Ok(response.text())
119    }
120
121    fn check_circuit_breaker(&self) -> Result<Option<Arc<CircuitBreaker>>> {
122        let cb = self.resilience.as_ref().and_then(|r| r.circuit().cloned());
123        if let Some(ref cb) = cb
124            && !cb.allow_request()
125        {
126            return Err(Error::CircuitOpen);
127        }
128        Ok(cb)
129    }
130
131    fn record_circuit_result<T>(cb: &Option<Arc<CircuitBreaker>>, result: &Result<T>) {
132        if let Some(cb) = cb {
133            match result {
134                Ok(_) => cb.record_success(),
135                Err(_) => cb.record_failure(),
136            }
137        }
138    }
139
140    pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
141        let cb = self.check_circuit_breaker()?;
142        let result = self.send_inner(request).await;
143        Self::record_circuit_result(&cb, &result);
144        result
145    }
146
147    async fn send_inner(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
148        request.validate()?;
149
150        let fallback = match &self.fallback_config {
151            Some(f) => f,
152            None => return self.adapter.send(&self.http, request).await,
153        };
154
155        let mut current_request = request;
156        let mut attempt = 0;
157        let mut using_fallback = false;
158
159        loop {
160            match self.adapter.send(&self.http, current_request.clone()).await {
161                Ok(response) => return Ok(response),
162                Err(e) => {
163                    if !fallback.should_fallback(&e) {
164                        return Err(e);
165                    }
166
167                    attempt += 1;
168                    if attempt > fallback.max_retries {
169                        return Err(e);
170                    }
171
172                    if !using_fallback {
173                        tracing::warn!(
174                            error = %e,
175                            fallback_model = %fallback.fallback_model,
176                            attempt,
177                            max_retries = fallback.max_retries,
178                            "Primary model failed, falling back"
179                        );
180                        current_request = current_request.model(&fallback.fallback_model);
181                        using_fallback = true;
182                    } else {
183                        tracing::warn!(
184                            error = %e,
185                            attempt,
186                            max_retries = fallback.max_retries,
187                            "Fallback model failed, retrying"
188                        );
189                    }
190                }
191            }
192        }
193    }
194
195    pub async fn send_no_fallback(
196        &self,
197        request: CreateMessageRequest,
198    ) -> Result<crate::types::ApiResponse> {
199        request.validate()?;
200        self.adapter.send(&self.http, request).await
201    }
202
203    pub fn fallback_config(&self) -> Option<&FallbackConfig> {
204        self.fallback_config.as_ref()
205    }
206
207    pub async fn stream(
208        &self,
209        prompt: &str,
210    ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
211        let model = self.adapter.model(ModelType::Primary).to_string();
212        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
213            .max_tokens(self.adapter.config().max_tokens);
214        request.validate()?;
215
216        let response = self.adapter.send_stream(&self.http, request).await?;
217        let stream = StreamParser::new(response.bytes_stream());
218
219        Ok(futures::StreamExt::filter_map(stream, |item| async move {
220            match item {
221                Ok(StreamItem::Text(text)) => Some(Ok(text)),
222                Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
223                Ok(
224                    StreamItem::Event(_) | StreamItem::Citation(_) | StreamItem::ToolUseComplete(_),
225                ) => None,
226                Err(e) => Some(Err(e)),
227            }
228        }))
229    }
230
231    pub async fn stream_request(
232        &self,
233        request: CreateMessageRequest,
234    ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
235        let cb = self.check_circuit_breaker()?;
236        let result = self.stream_request_inner(request).await;
237        Self::record_circuit_result(&cb, &result);
238        result
239    }
240
241    async fn stream_request_inner(
242        &self,
243        request: CreateMessageRequest,
244    ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
245        request.validate()?;
246        let response = self.adapter.send_stream(&self.http, request).await?;
247        Ok(StreamParser::new(response.bytes_stream()))
248    }
249
250    pub async fn stream_recoverable(
251        &self,
252        request: CreateMessageRequest,
253    ) -> Result<
254        RecoverableStream<
255            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
256            + Send
257            + 'static
258            + use<>,
259        >,
260    > {
261        request.validate()?;
262        let response = self.adapter.send_stream(&self.http, request).await?;
263        Ok(RecoverableStream::new(response.bytes_stream()))
264    }
265
266    pub async fn stream_with_recovery(
267        &self,
268        request: CreateMessageRequest,
269        recovery_state: Option<StreamRecoveryState>,
270    ) -> Result<
271        RecoverableStream<
272            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
273            + Send
274            + 'static
275            + use<>,
276        >,
277    > {
278        let request = match recovery_state {
279            Some(state) if state.is_recoverable() => {
280                let mut req = request;
281                req.messages = state.build_continuation_messages(&req.messages);
282                req
283            }
284            _ => request,
285        };
286        self.stream_recoverable(request).await
287    }
288
289    pub fn batch(&self) -> BatchClient<'_> {
290        BatchClient::new(self)
291    }
292
293    pub fn files(&self) -> FilesClient<'_> {
294        FilesClient::new(self)
295    }
296
297    pub fn adapter(&self) -> &dyn ProviderAdapter {
298        self.adapter.as_ref()
299    }
300
301    pub fn config(&self) -> &ProviderConfig {
302        self.adapter.config()
303    }
304
305    pub(crate) fn http(&self) -> &reqwest::Client {
306        &self.http
307    }
308
309    pub async fn refresh_credentials(&self) -> Result<()> {
310        self.adapter.refresh_credentials().await
311    }
312
313    async fn with_auth_retry<T, F, Fut>(&self, op: F) -> Result<T>
314    where
315        F: Fn() -> Fut,
316        Fut: std::future::Future<Output = Result<T>>,
317    {
318        self.adapter.ensure_fresh_credentials().await?;
319
320        match op().await {
321            Ok(resp) => Ok(resp),
322            Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
323                tracing::debug!("Received 401, refreshing credentials");
324                self.refresh_credentials().await?;
325                op().await
326            }
327            Err(e) => Err(e),
328        }
329    }
330
331    pub async fn send_with_auth_retry(
332        &self,
333        request: CreateMessageRequest,
334    ) -> Result<crate::types::ApiResponse> {
335        self.with_auth_retry(|| self.send(request.clone())).await
336    }
337
338    pub async fn send_stream_with_auth_retry(
339        &self,
340        request: CreateMessageRequest,
341    ) -> Result<reqwest::Response> {
342        request.validate()?;
343        self.with_auth_retry(|| self.adapter.send_stream(&self.http, request.clone()))
344            .await
345    }
346
347    pub async fn count_tokens(
348        &self,
349        request: messages::CountTokensRequest,
350    ) -> Result<messages::CountTokensResponse> {
351        self.with_auth_retry(|| self.adapter.count_tokens(&self.http, request.clone()))
352            .await
353    }
354
355    pub async fn count_tokens_for_request(
356        &self,
357        request: &CreateMessageRequest,
358    ) -> Result<messages::CountTokensResponse> {
359        let count_request = messages::CountTokensRequest::from_message_request(request);
360        self.count_tokens(count_request).await
361    }
362}
363
364impl std::fmt::Debug for Client {
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        f.debug_struct("Client")
367            .field("provider", &self.adapter.name())
368            .finish()
369    }
370}
371
372#[derive(Default)]
373pub struct ClientBuilder {
374    provider: Option<CloudProvider>,
375    credential: Option<Credential>,
376    credential_provider: Option<Arc<dyn crate::auth::CredentialProvider>>,
377    oauth_config: Option<OAuthConfig>,
378    config: Option<ProviderConfig>,
379    models: Option<ModelConfig>,
380    network: Option<HttpNetworkConfig>,
381    gateway: Option<GatewayConfig>,
382    timeout: Option<Duration>,
383    fallback_config: Option<FallbackConfig>,
384    resilience_config: Option<ResilienceConfig>,
385
386    #[cfg(feature = "aws")]
387    aws_region: Option<String>,
388    #[cfg(feature = "gcp")]
389    gcp_project: Option<String>,
390    #[cfg(feature = "gcp")]
391    gcp_region: Option<String>,
392    #[cfg(feature = "azure")]
393    azure_resource: Option<String>,
394}
395
396impl ClientBuilder {
397    /// Configure authentication for the client.
398    ///
399    /// Accepts `Auth` enum or any type that converts to it (e.g., API key string).
400    /// For `Auth::ClaudeCli`, the credential provider is preserved for automatic token refresh.
401    pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
402        let auth = auth.into();
403
404        #[allow(unreachable_patterns)]
405        match &auth {
406            #[cfg(feature = "aws")]
407            Auth::Bedrock { region } => {
408                self.provider = Some(CloudProvider::Bedrock);
409                self.aws_region = Some(region.clone());
410            }
411            #[cfg(feature = "gcp")]
412            Auth::Vertex { project, region } => {
413                self.provider = Some(CloudProvider::Vertex);
414                self.gcp_project = Some(project.clone());
415                self.gcp_region = Some(region.clone());
416            }
417            #[cfg(feature = "azure")]
418            Auth::Foundry { resource } => {
419                self.provider = Some(CloudProvider::Foundry);
420                self.azure_resource = Some(resource.clone());
421            }
422            _ => {
423                self.provider = Some(CloudProvider::Anthropic);
424            }
425        }
426
427        let (credential, provider) = auth.resolve_with_provider().await?;
428        if !credential.is_placeholder() {
429            self.credential = Some(credential);
430        }
431        self.credential_provider = provider;
432
433        Ok(self)
434    }
435
436    pub fn anthropic(mut self) -> Self {
437        self.provider = Some(CloudProvider::Anthropic);
438        self
439    }
440
441    #[cfg(feature = "aws")]
442    pub(crate) fn aws_region(mut self, region: String) -> Self {
443        self.provider = Some(CloudProvider::Bedrock);
444        self.aws_region = Some(region);
445        self
446    }
447
448    #[cfg(feature = "gcp")]
449    pub(crate) fn gcp(mut self, project: String, region: String) -> Self {
450        self.provider = Some(CloudProvider::Vertex);
451        self.gcp_project = Some(project);
452        self.gcp_region = Some(region);
453        self
454    }
455
456    #[cfg(feature = "azure")]
457    pub(crate) fn azure_resource(mut self, resource: String) -> Self {
458        self.provider = Some(CloudProvider::Foundry);
459        self.azure_resource = Some(resource);
460        self
461    }
462
463    pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
464        self.oauth_config = Some(config);
465        self
466    }
467
468    pub fn models(mut self, models: ModelConfig) -> Self {
469        self.models = Some(models);
470        self
471    }
472
473    pub fn config(mut self, config: ProviderConfig) -> Self {
474        self.config = Some(config);
475        self
476    }
477
478    pub fn network(mut self, network: HttpNetworkConfig) -> Self {
479        self.network = Some(network);
480        self
481    }
482
483    pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
484        self.gateway = Some(gateway);
485        self
486    }
487
488    pub fn timeout(mut self, timeout: Duration) -> Self {
489        self.timeout = Some(timeout);
490        self
491    }
492
493    pub fn fallback(mut self, config: FallbackConfig) -> Self {
494        self.fallback_config = Some(config);
495        self
496    }
497
498    pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
499        self.fallback_config = Some(FallbackConfig::new(model));
500        self
501    }
502
503    pub fn resilience(mut self, config: ResilienceConfig) -> Self {
504        self.resilience_config = Some(config);
505        self
506    }
507
508    pub fn default_resilience(mut self) -> Self {
509        self.resilience_config = Some(ResilienceConfig::default());
510        self
511    }
512
513    pub async fn build(self) -> Result<Client> {
514        let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
515
516        let models = self.models.unwrap_or_else(|| provider.default_models());
517
518        let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
519
520        let adapter: Box<dyn ProviderAdapter> = match provider {
521            CloudProvider::Anthropic => {
522                let adapter = if let Some(ref cred) = self.credential {
523                    let mut a = if let Some(cred_provider) = self.credential_provider {
524                        AnthropicAdapter::from_credential_provider(
525                            config,
526                            cred,
527                            self.oauth_config,
528                            cred_provider,
529                        )
530                    } else {
531                        AnthropicAdapter::from_credential(config, cred, self.oauth_config)
532                    };
533                    if let Some(ref gw) = self.gateway
534                        && let Some(ref url) = gw.base_url
535                    {
536                        a = a.base_url(url);
537                    }
538                    a
539                } else {
540                    let mut a = AnthropicAdapter::new(config);
541                    if let Some(ref gw) = self.gateway {
542                        if let Some(ref url) = gw.base_url {
543                            a = a.base_url(url);
544                        }
545                        if let Some(ref token) = gw.auth_token {
546                            a = a.api_key(token);
547                        }
548                    }
549                    a
550                };
551                Box::new(adapter)
552            }
553            #[cfg(feature = "aws")]
554            CloudProvider::Bedrock => {
555                let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
556                if let Some(region) = self.aws_region {
557                    adapter = adapter.region(region);
558                }
559                Box::new(adapter)
560            }
561            #[cfg(feature = "gcp")]
562            CloudProvider::Vertex => {
563                let mut adapter = adapter::VertexAdapter::from_env(config).await?;
564                if let Some(project) = self.gcp_project {
565                    adapter = adapter.project(project);
566                }
567                if let Some(region) = self.gcp_region {
568                    adapter = adapter.region(region);
569                }
570                Box::new(adapter)
571            }
572            #[cfg(feature = "azure")]
573            CloudProvider::Foundry => {
574                let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
575                if let Some(resource) = self.azure_resource {
576                    adapter = adapter.resource(resource);
577                }
578                Box::new(adapter)
579            }
580        };
581
582        let mut http_builder =
583            reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
584
585        if let Some(ref network) = self.network {
586            http_builder = network
587                .apply_to_builder(http_builder)
588                .await
589                .map_err(|e| Error::Config(e.to_string()))?;
590        }
591
592        let http = http_builder.build().map_err(Error::Network)?;
593
594        let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
595
596        Ok(Client {
597            adapter: Arc::from(adapter),
598            http,
599            fallback_config: self.fallback_config,
600            resilience,
601        })
602    }
603}
604
605#[cfg(test)]
606mod tests {
607    use super::*;
608
609    #[test]
610    fn test_client_builder() {
611        let _builder = Client::builder().anthropic();
612    }
613
614    #[test]
615    fn test_cloud_provider_from_env() {
616        let provider = CloudProvider::from_env();
617        assert_eq!(provider, CloudProvider::Anthropic);
618    }
619
620    #[tokio::test]
621    async fn test_builder_with_auth_credential() {
622        let _builder = Client::builder()
623            .anthropic()
624            .auth(Credential::api_key("test-key"))
625            .await
626            .unwrap();
627    }
628}