rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33    completion::CompletionModel,
34    embeddings::EmbeddingModel,
35    http_client::{
36        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
37    },
38    prelude::TranscriptionClient,
39    transcription::TranscriptionModel,
40    wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46    #[error("reqwest error: {0}")]
47    HttpError(
48        #[from]
49        #[source]
50        reqwest::Error,
51    ),
52    #[error("invalid property: {0}")]
53    InvalidProperty(&'static str),
54}
55
56/// Abstracts over the ability to instantiate a client, either via environment variables or some
57/// `Self::Input`
58pub trait ProviderClient {
59    type Input;
60
61    /// Create a client from the process's environment.
62    /// Panics if an environment is improperly configured.
63    fn from_env() -> Self;
64
65    fn from_val(input: Self::Input) -> Self;
66}
67
68use crate::completion::{GetTokenUsage, Usage};
69
70/// The final streaming response from a dynamic client.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FinalCompletionResponse {
73    pub usage: Option<Usage>,
74}
75
76impl GetTokenUsage for FinalCompletionResponse {
77    fn token_usage(&self) -> Option<Usage> {
78        self.usage
79    }
80}
81
82/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
83/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
84pub trait ApiKey: Sized {
85    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
86        None
87    }
88}
89
90/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
91pub struct BearerAuth(String);
92
93impl ApiKey for BearerAuth {
94    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
95        Some(make_auth_header(self.0))
96    }
97}
98
99impl<S> From<S> for BearerAuth
100where
101    S: Into<String>,
102{
103    fn from(value: S) -> Self {
104        Self(value.into())
105    }
106}
107
108/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
109/// the lack of a capability or field (an API key, for instance)
110#[derive(Debug, Default, Clone, Copy)]
111pub struct Nothing;
112
113impl ApiKey for Nothing {}
114
115impl TryFrom<String> for Nothing {
116    type Error = &'static str;
117
118    fn try_from(_: String) -> Result<Self, Self::Error> {
119        Err(
120            "Tried to create a Nothing from a string - this should not happen, please file an issue",
121        )
122    }
123}
124
125#[derive(Clone)]
126pub struct Client<Ext = Nothing, H = reqwest::Client> {
127    base_url: Arc<str>,
128    headers: Arc<HeaderMap>,
129    http_client: H,
130    ext: Ext,
131}
132
133pub trait DebugExt: Debug {
134    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
135        std::iter::empty()
136    }
137}
138
139impl<Ext, H> std::fmt::Debug for Client<Ext, H>
140where
141    Ext: DebugExt,
142    H: std::fmt::Debug,
143{
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        let mut d = &mut f.debug_struct("Client");
146
147        d = d
148            .field("base_url", &self.base_url)
149            .field(
150                "headers",
151                &self
152                    .headers
153                    .iter()
154                    .filter_map(|(k, v)| {
155                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
156                            None
157                        } else {
158                            Some((k, v))
159                        }
160                    })
161                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
162            )
163            .field("http_client", &self.http_client);
164
165        self.ext
166            .fields()
167            .fold(d, |d, (name, field)| d.field(name, field))
168            .finish()
169    }
170}
171
172pub enum Transport {
173    Http,
174    Sse,
175    NdJson,
176}
177
178/// An API provider extension, this abstracts over extensions which may be use in conjunction with
179/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
180/// auth, instantiating models
181pub trait Provider: Sized {
182    const VERIFY_PATH: &'static str;
183
184    type Builder: ProviderBuilder;
185
186    fn build<H>(
187        builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
188    ) -> http_client::Result<Self>;
189
190    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191        base_url.to_string() + "/" + path.trim_start_matches('/')
192    }
193
194    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
195        Ok(req)
196    }
197}
198
199/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
200pub struct Capable<M>(PhantomData<M>);
201
202pub trait Capability {
203    const CAPABLE: bool;
204}
205
206impl<M> Capability for Capable<M> {
207    const CAPABLE: bool = true;
208}
209
210impl Capability for Nothing {
211    const CAPABLE: bool = false;
212}
213
214/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
215pub trait Capabilities<H = reqwest::Client> {
216    type Completion: Capability;
217    type Embeddings: Capability;
218    type Transcription: Capability;
219    #[cfg(feature = "image")]
220    type ImageGeneration: Capability;
221    #[cfg(feature = "audio")]
222    type AudioGeneration: Capability;
223}
224
225/// An API provider extension *builder*, this abstracts over provider-specific builders which are
226/// able to configure and produce a given provider's extension type
227///
228/// See [Provider]
229pub trait ProviderBuilder: Sized {
230    type Output: Provider;
231    type ApiKey;
232
233    const BASE_URL: &'static str;
234
235    /// This method can be used to customize the fields of `builder` before it is used to create
236    /// a client. For example, adding default headers
237    fn finish<H>(
238        &self,
239        builder: ClientBuilder<Self, Self::ApiKey, H>,
240    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
241        Ok(builder)
242    }
243}
244
245impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
246where
247    ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
248    Ext: Provider<Builder = ExtBuilder>,
249    H: Default + HttpClientExt,
250    Key: ApiKey,
251{
252    pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
253        Self::builder().api_key(api_key).build()
254    }
255}
256
257impl<Ext, H> Client<Ext, H> {
258    pub fn base_url(&self) -> &str {
259        &self.base_url
260    }
261
262    pub fn headers(&self) -> &HeaderMap {
263        &self.headers
264    }
265
266    pub fn ext(&self) -> &Ext {
267        &self.ext
268    }
269
270    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
271        Client {
272            base_url: self.base_url,
273            headers: self.headers,
274            http_client: self.http_client,
275            ext: new_ext,
276        }
277    }
278}
279
280impl<Ext, H> HttpClientExt for Client<Ext, H>
281where
282    H: HttpClientExt + 'static,
283    Ext: WasmCompatSend + WasmCompatSync + 'static,
284{
285    fn send<T, U>(
286        &self,
287        mut req: Request<T>,
288    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
289    where
290        T: Into<Bytes> + WasmCompatSend,
291        U: From<Bytes>,
292        U: WasmCompatSend + 'static,
293    {
294        req.headers_mut().insert(
295            http::header::CONTENT_TYPE,
296            http::HeaderValue::from_static("application/json"),
297        );
298
299        self.http_client.send(req)
300    }
301
302    fn send_multipart<U>(
303        &self,
304        req: Request<MultipartForm>,
305    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
306    where
307        U: From<Bytes>,
308        U: WasmCompatSend + 'static,
309    {
310        self.http_client.send_multipart(req)
311    }
312
313    fn send_streaming<T>(
314        &self,
315        req: Request<T>,
316    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
317    where
318        T: Into<Bytes>,
319    {
320        self.http_client.send_streaming(req)
321    }
322}
323
324impl<Ext, Builder, H> Client<Ext, H>
325where
326    H: Default + HttpClientExt,
327    Ext: Provider<Builder = Builder>,
328    Builder: Default + ProviderBuilder,
329{
330    pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
331        ClientBuilder::default()
332    }
333}
334
335impl<Ext, H> Client<Ext, H>
336where
337    Ext: Provider,
338{
339    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
340    where
341        S: AsRef<str>,
342    {
343        let uri = self
344            .ext
345            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
346
347        let mut req = Request::post(uri);
348
349        if let Some(hs) = req.headers_mut() {
350            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
351        }
352
353        self.ext.with_custom(req)
354    }
355
356    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
357    where
358        S: AsRef<str>,
359    {
360        let uri = self
361            .ext
362            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
363
364        let mut req = Request::post(uri);
365
366        if let Some(hs) = req.headers_mut() {
367            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
368        }
369
370        self.ext.with_custom(req)
371    }
372
373    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
374    where
375        S: AsRef<str>,
376    {
377        let uri = self
378            .ext
379            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
380
381        self.ext.with_custom(Request::get(uri))
382    }
383}
384
385impl<Ext, H> VerifyClient for Client<Ext, H>
386where
387    H: HttpClientExt,
388    Ext: DebugExt + Provider + WasmCompatSync,
389{
390    async fn verify(&self) -> Result<(), VerifyError> {
391        use http::StatusCode;
392
393        let req = self
394            .get(Ext::VERIFY_PATH)?
395            .body(http_client::NoBody)
396            .map_err(http_client::Error::from)?;
397
398        let response = self.http_client.send(req).await?;
399
400        match response.status() {
401            StatusCode::OK => Ok(()),
402            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
403                Err(VerifyError::InvalidAuthentication)
404            }
405            StatusCode::INTERNAL_SERVER_ERROR => {
406                let text = http_client::text(response).await?;
407                Err(VerifyError::ProviderError(text))
408            }
409            status if status.as_u16() == 529 => {
410                let text = http_client::text(response).await?;
411                Err(VerifyError::ProviderError(text))
412            }
413            _ => {
414                let status = response.status();
415
416                if status.is_success() {
417                    Ok(())
418                } else {
419                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
420                    Err(VerifyError::HttpError(http_client::Error::Instance(
421                        format!("Failed with '{status}': {text}").into(),
422                    )))
423                }
424            }
425        }
426    }
427}
428
429#[derive(Debug, Default, Clone, Copy)]
430pub struct NeedsApiKey;
431
432// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
433#[derive(Clone)]
434pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
435    base_url: String,
436    api_key: ApiKey,
437    headers: HeaderMap,
438    http_client: Option<H>,
439    ext: Ext,
440}
441
442impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
443where
444    H: Default,
445    ExtBuilder: ProviderBuilder + Default,
446{
447    fn default() -> Self {
448        Self {
449            api_key: NeedsApiKey,
450            headers: Default::default(),
451            base_url: ExtBuilder::BASE_URL.into(),
452            http_client: None,
453            ext: Default::default(),
454        }
455    }
456}
457
458impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
459    /// Set the API key for this client. This *must* be done before the `build` method can be
460    /// called
461    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
462        ClientBuilder {
463            api_key: api_key.into(),
464            base_url: self.base_url,
465            headers: self.headers,
466            http_client: self.http_client,
467            ext: self.ext,
468        }
469    }
470}
471
472impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
473where
474    Ext: Clone,
475{
476    /// Owned map over the ext field
477    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
478    where
479        F: FnOnce(Ext) -> NewExt,
480    {
481        let ClientBuilder {
482            base_url,
483            api_key,
484            headers,
485            http_client,
486            ext,
487        } = self;
488
489        let new_ext = f(ext.clone());
490
491        ClientBuilder {
492            base_url,
493            api_key,
494            headers,
495            http_client,
496            ext: new_ext,
497        }
498    }
499
500    /// Set the base URL for this client
501    pub fn base_url<S>(self, base_url: S) -> Self
502    where
503        S: AsRef<str>,
504    {
505        Self {
506            base_url: base_url.as_ref().to_string(),
507            ..self
508        }
509    }
510
511    /// Set the HTTP backend used in this client
512    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
513        ClientBuilder {
514            http_client: Some(http_client),
515            base_url: self.base_url,
516            api_key: self.api_key,
517            headers: self.headers,
518            ext: self.ext,
519        }
520    }
521
522    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
523        &mut self.headers
524    }
525
526    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
527        &mut self.ext
528    }
529}
530
531impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
532    pub(crate) fn get_api_key(&self) -> &ApiKey {
533        &self.api_key
534    }
535}
536
537impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
538    pub fn ext(&self) -> &Ext {
539        &self.ext
540    }
541}
542
543impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
544where
545    ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
546    Ext: Provider<Builder = ExtBuilder>,
547    Key: ApiKey,
548    H: Default,
549{
550    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
551        let ext = self.ext.clone();
552
553        self = ext.finish(self)?;
554        let ext = Ext::build(&self)?;
555
556        let ClientBuilder {
557            http_client,
558            base_url,
559            mut headers,
560            api_key,
561            ..
562        } = self;
563
564        if let Some((k, v)) = api_key.into_header().transpose()? {
565            headers.insert(k, v);
566        }
567
568        let http_client = http_client.unwrap_or_default();
569
570        Ok(Client {
571            http_client,
572            base_url: Arc::from(base_url.as_str()),
573            headers: Arc::new(headers),
574            ext,
575        })
576    }
577}
578
579impl<M, Ext, H> CompletionClient for Client<Ext, H>
580where
581    Ext: Capabilities<H, Completion = Capable<M>>,
582    M: CompletionModel<Client = Self>,
583{
584    type CompletionModel = M;
585
586    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
587        M::make(self, model)
588    }
589}
590
591impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
592where
593    Ext: Capabilities<H, Embeddings = Capable<M>>,
594    M: EmbeddingModel<Client = Self>,
595{
596    type EmbeddingModel = M;
597
598    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
599        M::make(self, model, None)
600    }
601
602    fn embedding_model_with_ndims(
603        &self,
604        model: impl Into<String>,
605        ndims: usize,
606    ) -> Self::EmbeddingModel {
607        M::make(self, model, Some(ndims))
608    }
609}
610
611impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
612where
613    Ext: Capabilities<H, Transcription = Capable<M>>,
614    M: TranscriptionModel<Client = Self> + WasmCompatSend,
615{
616    type TranscriptionModel = M;
617
618    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
619        M::make(self, model)
620    }
621}
622
623#[cfg(feature = "image")]
624impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
625where
626    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
627    M: ImageGenerationModel<Client = Self>,
628{
629    type ImageGenerationModel = M;
630
631    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
632        M::make(self, model)
633    }
634}
635
636#[cfg(feature = "audio")]
637impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
638where
639    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
640    M: AudioGenerationModel<Client = Self>,
641{
642    type AudioGenerationModel = M;
643
644    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
645        M::make(self, model)
646    }
647}