Skip to main content

rig/providers/huggingface/
client.rs

1use crate::client::{
2    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient,
4};
5use crate::http_client;
6#[cfg(feature = "image")]
7use crate::image_generation::ImageGenerationError;
8use crate::transcription::TranscriptionError;
9use std::fmt::Debug;
10use std::fmt::Display;
11
12#[derive(Debug, Clone, PartialEq, Default)]
13pub enum SubProvider {
14    #[default]
15    HFInference,
16    Together,
17    SambaNova,
18    Fireworks,
19    Hyperbolic,
20    Nebius,
21    Novita,
22    Custom(String),
23}
24
25impl SubProvider {
26    /// Get the chat completion endpoint for the SubProvider
27    /// Required because Huggingface Inference requires the model
28    /// in the url and in the request body.
29    pub fn completion_endpoint(&self, _model: &str) -> String {
30        "v1/chat/completions".to_string()
31    }
32
33    /// Get the transcription endpoint for the SubProvider
34    /// Required because Huggingface Inference requires the model
35    /// in the url and in the request body.
36    pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
37        match self {
38            SubProvider::HFInference => Ok(format!("/{model}")),
39            _ => Err(TranscriptionError::ProviderError(format!(
40                "transcription endpoint is not supported yet for {self}"
41            ))),
42        }
43    }
44
45    /// Get the image generation endpoint for the SubProvider
46    /// Required because Huggingface Inference requires the model
47    /// in the url and in the request body.
48    #[cfg(feature = "image")]
49    pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
50        match self {
51            SubProvider::HFInference => Ok(format!("/{model}")),
52            _ => Err(ImageGenerationError::ProviderError(format!(
53                "image generation endpoint is not supported yet for {self}"
54            ))),
55        }
56    }
57
58    pub fn model_identifier(&self, model: &str) -> String {
59        match self {
60            SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
61            _ => model.to_string(),
62        }
63    }
64}
65
66impl From<&str> for SubProvider {
67    fn from(s: &str) -> Self {
68        SubProvider::Custom(s.to_string())
69    }
70}
71
72impl From<String> for SubProvider {
73    fn from(value: String) -> Self {
74        SubProvider::Custom(value)
75    }
76}
77
78impl Display for SubProvider {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        let route = match self {
81            SubProvider::HFInference => "hf-inference/models".to_string(),
82            SubProvider::Together => "together".to_string(),
83            SubProvider::SambaNova => "sambanova".to_string(),
84            SubProvider::Fireworks => "fireworks-ai".to_string(),
85            SubProvider::Hyperbolic => "hyperbolic".to_string(),
86            SubProvider::Nebius => "nebius".to_string(),
87            SubProvider::Novita => "novita".to_string(),
88            SubProvider::Custom(route) => route.clone(),
89        };
90
91        write!(f, "{route}")
92    }
93}
94
95// ================================================================
96// Main Huggingface Client
97// ================================================================
98const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co";
99
100#[derive(Debug, Default, Clone)]
101pub struct HuggingFaceExt {
102    subprovider: SubProvider,
103}
104
105#[derive(Debug, Default, Clone)]
106pub struct HuggingFaceBuilder {
107    subprovider: SubProvider,
108}
109
110type HuggingFaceApiKey = BearerAuth;
111
112pub type Client<H = reqwest::Client> = client::Client<HuggingFaceExt, H>;
113pub type ClientBuilder<H = reqwest::Client> =
114    client::ClientBuilder<HuggingFaceBuilder, HuggingFaceApiKey, H>;
115
116impl Provider for HuggingFaceExt {
117    type Builder = HuggingFaceBuilder;
118
119    const VERIFY_PATH: &'static str = "/api/whoami-v2";
120}
121
122impl<H> Capabilities<H> for HuggingFaceExt {
123    type Completion = Capable<super::completion::CompletionModel<H>>;
124    type Embeddings = Nothing;
125    type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
126    type ModelListing = Nothing;
127    #[cfg(feature = "image")]
128    type ImageGeneration = Capable<super::image_generation::ImageGenerationModel<H>>;
129
130    #[cfg(feature = "audio")]
131    type AudioGeneration = Nothing;
132}
133
134impl DebugExt for HuggingFaceExt {
135    fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
136        std::iter::once(("subprovider", (&self.subprovider as &dyn Debug)))
137    }
138}
139
140impl ProviderBuilder for HuggingFaceBuilder {
141    type Extension<H>
142        = HuggingFaceExt
143    where
144        H: http_client::HttpClientExt;
145    type ApiKey = HuggingFaceApiKey;
146
147    const BASE_URL: &'static str = HUGGINGFACE_API_BASE_URL;
148
149    fn build<H>(
150        builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
151    ) -> http_client::Result<Self::Extension<H>>
152    where
153        H: http_client::HttpClientExt,
154    {
155        Ok(HuggingFaceExt {
156            subprovider: builder.ext().subprovider.clone(),
157        })
158    }
159}
160
161impl ProviderClient for Client {
162    type Input = String;
163
164    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
165    /// Panics if the environment variable is not set.
166    fn from_env() -> Self {
167        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
168
169        Self::new(&api_key).unwrap()
170    }
171
172    fn from_val(input: Self::Input) -> Self {
173        Self::new(&input).unwrap()
174    }
175}
176
177impl<H> ClientBuilder<H> {
178    pub fn subprovider(mut self, subprovider: SubProvider) -> Self {
179        *self.ext_mut() = HuggingFaceBuilder { subprovider };
180        self
181    }
182}
183
184impl<H> Client<H> {
185    pub(crate) fn subprovider(&self) -> &SubProvider {
186        &self.ext().subprovider
187    }
188}
189#[cfg(test)]
190mod tests {
191    #[test]
192    fn test_client_initialization() {
193        let _client =
194            crate::providers::huggingface::Client::new("dummy-key").expect("Client::new() failed");
195        let _client_from_builder = crate::providers::huggingface::Client::builder()
196            .api_key("dummy-key")
197            .build()
198            .expect("Client::builder() failed");
199    }
200}