hugging_face_client/
client.rs

1//! Async hub client
2
3mod repo;
4
5use std::time::Duration;
6
7use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
8use serde::{Deserialize, Serialize};
9use snafu::ResultExt;
10
11use crate::{
12  api::HuggingFaceRes,
13  errors::{ReqwestClientSnafu, Result},
14};
15
16const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
17
18/// Options for creating [`Client`]
19#[derive(Debug, Clone)]
20pub struct ClientOption {
21  access_token: String,
22  api_endpoint: Option<String>,
23  http_proxy: Option<String>,
24  timeout: Option<Duration>,
25}
26
27impl ClientOption {
28  /// Create [`ClientOption`] instance with access_token
29  ///
30  /// `access_token`: authenticate client to the Hugging Face Hub and allow client to perform
31  /// actions based on token permissions, see [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
32  ///
33  /// ```rust
34  /// use hugging_face_client::client::ClientOption;
35  ///
36  /// fn main() {
37  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN");
38  /// }
39  /// ```
40  pub fn new(access_token: impl Into<String>) -> Self {
41    ClientOption {
42      access_token: access_token.into(),
43      api_endpoint: None,
44      http_proxy: None,
45      timeout: None,
46    }
47  }
48
49  /// Set endpoint, default is `https://huggingface.co`
50  ///
51  /// ```rust
52  /// use hugging_face_client::client::ClientOption;
53  ///
54  /// fn main() {
55  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
56  ///     .endpoint("https://fast-proxy.huggingface.to");
57  /// }
58  /// ```
59  pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
60    let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
61    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
62      self.api_endpoint = Some(endpoint.into());
63      return self;
64    }
65    let mut result = String::with_capacity(endpoint.len() + 8);
66    result.push_str("https://");
67    result.push_str(endpoint.as_str());
68    self.api_endpoint = Some(result);
69    self
70  }
71
72  /// Set proxy, default is `None`
73  ///
74  /// ```rust
75  /// use hugging_face_client::client::ClientOption;
76  ///
77  /// fn main() {
78  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
79  ///     .proxy("socks5://127.0.0.1:9000");
80  /// }
81  /// ```
82  pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
83    self.http_proxy = Some(proxy.into());
84    self
85  }
86
87  /// Set timeout in second, default is `None`
88  ///
89  /// ```rust
90  /// use hugging_face_client::client::ClientOption;
91  /// use std::time::Duration;
92  ///
93  /// fn main() {
94  /// let option = ClientOption::new("HUGGING_FACE_TOKEN")
95  ///     .timeout(Duration::from_secs(5));
96  /// }
97  /// ```
98  pub fn timeout(mut self, timeout: Duration) -> Self {
99    self.timeout = Some(timeout);
100    self
101  }
102}
103
104/// Async client for Hugging Face Hub
105#[derive(Debug, Clone)]
106pub struct Client {
107  access_token: String,
108  api_endpoint: String,
109  http_client: ReqwestClient,
110}
111
112impl Client {
113  /// Create [`Client`] instance with [`ClientOption`]
114  ///
115  ///
116  /// ```rust
117  /// use hugging_face_client::client::Client;
118  /// use hugging_face_client::client::ClientOption;
119  /// use std::time::Duration;
120  ///
121  /// fn main() {
122  ///   let option = ClientOption::new("HUGGING_FACE_TOKEN")
123  ///     .timeout(Duration::from_secs(5));
124  ///   let client = Client::new(option);
125  /// }
126  /// ```
127  pub fn new(option: ClientOption) -> Result<Self> {
128    let mut http_client = ReqwestClient::builder();
129    if let Some(http_proxy) = option.http_proxy {
130      let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
131      http_client = http_client.proxy(proxy);
132    }
133    if let Some(timeout) = option.timeout {
134      http_client = http_client.timeout(timeout);
135    }
136    let http_client = http_client.build().context(ReqwestClientSnafu)?;
137    let client = Client {
138      access_token: option.access_token,
139      api_endpoint: option
140        .api_endpoint
141        .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
142      http_client,
143    };
144    Ok(client)
145  }
146}
147
148// private method
149impl Client {
150  async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
151    &self,
152    url: &str,
153    query: Option<&T>,
154    need_token: bool,
155  ) -> Result<U> {
156    let mut req = self.http_client.get(url);
157    if need_token {
158      req = req.bearer_auth(&self.access_token);
159    }
160    if let Some(query) = query {
161      req = req.query(query);
162    }
163    let res = req
164      .send()
165      .await
166      .context(ReqwestClientSnafu)?
167      .json::<HuggingFaceRes<U>>()
168      .await
169      .context(ReqwestClientSnafu)?
170      .unwrap_data()?;
171    Ok(res)
172  }
173
174  async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
175    &self,
176    url: &str,
177    method: Method,
178    body: Option<&T>,
179  ) -> Result<U> {
180    let mut req = self
181      .http_client
182      .request(method, url)
183      .bearer_auth(&self.access_token);
184    if let Some(body) = body {
185      req = req.json(body);
186    }
187    let res = req
188      .send()
189      .await
190      .context(ReqwestClientSnafu)?
191      .json::<HuggingFaceRes<U>>()
192      .await
193      .context(ReqwestClientSnafu)?
194      .unwrap_data()?;
195    Ok(res)
196  }
197
198  async fn exec_request_without_response<T: Serialize>(
199    &self,
200    url: &str,
201    method: Method,
202    body: Option<&T>,
203  ) -> Result<()> {
204    let mut req = self
205      .http_client
206      .request(method, url)
207      .bearer_auth(&self.access_token);
208    if let Some(body) = body {
209      req = req.json(body);
210    }
211    let _res = req
212      .send()
213      .await
214      .context(ReqwestClientSnafu)?
215      .error_for_status()
216      .context(ReqwestClientSnafu)?;
217    Ok(())
218  }
219}