Skip to main content

rig_core/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3
4pub mod audio_generation;
5pub mod completion;
6pub mod embeddings;
7pub mod image_generation;
8pub mod model_listing;
9pub mod transcription;
10pub mod verify;
11
12use bytes::Bytes;
13pub use completion::CompletionClient;
14pub use embeddings::EmbeddingsClient;
15use http::{HeaderMap, HeaderName, HeaderValue};
16pub use model_listing::{ModelLister, ModelListingClient};
17use std::{env::VarError, fmt::Debug, marker::PhantomData, sync::Arc};
18use thiserror::Error;
19pub use verify::{VerifyClient, VerifyError};
20
21#[cfg(feature = "image")]
22use crate::image_generation::ImageGenerationModel;
23#[cfg(feature = "image")]
24use image_generation::ImageGenerationClient;
25
26#[cfg(feature = "audio")]
27use crate::audio_generation::*;
28#[cfg(feature = "audio")]
29use audio_generation::*;
30
31use crate::{
32    completion::CompletionModel,
33    embeddings::EmbeddingModel,
34    http_client::{
35        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
36    },
37    markers::Missing,
38    prelude::TranscriptionClient,
39    transcription::TranscriptionModel,
40    wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46    /// The underlying HTTP backend failed during builder construction.
47    #[error("reqwest error: {0}")]
48    HttpError(
49        #[from]
50        #[source]
51        reqwest::Error,
52    ),
53    /// A provider-specific builder property was invalid.
54    #[error("invalid property: {0}")]
55    InvalidProperty(&'static str),
56}
57
58/// Errors returned while constructing provider clients from environment variables or explicit input.
59///
60/// Provider-specific client constructors use this error for configuration problems that can be
61/// detected before any model request is sent, such as missing API keys, invalid environment
62/// values, or invalid builder configuration.
63#[derive(Debug, Error)]
64#[non_exhaustive]
65pub enum ProviderClientError {
66    /// A required or optional environment variable could not be read as valid Unicode.
67    ///
68    /// For required variables, this variant is also returned when the variable is not present.
69    #[error("environment variable `{name}` is not set or is invalid")]
70    EnvironmentVariable {
71        /// The environment variable name.
72        name: &'static str,
73        /// The underlying environment lookup error.
74        #[source]
75        source: VarError,
76    },
77    /// The underlying provider client builder failed while constructing HTTP configuration.
78    #[error(transparent)]
79    Http(#[from] http_client::Error),
80    /// The provider received an unsupported or incomplete configuration.
81    #[error("{0}")]
82    InvalidConfiguration(&'static str),
83}
84
85/// Result type returned by provider client construction helpers.
86pub type ProviderClientResult<T> = std::result::Result<T, ProviderClientError>;
87
88/// Read a required environment variable for provider client construction.
89///
90/// Returns [`ProviderClientError::EnvironmentVariable`] when the variable is missing or contains
91/// invalid Unicode.
92pub fn required_env_var(name: &'static str) -> ProviderClientResult<String> {
93    std::env::var(name).map_err(|source| ProviderClientError::EnvironmentVariable { name, source })
94}
95
96/// Read an optional environment variable for provider client construction.
97///
98/// Missing variables return `Ok(None)`. Variables containing invalid Unicode return
99/// [`ProviderClientError::EnvironmentVariable`].
100pub fn optional_env_var(name: &'static str) -> ProviderClientResult<Option<String>> {
101    match std::env::var(name) {
102        Ok(value) => Ok(Some(value)),
103        Err(VarError::NotPresent) => Ok(None),
104        Err(source) => Err(ProviderClientError::EnvironmentVariable { name, source }),
105    }
106}
107
108/// Abstracts over the ability to instantiate a client, either via environment variables or some
109/// `Self::Input`
110pub trait ProviderClient {
111    /// Input accepted by [`ProviderClient::from_val`].
112    type Input;
113    /// Error returned when client construction fails.
114    type Error;
115
116    /// Create a client from the process's environment.
117    fn from_env() -> Result<Self, Self::Error>
118    where
119        Self: Sized;
120
121    /// Create a client from an explicit provider-specific input value.
122    fn from_val(input: Self::Input) -> Result<Self, Self::Error>
123    where
124        Self: Sized;
125}
126
127/// A trait for API key inputs accepted by [`ClientBuilder::api_key`].
128///
129/// Returning `Some` inserts a header into the generic [`Client`]. Returning `None`
130/// lets the provider extension handle credentials itself.
131pub trait ApiKey: Sized {
132    /// Convert this key into a default request header, if the generic client
133    /// should own that authentication header.
134    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
135        None
136    }
137}
138
139/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
140pub struct BearerAuth(String);
141
142impl ApiKey for BearerAuth {
143    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
144        Some(make_auth_header(self.0))
145    }
146}
147
148impl<S> From<S> for BearerAuth
149where
150    S: Into<String>,
151{
152    fn from(value: S) -> Self {
153        Self(value.into())
154    }
155}
156
157/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
158/// the lack of a capability or field (an API key, for instance)
159#[derive(Debug, Default, Clone, Copy)]
160pub struct Nothing;
161
162impl ApiKey for Nothing {}
163
164impl TryFrom<String> for Nothing {
165    type Error = &'static str;
166
167    fn try_from(_: String) -> Result<Self, Self::Error> {
168        Err(
169            "Tried to create a Nothing from a string - this should not happen, please file an issue",
170        )
171    }
172}
173
174#[derive(Clone)]
175/// Generic provider client shared by Rig provider integrations.
176///
177/// `Ext` stores provider-specific behavior such as URL construction, request
178/// customization, and capabilities. `H` is the HTTP backend and defaults to
179/// `reqwest::Client`.
180pub struct Client<Ext = Nothing, H = reqwest::Client> {
181    base_url: Arc<str>,
182    headers: Arc<HeaderMap>,
183    http_client: H,
184    ext: Ext,
185}
186
187/// Provider extension hook for redacted [`Debug`] output.
188pub trait DebugExt: Debug {
189    /// Additional provider-specific fields to include in `Client` debug output.
190    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
191        std::iter::empty()
192    }
193}
194
195impl<Ext, H> std::fmt::Debug for Client<Ext, H>
196where
197    Ext: DebugExt,
198    H: std::fmt::Debug,
199{
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        let mut d = &mut f.debug_struct("Client");
202
203        d = d
204            .field("base_url", &self.base_url)
205            .field(
206                "headers",
207                &self
208                    .headers
209                    .iter()
210                    .filter_map(|(k, v)| {
211                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
212                            None
213                        } else {
214                            Some((k, v))
215                        }
216                    })
217                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
218            )
219            .field("http_client", &self.http_client);
220
221        self.ext
222            .fields()
223            .fold(d, |d, (name, field)| d.field(name, field))
224            .finish()
225    }
226}
227
228pub enum Transport {
229    /// Regular request/response HTTP transport.
230    Http,
231    /// Server-sent events streaming transport.
232    Sse,
233    /// Newline-delimited JSON streaming transport.
234    NdJson,
235}
236
237/// An API provider extension, this abstracts over extensions which may be used in conjunction with
238/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
239/// auth, instantiating models
240pub trait Provider: Sized {
241    /// The builder type that constructs this provider extension.
242    /// This associates extensions with their builders for type inference.
243    type Builder: ProviderBuilder;
244
245    /// Provider endpoint used by [`VerifyClient`] to validate credentials.
246    const VERIFY_PATH: &'static str;
247
248    /// Build a complete request URI for the given base URL, provider path, and transport.
249    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
250        // Some providers (like Azure) have a blank base URL to allow users to input their own endpoints.
251        let base_url = if base_url.is_empty() {
252            base_url.to_string()
253        } else {
254            base_url.to_string() + "/"
255        };
256
257        base_url.to_string() + path.trim_start_matches('/')
258    }
259
260    /// Apply provider-specific request customization before sending.
261    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
262        Ok(req)
263    }
264}
265
266/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
267pub struct Capable<M>(PhantomData<M>);
268
269/// Type-level marker for whether a provider supports a capability.
270pub trait Capability {
271    /// Whether this marker represents a supported capability.
272    const CAPABLE: bool;
273}
274
275impl<M> Capability for Capable<M> {
276    const CAPABLE: bool = true;
277}
278
279impl Capability for Nothing {
280    const CAPABLE: bool = false;
281}
282
283/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
284pub trait Capabilities<H = reqwest::Client> {
285    /// Completion model capability marker.
286    type Completion: Capability;
287    /// Embedding model capability marker.
288    type Embeddings: Capability;
289    /// Audio transcription model capability marker.
290    type Transcription: Capability;
291    /// Model listing capability marker.
292    type ModelListing: Capability;
293    #[cfg(feature = "image")]
294    /// Image generation model capability marker.
295    type ImageGeneration: Capability;
296    #[cfg(feature = "audio")]
297    /// Audio generation model capability marker.
298    type AudioGeneration: Capability;
299}
300
301/// An API provider extension *builder*, this abstracts over provider-specific builders which are
302/// able to configure and produce a given provider's extension type
303///
304/// See [Provider]
305pub trait ProviderBuilder: Sized + Default + Clone {
306    /// Provider extension type built for a concrete HTTP backend.
307    type Extension<H>: Provider
308    where
309        H: HttpClientExt;
310    /// API key input type accepted by the provider's client builder.
311    type ApiKey: ApiKey;
312
313    /// Default base URL for the provider.
314    const BASE_URL: &'static str;
315
316    /// Build the provider extension from the client builder configuration.
317    fn build<H>(
318        builder: &ClientBuilder<Self, Self::ApiKey, H>,
319    ) -> http_client::Result<Self::Extension<H>>
320    where
321        H: HttpClientExt;
322
323    /// This method can be used to customize the fields of `builder` before it is used to create
324    /// a client. For example, adding default headers
325    fn finish<H>(
326        &self,
327        builder: ClientBuilder<Self, Self::ApiKey, H>,
328    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
329        Ok(builder)
330    }
331}
332
333/// `new` is pinned to `H = reqwest::Client` so the call site infers without an explicit `H`
334/// annotation. Callers who want a different backend should go through [`Client::builder`] and
335/// chain [`ClientBuilder::http_client`] before [`ClientBuilder::build`].
336impl<Ext> Client<Ext, reqwest::Client>
337where
338    Ext: Provider,
339    Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
340{
341    /// Construct a provider client using the default `reqwest::Client` backend.
342    pub fn new(
343        api_key: impl Into<<Ext::Builder as ProviderBuilder>::ApiKey>,
344    ) -> http_client::Result<Self> {
345        Self::builder().api_key(api_key).build()
346    }
347}
348
349impl<Ext, H> Client<Ext, H> {
350    /// Returns the configured provider base URL.
351    pub fn base_url(&self) -> &str {
352        &self.base_url
353    }
354
355    /// Returns default headers applied to outgoing provider requests.
356    pub fn headers(&self) -> &HeaderMap {
357        &self.headers
358    }
359
360    /// Returns the provider extension.
361    pub fn ext(&self) -> &Ext {
362        &self.ext
363    }
364
365    /// Reuse this client's base URL, headers, and HTTP backend with a different extension.
366    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
367        Client {
368            base_url: self.base_url,
369            headers: self.headers,
370            http_client: self.http_client,
371            ext: new_ext,
372        }
373    }
374}
375
376impl<Ext, H> HttpClientExt for Client<Ext, H>
377where
378    H: HttpClientExt + 'static,
379    Ext: WasmCompatSend + WasmCompatSync + 'static,
380{
381    fn send<T, U>(
382        &self,
383        mut req: Request<T>,
384    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
385    where
386        T: Into<Bytes> + WasmCompatSend,
387        U: From<Bytes>,
388        U: WasmCompatSend + 'static,
389    {
390        req.headers_mut().insert(
391            http::header::CONTENT_TYPE,
392            http::HeaderValue::from_static("application/json"),
393        );
394
395        self.http_client.send(req)
396    }
397
398    fn send_multipart<U>(
399        &self,
400        req: Request<MultipartForm>,
401    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
402    where
403        U: From<Bytes>,
404        U: WasmCompatSend + 'static,
405    {
406        self.http_client.send_multipart(req)
407    }
408
409    fn send_streaming<T>(
410        &self,
411        mut req: Request<T>,
412    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
413    where
414        T: Into<Bytes> + WasmCompatSend,
415    {
416        req.headers_mut().insert(
417            http::header::CONTENT_TYPE,
418            http::HeaderValue::from_static("application/json"),
419        );
420
421        self.http_client.send_streaming(req)
422    }
423}
424
425/// `builder()` is anchored on `Client<Ext, reqwest::Client>` purely as an inference hook so that
426/// `provider::Client::builder()` resolves without a `H` annotation. The returned builder itself
427/// has `H = Missing`, accurately reflecting that no backend has been chosen yet; the eventual
428/// `Client` produced by `build()` may end up with any HTTP backend depending on whether
429/// [`ClientBuilder::http_client`] was called.
430impl<Ext> Client<Ext, reqwest::Client>
431where
432    Ext: Provider,
433    Ext::Builder: ProviderBuilder + Default,
434{
435    /// Start constructing a provider client.
436    pub fn builder() -> ClientBuilder<Ext::Builder, Missing, Missing> {
437        ClientBuilder::default()
438    }
439}
440
441impl<Ext, H> Client<Ext, H>
442where
443    Ext: Provider,
444{
445    /// Build a provider-customized POST request for a regular HTTP endpoint.
446    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
447    where
448        S: AsRef<str>,
449    {
450        let uri = self
451            .ext
452            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
453
454        let mut req = Request::post(uri);
455
456        if let Some(hs) = req.headers_mut() {
457            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
458        }
459
460        self.ext.with_custom(req)
461    }
462
463    /// Build a provider-customized POST request for an SSE endpoint.
464    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
465    where
466        S: AsRef<str>,
467    {
468        let uri = self
469            .ext
470            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
471
472        let mut req = Request::post(uri);
473
474        if let Some(hs) = req.headers_mut() {
475            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
476        }
477
478        self.ext.with_custom(req)
479    }
480
481    /// Build a provider-customized GET request for an SSE endpoint.
482    pub fn get_sse<S>(&self, path: S) -> http_client::Result<Builder>
483    where
484        S: AsRef<str>,
485    {
486        let uri = self
487            .ext
488            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
489
490        let mut req = Request::get(uri);
491
492        if let Some(hs) = req.headers_mut() {
493            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
494        }
495
496        self.ext.with_custom(req)
497    }
498
499    /// Build a provider-customized GET request for a regular HTTP endpoint.
500    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
501    where
502        S: AsRef<str>,
503    {
504        let uri = self
505            .ext
506            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
507
508        let mut req = Request::get(uri);
509
510        if let Some(hs) = req.headers_mut() {
511            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
512        }
513
514        self.ext.with_custom(req)
515    }
516}
517
518impl<Ext, H> VerifyClient for Client<Ext, H>
519where
520    H: HttpClientExt,
521    Ext: DebugExt + Provider + WasmCompatSync,
522{
523    async fn verify(&self) -> Result<(), VerifyError> {
524        use http::StatusCode;
525
526        let req = self
527            .get(Ext::VERIFY_PATH)?
528            .body(http_client::NoBody)
529            .map_err(http_client::Error::from)?;
530
531        let response = self.http_client.send(req).await?;
532
533        match response.status() {
534            StatusCode::OK => Ok(()),
535            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
536                Err(VerifyError::InvalidAuthentication)
537            }
538            StatusCode::INTERNAL_SERVER_ERROR => {
539                let text = http_client::text(response).await?;
540                Err(VerifyError::ProviderError(text))
541            }
542            status if status.as_u16() == 529 => {
543                let text = http_client::text(response).await?;
544                Err(VerifyError::ProviderError(text))
545            }
546            _ => {
547                let status = response.status();
548
549                if status.is_success() {
550                    Ok(())
551                } else {
552                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
553                    Err(VerifyError::HttpError(http_client::Error::Instance(
554                        format!("Failed with '{status}': {text}").into(),
555                    )))
556                }
557            }
558        }
559    }
560}
561
562/// Type-state builder for [`Client`].
563///
564/// Each generic slot encodes a separate "has the user supplied this yet?" question:
565///
566/// - `ApiKey = Missing` means the caller has not yet called [`Self::api_key`]; transitioning to a
567///   concrete `ApiKey` type is required before [`Self::build`] is reachable.
568/// - `H = Missing` means the caller has not yet called [`Self::http_client`]; in that state
569///   `build()` substitutes the canonical `reqwest::Client` backend at construction time. Once a
570///   backend has been supplied, `H` is the concrete HTTP client type and `build()` uses it
571///   directly.
572///
573/// Keeping `Missing` as the *type-level* placeholder (rather than reusing `reqwest::Client`)
574/// means the builder's generics describe what the caller has actually provided, instead of
575/// pretending a default value is already present. It also avoids carrying an `Option<H>` whose
576/// `None` branch existed only to model the same "user hasn't picked a backend" state.
577#[derive(Clone)]
578pub struct ClientBuilder<Ext, ApiKey = Missing, H = Missing> {
579    base_url: String,
580    api_key: ApiKey,
581    headers: HeaderMap,
582    http_client: H,
583    ext: Ext,
584}
585
586impl<ExtBuilder> Default for ClientBuilder<ExtBuilder, Missing, Missing>
587where
588    ExtBuilder: ProviderBuilder + Default,
589{
590    fn default() -> Self {
591        Self {
592            api_key: Missing,
593            headers: Default::default(),
594            base_url: ExtBuilder::BASE_URL.into(),
595            http_client: Missing,
596            ext: Default::default(),
597        }
598    }
599}
600
601impl<Ext, H> ClientBuilder<Ext, Missing, H> {
602    /// Set the API key for this client. This *must* be done before the `build` method can be
603    /// called
604    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
605        ClientBuilder {
606            api_key: api_key.into(),
607            base_url: self.base_url,
608            headers: self.headers,
609            http_client: self.http_client,
610            ext: self.ext,
611        }
612    }
613}
614
615impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
616where
617    Ext: Clone,
618{
619    /// Owned map over the ext field
620    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
621    where
622        F: FnOnce(Ext) -> NewExt,
623    {
624        let ClientBuilder {
625            base_url,
626            api_key,
627            headers,
628            http_client,
629            ext,
630        } = self;
631
632        let new_ext = f(ext.clone());
633
634        ClientBuilder {
635            base_url,
636            api_key,
637            headers,
638            http_client,
639            ext: new_ext,
640        }
641    }
642
643    /// Set the base URL for this client
644    pub fn base_url<S>(self, base_url: S) -> Self
645    where
646        S: AsRef<str>,
647    {
648        Self {
649            base_url: base_url.as_ref().to_string(),
650            ..self
651        }
652    }
653
654    /// Set the HTTP backend used in this client.
655    ///
656    /// Calling this advances the builder's `H` slot from whatever it was (typically `Missing`)
657    /// to the supplied client's type, which selects the H-generic [`Self::build`] impl below.
658    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
659        ClientBuilder {
660            http_client,
661            base_url: self.base_url,
662            api_key: self.api_key,
663            headers: self.headers,
664            ext: self.ext,
665        }
666    }
667
668    /// Set the HTTP headers used in this client
669    pub fn http_headers(self, headers: HeaderMap) -> Self {
670        Self { headers, ..self }
671    }
672
673    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
674        &mut self.headers
675    }
676
677    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
678        &mut self.ext
679    }
680}
681
682impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
683    pub(crate) fn get_api_key(&self) -> &ApiKey {
684        &self.api_key
685    }
686}
687
688impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
689    /// Returns the provider extension builder state.
690    pub fn ext(&self) -> &Ext {
691        &self.ext
692    }
693
694    /// Returns the configured base URL.
695    pub fn get_base_url(&self) -> &str {
696        &self.base_url
697    }
698}
699
700/// Default-backend `build`: when the caller never called [`ClientBuilder::http_client`], the
701/// builder's `H` slot is still `Missing`, and we substitute the canonical `reqwest::Client` at
702/// build time. This is the only place in the crate that knows about that default, and it is
703/// disjoint by trait bound from the H-generic `build` below (`Missing` does not implement
704/// [`HttpClientExt`]).
705impl<ExtBuilder, Key> ClientBuilder<ExtBuilder, Key, Missing>
706where
707    ExtBuilder: ProviderBuilder<ApiKey = Key>,
708    Key: ApiKey,
709{
710    /// Build a client using the default `reqwest::Client` backend.
711    pub fn build(
712        self,
713    ) -> http_client::Result<Client<ExtBuilder::Extension<reqwest::Client>, reqwest::Client>> {
714        self.http_client(reqwest::Client::default()).build()
715    }
716}
717
718/// Concrete-backend `build`: the caller supplied an HTTP client via
719/// [`ClientBuilder::http_client`], so `H` is a real `HttpClientExt` type and we use it directly.
720impl<ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
721where
722    ExtBuilder: ProviderBuilder<ApiKey = Key>,
723    Key: ApiKey,
724    H: HttpClientExt,
725{
726    /// Build a client using the HTTP backend supplied with [`ClientBuilder::http_client`].
727    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Extension<H>, H>> {
728        let ext_builder = self.ext.clone();
729
730        self = ext_builder.finish(self)?;
731        let ext = ExtBuilder::build(&self)?;
732
733        let ClientBuilder {
734            http_client,
735            base_url,
736            mut headers,
737            api_key,
738            ..
739        } = self;
740
741        if let Some((k, v)) = api_key.into_header().transpose()?
742            && !headers.contains_key(&k)
743        {
744            headers.insert(k, v);
745        }
746
747        Ok(Client {
748            http_client,
749            base_url: Arc::from(base_url.as_str()),
750            headers: Arc::new(headers),
751            ext,
752        })
753    }
754}
755
756impl<M, Ext, H> CompletionClient for Client<Ext, H>
757where
758    Ext: Capabilities<H, Completion = Capable<M>>,
759    M: CompletionModel<Client = Self>,
760{
761    type CompletionModel = M;
762
763    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
764        M::make(self, model)
765    }
766}
767
768impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
769where
770    Ext: Capabilities<H, Embeddings = Capable<M>>,
771    M: EmbeddingModel<Client = Self>,
772{
773    type EmbeddingModel = M;
774
775    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
776        M::make(self, model, None)
777    }
778
779    fn embedding_model_with_ndims(
780        &self,
781        model: impl Into<String>,
782        ndims: usize,
783    ) -> Self::EmbeddingModel {
784        M::make(self, model, Some(ndims))
785    }
786}
787
788impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
789where
790    Ext: Capabilities<H, Transcription = Capable<M>>,
791    M: TranscriptionModel<Client = Self> + WasmCompatSend,
792{
793    type TranscriptionModel = M;
794
795    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
796        M::make(self, model)
797    }
798}
799
800#[cfg(feature = "image")]
801impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
802where
803    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
804    M: ImageGenerationModel<Client = Self>,
805{
806    type ImageGenerationModel = M;
807
808    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
809        M::make(self, model)
810    }
811}
812
813#[cfg(feature = "audio")]
814impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
815where
816    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
817    M: AudioGenerationModel<Client = Self>,
818{
819    type AudioGenerationModel = M;
820
821    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
822        M::make(self, model)
823    }
824}
825
826impl<M, Ext, H> ModelListingClient for Client<Ext, H>
827where
828    Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
829    M: ModelLister<H, Client = Self> + WasmCompatSend + WasmCompatSync + Clone + 'static,
830    H: WasmCompatSend + WasmCompatSync + Clone,
831{
832    fn list_models(
833        &self,
834    ) -> impl std::future::Future<
835        Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
836    > + WasmCompatSend {
837        let lister = M::new(self.clone());
838        async move { lister.list_all().await }
839    }
840}
841
842#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
843mod wasm_model_listing_compile_checks {
844    use super::{ModelListingClient, Nothing};
845    use crate::{
846        http_client::{self, HttpClientExt, LazyBody, MultipartForm, Request, Response},
847        providers::{anthropic, deepseek, mistral, ollama, openai, openrouter},
848        wasm_compat::WasmCompatSend,
849    };
850    use bytes::Bytes;
851    use std::{
852        future::{self, Future},
853        marker::PhantomData,
854        rc::Rc,
855    };
856
857    #[derive(Clone, Default)]
858    struct WasmOnlyHttpClient {
859        _not_send_sync: PhantomData<Rc<()>>,
860    }
861
862    impl HttpClientExt for WasmOnlyHttpClient {
863        fn send<T, U>(
864            &self,
865            _req: Request<T>,
866        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
867        where
868            T: Into<Bytes> + WasmCompatSend,
869            U: From<Bytes> + WasmCompatSend + 'static,
870        {
871            future::ready(Err(http_client::Error::StreamEnded))
872        }
873
874        fn send_multipart<U>(
875            &self,
876            _req: Request<MultipartForm>,
877        ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
878        where
879            U: From<Bytes> + WasmCompatSend + 'static,
880        {
881            future::ready(Err(http_client::Error::StreamEnded))
882        }
883
884        fn send_streaming<T>(
885            &self,
886            _req: Request<T>,
887        ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
888        where
889            T: Into<Bytes> + WasmCompatSend,
890        {
891            future::ready(Err(http_client::Error::StreamEnded))
892        }
893    }
894
895    fn assert_model_listing_client<C>(client: C)
896    where
897        C: ModelListingClient,
898    {
899        let _ = client.list_models();
900    }
901
902    fn assert_simple_model_listers_accept_wasm_only_http_clients() {
903        let _ = openrouter::Client::builder()
904            .api_key("dummy-key")
905            .http_client(WasmOnlyHttpClient::default())
906            .build()
907            .map(assert_model_listing_client);
908
909        let _ = openai::Client::builder()
910            .api_key("dummy-key")
911            .http_client(WasmOnlyHttpClient::default())
912            .build()
913            .map(assert_model_listing_client);
914
915        let _ = mistral::Client::builder()
916            .api_key("dummy-key")
917            .http_client(WasmOnlyHttpClient::default())
918            .build()
919            .map(assert_model_listing_client);
920
921        let _ = anthropic::Client::builder()
922            .api_key("dummy-key")
923            .http_client(WasmOnlyHttpClient::default())
924            .build()
925            .map(assert_model_listing_client);
926
927        let _ = ollama::Client::builder()
928            .api_key(Nothing)
929            .http_client(WasmOnlyHttpClient::default())
930            .build()
931            .map(assert_model_listing_client);
932
933        let _ = deepseek::Client::builder()
934            .api_key("dummy-key")
935            .http_client(WasmOnlyHttpClient::default())
936            .build()
937            .map(assert_model_listing_client);
938    }
939
940    #[allow(dead_code)]
941    fn compile_assertions() {
942        assert_simple_model_listers_accept_wasm_only_http_clients();
943    }
944}
945
946#[cfg(test)]
947mod tests {
948    use crate::providers::anthropic;
949
950    /// Type-level test that `Client::builder()` methods do not require annotation to determine
951    /// backig HTTP client
952    #[test]
953    fn ensures_client_builder_no_annotation() {
954        let http_client = reqwest::Client::default();
955        let _ = anthropic::Client::builder()
956            .http_client(http_client)
957            .api_key("Foo")
958            .build()
959            .unwrap();
960    }
961}