rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13use bytes::Bytes;
14pub use completion::CompletionClient;
15pub use embeddings::EmbeddingsClient;
16use http::{HeaderMap, HeaderName, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, marker::PhantomData};
19use thiserror::Error;
20pub use verify::{VerifyClient, VerifyError};
21
22#[cfg(feature = "image")]
23use crate::image_generation::ImageGenerationModel;
24#[cfg(feature = "image")]
25use image_generation::ImageGenerationClient;
26
27#[cfg(feature = "audio")]
28use crate::audio_generation::*;
29#[cfg(feature = "audio")]
30use audio_generation::*;
31
32use crate::{
33    completion::CompletionModel,
34    embeddings::EmbeddingModel,
35    http_client::{self, Builder, HttpClientExt, LazyBody, Request, Response, make_auth_header},
36    prelude::TranscriptionClient,
37    transcription::TranscriptionModel,
38    wasm_compat::{WasmCompatSend, WasmCompatSync},
39};
40
41#[derive(Debug, Error)]
42#[non_exhaustive]
43pub enum ClientBuilderError {
44    #[error("reqwest error: {0}")]
45    HttpError(
46        #[from]
47        #[source]
48        reqwest::Error,
49    ),
50    #[error("invalid property: {0}")]
51    InvalidProperty(&'static str),
52}
53
54/// Abstracts over the ability to instantiate a client, either via environment variables or some
55/// `Self::Input`
56pub trait ProviderClient {
57    type Input;
58
59    /// Create a client from the process's environment.
60    /// Panics if an environment is improperly configured.
61    fn from_env() -> Self;
62
63    fn from_val(input: Self::Input) -> Self;
64}
65
66use crate::completion::{GetTokenUsage, Usage};
67
68/// The final streaming response from a dynamic client.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct FinalCompletionResponse {
71    pub usage: Option<Usage>,
72}
73
74impl GetTokenUsage for FinalCompletionResponse {
75    fn token_usage(&self) -> Option<Usage> {
76        self.usage
77    }
78}
79
80/// A trait for API keys. This determines whether the key is inserted into a [Client]'s default
81/// headers (in the `Some` case) or handled by a given provider extension (in the `None` case)
82pub trait ApiKey: Sized {
83    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
84        None
85    }
86}
87
88/// An API key which will be inserted into a `Client`'s default headers as a bearer auth token
89pub struct BearerAuth(String);
90
91impl ApiKey for BearerAuth {
92    fn into_header(self) -> Option<http_client::Result<(HeaderName, HeaderValue)>> {
93        Some(make_auth_header(self.0))
94    }
95}
96
97impl<S> From<S> for BearerAuth
98where
99    S: Into<String>,
100{
101    fn from(value: S) -> Self {
102        Self(value.into())
103    }
104}
105
106/// A type containing nothing at all. For `Option`-like behavior on the type level, i.e. to describe
107/// the lack of an API key in a `ClientBuilder`
108#[derive(Debug, Default, Clone, Copy)]
109pub struct Nothing;
110
111impl ApiKey for Nothing {}
112
113impl TryFrom<String> for Nothing {
114    type Error = &'static str;
115
116    fn try_from(_: String) -> Result<Self, Self::Error> {
117        Err(
118            "Tried to create a Nothing from a string - this should not happen, please file an issue",
119        )
120    }
121}
122
123// So that Ollama can ignore auth
124#[derive(Clone)]
125pub struct Client<Ext = Nothing, H = reqwest::Client> {
126    base_url: String,
127    headers: HeaderMap,
128    http_client: H,
129    ext: Ext,
130}
131
132pub trait DebugExt: Debug {
133    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
134        std::iter::empty()
135    }
136}
137
138impl<Ext, H> std::fmt::Debug for Client<Ext, H>
139where
140    Ext: DebugExt,
141    H: std::fmt::Debug,
142{
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        let mut d = &mut f.debug_struct("Client");
145
146        d = d
147            .field("base_url", &self.base_url)
148            .field(
149                "headers",
150                &self.headers.iter().filter_map(|(k, v)| {
151                    (!(k == http::header::AUTHORIZATION || k.as_str().contains("api-key")))
152                        .then_some((k, v))
153                }),
154            )
155            .field("http_client", &self.http_client);
156
157        self.ext
158            .fields()
159            .fold(d, |d, (name, field)| d.field(name, field))
160            .finish()
161    }
162}
163
164pub enum Transport {
165    Http,
166    Sse,
167    NdJson,
168}
169
170/// An API provider extension, this abstracts over extensions which may be use in conjunction with
171/// the `Client<Ext, H>` struct to define the behavior of a provider with respect to networking,
172/// auth, instantiating models
173pub trait Provider: Sized {
174    const VERIFY_PATH: &'static str;
175
176    type Builder: ProviderBuilder;
177
178    fn build<H>(
179        builder: &ClientBuilder<Self::Builder, <Self::Builder as ProviderBuilder>::ApiKey, H>,
180    ) -> http_client::Result<Self>;
181
182    fn build_uri(&self, base_url: &str, path: &str, _transport: Transport) -> String {
183        base_url.to_string() + "/" + path.trim_start_matches('/')
184    }
185
186    fn with_custom(&self, req: http_client::Builder) -> http_client::Result<http_client::Builder> {
187        Ok(req)
188    }
189}
190
191/// A wrapper type providing runtime checks on a provider's capabilities via the [Capability] trait
192pub struct Capable<M>(PhantomData<M>);
193
194pub trait Capability {
195    const CAPABLE: bool;
196}
197
198impl<M> Capability for Capable<M> {
199    const CAPABLE: bool = true;
200}
201
202impl Capability for Nothing {
203    const CAPABLE: bool = false;
204}
205
206/// The capabilities of a given provider, i.e. embeddings, audio transcriptions, text completion
207pub trait Capabilities<H = reqwest::Client> {
208    type Completion: Capability;
209    type Embeddings: Capability;
210    type Transcription: Capability;
211    #[cfg(feature = "image")]
212    type ImageGeneration: Capability;
213    #[cfg(feature = "audio")]
214    type AudioGeneration: Capability;
215}
216
217/// An API provider extension *builder*, this abstracts over provider-specific builders which are
218/// able to configure and produce a given provider's extension type
219///
220/// See [Provider]
221pub trait ProviderBuilder: Sized {
222    type Output: Provider;
223    type ApiKey;
224
225    const BASE_URL: &'static str;
226
227    /// This method can be used to customize the fields of `builder` before it is used to create
228    /// a client. For example, adding default headers
229    fn finish<H>(
230        &self,
231        builder: ClientBuilder<Self, Self::ApiKey, H>,
232    ) -> http_client::Result<ClientBuilder<Self, Self::ApiKey, H>> {
233        Ok(builder)
234    }
235}
236
237impl<Ext, ExtBuilder, Key, H> Client<Ext, H>
238where
239    ExtBuilder: Clone + Default + ProviderBuilder<Output = Ext, ApiKey = Key>,
240    Ext: Provider<Builder = ExtBuilder>,
241    H: Default + HttpClientExt,
242    Key: ApiKey,
243{
244    pub fn new(api_key: impl Into<Key>) -> http_client::Result<Self> {
245        Self::builder().api_key(api_key).build()
246    }
247}
248
249impl<Ext, H> Client<Ext, H> {
250    pub fn base_url(&self) -> &str {
251        &self.base_url
252    }
253
254    pub fn headers(&self) -> &HeaderMap {
255        &self.headers
256    }
257
258    pub fn http_client(&self) -> &H {
259        &self.http_client
260    }
261
262    pub fn ext(&self) -> &Ext {
263        &self.ext
264    }
265
266    /// Create a new client from individual parts.
267    /// This is useful for converting between different provider extensions.
268    pub fn from_parts(base_url: String, headers: HeaderMap, http_client: H, ext: Ext) -> Self {
269        Self {
270            base_url,
271            headers,
272            http_client,
273            ext,
274        }
275    }
276}
277
278impl<Ext, Builder, H> Client<Ext, H>
279where
280    H: Default + HttpClientExt,
281    Ext: Provider<Builder = Builder>,
282    Builder: Default + ProviderBuilder,
283{
284    pub fn builder() -> ClientBuilder<Builder, NeedsApiKey, H> {
285        ClientBuilder::default()
286    }
287}
288
289impl<Ext, H> Client<Ext, H>
290where
291    Ext: Provider,
292{
293    pub fn post<S>(&self, path: S) -> http_client::Result<Builder>
294    where
295        S: AsRef<str>,
296    {
297        let uri = self
298            .ext
299            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
300
301        let mut req = Request::post(uri);
302
303        if let Some(hs) = req.headers_mut() {
304            hs.extend(self.headers.clone());
305        }
306
307        self.ext.with_custom(req)
308    }
309
310    pub fn post_sse<S>(&self, path: S) -> http_client::Result<Builder>
311    where
312        S: AsRef<str>,
313    {
314        let uri = self
315            .ext
316            .build_uri(&self.base_url, path.as_ref(), Transport::Sse);
317
318        let mut req = Request::post(uri);
319
320        if let Some(hs) = req.headers_mut() {
321            hs.extend(self.headers.clone());
322        }
323
324        self.ext.with_custom(req)
325    }
326
327    pub fn get<S>(&self, path: S) -> http_client::Result<Builder>
328    where
329        S: AsRef<str>,
330    {
331        let uri = self
332            .ext
333            .build_uri(&self.base_url, path.as_ref(), Transport::Http);
334
335        self.ext.with_custom(Request::get(uri))
336    }
337}
338
339impl<Ext, H> Client<Ext, H>
340where
341    H: HttpClientExt,
342{
343    pub async fn send<T, U>(
344        &self,
345        mut req: Request<T>,
346    ) -> http_client::Result<Response<LazyBody<U>>>
347    where
348        T: std::fmt::Debug + Into<Bytes> + WasmCompatSend,
349        U: std::fmt::Debug + From<Bytes> + WasmCompatSend + 'static,
350    {
351        req.headers_mut().insert(
352            http::header::CONTENT_TYPE,
353            http::HeaderValue::from_static("application/json"),
354        );
355
356        dbg!(&req);
357
358        self.http_client.send(req).await
359    }
360
361    pub async fn send_streaming<U, R>(
362        &self,
363        req: Request<U>,
364    ) -> Result<http_client::StreamingResponse, http_client::Error>
365    where
366        U: std::fmt::Debug + Into<Bytes>,
367    {
368        self.http_client.send_streaming(req).await
369    }
370}
371
372impl<Ext, H> VerifyClient for Client<Ext, H>
373where
374    H: HttpClientExt,
375    Ext: DebugExt + Provider + WasmCompatSync,
376{
377    async fn verify(&self) -> Result<(), VerifyError> {
378        use http::StatusCode;
379
380        let req = self
381            .get(Ext::VERIFY_PATH)?
382            .body(http_client::NoBody)
383            .map_err(http_client::Error::from)?;
384
385        let response = self.http_client.send(req).await?;
386
387        match response.status() {
388            StatusCode::OK => Ok(()),
389            StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
390                Err(VerifyError::InvalidAuthentication)
391            }
392            StatusCode::INTERNAL_SERVER_ERROR => {
393                let text = http_client::text(response).await?;
394                Err(VerifyError::ProviderError(text))
395            }
396            status if status.as_u16() == 529 => {
397                let text = http_client::text(response).await?;
398                Err(VerifyError::ProviderError(text))
399            }
400            _ => {
401                let status = response.status();
402
403                if status.is_success() {
404                    Ok(())
405                } else {
406                    let text: String = String::from_utf8_lossy(&response.into_body().await?).into();
407                    Err(VerifyError::HttpError(http_client::Error::Instance(
408                        format!("Failed with '{status}': {text}").into(),
409                    )))
410                }
411            }
412        }
413    }
414}
415
416#[derive(Debug, Default, Clone, Copy)]
417pub struct NeedsApiKey;
418
419// ApiKey is generic because Anthropic uses custom auth header, local models like Ollama use none
420#[derive(Clone)]
421pub struct ClientBuilder<Ext, ApiKey = NeedsApiKey, H = reqwest::Client> {
422    base_url: String,
423    api_key: ApiKey,
424    headers: HeaderMap,
425    http_client: Option<H>,
426    ext: Ext,
427}
428
429impl<ExtBuilder, H> Default for ClientBuilder<ExtBuilder, NeedsApiKey, H>
430where
431    H: Default,
432    ExtBuilder: ProviderBuilder + Default,
433{
434    fn default() -> Self {
435        Self {
436            api_key: NeedsApiKey,
437            headers: Default::default(),
438            base_url: ExtBuilder::BASE_URL.into(),
439            http_client: None,
440            ext: Default::default(),
441        }
442    }
443}
444
445impl<Ext, H> ClientBuilder<Ext, NeedsApiKey, H> {
446    /// Set the API key for this client. This *must* be done before the `build` method can be
447    /// called
448    pub fn api_key<ApiKey>(self, api_key: impl Into<ApiKey>) -> ClientBuilder<Ext, ApiKey, H> {
449        ClientBuilder {
450            api_key: api_key.into(),
451            base_url: self.base_url,
452            headers: self.headers,
453            http_client: self.http_client,
454            ext: self.ext,
455        }
456    }
457}
458
459impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H>
460where
461    Ext: Clone,
462{
463    /// Owned map over the ext field
464    pub(crate) fn over_ext<F, NewExt>(self, f: F) -> ClientBuilder<NewExt, ApiKey, H>
465    where
466        F: FnOnce(Ext) -> NewExt,
467    {
468        let ClientBuilder {
469            base_url,
470            api_key,
471            headers,
472            http_client,
473            ext,
474        } = self;
475
476        let new_ext = f(ext.clone());
477
478        ClientBuilder {
479            base_url,
480            api_key,
481            headers,
482            http_client,
483            ext: new_ext,
484        }
485    }
486
487    /// Set the base URL for this client
488    pub fn base_url<S>(self, base_url: S) -> Self
489    where
490        S: AsRef<str>,
491    {
492        Self {
493            base_url: base_url.as_ref().to_string(),
494            ..self
495        }
496    }
497
498    /// Set the HTTP backend used in this client
499    pub fn http_client<U>(self, http_client: U) -> ClientBuilder<Ext, ApiKey, U> {
500        ClientBuilder {
501            http_client: Some(http_client),
502            base_url: self.base_url,
503            api_key: self.api_key,
504            headers: self.headers,
505            ext: self.ext,
506        }
507    }
508
509    pub(crate) fn headers_mut(&mut self) -> &mut HeaderMap {
510        &mut self.headers
511    }
512
513    pub(crate) fn ext_mut(&mut self) -> &mut Ext {
514        &mut self.ext
515    }
516}
517
518impl<Ext, ApiKey, H> ClientBuilder<Ext, ApiKey, H> {
519    pub(crate) fn get_api_key(&self) -> &ApiKey {
520        &self.api_key
521    }
522}
523
524impl<Ext, Key, H> ClientBuilder<Ext, Key, H> {
525    pub fn ext(&self) -> &Ext {
526        &self.ext
527    }
528}
529
530impl<Ext, ExtBuilder, Key, H> ClientBuilder<ExtBuilder, Key, H>
531where
532    ExtBuilder: Clone + ProviderBuilder<Output = Ext, ApiKey = Key> + Default,
533    Ext: Provider<Builder = ExtBuilder>,
534    Key: ApiKey,
535    H: Default,
536{
537    pub fn build(mut self) -> http_client::Result<Client<ExtBuilder::Output, H>> {
538        let ext = self.ext.clone();
539
540        self = ext.finish(self)?;
541        let ext = Ext::build(&self)?;
542
543        let ClientBuilder {
544            http_client,
545            base_url,
546            mut headers,
547            api_key,
548            ..
549        } = self;
550
551        if let Some((k, v)) = api_key.into_header().transpose()? {
552            headers.insert(k, v);
553        }
554
555        let http_client = http_client.unwrap_or_default();
556
557        Ok(Client {
558            http_client,
559            base_url,
560            headers,
561            ext,
562        })
563    }
564}
565
566impl<M, Ext, H> CompletionClient for Client<Ext, H>
567where
568    Ext: Capabilities<H, Completion = Capable<M>>,
569    M: CompletionModel<Client = Self>,
570{
571    type CompletionModel = M;
572
573    fn completion_model(&self, model: impl Into<String>) -> Self::CompletionModel {
574        M::make(self, model)
575    }
576}
577
578impl<M, Ext, H> EmbeddingsClient for Client<Ext, H>
579where
580    Ext: Capabilities<H, Embeddings = Capable<M>>,
581    M: EmbeddingModel<Client = Self>,
582{
583    type EmbeddingModel = M;
584
585    fn embedding_model(&self, model: impl Into<String>) -> Self::EmbeddingModel {
586        M::make(self, model, None)
587    }
588
589    fn embedding_model_with_ndims(
590        &self,
591        model: impl Into<String>,
592        ndims: usize,
593    ) -> Self::EmbeddingModel {
594        M::make(self, model, Some(ndims))
595    }
596}
597
598impl<M, Ext, H> TranscriptionClient for Client<Ext, H>
599where
600    Ext: Capabilities<H, Transcription = Capable<M>>,
601    M: TranscriptionModel<Client = Self> + WasmCompatSend,
602{
603    type TranscriptionModel = M;
604
605    fn transcription_model(&self, model: impl Into<String>) -> Self::TranscriptionModel {
606        M::make(self, model)
607    }
608}
609
610#[cfg(feature = "image")]
611impl<M, Ext, H> ImageGenerationClient for Client<Ext, H>
612where
613    Ext: Capabilities<H, ImageGeneration = Capable<M>>,
614    M: ImageGenerationModel<Client = Self>,
615{
616    type ImageGenerationModel = M;
617
618    fn image_generation_model(&self, model: impl Into<String>) -> Self::ImageGenerationModel {
619        M::make(self, model)
620    }
621}
622
623#[cfg(feature = "audio")]
624impl<M, Ext, H> AudioGenerationClient for Client<Ext, H>
625where
626    Ext: Capabilities<H, AudioGeneration = Capable<M>>,
627    M: AudioGenerationModel<Client = Self>,
628{
629    type AudioGenerationModel = M;
630
631    fn audio_generation_model(&self, model: impl Into<String>) -> Self::AudioGenerationModel {
632        M::make(self, model)
633    }
634}