1pub mod adapter;
4pub mod batch;
5pub mod fallback;
6pub mod files;
7pub mod gateway;
8pub mod messages;
9pub mod network;
10pub mod recovery;
11pub mod resilience;
12pub mod schema;
13mod streaming;
14
15pub use adapter::{
16 AnthropicAdapter, BetaConfig, BetaFeature, CloudProvider, DEFAULT_MODEL,
17 DEFAULT_REASONING_MODEL, DEFAULT_SMALL_MODEL, FRONTIER_MODEL, ModelConfig, ModelType,
18 ProviderAdapter, ProviderConfig,
19};
20pub use batch::{
21 BatchClient, BatchRequest, BatchResult, BatchStatus, CreateBatchRequest, MessageBatch,
22};
23pub use fallback::{FallbackConfig, FallbackTrigger};
24pub use files::{File, FileData, FileDownload, FileListResponse, FilesClient, UploadFileRequest};
25pub use gateway::GatewayConfig;
26pub use messages::{
27 ClearConfig, ClearTrigger, ContextEdit, ContextManagement, CountTokensContextManagement,
28 CountTokensRequest, CountTokensResponse, CreateMessageRequest, EffortLevel, KeepConfig,
29 KeepThinkingConfig, OutputConfig, OutputFormat, ThinkingConfig, ThinkingType, ToolChoice,
30};
31pub use network::{ClientCertConfig, NetworkConfig, ProxyConfig};
32pub use recovery::StreamRecoveryState;
33pub use resilience::{
34 CircuitBreaker, CircuitConfig, CircuitState, ExponentialBackoff, Resilience, ResilienceConfig,
35 RetryConfig,
36};
37pub use schema::{strict_schema, transform_for_strict};
38pub use streaming::{RecoverableStream, StreamItem, StreamParser};
39
40#[cfg(feature = "aws")]
41pub use adapter::BedrockAdapter;
42#[cfg(feature = "azure")]
43pub use adapter::FoundryAdapter;
44#[cfg(feature = "gcp")]
45pub use adapter::VertexAdapter;
46
47use std::sync::Arc;
48use std::time::Duration;
49
50use crate::auth::{Auth, Credential, OAuthConfig};
51use crate::{Error, Result};
52
53const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
54
55#[derive(Clone)]
56pub struct Client {
57 adapter: Arc<dyn ProviderAdapter>,
58 http: reqwest::Client,
59 fallback_config: Option<FallbackConfig>,
60 resilience: Option<Arc<Resilience>>,
61}
62
63impl Client {
64 pub fn new(adapter: impl ProviderAdapter + 'static) -> Result<Self> {
65 let timeout = DEFAULT_TIMEOUT;
66 let http = reqwest::Client::builder()
67 .timeout(timeout)
68 .build()
69 .map_err(Error::Network)?;
70
71 Ok(Self {
72 adapter: Arc::new(adapter),
73 http,
74 fallback_config: None,
75 resilience: None,
76 })
77 }
78
79 pub fn with_http(adapter: impl ProviderAdapter + 'static, http: reqwest::Client) -> Self {
80 Self {
81 adapter: Arc::new(adapter),
82 http,
83 fallback_config: None,
84 resilience: None,
85 }
86 }
87
88 pub fn with_fallback(mut self, config: FallbackConfig) -> Self {
89 self.fallback_config = Some(config);
90 self
91 }
92
93 pub fn with_resilience(mut self, config: ResilienceConfig) -> Self {
94 self.resilience = Some(Arc::new(Resilience::new(config)));
95 self
96 }
97
98 pub fn resilience(&self) -> Option<&Arc<Resilience>> {
99 self.resilience.as_ref()
100 }
101
102 pub fn builder() -> ClientBuilder {
103 ClientBuilder::default()
104 }
105
106 pub async fn query(&self, prompt: &str) -> Result<String> {
107 self.query_with_model(prompt, ModelType::Primary).await
108 }
109
110 pub async fn query_with_model(&self, prompt: &str, model_type: ModelType) -> Result<String> {
111 let model = self.adapter.model(model_type).to_string();
112 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
113 .with_max_tokens(self.adapter.config().max_tokens);
114
115 let response = self.adapter.send(&self.http, request).await?;
116 Ok(response.text())
117 }
118
119 pub async fn send(&self, request: CreateMessageRequest) -> Result<crate::types::ApiResponse> {
120 let fallback = match &self.fallback_config {
121 Some(f) => f,
122 None => return self.adapter.send(&self.http, request).await,
123 };
124
125 let mut current_request = request;
126 let mut attempt = 0;
127 let mut using_fallback = false;
128
129 loop {
130 match self.adapter.send(&self.http, current_request.clone()).await {
131 Ok(response) => return Ok(response),
132 Err(e) => {
133 if !fallback.should_fallback(&e) {
134 return Err(e);
135 }
136
137 attempt += 1;
138 if attempt > fallback.max_retries {
139 return Err(e);
140 }
141
142 if !using_fallback {
143 tracing::warn!(
144 error = %e,
145 fallback_model = %fallback.fallback_model,
146 attempt,
147 max_retries = fallback.max_retries,
148 "Primary model failed, falling back"
149 );
150 current_request = current_request.with_model(&fallback.fallback_model);
151 using_fallback = true;
152 } else {
153 tracing::warn!(
154 error = %e,
155 attempt,
156 max_retries = fallback.max_retries,
157 "Fallback model failed, retrying"
158 );
159 }
160 }
161 }
162 }
163 }
164
165 pub async fn send_no_fallback(
166 &self,
167 request: CreateMessageRequest,
168 ) -> Result<crate::types::ApiResponse> {
169 self.adapter.send(&self.http, request).await
170 }
171
172 pub fn fallback_config(&self) -> Option<&FallbackConfig> {
173 self.fallback_config.as_ref()
174 }
175
176 pub async fn stream(
177 &self,
178 prompt: &str,
179 ) -> Result<impl futures::Stream<Item = Result<String>> + Send + 'static + use<>> {
180 let model = self.adapter.model(ModelType::Primary).to_string();
181 let request = CreateMessageRequest::new(&model, vec![crate::types::Message::user(prompt)])
182 .with_max_tokens(self.adapter.config().max_tokens);
183
184 let response = self.adapter.send_stream(&self.http, request).await?;
185 let stream = StreamParser::new(response.bytes_stream());
186
187 Ok(futures::StreamExt::filter_map(stream, |item| async move {
188 match item {
189 Ok(StreamItem::Text(text)) => Some(Ok(text)),
190 Ok(StreamItem::Thinking(text)) => Some(Ok(text)),
191 Ok(StreamItem::Event(_) | StreamItem::Citation(_)) => None,
192 Err(e) => Some(Err(e)),
193 }
194 }))
195 }
196
197 pub async fn stream_request(
198 &self,
199 request: CreateMessageRequest,
200 ) -> Result<impl futures::Stream<Item = Result<StreamItem>> + Send + 'static + use<>> {
201 let response = self.adapter.send_stream(&self.http, request).await?;
202 Ok(StreamParser::new(response.bytes_stream()))
203 }
204
205 pub async fn stream_recoverable(
206 &self,
207 request: CreateMessageRequest,
208 ) -> Result<
209 RecoverableStream<
210 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
211 + Send
212 + 'static
213 + use<>,
214 >,
215 > {
216 let response = self.adapter.send_stream(&self.http, request).await?;
217 Ok(RecoverableStream::new(response.bytes_stream()))
218 }
219
220 pub async fn stream_with_recovery(
221 &self,
222 request: CreateMessageRequest,
223 recovery_state: Option<StreamRecoveryState>,
224 ) -> Result<
225 RecoverableStream<
226 impl futures::Stream<Item = std::result::Result<bytes::Bytes, reqwest::Error>>
227 + Send
228 + 'static
229 + use<>,
230 >,
231 > {
232 let request = match recovery_state {
233 Some(state) if state.is_recoverable() => {
234 let mut req = request;
235 req.messages = state.build_continuation_messages(&req.messages);
236 req
237 }
238 _ => request,
239 };
240 self.stream_recoverable(request).await
241 }
242
243 pub fn batch(&self) -> BatchClient<'_> {
244 BatchClient::new(self)
245 }
246
247 pub fn files(&self) -> FilesClient<'_> {
248 FilesClient::new(self)
249 }
250
251 pub fn adapter(&self) -> &dyn ProviderAdapter {
252 self.adapter.as_ref()
253 }
254
255 pub fn config(&self) -> &ProviderConfig {
256 self.adapter.config()
257 }
258
259 pub(crate) fn http(&self) -> &reqwest::Client {
260 &self.http
261 }
262
263 pub async fn refresh_credentials(&self) -> Result<()> {
264 self.adapter.refresh_credentials().await
265 }
266
267 pub async fn count_tokens(
268 &self,
269 request: messages::CountTokensRequest,
270 ) -> Result<messages::CountTokensResponse> {
271 self.adapter.count_tokens(&self.http, request).await
272 }
273
274 pub async fn count_tokens_for_request(
275 &self,
276 request: &CreateMessageRequest,
277 ) -> Result<messages::CountTokensResponse> {
278 let count_request = messages::CountTokensRequest::from_message_request(request);
279 self.count_tokens(count_request).await
280 }
281}
282
283impl std::fmt::Debug for Client {
284 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
285 f.debug_struct("Client")
286 .field("provider", &self.adapter.name())
287 .finish()
288 }
289}
290
291#[derive(Default)]
292pub struct ClientBuilder {
293 provider: Option<CloudProvider>,
294 credential: Option<Credential>,
295 oauth_config: Option<OAuthConfig>,
296 config: Option<ProviderConfig>,
297 models: Option<ModelConfig>,
298 network: Option<NetworkConfig>,
299 gateway: Option<GatewayConfig>,
300 timeout: Option<Duration>,
301 fallback_config: Option<FallbackConfig>,
302 resilience_config: Option<ResilienceConfig>,
303
304 #[cfg(feature = "aws")]
305 aws_region: Option<String>,
306 #[cfg(feature = "gcp")]
307 gcp_project: Option<String>,
308 #[cfg(feature = "gcp")]
309 gcp_region: Option<String>,
310 #[cfg(feature = "azure")]
311 azure_resource: Option<String>,
312}
313
314impl ClientBuilder {
315 pub async fn auth(mut self, auth: impl Into<Auth>) -> Result<Self> {
319 let auth = auth.into();
320
321 #[allow(unreachable_patterns)]
322 match &auth {
323 #[cfg(feature = "aws")]
324 Auth::Bedrock { region } => {
325 self.provider = Some(CloudProvider::Bedrock);
326 self.aws_region = Some(region.clone());
327 }
328 #[cfg(feature = "gcp")]
329 Auth::Vertex { project, region } => {
330 self.provider = Some(CloudProvider::Vertex);
331 self.gcp_project = Some(project.clone());
332 self.gcp_region = Some(region.clone());
333 }
334 #[cfg(feature = "azure")]
335 Auth::Foundry { resource } => {
336 self.provider = Some(CloudProvider::Foundry);
337 self.azure_resource = Some(resource.clone());
338 }
339 _ => {
340 self.provider = Some(CloudProvider::Anthropic);
341 }
342 }
343
344 let credential = auth.resolve().await?;
345 if !credential.is_default() {
346 self.credential = Some(credential);
347 }
348
349 Ok(self)
350 }
351
352 pub fn anthropic(mut self) -> Self {
353 self.provider = Some(CloudProvider::Anthropic);
354 self
355 }
356
357 #[cfg(feature = "aws")]
358 pub(crate) fn with_aws_region(mut self, region: String) -> Self {
359 self.provider = Some(CloudProvider::Bedrock);
360 self.aws_region = Some(region);
361 self
362 }
363
364 #[cfg(feature = "gcp")]
365 pub(crate) fn with_gcp(mut self, project: String, region: String) -> Self {
366 self.provider = Some(CloudProvider::Vertex);
367 self.gcp_project = Some(project);
368 self.gcp_region = Some(region);
369 self
370 }
371
372 #[cfg(feature = "azure")]
373 pub(crate) fn with_azure_resource(mut self, resource: String) -> Self {
374 self.provider = Some(CloudProvider::Foundry);
375 self.azure_resource = Some(resource);
376 self
377 }
378
379 pub fn oauth_config(mut self, config: OAuthConfig) -> Self {
380 self.oauth_config = Some(config);
381 self
382 }
383
384 pub fn models(mut self, models: ModelConfig) -> Self {
385 self.models = Some(models);
386 self
387 }
388
389 pub fn config(mut self, config: ProviderConfig) -> Self {
390 self.config = Some(config);
391 self
392 }
393
394 pub fn network(mut self, network: NetworkConfig) -> Self {
395 self.network = Some(network);
396 self
397 }
398
399 pub fn gateway(mut self, gateway: GatewayConfig) -> Self {
400 self.gateway = Some(gateway);
401 self
402 }
403
404 pub fn timeout(mut self, timeout: Duration) -> Self {
405 self.timeout = Some(timeout);
406 self
407 }
408
409 pub fn fallback(mut self, config: FallbackConfig) -> Self {
410 self.fallback_config = Some(config);
411 self
412 }
413
414 pub fn fallback_model(mut self, model: impl Into<String>) -> Self {
415 self.fallback_config = Some(FallbackConfig::new(model));
416 self
417 }
418
419 pub fn resilience(mut self, config: ResilienceConfig) -> Self {
420 self.resilience_config = Some(config);
421 self
422 }
423
424 pub fn with_default_resilience(mut self) -> Self {
425 self.resilience_config = Some(ResilienceConfig::default());
426 self
427 }
428
429 pub async fn build(self) -> Result<Client> {
430 let provider = self.provider.unwrap_or_else(CloudProvider::from_env);
431
432 let models = self.models.unwrap_or_else(|| provider.default_models());
433
434 let config = self.config.unwrap_or_else(|| ProviderConfig::new(models));
435
436 let adapter: Box<dyn ProviderAdapter> = match provider {
437 CloudProvider::Anthropic => {
438 let adapter = if let Some(cred) = self.credential {
439 let mut a = AnthropicAdapter::from_credential(config, cred, self.oauth_config);
440 if let Some(ref gw) = self.gateway
441 && let Some(ref url) = gw.base_url
442 {
443 a = a.with_base_url(url);
444 }
445 a
446 } else {
447 let mut a = AnthropicAdapter::new(config);
448 if let Some(ref gw) = self.gateway {
449 if let Some(ref url) = gw.base_url {
450 a = a.with_base_url(url);
451 }
452 if let Some(ref token) = gw.auth_token {
453 a = a.with_api_key(token);
454 }
455 }
456 a
457 };
458 Box::new(adapter)
459 }
460 #[cfg(feature = "aws")]
461 CloudProvider::Bedrock => {
462 let mut adapter = adapter::BedrockAdapter::from_env(config).await?;
463 if let Some(region) = self.aws_region {
464 adapter = adapter.with_region(region);
465 }
466 Box::new(adapter)
467 }
468 #[cfg(feature = "gcp")]
469 CloudProvider::Vertex => {
470 let mut adapter = adapter::VertexAdapter::from_env(config).await?;
471 if let Some(project) = self.gcp_project {
472 adapter = adapter.with_project(project);
473 }
474 if let Some(region) = self.gcp_region {
475 adapter = adapter.with_region(region);
476 }
477 Box::new(adapter)
478 }
479 #[cfg(feature = "azure")]
480 CloudProvider::Foundry => {
481 let mut adapter = adapter::FoundryAdapter::from_env(config).await?;
482 if let Some(resource) = self.azure_resource {
483 adapter = adapter.with_resource(resource);
484 }
485 Box::new(adapter)
486 }
487 };
488
489 let mut http_builder =
490 reqwest::Client::builder().timeout(self.timeout.unwrap_or(DEFAULT_TIMEOUT));
491
492 if let Some(ref network) = self.network {
493 http_builder = network
494 .apply_to_builder(http_builder)
495 .map_err(|e| Error::Config(e.to_string()))?;
496 }
497
498 let http = http_builder.build().map_err(Error::Network)?;
499
500 let resilience = self.resilience_config.map(|c| Arc::new(Resilience::new(c)));
501
502 Ok(Client {
503 adapter: Arc::from(adapter),
504 http,
505 fallback_config: self.fallback_config,
506 resilience,
507 })
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_client_builder() {
517 let _builder = Client::builder().anthropic();
518 }
519
520 #[test]
521 fn test_cloud_provider_from_env() {
522 let provider = CloudProvider::from_env();
523 assert_eq!(provider, CloudProvider::Anthropic);
524 }
525
526 #[tokio::test]
527 async fn test_builder_with_auth_credential() {
528 let _builder = Client::builder()
529 .anthropic()
530 .auth(Credential::api_key("test-key"))
531 .await
532 .unwrap();
533 }
534}