hugging_face_client/
client.rs1mod arxiv;
4mod repo;
5
6use std::time::Duration;
7
8use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
9use serde::{Deserialize, Serialize};
10use snafu::ResultExt;
11
12use crate::{
13 api::HuggingFaceRes,
14 errors::{ReqwestClientSnafu, Result},
15};
16
17const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
18
19#[derive(Debug, Clone)]
21pub struct ClientOption {
22 access_token: String,
23 api_endpoint: Option<String>,
24 http_proxy: Option<String>,
25 timeout: Option<Duration>,
26}
27
28impl ClientOption {
29 pub fn new(access_token: impl Into<String>) -> Self {
42 ClientOption {
43 access_token: access_token.into(),
44 api_endpoint: None,
45 http_proxy: None,
46 timeout: None,
47 }
48 }
49
50 pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
61 let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
62 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
63 self.api_endpoint = Some(endpoint.into());
64 return self;
65 }
66 let mut result = String::with_capacity(endpoint.len() + 8);
67 result.push_str("https://");
68 result.push_str(endpoint.as_str());
69 self.api_endpoint = Some(result);
70 self
71 }
72
73 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
84 self.http_proxy = Some(proxy.into());
85 self
86 }
87
88 pub fn timeout(mut self, timeout: Duration) -> Self {
100 self.timeout = Some(timeout);
101 self
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct Client {
108 access_token: String,
109 api_endpoint: String,
110 http_client: ReqwestClient,
111}
112
113impl Client {
114 pub fn new(option: ClientOption) -> Result<Self> {
129 let mut http_client = ReqwestClient::builder();
130 if let Some(http_proxy) = option.http_proxy {
131 let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
132 http_client = http_client.proxy(proxy);
133 }
134 if let Some(timeout) = option.timeout {
135 http_client = http_client.timeout(timeout);
136 }
137 let http_client = http_client.build().context(ReqwestClientSnafu)?;
138 let client = Client {
139 access_token: option.access_token,
140 api_endpoint: option
141 .api_endpoint
142 .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
143 http_client,
144 };
145 Ok(client)
146 }
147}
148
149impl Client {
151 async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
152 &self,
153 url: &str,
154 query: Option<&T>,
155 need_token: bool,
156 ) -> Result<U> {
157 let mut req = self.http_client.get(url);
158 if need_token {
159 req = req.bearer_auth(&self.access_token);
160 }
161 if let Some(query) = query {
162 req = req.query(query);
163 }
164 let res = req
165 .send()
166 .await
167 .context(ReqwestClientSnafu)?
168 .json::<HuggingFaceRes<U>>()
169 .await
170 .context(ReqwestClientSnafu)?
171 .unwrap_data()?;
172 Ok(res)
173 }
174
175 async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
176 &self,
177 url: &str,
178 method: Method,
179 body: Option<&T>,
180 ) -> Result<U> {
181 let mut req = self
182 .http_client
183 .request(method, url)
184 .bearer_auth(&self.access_token);
185 if let Some(body) = body {
186 req = req.json(body);
187 }
188 let res = req
189 .send()
190 .await
191 .context(ReqwestClientSnafu)?
192 .json::<HuggingFaceRes<U>>()
193 .await
194 .context(ReqwestClientSnafu)?
195 .unwrap_data()?;
196 Ok(res)
197 }
198
199 async fn exec_request_without_response<T: Serialize>(
200 &self,
201 url: &str,
202 method: Method,
203 body: Option<&T>,
204 ) -> Result<()> {
205 let mut req = self
206 .http_client
207 .request(method, url)
208 .bearer_auth(&self.access_token);
209 if let Some(body) = body {
210 req = req.json(body);
211 }
212 let _res = req
213 .send()
214 .await
215 .context(ReqwestClientSnafu)?
216 .error_for_status()
217 .context(ReqwestClientSnafu)?;
218 Ok(())
219 }
220}