pub mod error;
use openidconnect::OAuth2TokenResponse;
use openidconnect::{AsyncHttpClient, core::CoreUserInfoClaims};
use serde::{Deserialize, Serialize};
use tracing::trace;
#[derive(Clone, Debug, Deserialize, Serialize)]
struct DeviceEndpointProviderMetadata {
device_authorization_endpoint: openidconnect::DeviceAuthorizationUrl,
}
impl openidconnect::AdditionalProviderMetadata for DeviceEndpointProviderMetadata {}
type DeviceProviderMetadata = openidconnect::ProviderMetadata<
DeviceEndpointProviderMetadata,
openidconnect::core::CoreAuthDisplay,
openidconnect::core::CoreClientAuthMethod,
openidconnect::core::CoreClaimName,
openidconnect::core::CoreClaimType,
openidconnect::core::CoreGrantType,
openidconnect::core::CoreJweContentEncryptionAlgorithm,
openidconnect::core::CoreJweKeyManagementAlgorithm,
openidconnect::core::CoreJsonWebKey,
openidconnect::core::CoreResponseMode,
openidconnect::core::CoreResponseType,
openidconnect::core::CoreSubjectIdentifierType,
>;
type OidcClient = openidconnect::core::CoreClient<
openidconnect::EndpointSet,
openidconnect::EndpointNotSet,
openidconnect::EndpointNotSet,
openidconnect::EndpointNotSet,
openidconnect::EndpointMaybeSet,
openidconnect::EndpointMaybeSet,
>;
#[derive(Clone, Serialize, Deserialize)]
pub struct AccessToken(pub(crate) String);
impl std::fmt::Debug for AccessToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("AccessToken").field(&"********").finish()
}
}
impl AsRef<str> for AccessToken {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RefreshToken(String);
impl std::fmt::Debug for RefreshToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("RefreshToken").field(&"********").finish()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub username: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub email: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Token {
pub access_token: AccessToken,
pub refresh_token: Option<RefreshToken>,
lifespan: Option<std::time::Duration>,
created_at: std::time::Instant,
}
impl Token {
pub fn expires_in(&self) -> Option<std::time::Duration> {
self.lifespan
.map(|t| t - (std::time::Instant::now() - self.created_at))
}
}
#[derive(Debug, Clone)]
pub struct Oidc {
http_client: HttpClient,
oidc_client: OidcClient,
provider_metadata: DeviceProviderMetadata,
}
#[derive(Debug, Clone)]
pub struct IssuerUrl {
inner: openidconnect::IssuerUrl,
}
impl std::str::FromStr for IssuerUrl {
type Err = error::UrlParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self {
inner: openidconnect::IssuerUrl::from_url(
s.parse()
.map_err(|e| error::UrlParseError(format!("{e}")))?,
),
})
}
}
#[derive(Debug, Clone, Serialize)]
pub struct VerificationDetails {
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
}
impl Oidc {
pub async fn new(client_id: &str, issuer_url: IssuerUrl) -> Result<Self, error::ConfigError> {
let http_client = HttpClient::new()?;
let provider_metadata =
DeviceProviderMetadata::discover_async(issuer_url.inner, &http_client).await?;
trace!(?provider_metadata, "provider metadata");
let oidc_client = openidconnect::core::CoreClient::from_provider_metadata(
provider_metadata.clone(),
openidconnect::ClientId::new(client_id.into()),
None,
);
Ok(Self {
http_client,
oidc_client,
provider_metadata,
})
}
pub async fn user_info(
&self,
access_token: &AccessToken,
) -> Result<UserInfo, error::UserInfoError> {
let userinfo: CoreUserInfoClaims = self
.oidc_client
.user_info(
openidconnect::AccessToken::new(access_token.0.to_string()),
None,
)?
.request_async(&self.http_client)
.await?;
trace!(?userinfo, "user info");
fn localized_claim<T: core::ops::Deref<Target = String>>(
name: Option<&openidconnect::LocalizedClaim<T>>,
) -> Option<String> {
name.and_then(|claim| claim.get(None))
.map(|value| value.deref().clone())
}
Ok(UserInfo {
given_name: localized_claim(userinfo.given_name()),
family_name: localized_claim(userinfo.family_name()),
username: userinfo
.preferred_username()
.map(|username| username.to_string()),
email: userinfo.email().map(|email| email.to_string()),
})
}
pub async fn authenticate<E>(
&self,
auth_handler: impl FnOnce(VerificationDetails) -> Result<(), E>,
) -> Result<Token, error::AuthError<E>> {
let client = self
.oidc_client
.clone()
.set_device_authorization_url(
self.provider_metadata
.additional_metadata()
.device_authorization_endpoint
.clone(),
)
.set_auth_type(openidconnect::AuthType::RequestBody);
let (pkce_challenge, pkce_verifier) = openidconnect::PkceCodeChallenge::new_random_sha256();
let details: openidconnect::core::CoreDeviceAuthorizationResponse = client
.exchange_device_code()
.add_extra_param("code_challenge", pkce_challenge.as_str())
.add_extra_param("code_challenge_method", pkce_challenge.method().as_str())
.request_async(&self.http_client)
.await?;
auth_handler(VerificationDetails {
user_code: details.user_code().secret().to_string(),
verification_uri: details.verification_uri().to_string(),
verification_uri_complete: details
.verification_uri_complete()
.map(|uri| uri.secret().to_string()),
})
.map_err(error::AuthError::Handler)?;
let token = client
.exchange_device_access_token(&details)?
.add_extra_param("code_verifier", pkce_verifier.secret())
.request_async(&self.http_client, tokio::time::sleep, None)
.await?;
Ok(Token {
access_token: AccessToken(token.access_token().secret().to_string()),
refresh_token: token
.refresh_token()
.map(|t| RefreshToken(t.secret().to_string())),
lifespan: token.expires_in(),
created_at: std::time::Instant::now(),
})
}
pub async fn refresh(
&self,
refresh_token: &RefreshToken,
) -> Result<Token, error::RefreshTokenError> {
let token = self
.oidc_client
.exchange_refresh_token(&openidconnect::RefreshToken::new(refresh_token.0.clone()))?
.request_async(&self.http_client)
.await?;
Ok(Token {
access_token: AccessToken(token.access_token().secret().to_string()),
refresh_token: token
.refresh_token()
.map(|t| RefreshToken(t.secret().to_string())),
lifespan: token.expires_in(),
created_at: std::time::Instant::now(),
})
}
}
#[derive(Debug, Clone)]
struct HttpClient {
client: reqwest::Client,
}
impl HttpClient {
fn new() -> Result<Self, reqwest::Error> {
let client = reqwest::Client::builder()
.use_rustls_tls()
.redirect(reqwest::redirect::Policy::none())
.build()?;
Ok(Self { client })
}
}
impl<'c> AsyncHttpClient<'c> for HttpClient {
type Error = reqwest::Error;
type Future = std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<openidconnect::HttpResponse, Self::Error>>
+ 'c
+ Send,
>,
>;
fn call(&'c self, request: openidconnect::HttpRequest) -> Self::Future {
Box::pin(async move {
let original_response = self.client.execute(request.try_into()?).await?;
let version = original_response.version();
let status = original_response.status();
let headers = original_response.headers().clone();
let mut response =
openidconnect::HttpResponse::new(original_response.bytes().await?.to_vec());
*response.version_mut() = version;
*response.status_mut() = status;
*response.headers_mut() = headers;
Ok(response)
})
}
}