claude_agent/client/
mod.rs

1//! Anthropic API client with multi-provider support.
2
3pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16    AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17    DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18    ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21    BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27    ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28    CountTokensRequest, CountTokensResponse, CreateMessageRequest, EffortLevel, KeepConfig,
29    KeepThinkingConfig, OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, ToolChoice,
30};
31pub use network::{ClientCertConfig, NetworkConfig, ProxyConfig};
32pub use recovery::StreamRecoveryState;
33pub use resilience::{
34    CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
35    RetryConfig,
36};
37pub use schema::{strict_schema, transform_for_strict};
38pub use streaming::{RecoverableStream, StreamItem, StreamParser};
39
40#[cfg(feature = "aws")]
41pub use adapter::BedrockAdapter;
42#[cfg(feature = "azure")]
43pub use adapter::FoundryAdapter;
44#[cfg(feature = "gcp")]
45pub use adapter::VertexAdapter;
46
47use std::sync::Arc;
48use std::time::Duration;
49
50use crate::auth::{Auth, Credential, OAuthConfig};
51use crate::{Error, Result};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
54
55#[derive(Clone)]
56pub struct Client {
57    adapter: Arc<dyn ProviderAdapter>,
58    http: reqwest::Client,
59    fallback_config: Option<FallbackConfig>,
60    resilience: Option<Arc<Resilience>>,
61}
62
63impl Client {
64    pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
65        let timeout = DEFAULT_TIMEOUT;
66        let http = reqwest::Client::builder()
67            .timeout(timeout)
68            .build()
69            .map_err(Error::Network)?;
70
71        Ok(Self {
72            adapter: Arc::new(adapter),
73            http,
74            fallback_config: None,
75            resilience: None,
76        })
77    }
78
79    pub fn with_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
80        Self {
81            adapter: Arc::new(adapter),
82            http,
83            fallback_config: None,
84            resilience: None,
85        }
86    }
87
88    pub fn with_fallback(mut self, config: FallbackConfig) -> Self {
89        self.fallback_config = Some(config);
90        self
91    }
92
93    pub fn with_resilience(mut self, config: ResilienceConfig) -> Self {
94        self.resilience = Some(Arc::new(Resilience::new(config)));
95        self
96    }
97
98    pub fn resilience(&self) -> Option<&Arc<Resilience>> {
99        self.resilience.as_ref()
100    }
101
102    pub fn builder() -> ClientBuilder {
103        ClientBuilder::default()
104    }
105
106    pub async fn query(&self, prompt: &str) -> Result<String> {
107        self.query_with_model(prompt, ModelType::Primary).await
108    }
109
110    pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
111        let model = self.adapter.model(model_type).to_string();
112        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
113            .with_max_tokens(self.adapter.config().max_tokens);
114
115        let response = self.adapter.send(&self.http, request).await?;
116        Ok(response.text())
117    }
118
119    pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
120        let fallback = match &self.fallback_config {
121            Some(f) => f,
122            None => return self.adapter.send(&self.http, request).await,
123        };
124
125        let mut current_request = request;
126        let mut attempt = 0;
127        let mut using_fallback = false;
128
129        loop {
130            match self.adapter.send(&self.http, current_request.clone()).await {
131                Ok(response) => return Ok(response),
132                Err(e) => {
133                    if !fallback.should_fallback(&e) {
134                        return Err(e);
135                    }
136
137                    attempt += 1;
138                    if attempt > fallback.max_retries {
139                        return Err(e);
140                    }
141
142                    if !using_fallback {
143                        tracing::warn!(
144                            error = %e,
145                            fallback_model = %fallback.fallback_model,
146                            attempt,
147                            max_retries = fallback.max_retries,
148                            "Primary model failed, falling back"
149                        );
150                        current_request = current_request.with_model(&fallback.fallback_model);
151                        using_fallback = true;
152                    } else {
153                        tracing::warn!(
154                            error = %e,
155                            attempt,
156                            max_retries = fallback.max_retries,
157                            "Fallback model failed, retrying"
158                        );
159                    }
160                }
161            }
162        }
163    }
164
165    pub async fn send_no_fallback(
166        &self,
167        request: CreateMessageRequest,
168    ) -> Result<crate::types::ApiResponse> {
169        self.adapter.send(&self.http, request).await
170    }
171
172    pub fn fallback_config(&self) -> Option<&FallbackConfig> {
173        self.fallback_config.as_ref()
174    }
175
176    pub async fn stream(
177        &self,
178        prompt: &str,
179    ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
180        let model = self.adapter.model(ModelType::Primary).to_string();
181        let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
182            .with_max_tokens(self.adapter.config().max_tokens);
183
184        let response = self.adapter.send_stream(&self.http, request).await?;
185        let stream = StreamParser::new(response.bytes_stream());
186
187        Ok(futures::StreamExt::filter_map(stream, |item| async move {
188            match item {
189                Ok(StreamItem::Text(text)) => Some(Ok(text)),
190                Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
191                Ok(StreamItem::Event(_) | StreamItem::Citation(_)) => None,
192                Err(e) => Some(Err(e)),
193            }
194        }))
195    }
196
197    pub async fn stream_request(
198        &self,
199        request: CreateMessageRequest,
200    ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
201        let response = self.adapter.send_stream(&self.http, request).await?;
202        Ok(StreamParser::new(response.bytes_stream()))
203    }
204
205    pub async fn stream_recoverable(
206        &self,
207        request: CreateMessageRequest,
208    ) -> Result<
209        RecoverableStream<
210            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
211            + Send
212            + 'static
213            + use<>,
214        >,
215    > {
216        let response = self.adapter.send_stream(&self.http, request).await?;
217        Ok(RecoverableStream::new(response.bytes_stream()))
218    }
219
220    pub async fn stream_with_recovery(
221        &self,
222        request: CreateMessageRequest,
223        recovery_state: Option<StreamRecoveryState>,
224    ) -> Result<
225        RecoverableStream<
226            impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
227            + Send
228            + 'static
229            + use<>,
230        >,
231    > {
232        let request = match recovery_state {
233            Some(state) if state.is_recoverable() => {
234                let mut req = request;
235                req.messages = state.build_continuation_messages(&req.messages);
236                req
237            }
238            _ => request,
239        };
240        self.stream_recoverable(request).await
241    }
242
243    pub fn batch(&self) -> BatchClient<'_> {
244        BatchClient::new(self)
245    }
246
247    pub fn files(&self) -> FilesClient<'_> {
248        FilesClient::new(self)
249    }
250
251    pub fn adapter(&self) -> &dyn ProviderAdapter {
252        self.adapter.as_ref()
253    }
254
255    pub fn config(&self) -> &ProviderConfig {
256        self.adapter.config()
257    }
258
259    pub(crate) fn http(&self) -> &reqwest::Client {
260        &self.http
261    }
262
263    pub async fn refresh_credentials(&self) -> Result<()> {
264        self.adapter.refresh_credentials().await
265    }
266
267    pub async fn count_tokens(
268        &self,
269        request: messages::CountTokensRequest,
270    ) -> Result<messages::CountTokensResponse> {
271        self.adapter.count_tokens(&self.http, request).await
272    }
273
274    pub async fn count_tokens_for_request(
275        &self,
276        request: &CreateMessageRequest,
277    ) -> Result<messages::CountTokensResponse> {
278        let count_request = messages::CountTokensRequest::from_message_request(request);
279        self.count_tokens(count_request).await
280    }
281}
282
283impl std::fmt::Debug for Client {
284    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285        f.debug_struct("Client")
286            .field("provider", &self.adapter.name())
287            .finish()
288    }
289}
290
291#[derive(Default)]
292pub struct ClientBuilder {
293    provider: Option<CloudProvider>,
294    credential: Option<Credential>,
295    oauth_config: Option<OAuthConfig>,
296    config: Option<ProviderConfig>,
297    models: Option<ModelConfig>,
298    network: Option<NetworkConfig>,
299    gateway: Option<GatewayConfig>,
300    timeout: Option<Duration>,
301    fallback_config: Option<FallbackConfig>,
302    resilience_config: Option<ResilienceConfig>,
303
304    #[cfg(feature = "aws")]
305    aws_region: Option<String>,
306    #[cfg(feature = "gcp")]
307    gcp_project: Option<String>,
308    #[cfg(feature = "gcp")]
309    gcp_region: Option<String>,
310    #[cfg(feature = "azure")]
311    azure_resource: Option<String>,
312}
313
314impl ClientBuilder {
315    /// Configure authentication for the client.
316    ///
317    /// Accepts `Auth` enum or any type that converts to it (e.g., API key string).
318    pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
319        let auth = auth.into();
320
321        #[allow(unreachable_patterns)]
322        match &auth {
323            #[cfg(feature = "aws")]
324            Auth::Bedrock { region } => {
325                self.provider = Some(CloudProvider::Bedrock);
326                self.aws_region = Some(region.clone());
327            }
328            #[cfg(feature = "gcp")]
329            Auth::Vertex { project, region } => {
330                self.provider = Some(CloudProvider::Vertex);
331                self.gcp_project = Some(project.clone());
332                self.gcp_region = Some(region.clone());
333            }
334            #[cfg(feature = "azure")]
335            Auth::Foundry { resource } => {
336                self.provider = Some(CloudProvider::Foundry);
337                self.azure_resource = Some(resource.clone());
338            }
339            _ => {
340                self.provider = Some(CloudProvider::Anthropic);
341            }
342        }
343
344        let credential = auth.resolve().await?;
345        if !credential.is_default() {
346            self.credential = Some(credential);
347        }
348
349        Ok(self)
350    }
351
352    pub fn anthropic(mut self) -> Self {
353        self.provider = Some(CloudProvider::Anthropic);
354        self
355    }
356
357    #[cfg(feature = "aws")]
358    pub(crate) fn with_aws_region(mut self, region: String) -> Self {
359        self.provider = Some(CloudProvider::Bedrock);
360        self.aws_region = Some(region);
361        self
362    }
363
364    #[cfg(feature = "gcp")]
365    pub(crate) fn with_gcp(mut self, project: String, region: String) -> Self {
366        self.provider = Some(CloudProvider::Vertex);
367        self.gcp_project = Some(project);
368        self.gcp_region = Some(region);
369        self
370    }
371
372    #[cfg(feature = "azure")]
373    pub(crate) fn with_azure_resource(mut self, resource: String) -> Self {
374        self.provider = Some(CloudProvider::Foundry);
375        self.azure_resource = Some(resource);
376        self
377    }
378
379    pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
380        self.oauth_config = Some(config);
381        self
382    }
383
384    pub fn models(mut self, models: ModelConfig) -> Self {
385        self.models = Some(models);
386        self
387    }
388
389    pub fn config(mut self, config: ProviderConfig) -> Self {
390        self.config = Some(config);
391        self
392    }
393
394    pub fn network(mut self, network: NetworkConfig) -> Self {
395        self.network = Some(network);
396        self
397    }
398
399    pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
400        self.gateway = Some(gateway);
401        self
402    }
403
404    pub fn timeout(mut self, timeout: Duration) -> Self {
405        self.timeout = Some(timeout);
406        self
407    }
408
409    pub fn fallback(mut self, config: FallbackConfig) -> Self {
410        self.fallback_config = Some(config);
411        self
412    }
413
414    pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
415        self.fallback_config = Some(FallbackConfig::new(model));
416        self
417    }
418
419    pub fn resilience(mut self, config: ResilienceConfig) -> Self {
420        self.resilience_config = Some(config);
421        self
422    }
423
424    pub fn with_default_resilience(mut self) -> Self {
425        self.resilience_config = Some(ResilienceConfig::default());
426        self
427    }
428
429    pub async fn build(self) -> Result<Client> {
430        let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
431
432        let models = self.models.unwrap_or_else(|| provider.default_models());
433
434        let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
435
436        let adapter: Box<dyn ProviderAdapter> = match provider {
437            CloudProvider::Anthropic => {
438                let adapter = if let Some(cred) = self.credential {
439                    let mut a = AnthropicAdapter::from_credential(config, cred, self.oauth_config);
440                    if let Some(ref gw) = self.gateway
441                        && let Some(ref url) = gw.base_url
442                    {
443                        a = a.with_base_url(url);
444                    }
445                    a
446                } else {
447                    let mut a = AnthropicAdapter::new(config);
448                    if let Some(ref gw) = self.gateway {
449                        if let Some(ref url) = gw.base_url {
450                            a = a.with_base_url(url);
451                        }
452                        if let Some(ref token) = gw.auth_token {
453                            a = a.with_api_key(token);
454                        }
455                    }
456                    a
457                };
458                Box::new(adapter)
459            }
460            #[cfg(feature = "aws")]
461            CloudProvider::Bedrock => {
462                let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
463                if let Some(region) = self.aws_region {
464                    adapter = adapter.with_region(region);
465                }
466                Box::new(adapter)
467            }
468            #[cfg(feature = "gcp")]
469            CloudProvider::Vertex => {
470                let mut adapter = adapter::VertexAdapter::from_env(config).await?;
471                if let Some(project) = self.gcp_project {
472                    adapter = adapter.with_project(project);
473                }
474                if let Some(region) = self.gcp_region {
475                    adapter = adapter.with_region(region);
476                }
477                Box::new(adapter)
478            }
479            #[cfg(feature = "azure")]
480            CloudProvider::Foundry => {
481                let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
482                if let Some(resource) = self.azure_resource {
483                    adapter = adapter.with_resource(resource);
484                }
485                Box::new(adapter)
486            }
487        };
488
489        let mut http_builder =
490            reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
491
492        if let Some(ref network) = self.network {
493            http_builder = network
494                .apply_to_builder(http_builder)
495                .map_err(|e| Error::Config(e.to_string()))?;
496        }
497
498        let http = http_builder.build().map_err(Error::Network)?;
499
500        let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
501
502        Ok(Client {
503            adapter: Arc::from(adapter),
504            http,
505            fallback_config: self.fallback_config,
506            resilience,
507        })
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514
515    #[test]
516    fn test_client_builder() {
517        let _builder = Client::builder().anthropic();
518    }
519
520    #[test]
521    fn test_cloud_provider_from_env() {
522        let provider = CloudProvider::from_env();
523        assert_eq!(provider, CloudProvider::Anthropic);
524    }
525
526    #[tokio::test]
527    async fn test_builder_with_auth_credential() {
528        let _builder = Client::builder()
529            .anthropic()
530            .auth(Credential::api_key("test-key"))
531            .await
532            .unwrap();
533    }
534}