use crate::auth::AuthManager;
use crate::error::{WebullError, WebullResult};
use crate::models::response::ApiResponse;
use crate::utils::cache::CacheManager;
use crate::utils::rate_limit::RateLimiter;
use reqwest::{Client, Method, RequestBuilder, StatusCode};
use reqwest::header::AUTHORIZATION;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::sync::Arc;
use std::time::Duration;
use url::Url;
pub struct BaseEndpoint {
client: Client,
base_url: String,
auth_manager: Arc<AuthManager>,
rate_limiter: Arc<RateLimiter>,
cache_manager: Arc<CacheManager>,
}
impl BaseEndpoint {
pub fn new(client: Client, base_url: String, auth_manager: Arc<AuthManager>) -> Self {
Self {
client,
base_url,
auth_manager,
rate_limiter: Arc::new(RateLimiter::new(60)), cache_manager: Arc::new(CacheManager::new()),
}
}
pub fn request<T>(&self, method: Method, path: &str) -> RequestBuilder
where
T: DeserializeOwned,
{
let url = self.build_url(path);
self.client.request(method, url)
}
pub async fn send_request<T>(&self, request: RequestBuilder) -> WebullResult<T>
where
T: DeserializeOwned + Clone,
{
let req_url = request.try_clone()
.ok_or_else(|| WebullError::InvalidRequest("Failed to clone request".to_string()))?
.build()
.map_err(WebullError::NetworkError)?
.url()
.clone();
let path = req_url.path();
self.rate_limiter.wait(path).await;
let response = request.send().await.map_err(WebullError::NetworkError)?;
let status = response.status();
if status == StatusCode::TOO_MANY_REQUESTS {
let retry_after = response.headers()
.get("retry-after")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(1);
tokio::time::sleep(std::time::Duration::from_secs(retry_after)).await;
return Err(WebullError::RateLimitExceeded);
}
if status == StatusCode::UNAUTHORIZED {
return Err(WebullError::Unauthorized);
}
let body = response.text().await.map_err(WebullError::NetworkError)?;
if !status.is_success() {
return Err(WebullError::ApiError {
code: status.as_u16().to_string(),
message: body,
});
}
let api_response: ApiResponse<T> = serde_json::from_str(&body)
.map_err(|e| WebullError::SerializationError(e))?;
if !api_response.is_success() {
return Err(WebullError::ApiError {
code: api_response.code.unwrap_or_else(|| "unknown".to_string()),
message: api_response.message.unwrap_or_else(|| "Unknown error".to_string()),
});
}
api_response.get_data().cloned().ok_or_else(|| WebullError::ApiError {
code: "no_data".to_string(),
message: "Response did not contain data".to_string(),
})
}
fn build_url(&self, path: &str) -> Url {
let base = self.base_url.trim_end_matches('/');
let path = path.trim_start_matches('/');
let url = format!("{}/{}", base, path);
Url::parse(&url).unwrap_or_else(|_| {
panic!("Invalid URL: {}", url);
})
}
pub async fn authenticate_request(&self, request: RequestBuilder) -> WebullResult<RequestBuilder> {
let token = self.auth_manager.get_token().await?;
let request = request.header(AUTHORIZATION, format!("Bearer {}", token.token));
Ok(request)
}
pub async fn get<T>(&self, path: &str) -> WebullResult<T>
where
T: DeserializeOwned + Clone + Send + Sync + 'static,
{
let cache = self.cache_manager.get_cache::<T>("get");
if let Some(cached) = cache.get("GET", path, None, None) {
return Ok(cached);
}
let request = self.request::<T>(Method::GET, path);
let request = self.authenticate_request(request).await?;
let response: T = self.send_request(request).await?;
cache.set("GET", path, None, None, response.clone(), Some(Duration::from_secs(60)));
Ok(response)
}
pub async fn post<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
where
T: DeserializeOwned + Clone + Send + Sync + 'static,
B: Serialize,
{
let body_str = match serde_json::to_string(body) {
Ok(s) => Some(s),
Err(_) => None,
};
if let Some(body_str) = &body_str {
let cache = self.cache_manager.get_cache::<T>("post");
if let Some(cached) = cache.get("POST", path, None, Some(body_str)) {
return Ok(cached);
}
}
let request = self.request::<T>(Method::POST, path).json(body);
let request = self.authenticate_request(request).await?;
let response: T = self.send_request(request).await?;
if let Some(body_str) = body_str {
let cache = self.cache_manager.get_cache::<T>("post");
cache.set("POST", path, None, Some(&body_str), response.clone(), Some(Duration::from_secs(60)));
}
Ok(response)
}
pub async fn put<T, B>(&self, path: &str, body: &B) -> WebullResult<T>
where
T: DeserializeOwned + Clone + Send + Sync + 'static,
B: Serialize,
{
let request = self.request::<T>(Method::PUT, path).json(body);
let request = self.authenticate_request(request).await?;
let response: T = self.send_request(request).await?;
let get_cache = self.cache_manager.get_cache::<T>("get");
get_cache.clear();
Ok(response)
}
pub async fn delete<T>(&self, path: &str) -> WebullResult<T>
where
T: DeserializeOwned + Clone + Send + Sync + 'static,
{
let request = self.request::<T>(Method::DELETE, path);
let request = self.authenticate_request(request).await?;
let response: T = self.send_request(request).await?;
let get_cache = self.cache_manager.get_cache::<T>("get");
get_cache.clear();
let post_cache = self.cache_manager.get_cache::<T>("post");
post_cache.clear();
Ok(response)
}
}