rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33    completion::CompletionModel,
34    embeddings::EmbeddingModel,
35    http_client::{
36        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
37    },
38    prelude::TranscriptionClient,
39    transcription::TranscriptionModel,
40    wasm_compat::{WasmCompatSend, WasmCompatSync},
41};
42
43#[derive(Debug, Error)]
44#[non_exhaustive]
45pub enum ClientBuilderError {
46    #[error("reqwest error: {0}")]
47    HttpError(
48        #[from]
49        #[source]
50        reqwest::Error,
51    ),
52    #[error("invalid property: {0}")]
53    InvalidProperty(&'static str),
54}
55
56/// Abstracts over the ability to instantiate a client, either via environment variables or some
57/// `Self::Input`
58pub trait ProviderClient {
59    type Input;
60
61    /// Create a client from the process's environment.
62    /// Panics if an environment is improperly configured.
63    fn from_env() -> Self;
64
65    fn from_val(input: Self::Input) -> Self;
66}
67
68use crate::completion::{GetTokenUsage, Usage};
69
70/// The final streaming response from a dynamic client.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct FinalCompletionResponse {
73    pub usage: Option<Usage>,
74}
75
76impl GetTokenUsage for FinalCompletionResponse {
77    fn token_usage(&self) -> Option<Usage> {
78        self.usage
79    }
80}
81
82/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
83/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
84pub trait ApiKey: Sized {
85    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
86        None
87    }
88}
89
90/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
91pub struct BearerAuth(String);
92
93impl ApiKey for BearerAuth {
94    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
95        Some(make_auth_header(self.0))
96    }
97}
98
99impl<S> From<S> for BearerAuth
100where
101    S: Into<String>,
102{
103    fn from(value: S) -> Self {
104        Self(value.into())
105    }
106}
107
108/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
109/// the lack of a capability or field (an API key, for instance)
110#[derive(Debug, Default, Clone, Copy)]
111pub struct Nothing;
112
113impl ApiKey for Nothing {}
114
115impl TryFrom<String> for Nothing {
116    type Error = &'static str;
117
118    fn try_from(_: String) -> Result<Self, Self::Error> {
119        Err(
120            "Tried to create a Nothing from a string - this should not happen, please file an issue",
121        )
122    }
123}
124
125#[derive(Clone)]
126pub struct Client<Ext = Nothing, H = reqwest::Client> {
127    base_url: Arc<str>,
128    headers: Arc<HeaderMap>,
129    http_client: H,
130    ext: Ext,
131}
132
133pub trait DebugExt: Debug {
134    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
135        std::iter::empty()
136    }
137}
138
139impl<Ext, H> std::fmt::Debug for Client<Ext, H>
140where
141    Ext: DebugExt,
142    H: std::fmt::Debug,
143{
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        let mut d = &mut f.debug_struct("Client");
146
147        d = d
148            .field("base_url", &self.base_url)
149            .field(
150                "headers",
151                &self
152                    .headers
153                    .iter()
154                    .filter_map(|(k, v)| {
155                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
156                            None
157                        } else {
158                            Some((k, v))
159                        }
160                    })
161                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
162            )
163            .field("http_client", &self.http_client);
164
165        self.ext
166            .fields()
167            .fold(d, |d, (name, field)| d.field(name, field))
168            .finish()
169    }
170}
171
172pub enum Transport {
173    Http,
174    Sse,
175    NdJson,
176}
177
178/// An API provider extension, this abstracts over extensions which may be use in conjunction with
179/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
180/// auth, instantiating models
181pub trait Provider: Sized {
182    const VERIFY_PATH: &'static str;
183
184    type Builder: ProviderBuilder;
185
186    fn build<H>(
187        builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
188    ) -> http_client::Result<Self>;
189
190    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191        base_url.to_string() + "/" + path.trim_start_matches('/')
192    }
193
194    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
195        Ok(req)
196    }
197}
198
199/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
200pub struct Capable<M>(PhantomData<M>);
201
202pub trait Capability {
203    const CAPABLE: bool;
204}
205
206impl<M> Capability for Capable<M> {
207    const CAPABLE: bool = true;
208}
209
210impl Capability for Nothing {
211    const CAPABLE: bool = false;
212}
213
214/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
215pub trait Capabilities<H = reqwest::Client> {
216    type Completion: Capability;
217    type Embeddings: Capability;
218    type Transcription: Capability;
219    #[cfg(feature = "image")]
220    type ImageGeneration: Capability;
221    #[cfg(feature = "audio")]
222    type AudioGeneration: Capability;
223}
224
225/// An API provider extension *builder*, this abstracts over provider-specific builders which are
226/// able to configure and produce a given provider's extension type
227///
228/// See [Provider]
229pub trait ProviderBuilder: Sized {
230    type Output: Provider;
231    type ApiKey;
232
233    const BASE_URL: &'static str;
234
235    /// This method can be used to customize the fields of `builder` before it is used to create
236    /// a client. For example, adding default headers
237    fn finish<H>(
238        &self,
239        builder: ClientBuilder<Self, Self::ApiKey, H>,
240    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
241        Ok(builder)
242    }
243}
244
245impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
246where
247    ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
248    Ext: Provider<Builder = ExtBuilder>,
249    H: Default + HttpClientExt,
250    Key: ApiKey,
251{
252    pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
253        Self::builder().api_key(api_key).build()
254    }
255}
256
257impl<Ext, H> Client<Ext, H> {
258    pub fn base_url(&self) -> &str {
259        &self.base_url
260    }
261
262    pub fn headers(&self) -> &HeaderMap {
263        &self.headers
264    }
265
266    pub fn ext(&self) -> &Ext {
267        &self.ext
268    }
269
270    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
271        Client {
272            base_url: self.base_url,
273            headers: self.headers,
274            http_client: self.http_client,
275            ext: new_ext,
276        }
277    }
278}
279
280impl<Ext, H> HttpClientExt for Client<Ext, H>
281where
282    H: HttpClientExt + 'static,
283    Ext: WasmCompatSend + WasmCompatSync + 'static,
284{
285    fn send<T, U>(
286        &self,
287        mut req: Request<T>,
288    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
289    where
290        T: Into<Bytes> + WasmCompatSend,
291        U: From<Bytes>,
292        U: WasmCompatSend + 'static,
293    {
294        req.headers_mut().insert(
295            http::header::CONTENT_TYPE,
296            http::HeaderValue::from_static("application/json"),
297        );
298
299        self.http_client.send(req)
300    }
301
302    fn send_multipart<U>(
303        &self,
304        req: Request<MultipartForm>,
305    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
306    where
307        U: From<Bytes>,
308        U: WasmCompatSend + 'static,
309    {
310        self.http_client.send_multipart(req)
311    }
312
313    fn send_streaming<T>(
314        &self,
315        mut req: Request<T>,
316    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
317    where
318        T: Into<Bytes>,
319    {
320        req.headers_mut().insert(
321            http::header::CONTENT_TYPE,
322            http::HeaderValue::from_static("application/json"),
323        );
324
325        self.http_client.send_streaming(req)
326    }
327}
328
329impl<Ext, Builder, H> Client<Ext, H>
330where
331    H: Default + HttpClientExt,
332    Ext: Provider<Builder = Builder>,
333    Builder: Default + ProviderBuilder,
334{
335    pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
336        ClientBuilder::default()
337    }
338}
339
340impl<Ext, H> Client<Ext, H>
341where
342    Ext: Provider,
343{
344    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
345    where
346        S: AsRef<str>,
347    {
348        let uri = self
349            .ext
350            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
351
352        let mut req = Request::post(uri);
353
354        if let Some(hs) = req.headers_mut() {
355            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
356        }
357
358        self.ext.with_custom(req)
359    }
360
361    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
362    where
363        S: AsRef<str>,
364    {
365        let uri = self
366            .ext
367            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
368
369        let mut req = Request::post(uri);
370
371        if let Some(hs) = req.headers_mut() {
372            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
373        }
374
375        self.ext.with_custom(req)
376    }
377
378    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
379    where
380        S: AsRef<str>,
381    {
382        let uri = self
383            .ext
384            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
385
386        let mut req = Request::get(uri);
387
388        if let Some(hs) = req.headers_mut() {
389            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
390        }
391
392        self.ext.with_custom(req)
393    }
394}
395
396impl<Ext, H> VerifyClient for Client<Ext, H>
397where
398    H: HttpClientExt,
399    Ext: DebugExt + Provider + WasmCompatSync,
400{
401    async fn verify(&self) -> Result<(), VerifyError> {
402        use http::StatusCode;
403
404        let req = self
405            .get(Ext::VERIFY_PATH)?
406            .body(http_client::NoBody)
407            .map_err(http_client::Error::from)?;
408
409        let response = self.http_client.send(req).await?;
410
411        match response.status() {
412            StatusCode::OK => Ok(()),
413            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
414                Err(VerifyError::InvalidAuthentication)
415            }
416            StatusCode::INTERNAL_SERVER_ERROR => {
417                let text = http_client::text(response).await?;
418                Err(VerifyError::ProviderError(text))
419            }
420            status if status.as_u16() == 529 => {
421                let text = http_client::text(response).await?;
422                Err(VerifyError::ProviderError(text))
423            }
424            _ => {
425                let status = response.status();
426
427                if status.is_success() {
428                    Ok(())
429                } else {
430                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
431                    Err(VerifyError::HttpError(http_client::Error::Instance(
432                        format!("Failed with '{status}': {text}").into(),
433                    )))
434                }
435            }
436        }
437    }
438}
439
440#[derive(Debug, Default, Clone, Copy)]
441pub struct NeedsApiKey;
442
443// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
444#[derive(Clone)]
445pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
446    base_url: String,
447    api_key: ApiKey,
448    headers: HeaderMap,
449    http_client: Option<H>,
450    ext: Ext,
451}
452
453impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
454where
455    H: Default,
456    ExtBuilder: ProviderBuilder + Default,
457{
458    fn default() -> Self {
459        Self {
460            api_key: NeedsApiKey,
461            headers: Default::default(),
462            base_url: ExtBuilder::BASE_URL.into(),
463            http_client: None,
464            ext: Default::default(),
465        }
466    }
467}
468
469impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
470    /// Set the API key for this client. This *must* be done before the `build` method can be
471    /// called
472    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
473        ClientBuilder {
474            api_key: api_key.into(),
475            base_url: self.base_url,
476            headers: self.headers,
477            http_client: self.http_client,
478            ext: self.ext,
479        }
480    }
481}
482
483impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
484where
485    Ext: Clone,
486{
487    /// Owned map over the ext field
488    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
489    where
490        F: FnOnce(Ext) -> NewExt,
491    {
492        let ClientBuilder {
493            base_url,
494            api_key,
495            headers,
496            http_client,
497            ext,
498        } = self;
499
500        let new_ext = f(ext.clone());
501
502        ClientBuilder {
503            base_url,
504            api_key,
505            headers,
506            http_client,
507            ext: new_ext,
508        }
509    }
510
511    /// Set the base URL for this client
512    pub fn base_url<S>(self, base_url: S) -> Self
513    where
514        S: AsRef<str>,
515    {
516        Self {
517            base_url: base_url.as_ref().to_string(),
518            ..self
519        }
520    }
521
522    /// Set the HTTP backend used in this client
523    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
524        ClientBuilder {
525            http_client: Some(http_client),
526            base_url: self.base_url,
527            api_key: self.api_key,
528            headers: self.headers,
529            ext: self.ext,
530        }
531    }
532
533    /// Set the HTTP headers used in this client
534    pub fn http_headers(self, headers: HeaderMap) -> Self {
535        Self { headers, ..self }
536    }
537
538    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
539        &mut self.headers
540    }
541
542    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
543        &mut self.ext
544    }
545}
546
547impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
548    pub(crate) fn get_api_key(&self) -> &ApiKey {
549        &self.api_key
550    }
551}
552
553impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
554    pub fn ext(&self) -> &Ext {
555        &self.ext
556    }
557}
558
559impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
560where
561    ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
562    Ext: Provider<Builder = ExtBuilder>,
563    Key: ApiKey,
564    H: Default,
565{
566    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
567        let ext = self.ext.clone();
568
569        self = ext.finish(self)?;
570        let ext = Ext::build(&self)?;
571
572        let ClientBuilder {
573            http_client,
574            base_url,
575            mut headers,
576            api_key,
577            ..
578        } = self;
579
580        if let Some((k, v)) = api_key.into_header().transpose()? {
581            headers.insert(k, v);
582        }
583
584        let http_client = http_client.unwrap_or_default();
585
586        Ok(Client {
587            http_client,
588            base_url: Arc::from(base_url.as_str()),
589            headers: Arc::new(headers),
590            ext,
591        })
592    }
593}
594
595impl<M, Ext, H> CompletionClient for Client<Ext, H>
596where
597    Ext: Capabilities<H, Completion = Capable<M>>,
598    M: CompletionModel<Client = Self>,
599{
600    type CompletionModel = M;
601
602    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
603        M::make(self, model)
604    }
605}
606
607impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
608where
609    Ext: Capabilities<H, Embeddings = Capable<M>>,
610    M: EmbeddingModel<Client = Self>,
611{
612    type EmbeddingModel = M;
613
614    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
615        M::make(self, model, None)
616    }
617
618    fn embedding_model_with_ndims(
619        &self,
620        model: impl Into<String>,
621        ndims: usize,
622    ) -> Self::EmbeddingModel {
623        M::make(self, model, Some(ndims))
624    }
625}
626
627impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
628where
629    Ext: Capabilities<H, Transcription = Capable<M>>,
630    M: TranscriptionModel<Client = Self> + WasmCompatSend,
631{
632    type TranscriptionModel = M;
633
634    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
635        M::make(self, model)
636    }
637}
638
639#[cfg(feature = "image")]
640impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
641where
642    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
643    M: ImageGenerationModel<Client = Self>,
644{
645    type ImageGenerationModel = M;
646
647    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
648        M::make(self, model)
649    }
650}
651
652#[cfg(feature = "audio")]
653impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
654where
655    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
656    M: AudioGenerationModel<Client = Self>,
657{
658    type AudioGenerationModel = M;
659
660    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
661        M::make(self, model)
662    }
663}