hugging_face_client/
client.rs1mod arxiv;
4mod collection;
5mod organization;
6mod repo;
7mod user;
8
9use std::time::Duration;
10
11use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
12use serde::{Deserialize, Serialize};
13use snafu::ResultExt;
14
15use crate::{
16 api::HuggingFaceRes,
17 errors::{ReqwestClientSnafu, Result},
18};
19
20const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
21
22#[derive(Debug, Clone)]
24pub struct ClientOption {
25 access_token: String,
26 api_endpoint: Option<String>,
27 http_proxy: Option<String>,
28 timeout: Option<Duration>,
29}
30
31impl ClientOption {
32 pub fn new(access_token: impl Into<String>) -> Self {
45 ClientOption {
46 access_token: access_token.into(),
47 api_endpoint: None,
48 http_proxy: None,
49 timeout: None,
50 }
51 }
52
53 pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
64 let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
65 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
66 self.api_endpoint = Some(endpoint.into());
67 return self;
68 }
69 let mut result = String::with_capacity(endpoint.len() + 8);
70 result.push_str("https://");
71 result.push_str(endpoint.as_str());
72 self.api_endpoint = Some(result);
73 self
74 }
75
76 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
87 self.http_proxy = Some(proxy.into());
88 self
89 }
90
91 pub fn timeout(mut self, timeout: Duration) -> Self {
103 self.timeout = Some(timeout);
104 self
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct Client {
111 access_token: String,
112 api_endpoint: String,
113 http_client: ReqwestClient,
114}
115
116impl Client {
117 pub fn new(option: ClientOption) -> Result<Self> {
132 let mut http_client = ReqwestClient::builder();
133 if let Some(http_proxy) = option.http_proxy {
134 let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
135 http_client = http_client.proxy(proxy);
136 }
137 if let Some(timeout) = option.timeout {
138 http_client = http_client.timeout(timeout);
139 }
140 let http_client = http_client.build().context(ReqwestClientSnafu)?;
141 let client = Client {
142 access_token: option.access_token,
143 api_endpoint: option
144 .api_endpoint
145 .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
146 http_client,
147 };
148 Ok(client)
149 }
150}
151
152impl Client {
154 async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
155 &self,
156 url: &str,
157 query: Option<&T>,
158 need_token: bool,
159 ) -> Result<U> {
160 let mut req = self.http_client.get(url);
161 if need_token {
162 req = req.bearer_auth(&self.access_token);
163 }
164 if let Some(query) = query {
165 req = req.query(query);
166 }
167 let res = req
168 .send()
169 .await
170 .context(ReqwestClientSnafu)?
171 .json::<HuggingFaceRes<U>>()
172 .await
173 .context(ReqwestClientSnafu)?
174 .unwrap_data()?;
175 Ok(res)
176 }
177
178 async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
179 &self,
180 url: &str,
181 method: Method,
182 body: Option<&T>,
183 ) -> Result<U> {
184 let mut req = self
185 .http_client
186 .request(method, url)
187 .bearer_auth(&self.access_token);
188 if let Some(body) = body {
189 req = req.json(body);
190 }
191 let res = req
192 .send()
193 .await
194 .context(ReqwestClientSnafu)?
195 .json::<HuggingFaceRes<U>>()
196 .await
197 .context(ReqwestClientSnafu)?
198 .unwrap_data()?;
199 Ok(res)
200 }
201
202 async fn exec_request_without_response<T: Serialize>(
203 &self,
204 url: &str,
205 method: Method,
206 body: Option<&T>,
207 ) -> Result<()> {
208 let mut req = self
209 .http_client
210 .request(method, url)
211 .bearer_auth(&self.access_token);
212 if let Some(body) = body {
213 req = req.json(body);
214 }
215 let _res = req
216 .send()
217 .await
218 .context(ReqwestClientSnafu)?
219 .error_for_status()
220 .context(ReqwestClientSnafu)?;
221 Ok(())
222 }
223
224 #[inline]
225 fn empty_req(&self) -> Option<&()> {
226 if true { None } else { Some(&())}
227 }
228}