use std::{fmt::Debug, future::Future, sync::Arc};
use reqwest::{Client, IntoUrl, Response, StatusCode};
use tokio::sync::{Mutex, RwLock};
use crate::{error::ApiError, models::auth::BodyAdminTokenApiAdminTokenPost};
#[derive(Debug, Clone)]
pub struct MarzbanAPIClient {
pub(crate) inner: Arc<MarzbanAPIClientRef>,
}
pub(crate) struct MarzbanAPIClientRef {
pub(crate) base_url: String,
pub(crate) client: Client,
pub(crate) token: RwLock<Option<String>>,
pub(crate) username: RwLock<Option<String>>,
pub(crate) password: RwLock<Option<String>>,
pub(crate) refresh_lock: Mutex<()>,
}
impl Debug for MarzbanAPIClientRef {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MarzbanAPIClient")
.field("base_url", &self.base_url)
.field("client", &self.client)
.field("token", &"*****")
.field("username", &self.username)
.field("password", &"*****")
.finish()
}
}
impl MarzbanAPIClient {
pub fn new(base_url: &str) -> Self {
MarzbanAPIClient {
inner: MarzbanAPIClientRef {
base_url: base_url.to_string(),
client: Client::new(),
token: RwLock::new(None),
username: RwLock::new(None),
password: RwLock::new(None),
refresh_lock: Mutex::new(()),
}
.into(),
}
}
pub fn new_with_token(base_url: &str, token: &str) -> Self {
MarzbanAPIClient {
inner: MarzbanAPIClientRef {
base_url: base_url.to_string(),
client: Client::new(),
token: RwLock::new(Some(token.to_owned())),
username: RwLock::new(None),
password: RwLock::new(None),
refresh_lock: Mutex::new(()),
}
.into(),
}
}
pub fn prepare_request(
&self,
method: reqwest::Method,
url: impl IntoUrl,
) -> reqwest::RequestBuilder {
self.inner.client.request(method, url)
}
pub async fn attach_auth(&self, mut rb: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some(token) = self.inner.token.read().await.as_ref() {
rb = rb.bearer_auth(token);
}
rb
}
pub async fn send_with_auth_retry<M, Fut>(&self, make: M) -> Result<Response, ApiError>
where
M: Fn() -> Fut + Send + Sync,
Fut: Future<Output = reqwest::RequestBuilder> + Send,
{
let mut stage = 0usize;
loop {
let resp = self.attach_auth(make().await).await.send().await?;
if resp.status() != StatusCode::UNAUTHORIZED {
return Ok(resp);
}
match stage {
0 => {
let _g = self.inner.refresh_lock.lock().await;
stage = 1;
continue;
}
1 => {
self.refresh_token_from_saved_creds().await?;
stage = 2;
continue;
}
_ => {
return Err(ApiError::ApiResponseError("Unauthorized".into()));
}
}
}
}
async fn refresh_token_from_saved_creds(&self) -> Result<(), ApiError> {
let (username, password) = {
let u = self.inner.username.read().await.clone();
let p = self.inner.password.read().await.clone();
(u, p)
};
let (username, password) = match (username, password) {
(Some(u), Some(p)) => (u, p),
_ => {
return Err(ApiError::ApiResponseError(
"401 and no saved credentials to refresh token".into(),
));
}
};
let body = BodyAdminTokenApiAdminTokenPost {
username,
password,
grant_type: None,
scope: "".to_string(),
client_id: None,
client_secret: None,
};
let token = self.admin_token(body).await?;
let mut token_lock = self.inner.token.write().await;
*token_lock = Some(token.access_token);
Ok(())
}
}