use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use reqwest::header::HeaderMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
use crate::error::ApiError;
use crate::request::RequestConfig;
use crate::response::ApiResponse;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HttpMethod {
Get,
Post,
}
impl HttpMethod {
fn as_reqwest(&self) -> reqwest::Method {
match self {
HttpMethod::Get => reqwest::Method::GET,
HttpMethod::Post => reqwest::Method::POST,
}
}
}
impl std::fmt::Display for HttpMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HttpMethod::Get => write!(f, "GET"),
HttpMethod::Post => write!(f, "POST"),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SessionStore {
inner: Arc<RwLock<SessionData>>,
}
#[derive(Debug, Default)]
struct SessionData {
xsrf_token: Option<String>,
cookie: Option<String>,
}
impl SessionStore {
pub fn new() -> Self {
Self::default()
}
pub fn xsrf_token(&self) -> Option<String> {
self.inner.read().unwrap().xsrf_token.clone()
}
pub fn cookie(&self) -> Option<String> {
self.inner.read().unwrap().cookie.clone()
}
fn update_from_response(&self, headers: &HeaderMap) {
let mut data = self.inner.write().unwrap();
for key in &["x-xsrf-token", "xsrf-token", "x-csrf-token"] {
if let Some(value) = headers.get(*key) {
if let Ok(v) = value.to_str() {
data.xsrf_token = Some(v.to_string());
tracing::debug!("Updated XSRF token from header '{}'", key);
break;
}
}
}
if let Some(value) = headers.get("set-cookie") {
if let Ok(v) = value.to_str() {
data.cookie = Some(v.to_string());
tracing::debug!("Updated cookie from response");
}
}
}
fn apply_to_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let data = self.inner.read().unwrap();
let mut builder = builder;
if let Some(ref token) = data.xsrf_token {
builder = builder.header("X-XSRF-TOKEN", token);
}
if let Some(ref cookie) = data.cookie {
builder = builder.header("Cookie", cookie);
}
builder
}
pub fn set_xsrf_token(&self, token: impl Into<String>) {
self.inner.write().unwrap().xsrf_token = Some(token.into());
}
pub fn set_cookie(&self, cookie: impl Into<String>) {
self.inner.write().unwrap().cookie = Some(cookie.into());
}
pub fn clear(&self) {
let mut data = self.inner.write().unwrap();
data.xsrf_token = None;
data.cookie = None;
}
}
#[derive(Debug, Clone)]
pub struct ApiClient {
http_client: reqwest::Client,
base_url: Option<String>,
default_headers: HashMap<String, String>,
default_timeout: Duration,
session_enabled: bool,
pub session: SessionStore,
}
impl ApiClient {
pub fn builder() -> ApiClientBuilder {
ApiClientBuilder::default()
}
pub async fn make_request<T, B>(
&self,
method: HttpMethod,
url: &str,
body: Option<&B>,
query: Option<&[(&str, &str)]>,
config: Option<RequestConfig>,
) -> Result<ApiResponse<T>, ApiError>
where
T: DeserializeOwned,
B: Serialize + ?Sized,
{
let config = config.unwrap_or_default();
let full_url = self.build_url(url)?;
tracing::info!("{} {}", method, full_url);
let timeout = config.timeout.unwrap_or(self.default_timeout);
let mut builder = self
.http_client
.request(method.as_reqwest(), &full_url)
.timeout(timeout);
for (key, value) in &self.default_headers {
builder = builder.header(key.as_str(), value.as_str());
}
for (key, value) in &config.headers {
builder = builder.header(key.as_str(), value.as_str());
}
if let Some(ref token) = config.bearer_token {
builder = builder.bearer_auth(token);
tracing::debug!("Authorization header set");
}
if self.session_enabled {
builder = self.session.apply_to_request(builder);
}
if let Some(params) = query {
builder = builder.query(params);
tracing::debug!("Query params: {:?}", params);
}
if let Some(body) = body {
if method == HttpMethod::Post {
let json_value = serde_json::to_value(body)?;
tracing::debug!("Request body: {}", json_value);
builder = builder.json(&json_value);
}
}
let response = builder.send().await.map_err(|e| {
if e.is_timeout() {
tracing::error!("Request timed out: {} {}", method, full_url);
ApiError::Timeout
} else {
tracing::error!("Network error: {} {} — {}", method, full_url, e);
ApiError::NetworkError(e)
}
})?;
let status = response.status().as_u16();
let response_headers = response.headers().clone();
if self.session_enabled {
self.session.update_from_response(&response_headers);
}
let headers: HashMap<String, String> = response_headers
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect();
let raw_body = response.text().await.map_err(|e| {
tracing::error!("Failed to read response body: {}", e);
ApiError::NetworkError(e)
})?;
if raw_body.is_empty() {
tracing::warn!("Empty response body from {} {}", method, full_url);
return Err(ApiError::EmptyResponse);
}
tracing::debug!("Response [{}] from {}: {}", status, full_url, &raw_body);
if !(200..300).contains(&(status as usize)) {
tracing::error!("HTTP {} from {} {}", status, method, full_url);
return Err(ApiError::HttpError {
status,
body: raw_body,
});
}
let body: T = serde_json::from_str(&raw_body).map_err(|e| {
tracing::error!(
"Failed to deserialize response from {} {}: {}",
method,
full_url,
e
);
ApiError::SerializationError(e)
})?;
Ok(ApiResponse {
status,
headers,
body,
raw_body,
})
}
pub async fn get<T: DeserializeOwned>(
&self,
url: &str,
query: Option<&[(&str, &str)]>,
config: Option<RequestConfig>,
) -> Result<ApiResponse<T>, ApiError> {
self.make_request::<T, ()>(HttpMethod::Get, url, None, query, config)
.await
}
pub async fn post<T, B>(
&self,
url: &str,
body: Option<&B>,
config: Option<RequestConfig>,
) -> Result<ApiResponse<T>, ApiError>
where
T: DeserializeOwned,
B: Serialize + ?Sized,
{
self.make_request(HttpMethod::Post, url, body, None, config)
.await
}
fn build_url(&self, url: &str) -> Result<String, ApiError> {
let full_url = if url.starts_with("http://") || url.starts_with("https://") {
url.to_string()
} else if let Some(ref base) = self.base_url {
let base = base.trim_end_matches('/');
let path = url.trim_start_matches('/');
format!("{}/{}", base, path)
} else {
url.to_string()
};
reqwest::Url::parse(&full_url)
.map_err(|_| ApiError::InvalidUrl(full_url.clone()))?;
Ok(full_url)
}
}
#[derive(Debug, Default)]
pub struct ApiClientBuilder {
base_url: Option<String>,
default_headers: HashMap<String, String>,
default_timeout: Option<Duration>,
session_enabled: bool,
}
impl ApiClientBuilder {
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn default_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = Some(timeout);
self
}
pub fn default_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.default_headers.insert(key.into(), value.into());
self
}
pub fn session_enabled(mut self, enabled: bool) -> Self {
self.session_enabled = enabled;
self
}
pub fn build(self) -> Result<ApiClient, ApiError> {
let http_client = reqwest::Client::builder()
.cookie_store(self.session_enabled)
.build()
.map_err(ApiError::NetworkError)?;
Ok(ApiClient {
http_client,
base_url: self.base_url,
default_headers: self.default_headers,
default_timeout: self.default_timeout.unwrap_or(Duration::from_secs(60)),
session_enabled: self.session_enabled,
session: SessionStore::new(),
})
}
}