Skip to main content

rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod model_listing;
11pub mod transcription;
12pub mod verify;
13
14use bytes::Bytes;
15pub use completion::CompletionClient;
16pub use embeddings::EmbeddingsClient;
17use http::{HeaderMap, HeaderName, HeaderValue};
18pub use model_listing::{ModelLister, ModelListingClient};
19use serde::{Deserialize, Serialize};
20use std::{fmt::Debug, marker::PhantomData, sync::Arc};
21use thiserror::Error;
22pub use verify::{VerifyClient, VerifyError};
23
24#[cfg(feature = "image")]
25use crate::image_generation::ImageGenerationModel;
26#[cfg(feature = "image")]
27use image_generation::ImageGenerationClient;
28
29#[cfg(feature = "audio")]
30use crate::audio_generation::*;
31#[cfg(feature = "audio")]
32use audio_generation::*;
33
34use crate::{
35    completion::CompletionModel,
36    embeddings::EmbeddingModel,
37    http_client::{
38        self, Builder, HttpClientExt, LazyBody, MultipartForm, Request, Response, make_auth_header,
39    },
40    prelude::TranscriptionClient,
41    transcription::TranscriptionModel,
42    wasm_compat::{WasmCompatSend, WasmCompatSync},
43};
44
45#[derive(Debug, Error)]
46#[non_exhaustive]
47pub enum ClientBuilderError {
48    #[error("reqwest error: {0}")]
49    HttpError(
50        #[from]
51        #[source]
52        reqwest::Error,
53    ),
54    #[error("invalid property: {0}")]
55    InvalidProperty(&'static str),
56}
57
58/// Abstracts over the ability to instantiate a client, either via environment variables or some
59/// `Self::Input`
60pub trait ProviderClient {
61    type Input;
62
63    /// Create a client from the process's environment.
64    /// Panics if an environment is improperly configured.
65    fn from_env() -> Self;
66
67    fn from_val(input: Self::Input) -> Self;
68}
69
70use crate::completion::{GetTokenUsage, Usage};
71
72/// The final streaming response from a dynamic client.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct FinalCompletionResponse {
75    pub usage: Option<Usage>,
76}
77
78impl GetTokenUsage for FinalCompletionResponse {
79    fn token_usage(&self) -> Option<Usage> {
80        self.usage
81    }
82}
83
84/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
85/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
86pub trait ApiKey: Sized {
87    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
88        None
89    }
90}
91
92/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
93pub struct BearerAuth(String);
94
95impl ApiKey for BearerAuth {
96    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
97        Some(make_auth_header(self.0))
98    }
99}
100
101impl<S> From<S> for BearerAuth
102where
103    S: Into<String>,
104{
105    fn from(value: S) -> Self {
106        Self(value.into())
107    }
108}
109
110/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
111/// the lack of a capability or field (an API key, for instance)
112#[derive(Debug, Default, Clone, Copy)]
113pub struct Nothing;
114
115impl ApiKey for Nothing {}
116
117impl TryFrom<String> for Nothing {
118    type Error = &'static str;
119
120    fn try_from(_: String) -> Result<Self, Self::Error> {
121        Err(
122            "Tried to create a Nothing from a string - this should not happen, please file an issue",
123        )
124    }
125}
126
127#[derive(Clone)]
128pub struct Client<Ext = Nothing, H = reqwest::Client> {
129    base_url: Arc<str>,
130    headers: Arc<HeaderMap>,
131    http_client: H,
132    ext: Ext,
133}
134
135pub trait DebugExt: Debug {
136    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
137        std::iter::empty()
138    }
139}
140
141impl<Ext, H> std::fmt::Debug for Client<Ext, H>
142where
143    Ext: DebugExt,
144    H: std::fmt::Debug,
145{
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        let mut d = &mut f.debug_struct("Client");
148
149        d = d
150            .field("base_url", &self.base_url)
151            .field(
152                "headers",
153                &self
154                    .headers
155                    .iter()
156                    .filter_map(|(k, v)| {
157                        if k == http::header::AUTHORIZATION || k.as_str().contains("api-key") {
158                            None
159                        } else {
160                            Some((k, v))
161                        }
162                    })
163                    .collect::<Vec<(&HeaderName, &HeaderValue)>>(),
164            )
165            .field("http_client", &self.http_client);
166
167        self.ext
168            .fields()
169            .fold(d, |d, (name, field)| d.field(name, field))
170            .finish()
171    }
172}
173
174pub enum Transport {
175    Http,
176    Sse,
177    NdJson,
178}
179
180/// An API provider extension, this abstracts over extensions which may be used in conjunction with
181/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
182/// auth, instantiating models
183pub trait Provider: Sized {
184    /// The builder type that constructs this provider extension.
185    /// This associates extensions with their builders for type inference.
186    type Builder: ProviderBuilder;
187
188    const VERIFY_PATH: &'static str;
189
190    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
191        // Some providers (like Azure) have a blank base URL to allow users to input their own endpoints.
192        let base_url = if base_url.is_empty() {
193            base_url.to_string()
194        } else {
195            base_url.to_string() + "/"
196        };
197
198        base_url.to_string() + path.trim_start_matches('/')
199    }
200
201    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
202        Ok(req)
203    }
204}
205
206/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
207pub struct Capable<M>(PhantomData<M>);
208
209pub trait Capability {
210    const CAPABLE: bool;
211}
212
213impl<M> Capability for Capable<M> {
214    const CAPABLE: bool = true;
215}
216
217impl Capability for Nothing {
218    const CAPABLE: bool = false;
219}
220
221/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
222pub trait Capabilities<H = reqwest::Client> {
223    type Completion: Capability;
224    type Embeddings: Capability;
225    type Transcription: Capability;
226    type ModelListing: Capability;
227    #[cfg(feature = "image")]
228    type ImageGeneration: Capability;
229    #[cfg(feature = "audio")]
230    type AudioGeneration: Capability;
231}
232
233/// An API provider extension *builder*, this abstracts over provider-specific builders which are
234/// able to configure and produce a given provider's extension type
235///
236/// See [Provider]
237pub trait ProviderBuilder: Sized + Default + Clone {
238    type Extension<H>: Provider
239    where
240        H: HttpClientExt;
241    type ApiKey: ApiKey;
242
243    const BASE_URL: &'static str;
244
245    /// Build the provider extension from the client builder configuration.
246    fn build<H>(
247        builder: &ClientBuilder<Self, Self::ApiKey, H>,
248    ) -> http_client::Result<Self::Extension<H>>
249    where
250        H: HttpClientExt;
251
252    /// This method can be used to customize the fields of `builder` before it is used to create
253    /// a client. For example, adding default headers
254    fn finish<H>(
255        &self,
256        builder: ClientBuilder<Self, Self::ApiKey, H>,
257    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
258        Ok(builder)
259    }
260}
261
262impl<Ext> Client<Ext, reqwest::Client>
263where
264    Ext: Provider,
265    Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
266{
267    pub fn new(
268        api_key: impl Into<<Ext::Builder as ProviderBuilder>::ApiKey>,
269    ) -> http_client::Result<Self> {
270        Self::builder().api_key(api_key).build()
271    }
272}
273
274impl<Ext, H> Client<Ext, H> {
275    pub fn base_url(&self) -> &str {
276        &self.base_url
277    }
278
279    pub fn headers(&self) -> &HeaderMap {
280        &self.headers
281    }
282
283    pub fn ext(&self) -> &Ext {
284        &self.ext
285    }
286
287    pub fn with_ext<NewExt>(self, new_ext: NewExt) -> Client<NewExt, H> {
288        Client {
289            base_url: self.base_url,
290            headers: self.headers,
291            http_client: self.http_client,
292            ext: new_ext,
293        }
294    }
295}
296
297impl<Ext, H> HttpClientExt for Client<Ext, H>
298where
299    H: HttpClientExt + 'static,
300    Ext: WasmCompatSend + WasmCompatSync + 'static,
301{
302    fn send<T, U>(
303        &self,
304        mut req: Request<T>,
305    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
306    where
307        T: Into<Bytes> + WasmCompatSend,
308        U: From<Bytes>,
309        U: WasmCompatSend + 'static,
310    {
311        req.headers_mut().insert(
312            http::header::CONTENT_TYPE,
313            http::HeaderValue::from_static("application/json"),
314        );
315
316        self.http_client.send(req)
317    }
318
319    fn send_multipart<U>(
320        &self,
321        req: Request<MultipartForm>,
322    ) -> impl Future<Output = http_client::Result<Response<LazyBody<U>>>> + WasmCompatSend + 'static
323    where
324        U: From<Bytes>,
325        U: WasmCompatSend + 'static,
326    {
327        self.http_client.send_multipart(req)
328    }
329
330    fn send_streaming<T>(
331        &self,
332        mut req: Request<T>,
333    ) -> impl Future<Output = http_client::Result<http_client::StreamingResponse>> + WasmCompatSend
334    where
335        T: Into<Bytes>,
336    {
337        req.headers_mut().insert(
338            http::header::CONTENT_TYPE,
339            http::HeaderValue::from_static("application/json"),
340        );
341
342        self.http_client.send_streaming(req)
343    }
344}
345
346impl<Ext> Client<Ext, reqwest::Client>
347where
348    Ext: Provider,
349    Ext::Builder: ProviderBuilder<Extension<reqwest::Client> = Ext> + Default,
350{
351    pub fn builder() -> ClientBuilder<Ext::Builder, NeedsApiKey, reqwest::Client> {
352        ClientBuilder {
353            api_key: NeedsApiKey,
354            headers: Default::default(),
355            base_url: <Ext::Builder as ProviderBuilder>::BASE_URL.into(),
356            http_client: None,
357            ext: Default::default(),
358        }
359    }
360}
361
362impl<Ext, H> Client<Ext, H>
363where
364    Ext: Provider,
365{
366    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
367    where
368        S: AsRef<str>,
369    {
370        let uri = self
371            .ext
372            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
373
374        let mut req = Request::post(uri);
375
376        if let Some(hs) = req.headers_mut() {
377            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
378        }
379
380        self.ext.with_custom(req)
381    }
382
383    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
384    where
385        S: AsRef<str>,
386    {
387        let uri = self
388            .ext
389            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
390
391        let mut req = Request::post(uri);
392
393        if let Some(hs) = req.headers_mut() {
394            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
395        }
396
397        self.ext.with_custom(req)
398    }
399
400    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
401    where
402        S: AsRef<str>,
403    {
404        let uri = self
405            .ext
406            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
407
408        let mut req = Request::get(uri);
409
410        if let Some(hs) = req.headers_mut() {
411            hs.extend(self.headers.iter().map(|(k, v)| (k.clone(), v.clone())));
412        }
413
414        self.ext.with_custom(req)
415    }
416}
417
418impl<Ext, H> VerifyClient for Client<Ext, H>
419where
420    H: HttpClientExt,
421    Ext: DebugExt + Provider + WasmCompatSync,
422{
423    async fn verify(&self) -> Result<(), VerifyError> {
424        use http::StatusCode;
425
426        let req = self
427            .get(Ext::VERIFY_PATH)?
428            .body(http_client::NoBody)
429            .map_err(http_client::Error::from)?;
430
431        let response = self.http_client.send(req).await?;
432
433        match response.status() {
434            StatusCode::OK => Ok(()),
435            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
436                Err(VerifyError::InvalidAuthentication)
437            }
438            StatusCode::INTERNAL_SERVER_ERROR => {
439                let text = http_client::text(response).await?;
440                Err(VerifyError::ProviderError(text))
441            }
442            status if status.as_u16() == 529 => {
443                let text = http_client::text(response).await?;
444                Err(VerifyError::ProviderError(text))
445            }
446            _ => {
447                let status = response.status();
448
449                if status.is_success() {
450                    Ok(())
451                } else {
452                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
453                    Err(VerifyError::HttpError(http_client::Error::Instance(
454                        format!("Failed with '{status}': {text}").into(),
455                    )))
456                }
457            }
458        }
459    }
460}
461
462#[derive(Debug, Default, Clone, Copy)]
463pub struct NeedsApiKey;
464
465// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
466#[derive(Clone)]
467pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
468    base_url: String,
469    api_key: ApiKey,
470    headers: HeaderMap,
471    http_client: Option<H>,
472    ext: Ext,
473}
474
475impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
476where
477    H: Default,
478    ExtBuilder: ProviderBuilder + Default,
479{
480    fn default() -> Self {
481        Self {
482            api_key: NeedsApiKey,
483            headers: Default::default(),
484            base_url: ExtBuilder::BASE_URL.into(),
485            http_client: None,
486            ext: Default::default(),
487        }
488    }
489}
490
491impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
492    /// Set the API key for this client. This *must* be done before the `build` method can be
493    /// called
494    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
495        ClientBuilder {
496            api_key: api_key.into(),
497            base_url: self.base_url,
498            headers: self.headers,
499            http_client: self.http_client,
500            ext: self.ext,
501        }
502    }
503}
504
505impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
506where
507    Ext: Clone,
508{
509    /// Owned map over the ext field
510    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
511    where
512        F: FnOnce(Ext) -> NewExt,
513    {
514        let ClientBuilder {
515            base_url,
516            api_key,
517            headers,
518            http_client,
519            ext,
520        } = self;
521
522        let new_ext = f(ext.clone());
523
524        ClientBuilder {
525            base_url,
526            api_key,
527            headers,
528            http_client,
529            ext: new_ext,
530        }
531    }
532
533    /// Set the base URL for this client
534    pub fn base_url<S>(self, base_url: S) -> Self
535    where
536        S: AsRef<str>,
537    {
538        Self {
539            base_url: base_url.as_ref().to_string(),
540            ..self
541        }
542    }
543
544    /// Set the HTTP backend used in this client
545    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
546        ClientBuilder {
547            http_client: Some(http_client),
548            base_url: self.base_url,
549            api_key: self.api_key,
550            headers: self.headers,
551            ext: self.ext,
552        }
553    }
554
555    /// Set the HTTP headers used in this client
556    pub fn http_headers(self, headers: HeaderMap) -> Self {
557        Self { headers, ..self }
558    }
559
560    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
561        &mut self.headers
562    }
563
564    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
565        &mut self.ext
566    }
567}
568
569impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
570    pub(crate) fn get_api_key(&self) -> &ApiKey {
571        &self.api_key
572    }
573}
574
575impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
576    pub fn ext(&self) -> &Ext {
577        &self.ext
578    }
579}
580
581impl<ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
582where
583    ExtBuilder: ProviderBuilder<ApiKey = Key>,
584    Key: ApiKey,
585    H: Default + HttpClientExt,
586{
587    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Extension<H>, H>> {
588        let ext_builder = self.ext.clone();
589
590        self = ext_builder.finish(self)?;
591        let ext = ExtBuilder::build(&self)?;
592
593        let ClientBuilder {
594            http_client,
595            base_url,
596            mut headers,
597            api_key,
598            ..
599        } = self;
600
601        if let Some((k, v)) = api_key.into_header().transpose()? {
602            headers.insert(k, v);
603        }
604
605        let http_client = http_client.unwrap_or_default();
606
607        Ok(Client {
608            http_client,
609            base_url: Arc::from(base_url.as_str()),
610            headers: Arc::new(headers),
611            ext,
612        })
613    }
614}
615
616impl<M, Ext, H> CompletionClient for Client<Ext, H>
617where
618    Ext: Capabilities<H, Completion = Capable<M>>,
619    M: CompletionModel<Client = Self>,
620{
621    type CompletionModel = M;
622
623    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
624        M::make(self, model)
625    }
626}
627
628impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
629where
630    Ext: Capabilities<H, Embeddings = Capable<M>>,
631    M: EmbeddingModel<Client = Self>,
632{
633    type EmbeddingModel = M;
634
635    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
636        M::make(self, model, None)
637    }
638
639    fn embedding_model_with_ndims(
640        &self,
641        model: impl Into<String>,
642        ndims: usize,
643    ) -> Self::EmbeddingModel {
644        M::make(self, model, Some(ndims))
645    }
646}
647
648impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
649where
650    Ext: Capabilities<H, Transcription = Capable<M>>,
651    M: TranscriptionModel<Client = Self> + WasmCompatSend,
652{
653    type TranscriptionModel = M;
654
655    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
656        M::make(self, model)
657    }
658}
659
660#[cfg(feature = "image")]
661impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
662where
663    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
664    M: ImageGenerationModel<Client = Self>,
665{
666    type ImageGenerationModel = M;
667
668    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
669        M::make(self, model)
670    }
671}
672
673#[cfg(feature = "audio")]
674impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
675where
676    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
677    M: AudioGenerationModel<Client = Self>,
678{
679    type AudioGenerationModel = M;
680
681    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
682        M::make(self, model)
683    }
684}
685
686impl<M, Ext, H> ModelListingClient for Client<Ext, H>
687where
688    Ext: Capabilities<H, ModelListing = Capable<M>> + Clone,
689    M: ModelLister<H, Client = Self> + Send + Sync + Clone + 'static,
690    H: Send + Sync + Clone,
691{
692    fn list_models(
693        &self,
694    ) -> impl std::future::Future<
695        Output = Result<crate::model::ModelList, crate::model::ModelListingError>,
696    > + WasmCompatSend {
697        let lister = M::new(self.clone());
698        async move { lister.list_all().await }
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use crate::providers::anthropic;
705
706    /// Type-level test that `Client::builder()` methods do not require annotation to determine
707    /// backig HTTP client
708    #[test]
709    fn ensures_client_builder_no_annotation() {
710        let http_client = reqwest::Client::default();
711        let _ = anthropic::Client::builder()
712            .http_client(http_client)
713            .api_key("Foo")
714            .build()
715            .unwrap();
716    }
717}