use crate::{ClientRegisterRequest, database::Database};
use anyhow::Result;
use reqwest::{Client, Method, RequestBuilder, Response};
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{error, info, warn};
#[derive(Debug, Clone)]
pub struct AuthenticatedClient {
client: Client,
database: Arc<Database>,
server_base_url: String,
client_id: Arc<RwLock<Option<String>>>,
}
impl AuthenticatedClient {
pub async fn new(database: Arc<Database>, server_base_url: String) -> Result<Self> {
let client = Client::new();
let client_id = database.get_client_id().await?;
Ok(Self {
client,
database,
server_base_url,
client_id: Arc::new(RwLock::new(client_id)),
})
}
fn is_our_server(&self, url: &str) -> bool {
url.starts_with(&self.server_base_url)
}
fn is_register_endpoint(&self, url: &str) -> bool {
url.contains("/clients/register")
}
async fn get_client_id(&self) -> Option<String> {
self.client_id.read().await.clone()
}
async fn set_client_id(&self, new_client_id: String) -> Result<()> {
*self.client_id.write().await = Some(new_client_id.clone());
self.database.update_client_id(&new_client_id).await?;
Ok(())
}
async fn auto_register(&self) -> Result<String> {
info!("Attempting to auto-register client...");
let request = ClientRegisterRequest {
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
};
let register_url = format!(
"{}{}",
self.server_base_url,
crate::constants::api::endpoints::CLIENT_REGISTER
);
let response = self
.client
.post(®ister_url)
.json(&request)
.send()
.await?;
if response.status().is_success() {
let register_response: serde_json::Value = response.json().await?;
if let Some(client_id) = register_response.get("client_id").and_then(|v| v.as_str()) {
let client_id = client_id.to_string();
info!("Auto-registration successful, client ID: {}", client_id);
self.set_client_id(client_id.clone()).await?;
Ok(client_id)
} else {
Err(anyhow::anyhow!("Invalid registration response format"))
}
} else {
let status = response.status();
let text = response.text().await.unwrap_or_default();
error!("Client registration failed: {} - {}", status, text);
Err(anyhow::anyhow!("Registration failed: {status} - {text}"))
}
}
async fn add_auth_header(
&self,
mut request_builder: RequestBuilder,
url: &str,
) -> RequestBuilder {
if self.is_our_server(url) && !self.is_register_endpoint(url) {
if let Some(client_id) = self.get_client_id().await {
request_builder = request_builder.header("X-Client-ID", client_id);
}
}
request_builder
}
async fn execute_request(&self, method: Method, url: &str) -> Result<RequestBuilder> {
let request_builder = self.client.request(method, url);
Ok(self.add_auth_header(request_builder, url).await)
}
async fn execute_request_with_json<T: Serialize>(
&self,
method: Method,
url: &str,
json: &T,
) -> Result<RequestBuilder> {
let request_builder = self.client.request(method, url).json(json);
Ok(self.add_auth_header(request_builder, url).await)
}
async fn send_with_retry(
&self,
request_builder: RequestBuilder,
original_url: &str,
) -> Result<Response> {
let response = request_builder.send().await?;
if response.status() == reqwest::StatusCode::UNAUTHORIZED
&& self.is_our_server(original_url)
&& !self.is_register_endpoint(original_url)
{
warn!("API request authentication failed (401), attempting to auto-re-register...");
match self.auto_register().await {
Ok(new_client_id) => {
info!("Auto re-registration successful, client ID: {}, retrying request...", new_client_id);
let retry_request_builder = self
.client
.get(original_url)
.header("X-Client-ID", new_client_id);
let retry_response = retry_request_builder.send().await?;
Ok(retry_response)
}
Err(e) => {
error!("Auto re-registration failed: {}", e);
Err(anyhow::anyhow!("Authentication failed and unable to re-register: {e}"))
}
}
} else {
Ok(response)
}
}
pub async fn get(&self, url: &str) -> Result<RequestBuilder> {
self.execute_request(Method::GET, url).await
}
pub async fn post(&self, url: &str) -> Result<RequestBuilder> {
self.execute_request(Method::POST, url).await
}
pub async fn put(&self, url: &str) -> Result<RequestBuilder> {
self.execute_request(Method::PUT, url).await
}
pub async fn delete(&self, url: &str) -> Result<RequestBuilder> {
self.execute_request(Method::DELETE, url).await
}
pub async fn post_json<T: Serialize>(&self, url: &str, json: &T) -> Result<Response> {
let request_builder = self
.execute_request_with_json(Method::POST, url, json)
.await?;
self.send_with_retry(request_builder, url).await
}
pub async fn put_json<T: Serialize>(&self, url: &str, json: &T) -> Result<Response> {
let request_builder = self
.execute_request_with_json(Method::PUT, url, json)
.await?;
self.send_with_retry(request_builder, url).await
}
pub async fn send(&self, request_builder: RequestBuilder, url: &str) -> Result<Response> {
self.send_with_retry(request_builder, url).await
}
pub fn inner(&self) -> &Client {
&self.client
}
pub async fn current_client_id(&self) -> Option<String> {
self.get_client_id().await
}
}