use reqwest_retry::{default_on_request_failure, Retryable, RetryableStrategy};
use secrecy::SecretString;
use std::fmt::{Display, Formatter};
use url::ParseError;
use chrono::Utc;
use reqwest::{Response, StatusCode};
use serde::{Deserialize, Serialize};
use crate::{
constants::{GRANT_TYPE_AUTH_CODE, GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN},
utils::parse_body,
};
use super::{errors::DracoonClientError, Connection, DracoonClient};
pub(crate) trait GetClient<S> {
fn get_client(&self) -> &DracoonClient<S>;
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuth2PasswordFlow {
pub username: String,
pub password: String,
pub grant_type: String,
}
impl OAuth2PasswordFlow {
pub fn new(username: &str, password: &str) -> Self {
Self {
username: username.to_string(),
password: password.to_string(),
grant_type: GRANT_TYPE_PASSWORD.to_string(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuth2AuthCodeFlow {
pub client_id: String,
pub client_secret: String,
pub grant_type: String,
pub code: String,
pub redirect_uri: String,
}
impl OAuth2AuthCodeFlow {
pub fn new(client_id: &str, client_secret: &str, code: &str, redirect_uri: &str) -> Self {
Self {
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
grant_type: GRANT_TYPE_AUTH_CODE.to_string(),
code: code.to_string(),
redirect_uri: redirect_uri.to_string(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuth2RefreshTokenFlow {
client_id: String,
client_secret: String,
grant_type: String,
refresh_token: String,
}
impl OAuth2RefreshTokenFlow {
pub fn new(client_id: &str, client_secret: &str, refresh_token: &str) -> Self {
Self {
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
grant_type: GRANT_TYPE_REFRESH_TOKEN.to_string(),
refresh_token: refresh_token.to_string(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuth2TokenRevoke {
client_id: String,
client_secret: String,
token_type_hint: String,
token: String,
}
impl OAuth2TokenRevoke {
pub fn new(client_id: &str, client_secret: &str, token_type_hint: &str, token: &str) -> Self {
Self {
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
token_type_hint: token_type_hint.to_string(),
token: token.to_string(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuth2TokenResponse {
access_token: String,
refresh_token: String,
token_type: Option<String>,
expires_in: u64,
expires_in_inactive: Option<u64>,
scope: Option<String>,
}
#[cfg_attr(feature = "mcp", derive(schemars::JsonSchema))]
#[derive(Deserialize, Debug, PartialEq, Clone)]
#[serde(rename_all = "camelCase")]
pub struct DracoonErrorResponse {
code: i32,
message: String,
debug_info: Option<String>,
error_code: Option<i32>,
}
impl Display for DracoonErrorResponse {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let dbg_info = self.debug_info.as_deref().unwrap_or("No details");
let error_code = self.error_code.unwrap_or(0);
write!(
f,
"{} {} - {dbg_info} ({})",
self.code, self.message, error_code
)
}
}
impl DracoonErrorResponse {
pub fn new(code: i32, message: &str) -> Self {
Self {
code,
message: message.to_string(),
debug_info: None,
error_code: None,
}
}
pub fn is_forbidden(&self) -> bool {
self.code == 403
}
pub fn is_not_found(&self) -> bool {
self.code == 404
}
pub fn is_conflict(&self) -> bool {
self.code == 409
}
pub fn is_too_many_requests(&self) -> bool {
self.code == 429
}
pub fn is_server_error(&self) -> bool {
self.code >= 500
}
pub fn is_client_error(&self) -> bool {
self.code >= 400 && self.code < 500
}
pub fn is_unauthorized(&self) -> bool {
self.code == 401
}
pub fn is_bad_request(&self) -> bool {
self.code == 400
}
pub fn is_payment_required(&self) -> bool {
self.code == 402
}
pub fn is_precondition_failed(&self) -> bool {
self.code == 412
}
pub fn error_code(&self) -> Option<i32> {
self.error_code
}
pub fn code(&self) -> i32 {
self.code
}
pub fn error_message(&self) -> String {
self.message.clone()
}
pub fn debug_info(&self) -> Option<String> {
self.debug_info.clone()
}
}
#[cfg_attr(feature = "mcp", derive(schemars::JsonSchema))]
#[derive(Deserialize, Debug, PartialEq, Clone)]
#[serde(rename_all = "camelCase")]
pub struct DracoonAuthErrorResponse {
error: String,
error_description: Option<String>,
}
impl DracoonAuthErrorResponse {
pub fn new_unauthorized() -> Self {
Self {
error: "Unauthorized".to_string(),
error_description: None,
}
}
}
impl Display for DracoonAuthErrorResponse {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Error: {} ({})",
self.error_description
.clone()
.unwrap_or_else(|| "Unknown".to_string()),
self.error
)
}
}
impl OAuth2TokenResponse {
pub async fn from_response(res: Response) -> Result<Self, DracoonClientError> {
parse_body::<Self, DracoonAuthErrorResponse>(res).await
}
}
pub enum StatusCodeState {
Ok(StatusCode),
Error(StatusCode),
}
impl From<StatusCode> for StatusCodeState {
fn from(value: StatusCode) -> Self {
match value {
StatusCode::OK
| StatusCode::CREATED
| StatusCode::ACCEPTED
| StatusCode::NO_CONTENT => StatusCodeState::Ok(value),
_ => StatusCodeState::Error(value),
}
}
}
impl From<OAuth2TokenResponse> for Connection {
fn from(value: OAuth2TokenResponse) -> Self {
Self {
connected_at: Utc::now(),
access_token: SecretString::from(value.access_token),
refresh_token: SecretString::from(value.refresh_token),
expires_in: value.expires_in,
}
}
}
impl From<DracoonAuthErrorResponse> for DracoonClientError {
fn from(value: DracoonAuthErrorResponse) -> Self {
Self::Auth(value)
}
}
impl From<DracoonErrorResponse> for DracoonClientError {
fn from(value: DracoonErrorResponse) -> Self {
Self::Http(value)
}
}
impl From<ParseError> for DracoonClientError {
fn from(_v: ParseError) -> Self {
Self::InvalidUrl("parsing url failed (invalid)".to_string())
}
}
pub(crate) struct DracoonCustomRetryStrategy;
impl RetryableStrategy for DracoonCustomRetryStrategy {
fn handle(
&self,
res: &Result<reqwest::Response, reqwest_middleware::Error>,
) -> Option<Retryable> {
match res {
Ok(success) => default_on_request_success(success),
Err(error) => default_on_request_failure(error),
}
}
}
fn default_on_request_success(success: &reqwest::Response) -> Option<Retryable> {
let status = success.status();
if status.is_server_error() {
Some(Retryable::Transient)
} else if status.is_client_error()
&& status != StatusCode::REQUEST_TIMEOUT
&& status != StatusCode::TOO_MANY_REQUESTS
&& status != StatusCode::UNAUTHORIZED
{
Some(Retryable::Fatal)
} else if status.is_success() {
None
} else if status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
|| status == StatusCode::UNAUTHORIZED
{
if status == StatusCode::UNAUTHORIZED && success.url().path().starts_with("/oauth") {
Some(Retryable::Fatal)
} else {
Some(Retryable::Transient)
}
} else {
Some(Retryable::Fatal)
}
}