Skip to main content

rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod model_listing;
11pub mod transcription;
12pub mod verify;
13
14use bytes::Bytes;
15pub use completion::CompletionClient;
16pub use embeddings::EmbeddingsClient;
17use http::{HeaderMap, HeaderName, HeaderValue};
18pub use model_listing::{ModelLister, ModelListingClient};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, marker::PhantomData, sync::Arc};
21use thiserror::Error;
22pub use verify::{VerifyClient, VerifyError};
23
24#[cfg(feature = "image")]
25use crate::image_generation::ImageGenerationModel;
26#[cfg(feature = "image")]
27use image_generation::ImageGenerationClient;
28
29#[cfg(feature = "audio")]
30use crate::audio_generation::*;
31#[cfg(feature = "audio")]
32use audio_generation::*;
33
34use crate::{
35    completion::CompletionModel,
36    embeddings::EmbeddingModel,
37    http_client::{
38        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
39    },
40    prelude::TranscriptionClient,
41    transcription::TranscriptionModel,
42    wasm_compat::{WasmCompatSend, WasmCompatSync},
43};
44
45#[derive(Debug, Error)]
46#[non_exhaustive]
47pub enum ClientBuilderError {
48    #[error("reqwest error: {0}")]
49    HttpError(
50        #[from]
51        #[source]
52        reqwest::Error,
53    ),
54    #[error("invalid property: {0}")]
55    InvalidProperty(&'static str),
56}
57
58/// Abstracts over the ability to instantiate a client, either via environment variables or some
59/// `Self::Input`
60pub trait ProviderClient {
61    type Input;
62
63    /// Create a client from the process's environment.
64    /// Panics if an environment is improperly configured.
65    fn from_env() -> Self;
66
67    fn from_val(input: Self::Input) -> Self;
68}
69
70use crate::completion::{GetTokenUsage, Usage};
71
72/// The final streaming response from a dynamic client.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FinalCompletionResponse {
75    pub usage: Option<Usage>,
76}
77
78impl GetTokenUsage for FinalCompletionResponse {
79    fn token_usage(&self) -> Option<Usage> {
80        self.usage
81    }
82}
83
84/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
85/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
86pub trait ApiKey: Sized {
87    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
88        None
89    }
90}
91
92/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
93pub struct BearerAuth(String);
94
95impl ApiKey for BearerAuth {
96    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
97        Some(make_auth_header(self.0))
98    }
99}
100
101impl<S> From<S> for BearerAuth
102where
103    S: Into<String>,
104{
105    fn from(value: S) -> Self {
106        Self(value.into())
107    }
108}
109
110/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
111/// the lack of a capability or field (an API key, for instance)
112#[derive(Debug, Default, Clone, Copy)]
113pub struct Nothing;
114
115impl ApiKey for Nothing {}
116
117impl TryFrom<String> for Nothing {
118    type Error = &'static str;
119
120    fn try_from(_: String) -> Result<Self, Self::Error> {
121        Err(
122            "Tried to create a Nothing from a string - this should not happen, please file an issue",
123        )
124    }
125}
126
127#[derive(Clone)]
128pub struct Client<Ext = Nothing, H = reqwest::Client> {
129    base_url: Arc<str>,
130    headers: Arc<HeaderMap>,
131    http_client: H,
132    ext: Ext,
133}
134
135pub trait DebugExt: Debug {
136    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
137        std::iter::empty()
138    }
139}
140
141impl<Ext, H> std::fmt::Debug for Client<Ext, H>
142where
143    Ext: DebugExt,
144    H: std::fmt::Debug,
145{
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let mut d = &mut f.debug_struct("Client");
148
149        d = d
150            .field("base_url", &self.base_url)
151            .field(
152                "headers",
153                &self
154                    .headers
155                    .iter()
156                    .filter_map(|(k, v)| {
157                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
158                            None
159                        } else {
160                            Some((k, v))
161                        }
162                    })
163                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
164            )
165            .field("http_client", &self.http_client);
166
167        self.ext
168            .fields()
169            .fold(d, |d, (name, field)| d.field(name, field))
170            .finish()
171    }
172}
173
174pub enum Transport {
175    Http,
176    Sse,
177    NdJson,
178}
179
180/// An API provider extension, this abstracts over extensions which may be use in conjunction with
181/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
182/// auth, instantiating models
183pub trait Provider: Sized {
184    const VERIFY_PATH: &'static str;
185
186    type Builder: ProviderBuilder;
187
188    fn build<H>(
189        builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
190    ) -> http_client::Result<Self>;
191
192    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
193        // Some providers (like Azure) have a blank base URL to allow users to input their own endpoints.
194        let base_url = if base_url.is_empty() {
195            base_url.to_string()
196        } else {
197            base_url.to_string() + "/"
198        };
199
200        base_url.to_string() + path.trim_start_matches('/')
201    }
202
203    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
204        Ok(req)
205    }
206}
207
208/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
209pub struct Capable<M>(PhantomData<M>);
210
211pub trait Capability {
212    const CAPABLE: bool;
213}
214
215impl<M> Capability for Capable<M> {
216    const CAPABLE: bool = true;
217}
218
219impl Capability for Nothing {
220    const CAPABLE: bool = false;
221}
222
223/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
224pub trait Capabilities<H = reqwest::Client> {
225    type Completion: Capability;
226    type Embeddings: Capability;
227    type Transcription: Capability;
228    type ModelListing: Capability;
229    #[cfg(feature = "image")]
230    type ImageGeneration: Capability;
231    #[cfg(feature = "audio")]
232    type AudioGeneration: Capability;
233}
234
235/// An API provider extension *builder*, this abstracts over provider-specific builders which are
236/// able to configure and produce a given provider's extension type
237///
238/// See [Provider]
239pub trait ProviderBuilder: Sized {
240    type Output: Provider;
241    type ApiKey;
242
243    const BASE_URL: &'static str;
244
245    /// This method can be used to customize the fields of `builder` before it is used to create
246    /// a client. For example, adding default headers
247    fn finish<H>(
248        &self,
249        builder: ClientBuilder<Self, Self::ApiKey, H>,
250    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
251        Ok(builder)
252    }
253}
254
255impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
256where
257    ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
258    Ext: Provider<Builder = ExtBuilder>,
259    H: Default + HttpClientExt,
260    Key: ApiKey,
261{
262    pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
263        Self::builder().api_key(api_key).build()
264    }
265}
266
267impl<Ext, H> Client<Ext, H> {
268    pub fn base_url(&self) -> &str {
269        &self.base_url
270    }
271
272    pub fn headers(&self) -> &HeaderMap {
273        &self.headers
274    }
275
276    pub fn ext(&self) -> &Ext {
277        &self.ext
278    }
279
280    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
281        Client {
282            base_url: self.base_url,
283            headers: self.headers,
284            http_client: self.http_client,
285            ext: new_ext,
286        }
287    }
288}
289
290impl<Ext, H> HttpClientExt for Client<Ext, H>
291where
292    H: HttpClientExt + 'static,
293    Ext: WasmCompatSend + WasmCompatSync + 'static,
294{
295    fn send<T, U>(
296        &self,
297        mut req: Request<T>,
298    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
299    where
300        T: Into<Bytes> + WasmCompatSend,
301        U: From<Bytes>,
302        U: WasmCompatSend + 'static,
303    {
304        req.headers_mut().insert(
305            http::header::CONTENT_TYPE,
306            http::HeaderValue::from_static("application/json"),
307        );
308
309        self.http_client.send(req)
310    }
311
312    fn send_multipart<U>(
313        &self,
314        req: Request<MultipartForm>,
315    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
316    where
317        U: From<Bytes>,
318        U: WasmCompatSend + 'static,
319    {
320        self.http_client.send_multipart(req)
321    }
322
323    fn send_streaming<T>(
324        &self,
325        mut req: Request<T>,
326    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
327    where
328        T: Into<Bytes>,
329    {
330        req.headers_mut().insert(
331            http::header::CONTENT_TYPE,
332            http::HeaderValue::from_static("application/json"),
333        );
334
335        self.http_client.send_streaming(req)
336    }
337}
338
339impl<Ext, Builder, H> Client<Ext, H>
340where
341    H: Default + HttpClientExt,
342    Ext: Provider<Builder = Builder>,
343    Builder: Default + ProviderBuilder,
344{
345    pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
346        ClientBuilder::default()
347    }
348}
349
350impl<Ext, H> Client<Ext, H>
351where
352    Ext: Provider,
353{
354    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
355    where
356        S: AsRef<str>,
357    {
358        let uri = self
359            .ext
360            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
361
362        let mut req = Request::post(uri);
363
364        if let Some(hs) = req.headers_mut() {
365            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
366        }
367
368        self.ext.with_custom(req)
369    }
370
371    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
372    where
373        S: AsRef<str>,
374    {
375        let uri = self
376            .ext
377            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
378
379        let mut req = Request::post(uri);
380
381        if let Some(hs) = req.headers_mut() {
382            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
383        }
384
385        self.ext.with_custom(req)
386    }
387
388    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
389    where
390        S: AsRef<str>,
391    {
392        let uri = self
393            .ext
394            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
395
396        let mut req = Request::get(uri);
397
398        if let Some(hs) = req.headers_mut() {
399            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
400        }
401
402        self.ext.with_custom(req)
403    }
404}
405
406impl<Ext, H> VerifyClient for Client<Ext, H>
407where
408    H: HttpClientExt,
409    Ext: DebugExt + Provider + WasmCompatSync,
410{
411    async fn verify(&self) -> Result<(), VerifyError> {
412        use http::StatusCode;
413
414        let req = self
415            .get(Ext::VERIFY_PATH)?
416            .body(http_client::NoBody)
417            .map_err(http_client::Error::from)?;
418
419        let response = self.http_client.send(req).await?;
420
421        match response.status() {
422            StatusCode::OK => Ok(()),
423            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
424                Err(VerifyError::InvalidAuthentication)
425            }
426            StatusCode::INTERNAL_SERVER_ERROR => {
427                let text = http_client::text(response).await?;
428                Err(VerifyError::ProviderError(text))
429            }
430            status if status.as_u16() == 529 => {
431                let text = http_client::text(response).await?;
432                Err(VerifyError::ProviderError(text))
433            }
434            _ => {
435                let status = response.status();
436
437                if status.is_success() {
438                    Ok(())
439                } else {
440                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
441                    Err(VerifyError::HttpError(http_client::Error::Instance(
442                        format!("Failed with '{status}': {text}").into(),
443                    )))
444                }
445            }
446        }
447    }
448}
449
450#[derive(Debug, Default, Clone, Copy)]
451pub struct NeedsApiKey;
452
453// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
454#[derive(Clone)]
455pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
456    base_url: String,
457    api_key: ApiKey,
458    headers: HeaderMap,
459    http_client: Option<H>,
460    ext: Ext,
461}
462
463impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
464where
465    H: Default,
466    ExtBuilder: ProviderBuilder + Default,
467{
468    fn default() -> Self {
469        Self {
470            api_key: NeedsApiKey,
471            headers: Default::default(),
472            base_url: ExtBuilder::BASE_URL.into(),
473            http_client: None,
474            ext: Default::default(),
475        }
476    }
477}
478
479impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
480    /// Set the API key for this client. This *must* be done before the `build` method can be
481    /// called
482    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
483        ClientBuilder {
484            api_key: api_key.into(),
485            base_url: self.base_url,
486            headers: self.headers,
487            http_client: self.http_client,
488            ext: self.ext,
489        }
490    }
491}
492
493impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
494where
495    Ext: Clone,
496{
497    /// Owned map over the ext field
498    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
499    where
500        F: FnOnce(Ext) -> NewExt,
501    {
502        let ClientBuilder {
503            base_url,
504            api_key,
505            headers,
506            http_client,
507            ext,
508        } = self;
509
510        let new_ext = f(ext.clone());
511
512        ClientBuilder {
513            base_url,
514            api_key,
515            headers,
516            http_client,
517            ext: new_ext,
518        }
519    }
520
521    /// Set the base URL for this client
522    pub fn base_url<S>(self, base_url: S) -> Self
523    where
524        S: AsRef<str>,
525    {
526        Self {
527            base_url: base_url.as_ref().to_string(),
528            ..self
529        }
530    }
531
532    /// Set the HTTP backend used in this client
533    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
534        ClientBuilder {
535            http_client: Some(http_client),
536            base_url: self.base_url,
537            api_key: self.api_key,
538            headers: self.headers,
539            ext: self.ext,
540        }
541    }
542
543    /// Set the HTTP headers used in this client
544    pub fn http_headers(self, headers: HeaderMap) -> Self {
545        Self { headers, ..self }
546    }
547
548    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
549        &mut self.headers
550    }
551
552    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
553        &mut self.ext
554    }
555}
556
557impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
558    pub(crate) fn get_api_key(&self) -> &ApiKey {
559        &self.api_key
560    }
561}
562
563impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
564    pub fn ext(&self) -> &Ext {
565        &self.ext
566    }
567}
568
569impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
570where
571    ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
572    Ext: Provider<Builder = ExtBuilder>,
573    Key: ApiKey,
574    H: Default,
575{
576    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
577        let ext = self.ext.clone();
578
579        self = ext.finish(self)?;
580        let ext = Ext::build(&self)?;
581
582        let ClientBuilder {
583            http_client,
584            base_url,
585            mut headers,
586            api_key,
587            ..
588        } = self;
589
590        if let Some((k, v)) = api_key.into_header().transpose()? {
591            headers.insert(k, v);
592        }
593
594        let http_client = http_client.unwrap_or_default();
595
596        Ok(Client {
597            http_client,
598            base_url: Arc::from(base_url.as_str()),
599            headers: Arc::new(headers),
600            ext,
601        })
602    }
603}
604
605impl<M, Ext, H> CompletionClient for Client<Ext, H>
606where
607    Ext: Capabilities<H, Completion = Capable<M>>,
608    M: CompletionModel<Client = Self>,
609{
610    type CompletionModel = M;
611
612    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
613        M::make(self, model)
614    }
615}
616
617impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
618where
619    Ext: Capabilities<H, Embeddings = Capable<M>>,
620    M: EmbeddingModel<Client = Self>,
621{
622    type EmbeddingModel = M;
623
624    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
625        M::make(self, model, None)
626    }
627
628    fn embedding_model_with_ndims(
629        &self,
630        model: impl Into<String>,
631        ndims: usize,
632    ) -> Self::EmbeddingModel {
633        M::make(self, model, Some(ndims))
634    }
635}
636
637impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
638where
639    Ext: Capabilities<H, Transcription = Capable<M>>,
640    M: TranscriptionModel<Client = Self> + WasmCompatSend,
641{
642    type TranscriptionModel = M;
643
644    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
645        M::make(self, model)
646    }
647}
648
649#[cfg(feature = "image")]
650impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
651where
652    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
653    M: ImageGenerationModel<Client = Self>,
654{
655    type ImageGenerationModel = M;
656
657    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
658        M::make(self, model)
659    }
660}
661
662#[cfg(feature = "audio")]
663impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
664where
665    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
666    M: AudioGenerationModel<Client = Self>,
667{
668    type AudioGenerationModel = M;
669
670    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
671        M::make(self, model)
672    }
673}
674
675impl<M, Ext, H> ModelListingClient for Client<Ext, H>
676where
677    Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
678    M: ModelLister<H, Client = Self> + Send + Sync + Clone + 'static,
679    H: Send + Sync + Clone,
680{
681    fn list_models(
682        &self,
683    ) -> impl std::future::Future<
684        Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
685    > + WasmCompatSend {
686        let lister = M::new(self.clone());
687        async move { lister.list_all().await }
688    }
689}