1pub mod audio_generation;
5pub mod completion;
6pub mod embeddings;
7pub mod image_generation;
8pub mod model_listing;
9pub mod transcription;
10pub mod verify;
11
12use bytes::Bytes;
13pub use completion::CompletionClient;
14pub use embeddings::EmbeddingsClient;
15use http::{HeaderMap, HeaderName, HeaderValue};
16pub use model_listing::{ModelLister, ModelListingClient};
17use std::{env::VarError, fmt::Debug, marker::PhantomData, sync::Arc};
18use thiserror::Error;
19pub use verify::{VerifyClient, VerifyError};
20
21#[cfg(feature = "image")]
22use crate::image_generation::ImageGenerationModel;
23#[cfg(feature = "image")]
24use image_generation::ImageGenerationClient;
25
26#[cfg(feature = "audio")]
27use crate::audio_generation::*;
28#[cfg(feature = "audio")]
29use audio_generation::*;
30
31use crate::{
32 completion::CompletionModel,
33 embeddings::EmbeddingModel,
34 http_client::{
35 self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
36 },
37 markers::Missing,
38 prelude::TranscriptionClient,
39 transcription::TranscriptionModel,
40 wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46 #[error("reqwest error: {0}")]
47 HttpError(
48 #[from]
49 #[source]
50 reqwest::Error,
51 ),
52 #[error("invalid property: {0}")]
53 InvalidProperty(&'static str),
54}
55
56#[derive(Debug, Error)]
62#[non_exhaustive]
63pub enum ProviderClientError {
64 #[error("environment variable `{name}` is not set or is invalid")]
68 EnvironmentVariable {
69 name: &'static str,
71 #[source]
73 source: VarError,
74 },
75 #[error(transparent)]
77 Http(#[from] http_client::Error),
78 #[error("{0}")]
80 InvalidConfiguration(&'static str),
81}
82
83pub type ProviderClientResult<T> = std::result::Result<T, ProviderClientError>;
85
86pub fn required_env_var(name: &'static str) -> ProviderClientResult<String> {
91 std::env::var(name).map_err(|source| ProviderClientError::EnvironmentVariable { name, source })
92}
93
94pub fn optional_env_var(name: &'static str) -> ProviderClientResult<Option<String>> {
99 match std::env::var(name) {
100 Ok(value) => Ok(Some(value)),
101 Err(VarError::NotPresent) => Ok(None),
102 Err(source) => Err(ProviderClientError::EnvironmentVariable { name, source }),
103 }
104}
105
106pub trait ProviderClient {
109 type Input;
111 type Error;
113
114 fn from_env() -> Result<Self, Self::Error>
116 where
117 Self: Sized;
118
119 fn from_val(input: Self::Input) -> Result<Self, Self::Error>
121 where
122 Self: Sized;
123}
124
125pub trait ApiKey: Sized {
128 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
129 None
130 }
131}
132
133pub struct BearerAuth(String);
135
136impl ApiKey for BearerAuth {
137 fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
138 Some(make_auth_header(self.0))
139 }
140}
141
142impl<S> From<S> for BearerAuth
143where
144 S: Into<String>,
145{
146 fn from(value: S) -> Self {
147 Self(value.into())
148 }
149}
150
151#[derive(Debug, Default, Clone, Copy)]
154pub struct Nothing;
155
156impl ApiKey for Nothing {}
157
158impl TryFrom<String> for Nothing {
159 type Error = &'static str;
160
161 fn try_from(_: String) -> Result<Self, Self::Error> {
162 Err(
163 "Tried to create a Nothing from a string - this should not happen, please file an issue",
164 )
165 }
166}
167
168#[derive(Clone)]
169pub struct Client<Ext = Nothing, H = reqwest::Client> {
170 base_url: Arc<str>,
171 headers: Arc<HeaderMap>,
172 http_client: H,
173 ext: Ext,
174}
175
176pub trait DebugExt: Debug {
177 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
178 std::iter::empty()
179 }
180}
181
182impl<Ext, H> std::fmt::Debug for Client<Ext, H>
183where
184 Ext: DebugExt,
185 H: std::fmt::Debug,
186{
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 let mut d = &mut f.debug_struct("Client");
189
190 d = d
191 .field("base_url", &self.base_url)
192 .field(
193 "headers",
194 &self
195 .headers
196 .iter()
197 .filter_map(|(k, v)| {
198 if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
199 None
200 } else {
201 Some((k, v))
202 }
203 })
204 .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
205 )
206 .field("http_client", &self.http_client);
207
208 self.ext
209 .fields()
210 .fold(d, |d, (name, field)| d.field(name, field))
211 .finish()
212 }
213}
214
215pub enum Transport {
216 Http,
217 Sse,
218 NdJson,
219}
220
221pub trait Provider: Sized {
225 type Builder: ProviderBuilder;
228
229 const VERIFY_PATH: &'static str;
230
231 fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
232 let base_url = if base_url.is_empty() {
234 base_url.to_string()
235 } else {
236 base_url.to_string() + "/"
237 };
238
239 base_url.to_string() + path.trim_start_matches('/')
240 }
241
242 fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
243 Ok(req)
244 }
245}
246
247pub struct Capable<M>(PhantomData<M>);
249
250pub trait Capability {
251 const CAPABLE: bool;
252}
253
254impl<M> Capability for Capable<M> {
255 const CAPABLE: bool = true;
256}
257
258impl Capability for Nothing {
259 const CAPABLE: bool = false;
260}
261
262pub trait Capabilities<H = reqwest::Client> {
264 type Completion: Capability;
265 type Embeddings: Capability;
266 type Transcription: Capability;
267 type ModelListing: Capability;
268 #[cfg(feature = "image")]
269 type ImageGeneration: Capability;
270 #[cfg(feature = "audio")]
271 type AudioGeneration: Capability;
272}
273
274pub trait ProviderBuilder: Sized + Default + Clone {
279 type Extension<H>: Provider
280 where
281 H: HttpClientExt;
282 type ApiKey: ApiKey;
283
284 const BASE_URL: &'static str;
285
286 fn build<H>(
288 builder: &ClientBuilder<Self, Self::ApiKey, H>,
289 ) -> http_client::Result<Self::Extension<H>>
290 where
291 H: HttpClientExt;
292
293 fn finish<H>(
296 &self,
297 builder: ClientBuilder<Self, Self::ApiKey, H>,
298 ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
299 Ok(builder)
300 }
301}
302
303impl<Ext> Client<Ext, reqwest::Client>
304where
305 Ext: Provider,
306 Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
307{
308 pub fn new(
309 api_key: impl Into<<Ext::Builder as ProviderBuilder>::ApiKey>,
310 ) -> http_client::Result<Self> {
311 Self::builder().api_key(api_key).build()
312 }
313}
314
315impl<Ext, H> Client<Ext, H> {
316 pub fn base_url(&self) -> &str {
317 &self.base_url
318 }
319
320 pub fn headers(&self) -> &HeaderMap {
321 &self.headers
322 }
323
324 pub fn ext(&self) -> &Ext {
325 &self.ext
326 }
327
328 pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
329 Client {
330 base_url: self.base_url,
331 headers: self.headers,
332 http_client: self.http_client,
333 ext: new_ext,
334 }
335 }
336}
337
338impl<Ext, H> HttpClientExt for Client<Ext, H>
339where
340 H: HttpClientExt + 'static,
341 Ext: WasmCompatSend + WasmCompatSync + 'static,
342{
343 fn send<T, U>(
344 &self,
345 mut req: Request<T>,
346 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
347 where
348 T: Into<Bytes> + WasmCompatSend,
349 U: From<Bytes>,
350 U: WasmCompatSend + 'static,
351 {
352 req.headers_mut().insert(
353 http::header::CONTENT_TYPE,
354 http::HeaderValue::from_static("application/json"),
355 );
356
357 self.http_client.send(req)
358 }
359
360 fn send_multipart<U>(
361 &self,
362 req: Request<MultipartForm>,
363 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
364 where
365 U: From<Bytes>,
366 U: WasmCompatSend + 'static,
367 {
368 self.http_client.send_multipart(req)
369 }
370
371 fn send_streaming<T>(
372 &self,
373 mut req: Request<T>,
374 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
375 where
376 T: Into<Bytes> + WasmCompatSend,
377 {
378 req.headers_mut().insert(
379 http::header::CONTENT_TYPE,
380 http::HeaderValue::from_static("application/json"),
381 );
382
383 self.http_client.send_streaming(req)
384 }
385}
386
387impl<Ext> Client<Ext, reqwest::Client>
388where
389 Ext: Provider,
390 Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
391{
392 pub fn builder() -> ClientBuilder<Ext::Builder, Missing, reqwest::Client> {
393 ClientBuilder {
394 api_key: Missing,
395 headers: Default::default(),
396 base_url: <Ext::Builder as ProviderBuilder>::BASE_URL.into(),
397 http_client: None,
398 ext: Default::default(),
399 }
400 }
401}
402
403impl<Ext, H> Client<Ext, H>
404where
405 Ext: Provider,
406{
407 pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
408 where
409 S: AsRef<str>,
410 {
411 let uri = self
412 .ext
413 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
414
415 let mut req = Request::post(uri);
416
417 if let Some(hs) = req.headers_mut() {
418 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
419 }
420
421 self.ext.with_custom(req)
422 }
423
424 pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
425 where
426 S: AsRef<str>,
427 {
428 let uri = self
429 .ext
430 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
431
432 let mut req = Request::post(uri);
433
434 if let Some(hs) = req.headers_mut() {
435 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
436 }
437
438 self.ext.with_custom(req)
439 }
440
441 pub fn get_sse<S>(&self, path: S) -> http_client::Result<Builder>
442 where
443 S: AsRef<str>,
444 {
445 let uri = self
446 .ext
447 .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
448
449 let mut req = Request::get(uri);
450
451 if let Some(hs) = req.headers_mut() {
452 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
453 }
454
455 self.ext.with_custom(req)
456 }
457
458 pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
459 where
460 S: AsRef<str>,
461 {
462 let uri = self
463 .ext
464 .build_uri(&self.base_url, path.as_ref(), Transport::Http);
465
466 let mut req = Request::get(uri);
467
468 if let Some(hs) = req.headers_mut() {
469 hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
470 }
471
472 self.ext.with_custom(req)
473 }
474}
475
476impl<Ext, H> VerifyClient for Client<Ext, H>
477where
478 H: HttpClientExt,
479 Ext: DebugExt + Provider + WasmCompatSync,
480{
481 async fn verify(&self) -> Result<(), VerifyError> {
482 use http::StatusCode;
483
484 let req = self
485 .get(Ext::VERIFY_PATH)?
486 .body(http_client::NoBody)
487 .map_err(http_client::Error::from)?;
488
489 let response = self.http_client.send(req).await?;
490
491 match response.status() {
492 StatusCode::OK => Ok(()),
493 StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
494 Err(VerifyError::InvalidAuthentication)
495 }
496 StatusCode::INTERNAL_SERVER_ERROR => {
497 let text = http_client::text(response).await?;
498 Err(VerifyError::ProviderError(text))
499 }
500 status if status.as_u16() == 529 => {
501 let text = http_client::text(response).await?;
502 Err(VerifyError::ProviderError(text))
503 }
504 _ => {
505 let status = response.status();
506
507 if status.is_success() {
508 Ok(())
509 } else {
510 let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
511 Err(VerifyError::HttpError(http_client::Error::Instance(
512 format!("Failed with '{status}': {text}").into(),
513 )))
514 }
515 }
516 }
517 }
518}
519
520#[derive(Clone)]
522pub struct ClientBuilder<Ext, ApiKey = Missing, H = reqwest::Client> {
523 base_url: String,
524 api_key: ApiKey,
525 headers: HeaderMap,
526 http_client: Option<H>,
527 ext: Ext,
528}
529
530impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, Missing, H>
531where
532 H: Default,
533 ExtBuilder: ProviderBuilder + Default,
534{
535 fn default() -> Self {
536 Self {
537 api_key: Missing,
538 headers: Default::default(),
539 base_url: ExtBuilder::BASE_URL.into(),
540 http_client: None,
541 ext: Default::default(),
542 }
543 }
544}
545
546impl<Ext, H> ClientBuilder<Ext, Missing, H> {
547 pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
550 ClientBuilder {
551 api_key: api_key.into(),
552 base_url: self.base_url,
553 headers: self.headers,
554 http_client: self.http_client,
555 ext: self.ext,
556 }
557 }
558}
559
560impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
561where
562 Ext: Clone,
563{
564 pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
566 where
567 F: FnOnce(Ext) -> NewExt,
568 {
569 let ClientBuilder {
570 base_url,
571 api_key,
572 headers,
573 http_client,
574 ext,
575 } = self;
576
577 let new_ext = f(ext.clone());
578
579 ClientBuilder {
580 base_url,
581 api_key,
582 headers,
583 http_client,
584 ext: new_ext,
585 }
586 }
587
588 pub fn base_url<S>(self, base_url: S) -> Self
590 where
591 S: AsRef<str>,
592 {
593 Self {
594 base_url: base_url.as_ref().to_string(),
595 ..self
596 }
597 }
598
599 pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
601 ClientBuilder {
602 http_client: Some(http_client),
603 base_url: self.base_url,
604 api_key: self.api_key,
605 headers: self.headers,
606 ext: self.ext,
607 }
608 }
609
610 pub fn http_headers(self, headers: HeaderMap) -> Self {
612 Self { headers, ..self }
613 }
614
615 pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
616 &mut self.headers
617 }
618
619 pub(crate) fn ext_mut(&mut self) -> &mut Ext {
620 &mut self.ext
621 }
622}
623
624impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
625 pub(crate) fn get_api_key(&self) -> &ApiKey {
626 &self.api_key
627 }
628}
629
630impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
631 pub fn ext(&self) -> &Ext {
632 &self.ext
633 }
634
635 pub fn get_base_url(&self) -> &str {
636 &self.base_url
637 }
638}
639
640impl<ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
641where
642 ExtBuilder: ProviderBuilder<ApiKey = Key>,
643 Key: ApiKey,
644 H: Default + HttpClientExt,
645{
646 pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Extension<H>, H>> {
647 let ext_builder = self.ext.clone();
648
649 self = ext_builder.finish(self)?;
650 let ext = ExtBuilder::build(&self)?;
651
652 let ClientBuilder {
653 http_client,
654 base_url,
655 mut headers,
656 api_key,
657 ..
658 } = self;
659
660 if let Some((k, v)) = api_key.into_header().transpose()?
661 && !headers.contains_key(&k)
662 {
663 headers.insert(k, v);
664 }
665
666 let http_client = http_client.unwrap_or_default();
667
668 Ok(Client {
669 http_client,
670 base_url: Arc::from(base_url.as_str()),
671 headers: Arc::new(headers),
672 ext,
673 })
674 }
675}
676
677impl<M, Ext, H> CompletionClient for Client<Ext, H>
678where
679 Ext: Capabilities<H, Completion = Capable<M>>,
680 M: CompletionModel<Client = Self>,
681{
682 type CompletionModel = M;
683
684 fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
685 M::make(self, model)
686 }
687}
688
689impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
690where
691 Ext: Capabilities<H, Embeddings = Capable<M>>,
692 M: EmbeddingModel<Client = Self>,
693{
694 type EmbeddingModel = M;
695
696 fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
697 M::make(self, model, None)
698 }
699
700 fn embedding_model_with_ndims(
701 &self,
702 model: impl Into<String>,
703 ndims: usize,
704 ) -> Self::EmbeddingModel {
705 M::make(self, model, Some(ndims))
706 }
707}
708
709impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
710where
711 Ext: Capabilities<H, Transcription = Capable<M>>,
712 M: TranscriptionModel<Client = Self> + WasmCompatSend,
713{
714 type TranscriptionModel = M;
715
716 fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
717 M::make(self, model)
718 }
719}
720
721#[cfg(feature = "image")]
722impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
723where
724 Ext: Capabilities<H, ImageGeneration = Capable<M>>,
725 M: ImageGenerationModel<Client = Self>,
726{
727 type ImageGenerationModel = M;
728
729 fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
730 M::make(self, model)
731 }
732}
733
734#[cfg(feature = "audio")]
735impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
736where
737 Ext: Capabilities<H, AudioGeneration = Capable<M>>,
738 M: AudioGenerationModel<Client = Self>,
739{
740 type AudioGenerationModel = M;
741
742 fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
743 M::make(self, model)
744 }
745}
746
747impl<M, Ext, H> ModelListingClient for Client<Ext, H>
748where
749 Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
750 M: ModelLister<H, Client = Self> + WasmCompatSend + WasmCompatSync + Clone + 'static,
751 H: WasmCompatSend + WasmCompatSync + Clone,
752{
753 fn list_models(
754 &self,
755 ) -> impl std::future::Future<
756 Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
757 > + WasmCompatSend {
758 let lister = M::new(self.clone());
759 async move { lister.list_all().await }
760 }
761}
762
763#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
764mod wasm_model_listing_compile_checks {
765 use super::{ModelListingClient, Nothing};
766 use crate::{
767 http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response},
768 providers::{anthropic, deepseek, mistral, ollama, openai, openrouter},
769 wasm_compat::WasmCompatSend,
770 };
771 use bytes::Bytes;
772 use std::{
773 future::{self, Future},
774 marker::PhantomData,
775 rc::Rc,
776 };
777
778 #[derive(Clone, Default)]
779 struct WasmOnlyHttpClient {
780 _not_send_sync: PhantomData<Rc<()>>,
781 }
782
783 impl HttpClientExt for WasmOnlyHttpClient {
784 fn send<T, U>(
785 &self,
786 _req: Request<T>,
787 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
788 where
789 T: Into<Bytes> + WasmCompatSend,
790 U: From<Bytes> + WasmCompatSend + 'static,
791 {
792 future::ready(Err(http_client::Error::StreamEnded))
793 }
794
795 fn send_multipart<U>(
796 &self,
797 _req: Request<MultipartForm>,
798 ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
799 where
800 U: From<Bytes> + WasmCompatSend + 'static,
801 {
802 future::ready(Err(http_client::Error::StreamEnded))
803 }
804
805 fn send_streaming<T>(
806 &self,
807 _req: Request<T>,
808 ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
809 where
810 T: Into<Bytes> + WasmCompatSend,
811 {
812 future::ready(Err(http_client::Error::StreamEnded))
813 }
814 }
815
816 fn assert_model_listing_client<C>(client: C)
817 where
818 C: ModelListingClient,
819 {
820 let _ = client.list_models();
821 }
822
823 fn assert_simple_model_listers_accept_wasm_only_http_clients() {
824 let _ = openrouter::Client::builder()
825 .api_key("dummy-key")
826 .http_client(WasmOnlyHttpClient::default())
827 .build()
828 .map(assert_model_listing_client);
829
830 let _ = openai::Client::builder()
831 .api_key("dummy-key")
832 .http_client(WasmOnlyHttpClient::default())
833 .build()
834 .map(assert_model_listing_client);
835
836 let _ = mistral::Client::builder()
837 .api_key("dummy-key")
838 .http_client(WasmOnlyHttpClient::default())
839 .build()
840 .map(assert_model_listing_client);
841
842 let _ = anthropic::Client::builder()
843 .api_key("dummy-key")
844 .http_client(WasmOnlyHttpClient::default())
845 .build()
846 .map(assert_model_listing_client);
847
848 let _ = ollama::Client::builder()
849 .api_key(Nothing)
850 .http_client(WasmOnlyHttpClient::default())
851 .build()
852 .map(assert_model_listing_client);
853
854 let _ = deepseek::Client::builder()
855 .api_key("dummy-key")
856 .http_client(WasmOnlyHttpClient::default())
857 .build()
858 .map(assert_model_listing_client);
859 }
860
861 #[allow(dead_code)]
862 fn compile_assertions() {
863 assert_simple_model_listers_accept_wasm_only_http_clients();
864 }
865}
866
867#[cfg(test)]
868mod tests {
869 use crate::providers::anthropic;
870
871 #[test]
874 fn ensures_client_builder_no_annotation() {
875 let http_client = reqwest::Client::default();
876 let _ = anthropic::Client::builder()
877 .http_client(http_client)
878 .api_key("Foo")
879 .build()
880 .unwrap();
881 }
882}