1pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16 AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17 DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18 ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21 BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27 ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28 CountTokensRequest, CountTokensResponse, CreateMessageRequest, EffortLevel, KeepConfig,
29 KeepThinkingConfig, OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, ToolChoice,
30};
31pub use network::{ClientCertConfig, NetworkConfig, PoolConfig, ProxyConfig};
32pub use recovery::StreamRecoveryState;
33pub use resilience::{
34 CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
35 RetryConfig,
36};
37pub use schema::{strict_schema, transform_for_strict};
38pub use streaming::{RecoverableStream, StreamItem, StreamParser};
39
40#[cfg(feature = "aws")]
41pub use adapter::BedrockAdapter;
42#[cfg(feature = "azure")]
43pub use adapter::FoundryAdapter;
44#[cfg(feature = "gcp")]
45pub use adapter::VertexAdapter;
46
47use std::sync::Arc;
48use std::time::Duration;
49
50use crate::auth::{Auth, Credential, OAuthConfig};
51use crate::{Error, Result};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
54
55#[derive(Clone)]
56pub struct Client {
57 adapter: Arc<dyn ProviderAdapter>,
58 http: reqwest::Client,
59 fallback_config: Option<FallbackConfig>,
60 resilience: Option<Arc<Resilience>>,
61}
62
63impl Client {
64 pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
65 let timeout = DEFAULT_TIMEOUT;
66 let http = reqwest::Client::builder()
67 .timeout(timeout)
68 .build()
69 .map_err(Error::Network)?;
70
71 Ok(Self {
72 adapter: Arc::new(adapter),
73 http,
74 fallback_config: None,
75 resilience: None,
76 })
77 }
78
79 pub fn with_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
80 Self {
81 adapter: Arc::new(adapter),
82 http,
83 fallback_config: None,
84 resilience: None,
85 }
86 }
87
88 pub fn with_fallback(mut self, config: FallbackConfig) -> Self {
89 self.fallback_config = Some(config);
90 self
91 }
92
93 pub fn with_resilience(mut self, config: ResilienceConfig) -> Self {
94 self.resilience = Some(Arc::new(Resilience::new(config)));
95 self
96 }
97
98 pub fn resilience(&self) -> Option<&Arc<Resilience>> {
99 self.resilience.as_ref()
100 }
101
102 pub fn builder() -> ClientBuilder {
103 ClientBuilder::default()
104 }
105
106 pub async fn query(&self, prompt: &str) -> Result<String> {
107 self.query_with_model(prompt, ModelType::Primary).await
108 }
109
110 pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
111 let model = self.adapter.model(model_type).to_string();
112 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
113 .with_max_tokens(self.adapter.config().max_tokens);
114
115 let response = self.adapter.send(&self.http, request).await?;
116 Ok(response.text())
117 }
118
119 pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
120 let fallback = match &self.fallback_config {
121 Some(f) => f,
122 None => return self.adapter.send(&self.http, request).await,
123 };
124
125 let mut current_request = request;
126 let mut attempt = 0;
127 let mut using_fallback = false;
128
129 loop {
130 match self.adapter.send(&self.http, current_request.clone()).await {
131 Ok(response) => return Ok(response),
132 Err(e) => {
133 if !fallback.should_fallback(&e) {
134 return Err(e);
135 }
136
137 attempt += 1;
138 if attempt > fallback.max_retries {
139 return Err(e);
140 }
141
142 if !using_fallback {
143 tracing::warn!(
144 error = %e,
145 fallback_model = %fallback.fallback_model,
146 attempt,
147 max_retries = fallback.max_retries,
148 "Primary model failed, falling back"
149 );
150 current_request = current_request.with_model(&fallback.fallback_model);
151 using_fallback = true;
152 } else {
153 tracing::warn!(
154 error = %e,
155 attempt,
156 max_retries = fallback.max_retries,
157 "Fallback model failed, retrying"
158 );
159 }
160 }
161 }
162 }
163 }
164
165 pub async fn send_no_fallback(
166 &self,
167 request: CreateMessageRequest,
168 ) -> Result<crate::types::ApiResponse> {
169 self.adapter.send(&self.http, request).await
170 }
171
172 pub fn fallback_config(&self) -> Option<&FallbackConfig> {
173 self.fallback_config.as_ref()
174 }
175
176 pub async fn stream(
177 &self,
178 prompt: &str,
179 ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
180 let model = self.adapter.model(ModelType::Primary).to_string();
181 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
182 .with_max_tokens(self.adapter.config().max_tokens);
183
184 let response = self.adapter.send_stream(&self.http, request).await?;
185 let stream = StreamParser::new(response.bytes_stream());
186
187 Ok(futures::StreamExt::filter_map(stream, |item| async move {
188 match item {
189 Ok(StreamItem::Text(text)) => Some(Ok(text)),
190 Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
191 Ok(
192 StreamItem::Event(_) | StreamItem::Citation(_) | StreamItem::ToolUseComplete(_),
193 ) => None,
194 Err(e) => Some(Err(e)),
195 }
196 }))
197 }
198
199 pub async fn stream_request(
200 &self,
201 request: CreateMessageRequest,
202 ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
203 let response = self.adapter.send_stream(&self.http, request).await?;
204 Ok(StreamParser::new(response.bytes_stream()))
205 }
206
207 pub async fn stream_recoverable(
208 &self,
209 request: CreateMessageRequest,
210 ) -> Result<
211 RecoverableStream<
212 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
213 + Send
214 + 'static
215 + use<>,
216 >,
217 > {
218 let response = self.adapter.send_stream(&self.http, request).await?;
219 Ok(RecoverableStream::new(response.bytes_stream()))
220 }
221
222 pub async fn stream_with_recovery(
223 &self,
224 request: CreateMessageRequest,
225 recovery_state: Option<StreamRecoveryState>,
226 ) -> Result<
227 RecoverableStream<
228 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
229 + Send
230 + 'static
231 + use<>,
232 >,
233 > {
234 let request = match recovery_state {
235 Some(state) if state.is_recoverable() => {
236 let mut req = request;
237 req.messages = state.build_continuation_messages(&req.messages);
238 req
239 }
240 _ => request,
241 };
242 self.stream_recoverable(request).await
243 }
244
245 pub fn batch(&self) -> BatchClient<'_> {
246 BatchClient::new(self)
247 }
248
249 pub fn files(&self) -> FilesClient<'_> {
250 FilesClient::new(self)
251 }
252
253 pub fn adapter(&self) -> &dyn ProviderAdapter {
254 self.adapter.as_ref()
255 }
256
257 pub fn config(&self) -> &ProviderConfig {
258 self.adapter.config()
259 }
260
261 pub(crate) fn http(&self) -> &reqwest::Client {
262 &self.http
263 }
264
265 pub async fn refresh_credentials(&self) -> Result<()> {
266 self.adapter.refresh_credentials().await
267 }
268
269 pub async fn send_with_auth_retry(
273 &self,
274 request: CreateMessageRequest,
275 ) -> Result<crate::types::ApiResponse> {
276 match self.send(request.clone()).await {
277 Ok(resp) => Ok(resp),
278 Err(e) if e.is_unauthorized() => {
279 tracing::debug!("Received 401, attempting credential refresh");
280 self.refresh_credentials().await?;
281 self.send(request).await
282 }
283 Err(e) => Err(e),
284 }
285 }
286
287 pub async fn send_stream_with_auth_retry(
291 &self,
292 request: CreateMessageRequest,
293 ) -> Result<reqwest::Response> {
294 match self.adapter.send_stream(&self.http, request.clone()).await {
295 Ok(resp) => Ok(resp),
296 Err(e) if e.is_unauthorized() => {
297 tracing::debug!("Received 401, attempting credential refresh for stream");
298 self.refresh_credentials().await?;
299 self.adapter.send_stream(&self.http, request).await
300 }
301 Err(e) => Err(e),
302 }
303 }
304
305 pub async fn count_tokens(
306 &self,
307 request: messages::CountTokensRequest,
308 ) -> Result<messages::CountTokensResponse> {
309 self.adapter.count_tokens(&self.http, request).await
310 }
311
312 pub async fn count_tokens_for_request(
313 &self,
314 request: &CreateMessageRequest,
315 ) -> Result<messages::CountTokensResponse> {
316 let count_request = messages::CountTokensRequest::from_message_request(request);
317 self.count_tokens(count_request).await
318 }
319}
320
321impl std::fmt::Debug for Client {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 f.debug_struct("Client")
324 .field("provider", &self.adapter.name())
325 .finish()
326 }
327}
328
329#[derive(Default)]
330pub struct ClientBuilder {
331 provider: Option<CloudProvider>,
332 credential: Option<Credential>,
333 oauth_config: Option<OAuthConfig>,
334 config: Option<ProviderConfig>,
335 models: Option<ModelConfig>,
336 network: Option<NetworkConfig>,
337 gateway: Option<GatewayConfig>,
338 timeout: Option<Duration>,
339 fallback_config: Option<FallbackConfig>,
340 resilience_config: Option<ResilienceConfig>,
341
342 #[cfg(feature = "aws")]
343 aws_region: Option<String>,
344 #[cfg(feature = "gcp")]
345 gcp_project: Option<String>,
346 #[cfg(feature = "gcp")]
347 gcp_region: Option<String>,
348 #[cfg(feature = "azure")]
349 azure_resource: Option<String>,
350}
351
352impl ClientBuilder {
353 pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
357 let auth = auth.into();
358
359 #[allow(unreachable_patterns)]
360 match &auth {
361 #[cfg(feature = "aws")]
362 Auth::Bedrock { region } => {
363 self.provider = Some(CloudProvider::Bedrock);
364 self.aws_region = Some(region.clone());
365 }
366 #[cfg(feature = "gcp")]
367 Auth::Vertex { project, region } => {
368 self.provider = Some(CloudProvider::Vertex);
369 self.gcp_project = Some(project.clone());
370 self.gcp_region = Some(region.clone());
371 }
372 #[cfg(feature = "azure")]
373 Auth::Foundry { resource } => {
374 self.provider = Some(CloudProvider::Foundry);
375 self.azure_resource = Some(resource.clone());
376 }
377 _ => {
378 self.provider = Some(CloudProvider::Anthropic);
379 }
380 }
381
382 let credential = auth.resolve().await?;
383 if !credential.is_default() {
384 self.credential = Some(credential);
385 }
386
387 Ok(self)
388 }
389
390 pub fn anthropic(mut self) -> Self {
391 self.provider = Some(CloudProvider::Anthropic);
392 self
393 }
394
395 #[cfg(feature = "aws")]
396 pub(crate) fn with_aws_region(mut self, region: String) -> Self {
397 self.provider = Some(CloudProvider::Bedrock);
398 self.aws_region = Some(region);
399 self
400 }
401
402 #[cfg(feature = "gcp")]
403 pub(crate) fn with_gcp(mut self, project: String, region: String) -> Self {
404 self.provider = Some(CloudProvider::Vertex);
405 self.gcp_project = Some(project);
406 self.gcp_region = Some(region);
407 self
408 }
409
410 #[cfg(feature = "azure")]
411 pub(crate) fn with_azure_resource(mut self, resource: String) -> Self {
412 self.provider = Some(CloudProvider::Foundry);
413 self.azure_resource = Some(resource);
414 self
415 }
416
417 pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
418 self.oauth_config = Some(config);
419 self
420 }
421
422 pub fn models(mut self, models: ModelConfig) -> Self {
423 self.models = Some(models);
424 self
425 }
426
427 pub fn config(mut self, config: ProviderConfig) -> Self {
428 self.config = Some(config);
429 self
430 }
431
432 pub fn network(mut self, network: NetworkConfig) -> Self {
433 self.network = Some(network);
434 self
435 }
436
437 pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
438 self.gateway = Some(gateway);
439 self
440 }
441
442 pub fn timeout(mut self, timeout: Duration) -> Self {
443 self.timeout = Some(timeout);
444 self
445 }
446
447 pub fn fallback(mut self, config: FallbackConfig) -> Self {
448 self.fallback_config = Some(config);
449 self
450 }
451
452 pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
453 self.fallback_config = Some(FallbackConfig::new(model));
454 self
455 }
456
457 pub fn resilience(mut self, config: ResilienceConfig) -> Self {
458 self.resilience_config = Some(config);
459 self
460 }
461
462 pub fn with_default_resilience(mut self) -> Self {
463 self.resilience_config = Some(ResilienceConfig::default());
464 self
465 }
466
467 pub async fn build(self) -> Result<Client> {
468 let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
469
470 let models = self.models.unwrap_or_else(|| provider.default_models());
471
472 let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
473
474 let adapter: Box<dyn ProviderAdapter> = match provider {
475 CloudProvider::Anthropic => {
476 let adapter = if let Some(cred) = self.credential {
477 let mut a = AnthropicAdapter::from_credential(config, cred, self.oauth_config);
478 if let Some(ref gw) = self.gateway
479 && let Some(ref url) = gw.base_url
480 {
481 a = a.with_base_url(url);
482 }
483 a
484 } else {
485 let mut a = AnthropicAdapter::new(config);
486 if let Some(ref gw) = self.gateway {
487 if let Some(ref url) = gw.base_url {
488 a = a.with_base_url(url);
489 }
490 if let Some(ref token) = gw.auth_token {
491 a = a.with_api_key(token);
492 }
493 }
494 a
495 };
496 Box::new(adapter)
497 }
498 #[cfg(feature = "aws")]
499 CloudProvider::Bedrock => {
500 let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
501 if let Some(region) = self.aws_region {
502 adapter = adapter.with_region(region);
503 }
504 Box::new(adapter)
505 }
506 #[cfg(feature = "gcp")]
507 CloudProvider::Vertex => {
508 let mut adapter = adapter::VertexAdapter::from_env(config).await?;
509 if let Some(project) = self.gcp_project {
510 adapter = adapter.with_project(project);
511 }
512 if let Some(region) = self.gcp_region {
513 adapter = adapter.with_region(region);
514 }
515 Box::new(adapter)
516 }
517 #[cfg(feature = "azure")]
518 CloudProvider::Foundry => {
519 let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
520 if let Some(resource) = self.azure_resource {
521 adapter = adapter.with_resource(resource);
522 }
523 Box::new(adapter)
524 }
525 };
526
527 let mut http_builder =
528 reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
529
530 if let Some(ref network) = self.network {
531 http_builder = network
532 .apply_to_builder(http_builder)
533 .map_err(|e| Error::Config(e.to_string()))?;
534 }
535
536 let http = http_builder.build().map_err(Error::Network)?;
537
538 let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
539
540 Ok(Client {
541 adapter: Arc::from(adapter),
542 http,
543 fallback_config: self.fallback_config,
544 resilience,
545 })
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_client_builder() {
555 let _builder = Client::builder().anthropic();
556 }
557
558 #[test]
559 fn test_cloud_provider_from_env() {
560 let provider = CloudProvider::from_env();
561 assert_eq!(provider, CloudProvider::Anthropic);
562 }
563
564 #[tokio::test]
565 async fn test_builder_with_auth_credential() {
566 let _builder = Client::builder()
567 .anthropic()
568 .auth(Credential::api_key("test-key"))
569 .await
570 .unwrap();
571 }
572}