1mod builder;
57mod config;
58mod rate_limit;
59
60pub use builder::ClientBuilder;
61pub use config::{AuthConfig, Config};
62use rate_limit::RateLimiter;
63
64use crate::error::{Error, Result};
65use reqwest::{Client as ReqwestClient, RequestBuilder, StatusCode};
66use serde::{de::DeserializeOwned, Serialize};
67use std::sync::Arc;
68use tokio::sync::RwLock;
69
70#[derive(Debug, Clone)]
72pub struct Client {
73 http_client: ReqwestClient,
74 base_url: String,
75 auth: Arc<RwLock<Option<AuthConfig>>>,
76 rate_limiter: Arc<RateLimiter>,
77}
78
79impl Client {
80 pub fn new() -> Result<Self> {
82 Self::builder().build()
83 }
84
85 pub fn builder() -> ClientBuilder {
87 ClientBuilder::new()
88 }
89
90 pub(crate) fn with_config(config: Config) -> Result<Self> {
92 let http_client = ReqwestClient::builder().timeout(config.timeout).build()?;
93
94 let base_url = config.base_url();
95
96 Ok(Self {
97 http_client,
98 base_url,
99 auth: Arc::new(RwLock::new(config.auth)),
100 rate_limiter: Arc::new(RateLimiter::new(config.rate_limit)),
101 })
102 }
103
104 pub fn base_url(&self) -> &str {
106 &self.base_url
107 }
108
109 pub async fn get_auth_token(&self) -> Option<String> {
111 let auth = self.auth.read().await;
112 auth.as_ref().and_then(|auth| {
113 if auth.is_valid() {
114 Some(auth.token.clone())
115 } else {
116 None
117 }
118 })
119 }
120
121 pub fn auth_token(&self) -> Option<String> {
123 futures::executor::block_on(self.get_auth_token())
124 }
125
126 pub async fn set_auth_token(&self, token: String) {
128 let mut auth = self.auth.write().await;
129 *auth = Some(AuthConfig::new(token));
130 }
131
132 pub async fn set_auth_token_with_expiry(
134 &self,
135 token: String,
136 expiry: chrono::DateTime<chrono::Utc>,
137 ) {
138 let mut auth = self.auth.write().await;
139 *auth = Some(AuthConfig::with_expiry(token, expiry));
140 }
141
142 pub async fn clear_auth_token(&self) {
144 let mut auth = self.auth.write().await;
145 *auth = None;
146 }
147
148 pub async fn has_valid_auth(&self) -> bool {
150 let auth = self.auth.read().await;
151 auth.as_ref().map_or(false, |auth| auth.is_valid())
152 }
153
154 async fn build_request(&self, request: RequestBuilder) -> RequestBuilder {
156 let auth = self.auth.read().await;
157 if let Some(auth) = auth.as_ref() {
158 if auth.is_valid() {
159 return request.header("Authorization", format!("Bearer {}", auth.token));
160 }
161 }
162 request
163 }
164
165 pub(crate) async fn get<T>(&self, endpoint: &str) -> Result<T>
167 where
168 T: DeserializeOwned,
169 {
170 self.rate_limiter
172 .check()
173 .await
174 .map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
175
176 let url = format!("{}{}", self.base_url, endpoint);
177 let request = self.http_client.get(&url);
178 let request = self.build_request(request).await;
179
180 let response = request.send().await?;
181
182 match response.status() {
183 StatusCode::OK => Ok(response.json().await?),
184 status => {
185 let message = response
186 .text()
187 .await
188 .unwrap_or_else(|_| "Unknown error".to_string());
189 Err(Error::Api {
190 status: status.as_u16(),
191 message,
192 })
193 }
194 }
195 }
196
197 pub(crate) async fn post<T, B>(&self, endpoint: &str, body: &B) -> Result<T>
199 where
200 T: DeserializeOwned,
201 B: Serialize,
202 {
203 self.rate_limiter
205 .check()
206 .await
207 .map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
208
209 let url = format!("{}{}", self.base_url, endpoint);
210 let request = self.http_client.post(&url).json(body);
211 let request = self.build_request(request).await;
212
213 let response = request.send().await?;
214 let status = response.status();
215 let text = response.text().await?;
216
217 if !status.is_success() {
219 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
220 if let Some(error) = json.get("error") {
221 if let Some(message) = error.get("message").and_then(|m| m.as_str()) {
222 return Err(Error::Api {
223 status: status.as_u16(),
224 message: message.to_string(),
225 });
226 }
227 }
228 }
229 return Err(Error::Api {
230 status: status.as_u16(),
231 message: text,
232 });
233 }
234
235 match serde_json::from_str(&text) {
237 Ok(value) => Ok(value),
238 Err(e) => Err(Error::Json(e)),
239 }
240 }
241
242 pub(crate) async fn post_cbor<T>(&self, endpoint: &str, data: &[u8]) -> Result<T>
244 where
245 T: DeserializeOwned,
246 {
247 self.rate_limiter
249 .check()
250 .await
251 .map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
252
253 let url = format!("{}{}", self.base_url, endpoint);
254 let request = self
255 .http_client
256 .post(&url)
257 .header("Content-Type", "application/cbor")
258 .body(data.to_vec());
259
260 let request = self.build_request(request).await;
261
262 let response = request.send().await?;
263
264 match response.status() {
265 StatusCode::OK | StatusCode::ACCEPTED => Ok(response.json().await?),
266 status => {
267 let message = response
268 .text()
269 .await
270 .unwrap_or_else(|_| "Unknown error".to_string());
271 Err(Error::Api {
272 status: status.as_u16(),
273 message,
274 })
275 }
276 }
277 }
278
279 fn is_rate_limit_error(status: StatusCode, _text: &str) -> Option<u64> {
281 if status == StatusCode::TOO_MANY_REQUESTS {
282 return Some(60);
284 }
285 None
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use serde_json::json;
293 use wiremock::matchers::{header, method, path};
294 use wiremock::{Mock, MockServer, ResponseTemplate};
295
296 #[tokio::test]
297 async fn test_get_request() {
298 let mock_server = MockServer::start().await;
299 let client = Client::builder()
300 .base_url(mock_server.uri())
301 .build()
302 .unwrap();
303
304 let mock_response = json!({
305 "data": "test"
306 });
307
308 Mock::given(method("GET"))
309 .and(path("/test"))
310 .respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
311 .mount(&mock_server)
312 .await;
313
314 let response: serde_json::Value = client.get("/test").await.unwrap();
315 assert_eq!(response, mock_response);
316 }
317
318 #[tokio::test]
319 async fn test_post_request() {
320 let mock_server = MockServer::start().await;
321 let client = Client::builder()
322 .base_url(mock_server.uri())
323 .build()
324 .unwrap();
325
326 let request_body = json!({
327 "test": "data"
328 });
329
330 let mock_response = json!({
331 "result": "success"
332 });
333
334 Mock::given(method("POST"))
335 .and(path("/test"))
336 .respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
337 .mount(&mock_server)
338 .await;
339
340 let response: serde_json::Value = client.post("/test", &request_body).await.unwrap();
341 assert_eq!(response, mock_response);
342 }
343
344 #[tokio::test]
345 async fn test_auth_token() {
346 let mock_server = MockServer::start().await;
347 let client = Client::builder()
348 .base_url(mock_server.uri())
349 .build()
350 .unwrap();
351
352 client.set_auth_token("test-token".to_string()).await;
353 assert!(client.has_valid_auth().await);
354 assert_eq!(
355 client.get_auth_token().await,
356 Some("test-token".to_string())
357 );
358
359 let mock_response = json!({
360 "data": "test"
361 });
362
363 Mock::given(method("GET"))
364 .and(path("/test"))
365 .and(header("Authorization", "Bearer test-token"))
366 .respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
367 .mount(&mock_server)
368 .await;
369
370 let response: serde_json::Value = client.get("/test").await.unwrap();
371 assert_eq!(response, mock_response);
372 }
373
374 #[tokio::test]
375 async fn test_error_handling() {
376 let mock_server = MockServer::start().await;
377 let client = Client::builder()
378 .base_url(mock_server.uri())
379 .build()
380 .unwrap();
381
382 Mock::given(method("GET"))
383 .and(path("/test"))
384 .respond_with(ResponseTemplate::new(404).set_body_string("Not Found"))
385 .mount(&mock_server)
386 .await;
387
388 let error = client.get::<serde_json::Value>("/test").await.unwrap_err();
389 match error {
390 Error::Api { status, message } => {
391 assert_eq!(status, 404);
392 assert_eq!(message, "Not Found");
393 }
394 _ => panic!("Expected API error"),
395 }
396 }
397
398 #[tokio::test]
399 async fn test_rate_limit() {
400 let mock_server = MockServer::start().await;
401 let client = Client::builder()
402 .base_url(mock_server.uri())
403 .build()
404 .unwrap();
405
406 Mock::given(method("GET"))
407 .and(path("/test"))
408 .respond_with(ResponseTemplate::new(429).set_body_string("Too Many Requests"))
409 .mount(&mock_server)
410 .await;
411
412 let error = client.get::<serde_json::Value>("/test").await.unwrap_err();
413 match error {
414 Error::Api { status, message } => {
415 assert_eq!(status, 429);
416 assert_eq!(message, "Too Many Requests");
417 }
418 _ => panic!("Expected rate limit error"),
419 }
420 }
421}