1pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16 AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17 DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18 ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21 BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27 ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28 CountTokensRequest, CountTokensResponse, CreateMessageRequest, DEFAULT_MAX_TOKENS, EffortLevel,
29 KeepConfig, KeepThinkingConfig, MAX_TOKENS_128K, MIN_MAX_TOKENS, MIN_THINKING_BUDGET,
30 OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, TokenValidationError, ToolChoice,
31};
32pub use network::{ClientCertConfig, HttpNetworkConfig, PoolConfig, ProxyConfig};
33pub use recovery::StreamRecoveryState;
34pub use resilience::{
35 CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
36 RetryConfig,
37};
38pub use schema::{strict_schema, transform_for_strict};
39pub use streaming::{RecoverableStream, StreamItem, StreamParser};
40
41#[cfg(feature = "aws")]
42pub use adapter::BedrockAdapter;
43#[cfg(feature = "azure")]
44pub use adapter::FoundryAdapter;
45#[cfg(feature = "gcp")]
46pub use adapter::VertexAdapter;
47
48use std::sync::Arc;
49use std::time::Duration;
50
51use crate::auth::{Auth, Credential, OAuthConfig};
52use crate::{Error, Result};
53
54const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
55
56#[derive(Clone)]
57pub struct Client {
58 adapter: Arc<dyn ProviderAdapter>,
59 http: reqwest::Client,
60 fallback_config: Option<FallbackConfig>,
61 resilience: Option<Arc<Resilience>>,
62}
63
64impl Client {
65 pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
66 let timeout = DEFAULT_TIMEOUT;
67 let http = reqwest::Client::builder()
68 .timeout(timeout)
69 .build()
70 .map_err(Error::Network)?;
71
72 Ok(Self {
73 adapter: Arc::new(adapter),
74 http,
75 fallback_config: None,
76 resilience: None,
77 })
78 }
79
80 pub fn from_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
81 Self {
82 adapter: Arc::new(adapter),
83 http,
84 fallback_config: None,
85 resilience: None,
86 }
87 }
88
89 pub fn fallback(mut self, config: FallbackConfig) -> Self {
90 self.fallback_config = Some(config);
91 self
92 }
93
94 pub fn resilience(mut self, config: ResilienceConfig) -> Self {
95 self.resilience = Some(Arc::new(Resilience::new(config)));
96 self
97 }
98
99 pub fn resilience_ref(&self) -> Option<&Arc<Resilience>> {
100 self.resilience.as_ref()
101 }
102
103 pub fn builder() -> ClientBuilder {
104 ClientBuilder::default()
105 }
106
107 pub async fn query(&self, prompt: &str) -> Result<String> {
108 self.query_with_model(prompt, ModelType::Primary).await
109 }
110
111 pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
112 let model = self.adapter.model(model_type).to_string();
113 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
114 .max_tokens(self.adapter.config().max_tokens);
115 request.validate()?;
116
117 let response = self.adapter.send(&self.http, request).await?;
118 Ok(response.text())
119 }
120
121 fn check_circuit_breaker(&self) -> Result<Option<Arc<CircuitBreaker>>> {
122 let cb = self.resilience.as_ref().and_then(|r| r.circuit().cloned());
123 if let Some(ref cb) = cb
124 && !cb.allow_request()
125 {
126 return Err(Error::CircuitOpen);
127 }
128 Ok(cb)
129 }
130
131 fn record_circuit_result<T>(cb: &Option<Arc<CircuitBreaker>>, result: &Result<T>) {
132 if let Some(cb) = cb {
133 match result {
134 Ok(_) => cb.record_success(),
135 Err(_) => cb.record_failure(),
136 }
137 }
138 }
139
140 pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
141 let cb = self.check_circuit_breaker()?;
142 let result = self.send_inner(request).await;
143 Self::record_circuit_result(&cb, &result);
144 result
145 }
146
147 async fn send_inner(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
148 request.validate()?;
149
150 let fallback = match &self.fallback_config {
151 Some(f) => f,
152 None => return self.adapter.send(&self.http, request).await,
153 };
154
155 let mut current_request = request;
156 let mut attempt = 0;
157 let mut using_fallback = false;
158
159 loop {
160 match self.adapter.send(&self.http, current_request.clone()).await {
161 Ok(response) => return Ok(response),
162 Err(e) => {
163 if !fallback.should_fallback(&e) {
164 return Err(e);
165 }
166
167 attempt += 1;
168 if attempt > fallback.max_retries {
169 return Err(e);
170 }
171
172 if !using_fallback {
173 tracing::warn!(
174 error = %e,
175 fallback_model = %fallback.fallback_model,
176 attempt,
177 max_retries = fallback.max_retries,
178 "Primary model failed, falling back"
179 );
180 current_request = current_request.model(&fallback.fallback_model);
181 using_fallback = true;
182 } else {
183 tracing::warn!(
184 error = %e,
185 attempt,
186 max_retries = fallback.max_retries,
187 "Fallback model failed, retrying"
188 );
189 }
190 }
191 }
192 }
193 }
194
195 pub async fn send_no_fallback(
196 &self,
197 request: CreateMessageRequest,
198 ) -> Result<crate::types::ApiResponse> {
199 request.validate()?;
200 self.adapter.send(&self.http, request).await
201 }
202
203 pub fn fallback_config(&self) -> Option<&FallbackConfig> {
204 self.fallback_config.as_ref()
205 }
206
207 pub async fn stream(
208 &self,
209 prompt: &str,
210 ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
211 let model = self.adapter.model(ModelType::Primary).to_string();
212 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
213 .max_tokens(self.adapter.config().max_tokens);
214 request.validate()?;
215
216 let response = self.adapter.send_stream(&self.http, request).await?;
217 let stream = StreamParser::new(response.bytes_stream());
218
219 Ok(futures::StreamExt::filter_map(stream, |item| async move {
220 match item {
221 Ok(StreamItem::Text(text)) => Some(Ok(text)),
222 Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
223 Ok(
224 StreamItem::Event(_) | StreamItem::Citation(_) | StreamItem::ToolUseComplete(_),
225 ) => None,
226 Err(e) => Some(Err(e)),
227 }
228 }))
229 }
230
231 pub async fn stream_request(
232 &self,
233 request: CreateMessageRequest,
234 ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
235 let cb = self.check_circuit_breaker()?;
236 let result = self.stream_request_inner(request).await;
237 Self::record_circuit_result(&cb, &result);
238 result
239 }
240
241 async fn stream_request_inner(
242 &self,
243 request: CreateMessageRequest,
244 ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
245 request.validate()?;
246 let response = self.adapter.send_stream(&self.http, request).await?;
247 Ok(StreamParser::new(response.bytes_stream()))
248 }
249
250 pub async fn stream_recoverable(
251 &self,
252 request: CreateMessageRequest,
253 ) -> Result<
254 RecoverableStream<
255 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
256 + Send
257 + 'static
258 + use<>,
259 >,
260 > {
261 request.validate()?;
262 let response = self.adapter.send_stream(&self.http, request).await?;
263 Ok(RecoverableStream::new(response.bytes_stream()))
264 }
265
266 pub async fn stream_with_recovery(
267 &self,
268 request: CreateMessageRequest,
269 recovery_state: Option<StreamRecoveryState>,
270 ) -> Result<
271 RecoverableStream<
272 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
273 + Send
274 + 'static
275 + use<>,
276 >,
277 > {
278 let request = match recovery_state {
279 Some(state) if state.is_recoverable() => {
280 let mut req = request;
281 req.messages = state.build_continuation_messages(&req.messages);
282 req
283 }
284 _ => request,
285 };
286 self.stream_recoverable(request).await
287 }
288
289 pub fn batch(&self) -> BatchClient<'_> {
290 BatchClient::new(self)
291 }
292
293 pub fn files(&self) -> FilesClient<'_> {
294 FilesClient::new(self)
295 }
296
297 pub fn adapter(&self) -> &dyn ProviderAdapter {
298 self.adapter.as_ref()
299 }
300
301 pub fn config(&self) -> &ProviderConfig {
302 self.adapter.config()
303 }
304
305 pub(crate) fn http(&self) -> &reqwest::Client {
306 &self.http
307 }
308
309 pub async fn refresh_credentials(&self) -> Result<()> {
310 self.adapter.refresh_credentials().await
311 }
312
313 async fn with_auth_retry<T, F, Fut>(&self, op: F) -> Result<T>
314 where
315 F: Fn() -> Fut,
316 Fut: std::future::Future<Output = Result<T>>,
317 {
318 self.adapter.ensure_fresh_credentials().await?;
319
320 match op().await {
321 Ok(resp) => Ok(resp),
322 Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
323 tracing::debug!("Received 401, refreshing credentials");
324 self.refresh_credentials().await?;
325 op().await
326 }
327 Err(e) => Err(e),
328 }
329 }
330
331 pub async fn send_with_auth_retry(
332 &self,
333 request: CreateMessageRequest,
334 ) -> Result<crate::types::ApiResponse> {
335 self.with_auth_retry(|| self.send(request.clone())).await
336 }
337
338 pub async fn send_stream_with_auth_retry(
339 &self,
340 request: CreateMessageRequest,
341 ) -> Result<reqwest::Response> {
342 request.validate()?;
343 self.with_auth_retry(|| self.adapter.send_stream(&self.http, request.clone()))
344 .await
345 }
346
347 pub async fn count_tokens(
348 &self,
349 request: messages::CountTokensRequest,
350 ) -> Result<messages::CountTokensResponse> {
351 self.with_auth_retry(|| self.adapter.count_tokens(&self.http, request.clone()))
352 .await
353 }
354
355 pub async fn count_tokens_for_request(
356 &self,
357 request: &CreateMessageRequest,
358 ) -> Result<messages::CountTokensResponse> {
359 let count_request = messages::CountTokensRequest::from_message_request(request);
360 self.count_tokens(count_request).await
361 }
362}
363
364impl std::fmt::Debug for Client {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 f.debug_struct("Client")
367 .field("provider", &self.adapter.name())
368 .finish()
369 }
370}
371
372#[derive(Default)]
373pub struct ClientBuilder {
374 provider: Option<CloudProvider>,
375 credential: Option<Credential>,
376 credential_provider: Option<Arc<dyn crate::auth::CredentialProvider>>,
377 oauth_config: Option<OAuthConfig>,
378 config: Option<ProviderConfig>,
379 models: Option<ModelConfig>,
380 network: Option<HttpNetworkConfig>,
381 gateway: Option<GatewayConfig>,
382 timeout: Option<Duration>,
383 fallback_config: Option<FallbackConfig>,
384 resilience_config: Option<ResilienceConfig>,
385
386 #[cfg(feature = "aws")]
387 aws_region: Option<String>,
388 #[cfg(feature = "gcp")]
389 gcp_project: Option<String>,
390 #[cfg(feature = "gcp")]
391 gcp_region: Option<String>,
392 #[cfg(feature = "azure")]
393 azure_resource: Option<String>,
394}
395
396impl ClientBuilder {
397 pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
402 let auth = auth.into();
403
404 #[allow(unreachable_patterns)]
405 match &auth {
406 #[cfg(feature = "aws")]
407 Auth::Bedrock { region } => {
408 self.provider = Some(CloudProvider::Bedrock);
409 self.aws_region = Some(region.clone());
410 }
411 #[cfg(feature = "gcp")]
412 Auth::Vertex { project, region } => {
413 self.provider = Some(CloudProvider::Vertex);
414 self.gcp_project = Some(project.clone());
415 self.gcp_region = Some(region.clone());
416 }
417 #[cfg(feature = "azure")]
418 Auth::Foundry { resource } => {
419 self.provider = Some(CloudProvider::Foundry);
420 self.azure_resource = Some(resource.clone());
421 }
422 _ => {
423 self.provider = Some(CloudProvider::Anthropic);
424 }
425 }
426
427 let (credential, provider) = auth.resolve_with_provider().await?;
428 if !credential.is_placeholder() {
429 self.credential = Some(credential);
430 }
431 self.credential_provider = provider;
432
433 Ok(self)
434 }
435
436 pub fn anthropic(mut self) -> Self {
437 self.provider = Some(CloudProvider::Anthropic);
438 self
439 }
440
441 #[cfg(feature = "aws")]
442 pub(crate) fn aws_region(mut self, region: String) -> Self {
443 self.provider = Some(CloudProvider::Bedrock);
444 self.aws_region = Some(region);
445 self
446 }
447
448 #[cfg(feature = "gcp")]
449 pub(crate) fn gcp(mut self, project: String, region: String) -> Self {
450 self.provider = Some(CloudProvider::Vertex);
451 self.gcp_project = Some(project);
452 self.gcp_region = Some(region);
453 self
454 }
455
456 #[cfg(feature = "azure")]
457 pub(crate) fn azure_resource(mut self, resource: String) -> Self {
458 self.provider = Some(CloudProvider::Foundry);
459 self.azure_resource = Some(resource);
460 self
461 }
462
463 pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
464 self.oauth_config = Some(config);
465 self
466 }
467
468 pub fn models(mut self, models: ModelConfig) -> Self {
469 self.models = Some(models);
470 self
471 }
472
473 pub fn config(mut self, config: ProviderConfig) -> Self {
474 self.config = Some(config);
475 self
476 }
477
478 pub fn network(mut self, network: HttpNetworkConfig) -> Self {
479 self.network = Some(network);
480 self
481 }
482
483 pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
484 self.gateway = Some(gateway);
485 self
486 }
487
488 pub fn timeout(mut self, timeout: Duration) -> Self {
489 self.timeout = Some(timeout);
490 self
491 }
492
493 pub fn fallback(mut self, config: FallbackConfig) -> Self {
494 self.fallback_config = Some(config);
495 self
496 }
497
498 pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
499 self.fallback_config = Some(FallbackConfig::new(model));
500 self
501 }
502
503 pub fn resilience(mut self, config: ResilienceConfig) -> Self {
504 self.resilience_config = Some(config);
505 self
506 }
507
508 pub fn default_resilience(mut self) -> Self {
509 self.resilience_config = Some(ResilienceConfig::default());
510 self
511 }
512
513 pub async fn build(self) -> Result<Client> {
514 let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
515
516 let models = self.models.unwrap_or_else(|| provider.default_models());
517
518 let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
519
520 let adapter: Box<dyn ProviderAdapter> = match provider {
521 CloudProvider::Anthropic => {
522 let adapter = if let Some(ref cred) = self.credential {
523 let mut a = if let Some(cred_provider) = self.credential_provider {
524 AnthropicAdapter::from_credential_provider(
525 config,
526 cred,
527 self.oauth_config,
528 cred_provider,
529 )
530 } else {
531 AnthropicAdapter::from_credential(config, cred, self.oauth_config)
532 };
533 if let Some(ref gw) = self.gateway
534 && let Some(ref url) = gw.base_url
535 {
536 a = a.base_url(url);
537 }
538 a
539 } else {
540 let mut a = AnthropicAdapter::new(config);
541 if let Some(ref gw) = self.gateway {
542 if let Some(ref url) = gw.base_url {
543 a = a.base_url(url);
544 }
545 if let Some(ref token) = gw.auth_token {
546 a = a.api_key(token);
547 }
548 }
549 a
550 };
551 Box::new(adapter)
552 }
553 #[cfg(feature = "aws")]
554 CloudProvider::Bedrock => {
555 let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
556 if let Some(region) = self.aws_region {
557 adapter = adapter.region(region);
558 }
559 Box::new(adapter)
560 }
561 #[cfg(feature = "gcp")]
562 CloudProvider::Vertex => {
563 let mut adapter = adapter::VertexAdapter::from_env(config).await?;
564 if let Some(project) = self.gcp_project {
565 adapter = adapter.project(project);
566 }
567 if let Some(region) = self.gcp_region {
568 adapter = adapter.region(region);
569 }
570 Box::new(adapter)
571 }
572 #[cfg(feature = "azure")]
573 CloudProvider::Foundry => {
574 let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
575 if let Some(resource) = self.azure_resource {
576 adapter = adapter.resource(resource);
577 }
578 Box::new(adapter)
579 }
580 };
581
582 let mut http_builder =
583 reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
584
585 if let Some(ref network) = self.network {
586 http_builder = network
587 .apply_to_builder(http_builder)
588 .await
589 .map_err(|e| Error::Config(e.to_string()))?;
590 }
591
592 let http = http_builder.build().map_err(Error::Network)?;
593
594 let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
595
596 Ok(Client {
597 adapter: Arc::from(adapter),
598 http,
599 fallback_config: self.fallback_config,
600 resilience,
601 })
602 }
603}
604
605#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_client_builder() {
611 let _builder = Client::builder().anthropic();
612 }
613
614 #[test]
615 fn test_cloud_provider_from_env() {
616 let provider = CloudProvider::from_env();
617 assert_eq!(provider, CloudProvider::Anthropic);
618 }
619
620 #[tokio::test]
621 async fn test_builder_with_auth_credential() {
622 let _builder = Client::builder()
623 .anthropic()
624 .auth(Credential::api_key("test-key"))
625 .await
626 .unwrap();
627 }
628}