use anyhow::anyhow;
use openidconnect::core::CoreUserInfoClaims;
use openidconnect::OAuth2TokenResponse;
use serde::{Deserialize, Serialize};
use tracing::trace;
#[derive(Clone, Debug, Deserialize, Serialize)]
struct DeviceEndpointProviderMetadata {
device_authorization_endpoint: openidconnect::DeviceAuthorizationUrl,
}
impl openidconnect::AdditionalProviderMetadata for DeviceEndpointProviderMetadata {}
type DeviceProviderMetadata = openidconnect::ProviderMetadata<
DeviceEndpointProviderMetadata,
openidconnect::core::CoreAuthDisplay,
openidconnect::core::CoreClientAuthMethod,
openidconnect::core::CoreClaimName,
openidconnect::core::CoreClaimType,
openidconnect::core::CoreGrantType,
openidconnect::core::CoreJweContentEncryptionAlgorithm,
openidconnect::core::CoreJweKeyManagementAlgorithm,
openidconnect::core::CoreJwsSigningAlgorithm,
openidconnect::core::CoreJsonWebKeyType,
openidconnect::core::CoreJsonWebKeyUse,
openidconnect::core::CoreJsonWebKey,
openidconnect::core::CoreResponseMode,
openidconnect::core::CoreResponseType,
openidconnect::core::CoreSubjectIdentifierType,
>;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AccessToken(pub(crate) String);
#[derive(Clone, Serialize, Deserialize)]
pub struct RefreshToken(String);
#[derive(Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub username: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub email: Option<String>,
}
#[derive(Clone)]
pub struct Token {
pub access_token: AccessToken,
pub refresh_token: Option<RefreshToken>,
lifespan: Option<std::time::Duration>,
created_at: std::time::Instant,
}
impl Token {
pub fn expires_in(&self) -> Option<std::time::Duration> {
self.lifespan
.map(|t| t - (std::time::Instant::now() - self.created_at))
}
}
pub struct Oidc {
client_id: openidconnect::ClientId,
issuer_url: openidconnect::IssuerUrl,
}
#[derive(Clone, Serialize)]
pub struct VerificationDetails {
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
}
type AuthHandler = dyn FnOnce(&VerificationDetails) -> anyhow::Result<()> + Send;
impl Oidc {
pub fn new(client_id: &str, issuer_url: &str) -> anyhow::Result<Self> {
Ok(Self {
client_id: openidconnect::ClientId::new(client_id.into()),
issuer_url: openidconnect::IssuerUrl::new(issuer_url.into())?,
})
}
async fn client(&self) -> anyhow::Result<openidconnect::core::CoreClient> {
let provider_metadata = openidconnect::core::CoreProviderMetadata::discover_async(
self.issuer_url.clone(),
http_client,
)
.await?;
trace!(?provider_metadata, "provider metadata");
Ok(openidconnect::core::CoreClient::from_provider_metadata(
provider_metadata,
self.client_id.clone(),
None,
))
}
pub async fn user_info(&self, access_token: &AccessToken) -> anyhow::Result<UserInfo> {
let client = self.client().await?;
let userinfo: CoreUserInfoClaims = client
.user_info(
openidconnect::AccessToken::new(access_token.0.to_string()),
None,
)
.map_err(|err| anyhow!("No user info endpoint: {:?}", err))?
.request_async(http_client)
.await
.map_err(|err| anyhow!("Failed requesting user info: {:?}", err))?;
trace!(?userinfo, "user info");
macro_rules! localized_claim {
($input:expr) => {
$input
.and_then(|claim| claim.get(None))
.map(|value| value.to_string())
};
}
Ok(UserInfo {
given_name: localized_claim!(userinfo.given_name()),
family_name: localized_claim!(userinfo.family_name()),
username: userinfo
.preferred_username()
.map(|username| username.to_string()),
email: userinfo.email().map(|email| email.to_string()),
})
}
pub async fn authenticate(&self, auth_handler: Box<AuthHandler>) -> anyhow::Result<Token> {
let provider_metadata =
DeviceProviderMetadata::discover_async(self.issuer_url.clone(), http_client).await?;
trace!(?provider_metadata, "provider metadata");
let device_authorization_uri = provider_metadata
.additional_metadata()
.device_authorization_endpoint
.clone();
let client = openidconnect::core::CoreClient::from_provider_metadata(
provider_metadata,
self.client_id.clone(),
None,
)
.set_device_authorization_uri(device_authorization_uri)
.set_auth_type(openidconnect::AuthType::RequestBody);
let (pkce_challenge, pkce_verifier) = openidconnect::PkceCodeChallenge::new_random_sha256();
let details: openidconnect::core::CoreDeviceAuthorizationResponse = client
.exchange_device_code()?
.add_extra_param("code_challenge", pkce_challenge.as_str())
.add_extra_param("code_challenge_method", pkce_challenge.method().as_str())
.request_async(http_client)
.await?;
auth_handler(&VerificationDetails {
user_code: details.user_code().secret().to_string(),
verification_uri: details.verification_uri().to_string(),
verification_uri_complete: details
.verification_uri_complete()
.map(|uri| uri.secret().to_string()),
})?;
let token = client
.exchange_device_access_token(&details)
.add_extra_param("code_verifier", pkce_verifier.secret())
.request_async(http_client, tokio::time::sleep, None)
.await?;
Ok(Token {
access_token: AccessToken(token.access_token().secret().to_string()),
refresh_token: token
.refresh_token()
.map(|t| RefreshToken(t.secret().to_string())),
lifespan: token.expires_in(),
created_at: std::time::Instant::now(),
})
}
pub async fn refresh(&self, refresh_token: &RefreshToken) -> anyhow::Result<Token> {
let client = self.client().await?;
let token = client
.exchange_refresh_token(&openidconnect::RefreshToken::new(refresh_token.0.clone()))
.request_async(http_client)
.await?;
Ok(Token {
access_token: AccessToken(token.access_token().secret().to_string()),
refresh_token: token
.refresh_token()
.map(|t| RefreshToken(t.secret().to_string())),
lifespan: token.expires_in(),
created_at: std::time::Instant::now(),
})
}
}
async fn http_client(
request: openidconnect::HttpRequest,
) -> Result<openidconnect::HttpResponse, reqwest::Error> {
let client = reqwest::Client::builder()
.use_rustls_tls()
.redirect(reqwest::redirect::Policy::none())
.build()?;
let mut request_builder = client
.request(
reqwest::Method::from_bytes(request.method.as_str().as_bytes()).expect("valid method"),
request.url.as_str(),
)
.body(request.body);
for (name, value) in &request.headers {
request_builder = request_builder.header(name.as_str(), value.as_bytes());
}
let response = client.execute(request_builder.build()?).await?;
let status_code = openidconnect::http::StatusCode::from_u16(response.status().as_u16())
.expect("valid status code");
let headers = response
.headers()
.iter()
.map(|(key, value)| {
(
key.as_str()
.as_bytes()
.try_into()
.expect("valid header name"),
value.as_bytes().try_into().expect("valid header value"),
)
})
.collect();
Ok(openidconnect::HttpResponse {
status_code,
headers,
body: response.bytes().await?.to_vec(),
})
}