hugging_face_client/
client.rs

1//! Async hub client
2
3mod arxiv;
4mod repo;
5
6use std::time::Duration;
7
8use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
9use serde::{Deserialize, Serialize};
10use snafu::ResultExt;
11
12use crate::{
13  api::HuggingFaceRes,
14  errors::{ReqwestClientSnafu, Result},
15};
16
17const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
18
19/// Options for creating [`Client`]
20#[derive(Debug, Clone)]
21pub struct ClientOption {
22  access_token: String,
23  api_endpoint: Option<String>,
24  http_proxy: Option<String>,
25  timeout: Option<Duration>,
26}
27
28impl ClientOption {
29  /// Create [`ClientOption`] instance with access_token
30  ///
31  /// `access_token`: authenticate client to the Hugging Face Hub and allow client to perform
32  /// actions based on token permissions, see [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
33  ///
34  /// ```rust
35  /// use hugging_face_client::client::ClientOption;
36  ///
37  /// fn main() {
38  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN");
39  /// }
40  /// ```
41  pub fn new(access_token: impl Into<String>) -> Self {
42    ClientOption {
43      access_token: access_token.into(),
44      api_endpoint: None,
45      http_proxy: None,
46      timeout: None,
47    }
48  }
49
50  /// Set endpoint, default is `https://huggingface.co`
51  ///
52  /// ```rust
53  /// use hugging_face_client::client::ClientOption;
54  ///
55  /// fn main() {
56  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
57  ///     .endpoint("https://fast-proxy.huggingface.to");
58  /// }
59  /// ```
60  pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
61    let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
62    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
63      self.api_endpoint = Some(endpoint.into());
64      return self;
65    }
66    let mut result = String::with_capacity(endpoint.len() + 8);
67    result.push_str("https://");
68    result.push_str(endpoint.as_str());
69    self.api_endpoint = Some(result);
70    self
71  }
72
73  /// Set proxy, default is `None`
74  ///
75  /// ```rust
76  /// use hugging_face_client::client::ClientOption;
77  ///
78  /// fn main() {
79  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
80  ///     .proxy("socks5://127.0.0.1:9000");
81  /// }
82  /// ```
83  pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
84    self.http_proxy = Some(proxy.into());
85    self
86  }
87
88  /// Set timeout in second, default is `None`
89  ///
90  /// ```rust
91  /// use hugging_face_client::client::ClientOption;
92  /// use std::time::Duration;
93  ///
94  /// fn main() {
95  /// let option = ClientOption::new("HUGGING_FACE_TOKEN")
96  ///     .timeout(Duration::from_secs(5));
97  /// }
98  /// ```
99  pub fn timeout(mut self, timeout: Duration) -> Self {
100    self.timeout = Some(timeout);
101    self
102  }
103}
104
105/// Async client for Hugging Face Hub
106#[derive(Debug, Clone)]
107pub struct Client {
108  access_token: String,
109  api_endpoint: String,
110  http_client: ReqwestClient,
111}
112
113impl Client {
114  /// Create [`Client`] instance with [`ClientOption`]
115  ///
116  ///
117  /// ```rust
118  /// use hugging_face_client::client::Client;
119  /// use hugging_face_client::client::ClientOption;
120  /// use std::time::Duration;
121  ///
122  /// fn main() {
123  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
124  ///     .timeout(Duration::from_secs(5));
125  ///   let client = Client::new(option);
126  /// }
127  /// ```
128  pub fn new(option: ClientOption) -> Result<Self> {
129    let mut http_client = ReqwestClient::builder();
130    if let Some(http_proxy) = option.http_proxy {
131      let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
132      http_client = http_client.proxy(proxy);
133    }
134    if let Some(timeout) = option.timeout {
135      http_client = http_client.timeout(timeout);
136    }
137    let http_client = http_client.build().context(ReqwestClientSnafu)?;
138    let client = Client {
139      access_token: option.access_token,
140      api_endpoint: option
141        .api_endpoint
142        .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
143      http_client,
144    };
145    Ok(client)
146  }
147}
148
149// private method
150impl Client {
151  async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
152    &self,
153    url: &str,
154    query: Option<&T>,
155    need_token: bool,
156  ) -> Result<U> {
157    let mut req = self.http_client.get(url);
158    if need_token {
159      req = req.bearer_auth(&self.access_token);
160    }
161    if let Some(query) = query {
162      req = req.query(query);
163    }
164    let res = req
165      .send()
166      .await
167      .context(ReqwestClientSnafu)?
168      .json::<HuggingFaceRes<U>>()
169      .await
170      .context(ReqwestClientSnafu)?
171      .unwrap_data()?;
172    Ok(res)
173  }
174
175  async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
176    &self,
177    url: &str,
178    method: Method,
179    body: Option<&T>,
180  ) -> Result<U> {
181    let mut req = self
182      .http_client
183      .request(method, url)
184      .bearer_auth(&self.access_token);
185    if let Some(body) = body {
186      req = req.json(body);
187    }
188    let res = req
189      .send()
190      .await
191      .context(ReqwestClientSnafu)?
192      .json::<HuggingFaceRes<U>>()
193      .await
194      .context(ReqwestClientSnafu)?
195      .unwrap_data()?;
196    Ok(res)
197  }
198
199  async fn exec_request_without_response<T: Serialize>(
200    &self,
201    url: &str,
202    method: Method,
203    body: Option<&T>,
204  ) -> Result<()> {
205    let mut req = self
206      .http_client
207      .request(method, url)
208      .bearer_auth(&self.access_token);
209    if let Some(body) = body {
210      req = req.json(body);
211    }
212    let _res = req
213      .send()
214      .await
215      .context(ReqwestClientSnafu)?
216      .error_for_status()
217      .context(ReqwestClientSnafu)?;
218    Ok(())
219  }
220}