use crate::{
ProxmoxAuth, ProxmoxConnection, ProxmoxError, ProxmoxResult, ValidationConfig,
auth::application::service::login_service::LoginService,
};
use governor::{DefaultDirectRateLimiter, Quota};
use reqwest::{Client, StatusCode};
use serde::de::DeserializeOwned;
use std::num::NonZeroU32;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct ProxmoxResponse<T> {
data: T,
}
#[derive(Debug)]
pub struct ApiClient {
http_client: Client,
connection: Arc<ProxmoxConnection>,
auth: Arc<RwLock<Option<ProxmoxAuth>>>,
config: Arc<ValidationConfig>,
rate_limiter: Option<Arc<DefaultDirectRateLimiter>>,
}
impl ApiClient {
pub fn new(connection: ProxmoxConnection, config: ValidationConfig) -> ProxmoxResult<Self> {
let http_client = Client::builder()
.danger_accept_invalid_certs(connection.accept_invalid_certs())
.build()
.map_err(|e| ProxmoxError::Connection(e.to_string()))?;
let rate_limiter = config.rate_limit.map(|rl| {
let quota = Quota::per_second(NonZeroU32::new(rl.requests_per_second).unwrap())
.allow_burst(NonZeroU32::new(rl.burst_size).unwrap());
Arc::new(DefaultDirectRateLimiter::direct(quota))
});
Ok(Self {
http_client,
connection: Arc::new(connection),
auth: Arc::new(RwLock::new(None)),
config: Arc::new(config),
rate_limiter,
})
}
pub fn connection(&self) -> &ProxmoxConnection {
&self.connection
}
pub async fn set_auth(&self, auth: ProxmoxAuth) {
let mut lock = self.auth.write().await;
*lock = Some(auth);
}
pub async fn auth(&self) -> Option<ProxmoxAuth> {
self.auth.read().await.clone()
}
pub async fn is_authenticated(&self) -> bool {
let lock = self.auth.read().await;
lock.as_ref()
.map(|a| !a.ticket().is_expired(self.config.ticket_lifetime))
.unwrap_or(false)
}
#[allow(dead_code)] pub async fn get<T>(&self, path: &str) -> ProxmoxResult<T>
where
T: DeserializeOwned,
{
self.execute_request(reqwest::Method::GET, path, None::<&()>)
.await
}
#[allow(dead_code)] pub async fn post<B, T>(&self, path: &str, body: &B) -> ProxmoxResult<T>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.execute_request(reqwest::Method::POST, path, Some(body))
.await
}
#[allow(dead_code)] pub async fn put<B, T>(&self, path: &str, body: &B) -> ProxmoxResult<T>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.execute_request(reqwest::Method::PUT, path, Some(body))
.await
}
#[allow(dead_code)] pub async fn delete<T>(&self, path: &str) -> ProxmoxResult<T>
where
T: DeserializeOwned,
{
self.execute_request(reqwest::Method::DELETE, path, None::<&()>)
.await
}
async fn execute_request<B, T>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> ProxmoxResult<T>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.ensure_authenticated().await?;
if let Some(limiter) = &self.rate_limiter {
limiter.until_ready().await;
}
let base = self.connection.url().as_str().trim_end_matches('/');
let url = format!("{}/api2/json/{}", base, path.trim_start_matches('/'));
let mut req_builder = self.http_client.request(method.clone(), &url);
{
let auth_guard = self.auth.read().await;
if let Some(auth) = auth_guard.as_ref() {
req_builder = req_builder
.header("Cookie", auth.ticket().as_cookie_header())
.header("CSRFPreventionToken", auth.csrf_token().unwrap().as_str());
}
}
if let Some(body) = body {
req_builder = req_builder.json(body);
}
let response = req_builder
.send()
.await
.map_err(|e| ProxmoxError::Connection(format!("HTTP request failed: {}", e)))?;
if response.status() == StatusCode::UNAUTHORIZED {
self.refresh_auth().await?;
return self.retry_request(method, path, body).await;
}
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "unknown".to_string());
return Err(ProxmoxError::Connection(format!(
"API error ({}): {}",
status, error_text
)));
}
let proxmox_resp = response
.json::<ProxmoxResponse<T>>()
.await
.map_err(|e| ProxmoxError::Connection(format!("Failed to parse response: {}", e)))?;
Ok(proxmox_resp.data)
}
async fn retry_request<B, T>(
&self,
method: reqwest::Method,
path: &str,
body: Option<&B>,
) -> ProxmoxResult<T>
where
B: serde::Serialize,
T: DeserializeOwned,
{
let base = self.connection.url().as_str().trim_end_matches('/');
let url = format!("{}/api2/json/{}", base, path.trim_start_matches('/'));
let mut req_builder = self.http_client.request(method, &url);
{
let auth_guard = self.auth.read().await;
if let Some(auth) = auth_guard.as_ref() {
req_builder = req_builder
.header("Cookie", auth.ticket().as_cookie_header())
.header("CSRFPreventionToken", auth.csrf_token().unwrap().as_str());
}
}
if let Some(body) = body {
req_builder = req_builder.json(body);
}
let response = req_builder.send().await.map_err(|e| {
ProxmoxError::Connection(format!("HTTP request failed on retry: {}", e))
})?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "unknown".to_string());
return Err(ProxmoxError::Connection(format!(
"API error after refresh ({}): {}",
status, error_text
)));
}
let proxmox_resp = response.json::<ProxmoxResponse<T>>().await.map_err(|e| {
ProxmoxError::Connection(format!("Failed to parse response after refresh: {}", e))
})?;
Ok(proxmox_resp.data)
}
async fn ensure_authenticated(&self) -> ProxmoxResult<()> {
let need_refresh = {
let auth_guard = self.auth.read().await;
match auth_guard.as_ref() {
Some(auth) => auth.ticket().is_expired(self.config.ticket_lifetime),
None => true,
}
};
if need_refresh {
self.refresh_auth().await?;
}
Ok(())
}
async fn refresh_auth(&self) -> ProxmoxResult<()> {
let service = LoginService::new();
let auth = service.execute(&self.connection).await?;
let mut lock = self.auth.write().await;
*lock = Some(auth);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
ProxmoxHost, ProxmoxPassword, ProxmoxPort, ProxmoxRealm, ProxmoxUrl, ProxmoxUsername,
RateLimitConfig,
core::domain::value_object::{ProxmoxCSRFToken, ProxmoxTicket},
};
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path},
};
fn create_test_connection(server_url: &str) -> ProxmoxConnection {
let host = ProxmoxHost::new_unchecked(server_url.trim_start_matches("http://").to_string());
let port = ProxmoxPort::new_unchecked(8006);
let username = ProxmoxUsername::new_unchecked("testuser".to_string());
let password = ProxmoxPassword::new_unchecked("testpass".to_string());
let realm = ProxmoxRealm::new_unchecked("pam".to_string());
let url = ProxmoxUrl::new_unchecked(server_url.to_string() + "/");
ProxmoxConnection::new(host, port, username, password, realm, false, true, url)
}
fn create_test_auth() -> ProxmoxAuth {
let ticket = ProxmoxTicket::new_unchecked("PVE:testuser@pam:4EEC61E2::sig".to_string());
let csrf = ProxmoxCSRFToken::new_unchecked("4EEC61E2:token".to_string());
ProxmoxAuth::new(ticket, Some(csrf))
}
#[tokio::test]
async fn test_get_success() {
let mock_server = MockServer::start().await;
let connection = create_test_connection(&mock_server.uri());
let config = ValidationConfig::default();
let client = ApiClient::new(connection, config).unwrap();
client.set_auth(create_test_auth()).await;
Mock::given(method("GET"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"data": "ok"
})))
.mount(&mock_server)
.await;
let result: String = client.get("test").await.unwrap();
assert_eq!(result, "ok");
}
#[tokio::test]
async fn test_unauthorized_triggers_refresh() {
let mock_server = MockServer::start().await;
let connection = create_test_connection(&mock_server.uri());
let config = ValidationConfig::default();
let client = ApiClient::new(connection, config).unwrap();
Mock::given(method("GET"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(401))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.and(path("/api2/json/access/ticket"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"data": {
"ticket": "PVE:testuser@pam:4EEC61E2::new_sig",
"CSRFPreventionToken": "4EEC61E2:abc123"
}
})))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"data": "ok"
})))
.mount(&mock_server)
.await;
let result: String = client.get("test").await.unwrap();
assert_eq!(result, "ok");
let auth = client.auth().await.unwrap();
assert_eq!(auth.ticket().as_str(), "PVE:testuser@pam:4EEC61E2::new_sig");
assert_eq!(auth.csrf_token().unwrap().as_str(), "4EEC61E2:abc123");
}
#[tokio::test]
async fn test_refresh_failure_returns_error() {
let mock_server = MockServer::start().await;
let connection = create_test_connection(&mock_server.uri());
let config = ValidationConfig::default();
let client = ApiClient::new(connection, config).unwrap();
Mock::given(method("GET"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(401))
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.and(path("/api2/json/access/ticket"))
.respond_with(ResponseTemplate::new(401))
.mount(&mock_server)
.await;
let result: ProxmoxResult<String> = client.get("test").await;
assert!(matches!(result, Err(ProxmoxError::Authentication(_))));
}
#[tokio::test]
async fn test_rate_limiting_delays_requests() {
use std::time::{Duration, Instant};
let mock_server = MockServer::start().await;
let connection = create_test_connection(&mock_server.uri());
let config = ValidationConfig {
rate_limit: Some(RateLimitConfig {
requests_per_second: 2,
burst_size: 2,
}),
..Default::default()
};
let client = ApiClient::new(connection, config).unwrap();
client.set_auth(create_test_auth()).await;
Mock::given(method("GET"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"data": "ok"
})))
.expect(4) .mount(&mock_server)
.await;
let start = Instant::now();
let req1 = client.get::<String>("test");
let req2 = client.get::<String>("test");
let (res1, res2) = tokio::join!(req1, req2);
res1.unwrap();
res2.unwrap();
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(500));
let start = Instant::now();
let req3 = client.get::<String>("test");
let req4 = client.get::<String>("test");
let (res3, res4) = tokio::join!(req3, req4);
res3.unwrap();
res4.unwrap();
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(900));
}
#[tokio::test]
async fn test_rate_limiting_disabled() {
use tokio::time::{self, Duration};
let mock_server = MockServer::start().await;
let connection = create_test_connection(&mock_server.uri());
let config = ValidationConfig {
rate_limit: None, ..Default::default()
};
let client = ApiClient::new(connection, config).unwrap();
client.set_auth(create_test_auth()).await;
Mock::given(method("GET"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"data": "ok"
})))
.expect(10)
.mount(&mock_server)
.await;
let start = time::Instant::now();
for _ in 0..10 {
client.get::<String>("test").await.unwrap();
}
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_millis(500));
}
#[tokio::test]
async fn test_post_with_body() {
let mock_server = MockServer::start().await;
let connection = create_test_connection(&mock_server.uri());
let config = ValidationConfig::default();
let client = ApiClient::new(connection, config).unwrap();
client.set_auth(create_test_auth()).await;
#[derive(serde::Serialize)]
struct MyBody {
key: String,
}
Mock::given(method("POST"))
.and(path("/api2/json/test"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"data": {
"result": "created"
}
})))
.mount(&mock_server)
.await;
let body = MyBody {
key: "value".into(),
};
let result: serde_json::Value = client.post("test", &body).await.unwrap();
assert_eq!(result["result"], "created");
}
}