rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData, sync::Arc};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33    completion::CompletionModel,
34    embeddings::EmbeddingModel,
35    http_client::{self, Builder, HttpClientExt, LazyBody, Request, Response, make_auth_header},
36    prelude::TranscriptionClient,
37    transcription::TranscriptionModel,
38    wasm_compat::{WasmCompatSend, WasmCompatSync},
39};
40
41#[derive(Debug, Error)]
42#[non_exhaustive]
43pub enum ClientBuilderError {
44    #[error("reqwest error: {0}")]
45    HttpError(
46        #[from]
47        #[source]
48        reqwest::Error,
49    ),
50    #[error("invalid property: {0}")]
51    InvalidProperty(&'static str),
52}
53
54/// Abstracts over the ability to instantiate a client, either via environment variables or some
55/// `Self::Input`
56pub trait ProviderClient {
57    type Input;
58
59    /// Create a client from the process's environment.
60    /// Panics if an environment is improperly configured.
61    fn from_env() -> Self;
62
63    fn from_val(input: Self::Input) -> Self;
64}
65
66use crate::completion::{GetTokenUsage, Usage};
67
68/// The final streaming response from a dynamic client.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct FinalCompletionResponse {
71    pub usage: Option<Usage>,
72}
73
74impl GetTokenUsage for FinalCompletionResponse {
75    fn token_usage(&self) -> Option<Usage> {
76        self.usage
77    }
78}
79
80/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
81/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
82pub trait ApiKey: Sized {
83    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
84        None
85    }
86}
87
88/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
89pub struct BearerAuth(String);
90
91impl ApiKey for BearerAuth {
92    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
93        Some(make_auth_header(self.0))
94    }
95}
96
97impl<S> From<S> for BearerAuth
98where
99    S: Into<String>,
100{
101    fn from(value: S) -> Self {
102        Self(value.into())
103    }
104}
105
106/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
107/// the lack of a capability or field (an API key, for instance)
108#[derive(Debug, Default, Clone, Copy)]
109pub struct Nothing;
110
111impl ApiKey for Nothing {}
112
113impl TryFrom<String> for Nothing {
114    type Error = &'static str;
115
116    fn try_from(_: String) -> Result<Self, Self::Error> {
117        Err(
118            "Tried to create a Nothing from a string - this should not happen, please file an issue",
119        )
120    }
121}
122
123#[derive(Clone)]
124pub struct Client<Ext = Nothing, H = reqwest::Client> {
125    base_url: Arc<str>,
126    headers: Arc<HeaderMap>,
127    http_client: H,
128    ext: Ext,
129}
130
131pub trait DebugExt: Debug {
132    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
133        std::iter::empty()
134    }
135}
136
137impl<Ext, H> std::fmt::Debug for Client<Ext, H>
138where
139    Ext: DebugExt,
140    H: std::fmt::Debug,
141{
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        let mut d = &mut f.debug_struct("Client");
144
145        d = d
146            .field("base_url", &self.base_url)
147            .field(
148                "headers",
149                &self.headers.iter().filter_map(|(k, v)| {
150                    (!(k == http::header::AUTHORIZATION || k.as_str().contains("api-key")))
151                        .then_some((k, v))
152                }),
153            )
154            .field("http_client", &self.http_client);
155
156        self.ext
157            .fields()
158            .fold(d, |d, (name, field)| d.field(name, field))
159            .finish()
160    }
161}
162
163pub enum Transport {
164    Http,
165    Sse,
166    NdJson,
167}
168
169/// An API provider extension, this abstracts over extensions which may be use in conjunction with
170/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
171/// auth, instantiating models
172pub trait Provider: Sized {
173    const VERIFY_PATH: &'static str;
174
175    type Builder: ProviderBuilder;
176
177    fn build<H>(
178        builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
179    ) -> http_client::Result<Self>;
180
181    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
182        base_url.to_string() + "/" + path.trim_start_matches('/')
183    }
184
185    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
186        Ok(req)
187    }
188}
189
190/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
191pub struct Capable<M>(PhantomData<M>);
192
193pub trait Capability {
194    const CAPABLE: bool;
195}
196
197impl<M> Capability for Capable<M> {
198    const CAPABLE: bool = true;
199}
200
201impl Capability for Nothing {
202    const CAPABLE: bool = false;
203}
204
205/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
206pub trait Capabilities<H = reqwest::Client> {
207    type Completion: Capability;
208    type Embeddings: Capability;
209    type Transcription: Capability;
210    #[cfg(feature = "image")]
211    type ImageGeneration: Capability;
212    #[cfg(feature = "audio")]
213    type AudioGeneration: Capability;
214}
215
216/// An API provider extension *builder*, this abstracts over provider-specific builders which are
217/// able to configure and produce a given provider's extension type
218///
219/// See [Provider]
220pub trait ProviderBuilder: Sized {
221    type Output: Provider;
222    type ApiKey;
223
224    const BASE_URL: &'static str;
225
226    /// This method can be used to customize the fields of `builder` before it is used to create
227    /// a client. For example, adding default headers
228    fn finish<H>(
229        &self,
230        builder: ClientBuilder<Self, Self::ApiKey, H>,
231    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
232        Ok(builder)
233    }
234}
235
236impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
237where
238    ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
239    Ext: Provider<Builder = ExtBuilder>,
240    H: Default + HttpClientExt,
241    Key: ApiKey,
242{
243    pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
244        Self::builder().api_key(api_key).build()
245    }
246}
247
248impl<Ext, H> Client<Ext, H> {
249    pub fn base_url(&self) -> &str {
250        &self.base_url
251    }
252
253    pub fn headers(&self) -> &HeaderMap {
254        &self.headers
255    }
256
257    pub fn ext(&self) -> &Ext {
258        &self.ext
259    }
260
261    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
262        Client {
263            base_url: self.base_url,
264            headers: self.headers,
265            http_client: self.http_client,
266            ext: new_ext,
267        }
268    }
269}
270
271impl<Ext, H> HttpClientExt for Client<Ext, H>
272where
273    H: HttpClientExt + 'static,
274    Ext: WasmCompatSend + WasmCompatSync + 'static,
275{
276    fn send<T, U>(
277        &self,
278        mut req: Request<T>,
279    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
280    where
281        T: Into<Bytes> + WasmCompatSend,
282        U: From<Bytes>,
283        U: WasmCompatSend + 'static,
284    {
285        req.headers_mut().insert(
286            http::header::CONTENT_TYPE,
287            http::HeaderValue::from_static("application/json"),
288        );
289
290        self.http_client.send(req)
291    }
292
293    fn send_multipart<U>(
294        &self,
295        req: Request<reqwest::multipart::Form>,
296    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
297    where
298        U: From<Bytes>,
299        U: WasmCompatSend + 'static,
300    {
301        self.http_client.send_multipart(req)
302    }
303
304    fn send_streaming<T>(
305        &self,
306        req: Request<T>,
307    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
308    where
309        T: Into<Bytes>,
310    {
311        self.http_client.send_streaming(req)
312    }
313}
314
315impl<Ext, Builder, H> Client<Ext, H>
316where
317    H: Default + HttpClientExt,
318    Ext: Provider<Builder = Builder>,
319    Builder: Default + ProviderBuilder,
320{
321    pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
322        ClientBuilder::default()
323    }
324}
325
326impl<Ext, H> Client<Ext, H>
327where
328    Ext: Provider,
329{
330    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
331    where
332        S: AsRef<str>,
333    {
334        let uri = self
335            .ext
336            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
337
338        let mut req = Request::post(uri);
339
340        if let Some(hs) = req.headers_mut() {
341            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
342        }
343
344        self.ext.with_custom(req)
345    }
346
347    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
348    where
349        S: AsRef<str>,
350    {
351        let uri = self
352            .ext
353            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
354
355        let mut req = Request::post(uri);
356
357        if let Some(hs) = req.headers_mut() {
358            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
359        }
360
361        self.ext.with_custom(req)
362    }
363
364    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
365    where
366        S: AsRef<str>,
367    {
368        let uri = self
369            .ext
370            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
371
372        self.ext.with_custom(Request::get(uri))
373    }
374}
375
376impl<Ext, H> VerifyClient for Client<Ext, H>
377where
378    H: HttpClientExt,
379    Ext: DebugExt + Provider + WasmCompatSync,
380{
381    async fn verify(&self) -> Result<(), VerifyError> {
382        use http::StatusCode;
383
384        let req = self
385            .get(Ext::VERIFY_PATH)?
386            .body(http_client::NoBody)
387            .map_err(http_client::Error::from)?;
388
389        let response = self.http_client.send(req).await?;
390
391        match response.status() {
392            StatusCode::OK => Ok(()),
393            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
394                Err(VerifyError::InvalidAuthentication)
395            }
396            StatusCode::INTERNAL_SERVER_ERROR => {
397                let text = http_client::text(response).await?;
398                Err(VerifyError::ProviderError(text))
399            }
400            status if status.as_u16() == 529 => {
401                let text = http_client::text(response).await?;
402                Err(VerifyError::ProviderError(text))
403            }
404            _ => {
405                let status = response.status();
406
407                if status.is_success() {
408                    Ok(())
409                } else {
410                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
411                    Err(VerifyError::HttpError(http_client::Error::Instance(
412                        format!("Failed with '{status}': {text}").into(),
413                    )))
414                }
415            }
416        }
417    }
418}
419
420#[derive(Debug, Default, Clone, Copy)]
421pub struct NeedsApiKey;
422
423// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
424#[derive(Clone)]
425pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
426    base_url: String,
427    api_key: ApiKey,
428    headers: HeaderMap,
429    http_client: Option<H>,
430    ext: Ext,
431}
432
433impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
434where
435    H: Default,
436    ExtBuilder: ProviderBuilder + Default,
437{
438    fn default() -> Self {
439        Self {
440            api_key: NeedsApiKey,
441            headers: Default::default(),
442            base_url: ExtBuilder::BASE_URL.into(),
443            http_client: None,
444            ext: Default::default(),
445        }
446    }
447}
448
449impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
450    /// Set the API key for this client. This *must* be done before the `build` method can be
451    /// called
452    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
453        ClientBuilder {
454            api_key: api_key.into(),
455            base_url: self.base_url,
456            headers: self.headers,
457            http_client: self.http_client,
458            ext: self.ext,
459        }
460    }
461}
462
463impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
464where
465    Ext: Clone,
466{
467    /// Owned map over the ext field
468    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
469    where
470        F: FnOnce(Ext) -> NewExt,
471    {
472        let ClientBuilder {
473            base_url,
474            api_key,
475            headers,
476            http_client,
477            ext,
478        } = self;
479
480        let new_ext = f(ext.clone());
481
482        ClientBuilder {
483            base_url,
484            api_key,
485            headers,
486            http_client,
487            ext: new_ext,
488        }
489    }
490
491    /// Set the base URL for this client
492    pub fn base_url<S>(self, base_url: S) -> Self
493    where
494        S: AsRef<str>,
495    {
496        Self {
497            base_url: base_url.as_ref().to_string(),
498            ..self
499        }
500    }
501
502    /// Set the HTTP backend used in this client
503    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
504        ClientBuilder {
505            http_client: Some(http_client),
506            base_url: self.base_url,
507            api_key: self.api_key,
508            headers: self.headers,
509            ext: self.ext,
510        }
511    }
512
513    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
514        &mut self.headers
515    }
516
517    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
518        &mut self.ext
519    }
520}
521
522impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
523    pub(crate) fn get_api_key(&self) -> &ApiKey {
524        &self.api_key
525    }
526}
527
528impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
529    pub fn ext(&self) -> &Ext {
530        &self.ext
531    }
532}
533
534impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
535where
536    ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
537    Ext: Provider<Builder = ExtBuilder>,
538    Key: ApiKey,
539    H: Default,
540{
541    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
542        let ext = self.ext.clone();
543
544        self = ext.finish(self)?;
545        let ext = Ext::build(&self)?;
546
547        let ClientBuilder {
548            http_client,
549            base_url,
550            mut headers,
551            api_key,
552            ..
553        } = self;
554
555        if let Some((k, v)) = api_key.into_header().transpose()? {
556            headers.insert(k, v);
557        }
558
559        let http_client = http_client.unwrap_or_default();
560
561        Ok(Client {
562            http_client,
563            base_url: Arc::from(base_url.as_str()),
564            headers: Arc::new(headers),
565            ext,
566        })
567    }
568}
569
570impl<M, Ext, H> CompletionClient for Client<Ext, H>
571where
572    Ext: Capabilities<H, Completion = Capable<M>>,
573    M: CompletionModel<Client = Self>,
574{
575    type CompletionModel = M;
576
577    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
578        M::make(self, model)
579    }
580}
581
582impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
583where
584    Ext: Capabilities<H, Embeddings = Capable<M>>,
585    M: EmbeddingModel<Client = Self>,
586{
587    type EmbeddingModel = M;
588
589    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
590        M::make(self, model, None)
591    }
592
593    fn embedding_model_with_ndims(
594        &self,
595        model: impl Into<String>,
596        ndims: usize,
597    ) -> Self::EmbeddingModel {
598        M::make(self, model, Some(ndims))
599    }
600}
601
602impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
603where
604    Ext: Capabilities<H, Transcription = Capable<M>>,
605    M: TranscriptionModel<Client = Self> + WasmCompatSend,
606{
607    type TranscriptionModel = M;
608
609    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
610        M::make(self, model)
611    }
612}
613
614#[cfg(feature = "image")]
615impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
616where
617    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
618    M: ImageGenerationModel<Client = Self>,
619{
620    type ImageGenerationModel = M;
621
622    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
623        M::make(self, model)
624    }
625}
626
627#[cfg(feature = "audio")]
628impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
629where
630    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
631    M: AudioGenerationModel<Client = Self>,
632{
633    type AudioGenerationModel = M;
634
635    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
636        M::make(self, model)
637    }
638}