use anyhow::{Result, Context};
use base64::Engine;
use chrono::{DateTime, Utc};
use oauth2::{
AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl, TokenResponse as OAuthTokenResponse,
};
use reqwest::Client as HttpClient;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use crate::config::AuthConfig;
const GRAPH_API_BASE_URL: &str = "https://graph.microsoft.com/v1.0";
const TOKEN_URL: &str = "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token";
const AUTH_URL: &str = "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthToken {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: DateTime<Utc>,
pub token_type: String,
}
pub struct OAuth2Client {
tenant_id: String,
client_id: String,
client_secret: String,
redirect_uri: String,
http_client: HttpClient,
current_token: Arc<RwLock<Option<OAuthToken>>>,
}
impl OAuth2Client {
pub fn new(config: &AuthConfig) -> Result<Self> {
Ok(Self {
tenant_id: config.tenant_id.clone(),
client_id: config.client_id.clone(),
client_secret: config.client_secret.clone(),
redirect_uri: config.redirect_uri.clone(),
http_client: HttpClient::new(),
current_token: Arc::new(RwLock::new(None)),
})
}
pub fn get_authorization_url(&self, state: &str) -> Result<String> {
let auth_url = AUTH_URL.replace("{tenant}", &self.tenant_id);
let scope = "https://graph.microsoft.com/.default offline_access";
let url = format!(
"{}?client_id={}&response_type=code&redirect_uri={}&scope={}&state={}",
auth_url,
self.client_id,
urlencoding::encode(&self.redirect_uri),
urlencoding::encode(scope),
state
);
Ok(url)
}
pub async fn exchange_code_for_token(&self, code: &str) -> Result<OAuthToken> {
let token_url = TOKEN_URL.replace("{tenant}", &self.tenant_id);
let params = [
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("code", code),
("redirect_uri", self.redirect_uri.as_str()),
("grant_type", "authorization_code"),
];
let response = self
.http_client
.post(&token_url)
.form(¶ms)
.send()
.await
.context("failed to exchange authorization code")?;
let token_response: TokenResponse = response
.json()
.await
.context("failed to parse token response")?;
let token = OAuthToken {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_at: Utc::now() + chrono::Duration::seconds(token_response.expires_in),
token_type: token_response.token_type,
};
let mut current_token = self.current_token.write().await;
*current_token = Some(token.clone());
Ok(token)
}
pub async fn refresh_access_token(&self) -> Result<OAuthToken> {
let current_token = self.current_token.read().await;
let refresh_token = current_token
.as_ref()
.and_then(|t| t.refresh_token.clone())
.ok_or_else(|| anyhow::anyhow!("no refresh token available"))?;
let token_url = TOKEN_URL.replace("{tenant}", &self.tenant_id);
let params = [
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("refresh_token", refresh_token.as_str()),
("grant_type", "refresh_token"),
("scope", "https://graph.microsoft.com/.default"),
];
let response = self
.http_client
.post(&token_url)
.form(¶ms)
.send()
.await
.context("failed to refresh access token")?;
let token_response: TokenResponse = response
.json()
.await
.context("failed to parse token response")?;
let token = OAuthToken {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token.or(Some(refresh_token)),
expires_at: Utc::now() + chrono::Duration::seconds(token_response.expires_in),
token_type: token_response.token_type,
};
let mut current_token = self.current_token.write().await;
*current_token = Some(token.clone());
Ok(token)
}
pub async fn get_access_token(&self) -> Result<String> {
let current_token = self.current_token.read().await;
if let Some(token) = current_token.as_ref() {
if token.expires_at > Utc::now() + chrono::Duration::seconds(60) {
return Ok(token.access_token.clone());
}
}
drop(current_token);
let new_token = self.refresh_access_token().await?;
Ok(new_token.access_token)
}
pub async fn set_token(&self, token: OAuthToken) {
let mut current_token = self.current_token.write().await;
*current_token = Some(token);
}
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
refresh_token: Option<String>,
expires_in: i64,
token_type: String,
}