rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33    completion::CompletionModel,
34    embeddings::EmbeddingModel,
35    http_client::{
36        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
37    },
38    prelude::TranscriptionClient,
39    transcription::TranscriptionModel,
40    wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46    #[error("reqwest error: {0}")]
47    HttpError(
48        #[from]
49        #[source]
50        reqwest::Error,
51    ),
52    #[error("invalid property: {0}")]
53    InvalidProperty(&'static str),
54}
55
56/// Abstracts over the ability to instantiate a client, either via environment variables or some
57/// `Self::Input`
58pub trait ProviderClient {
59    type Input;
60
61    /// Create a client from the process's environment.
62    /// Panics if an environment is improperly configured.
63    fn from_env() -> Self;
64
65    fn from_val(input: Self::Input) -> Self;
66}
67
68use crate::completion::{GetTokenUsage, Usage};
69
70/// The final streaming response from a dynamic client.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FinalCompletionResponse {
73    pub usage: Option<Usage>,
74}
75
76impl GetTokenUsage for FinalCompletionResponse {
77    fn token_usage(&self) -> Option<Usage> {
78        self.usage
79    }
80}
81
82/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
83/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
84pub trait ApiKey: Sized {
85    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
86        None
87    }
88}
89
90/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
91pub struct BearerAuth(String);
92
93impl ApiKey for BearerAuth {
94    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
95        Some(make_auth_header(self.0))
96    }
97}
98
99impl<S> From<S> for BearerAuth
100where
101    S: Into<String>,
102{
103    fn from(value: S) -> Self {
104        Self(value.into())
105    }
106}
107
108/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
109/// the lack of a capability or field (an API key, for instance)
110#[derive(Debug, Default, Clone, Copy)]
111pub struct Nothing;
112
113impl ApiKey for Nothing {}
114
115impl TryFrom<String> for Nothing {
116    type Error = &'static str;
117
118    fn try_from(_: String) -> Result<Self, Self::Error> {
119        Err(
120            "Tried to create a Nothing from a string - this should not happen, please file an issue",
121        )
122    }
123}
124
125#[derive(Clone)]
126pub struct Client<Ext = Nothing, H = reqwest::Client> {
127    base_url: Arc<str>,
128    headers: Arc<HeaderMap>,
129    http_client: H,
130    ext: Ext,
131}
132
133pub trait DebugExt: Debug {
134    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
135        std::iter::empty()
136    }
137}
138
139impl<Ext, H> std::fmt::Debug for Client<Ext, H>
140where
141    Ext: DebugExt,
142    H: std::fmt::Debug,
143{
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        let mut d = &mut f.debug_struct("Client");
146
147        d = d
148            .field("base_url", &self.base_url)
149            .field(
150                "headers",
151                &self
152                    .headers
153                    .iter()
154                    .filter_map(|(k, v)| {
155                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
156                            None
157                        } else {
158                            Some((k, v))
159                        }
160                    })
161                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
162            )
163            .field("http_client", &self.http_client);
164
165        self.ext
166            .fields()
167            .fold(d, |d, (name, field)| d.field(name, field))
168            .finish()
169    }
170}
171
172pub enum Transport {
173    Http,
174    Sse,
175    NdJson,
176}
177
178/// An API provider extension, this abstracts over extensions which may be use in conjunction with
179/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
180/// auth, instantiating models
181pub trait Provider: Sized {
182    const VERIFY_PATH: &'static str;
183
184    type Builder: ProviderBuilder;
185
186    fn build<H>(
187        builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
188    ) -> http_client::Result<Self>;
189
190    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191        // Some providers (like Azure) have a blank base URL to allow users to input their own endpoints.
192        let base_url = if base_url.is_empty() {
193            base_url.to_string()
194        } else {
195            base_url.to_string() + "/"
196        };
197
198        base_url.to_string() + path.trim_start_matches('/')
199    }
200
201    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
202        Ok(req)
203    }
204}
205
206/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
207pub struct Capable<M>(PhantomData<M>);
208
209pub trait Capability {
210    const CAPABLE: bool;
211}
212
213impl<M> Capability for Capable<M> {
214    const CAPABLE: bool = true;
215}
216
217impl Capability for Nothing {
218    const CAPABLE: bool = false;
219}
220
221/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
222pub trait Capabilities<H = reqwest::Client> {
223    type Completion: Capability;
224    type Embeddings: Capability;
225    type Transcription: Capability;
226    #[cfg(feature = "image")]
227    type ImageGeneration: Capability;
228    #[cfg(feature = "audio")]
229    type AudioGeneration: Capability;
230}
231
232/// An API provider extension *builder*, this abstracts over provider-specific builders which are
233/// able to configure and produce a given provider's extension type
234///
235/// See [Provider]
236pub trait ProviderBuilder: Sized {
237    type Output: Provider;
238    type ApiKey;
239
240    const BASE_URL: &'static str;
241
242    /// This method can be used to customize the fields of `builder` before it is used to create
243    /// a client. For example, adding default headers
244    fn finish<H>(
245        &self,
246        builder: ClientBuilder<Self, Self::ApiKey, H>,
247    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
248        Ok(builder)
249    }
250}
251
252impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
253where
254    ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
255    Ext: Provider<Builder = ExtBuilder>,
256    H: Default + HttpClientExt,
257    Key: ApiKey,
258{
259    pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
260        Self::builder().api_key(api_key).build()
261    }
262}
263
264impl<Ext, H> Client<Ext, H> {
265    pub fn base_url(&self) -> &str {
266        &self.base_url
267    }
268
269    pub fn headers(&self) -> &HeaderMap {
270        &self.headers
271    }
272
273    pub fn ext(&self) -> &Ext {
274        &self.ext
275    }
276
277    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
278        Client {
279            base_url: self.base_url,
280            headers: self.headers,
281            http_client: self.http_client,
282            ext: new_ext,
283        }
284    }
285}
286
287impl<Ext, H> HttpClientExt for Client<Ext, H>
288where
289    H: HttpClientExt + 'static,
290    Ext: WasmCompatSend + WasmCompatSync + 'static,
291{
292    fn send<T, U>(
293        &self,
294        mut req: Request<T>,
295    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
296    where
297        T: Into<Bytes> + WasmCompatSend,
298        U: From<Bytes>,
299        U: WasmCompatSend + 'static,
300    {
301        req.headers_mut().insert(
302            http::header::CONTENT_TYPE,
303            http::HeaderValue::from_static("application/json"),
304        );
305
306        self.http_client.send(req)
307    }
308
309    fn send_multipart<U>(
310        &self,
311        req: Request<MultipartForm>,
312    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
313    where
314        U: From<Bytes>,
315        U: WasmCompatSend + 'static,
316    {
317        self.http_client.send_multipart(req)
318    }
319
320    fn send_streaming<T>(
321        &self,
322        mut req: Request<T>,
323    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
324    where
325        T: Into<Bytes>,
326    {
327        req.headers_mut().insert(
328            http::header::CONTENT_TYPE,
329            http::HeaderValue::from_static("application/json"),
330        );
331
332        self.http_client.send_streaming(req)
333    }
334}
335
336impl<Ext, Builder, H> Client<Ext, H>
337where
338    H: Default + HttpClientExt,
339    Ext: Provider<Builder = Builder>,
340    Builder: Default + ProviderBuilder,
341{
342    pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
343        ClientBuilder::default()
344    }
345}
346
347impl<Ext, H> Client<Ext, H>
348where
349    Ext: Provider,
350{
351    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
352    where
353        S: AsRef<str>,
354    {
355        let uri = self
356            .ext
357            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
358
359        let mut req = Request::post(uri);
360
361        if let Some(hs) = req.headers_mut() {
362            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
363        }
364
365        self.ext.with_custom(req)
366    }
367
368    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
369    where
370        S: AsRef<str>,
371    {
372        let uri = self
373            .ext
374            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
375
376        let mut req = Request::post(uri);
377
378        if let Some(hs) = req.headers_mut() {
379            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
380        }
381
382        self.ext.with_custom(req)
383    }
384
385    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
386    where
387        S: AsRef<str>,
388    {
389        let uri = self
390            .ext
391            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
392
393        let mut req = Request::get(uri);
394
395        if let Some(hs) = req.headers_mut() {
396            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
397        }
398
399        self.ext.with_custom(req)
400    }
401}
402
403impl<Ext, H> VerifyClient for Client<Ext, H>
404where
405    H: HttpClientExt,
406    Ext: DebugExt + Provider + WasmCompatSync,
407{
408    async fn verify(&self) -> Result<(), VerifyError> {
409        use http::StatusCode;
410
411        let req = self
412            .get(Ext::VERIFY_PATH)?
413            .body(http_client::NoBody)
414            .map_err(http_client::Error::from)?;
415
416        let response = self.http_client.send(req).await?;
417
418        match response.status() {
419            StatusCode::OK => Ok(()),
420            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
421                Err(VerifyError::InvalidAuthentication)
422            }
423            StatusCode::INTERNAL_SERVER_ERROR => {
424                let text = http_client::text(response).await?;
425                Err(VerifyError::ProviderError(text))
426            }
427            status if status.as_u16() == 529 => {
428                let text = http_client::text(response).await?;
429                Err(VerifyError::ProviderError(text))
430            }
431            _ => {
432                let status = response.status();
433
434                if status.is_success() {
435                    Ok(())
436                } else {
437                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
438                    Err(VerifyError::HttpError(http_client::Error::Instance(
439                        format!("Failed with '{status}': {text}").into(),
440                    )))
441                }
442            }
443        }
444    }
445}
446
447#[derive(Debug, Default, Clone, Copy)]
448pub struct NeedsApiKey;
449
450// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
451#[derive(Clone)]
452pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
453    base_url: String,
454    api_key: ApiKey,
455    headers: HeaderMap,
456    http_client: Option<H>,
457    ext: Ext,
458}
459
460impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
461where
462    H: Default,
463    ExtBuilder: ProviderBuilder + Default,
464{
465    fn default() -> Self {
466        Self {
467            api_key: NeedsApiKey,
468            headers: Default::default(),
469            base_url: ExtBuilder::BASE_URL.into(),
470            http_client: None,
471            ext: Default::default(),
472        }
473    }
474}
475
476impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
477    /// Set the API key for this client. This *must* be done before the `build` method can be
478    /// called
479    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
480        ClientBuilder {
481            api_key: api_key.into(),
482            base_url: self.base_url,
483            headers: self.headers,
484            http_client: self.http_client,
485            ext: self.ext,
486        }
487    }
488}
489
490impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
491where
492    Ext: Clone,
493{
494    /// Owned map over the ext field
495    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
496    where
497        F: FnOnce(Ext) -> NewExt,
498    {
499        let ClientBuilder {
500            base_url,
501            api_key,
502            headers,
503            http_client,
504            ext,
505        } = self;
506
507        let new_ext = f(ext.clone());
508
509        ClientBuilder {
510            base_url,
511            api_key,
512            headers,
513            http_client,
514            ext: new_ext,
515        }
516    }
517
518    /// Set the base URL for this client
519    pub fn base_url<S>(self, base_url: S) -> Self
520    where
521        S: AsRef<str>,
522    {
523        Self {
524            base_url: base_url.as_ref().to_string(),
525            ..self
526        }
527    }
528
529    /// Set the HTTP backend used in this client
530    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
531        ClientBuilder {
532            http_client: Some(http_client),
533            base_url: self.base_url,
534            api_key: self.api_key,
535            headers: self.headers,
536            ext: self.ext,
537        }
538    }
539
540    /// Set the HTTP headers used in this client
541    pub fn http_headers(self, headers: HeaderMap) -> Self {
542        Self { headers, ..self }
543    }
544
545    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
546        &mut self.headers
547    }
548
549    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
550        &mut self.ext
551    }
552}
553
554impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
555    pub(crate) fn get_api_key(&self) -> &ApiKey {
556        &self.api_key
557    }
558}
559
560impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
561    pub fn ext(&self) -> &Ext {
562        &self.ext
563    }
564}
565
566impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
567where
568    ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
569    Ext: Provider<Builder = ExtBuilder>,
570    Key: ApiKey,
571    H: Default,
572{
573    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
574        let ext = self.ext.clone();
575
576        self = ext.finish(self)?;
577        let ext = Ext::build(&self)?;
578
579        let ClientBuilder {
580            http_client,
581            base_url,
582            mut headers,
583            api_key,
584            ..
585        } = self;
586
587        if let Some((k, v)) = api_key.into_header().transpose()? {
588            headers.insert(k, v);
589        }
590
591        let http_client = http_client.unwrap_or_default();
592
593        Ok(Client {
594            http_client,
595            base_url: Arc::from(base_url.as_str()),
596            headers: Arc::new(headers),
597            ext,
598        })
599    }
600}
601
602impl<M, Ext, H> CompletionClient for Client<Ext, H>
603where
604    Ext: Capabilities<H, Completion = Capable<M>>,
605    M: CompletionModel<Client = Self>,
606{
607    type CompletionModel = M;
608
609    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
610        M::make(self, model)
611    }
612}
613
614impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
615where
616    Ext: Capabilities<H, Embeddings = Capable<M>>,
617    M: EmbeddingModel<Client = Self>,
618{
619    type EmbeddingModel = M;
620
621    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
622        M::make(self, model, None)
623    }
624
625    fn embedding_model_with_ndims(
626        &self,
627        model: impl Into<String>,
628        ndims: usize,
629    ) -> Self::EmbeddingModel {
630        M::make(self, model, Some(ndims))
631    }
632}
633
634impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
635where
636    Ext: Capabilities<H, Transcription = Capable<M>>,
637    M: TranscriptionModel<Client = Self> + WasmCompatSend,
638{
639    type TranscriptionModel = M;
640
641    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
642        M::make(self, model)
643    }
644}
645
646#[cfg(feature = "image")]
647impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
648where
649    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
650    M: ImageGenerationModel<Client = Self>,
651{
652    type ImageGenerationModel = M;
653
654    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
655        M::make(self, model)
656    }
657}
658
659#[cfg(feature = "audio")]
660impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
661where
662    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
663    M: AudioGenerationModel<Client = Self>,
664{
665    type AudioGenerationModel = M;
666
667    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
668        M::make(self, model)
669    }
670}