Skip to main content

rig_core/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3
4pub mod audio_generation;
5pub mod completion;
6pub mod embeddings;
7pub mod image_generation;
8pub mod model_listing;
9pub mod rerank;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17pub use model_listing::{ModelLister, ModelListingClient};
18pub use rerank::RerankingClient;
19use std::{env::VarError, fmt::Debug, marker::PhantomData, sync::Arc};
20use thiserror::Error;
21pub use verify::{VerifyClient, VerifyError};
22
23#[cfg(feature = "image")]
24use crate::image_generation::ImageGenerationModel;
25#[cfg(feature = "image")]
26use image_generation::ImageGenerationClient;
27
28#[cfg(feature = "audio")]
29use crate::audio_generation::*;
30#[cfg(feature = "audio")]
31use audio_generation::*;
32
33use crate::{
34    completion::CompletionModel,
35    embeddings::EmbeddingModel,
36    http_client::{
37        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
38    },
39    markers::Missing,
40    prelude::TranscriptionClient,
41    rerank::RerankModel,
42    transcription::TranscriptionModel,
43    wasm_compat::{WasmCompatSend, WasmCompatSync},
44};
45
46#[derive(Debug, Error)]
47#[non_exhaustive]
48pub enum ClientBuilderError {
49    /// The underlying HTTP backend failed during builder construction.
50    #[error("reqwest error: {0}")]
51    HttpError(
52        #[from]
53        #[source]
54        reqwest::Error,
55    ),
56    /// A provider-specific builder property was invalid.
57    #[error("invalid property: {0}")]
58    InvalidProperty(&'static str),
59}
60
61/// Errors returned while constructing provider clients from environment variables or explicit input.
62///
63/// Provider-specific client constructors use this error for configuration problems that can be
64/// detected before any model request is sent, such as missing API keys, invalid environment
65/// values, or invalid builder configuration.
66#[derive(Debug, Error)]
67#[non_exhaustive]
68pub enum ProviderClientError {
69    /// A required or optional environment variable could not be read as valid Unicode.
70    ///
71    /// For required variables, this variant is also returned when the variable is not present.
72    #[error("environment variable `{name}` is not set or is invalid")]
73    EnvironmentVariable {
74        /// The environment variable name.
75        name: &'static str,
76        /// The underlying environment lookup error.
77        #[source]
78        source: VarError,
79    },
80    /// The underlying provider client builder failed while constructing HTTP configuration.
81    #[error(transparent)]
82    Http(#[from] http_client::Error),
83    /// The provider received an unsupported or incomplete configuration.
84    #[error("{0}")]
85    InvalidConfiguration(&'static str),
86}
87
88/// Result type returned by provider client construction helpers.
89pub type ProviderClientResult<T> = std::result::Result<T, ProviderClientError>;
90
91/// Read a required environment variable for provider client construction.
92///
93/// Returns [`ProviderClientError::EnvironmentVariable`] when the variable is missing or contains
94/// invalid Unicode.
95pub fn required_env_var(name: &'static str) -> ProviderClientResult<String> {
96    std::env::var(name).map_err(|source| ProviderClientError::EnvironmentVariable { name, source })
97}
98
99/// Read an optional environment variable for provider client construction.
100///
101/// Missing variables return `Ok(None)`. Variables containing invalid Unicode return
102/// [`ProviderClientError::EnvironmentVariable`].
103pub fn optional_env_var(name: &'static str) -> ProviderClientResult<Option<String>> {
104    match std::env::var(name) {
105        Ok(value) => Ok(Some(value)),
106        Err(VarError::NotPresent) => Ok(None),
107        Err(source) => Err(ProviderClientError::EnvironmentVariable { name, source }),
108    }
109}
110
111/// Abstracts over the ability to instantiate a client, either via environment variables or some
112/// `Self::Input`
113pub trait ProviderClient {
114    /// Input accepted by [`ProviderClient::from_val`].
115    type Input;
116    /// Error returned when client construction fails.
117    type Error;
118
119    /// Create a client from the process's environment.
120    fn from_env() -> Result<Self, Self::Error>
121    where
122        Self: Sized;
123
124    /// Create a client from an explicit provider-specific input value.
125    fn from_val(input: Self::Input) -> Result<Self, Self::Error>
126    where
127        Self: Sized;
128}
129
130/// A trait for API key inputs accepted by [`ClientBuilder::api_key`].
131///
132/// Returning `Some` inserts a header into the generic [`Client`]. Returning `None`
133/// lets the provider extension handle credentials itself.
134pub trait ApiKey: Sized {
135    /// Convert this key into a default request header, if the generic client
136    /// should own that authentication header.
137    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
138        None
139    }
140}
141
142/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
143pub struct BearerAuth(String);
144
145impl ApiKey for BearerAuth {
146    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
147        Some(make_auth_header(self.0))
148    }
149}
150
151impl<S> From<S> for BearerAuth
152where
153    S: Into<String>,
154{
155    fn from(value: S) -> Self {
156        Self(value.into())
157    }
158}
159
160/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
161/// the lack of a capability or field (an API key, for instance)
162#[derive(Debug, Default, Clone, Copy)]
163pub struct Nothing;
164
165impl ApiKey for Nothing {}
166
167impl TryFrom<String> for Nothing {
168    type Error = &'static str;
169
170    fn try_from(_: String) -> Result<Self, Self::Error> {
171        Err(
172            "Tried to create a Nothing from a string - this should not happen, please file an issue",
173        )
174    }
175}
176
177#[derive(Clone)]
178/// Generic provider client shared by Rig provider integrations.
179///
180/// `Ext` stores provider-specific behavior such as URL construction, request
181/// customization, and capabilities. `H` is the HTTP backend and defaults to
182/// `reqwest::Client`.
183pub struct Client<Ext = Nothing, H = reqwest::Client> {
184    base_url: Arc<str>,
185    headers: Arc<HeaderMap>,
186    http_client: H,
187    ext: Ext,
188}
189
190/// Provider extension hook for redacted [`Debug`] output.
191pub trait DebugExt: Debug {
192    /// Additional provider-specific fields to include in `Client` debug output.
193    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
194        std::iter::empty()
195    }
196}
197
198impl<Ext, H> std::fmt::Debug for Client<Ext, H>
199where
200    Ext: DebugExt,
201    H: std::fmt::Debug,
202{
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        let mut d = &mut f.debug_struct("Client");
205
206        d = d
207            .field("base_url", &self.base_url)
208            .field(
209                "headers",
210                &self
211                    .headers
212                    .iter()
213                    .filter_map(|(k, v)| {
214                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
215                            None
216                        } else {
217                            Some((k, v))
218                        }
219                    })
220                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
221            )
222            .field("http_client", &self.http_client);
223
224        self.ext
225            .fields()
226            .fold(d, |d, (name, field)| d.field(name, field))
227            .finish()
228    }
229}
230
231pub enum Transport {
232    /// Regular request/response HTTP transport.
233    Http,
234    /// Server-sent events streaming transport.
235    Sse,
236    /// Newline-delimited JSON streaming transport.
237    NdJson,
238}
239
240/// An API provider extension, this abstracts over extensions which may be used in conjunction with
241/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
242/// auth, instantiating models
243pub trait Provider: Sized {
244    /// The builder type that constructs this provider extension.
245    /// This associates extensions with their builders for type inference.
246    type Builder: ProviderBuilder;
247
248    /// Provider endpoint used by [`VerifyClient`] to validate credentials.
249    const VERIFY_PATH: &'static str;
250
251    /// Build a complete request URI for the given base URL, provider path, and transport.
252    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
253        // Some providers (like Azure) have a blank base URL to allow users to input their own endpoints.
254        let base_url = if base_url.is_empty() || base_url.ends_with('/') {
255            base_url.to_string()
256        } else {
257            // Only add a slash to the base_url when it doesn't already end with a slash
258            base_url.to_string() + "/"
259        };
260
261        base_url + path.trim_start_matches('/')
262    }
263
264    /// Apply provider-specific request customization before sending.
265    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
266        Ok(req)
267    }
268}
269
270/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
271pub struct Capable<M>(PhantomData<M>);
272
273/// Type-level marker for whether a provider supports a capability.
274pub trait Capability {
275    /// Whether this marker represents a supported capability.
276    const CAPABLE: bool;
277}
278
279impl<M> Capability for Capable<M> {
280    const CAPABLE: bool = true;
281}
282
283impl Capability for Nothing {
284    const CAPABLE: bool = false;
285}
286
287/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
288pub trait Capabilities<H = reqwest::Client> {
289    /// Completion model capability marker.
290    type Completion: Capability;
291    /// Embedding model capability marker.
292    type Embeddings: Capability;
293    /// Rerank model capability marker.
294    type Rerank: Capability;
295    /// Audio transcription model capability marker.
296    type Transcription: Capability;
297    /// Model listing capability marker.
298    type ModelListing: Capability;
299    #[cfg(feature = "image")]
300    /// Image generation model capability marker.
301    type ImageGeneration: Capability;
302    #[cfg(feature = "audio")]
303    /// Audio generation model capability marker.
304    type AudioGeneration: Capability;
305}
306
307/// An API provider extension *builder*, this abstracts over provider-specific builders which are
308/// able to configure and produce a given provider's extension type
309///
310/// See [Provider]
311pub trait ProviderBuilder: Sized + Default + Clone {
312    /// Provider extension type built for a concrete HTTP backend.
313    type Extension<H>: Provider
314    where
315        H: HttpClientExt;
316    /// API key input type accepted by the provider's client builder.
317    type ApiKey: ApiKey;
318
319    /// Default base URL for the provider.
320    const BASE_URL: &'static str;
321
322    /// Build the provider extension from the client builder configuration.
323    fn build<H>(
324        builder: &ClientBuilder<Self, Self::ApiKey, H>,
325    ) -> http_client::Result<Self::Extension<H>>
326    where
327        H: HttpClientExt;
328
329    /// This method can be used to customize the fields of `builder` before it is used to create
330    /// a client. For example, adding default headers
331    fn finish<H>(
332        &self,
333        builder: ClientBuilder<Self, Self::ApiKey, H>,
334    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
335        Ok(builder)
336    }
337}
338
339/// `new` is pinned to `H = reqwest::Client` so the call site infers without an explicit `H`
340/// annotation. Callers who want a different backend should go through [`Client::builder`] and
341/// chain [`ClientBuilder::http_client`] before [`ClientBuilder::build`].
342impl<Ext> Client<Ext, reqwest::Client>
343where
344    Ext: Provider,
345    Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
346{
347    /// Construct a provider client using the default `reqwest::Client` backend.
348    pub fn new(
349        api_key: impl Into<<Ext::Builder as ProviderBuilder>::ApiKey>,
350    ) -> http_client::Result<Self> {
351        Self::builder().api_key(api_key).build()
352    }
353}
354
355impl<Ext, H> Client<Ext, H> {
356    /// Returns the configured provider base URL.
357    pub fn base_url(&self) -> &str {
358        &self.base_url
359    }
360
361    /// Returns default headers applied to outgoing provider requests.
362    pub fn headers(&self) -> &HeaderMap {
363        &self.headers
364    }
365
366    /// Returns the provider extension.
367    pub fn ext(&self) -> &Ext {
368        &self.ext
369    }
370
371    /// Reuse this client's base URL, headers, and HTTP backend with a different extension.
372    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
373        Client {
374            base_url: self.base_url,
375            headers: self.headers,
376            http_client: self.http_client,
377            ext: new_ext,
378        }
379    }
380}
381
382impl<Ext, H> HttpClientExt for Client<Ext, H>
383where
384    H: HttpClientExt + 'static,
385    Ext: WasmCompatSend + WasmCompatSync + 'static,
386{
387    fn send<T, U>(
388        &self,
389        mut req: Request<T>,
390    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
391    where
392        T: Into<Bytes> + WasmCompatSend,
393        U: From<Bytes>,
394        U: WasmCompatSend + 'static,
395    {
396        req.headers_mut().insert(
397            http::header::CONTENT_TYPE,
398            http::HeaderValue::from_static("application/json"),
399        );
400
401        self.http_client.send(req)
402    }
403
404    fn send_multipart<U>(
405        &self,
406        req: Request<MultipartForm>,
407    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
408    where
409        U: From<Bytes>,
410        U: WasmCompatSend + 'static,
411    {
412        self.http_client.send_multipart(req)
413    }
414
415    fn send_streaming<T>(
416        &self,
417        mut req: Request<T>,
418    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
419    where
420        T: Into<Bytes> + WasmCompatSend,
421    {
422        req.headers_mut().insert(
423            http::header::CONTENT_TYPE,
424            http::HeaderValue::from_static("application/json"),
425        );
426
427        self.http_client.send_streaming(req)
428    }
429}
430
431/// `builder()` is anchored on `Client<Ext, reqwest::Client>` purely as an inference hook so that
432/// `provider::Client::builder()` resolves without a `H` annotation. The returned builder itself
433/// has `H = Missing`, accurately reflecting that no backend has been chosen yet; the eventual
434/// `Client` produced by `build()` may end up with any HTTP backend depending on whether
435/// [`ClientBuilder::http_client`] was called.
436impl<Ext> Client<Ext, reqwest::Client>
437where
438    Ext: Provider,
439    Ext::Builder: ProviderBuilder + Default,
440{
441    /// Start constructing a provider client.
442    pub fn builder() -> ClientBuilder<Ext::Builder, Missing, Missing> {
443        ClientBuilder::default()
444    }
445}
446
447impl<Ext, H> Client<Ext, H>
448where
449    Ext: Provider,
450{
451    /// Build a provider-customized POST request for a regular HTTP endpoint.
452    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
453    where
454        S: AsRef<str>,
455    {
456        let uri = self
457            .ext
458            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
459
460        let mut req = Request::post(uri);
461
462        if let Some(hs) = req.headers_mut() {
463            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
464        }
465
466        self.ext.with_custom(req)
467    }
468
469    /// Build a provider-customized POST request for an SSE endpoint.
470    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
471    where
472        S: AsRef<str>,
473    {
474        let uri = self
475            .ext
476            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
477
478        let mut req = Request::post(uri);
479
480        if let Some(hs) = req.headers_mut() {
481            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
482        }
483
484        self.ext.with_custom(req)
485    }
486
487    /// Build a provider-customized GET request for an SSE endpoint.
488    pub fn get_sse<S>(&self, path: S) -> http_client::Result<Builder>
489    where
490        S: AsRef<str>,
491    {
492        let uri = self
493            .ext
494            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
495
496        let mut req = Request::get(uri);
497
498        if let Some(hs) = req.headers_mut() {
499            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
500        }
501
502        self.ext.with_custom(req)
503    }
504
505    /// Build a provider-customized GET request for a regular HTTP endpoint.
506    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
507    where
508        S: AsRef<str>,
509    {
510        let uri = self
511            .ext
512            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
513
514        let mut req = Request::get(uri);
515
516        if let Some(hs) = req.headers_mut() {
517            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
518        }
519
520        self.ext.with_custom(req)
521    }
522}
523
524impl<Ext, H> VerifyClient for Client<Ext, H>
525where
526    H: HttpClientExt,
527    Ext: DebugExt + Provider + WasmCompatSync,
528{
529    async fn verify(&self) -> Result<(), VerifyError> {
530        use http::StatusCode;
531
532        let req = self
533            .get(Ext::VERIFY_PATH)?
534            .body(http_client::NoBody)
535            .map_err(http_client::Error::from)?;
536
537        let response = self.http_client.send(req).await?;
538
539        match response.status() {
540            StatusCode::OK => Ok(()),
541            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
542                Err(VerifyError::InvalidAuthentication)
543            }
544            StatusCode::INTERNAL_SERVER_ERROR => {
545                let text = http_client::text(response).await?;
546                Err(VerifyError::ProviderError(text))
547            }
548            status if status.as_u16() == 529 => {
549                let text = http_client::text(response).await?;
550                Err(VerifyError::ProviderError(text))
551            }
552            _ => {
553                let status = response.status();
554
555                if status.is_success() {
556                    Ok(())
557                } else {
558                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
559                    Err(VerifyError::HttpError(http_client::Error::Instance(
560                        format!("Failed with '{status}': {text}").into(),
561                    )))
562                }
563            }
564        }
565    }
566}
567
568/// Type-state builder for [`Client`].
569///
570/// Each generic slot encodes a separate "has the user supplied this yet?" question:
571///
572/// - `ApiKey = Missing` means the caller has not yet called [`Self::api_key`]; transitioning to a
573///   concrete `ApiKey` type is required before [`Self::build`] is reachable.
574/// - `H = Missing` means the caller has not yet called [`Self::http_client`]; in that state
575///   `build()` substitutes the canonical `reqwest::Client` backend at construction time. Once a
576///   backend has been supplied, `H` is the concrete HTTP client type and `build()` uses it
577///   directly.
578///
579/// Keeping `Missing` as the *type-level* placeholder (rather than reusing `reqwest::Client`)
580/// means the builder's generics describe what the caller has actually provided, instead of
581/// pretending a default value is already present. It also avoids carrying an `Option<H>` whose
582/// `None` branch existed only to model the same "user hasn't picked a backend" state.
583#[derive(Clone)]
584pub struct ClientBuilder<Ext, ApiKey = Missing, H = Missing> {
585    base_url: String,
586    api_key: ApiKey,
587    headers: HeaderMap,
588    http_client: H,
589    ext: Ext,
590}
591
592impl<ExtBuilder> Default for ClientBuilder<ExtBuilder, Missing, Missing>
593where
594    ExtBuilder: ProviderBuilder + Default,
595{
596    fn default() -> Self {
597        Self {
598            api_key: Missing,
599            headers: Default::default(),
600            base_url: ExtBuilder::BASE_URL.into(),
601            http_client: Missing,
602            ext: Default::default(),
603        }
604    }
605}
606
607impl<Ext, H> ClientBuilder<Ext, Missing, H> {
608    /// Set the API key for this client. This *must* be done before the `build` method can be
609    /// called
610    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
611        ClientBuilder {
612            api_key: api_key.into(),
613            base_url: self.base_url,
614            headers: self.headers,
615            http_client: self.http_client,
616            ext: self.ext,
617        }
618    }
619}
620
621impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
622where
623    Ext: Clone,
624{
625    /// Owned map over the ext field
626    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
627    where
628        F: FnOnce(Ext) -> NewExt,
629    {
630        let ClientBuilder {
631            base_url,
632            api_key,
633            headers,
634            http_client,
635            ext,
636        } = self;
637
638        let new_ext = f(ext.clone());
639
640        ClientBuilder {
641            base_url,
642            api_key,
643            headers,
644            http_client,
645            ext: new_ext,
646        }
647    }
648
649    /// Set the base URL for this client
650    pub fn base_url<S>(self, base_url: S) -> Self
651    where
652        S: AsRef<str>,
653    {
654        Self {
655            base_url: base_url.as_ref().to_string(),
656            ..self
657        }
658    }
659
660    /// Set the HTTP backend used in this client.
661    ///
662    /// Calling this advances the builder's `H` slot from whatever it was (typically `Missing`)
663    /// to the supplied client's type, which selects the H-generic [`Self::build`] impl below.
664    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
665        ClientBuilder {
666            http_client,
667            base_url: self.base_url,
668            api_key: self.api_key,
669            headers: self.headers,
670            ext: self.ext,
671        }
672    }
673
674    /// Set the HTTP headers used in this client
675    pub fn http_headers(self, headers: HeaderMap) -> Self {
676        Self { headers, ..self }
677    }
678
679    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
680        &mut self.headers
681    }
682
683    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
684        &mut self.ext
685    }
686}
687
688impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
689    pub(crate) fn get_api_key(&self) -> &ApiKey {
690        &self.api_key
691    }
692}
693
694impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
695    /// Returns the provider extension builder state.
696    pub fn ext(&self) -> &Ext {
697        &self.ext
698    }
699
700    /// Returns the configured base URL.
701    pub fn get_base_url(&self) -> &str {
702        &self.base_url
703    }
704}
705
706/// Default-backend `build`: when the caller never called [`ClientBuilder::http_client`], the
707/// builder's `H` slot is still `Missing`, and we substitute the canonical `reqwest::Client` at
708/// build time. This is the only place in the crate that knows about that default, and it is
709/// disjoint by trait bound from the H-generic `build` below (`Missing` does not implement
710/// [`HttpClientExt`]).
711impl<ExtBuilder, Key> ClientBuilder<ExtBuilder, Key, Missing>
712where
713    ExtBuilder: ProviderBuilder<ApiKey = Key>,
714    Key: ApiKey,
715{
716    /// Build a client using the default `reqwest::Client` backend.
717    pub fn build(
718        self,
719    ) -> http_client::Result<Client<ExtBuilder::Extension<reqwest::Client>, reqwest::Client>> {
720        self.http_client(reqwest::Client::default()).build()
721    }
722}
723
724/// Concrete-backend `build`: the caller supplied an HTTP client via
725/// [`ClientBuilder::http_client`], so `H` is a real `HttpClientExt` type and we use it directly.
726impl<ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
727where
728    ExtBuilder: ProviderBuilder<ApiKey = Key>,
729    Key: ApiKey,
730    H: HttpClientExt,
731{
732    /// Build a client using the HTTP backend supplied with [`ClientBuilder::http_client`].
733    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Extension<H>, H>> {
734        let ext_builder = self.ext.clone();
735
736        self = ext_builder.finish(self)?;
737        let ext = ExtBuilder::build(&self)?;
738
739        let ClientBuilder {
740            http_client,
741            base_url,
742            mut headers,
743            api_key,
744            ..
745        } = self;
746
747        if let Some((k, v)) = api_key.into_header().transpose()?
748            && !headers.contains_key(&k)
749        {
750            headers.insert(k, v);
751        }
752
753        Ok(Client {
754            http_client,
755            base_url: Arc::from(base_url.as_str()),
756            headers: Arc::new(headers),
757            ext,
758        })
759    }
760}
761
762impl<M, Ext, H> CompletionClient for Client<Ext, H>
763where
764    Ext: Capabilities<H, Completion = Capable<M>>,
765    M: CompletionModel<Client = Self>,
766{
767    type CompletionModel = M;
768
769    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
770        M::make(self, model)
771    }
772}
773
774impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
775where
776    Ext: Capabilities<H, Embeddings = Capable<M>>,
777    M: EmbeddingModel<Client = Self>,
778{
779    type EmbeddingModel = M;
780
781    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
782        M::make(self, model, None)
783    }
784
785    fn embedding_model_with_ndims(
786        &self,
787        model: impl Into<String>,
788        ndims: usize,
789    ) -> Self::EmbeddingModel {
790        M::make(self, model, Some(ndims))
791    }
792}
793
794impl<M, Ext, H> RerankingClient for Client<Ext, H>
795where
796    Ext: Capabilities<H, Rerank = Capable<M>>,
797    M: RerankModel<Client = Self>,
798{
799    type RerankModel = M;
800
801    fn rerank_model(&self, model: impl Into<String>) -> Self::RerankModel {
802        M::make(self, model)
803    }
804}
805
806impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
807where
808    Ext: Capabilities<H, Transcription = Capable<M>>,
809    M: TranscriptionModel<Client = Self> + WasmCompatSend,
810{
811    type TranscriptionModel = M;
812
813    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
814        M::make(self, model)
815    }
816}
817
818#[cfg(feature = "image")]
819impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
820where
821    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
822    M: ImageGenerationModel<Client = Self>,
823{
824    type ImageGenerationModel = M;
825
826    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
827        M::make(self, model)
828    }
829}
830
831#[cfg(feature = "audio")]
832impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
833where
834    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
835    M: AudioGenerationModel<Client = Self>,
836{
837    type AudioGenerationModel = M;
838
839    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
840        M::make(self, model)
841    }
842}
843
844impl<M, Ext, H> ModelListingClient for Client<Ext, H>
845where
846    Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
847    M: ModelLister<H, Client = Self> + WasmCompatSend + WasmCompatSync + Clone + 'static,
848    H: WasmCompatSend + WasmCompatSync + Clone,
849{
850    fn list_models(
851        &self,
852    ) -> impl std::future::Future<
853        Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
854    > + WasmCompatSend {
855        let lister = M::new(self.clone());
856        async move { lister.list_all().await }
857    }
858}
859
860#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
861mod wasm_model_listing_compile_checks {
862    use super::{ModelListingClient, Nothing};
863    use crate::{
864        http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response},
865        providers::{anthropic, deepseek, mistral, ollama, openai, openrouter},
866        wasm_compat::WasmCompatSend,
867    };
868    use bytes::Bytes;
869    use std::{
870        future::{self, Future},
871        marker::PhantomData,
872        rc::Rc,
873    };
874
875    #[derive(Clone, Default)]
876    struct WasmOnlyHttpClient {
877        _not_send_sync: PhantomData<Rc<()>>,
878    }
879
880    impl HttpClientExt for WasmOnlyHttpClient {
881        fn send<T, U>(
882            &self,
883            _req: Request<T>,
884        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
885        where
886            T: Into<Bytes> + WasmCompatSend,
887            U: From<Bytes> + WasmCompatSend + 'static,
888        {
889            future::ready(Err(http_client::Error::StreamEnded))
890        }
891
892        fn send_multipart<U>(
893            &self,
894            _req: Request<MultipartForm>,
895        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
896        where
897            U: From<Bytes> + WasmCompatSend + 'static,
898        {
899            future::ready(Err(http_client::Error::StreamEnded))
900        }
901
902        fn send_streaming<T>(
903            &self,
904            _req: Request<T>,
905        ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
906        where
907            T: Into<Bytes> + WasmCompatSend,
908        {
909            future::ready(Err(http_client::Error::StreamEnded))
910        }
911    }
912
913    fn assert_model_listing_client<C>(client: C)
914    where
915        C: ModelListingClient,
916    {
917        let _ = client.list_models();
918    }
919
920    fn assert_simple_model_listers_accept_wasm_only_http_clients() {
921        let _ = openrouter::Client::builder()
922            .api_key("dummy-key")
923            .http_client(WasmOnlyHttpClient::default())
924            .build()
925            .map(assert_model_listing_client);
926
927        let _ = openai::Client::builder()
928            .api_key("dummy-key")
929            .http_client(WasmOnlyHttpClient::default())
930            .build()
931            .map(assert_model_listing_client);
932
933        let _ = mistral::Client::builder()
934            .api_key("dummy-key")
935            .http_client(WasmOnlyHttpClient::default())
936            .build()
937            .map(assert_model_listing_client);
938
939        let _ = anthropic::Client::builder()
940            .api_key("dummy-key")
941            .http_client(WasmOnlyHttpClient::default())
942            .build()
943            .map(assert_model_listing_client);
944
945        let _ = ollama::Client::builder()
946            .api_key(Nothing)
947            .http_client(WasmOnlyHttpClient::default())
948            .build()
949            .map(assert_model_listing_client);
950
951        let _ = deepseek::Client::builder()
952            .api_key("dummy-key")
953            .http_client(WasmOnlyHttpClient::default())
954            .build()
955            .map(assert_model_listing_client);
956    }
957
958    #[allow(dead_code)]
959    fn compile_assertions() {
960        assert_simple_model_listers_accept_wasm_only_http_clients();
961    }
962}
963
964#[cfg(test)]
965mod tests {
966    use crate::providers::anthropic;
967
968    /// Type-level test that `Client::builder()` methods do not require annotation to determine
969    /// backig HTTP client
970    #[test]
971    fn ensures_client_builder_no_annotation() {
972        let http_client = reqwest::Client::default();
973        let _ = anthropic::Client::builder()
974            .http_client(http_client)
975            .api_key("Foo")
976            .build()
977            .unwrap();
978    }
979}