hugging_face_client/
client.rs

1//! Async hub client
2
3use std::time::Duration;
4
5use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
6use serde::{Deserialize, Serialize};
7use snafu::ResultExt;
8
9use crate::{
10  api::{
11    CreateRepoReq, CreateRepoRes, DeleteRepoReq, GetModelReq, GetModelRes, GetModelsReq,
12    GetModelsRes, HuggingFaceRes,
13  },
14  errors::{ReqwestClientSnafu, Result},
15};
16
17const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
18
19/// Options for creating [`Client`]
20#[derive(Debug, Clone)]
21pub struct ClientOption {
22  access_token: String,
23  api_endpoint: Option<String>,
24  http_proxy: Option<String>,
25  timeout: Option<Duration>,
26}
27
28impl ClientOption {
29  /// Create [`ClientOption`] instance with access_token
30  ///
31  /// `access_token`: authenticate client to the Hugging Face Hub and allow client to perform
32  /// actions based on token permissions, see [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
33  ///
34  /// ```rust
35  /// use hugging_face_client::client::ClientOption;
36  ///
37  /// fn main() {
38  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN");
39  /// }
40  /// ```
41  pub fn new(access_token: impl Into<String>) -> Self {
42    ClientOption {
43      access_token: access_token.into(),
44      api_endpoint: None,
45      http_proxy: None,
46      timeout: None,
47    }
48  }
49
50  /// Set endpoint, default is `https://huggingface.co`
51  ///
52  /// ```rust
53  /// use hugging_face_client::client::ClientOption;
54  ///
55  /// fn main() {
56  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
57  ///     .endpoint("https://fast-proxy.huggingface.to");
58  /// }
59  /// ```
60  pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
61    let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
62    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
63      self.api_endpoint = Some(endpoint.into());
64      return self;
65    }
66    let mut result = String::with_capacity(endpoint.len() + 8);
67    result.push_str("https://");
68    result.push_str(endpoint.as_str());
69    self.api_endpoint = Some(result);
70    self
71  }
72
73  /// Set proxy, default is `None`
74  ///
75  /// ```rust
76  /// use hugging_face_client::client::ClientOption;
77  ///
78  /// fn main() {
79  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
80  ///     .proxy("socks5://127.0.0.1:9000");
81  /// }
82  /// ```
83  pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
84    self.http_proxy = Some(proxy.into());
85    self
86  }
87
88  /// Set timeout in second, default is `None`
89  ///
90  /// ```rust
91  /// use hugging_face_client::client::ClientOption;
92  /// use std::time::Duration;
93  ///
94  /// fn main() {
95  /// let option = ClientOption::new("HUGGING_FACE_TOKEN")
96  ///     .timeout(Duration::from_secs(5));
97  /// }
98  /// ```
99  pub fn timeout(mut self, timeout: Duration) -> Self {
100    self.timeout = Some(timeout);
101    self
102  }
103}
104
105/// Async client for Hugging Face Hub
106#[derive(Debug, Clone)]
107pub struct Client {
108  access_token: String,
109  api_endpoint: String,
110  http_client: ReqwestClient,
111}
112
113impl Client {
114  /// Create [`Client`] instance with [`ClientOption`]
115  ///
116  ///
117  /// ```rust
118  /// use hugging_face_client::client::Client;
119  /// use hugging_face_client::client::ClientOption;
120  /// use std::time::Duration;
121  ///
122  /// fn main() {
123  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
124  ///     .timeout(Duration::from_secs(5));
125  ///   let client = Client::new(option);
126  /// }
127  /// ```
128  pub fn new(option: ClientOption) -> Result<Self> {
129    let mut http_client = ReqwestClient::builder();
130    if let Some(http_proxy) = option.http_proxy {
131      let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
132      http_client = http_client.proxy(proxy);
133    }
134    if let Some(timeout) = option.timeout {
135      http_client = http_client.timeout(timeout);
136    }
137    let http_client = http_client.build().context(ReqwestClientSnafu)?;
138    let client = Client {
139      access_token: option.access_token,
140      api_endpoint: option
141        .api_endpoint
142        .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
143      http_client,
144    };
145    Ok(client)
146  }
147
148  /// Endpoint: GET /api/models
149  ///
150  /// Get information from all models in the Hub
151  pub async fn get_models(&self, req: GetModelsReq<'_>) -> Result<GetModelsRes> {
152    let url = format!("{}/api/models", &self.api_endpoint);
153    self.get_request(&url, Some(&req)).await
154  }
155
156  /// Endpoint: GET /api/models/{repo_id}
157  ///
158  /// Endpoint: GET /api/models/{repo_id}/revision/{revision}
159  ///
160  /// Get all information for a specific model
161  pub async fn get_model(&self, req: GetModelReq<'_>) -> Result<GetModelRes> {
162    let url = if let Some(revision) = req.revision {
163      format!(
164        "{}/api/models/{}/revision/{}",
165        &self.api_endpoint, req.name, revision
166      )
167    } else {
168      format!("{}/api/models/{}", &self.api_endpoint, req.name)
169    };
170    let req = if true { None } else { Some(&req) };
171    self.get_request(&url, req).await
172  }
173
174  /// Endpoint:  POST /api/repos/create
175  ///
176  /// Create a repository, model repo by default.
177  pub async fn create_repo(&self, req: CreateRepoReq<'_>) -> Result<CreateRepoRes> {
178    let url = format!("{}/api/repos/create", &self.api_endpoint);
179    self.exec_request(&url, Method::POST, Some(&req)).await
180  }
181
182  /// Endpoint: DELETE /api/repos/delete
183  ///
184  pub async fn delete_repo(&self, req: DeleteRepoReq<'_>) -> Result<()> {
185    let url = format!("{}/api/repos/delete", &self.api_endpoint);
186    self
187      .exec_request_without_response(&url, Method::DELETE, Some(&req))
188      .await
189  }
190}
191
192// private method
193impl Client {
194  async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
195    &self,
196    url: &str,
197    query: Option<&T>,
198  ) -> Result<U> {
199    let mut req = self.http_client.get(url).bearer_auth(&self.access_token);
200    if let Some(query) = query {
201      req = req.query(query);
202    }
203    let res = req
204      .send()
205      .await
206      .context(ReqwestClientSnafu)?
207      .json::<HuggingFaceRes<U>>()
208      .await
209      .context(ReqwestClientSnafu)?
210      .unwrap_data()?;
211    Ok(res)
212  }
213
214  async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
215    &self,
216    url: &str,
217    method: Method,
218    body: Option<&T>,
219  ) -> Result<U> {
220    let mut req = self
221      .http_client
222      .request(method, url)
223      .bearer_auth(&self.access_token);
224    if let Some(body) = body {
225      req = req.json(body);
226    }
227    let res = req
228      .send()
229      .await
230      .context(ReqwestClientSnafu)?
231      .json::<HuggingFaceRes<U>>()
232      .await
233      .context(ReqwestClientSnafu)?
234      .unwrap_data()?;
235    Ok(res)
236  }
237
238  async fn exec_request_without_response<T: Serialize>(
239    &self,
240    url: &str,
241    method: Method,
242    body: Option<&T>,
243  ) -> Result<()> {
244    let mut req = self
245      .http_client
246      .request(method, url)
247      .bearer_auth(&self.access_token);
248    if let Some(body) = body {
249      req = req.json(body);
250    }
251    let _res = req
252      .send()
253      .await
254      .context(ReqwestClientSnafu)?
255      .error_for_status()
256      .context(ReqwestClientSnafu)?;
257    Ok(())
258  }
259}