1pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16 AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17 DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18 ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21 BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27 ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28 CountTokensRequest, CountTokensResponse, CreateMessageRequest, DEFAULT_MAX_TOKENS, EffortLevel,
29 KeepConfig, KeepThinkingConfig, MAX_TOKENS_128K, MIN_MAX_TOKENS, MIN_THINKING_BUDGET,
30 OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, TokenValidationError, ToolChoice,
31};
32pub use network::{ClientCertConfig, NetworkConfig, PoolConfig, ProxyConfig};
33pub use recovery::StreamRecoveryState;
34pub use resilience::{
35 CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
36 RetryConfig,
37};
38pub use schema::{strict_schema, transform_for_strict};
39pub use streaming::{RecoverableStream, StreamItem, StreamParser};
40
41#[cfg(feature = "aws")]
42pub use adapter::BedrockAdapter;
43#[cfg(feature = "azure")]
44pub use adapter::FoundryAdapter;
45#[cfg(feature = "gcp")]
46pub use adapter::VertexAdapter;
47
48use std::sync::Arc;
49use std::time::Duration;
50
51use crate::auth::{Auth, Credential, OAuthConfig};
52use crate::{Error, Result};
53
54const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
55
56#[derive(Clone)]
57pub struct Client {
58 adapter: Arc<dyn ProviderAdapter>,
59 http: reqwest::Client,
60 fallback_config: Option<FallbackConfig>,
61 resilience: Option<Arc<Resilience>>,
62}
63
64impl Client {
65 pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
66 let timeout = DEFAULT_TIMEOUT;
67 let http = reqwest::Client::builder()
68 .timeout(timeout)
69 .build()
70 .map_err(Error::Network)?;
71
72 Ok(Self {
73 adapter: Arc::new(adapter),
74 http,
75 fallback_config: None,
76 resilience: None,
77 })
78 }
79
80 pub fn with_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
81 Self {
82 adapter: Arc::new(adapter),
83 http,
84 fallback_config: None,
85 resilience: None,
86 }
87 }
88
89 pub fn with_fallback(mut self, config: FallbackConfig) -> Self {
90 self.fallback_config = Some(config);
91 self
92 }
93
94 pub fn with_resilience(mut self, config: ResilienceConfig) -> Self {
95 self.resilience = Some(Arc::new(Resilience::new(config)));
96 self
97 }
98
99 pub fn resilience(&self) -> Option<&Arc<Resilience>> {
100 self.resilience.as_ref()
101 }
102
103 pub fn builder() -> ClientBuilder {
104 ClientBuilder::default()
105 }
106
107 pub async fn query(&self, prompt: &str) -> Result<String> {
108 self.query_with_model(prompt, ModelType::Primary).await
109 }
110
111 pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
112 let model = self.adapter.model(model_type).to_string();
113 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
114 .with_max_tokens(self.adapter.config().max_tokens);
115 request.validate()?;
116
117 let response = self.adapter.send(&self.http, request).await?;
118 Ok(response.text())
119 }
120
121 pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
122 request.validate()?;
123
124 let fallback = match &self.fallback_config {
125 Some(f) => f,
126 None => return self.adapter.send(&self.http, request).await,
127 };
128
129 let mut current_request = request;
130 let mut attempt = 0;
131 let mut using_fallback = false;
132
133 loop {
134 match self.adapter.send(&self.http, current_request.clone()).await {
135 Ok(response) => return Ok(response),
136 Err(e) => {
137 if !fallback.should_fallback(&e) {
138 return Err(e);
139 }
140
141 attempt += 1;
142 if attempt > fallback.max_retries {
143 return Err(e);
144 }
145
146 if !using_fallback {
147 tracing::warn!(
148 error = %e,
149 fallback_model = %fallback.fallback_model,
150 attempt,
151 max_retries = fallback.max_retries,
152 "Primary model failed, falling back"
153 );
154 current_request = current_request.with_model(&fallback.fallback_model);
155 using_fallback = true;
156 } else {
157 tracing::warn!(
158 error = %e,
159 attempt,
160 max_retries = fallback.max_retries,
161 "Fallback model failed, retrying"
162 );
163 }
164 }
165 }
166 }
167 }
168
169 pub async fn send_no_fallback(
170 &self,
171 request: CreateMessageRequest,
172 ) -> Result<crate::types::ApiResponse> {
173 request.validate()?;
174 self.adapter.send(&self.http, request).await
175 }
176
177 pub fn fallback_config(&self) -> Option<&FallbackConfig> {
178 self.fallback_config.as_ref()
179 }
180
181 pub async fn stream(
182 &self,
183 prompt: &str,
184 ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
185 let model = self.adapter.model(ModelType::Primary).to_string();
186 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
187 .with_max_tokens(self.adapter.config().max_tokens);
188 request.validate()?;
189
190 let response = self.adapter.send_stream(&self.http, request).await?;
191 let stream = StreamParser::new(response.bytes_stream());
192
193 Ok(futures::StreamExt::filter_map(stream, |item| async move {
194 match item {
195 Ok(StreamItem::Text(text)) => Some(Ok(text)),
196 Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
197 Ok(
198 StreamItem::Event(_) | StreamItem::Citation(_) | StreamItem::ToolUseComplete(_),
199 ) => None,
200 Err(e) => Some(Err(e)),
201 }
202 }))
203 }
204
205 pub async fn stream_request(
206 &self,
207 request: CreateMessageRequest,
208 ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
209 request.validate()?;
210 let response = self.adapter.send_stream(&self.http, request).await?;
211 Ok(StreamParser::new(response.bytes_stream()))
212 }
213
214 pub async fn stream_recoverable(
215 &self,
216 request: CreateMessageRequest,
217 ) -> Result<
218 RecoverableStream<
219 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
220 + Send
221 + 'static
222 + use<>,
223 >,
224 > {
225 request.validate()?;
226 let response = self.adapter.send_stream(&self.http, request).await?;
227 Ok(RecoverableStream::new(response.bytes_stream()))
228 }
229
230 pub async fn stream_with_recovery(
231 &self,
232 request: CreateMessageRequest,
233 recovery_state: Option<StreamRecoveryState>,
234 ) -> Result<
235 RecoverableStream<
236 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
237 + Send
238 + 'static
239 + use<>,
240 >,
241 > {
242 let request = match recovery_state {
243 Some(state) if state.is_recoverable() => {
244 let mut req = request;
245 req.messages = state.build_continuation_messages(&req.messages);
246 req
247 }
248 _ => request,
249 };
250 self.stream_recoverable(request).await
251 }
252
253 pub fn batch(&self) -> BatchClient<'_> {
254 BatchClient::new(self)
255 }
256
257 pub fn files(&self) -> FilesClient<'_> {
258 FilesClient::new(self)
259 }
260
261 pub fn adapter(&self) -> &dyn ProviderAdapter {
262 self.adapter.as_ref()
263 }
264
265 pub fn config(&self) -> &ProviderConfig {
266 self.adapter.config()
267 }
268
269 pub(crate) fn http(&self) -> &reqwest::Client {
270 &self.http
271 }
272
273 pub async fn refresh_credentials(&self) -> Result<()> {
274 self.adapter.refresh_credentials().await
275 }
276
277 pub async fn send_with_auth_retry(
278 &self,
279 request: CreateMessageRequest,
280 ) -> Result<crate::types::ApiResponse> {
281 self.adapter.ensure_fresh_credentials().await?;
282
283 match self.send(request.clone()).await {
284 Ok(resp) => Ok(resp),
285 Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
286 tracing::debug!("Received 401, refreshing credentials");
287 self.refresh_credentials().await?;
288 self.send(request).await
289 }
290 Err(e) => Err(e),
291 }
292 }
293
294 pub async fn send_stream_with_auth_retry(
295 &self,
296 request: CreateMessageRequest,
297 ) -> Result<reqwest::Response> {
298 request.validate()?;
299 self.adapter.ensure_fresh_credentials().await?;
300
301 match self.adapter.send_stream(&self.http, request.clone()).await {
302 Ok(resp) => Ok(resp),
303 Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
304 tracing::debug!("Received 401, refreshing credentials");
305 self.refresh_credentials().await?;
306 self.adapter.send_stream(&self.http, request).await
307 }
308 Err(e) => Err(e),
309 }
310 }
311
312 pub async fn count_tokens(
313 &self,
314 request: messages::CountTokensRequest,
315 ) -> Result<messages::CountTokensResponse> {
316 self.adapter.ensure_fresh_credentials().await?;
317
318 match self.adapter.count_tokens(&self.http, request.clone()).await {
319 Ok(resp) => Ok(resp),
320 Err(e) if e.is_unauthorized() && self.adapter.supports_credential_refresh() => {
321 self.refresh_credentials().await?;
322 self.adapter.count_tokens(&self.http, request).await
323 }
324 Err(e) => Err(e),
325 }
326 }
327
328 pub async fn count_tokens_for_request(
329 &self,
330 request: &CreateMessageRequest,
331 ) -> Result<messages::CountTokensResponse> {
332 let count_request = messages::CountTokensRequest::from_message_request(request);
333 self.count_tokens(count_request).await
334 }
335}
336
337impl std::fmt::Debug for Client {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 f.debug_struct("Client")
340 .field("provider", &self.adapter.name())
341 .finish()
342 }
343}
344
345#[derive(Default)]
346pub struct ClientBuilder {
347 provider: Option<CloudProvider>,
348 credential: Option<Credential>,
349 credential_provider: Option<Arc<dyn crate::auth::CredentialProvider>>,
350 oauth_config: Option<OAuthConfig>,
351 config: Option<ProviderConfig>,
352 models: Option<ModelConfig>,
353 network: Option<NetworkConfig>,
354 gateway: Option<GatewayConfig>,
355 timeout: Option<Duration>,
356 fallback_config: Option<FallbackConfig>,
357 resilience_config: Option<ResilienceConfig>,
358
359 #[cfg(feature = "aws")]
360 aws_region: Option<String>,
361 #[cfg(feature = "gcp")]
362 gcp_project: Option<String>,
363 #[cfg(feature = "gcp")]
364 gcp_region: Option<String>,
365 #[cfg(feature = "azure")]
366 azure_resource: Option<String>,
367}
368
369impl ClientBuilder {
370 pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
375 let auth = auth.into();
376
377 #[allow(unreachable_patterns)]
378 match &auth {
379 #[cfg(feature = "aws")]
380 Auth::Bedrock { region } => {
381 self.provider = Some(CloudProvider::Bedrock);
382 self.aws_region = Some(region.clone());
383 }
384 #[cfg(feature = "gcp")]
385 Auth::Vertex { project, region } => {
386 self.provider = Some(CloudProvider::Vertex);
387 self.gcp_project = Some(project.clone());
388 self.gcp_region = Some(region.clone());
389 }
390 #[cfg(feature = "azure")]
391 Auth::Foundry { resource } => {
392 self.provider = Some(CloudProvider::Foundry);
393 self.azure_resource = Some(resource.clone());
394 }
395 _ => {
396 self.provider = Some(CloudProvider::Anthropic);
397 }
398 }
399
400 let (credential, provider) = auth.resolve_with_provider().await?;
401 if !credential.is_default() {
402 self.credential = Some(credential);
403 }
404 self.credential_provider = provider;
405
406 Ok(self)
407 }
408
409 pub fn anthropic(mut self) -> Self {
410 self.provider = Some(CloudProvider::Anthropic);
411 self
412 }
413
414 #[cfg(feature = "aws")]
415 pub(crate) fn with_aws_region(mut self, region: String) -> Self {
416 self.provider = Some(CloudProvider::Bedrock);
417 self.aws_region = Some(region);
418 self
419 }
420
421 #[cfg(feature = "gcp")]
422 pub(crate) fn with_gcp(mut self, project: String, region: String) -> Self {
423 self.provider = Some(CloudProvider::Vertex);
424 self.gcp_project = Some(project);
425 self.gcp_region = Some(region);
426 self
427 }
428
429 #[cfg(feature = "azure")]
430 pub(crate) fn with_azure_resource(mut self, resource: String) -> Self {
431 self.provider = Some(CloudProvider::Foundry);
432 self.azure_resource = Some(resource);
433 self
434 }
435
436 pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
437 self.oauth_config = Some(config);
438 self
439 }
440
441 pub fn models(mut self, models: ModelConfig) -> Self {
442 self.models = Some(models);
443 self
444 }
445
446 pub fn config(mut self, config: ProviderConfig) -> Self {
447 self.config = Some(config);
448 self
449 }
450
451 pub fn network(mut self, network: NetworkConfig) -> Self {
452 self.network = Some(network);
453 self
454 }
455
456 pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
457 self.gateway = Some(gateway);
458 self
459 }
460
461 pub fn timeout(mut self, timeout: Duration) -> Self {
462 self.timeout = Some(timeout);
463 self
464 }
465
466 pub fn fallback(mut self, config: FallbackConfig) -> Self {
467 self.fallback_config = Some(config);
468 self
469 }
470
471 pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
472 self.fallback_config = Some(FallbackConfig::new(model));
473 self
474 }
475
476 pub fn resilience(mut self, config: ResilienceConfig) -> Self {
477 self.resilience_config = Some(config);
478 self
479 }
480
481 pub fn with_default_resilience(mut self) -> Self {
482 self.resilience_config = Some(ResilienceConfig::default());
483 self
484 }
485
486 pub async fn build(self) -> Result<Client> {
487 let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
488
489 let models = self.models.unwrap_or_else(|| provider.default_models());
490
491 let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
492
493 let adapter: Box<dyn ProviderAdapter> = match provider {
494 CloudProvider::Anthropic => {
495 let adapter = if let Some(cred) = self.credential {
496 let mut a = if let Some(cred_provider) = self.credential_provider {
497 AnthropicAdapter::from_credential_provider(
498 config,
499 cred,
500 self.oauth_config,
501 cred_provider,
502 )
503 } else {
504 AnthropicAdapter::from_credential(config, cred, self.oauth_config)
505 };
506 if let Some(ref gw) = self.gateway
507 && let Some(ref url) = gw.base_url
508 {
509 a = a.with_base_url(url);
510 }
511 a
512 } else {
513 let mut a = AnthropicAdapter::new(config);
514 if let Some(ref gw) = self.gateway {
515 if let Some(ref url) = gw.base_url {
516 a = a.with_base_url(url);
517 }
518 if let Some(ref token) = gw.auth_token {
519 a = a.with_api_key(token);
520 }
521 }
522 a
523 };
524 Box::new(adapter)
525 }
526 #[cfg(feature = "aws")]
527 CloudProvider::Bedrock => {
528 let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
529 if let Some(region) = self.aws_region {
530 adapter = adapter.with_region(region);
531 }
532 Box::new(adapter)
533 }
534 #[cfg(feature = "gcp")]
535 CloudProvider::Vertex => {
536 let mut adapter = adapter::VertexAdapter::from_env(config).await?;
537 if let Some(project) = self.gcp_project {
538 adapter = adapter.with_project(project);
539 }
540 if let Some(region) = self.gcp_region {
541 adapter = adapter.with_region(region);
542 }
543 Box::new(adapter)
544 }
545 #[cfg(feature = "azure")]
546 CloudProvider::Foundry => {
547 let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
548 if let Some(resource) = self.azure_resource {
549 adapter = adapter.with_resource(resource);
550 }
551 Box::new(adapter)
552 }
553 };
554
555 let mut http_builder =
556 reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
557
558 if let Some(ref network) = self.network {
559 http_builder = network
560 .apply_to_builder(http_builder)
561 .map_err(|e| Error::Config(e.to_string()))?;
562 }
563
564 let http = http_builder.build().map_err(Error::Network)?;
565
566 let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
567
568 Ok(Client {
569 adapter: Arc::from(adapter),
570 http,
571 fallback_config: self.fallback_config,
572 resilience,
573 })
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_client_builder() {
583 let _builder = Client::builder().anthropic();
584 }
585
586 #[test]
587 fn test_cloud_provider_from_env() {
588 let provider = CloudProvider::from_env();
589 assert_eq!(provider, CloudProvider::Anthropic);
590 }
591
592 #[tokio::test]
593 async fn test_builder_with_auth_credential() {
594 let _builder = Client::builder()
595 .anthropic()
596 .auth(Credential::api_key("test-key"))
597 .await
598 .unwrap();
599 }
600}