hugging_face_client/
client.rs1use std::time::Duration;
4
5use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
6use serde::{Deserialize, Serialize};
7use snafu::ResultExt;
8
9use crate::{
10 api::{
11 CreateRepoReq, CreateRepoRes, DeleteRepoReq, GetModelReq, GetModelRes, GetModelsReq,
12 GetModelsRes, HuggingFaceRes,
13 },
14 errors::{ReqwestClientSnafu, Result},
15};
16
17const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
18
19#[derive(Debug, Clone)]
21pub struct ClientOption {
22 access_token: String,
23 api_endpoint: Option<String>,
24 http_proxy: Option<String>,
25 timeout: Option<Duration>,
26}
27
28impl ClientOption {
29 pub fn new(access_token: impl Into<String>) -> Self {
42 ClientOption {
43 access_token: access_token.into(),
44 api_endpoint: None,
45 http_proxy: None,
46 timeout: None,
47 }
48 }
49
50 pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
61 let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
62 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
63 self.api_endpoint = Some(endpoint.into());
64 return self;
65 }
66 let mut result = String::with_capacity(endpoint.len() + 8);
67 result.push_str("https://");
68 result.push_str(endpoint.as_str());
69 self.api_endpoint = Some(result);
70 self
71 }
72
73 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
84 self.http_proxy = Some(proxy.into());
85 self
86 }
87
88 pub fn timeout(mut self, timeout: Duration) -> Self {
100 self.timeout = Some(timeout);
101 self
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct Client {
108 access_token: String,
109 api_endpoint: String,
110 http_client: ReqwestClient,
111}
112
113impl Client {
114 pub fn new(option: ClientOption) -> Result<Self> {
129 let mut http_client = ReqwestClient::builder();
130 if let Some(http_proxy) = option.http_proxy {
131 let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
132 http_client = http_client.proxy(proxy);
133 }
134 if let Some(timeout) = option.timeout {
135 http_client = http_client.timeout(timeout);
136 }
137 let http_client = http_client.build().context(ReqwestClientSnafu)?;
138 let client = Client {
139 access_token: option.access_token,
140 api_endpoint: option
141 .api_endpoint
142 .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
143 http_client,
144 };
145 Ok(client)
146 }
147
148 pub async fn get_models(&self, req: GetModelsReq<'_>) -> Result<GetModelsRes> {
152 let url = format!("{}/api/models", &self.api_endpoint);
153 self.get_request(&url, Some(&req)).await
154 }
155
156 pub async fn get_model(&self, req: GetModelReq<'_>) -> Result<GetModelRes> {
162 let url = if let Some(revision) = req.revision {
163 format!(
164 "{}/api/models/{}/revision/{}",
165 &self.api_endpoint, req.name, revision
166 )
167 } else {
168 format!("{}/api/models/{}", &self.api_endpoint, req.name)
169 };
170 let req = if true { None } else { Some(&req) };
171 self.get_request(&url, req).await
172 }
173
174 pub async fn create_repo(&self, req: CreateRepoReq<'_>) -> Result<CreateRepoRes> {
178 let url = format!("{}/api/repos/create", &self.api_endpoint);
179 self.exec_request(&url, Method::POST, Some(&req)).await
180 }
181
182 pub async fn delete_repo(&self, req: DeleteRepoReq<'_>) -> Result<()> {
185 let url = format!("{}/api/repos/delete", &self.api_endpoint);
186 self
187 .exec_request_without_response(&url, Method::DELETE, Some(&req))
188 .await
189 }
190}
191
192impl Client {
194 async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
195 &self,
196 url: &str,
197 query: Option<&T>,
198 ) -> Result<U> {
199 let mut req = self.http_client.get(url).bearer_auth(&self.access_token);
200 if let Some(query) = query {
201 req = req.query(query);
202 }
203 let res = req
204 .send()
205 .await
206 .context(ReqwestClientSnafu)?
207 .json::<HuggingFaceRes<U>>()
208 .await
209 .context(ReqwestClientSnafu)?
210 .unwrap_data()?;
211 Ok(res)
212 }
213
214 async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
215 &self,
216 url: &str,
217 method: Method,
218 body: Option<&T>,
219 ) -> Result<U> {
220 let mut req = self
221 .http_client
222 .request(method, url)
223 .bearer_auth(&self.access_token);
224 if let Some(body) = body {
225 req = req.json(body);
226 }
227 let res = req
228 .send()
229 .await
230 .context(ReqwestClientSnafu)?
231 .json::<HuggingFaceRes<U>>()
232 .await
233 .context(ReqwestClientSnafu)?
234 .unwrap_data()?;
235 Ok(res)
236 }
237
238 async fn exec_request_without_response<T: Serialize>(
239 &self,
240 url: &str,
241 method: Method,
242 body: Option<&T>,
243 ) -> Result<()> {
244 let mut req = self
245 .http_client
246 .request(method, url)
247 .bearer_auth(&self.access_token);
248 if let Some(body) = body {
249 req = req.json(body);
250 }
251 let _res = req
252 .send()
253 .await
254 .context(ReqwestClientSnafu)?
255 .error_for_status()
256 .context(ReqwestClientSnafu)?;
257 Ok(())
258 }
259}