1pub mod audio_generation;
5pub mod completion;
6pub mod embeddings;
7pub mod image_generation;
8pub mod model_listing;
9pub mod transcription;
10pub mod verify;
11
12use bytes::Bytes;
13pub use completion::CompletionClient;
14pub use embeddings::EmbeddingsClient;
15use http::{HeaderMap, HeaderName, HeaderValue};
16pub use model_listing::{ModelLister, ModelListingClient};
17use std::{env::VarError, fmt::Debug, marker::PhantomData, sync::Arc};
18use thiserror::Error;
19pub use verify::{VerifyClient, VerifyError};
20
21#[cfg(feature = "image")]
22use crate::image_generation::ImageGenerationModel;
23#[cfg(feature = "image")]
24use image_generation::ImageGenerationClient;
25
26#[cfg(feature = "audio")]
27use crate::audio_generation::*;
28#[cfg(feature = "audio")]
29use audio_generation::*;
30
31use crate::{
32 completion::CompletionModel,
33 embeddings::EmbeddingModel,
34 http_client::{
35 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
36 },
37 markers::Missing,
38 prelude::TranscriptionClient,
39 transcription::TranscriptionModel,
40 wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46 #[error("reqwest error: {0}")]
48 HttpError(
49 #[from]
50 #[source]
51 reqwest::Error,
52 ),
53 #[error("invalid property: {0}")]
55 InvalidProperty(&'static str),
56}
57
58#[derive(Debug, Error)]
64#[non_exhaustive]
65pub enum ProviderClientError {
66 #[error("environment variable `{name}` is not set or is invalid")]
70 EnvironmentVariable {
71 name: &'static str,
73 #[source]
75 source: VarError,
76 },
77 #[error(transparent)]
79 Http(#[from] http_client::Error),
80 #[error("{0}")]
82 InvalidConfiguration(&'static str),
83}
84
85pub type ProviderClientResult<T> = std::result::Result<T, ProviderClientError>;
87
88pub fn required_env_var(name: &'static str) -> ProviderClientResult<String> {
93 std::env::var(name).map_err(|source| ProviderClientError::EnvironmentVariable { name, source })
94}
95
96pub fn optional_env_var(name: &'static str) -> ProviderClientResult<Option<String>> {
101 match std::env::var(name) {
102 Ok(value) => Ok(Some(value)),
103 Err(VarError::NotPresent) => Ok(None),
104 Err(source) => Err(ProviderClientError::EnvironmentVariable { name, source }),
105 }
106}
107
108pub trait ProviderClient {
111 type Input;
113 type Error;
115
116 fn from_env() -> Result<Self, Self::Error>
118 where
119 Self: Sized;
120
121 fn from_val(input: Self::Input) -> Result<Self, Self::Error>
123 where
124 Self: Sized;
125}
126
127pub trait ApiKey: Sized {
132 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
135 None
136 }
137}
138
139pub struct BearerAuth(String);
141
142impl ApiKey for BearerAuth {
143 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
144 Some(make_auth_header(self.0))
145 }
146}
147
148impl<S> From<S> for BearerAuth
149where
150 S: Into<String>,
151{
152 fn from(value: S) -> Self {
153 Self(value.into())
154 }
155}
156
157#[derive(Debug, Default, Clone, Copy)]
160pub struct Nothing;
161
162impl ApiKey for Nothing {}
163
164impl TryFrom<String> for Nothing {
165 type Error = &'static str;
166
167 fn try_from(_: String) -> Result<Self, Self::Error> {
168 Err(
169 "Tried to create a Nothing from a string - this should not happen, please file an issue",
170 )
171 }
172}
173
174#[derive(Clone)]
175pub struct Client<Ext = Nothing, H = reqwest::Client> {
181 base_url: Arc<str>,
182 headers: Arc<HeaderMap>,
183 http_client: H,
184 ext: Ext,
185}
186
187pub trait DebugExt: Debug {
189 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
191 std::iter::empty()
192 }
193}
194
195impl<Ext, H> std::fmt::Debug for Client<Ext, H>
196where
197 Ext: DebugExt,
198 H: std::fmt::Debug,
199{
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 let mut d = &mut f.debug_struct("Client");
202
203 d = d
204 .field("base_url", &self.base_url)
205 .field(
206 "headers",
207 &self
208 .headers
209 .iter()
210 .filter_map(|(k, v)| {
211 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
212 None
213 } else {
214 Some((k, v))
215 }
216 })
217 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
218 )
219 .field("http_client", &self.http_client);
220
221 self.ext
222 .fields()
223 .fold(d, |d, (name, field)| d.field(name, field))
224 .finish()
225 }
226}
227
228pub enum Transport {
229 Http,
231 Sse,
233 NdJson,
235}
236
237pub trait Provider: Sized {
241 type Builder: ProviderBuilder;
244
245 const VERIFY_PATH: &'static str;
247
248 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
250 let base_url = if base_url.is_empty() {
252 base_url.to_string()
253 } else {
254 base_url.to_string() + "/"
255 };
256
257 base_url.to_string() + path.trim_start_matches('/')
258 }
259
260 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
262 Ok(req)
263 }
264}
265
266pub struct Capable<M>(PhantomData<M>);
268
269pub trait Capability {
271 const CAPABLE: bool;
273}
274
275impl<M> Capability for Capable<M> {
276 const CAPABLE: bool = true;
277}
278
279impl Capability for Nothing {
280 const CAPABLE: bool = false;
281}
282
283pub trait Capabilities<H = reqwest::Client> {
285 type Completion: Capability;
287 type Embeddings: Capability;
289 type Transcription: Capability;
291 type ModelListing: Capability;
293 #[cfg(feature = "image")]
294 type ImageGeneration: Capability;
296 #[cfg(feature = "audio")]
297 type AudioGeneration: Capability;
299}
300
301pub trait ProviderBuilder: Sized + Default + Clone {
306 type Extension<H>: Provider
308 where
309 H: HttpClientExt;
310 type ApiKey: ApiKey;
312
313 const BASE_URL: &'static str;
315
316 fn build<H>(
318 builder: &ClientBuilder<Self, Self::ApiKey, H>,
319 ) -> http_client::Result<Self::Extension<H>>
320 where
321 H: HttpClientExt;
322
323 fn finish<H>(
326 &self,
327 builder: ClientBuilder<Self, Self::ApiKey, H>,
328 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
329 Ok(builder)
330 }
331}
332
333impl<Ext> Client<Ext, reqwest::Client>
337where
338 Ext: Provider,
339 Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
340{
341 pub fn new(
343 api_key: impl Into<<Ext::Builder as ProviderBuilder>::ApiKey>,
344 ) -> http_client::Result<Self> {
345 Self::builder().api_key(api_key).build()
346 }
347}
348
349impl<Ext, H> Client<Ext, H> {
350 pub fn base_url(&self) -> &str {
352 &self.base_url
353 }
354
355 pub fn headers(&self) -> &HeaderMap {
357 &self.headers
358 }
359
360 pub fn ext(&self) -> &Ext {
362 &self.ext
363 }
364
365 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
367 Client {
368 base_url: self.base_url,
369 headers: self.headers,
370 http_client: self.http_client,
371 ext: new_ext,
372 }
373 }
374}
375
376impl<Ext, H> HttpClientExt for Client<Ext, H>
377where
378 H: HttpClientExt + 'static,
379 Ext: WasmCompatSend + WasmCompatSync + 'static,
380{
381 fn send<T, U>(
382 &self,
383 mut req: Request<T>,
384 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
385 where
386 T: Into<Bytes> + WasmCompatSend,
387 U: From<Bytes>,
388 U: WasmCompatSend + 'static,
389 {
390 req.headers_mut().insert(
391 http::header::CONTENT_TYPE,
392 http::HeaderValue::from_static("application/json"),
393 );
394
395 self.http_client.send(req)
396 }
397
398 fn send_multipart<U>(
399 &self,
400 req: Request<MultipartForm>,
401 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
402 where
403 U: From<Bytes>,
404 U: WasmCompatSend + 'static,
405 {
406 self.http_client.send_multipart(req)
407 }
408
409 fn send_streaming<T>(
410 &self,
411 mut req: Request<T>,
412 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
413 where
414 T: Into<Bytes> + WasmCompatSend,
415 {
416 req.headers_mut().insert(
417 http::header::CONTENT_TYPE,
418 http::HeaderValue::from_static("application/json"),
419 );
420
421 self.http_client.send_streaming(req)
422 }
423}
424
425impl<Ext> Client<Ext, reqwest::Client>
431where
432 Ext: Provider,
433 Ext::Builder: ProviderBuilder + Default,
434{
435 pub fn builder() -> ClientBuilder<Ext::Builder, Missing, Missing> {
437 ClientBuilder::default()
438 }
439}
440
441impl<Ext, H> Client<Ext, H>
442where
443 Ext: Provider,
444{
445 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
447 where
448 S: AsRef<str>,
449 {
450 let uri = self
451 .ext
452 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
453
454 let mut req = Request::post(uri);
455
456 if let Some(hs) = req.headers_mut() {
457 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
458 }
459
460 self.ext.with_custom(req)
461 }
462
463 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
465 where
466 S: AsRef<str>,
467 {
468 let uri = self
469 .ext
470 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
471
472 let mut req = Request::post(uri);
473
474 if let Some(hs) = req.headers_mut() {
475 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
476 }
477
478 self.ext.with_custom(req)
479 }
480
481 pub fn get_sse<S>(&self, path: S) -> http_client::Result<Builder>
483 where
484 S: AsRef<str>,
485 {
486 let uri = self
487 .ext
488 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
489
490 let mut req = Request::get(uri);
491
492 if let Some(hs) = req.headers_mut() {
493 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
494 }
495
496 self.ext.with_custom(req)
497 }
498
499 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
501 where
502 S: AsRef<str>,
503 {
504 let uri = self
505 .ext
506 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
507
508 let mut req = Request::get(uri);
509
510 if let Some(hs) = req.headers_mut() {
511 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
512 }
513
514 self.ext.with_custom(req)
515 }
516}
517
518impl<Ext, H> VerifyClient for Client<Ext, H>
519where
520 H: HttpClientExt,
521 Ext: DebugExt + Provider + WasmCompatSync,
522{
523 async fn verify(&self) -> Result<(), VerifyError> {
524 use http::StatusCode;
525
526 let req = self
527 .get(Ext::VERIFY_PATH)?
528 .body(http_client::NoBody)
529 .map_err(http_client::Error::from)?;
530
531 let response = self.http_client.send(req).await?;
532
533 match response.status() {
534 StatusCode::OK => Ok(()),
535 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
536 Err(VerifyError::InvalidAuthentication)
537 }
538 StatusCode::INTERNAL_SERVER_ERROR => {
539 let text = http_client::text(response).await?;
540 Err(VerifyError::ProviderError(text))
541 }
542 status if status.as_u16() == 529 => {
543 let text = http_client::text(response).await?;
544 Err(VerifyError::ProviderError(text))
545 }
546 _ => {
547 let status = response.status();
548
549 if status.is_success() {
550 Ok(())
551 } else {
552 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
553 Err(VerifyError::HttpError(http_client::Error::Instance(
554 format!("Failed with '{status}': {text}").into(),
555 )))
556 }
557 }
558 }
559 }
560}
561
562#[derive(Clone)]
578pub struct ClientBuilder<Ext, ApiKey = Missing, H = Missing> {
579 base_url: String,
580 api_key: ApiKey,
581 headers: HeaderMap,
582 http_client: H,
583 ext: Ext,
584}
585
586impl<ExtBuilder> Default for ClientBuilder<ExtBuilder, Missing, Missing>
587where
588 ExtBuilder: ProviderBuilder + Default,
589{
590 fn default() -> Self {
591 Self {
592 api_key: Missing,
593 headers: Default::default(),
594 base_url: ExtBuilder::BASE_URL.into(),
595 http_client: Missing,
596 ext: Default::default(),
597 }
598 }
599}
600
601impl<Ext, H> ClientBuilder<Ext, Missing, H> {
602 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
605 ClientBuilder {
606 api_key: api_key.into(),
607 base_url: self.base_url,
608 headers: self.headers,
609 http_client: self.http_client,
610 ext: self.ext,
611 }
612 }
613}
614
615impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
616where
617 Ext: Clone,
618{
619 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
621 where
622 F: FnOnce(Ext) -> NewExt,
623 {
624 let ClientBuilder {
625 base_url,
626 api_key,
627 headers,
628 http_client,
629 ext,
630 } = self;
631
632 let new_ext = f(ext.clone());
633
634 ClientBuilder {
635 base_url,
636 api_key,
637 headers,
638 http_client,
639 ext: new_ext,
640 }
641 }
642
643 pub fn base_url<S>(self, base_url: S) -> Self
645 where
646 S: AsRef<str>,
647 {
648 Self {
649 base_url: base_url.as_ref().to_string(),
650 ..self
651 }
652 }
653
654 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
659 ClientBuilder {
660 http_client,
661 base_url: self.base_url,
662 api_key: self.api_key,
663 headers: self.headers,
664 ext: self.ext,
665 }
666 }
667
668 pub fn http_headers(self, headers: HeaderMap) -> Self {
670 Self { headers, ..self }
671 }
672
673 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
674 &mut self.headers
675 }
676
677 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
678 &mut self.ext
679 }
680}
681
682impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
683 pub(crate) fn get_api_key(&self) -> &ApiKey {
684 &self.api_key
685 }
686}
687
688impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
689 pub fn ext(&self) -> &Ext {
691 &self.ext
692 }
693
694 pub fn get_base_url(&self) -> &str {
696 &self.base_url
697 }
698}
699
700impl<ExtBuilder, Key> ClientBuilder<ExtBuilder, Key, Missing>
706where
707 ExtBuilder: ProviderBuilder<ApiKey = Key>,
708 Key: ApiKey,
709{
710 pub fn build(
712 self,
713 ) -> http_client::Result<Client<ExtBuilder::Extension<reqwest::Client>, reqwest::Client>> {
714 self.http_client(reqwest::Client::default()).build()
715 }
716}
717
718impl<ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
721where
722 ExtBuilder: ProviderBuilder<ApiKey = Key>,
723 Key: ApiKey,
724 H: HttpClientExt,
725{
726 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Extension<H>, H>> {
728 let ext_builder = self.ext.clone();
729
730 self = ext_builder.finish(self)?;
731 let ext = ExtBuilder::build(&self)?;
732
733 let ClientBuilder {
734 http_client,
735 base_url,
736 mut headers,
737 api_key,
738 ..
739 } = self;
740
741 if let Some((k, v)) = api_key.into_header().transpose()?
742 && !headers.contains_key(&k)
743 {
744 headers.insert(k, v);
745 }
746
747 Ok(Client {
748 http_client,
749 base_url: Arc::from(base_url.as_str()),
750 headers: Arc::new(headers),
751 ext,
752 })
753 }
754}
755
756impl<M, Ext, H> CompletionClient for Client<Ext, H>
757where
758 Ext: Capabilities<H, Completion = Capable<M>>,
759 M: CompletionModel<Client = Self>,
760{
761 type CompletionModel = M;
762
763 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
764 M::make(self, model)
765 }
766}
767
768impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
769where
770 Ext: Capabilities<H, Embeddings = Capable<M>>,
771 M: EmbeddingModel<Client = Self>,
772{
773 type EmbeddingModel = M;
774
775 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
776 M::make(self, model, None)
777 }
778
779 fn embedding_model_with_ndims(
780 &self,
781 model: impl Into<String>,
782 ndims: usize,
783 ) -> Self::EmbeddingModel {
784 M::make(self, model, Some(ndims))
785 }
786}
787
788impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
789where
790 Ext: Capabilities<H, Transcription = Capable<M>>,
791 M: TranscriptionModel<Client = Self> + WasmCompatSend,
792{
793 type TranscriptionModel = M;
794
795 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
796 M::make(self, model)
797 }
798}
799
800#[cfg(feature = "image")]
801impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
802where
803 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
804 M: ImageGenerationModel<Client = Self>,
805{
806 type ImageGenerationModel = M;
807
808 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
809 M::make(self, model)
810 }
811}
812
813#[cfg(feature = "audio")]
814impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
815where
816 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
817 M: AudioGenerationModel<Client = Self>,
818{
819 type AudioGenerationModel = M;
820
821 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
822 M::make(self, model)
823 }
824}
825
826impl<M, Ext, H> ModelListingClient for Client<Ext, H>
827where
828 Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
829 M: ModelLister<H, Client = Self> + WasmCompatSend + WasmCompatSync + Clone + 'static,
830 H: WasmCompatSend + WasmCompatSync + Clone,
831{
832 fn list_models(
833 &self,
834 ) -> impl std::future::Future<
835 Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
836 > + WasmCompatSend {
837 let lister = M::new(self.clone());
838 async move { lister.list_all().await }
839 }
840}
841
842#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
843mod wasm_model_listing_compile_checks {
844 use super::{ModelListingClient, Nothing};
845 use crate::{
846 http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response},
847 providers::{anthropic, deepseek, mistral, ollama, openai, openrouter},
848 wasm_compat::WasmCompatSend,
849 };
850 use bytes::Bytes;
851 use std::{
852 future::{self, Future},
853 marker::PhantomData,
854 rc::Rc,
855 };
856
857 #[derive(Clone, Default)]
858 struct WasmOnlyHttpClient {
859 _not_send_sync: PhantomData<Rc<()>>,
860 }
861
862 impl HttpClientExt for WasmOnlyHttpClient {
863 fn send<T, U>(
864 &self,
865 _req: Request<T>,
866 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
867 where
868 T: Into<Bytes> + WasmCompatSend,
869 U: From<Bytes> + WasmCompatSend + 'static,
870 {
871 future::ready(Err(http_client::Error::StreamEnded))
872 }
873
874 fn send_multipart<U>(
875 &self,
876 _req: Request<MultipartForm>,
877 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
878 where
879 U: From<Bytes> + WasmCompatSend + 'static,
880 {
881 future::ready(Err(http_client::Error::StreamEnded))
882 }
883
884 fn send_streaming<T>(
885 &self,
886 _req: Request<T>,
887 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
888 where
889 T: Into<Bytes> + WasmCompatSend,
890 {
891 future::ready(Err(http_client::Error::StreamEnded))
892 }
893 }
894
895 fn assert_model_listing_client<C>(client: C)
896 where
897 C: ModelListingClient,
898 {
899 let _ = client.list_models();
900 }
901
902 fn assert_simple_model_listers_accept_wasm_only_http_clients() {
903 let _ = openrouter::Client::builder()
904 .api_key("dummy-key")
905 .http_client(WasmOnlyHttpClient::default())
906 .build()
907 .map(assert_model_listing_client);
908
909 let _ = openai::Client::builder()
910 .api_key("dummy-key")
911 .http_client(WasmOnlyHttpClient::default())
912 .build()
913 .map(assert_model_listing_client);
914
915 let _ = mistral::Client::builder()
916 .api_key("dummy-key")
917 .http_client(WasmOnlyHttpClient::default())
918 .build()
919 .map(assert_model_listing_client);
920
921 let _ = anthropic::Client::builder()
922 .api_key("dummy-key")
923 .http_client(WasmOnlyHttpClient::default())
924 .build()
925 .map(assert_model_listing_client);
926
927 let _ = ollama::Client::builder()
928 .api_key(Nothing)
929 .http_client(WasmOnlyHttpClient::default())
930 .build()
931 .map(assert_model_listing_client);
932
933 let _ = deepseek::Client::builder()
934 .api_key("dummy-key")
935 .http_client(WasmOnlyHttpClient::default())
936 .build()
937 .map(assert_model_listing_client);
938 }
939
940 #[allow(dead_code)]
941 fn compile_assertions() {
942 assert_simple_model_listers_accept_wasm_only_http_clients();
943 }
944}
945
946#[cfg(test)]
947mod tests {
948 use crate::providers::anthropic;
949
950 #[test]
953 fn ensures_client_builder_no_annotation() {
954 let http_client = reqwest::Client::default();
955 let _ = anthropic::Client::builder()
956 .http_client(http_client)
957 .api_key("Foo")
958 .build()
959 .unwrap();
960 }
961}