use std::time::Duration;
use anyhow::Result;
use chrono::Utc;
use oauth2::{
AuthUrl, ClientId, RefreshToken, ResourceOwnerPassword, ResourceOwnerUsername, Scope,
TokenResponse as _, TokenUrl, basic::BasicClient,
};
use opentalk_client_data_persistence::{AccountTokens, DataManager};
use secrecy::{ExposeSecret, SecretString};
use crate::{Authorization, oidc::OidcEndpoints, oidc_authorization::REFRESH_BEFORE_EXPIRY};
#[derive(Debug)]
pub struct OidcDirectAccessGrant {
data_manager: Box<dyn DataManager>,
oidc_endpoints: OidcEndpoints,
oidc_client_id: String,
}
#[async_trait::async_trait(?Send)]
impl Authorization for OidcDirectAccessGrant {
async fn get_access_token(&self) -> Result<String> {
self.get_token_and_refresh_if_needed(REFRESH_BEFORE_EXPIRY)
.await
}
}
#[async_trait::async_trait(?Send)]
impl Authorization for &OidcDirectAccessGrant {
async fn get_access_token(&self) -> Result<String> {
Authorization::get_access_token(*self).await
}
}
impl OidcDirectAccessGrant {
pub async fn get_token_and_refresh_if_needed(
&self,
refresh_before_expiry: Duration,
) -> Result<String> {
let AccountTokens {
access_token_expiry,
access_token,
..
} = self.data_manager.load_account_tokens()?;
let now = Utc::now();
if now + refresh_before_expiry > access_token_expiry {
Ok(self.refresh_token().await?)
} else {
Ok(access_token)
}
}
pub async fn refresh_token(&self) -> Result<String> {
let AccountTokens { refresh_token, .. } = self.data_manager.load_account_tokens()?;
let client = BasicClient::new(ClientId::new(self.oidc_client_id.clone()))
.set_auth_uri(
AuthUrl::new(self.oidc_endpoints.authorization_endpoint.to_string()).unwrap(),
)
.set_token_uri(TokenUrl::new(self.oidc_endpoints.token_endpoint.to_string()).unwrap());
let builder = reqwest::ClientBuilder::new();
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
let builder = {
builder.redirect(reqwest::redirect::Policy::none())
};
let http_client = builder.build().expect("Client should build");
let http_client = super::ClientWrapper(http_client);
let response = client
.exchange_refresh_token(&RefreshToken::new(refresh_token))
.request_async(&http_client)
.await
.unwrap();
let now = Utc::now();
let account_tokens = AccountTokens {
access_token_expiry: now + response.expires_in().unwrap_or_default(),
access_token: response.access_token().secret().clone(),
refresh_token: response.refresh_token().unwrap().secret().clone(),
};
let _ = self
.data_manager
.store_account_tokens(account_tokens.clone());
Ok(account_tokens.access_token)
}
pub async fn create_with_direct_access_grant(
data_manager: Box<dyn DataManager>,
oidc_endpoints: OidcEndpoints,
oidc_client_id: String,
oidc_user: String,
oidc_password: SecretString,
) -> Result<Self> {
let oidc_client = BasicClient::new(ClientId::new(oidc_client_id.clone()))
.set_auth_uri(AuthUrl::new(oidc_endpoints.authorization_endpoint.to_string()).unwrap())
.set_token_uri(TokenUrl::new(oidc_endpoints.token_endpoint.to_string()).unwrap());
let builder = reqwest::ClientBuilder::new();
#[cfg(not(all(target_family = "wasm", target_os = "unknown")))]
let builder = {
builder.redirect(reqwest::redirect::Policy::none())
};
let http_client = builder.build().expect("Client should build");
let http_client = super::ClientWrapper(http_client);
let token_result = oidc_client
.exchange_password(
&ResourceOwnerUsername::new(oidc_user.clone()),
&ResourceOwnerPassword::new(oidc_password.expose_secret().to_string()),
)
.add_scope(Scope::new("openid".to_string()))
.request_async(&http_client)
.await
.unwrap();
let now = Utc::now();
let account_tokens = AccountTokens {
access_token_expiry: now + token_result.expires_in().unwrap_or_default(),
access_token: token_result.access_token().clone().into_secret(),
refresh_token: token_result
.refresh_token()
.expect("Refresh token should be exist")
.clone()
.into_secret(),
};
data_manager.store_account_tokens(account_tokens.clone())?;
println!("{:?}", token_result);
Ok(Self {
data_manager,
oidc_endpoints,
oidc_client_id,
})
}
}