mod builder;
mod config;
mod rate_limit;
pub use builder::ClientBuilder;
pub use config::{AuthConfig, Config};
use rate_limit::RateLimiter;
use crate::error::{Error, Result};
use reqwest::{Client as ReqwestClient, RequestBuilder, StatusCode};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct Client {
http_client: ReqwestClient,
base_url: String,
auth: Arc<RwLock<Option<AuthConfig>>>,
rate_limiter: Arc<RateLimiter>,
}
impl Client {
pub fn new() -> Result<Self> {
Self::builder().build()
}
pub fn builder() -> ClientBuilder {
ClientBuilder::new()
}
pub(crate) fn with_config(config: Config) -> Result<Self> {
let http_client = ReqwestClient::builder().timeout(config.timeout).build()?;
let base_url = config.base_url();
Ok(Self {
http_client,
base_url,
auth: Arc::new(RwLock::new(config.auth)),
rate_limiter: Arc::new(RateLimiter::new(config.rate_limit)),
})
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub async fn get_auth_token(&self) -> Option<String> {
let auth = self.auth.read().await;
auth.as_ref().and_then(|auth| {
if auth.is_valid() {
Some(auth.token.clone())
} else {
None
}
})
}
pub fn auth_token(&self) -> Option<String> {
futures::executor::block_on(self.get_auth_token())
}
pub async fn set_auth_token(&self, token: String) {
let mut auth = self.auth.write().await;
*auth = Some(AuthConfig::new(token));
}
pub async fn set_auth_token_with_expiry(
&self,
token: String,
expiry: chrono::DateTime<chrono::Utc>,
) {
let mut auth = self.auth.write().await;
*auth = Some(AuthConfig::with_expiry(token, expiry));
}
pub async fn clear_auth_token(&self) {
let mut auth = self.auth.write().await;
*auth = None;
}
pub async fn has_valid_auth(&self) -> bool {
let auth = self.auth.read().await;
auth.as_ref().map_or(false, |auth| auth.is_valid())
}
async fn build_request(&self, request: RequestBuilder) -> RequestBuilder {
let auth = self.auth.read().await;
if let Some(auth) = auth.as_ref() {
if auth.is_valid() {
return request.header("Authorization", format!("Bearer {}", auth.token));
}
}
request
}
pub(crate) async fn get<T>(&self, endpoint: &str) -> Result<T>
where
T: DeserializeOwned,
{
self.rate_limiter
.check()
.await
.map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
let url = format!("{}{}", self.base_url, endpoint);
let request = self.http_client.get(&url);
let request = self.build_request(request).await;
let response = request.send().await?;
match response.status() {
StatusCode::OK => Ok(response.json().await?),
status => {
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(Error::Api {
status: status.as_u16(),
message,
})
}
}
}
pub(crate) async fn post<T, B>(&self, endpoint: &str, body: &B) -> Result<T>
where
T: DeserializeOwned,
B: Serialize,
{
self.rate_limiter
.check()
.await
.map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
let url = format!("{}{}", self.base_url, endpoint);
let request = self.http_client.post(&url).json(body);
let request = self.build_request(request).await;
let response = request.send().await?;
let status = response.status();
let text = response.text().await?;
if !status.is_success() {
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
if let Some(error) = json.get("error") {
if let Some(message) = error.get("message").and_then(|m| m.as_str()) {
return Err(Error::Api {
status: status.as_u16(),
message: message.to_string(),
});
}
}
}
return Err(Error::Api {
status: status.as_u16(),
message: text,
});
}
match serde_json::from_str(&text) {
Ok(value) => Ok(value),
Err(e) => Err(Error::Json(e)),
}
}
pub(crate) async fn post_cbor<T>(&self, endpoint: &str, data: &[u8]) -> Result<T>
where
T: DeserializeOwned,
{
self.rate_limiter
.check()
.await
.map_err(|e| Error::RateLimit(e.wait_time().as_secs()))?;
let url = format!("{}{}", self.base_url, endpoint);
let request = self
.http_client
.post(&url)
.header("Content-Type", "application/cbor")
.body(data.to_vec());
let request = self.build_request(request).await;
let response = request.send().await?;
match response.status() {
StatusCode::OK | StatusCode::ACCEPTED => Ok(response.json().await?),
status => {
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
Err(Error::Api {
status: status.as_u16(),
message,
})
}
}
}
fn is_rate_limit_error(status: StatusCode, _text: &str) -> Option<u64> {
if status == StatusCode::TOO_MANY_REQUESTS {
return Some(60);
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[tokio::test]
async fn test_get_request() {
let mock_server = MockServer::start().await;
let client = Client::builder()
.base_url(mock_server.uri())
.build()
.unwrap();
let mock_response = json!({
"data": "test"
});
Mock::given(method("GET"))
.and(path("/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
.mount(&mock_server)
.await;
let response: serde_json::Value = client.get("/test").await.unwrap();
assert_eq!(response, mock_response);
}
#[tokio::test]
async fn test_post_request() {
let mock_server = MockServer::start().await;
let client = Client::builder()
.base_url(mock_server.uri())
.build()
.unwrap();
let request_body = json!({
"test": "data"
});
let mock_response = json!({
"result": "success"
});
Mock::given(method("POST"))
.and(path("/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
.mount(&mock_server)
.await;
let response: serde_json::Value = client.post("/test", &request_body).await.unwrap();
assert_eq!(response, mock_response);
}
#[tokio::test]
async fn test_auth_token() {
let mock_server = MockServer::start().await;
let client = Client::builder()
.base_url(mock_server.uri())
.build()
.unwrap();
client.set_auth_token("test-token".to_string()).await;
assert!(client.has_valid_auth().await);
assert_eq!(
client.get_auth_token().await,
Some("test-token".to_string())
);
let mock_response = json!({
"data": "test"
});
Mock::given(method("GET"))
.and(path("/test"))
.and(header("Authorization", "Bearer test-token"))
.respond_with(ResponseTemplate::new(200).set_body_json(&mock_response))
.mount(&mock_server)
.await;
let response: serde_json::Value = client.get("/test").await.unwrap();
assert_eq!(response, mock_response);
}
#[tokio::test]
async fn test_error_handling() {
let mock_server = MockServer::start().await;
let client = Client::builder()
.base_url(mock_server.uri())
.build()
.unwrap();
Mock::given(method("GET"))
.and(path("/test"))
.respond_with(ResponseTemplate::new(404).set_body_string("Not Found"))
.mount(&mock_server)
.await;
let error = client.get::<serde_json::Value>("/test").await.unwrap_err();
match error {
Error::Api { status, message } => {
assert_eq!(status, 404);
assert_eq!(message, "Not Found");
}
_ => panic!("Expected API error"),
}
}
#[tokio::test]
async fn test_rate_limit() {
let mock_server = MockServer::start().await;
let client = Client::builder()
.base_url(mock_server.uri())
.build()
.unwrap();
Mock::given(method("GET"))
.and(path("/test"))
.respond_with(ResponseTemplate::new(429).set_body_string("Too Many Requests"))
.mount(&mock_server)
.await;
let error = client.get::<serde_json::Value>("/test").await.unwrap_err();
match error {
Error::Api { status, message } => {
assert_eq!(status, 429);
assert_eq!(message, "Too Many Requests");
}
_ => panic!("Expected rate limit error"),
}
}
}