use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use reqwest::{Client, Proxy, StatusCode};
use url::Url;
use crate::config::Config;
use crate::dto::{
AccessTokenRequest, AccessTokenResponse, ApiErrorBody, CreateSessionRequest, SessionResponse,
SessionStatusResponse, UserDataResponse,
};
use crate::error::{MyIdError, MyIdResult};
use crate::types::SessionId;
const ACCESS_TOKEN_PATH: &str = "api/v1/auth/clients/access-token";
const CREATE_SESSION_PATH: &str = "api/v2/sdk/sessions";
const USER_DATA_PATH: &str = "api/v1/sdk/data";
const SESSION_RECOVERY_PATH: &str = "api/v1/sdk/sessions";
const AUTH_MAX_ATTEMPTS: u8 = 4; const AUTH_RETRY_BASE_MS: u64 = 100;
const AUTH_RETRY_MAX_MS: u64 = 2_000;
#[derive(Clone)]
pub struct MyIdClient {
config: Config,
http: Client,
token: Arc<Mutex<Option<TokenState>>>,
refresh_lock: Arc<Mutex<()>>,
token_refresh_margin: Duration,
}
#[derive(Clone)]
struct TokenState {
access_token: String,
expires_at: Instant,
}
impl TokenState {
fn is_valid(&self, margin: Duration) -> bool {
Instant::now()
.checked_add(margin)
.is_some_and(|t| t < self.expires_at)
}
}
impl MyIdClient {
pub fn new(config: Config) -> MyIdResult<Self> {
let http = Self::build_http_client(&config)?;
Ok(Self {
config,
http,
token: Arc::new(Mutex::new(None)),
refresh_lock: Arc::new(Mutex::new(())),
token_refresh_margin: Duration::from_secs(60),
})
}
pub async fn create_session(
&self,
request: &CreateSessionRequest,
) -> MyIdResult<SessionResponse> {
let url = self.endpoint(CREATE_SESSION_PATH)?;
let response = self
.send_with_401_retry(|token| {
self.http
.post(url.as_str())
.bearer_auth(token)
.json(request)
.send()
})
.await?;
Self::handle_response(response).await
}
pub async fn get_token(&self) -> MyIdResult<String> {
if let Some(token) = self.read_cached_token().await {
return Ok(token);
}
let _refresh_guard = self.refresh_lock.lock().await;
if let Some(token) = self.read_cached_token().await {
return Ok(token);
}
let fresh = self.authenticate().await?;
self.write_cached_token(fresh).await
}
pub async fn recover_session(
&self,
session_id: SessionId,
) -> MyIdResult<SessionStatusResponse> {
let url = self.endpoint(&format!("{}/{}", SESSION_RECOVERY_PATH, session_id))?;
let response = self
.send_with_401_retry(|token| self.http.get(url.as_str()).bearer_auth(token).send())
.await?;
Self::handle_response(response).await
}
pub async fn handle_callback(&self, code: &str) -> MyIdResult<UserDataResponse> {
if code.trim().is_empty() {
return Err(MyIdError::validation("code bo'sh bo'lishi mumkin emas"));
}
let mut url = self.endpoint(USER_DATA_PATH)?;
url.query_pairs_mut().append_pair("code", code);
let url_str = url.to_string();
let response = self
.send_with_401_retry(|token| self.http.get(url_str.as_str()).bearer_auth(token).send())
.await?;
Self::handle_response(response).await
}
async fn authenticate(&self) -> MyIdResult<AccessTokenResponse> {
let url = self.endpoint(ACCESS_TOKEN_PATH)?;
let body = AccessTokenRequest {
client_id: self.config.client_id(),
client_secret: self.config.client_secret(),
};
let max_attempts = AUTH_MAX_ATTEMPTS.max(1);
for attempt in 1..=max_attempts {
let is_last = attempt == max_attempts;
match self.http.post(url.as_str()).json(&body).send().await {
Ok(resp) => match Self::handle_response(resp).await {
Ok(token) => return Ok(token),
Err(e) if !is_last && Self::is_retryable_auth_error(&e) => {
tokio::time::sleep(Self::auth_retry_backoff(attempt)).await;
}
Err(e) => return Err(e),
},
Err(e) => {
let err = MyIdError::http(e);
if !is_last && Self::is_retryable_auth_error(&err) {
tokio::time::sleep(Self::auth_retry_backoff(attempt)).await;
} else {
return Err(err);
}
}
}
}
unreachable!("authenticate: oxirgi urinishda doim Err qaytariladi")
}
async fn send_with_401_retry<F, Fut>(&self, build_request: F) -> MyIdResult<reqwest::Response>
where
F: Fn(String) -> Fut,
Fut: std::future::Future<Output = reqwest::Result<reqwest::Response>>,
{
let token = self.get_token().await?;
let response = build_request(token).await?;
if response.status() == StatusCode::UNAUTHORIZED {
self.invalidate_cached_token().await;
let retry_token = self.get_token().await?;
return Ok(build_request(retry_token).await?);
}
Ok(response)
}
async fn read_cached_token(&self) -> Option<String> {
let guard = self.token.lock().await;
guard.as_ref().and_then(|state| {
if state.is_valid(self.token_refresh_margin) {
Some(state.access_token.clone())
} else {
None
}
})
}
async fn invalidate_cached_token(&self) {
let mut guard = self.token.lock().await;
*guard = None;
}
async fn write_cached_token(&self, token: AccessTokenResponse) -> MyIdResult<String> {
const MAX_TTL_SECS: u64 = 31_536_000;
if token.expires_in == 0 {
return Err(MyIdError::internal("expires_in must be > 0"));
}
if token.expires_in > MAX_TTL_SECS {
return Err(MyIdError::internal(format!(
"expires_in too large: {}",
token.expires_in
)));
}
let expires_at = Instant::now()
.checked_add(Duration::from_secs(token.expires_in))
.ok_or_else(|| MyIdError::internal("expires_at overflow"))?;
let mut guard = self.token.lock().await;
*guard = Some(TokenState {
access_token: token.access_token.clone(),
expires_at,
});
Ok(token.access_token)
}
fn endpoint(&self, path: &str) -> MyIdResult<Url> {
self.config
.base_url_parsed()
.join(path)
.map_err(|e| MyIdError::config(format!("invalid endpoint `{path}`: {e}")))
}
async fn handle_response<T: serde::de::DeserializeOwned>(
response: reqwest::Response,
) -> MyIdResult<T> {
if response.status().is_success() {
return Ok(response.json().await?);
}
let status = response.status().as_u16();
let body = response
.text()
.await
.unwrap_or_else(|_| "response body o'qib bo'lmadi".to_string());
let message = serde_json::from_str::<ApiErrorBody>(&body)
.map(|e| e.message())
.unwrap_or(body);
Err(MyIdError::api(status, message))
}
fn build_http_client(config: &Config) -> MyIdResult<Client> {
let mut builder = Client::builder()
.connect_timeout(config.connection_timeout())
.timeout(config.timeout())
.user_agent(config.user_agent());
if let Some(proxy_url) = config.proxy_url() {
let proxy = Proxy::all(proxy_url)?;
builder = builder.proxy(proxy);
}
Ok(builder.build()?)
}
fn is_retryable_auth_error(err: &MyIdError) -> bool {
match err {
MyIdError::Api { status, .. } => *status == 429 || (500..=599).contains(status),
MyIdError::Http(source) => {
source.is_timeout() || source.is_connect() || source.is_request()
}
_ => false,
}
}
fn auth_retry_backoff(attempt: u8) -> Duration {
let shift = u32::from(attempt.saturating_sub(1));
let factor = 1_u64.checked_shl(shift).unwrap_or(u64::MAX);
let millis = AUTH_RETRY_BASE_MS
.saturating_mul(factor)
.min(AUTH_RETRY_MAX_MS);
Duration::from_millis(millis)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auth_retry_backoff_grows_and_caps() {
assert_eq!(
MyIdClient::auth_retry_backoff(1),
Duration::from_millis(100)
);
assert_eq!(
MyIdClient::auth_retry_backoff(2),
Duration::from_millis(200)
);
assert_eq!(
MyIdClient::auth_retry_backoff(3),
Duration::from_millis(400)
);
assert_eq!(
MyIdClient::auth_retry_backoff(4),
Duration::from_millis(800)
);
assert_eq!(
MyIdClient::auth_retry_backoff(10),
Duration::from_millis(2_000)
);
}
#[test]
fn retryable_auth_api_statuses() {
assert!(MyIdClient::is_retryable_auth_error(&MyIdError::api(
429,
"rate limit"
)));
assert!(MyIdClient::is_retryable_auth_error(&MyIdError::api(
500, "internal"
)));
assert!(MyIdClient::is_retryable_auth_error(&MyIdError::api(
502,
"bad gateway"
)));
assert!(MyIdClient::is_retryable_auth_error(&MyIdError::api(
503,
"unavailable"
)));
}
#[test]
fn non_retryable_auth_api_statuses() {
assert!(!MyIdClient::is_retryable_auth_error(&MyIdError::api(
400,
"bad request"
)));
assert!(!MyIdClient::is_retryable_auth_error(&MyIdError::api(
401,
"unauthorized"
)));
assert!(!MyIdClient::is_retryable_auth_error(&MyIdError::api(
403,
"forbidden"
)));
}
}