use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use reqwest::{Client, Proxy};
use url::Url;
use crate::config::Config;
use crate::dto::{AccessTokenRequest, AccessTokenResponse, CreateSessionRequest, SessionResponse};
use crate::error::{MyIdError, MyIdResult};
const ACCESS_TOKEN_PATH: &str = "api/v1/auth/clients/access-token";
const CREATE_SESSION_PATH: &str = "api/v2/sdk/sessions";
#[allow(dead_code)]
const USER_DATA_PATH: &str = "api/v1/sdk/data";
#[allow(dead_code)]
const SESSION_STATUS_PATH: &str = "api/v1/sdk/sessions";
#[derive(Clone)]
pub struct MyIdClient {
config: Config,
http: Client,
token: Arc<Mutex<Option<TokenState>>>,
token_refresh_margin: Duration,
}
#[derive(Clone)]
struct TokenState {
access_token: String,
expires_at: Instant,
}
impl TokenState {
fn is_valid(&self, margin: Duration) -> bool {
Instant::now() + margin < self.expires_at
}
}
impl MyIdClient {
pub fn new(config: Config) -> MyIdResult<Self> {
let http = Self::build_http_client(&config)?;
Ok(Self {
config,
http,
token: Arc::new(Mutex::new(None)),
token_refresh_margin: Duration::from_secs(60),
})
}
pub async fn create_session(
&self,
request: &CreateSessionRequest,
) -> MyIdResult<SessionResponse> {
let token = self.get_token().await?;
let url = self.endpoint(CREATE_SESSION_PATH)?;
let response = self
.http
.post(url.as_str())
.bearer_auth(token)
.json(request)
.send()
.await?;
Self::handle_response(response).await
}
pub async fn get_token(&self) -> MyIdResult<String> {
if let Some(token) = self.read_cached_token().await {
return Ok(token);
}
let fresh = self.authenticate().await?;
self.write_cached_token(fresh).await
}
async fn authenticate(&self) -> MyIdResult<AccessTokenResponse> {
let url = self.endpoint(ACCESS_TOKEN_PATH)?;
let body = AccessTokenRequest {
client_id: self.config.client_id(),
client_secret: self.config.client_secret(),
};
let response = self.http.post(url.as_str()).json(&body).send().await?;
Self::handle_response(response).await
}
async fn read_cached_token(&self) -> Option<String> {
let guard = self.token.lock().await;
guard.as_ref().and_then(|state| {
if state.is_valid(self.token_refresh_margin) {
Some(state.access_token.clone())
} else {
None
}
})
}
async fn write_cached_token(&self, token: AccessTokenResponse) -> MyIdResult<String> {
const MAX_TTL_SECS: u64 = 31_536_000;
if token.expires_in == 0 {
return Err(MyIdError::internal("expires_in must be > 0"));
}
if token.expires_in > MAX_TTL_SECS {
return Err(MyIdError::internal(format!(
"expires_in too large: {}",
token.expires_in
)));
}
let expires_at = Instant::now()
.checked_add(Duration::from_secs(token.expires_in))
.ok_or_else(|| MyIdError::internal("expires_at overflow"))?;
let mut guard = self.token.lock().await;
*guard = Some(TokenState {
access_token: token.access_token.clone(),
expires_at,
});
Ok(token.access_token)
}
fn endpoint(&self, path: &str) -> MyIdResult<Url> {
self.config
.base_url_parsed()
.join(path)
.map_err(|e| MyIdError::config(format!("invalid endpoint `{path}`: {e}")))
}
async fn handle_response<T: serde::de::DeserializeOwned>(
response: reqwest::Response,
) -> MyIdResult<T> {
if response.status().is_success() {
return Ok(response.json().await?);
}
let status = response.status().as_u16();
let body = response
.text()
.await
.unwrap_or_else(|_| "response body o'qib bo'lmadi".to_string());
Err(MyIdError::api(status, body))
}
fn build_http_client(config: &Config) -> MyIdResult<Client> {
let mut builder = Client::builder()
.connect_timeout(config.connection_timeout())
.timeout(config.timeout())
.user_agent(config.user_agent());
if let Some(proxy_url) = config.proxy_url() {
let proxy = Proxy::all(proxy_url)?;
builder = builder.proxy(proxy);
}
Ok(builder.build()?)
}
}