rig/providers/huggingface/
client.rs

1use std::fmt::Display;
2
3use super::completion::CompletionModel;
4use crate::agent::AgentBuilder;
5use crate::providers::huggingface::transcription::TranscriptionModel;
6
7// ================================================================
8// Main Huggingface Client
9// ================================================================
10const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
11
12#[derive(Debug, Clone, PartialEq, Default)]
13pub enum SubProvider {
14    #[default]
15    HFInference,
16    Together,
17    SambaNova,
18    Fireworks,
19    Hyperbolic,
20    Nebius,
21    Novita,
22    Custom(String),
23}
24
25impl SubProvider {
26    /// Get the chat completion endpoint for the SubProvider
27    /// Required because Huggingface Inference requires the model
28    /// in the url and in the request body.
29    pub fn completion_endpoint(&self, model: &str) -> String {
30        match self {
31            SubProvider::HFInference => format!("/{}/v1/chat/completions", model),
32            _ => "/v1/chat/completions".to_string(),
33        }
34    }
35
36    /// Get the transcription endpoint for the SubProvider
37    /// Required because Huggingface Inference requires the model
38    /// in the url and in the request body.
39    pub fn transcription_endpoint(&self, model: &str) -> String {
40        match self {
41            SubProvider::HFInference => format!("hf-inference/models/{}", model),
42            _ => panic!("transcription endpoint is not supported yet for {}", self),
43        }
44    }
45
46    pub fn model_identifier(&self, model: &str) -> String {
47        match self {
48            SubProvider::Fireworks => format!("accounts/fireworks/models/{}", model),
49            _ => model.to_string(),
50        }
51    }
52}
53
54impl From<&str> for SubProvider {
55    fn from(s: &str) -> Self {
56        SubProvider::Custom(s.to_string())
57    }
58}
59
60impl From<String> for SubProvider {
61    fn from(value: String) -> Self {
62        SubProvider::Custom(value)
63    }
64}
65
66impl Display for SubProvider {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        let route = match self {
69            SubProvider::HFInference => "hf-inference/models".to_string(),
70            SubProvider::Together => "together".to_string(),
71            SubProvider::SambaNova => "sambanova".to_string(),
72            SubProvider::Fireworks => "fireworks-ai".to_string(),
73            SubProvider::Hyperbolic => "hyperbolic".to_string(),
74            SubProvider::Nebius => "nebius".to_string(),
75            SubProvider::Novita => "novita".to_string(),
76            SubProvider::Custom(route) => route.clone(),
77        };
78
79        write!(f, "{}", route)
80    }
81}
82
83pub struct ClientBuilder {
84    api_key: String,
85    base_url: String,
86    sub_provider: SubProvider,
87}
88
89impl ClientBuilder {
90    pub fn new(api_key: &str) -> Self {
91        Self {
92            api_key: api_key.to_string(),
93            base_url: HUGGINGFACE_API_BASE_URL.to_string(),
94            sub_provider: SubProvider::default(),
95        }
96    }
97
98    pub fn base_url(mut self, base_url: &str) -> Self {
99        self.base_url = base_url.to_string();
100        self
101    }
102
103    pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
104        self.sub_provider = provider.into();
105        self
106    }
107
108    pub fn build(self) -> Client {
109        let route = self.sub_provider.to_string();
110
111        let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
112
113        Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
114    }
115}
116
117#[derive(Clone)]
118pub struct Client {
119    base_url: String,
120    http_client: reqwest::Client,
121    pub(crate) sub_provider: SubProvider,
122}
123
124impl Client {
125    /// Create a new Huggingface client with the given API key.
126    pub fn new(api_key: &str) -> Self {
127        Self::from_url(api_key, HUGGINGFACE_API_BASE_URL, SubProvider::HFInference)
128    }
129
130    /// Create a new Client with the given API key and base API URL.
131    pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
132        let http_client = reqwest::Client::builder()
133            .default_headers({
134                let mut headers = reqwest::header::HeaderMap::new();
135                headers.insert(
136                    "Authorization",
137                    format!("Bearer {api_key}")
138                        .parse()
139                        .expect("Failed to parse API key"),
140                );
141                headers.insert(
142                    "Content-Type",
143                    "application/json"
144                        .parse()
145                        .expect("Failed to parse Content-Type"),
146                );
147                headers
148            })
149            .build()
150            .expect("Failed to build HTTP client");
151
152        Self {
153            base_url: base_url.to_owned(),
154            http_client,
155            sub_provider,
156        }
157    }
158    /// Create a new Huggingface client from the `HUGGINGFACE_API_KEY` environment variable.
159    /// Panics if the environment variable is not set.
160    pub fn from_env() -> Self {
161        let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
162        Self::new(&api_key)
163    }
164
165    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
166        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
167        self.http_client.post(url)
168    }
169
170    /// Create a new completion model with the given name
171    ///
172    /// # Example
173    /// ```
174    /// use rig::providers::huggingface::{Client, self}
175    ///
176    /// // Initialize the Huggingface client
177    /// let client = Client::new("your-huggingface-api-key");
178    ///
179    /// let completion_model = client.completion_model(huggingface::GEMMA_2);
180    /// ```
181    pub fn completion_model(&self, model: &str) -> CompletionModel {
182        CompletionModel::new(self.clone(), model)
183    }
184
185    /// Create a new transcription model with the given name
186    ///
187    /// # Example
188    /// ```
189    /// use rig::providers::huggingface::{Client, self}
190    ///
191    /// // Initialize the Huggingface client
192    /// let client = Client::new("your-huggingface-api-key");
193    ///
194    /// let completion_model = client.transcription_model(huggingface::WHISPER_LARGE_V3);
195    /// ```
196    pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
197        TranscriptionModel::new(self.clone(), model)
198    }
199
200    /// Create an agent builder with the given completion model.
201    ///
202    /// # Example
203    /// ```
204    /// use rig::providers::huggingface::{Client, self};
205    ///
206    /// // Initialize the Anthropic client
207    /// let client = Client::new("your-huggingface-api-key");
208    ///
209    /// let agent = client.agent(huggingface::GEMMA_2)
210    ///    .preamble("You are comedian AI with a mission to make people laugh.")
211    ///    .temperature(0.0)
212    ///    .build();
213    /// ```
214    pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
215        AgentBuilder::new(self.completion_model(model))
216    }
217}