1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod model_listing;
11pub mod transcription;
12pub mod verify;
13
14use bytes::Bytes;
15pub use completion::CompletionClient;
16pub use embeddings::EmbeddingsClient;
17use http::{HeaderMap, HeaderName, HeaderValue};
18pub use model_listing::{ModelLister, ModelListingClient};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, marker::PhantomData, sync::Arc};
21use thiserror::Error;
22pub use verify::{VerifyClient, VerifyError};
23
24#[cfg(feature = "image")]
25use crate::image_generation::ImageGenerationModel;
26#[cfg(feature = "image")]
27use image_generation::ImageGenerationClient;
28
29#[cfg(feature = "audio")]
30use crate::audio_generation::*;
31#[cfg(feature = "audio")]
32use audio_generation::*;
33
34use crate::{
35 completion::CompletionModel,
36 embeddings::EmbeddingModel,
37 http_client::{
38 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
39 },
40 prelude::TranscriptionClient,
41 transcription::TranscriptionModel,
42 wasm_compat::{WasmCompatSend, WasmCompatSync},
43};
44
45#[derive(Debug, Error)]
46#[non_exhaustive]
47pub enum ClientBuilderError {
48 #[error("reqwest error: {0}")]
49 HttpError(
50 #[from]
51 #[source]
52 reqwest::Error,
53 ),
54 #[error("invalid property: {0}")]
55 InvalidProperty(&'static str),
56}
57
58pub trait ProviderClient {
61 type Input;
62
63 fn from_env() -> Self;
66
67 fn from_val(input: Self::Input) -> Self;
68}
69
70use crate::completion::{GetTokenUsage, Usage};
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FinalCompletionResponse {
75 pub usage: Option<Usage>,
76}
77
78impl GetTokenUsage for FinalCompletionResponse {
79 fn token_usage(&self) -> Option<Usage> {
80 self.usage
81 }
82}
83
84pub trait ApiKey: Sized {
87 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
88 None
89 }
90}
91
92pub struct BearerAuth(String);
94
95impl ApiKey for BearerAuth {
96 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
97 Some(make_auth_header(self.0))
98 }
99}
100
101impl<S> From<S> for BearerAuth
102where
103 S: Into<String>,
104{
105 fn from(value: S) -> Self {
106 Self(value.into())
107 }
108}
109
110#[derive(Debug, Default, Clone, Copy)]
113pub struct Nothing;
114
115impl ApiKey for Nothing {}
116
117impl TryFrom<String> for Nothing {
118 type Error = &'static str;
119
120 fn try_from(_: String) -> Result<Self, Self::Error> {
121 Err(
122 "Tried to create a Nothing from a string - this should not happen, please file an issue",
123 )
124 }
125}
126
127#[derive(Clone)]
128pub struct Client<Ext = Nothing, H = reqwest::Client> {
129 base_url: Arc<str>,
130 headers: Arc<HeaderMap>,
131 http_client: H,
132 ext: Ext,
133}
134
135pub trait DebugExt: Debug {
136 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
137 std::iter::empty()
138 }
139}
140
141impl<Ext, H> std::fmt::Debug for Client<Ext, H>
142where
143 Ext: DebugExt,
144 H: std::fmt::Debug,
145{
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 let mut d = &mut f.debug_struct("Client");
148
149 d = d
150 .field("base_url", &self.base_url)
151 .field(
152 "headers",
153 &self
154 .headers
155 .iter()
156 .filter_map(|(k, v)| {
157 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
158 None
159 } else {
160 Some((k, v))
161 }
162 })
163 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
164 )
165 .field("http_client", &self.http_client);
166
167 self.ext
168 .fields()
169 .fold(d, |d, (name, field)| d.field(name, field))
170 .finish()
171 }
172}
173
174pub enum Transport {
175 Http,
176 Sse,
177 NdJson,
178}
179
180pub trait Provider: Sized {
184 type Builder: ProviderBuilder;
187
188 const VERIFY_PATH: &'static str;
189
190 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191 let base_url = if base_url.is_empty() {
193 base_url.to_string()
194 } else {
195 base_url.to_string() + "/"
196 };
197
198 base_url.to_string() + path.trim_start_matches('/')
199 }
200
201 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
202 Ok(req)
203 }
204}
205
206pub struct Capable<M>(PhantomData<M>);
208
209pub trait Capability {
210 const CAPABLE: bool;
211}
212
213impl<M> Capability for Capable<M> {
214 const CAPABLE: bool = true;
215}
216
217impl Capability for Nothing {
218 const CAPABLE: bool = false;
219}
220
221pub trait Capabilities<H = reqwest::Client> {
223 type Completion: Capability;
224 type Embeddings: Capability;
225 type Transcription: Capability;
226 type ModelListing: Capability;
227 #[cfg(feature = "image")]
228 type ImageGeneration: Capability;
229 #[cfg(feature = "audio")]
230 type AudioGeneration: Capability;
231}
232
233pub trait ProviderBuilder: Sized + Default + Clone {
238 type Extension<H>: Provider
239 where
240 H: HttpClientExt;
241 type ApiKey: ApiKey;
242
243 const BASE_URL: &'static str;
244
245 fn build<H>(
247 builder: &ClientBuilder<Self, Self::ApiKey, H>,
248 ) -> http_client::Result<Self::Extension<H>>
249 where
250 H: HttpClientExt;
251
252 fn finish<H>(
255 &self,
256 builder: ClientBuilder<Self, Self::ApiKey, H>,
257 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
258 Ok(builder)
259 }
260}
261
262impl<Ext> Client<Ext, reqwest::Client>
263where
264 Ext: Provider,
265 Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
266{
267 pub fn new(
268 api_key: impl Into<<Ext::Builder as ProviderBuilder>::ApiKey>,
269 ) -> http_client::Result<Self> {
270 Self::builder().api_key(api_key).build()
271 }
272}
273
274impl<Ext, H> Client<Ext, H> {
275 pub fn base_url(&self) -> &str {
276 &self.base_url
277 }
278
279 pub fn headers(&self) -> &HeaderMap {
280 &self.headers
281 }
282
283 pub fn ext(&self) -> &Ext {
284 &self.ext
285 }
286
287 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
288 Client {
289 base_url: self.base_url,
290 headers: self.headers,
291 http_client: self.http_client,
292 ext: new_ext,
293 }
294 }
295}
296
297impl<Ext, H> HttpClientExt for Client<Ext, H>
298where
299 H: HttpClientExt + 'static,
300 Ext: WasmCompatSend + WasmCompatSync + 'static,
301{
302 fn send<T, U>(
303 &self,
304 mut req: Request<T>,
305 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
306 where
307 T: Into<Bytes> + WasmCompatSend,
308 U: From<Bytes>,
309 U: WasmCompatSend + 'static,
310 {
311 req.headers_mut().insert(
312 http::header::CONTENT_TYPE,
313 http::HeaderValue::from_static("application/json"),
314 );
315
316 self.http_client.send(req)
317 }
318
319 fn send_multipart<U>(
320 &self,
321 req: Request<MultipartForm>,
322 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
323 where
324 U: From<Bytes>,
325 U: WasmCompatSend + 'static,
326 {
327 self.http_client.send_multipart(req)
328 }
329
330 fn send_streaming<T>(
331 &self,
332 mut req: Request<T>,
333 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
334 where
335 T: Into<Bytes>,
336 {
337 req.headers_mut().insert(
338 http::header::CONTENT_TYPE,
339 http::HeaderValue::from_static("application/json"),
340 );
341
342 self.http_client.send_streaming(req)
343 }
344}
345
346impl<Ext> Client<Ext, reqwest::Client>
347where
348 Ext: Provider,
349 Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
350{
351 pub fn builder() -> ClientBuilder<Ext::Builder, NeedsApiKey, reqwest::Client> {
352 ClientBuilder {
353 api_key: NeedsApiKey,
354 headers: Default::default(),
355 base_url: <Ext::Builder as ProviderBuilder>::BASE_URL.into(),
356 http_client: None,
357 ext: Default::default(),
358 }
359 }
360}
361
362impl<Ext, H> Client<Ext, H>
363where
364 Ext: Provider,
365{
366 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
367 where
368 S: AsRef<str>,
369 {
370 let uri = self
371 .ext
372 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
373
374 let mut req = Request::post(uri);
375
376 if let Some(hs) = req.headers_mut() {
377 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
378 }
379
380 self.ext.with_custom(req)
381 }
382
383 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
384 where
385 S: AsRef<str>,
386 {
387 let uri = self
388 .ext
389 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
390
391 let mut req = Request::post(uri);
392
393 if let Some(hs) = req.headers_mut() {
394 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
395 }
396
397 self.ext.with_custom(req)
398 }
399
400 pub fn get_sse<S>(&self, path: S) -> http_client::Result<Builder>
401 where
402 S: AsRef<str>,
403 {
404 let uri = self
405 .ext
406 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
407
408 let mut req = Request::get(uri);
409
410 if let Some(hs) = req.headers_mut() {
411 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
412 }
413
414 self.ext.with_custom(req)
415 }
416
417 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
418 where
419 S: AsRef<str>,
420 {
421 let uri = self
422 .ext
423 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
424
425 let mut req = Request::get(uri);
426
427 if let Some(hs) = req.headers_mut() {
428 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
429 }
430
431 self.ext.with_custom(req)
432 }
433}
434
435impl<Ext, H> VerifyClient for Client<Ext, H>
436where
437 H: HttpClientExt,
438 Ext: DebugExt + Provider + WasmCompatSync,
439{
440 async fn verify(&self) -> Result<(), VerifyError> {
441 use http::StatusCode;
442
443 let req = self
444 .get(Ext::VERIFY_PATH)?
445 .body(http_client::NoBody)
446 .map_err(http_client::Error::from)?;
447
448 let response = self.http_client.send(req).await?;
449
450 match response.status() {
451 StatusCode::OK => Ok(()),
452 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
453 Err(VerifyError::InvalidAuthentication)
454 }
455 StatusCode::INTERNAL_SERVER_ERROR => {
456 let text = http_client::text(response).await?;
457 Err(VerifyError::ProviderError(text))
458 }
459 status if status.as_u16() == 529 => {
460 let text = http_client::text(response).await?;
461 Err(VerifyError::ProviderError(text))
462 }
463 _ => {
464 let status = response.status();
465
466 if status.is_success() {
467 Ok(())
468 } else {
469 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
470 Err(VerifyError::HttpError(http_client::Error::Instance(
471 format!("Failed with '{status}': {text}").into(),
472 )))
473 }
474 }
475 }
476 }
477}
478
479#[derive(Debug, Default, Clone, Copy)]
480pub struct NeedsApiKey;
481
482#[derive(Clone)]
484pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
485 base_url: String,
486 api_key: ApiKey,
487 headers: HeaderMap,
488 http_client: Option<H>,
489 ext: Ext,
490}
491
492impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
493where
494 H: Default,
495 ExtBuilder: ProviderBuilder + Default,
496{
497 fn default() -> Self {
498 Self {
499 api_key: NeedsApiKey,
500 headers: Default::default(),
501 base_url: ExtBuilder::BASE_URL.into(),
502 http_client: None,
503 ext: Default::default(),
504 }
505 }
506}
507
508impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
509 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
512 ClientBuilder {
513 api_key: api_key.into(),
514 base_url: self.base_url,
515 headers: self.headers,
516 http_client: self.http_client,
517 ext: self.ext,
518 }
519 }
520}
521
522impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
523where
524 Ext: Clone,
525{
526 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
528 where
529 F: FnOnce(Ext) -> NewExt,
530 {
531 let ClientBuilder {
532 base_url,
533 api_key,
534 headers,
535 http_client,
536 ext,
537 } = self;
538
539 let new_ext = f(ext.clone());
540
541 ClientBuilder {
542 base_url,
543 api_key,
544 headers,
545 http_client,
546 ext: new_ext,
547 }
548 }
549
550 pub fn base_url<S>(self, base_url: S) -> Self
552 where
553 S: AsRef<str>,
554 {
555 Self {
556 base_url: base_url.as_ref().to_string(),
557 ..self
558 }
559 }
560
561 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
563 ClientBuilder {
564 http_client: Some(http_client),
565 base_url: self.base_url,
566 api_key: self.api_key,
567 headers: self.headers,
568 ext: self.ext,
569 }
570 }
571
572 pub fn http_headers(self, headers: HeaderMap) -> Self {
574 Self { headers, ..self }
575 }
576
577 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
578 &mut self.headers
579 }
580
581 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
582 &mut self.ext
583 }
584}
585
586impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
587 pub(crate) fn get_api_key(&self) -> &ApiKey {
588 &self.api_key
589 }
590}
591
592impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
593 pub fn ext(&self) -> &Ext {
594 &self.ext
595 }
596}
597
598impl<ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
599where
600 ExtBuilder: ProviderBuilder<ApiKey = Key>,
601 Key: ApiKey,
602 H: Default + HttpClientExt,
603{
604 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Extension<H>, H>> {
605 let ext_builder = self.ext.clone();
606
607 self = ext_builder.finish(self)?;
608 let ext = ExtBuilder::build(&self)?;
609
610 let ClientBuilder {
611 http_client,
612 base_url,
613 mut headers,
614 api_key,
615 ..
616 } = self;
617
618 if let Some((k, v)) = api_key.into_header().transpose()? {
619 headers.insert(k, v);
620 }
621
622 let http_client = http_client.unwrap_or_default();
623
624 Ok(Client {
625 http_client,
626 base_url: Arc::from(base_url.as_str()),
627 headers: Arc::new(headers),
628 ext,
629 })
630 }
631}
632
633impl<M, Ext, H> CompletionClient for Client<Ext, H>
634where
635 Ext: Capabilities<H, Completion = Capable<M>>,
636 M: CompletionModel<Client = Self>,
637{
638 type CompletionModel = M;
639
640 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
641 M::make(self, model)
642 }
643}
644
645impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
646where
647 Ext: Capabilities<H, Embeddings = Capable<M>>,
648 M: EmbeddingModel<Client = Self>,
649{
650 type EmbeddingModel = M;
651
652 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
653 M::make(self, model, None)
654 }
655
656 fn embedding_model_with_ndims(
657 &self,
658 model: impl Into<String>,
659 ndims: usize,
660 ) -> Self::EmbeddingModel {
661 M::make(self, model, Some(ndims))
662 }
663}
664
665impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
666where
667 Ext: Capabilities<H, Transcription = Capable<M>>,
668 M: TranscriptionModel<Client = Self> + WasmCompatSend,
669{
670 type TranscriptionModel = M;
671
672 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
673 M::make(self, model)
674 }
675}
676
677#[cfg(feature = "image")]
678impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
679where
680 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
681 M: ImageGenerationModel<Client = Self>,
682{
683 type ImageGenerationModel = M;
684
685 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
686 M::make(self, model)
687 }
688}
689
690#[cfg(feature = "audio")]
691impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
692where
693 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
694 M: AudioGenerationModel<Client = Self>,
695{
696 type AudioGenerationModel = M;
697
698 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
699 M::make(self, model)
700 }
701}
702
703impl<M, Ext, H> ModelListingClient for Client<Ext, H>
704where
705 Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
706 M: ModelLister<H, Client = Self> + Send + Sync + Clone + 'static,
707 H: Send + Sync + Clone,
708{
709 fn list_models(
710 &self,
711 ) -> impl std::future::Future<
712 Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
713 > + WasmCompatSend {
714 let lister = M::new(self.clone());
715 async move { lister.list_all().await }
716 }
717}
718
719#[cfg(test)]
720mod tests {
721 use crate::providers::anthropic;
722
723 #[test]
726 fn ensures_client_builder_no_annotation() {
727 let http_client = reqwest::Client::default();
728 let _ = anthropic::Client::builder()
729 .http_client(http_client)
730 .api_key("Foo")
731 .build()
732 .unwrap();
733 }
734}