rig_core/providers/huggingface/
client.rs1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::http_client;
6#[cfg(feature = "image")]
7use crate::image_generation::ImageGenerationError;
8use crate::transcription::TranscriptionError;
9use std::fmt::Debug;
10use std::fmt::Display;
11
12#[derive(Debug, Clone, PartialEq, Default)]
13pub enum SubProvider {
14 #[default]
15 HFInference,
16 Together,
17 SambaNova,
18 Fireworks,
19 Hyperbolic,
20 Nebius,
21 Novita,
22 Custom(String),
23}
24
25impl SubProvider {
26 pub fn completion_endpoint(&self, _model: &str) -> String {
30 "v1/chat/completions".to_string()
31 }
32
33 pub fn transcription_endpoint(&self, model: &str) -> Result<String, TranscriptionError> {
37 match self {
38 SubProvider::HFInference => Ok(format!("/{model}")),
39 _ => Err(TranscriptionError::ProviderError(format!(
40 "transcription endpoint is not supported yet for {self}"
41 ))),
42 }
43 }
44
45 #[cfg(feature = "image")]
49 pub fn image_generation_endpoint(&self, model: &str) -> Result<String, ImageGenerationError> {
50 match self {
51 SubProvider::HFInference => Ok(format!("/{model}")),
52 _ => Err(ImageGenerationError::ProviderError(format!(
53 "image generation endpoint is not supported yet for {self}"
54 ))),
55 }
56 }
57
58 pub fn model_identifier(&self, model: &str) -> String {
59 match self {
60 SubProvider::Fireworks => format!("accounts/fireworks/models/{model}"),
61 _ => model.to_string(),
62 }
63 }
64}
65
66impl From<&str> for SubProvider {
67 fn from(s: &str) -> Self {
68 SubProvider::Custom(s.to_string())
69 }
70}
71
72impl From<String> for SubProvider {
73 fn from(value: String) -> Self {
74 SubProvider::Custom(value)
75 }
76}
77
78impl Display for SubProvider {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 let route = match self {
81 SubProvider::HFInference => "hf-inference/models".to_string(),
82 SubProvider::Together => "together".to_string(),
83 SubProvider::SambaNova => "sambanova".to_string(),
84 SubProvider::Fireworks => "fireworks-ai".to_string(),
85 SubProvider::Hyperbolic => "hyperbolic".to_string(),
86 SubProvider::Nebius => "nebius".to_string(),
87 SubProvider::Novita => "novita".to_string(),
88 SubProvider::Custom(route) => route.clone(),
89 };
90
91 write!(f, "{route}")
92 }
93}
94
95const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co";
99
100#[derive(Debug, Default, Clone)]
101pub struct HuggingFaceExt {
102 subprovider: SubProvider,
103}
104
105#[derive(Debug, Default, Clone)]
106pub struct HuggingFaceBuilder {
107 subprovider: SubProvider,
108}
109
110type HuggingFaceApiKey = BearerAuth;
111
112pub type Client<H = reqwest::Client> = client::Client<HuggingFaceExt, H>;
113pub type ClientBuilder<H = crate::markers::Missing> =
114 client::ClientBuilder<HuggingFaceBuilder, HuggingFaceApiKey, H>;
115
116impl Provider for HuggingFaceExt {
117 type Builder = HuggingFaceBuilder;
118
119 const VERIFY_PATH: &'static str = "/api/whoami-v2";
120}
121
122impl<H> Capabilities<H> for HuggingFaceExt {
123 type Completion = Capable<super::completion::CompletionModel<H>>;
124 type Embeddings = Nothing;
125 type Transcription = Capable<super::transcription::TranscriptionModel<H>>;
126 type ModelListing = Nothing;
127 #[cfg(feature = "image")]
128 type ImageGeneration = Capable<super::image_generation::ImageGenerationModel<H>>;
129
130 #[cfg(feature = "audio")]
131 type AudioGeneration = Nothing;
132 type Rerank = Nothing;
133}
134
135impl DebugExt for HuggingFaceExt {
136 fn fields(&self) -> impl Iterator<Item = (&'static str, &dyn Debug)> {
137 std::iter::once(("subprovider", (&self.subprovider as &dyn Debug)))
138 }
139}
140
141impl ProviderBuilder for HuggingFaceBuilder {
142 type Extension<H>
143 = HuggingFaceExt
144 where
145 H: http_client::HttpClientExt;
146 type ApiKey = HuggingFaceApiKey;
147
148 const BASE_URL: &'static str = HUGGINGFACE_API_BASE_URL;
149
150 fn build<H>(
151 builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
152 ) -> http_client::Result<Self::Extension<H>>
153 where
154 H: http_client::HttpClientExt,
155 {
156 Ok(HuggingFaceExt {
157 subprovider: builder.ext().subprovider.clone(),
158 })
159 }
160}
161
162impl ProviderClient for Client {
163 type Input = String;
164 type Error = crate::client::ProviderClientError;
165
166 fn from_env() -> Result<Self, Self::Error> {
168 let api_key = crate::client::required_env_var("HUGGINGFACE_API_KEY")?;
169
170 Self::new(&api_key).map_err(Into::into)
171 }
172
173 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
174 Self::new(&input).map_err(Into::into)
175 }
176}
177
178impl<H> ClientBuilder<H> {
179 pub fn subprovider(mut self, subprovider: SubProvider) -> Self {
180 *self.ext_mut() = HuggingFaceBuilder { subprovider };
181 self
182 }
183}
184
185impl<H> Client<H> {
186 pub(crate) fn subprovider(&self) -> &SubProvider {
187 &self.ext().subprovider
188 }
189}
190#[cfg(test)]
191mod tests {
192 #[test]
193 fn test_client_initialization() {
194 let _client =
195 crate::providers::huggingface::Client::new("dummy-key").expect("Client::new() failed");
196 let _client_from_builder = crate::providers::huggingface::Client::builder()
197 .api_key("dummy-key")
198 .build()
199 .expect("Client::builder() failed");
200 }
201}