hugging_face_client/
client.rs

1//! Async hub client
2
3mod arxiv;
4mod collection;
5mod organization;
6mod repo;
7mod user;
8
9use std::time::Duration;
10
11use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
12use serde::{Deserialize, Serialize};
13use snafu::ResultExt;
14
15use crate::{
16  api::HuggingFaceRes,
17  errors::{ReqwestClientSnafu, Result},
18};
19
20const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
21
22/// Options for creating [`Client`]
23#[derive(Debug, Clone)]
24pub struct ClientOption {
25  access_token: String,
26  api_endpoint: Option<String>,
27  http_proxy: Option<String>,
28  timeout: Option<Duration>,
29}
30
31impl ClientOption {
32  /// Create [`ClientOption`] instance with access_token
33  ///
34  /// `access_token`: authenticate client to the Hugging Face Hub and allow client to perform
35  /// actions based on token permissions, see [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
36  ///
37  /// ```rust
38  /// use hugging_face_client::client::ClientOption;
39  ///
40  /// fn main() {
41  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN");
42  /// }
43  /// ```
44  pub fn new(access_token: impl Into<String>) -> Self {
45    ClientOption {
46      access_token: access_token.into(),
47      api_endpoint: None,
48      http_proxy: None,
49      timeout: None,
50    }
51  }
52
53  /// Set endpoint, default is `https://huggingface.co`
54  ///
55  /// ```rust
56  /// use hugging_face_client::client::ClientOption;
57  ///
58  /// fn main() {
59  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
60  ///     .endpoint("https://fast-proxy.huggingface.to");
61  /// }
62  /// ```
63  pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
64    let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
65    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
66      self.api_endpoint = Some(endpoint.into());
67      return self;
68    }
69    let mut result = String::with_capacity(endpoint.len() + 8);
70    result.push_str("https://");
71    result.push_str(endpoint.as_str());
72    self.api_endpoint = Some(result);
73    self
74  }
75
76  /// Set proxy, default is `None`
77  ///
78  /// ```rust
79  /// use hugging_face_client::client::ClientOption;
80  ///
81  /// fn main() {
82  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
83  ///     .proxy("socks5://127.0.0.1:9000");
84  /// }
85  /// ```
86  pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
87    self.http_proxy = Some(proxy.into());
88    self
89  }
90
91  /// Set timeout in second, default is `None`
92  ///
93  /// ```rust
94  /// use hugging_face_client::client::ClientOption;
95  /// use std::time::Duration;
96  ///
97  /// fn main() {
98  /// let option = ClientOption::new("HUGGING_FACE_TOKEN")
99  ///     .timeout(Duration::from_secs(5));
100  /// }
101  /// ```
102  pub fn timeout(mut self, timeout: Duration) -> Self {
103    self.timeout = Some(timeout);
104    self
105  }
106}
107
108/// Async client for Hugging Face Hub
109#[derive(Debug, Clone)]
110pub struct Client {
111  access_token: String,
112  api_endpoint: String,
113  http_client: ReqwestClient,
114}
115
116impl Client {
117  /// Create [`Client`] instance with [`ClientOption`]
118  ///
119  ///
120  /// ```rust
121  /// use hugging_face_client::client::Client;
122  /// use hugging_face_client::client::ClientOption;
123  /// use std::time::Duration;
124  ///
125  /// fn main() {
126  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
127  ///     .timeout(Duration::from_secs(5));
128  ///   let client = Client::new(option);
129  /// }
130  /// ```
131  pub fn new(option: ClientOption) -> Result<Self> {
132    let mut http_client = ReqwestClient::builder();
133    if let Some(http_proxy) = option.http_proxy {
134      let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
135      http_client = http_client.proxy(proxy);
136    }
137    if let Some(timeout) = option.timeout {
138      http_client = http_client.timeout(timeout);
139    }
140    let http_client = http_client.build().context(ReqwestClientSnafu)?;
141    let client = Client {
142      access_token: option.access_token,
143      api_endpoint: option
144        .api_endpoint
145        .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
146      http_client,
147    };
148    Ok(client)
149  }
150}
151
152// private method
153impl Client {
154  async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
155    &self,
156    url: &str,
157    query: Option<&T>,
158    need_token: bool,
159  ) -> Result<U> {
160    let mut req = self.http_client.get(url);
161    if need_token {
162      req = req.bearer_auth(&self.access_token);
163    }
164    if let Some(query) = query {
165      req = req.query(query);
166    }
167    let res = req
168      .send()
169      .await
170      .context(ReqwestClientSnafu)?
171      .json::<HuggingFaceRes<U>>()
172      .await
173      .context(ReqwestClientSnafu)?
174      .unwrap_data()?;
175    Ok(res)
176  }
177
178  async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
179    &self,
180    url: &str,
181    method: Method,
182    body: Option<&T>,
183  ) -> Result<U> {
184    let mut req = self
185      .http_client
186      .request(method, url)
187      .bearer_auth(&self.access_token);
188    if let Some(body) = body {
189      req = req.json(body);
190    }
191    let res = req
192      .send()
193      .await
194      .context(ReqwestClientSnafu)?
195      .json::<HuggingFaceRes<U>>()
196      .await
197      .context(ReqwestClientSnafu)?
198      .unwrap_data()?;
199    Ok(res)
200  }
201
202  async fn exec_request_without_response<T: Serialize>(
203    &self,
204    url: &str,
205    method: Method,
206    body: Option<&T>,
207  ) -> Result<()> {
208    let mut req = self
209      .http_client
210      .request(method, url)
211      .bearer_auth(&self.access_token);
212    if let Some(body) = body {
213      req = req.json(body);
214    }
215    let _res = req
216      .send()
217      .await
218      .context(ReqwestClientSnafu)?
219      .error_for_status()
220      .context(ReqwestClientSnafu)?;
221    Ok(())
222  }
223  
224  #[inline]
225  fn empty_req(&self) -> Option<&()> {
226    if true { None } else { Some(&())}
227  }
228}