rig/providers/huggingface/
client.rs1use std::fmt::Display;
2
3use super::completion::CompletionModel;
4use crate::agent::AgentBuilder;
5use crate::providers::huggingface::transcription::TranscriptionModel;
6
7const HUGGINGFACE_API_BASE_URL: &str = "https://router.huggingface.co/";
11
12#[derive(Debug, Clone, PartialEq, Default)]
13pub enum SubProvider {
14 #[default]
15 HFInference,
16 Together,
17 SambaNova,
18 Fireworks,
19 Hyperbolic,
20 Nebius,
21 Novita,
22 Custom(String),
23}
24
25impl SubProvider {
26 pub fn completion_endpoint(&self, model: &str) -> String {
30 match self {
31 SubProvider::HFInference => format!("/{}/v1/chat/completions", model),
32 _ => "/v1/chat/completions".to_string(),
33 }
34 }
35
36 pub fn transcription_endpoint(&self, model: &str) -> String {
40 match self {
41 SubProvider::HFInference => format!("hf-inference/models/{}", model),
42 _ => panic!("transcription endpoint is not supported yet for {}", self),
43 }
44 }
45
46 pub fn model_identifier(&self, model: &str) -> String {
47 match self {
48 SubProvider::Fireworks => format!("accounts/fireworks/models/{}", model),
49 _ => model.to_string(),
50 }
51 }
52}
53
54impl From<&str> for SubProvider {
55 fn from(s: &str) -> Self {
56 SubProvider::Custom(s.to_string())
57 }
58}
59
60impl From<String> for SubProvider {
61 fn from(value: String) -> Self {
62 SubProvider::Custom(value)
63 }
64}
65
66impl Display for SubProvider {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 let route = match self {
69 SubProvider::HFInference => "hf-inference/models".to_string(),
70 SubProvider::Together => "together".to_string(),
71 SubProvider::SambaNova => "sambanova".to_string(),
72 SubProvider::Fireworks => "fireworks-ai".to_string(),
73 SubProvider::Hyperbolic => "hyperbolic".to_string(),
74 SubProvider::Nebius => "nebius".to_string(),
75 SubProvider::Novita => "novita".to_string(),
76 SubProvider::Custom(route) => route.clone(),
77 };
78
79 write!(f, "{}", route)
80 }
81}
82
83pub struct ClientBuilder {
84 api_key: String,
85 base_url: String,
86 sub_provider: SubProvider,
87}
88
89impl ClientBuilder {
90 pub fn new(api_key: &str) -> Self {
91 Self {
92 api_key: api_key.to_string(),
93 base_url: HUGGINGFACE_API_BASE_URL.to_string(),
94 sub_provider: SubProvider::default(),
95 }
96 }
97
98 pub fn base_url(mut self, base_url: &str) -> Self {
99 self.base_url = base_url.to_string();
100 self
101 }
102
103 pub fn sub_provider(mut self, provider: impl Into<SubProvider>) -> Self {
104 self.sub_provider = provider.into();
105 self
106 }
107
108 pub fn build(self) -> Client {
109 let route = self.sub_provider.to_string();
110
111 let base_url = format!("{}/{}", self.base_url, route).replace("//", "/");
112
113 Client::from_url(self.api_key.as_str(), base_url.as_str(), self.sub_provider)
114 }
115}
116
117#[derive(Clone)]
118pub struct Client {
119 base_url: String,
120 http_client: reqwest::Client,
121 pub(crate) sub_provider: SubProvider,
122}
123
124impl Client {
125 pub fn new(api_key: &str) -> Self {
127 Self::from_url(api_key, HUGGINGFACE_API_BASE_URL, SubProvider::HFInference)
128 }
129
130 pub fn from_url(api_key: &str, base_url: &str, sub_provider: SubProvider) -> Self {
132 let http_client = reqwest::Client::builder()
133 .default_headers({
134 let mut headers = reqwest::header::HeaderMap::new();
135 headers.insert(
136 "Authorization",
137 format!("Bearer {api_key}")
138 .parse()
139 .expect("Failed to parse API key"),
140 );
141 headers.insert(
142 "Content-Type",
143 "application/json"
144 .parse()
145 .expect("Failed to parse Content-Type"),
146 );
147 headers
148 })
149 .build()
150 .expect("Failed to build HTTP client");
151
152 Self {
153 base_url: base_url.to_owned(),
154 http_client,
155 sub_provider,
156 }
157 }
158 pub fn from_env() -> Self {
161 let api_key = std::env::var("HUGGINGFACE_API_KEY").expect("HUGGINGFACE_API_KEY is not set");
162 Self::new(&api_key)
163 }
164
165 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
166 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
167 self.http_client.post(url)
168 }
169
170 pub fn completion_model(&self, model: &str) -> CompletionModel {
182 CompletionModel::new(self.clone(), model)
183 }
184
185 pub fn transcription_model(&self, model: &str) -> TranscriptionModel {
197 TranscriptionModel::new(self.clone(), model)
198 }
199
200 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
215 AgentBuilder::new(self.completion_model(model))
216 }
217}