hugging_face_client/
client.rs1mod repo;
4
5use std::time::Duration;
6
7use reqwest::{Client as ReqwestClient, Method, Proxy as ReqwestProxy};
8use serde::{Deserialize, Serialize};
9use snafu::ResultExt;
10
11use crate::{
12 api::HuggingFaceRes,
13 errors::{ReqwestClientSnafu, Result},
14};
15
16const DEFAULT_API_ENDPOINT: &'static str = "https://huggingface.co";
17
18#[derive(Debug, Clone)]
20pub struct ClientOption {
21 access_token: String,
22 api_endpoint: Option<String>,
23 http_proxy: Option<String>,
24 timeout: Option<Duration>,
25}
26
27impl ClientOption {
28 pub fn new(access_token: impl Into<String>) -> Self {
41 ClientOption {
42 access_token: access_token.into(),
43 api_endpoint: None,
44 http_proxy: None,
45 timeout: None,
46 }
47 }
48
49 pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
60 let endpoint = endpoint.into().trim().trim_end_matches('/').to_string();
61 if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
62 self.api_endpoint = Some(endpoint.into());
63 return self;
64 }
65 let mut result = String::with_capacity(endpoint.len() + 8);
66 result.push_str("https://");
67 result.push_str(endpoint.as_str());
68 self.api_endpoint = Some(result);
69 self
70 }
71
72 pub fn proxy(mut self, proxy: impl Into<String>) -> Self {
83 self.http_proxy = Some(proxy.into());
84 self
85 }
86
87 pub fn timeout(mut self, timeout: Duration) -> Self {
99 self.timeout = Some(timeout);
100 self
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct Client {
107 access_token: String,
108 api_endpoint: String,
109 http_client: ReqwestClient,
110}
111
112impl Client {
113 pub fn new(option: ClientOption) -> Result<Self> {
128 let mut http_client = ReqwestClient::builder();
129 if let Some(http_proxy) = option.http_proxy {
130 let proxy = ReqwestProxy::all(&http_proxy).context(ReqwestClientSnafu)?;
131 http_client = http_client.proxy(proxy);
132 }
133 if let Some(timeout) = option.timeout {
134 http_client = http_client.timeout(timeout);
135 }
136 let http_client = http_client.build().context(ReqwestClientSnafu)?;
137 let client = Client {
138 access_token: option.access_token,
139 api_endpoint: option
140 .api_endpoint
141 .unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
142 http_client,
143 };
144 Ok(client)
145 }
146}
147
148impl Client {
150 async fn get_request<T: Serialize, U: for<'de> Deserialize<'de>>(
151 &self,
152 url: &str,
153 query: Option<&T>,
154 need_token: bool,
155 ) -> Result<U> {
156 let mut req = self.http_client.get(url);
157 if need_token {
158 req = req.bearer_auth(&self.access_token);
159 }
160 if let Some(query) = query {
161 req = req.query(query);
162 }
163 let res = req
164 .send()
165 .await
166 .context(ReqwestClientSnafu)?
167 .json::<HuggingFaceRes<U>>()
168 .await
169 .context(ReqwestClientSnafu)?
170 .unwrap_data()?;
171 Ok(res)
172 }
173
174 async fn exec_request<T: Serialize, U: for<'de> Deserialize<'de>>(
175 &self,
176 url: &str,
177 method: Method,
178 body: Option<&T>,
179 ) -> Result<U> {
180 let mut req = self
181 .http_client
182 .request(method, url)
183 .bearer_auth(&self.access_token);
184 if let Some(body) = body {
185 req = req.json(body);
186 }
187 let res = req
188 .send()
189 .await
190 .context(ReqwestClientSnafu)?
191 .json::<HuggingFaceRes<U>>()
192 .await
193 .context(ReqwestClientSnafu)?
194 .unwrap_data()?;
195 Ok(res)
196 }
197
198 async fn exec_request_without_response<T: Serialize>(
199 &self,
200 url: &str,
201 method: Method,
202 body: Option<&T>,
203 ) -> Result<()> {
204 let mut req = self
205 .http_client
206 .request(method, url)
207 .bearer_auth(&self.access_token);
208 if let Some(body) = body {
209 req = req.json(body);
210 }
211 let _res = req
212 .send()
213 .await
214 .context(ReqwestClientSnafu)?
215 .error_for_status()
216 .context(ReqwestClientSnafu)?;
217 Ok(())
218 }
219}