use crate::error::{AuthError, CallbackError, RefreshError};
use crate::jwks::{JwksValidator, JwksValidatorStorage, RemoteJwksValidator};
use crate::oidc::OpenIdConfiguration;
#[non_exhaustive]
pub enum OidcJwksConfig {
Enabled(JwksValidatorStorage),
Disabled,
}
use crate::pages::{
ErrorPageRenderer, ErrorRendererStorage, SuccessPageRenderer, SuccessRendererStorage,
};
use crate::scope::{OAuth2Scope, RequestScope};
use crate::server::{
CallbackResult, HttpTransport, PortConfig, RenderedHtml, ServerState, Transport, bind_listener,
};
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc, oneshot};
struct AuthUrlParams<'a> {
client_id: &'a str,
redirect_uri: &'a url::Url,
state_token: &'a str,
pkce: &'a crate::pkce::PkceChallenge,
nonce: Option<&'a str>,
scopes: &'a [OAuth2Scope],
}
impl AuthUrlParams<'_> {
const KEYS: &'static [&'static str] = &[
"response_type",
"client_id",
"redirect_uri",
"state",
"code_challenge",
"code_challenge_method",
"nonce",
"scope",
];
fn append_to(&self, url: &mut url::Url) {
url.query_pairs_mut()
.append_pair("response_type", "code")
.append_pair("client_id", self.client_id)
.append_pair("redirect_uri", self.redirect_uri.as_str())
.append_pair("state", self.state_token)
.append_pair("code_challenge", &self.pkce.code_challenge)
.append_pair("code_challenge_method", self.pkce.code_challenge_method);
if let Some(nonce) = self.nonce {
url.query_pairs_mut().append_pair("nonce", nonce);
}
if !self.scopes.is_empty() {
let scope_str = self
.scopes
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(" ");
url.query_pairs_mut().append_pair("scope", &scope_str);
}
}
}
pub struct ExtraAuthParams {
pairs: Vec<(String, String)>,
}
impl ExtraAuthParams {
const fn new() -> Self {
Self { pairs: Vec::new() }
}
pub fn append(&mut self, key: impl Into<String>, value: impl Into<String>) -> &mut Self {
self.pairs.push((key.into(), value.into()));
self
}
fn apply_to(self, url: &mut url::Url) {
for (key, value) in self.pairs {
if AuthUrlParams::KEYS.contains(&key.as_str()) {
tracing::warn!(
key = key.as_str(),
"on_auth_url hook attempted to set a reserved parameter; ignoring"
);
} else {
url.query_pairs_mut().append_pair(&key, &value);
}
}
}
}
type OnAuthUrlCallback = Box<dyn Fn(&mut ExtraAuthParams) + Send + Sync + 'static>;
type OnUrlCallback = Box<dyn Fn(&url::Url) + Send + Sync + 'static>;
type OnServerReadyCallback = Box<dyn Fn(u16) + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub struct ClientId(String);
impl ClientId {
pub(crate) fn as_str(&self) -> &str {
&self.0
}
}
const TIMEOUT_DURATION_IN_SECONDS: u64 = 300;
const HTTP_CONNECT_TIMEOUT_SECONDS: u64 = 10;
const HTTP_REQUEST_TIMEOUT_SECONDS: u64 = 30;
pub struct CliTokenClient {
client_id: ClientId,
client_secret: Option<String>,
auth_url: url::Url,
token_url: url::Url,
issuer: Option<url::Url>,
scopes: Vec<OAuth2Scope>,
port_config: PortConfig,
success_html: Option<String>,
error_html: Option<String>,
success_renderer: Option<SuccessRendererStorage>,
error_renderer: Option<ErrorRendererStorage>,
open_browser: bool,
timeout: std::time::Duration,
on_auth_url: Option<OnAuthUrlCallback>,
on_url: Option<OnUrlCallback>,
on_server_ready: Option<OnServerReadyCallback>,
oidc_jwks: Option<OidcJwksConfig>,
http_client: reqwest::Client,
transport: Arc<dyn Transport>,
}
impl CliTokenClient {
#[must_use]
pub fn builder() -> CliTokenClientBuilder {
CliTokenClientBuilder::default()
}
pub async fn run_authorization_flow(&self) -> Result<crate::token::TokenSet, AuthError> {
let listener = bind_listener(self.port_config)
.await
.map_err(AuthError::ServerBind)?;
let redirect_uri_url = self
.transport
.redirect_uri(&listener)
.map_err(AuthError::ServerBind)
.and_then(|redirect_uri| {
url::Url::parse(&redirect_uri).map_err(AuthError::InvalidUrl)
})?;
let pkce = crate::pkce::PkceChallenge::generate();
let state_token = uuid::Uuid::new_v4().to_string();
let nonce = self
.oidc_jwks
.is_some()
.then(|| uuid::Uuid::new_v4().to_string());
let mut auth_url = self.auth_url.clone();
AuthUrlParams {
client_id: self.client_id.as_str(),
redirect_uri: &redirect_uri_url,
state_token: &state_token,
pkce: &pkce,
nonce: nonce.as_deref(),
scopes: &self.scopes,
}
.append_to(&mut auth_url);
if let Some(ref hook) = self.on_auth_url {
let mut extras = ExtraAuthParams::new();
hook(&mut extras);
extras.apply_to(&mut auth_url);
}
let (outer_tx, outer_rx) = mpsc::channel::<CallbackResult>(1);
let (inner_tx, inner_rx) = mpsc::channel::<RenderedHtml>(1);
let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>();
let server_state = ServerState {
outer_tx,
inner_rx: Arc::new(Mutex::new(Some(inner_rx))),
shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
};
let port = listener.local_addr().map_err(AuthError::ServerBind)?.port();
let shutdown_arc = Arc::clone(&server_state.shutdown_tx);
let transport = Arc::clone(&self.transport);
tokio::spawn(async move {
transport
.run_server(listener, server_state, shutdown_rx)
.await
});
if let Some(ref hook) = self.on_server_ready {
hook(port);
}
if let Some(ref hook) = self.on_url {
hook(&auth_url);
}
if self.open_browser {
webbrowser::open(auth_url.as_str()).map_err(|e| AuthError::Browser(e.to_string()))?;
} else {
tracing::info!(url = auth_url.as_str(), "authorization URL");
}
handle_callback(
self,
&redirect_uri_url,
&state_token,
&pkce.code_verifier,
nonce.as_deref(),
inner_tx,
outer_rx,
shutdown_arc,
)
.await
}
pub async fn refresh(
&self,
refresh_token: &str,
) -> Result<crate::token::TokenSet, RefreshError> {
if refresh_token.is_empty() {
return Err(RefreshError::NoRefreshToken);
}
let unvalidated = exchange_refresh_token(
&self.http_client,
&self.token_url,
self.client_id.as_str(),
self.client_secret.as_deref(),
refresh_token,
&self.scopes,
)
.await?;
if let Some(oidc_jwks) = &self.oidc_jwks {
validate_id_token_if_present(
oidc_jwks,
unvalidated,
self.client_id.as_str(),
self.issuer.as_ref().map_or(
crate::oidc::IssuerValidation::Skip,
crate::oidc::IssuerValidation::MustMatch,
),
)
.await
.map_err(RefreshError::IdToken)
} else {
Ok(unvalidated.into_validated())
}
}
pub async fn refresh_if_expiring(
&self,
tokens: &crate::token::TokenSet,
threshold: std::time::Duration,
) -> Result<crate::token::RefreshOutcome, RefreshError> {
if !tokens.expires_within(threshold) {
return Ok(crate::token::RefreshOutcome::NotNeeded);
}
let refresh_token = tokens.refresh_token().ok_or(RefreshError::NoRefreshToken)?;
let new_tokens = self.refresh(refresh_token.as_str()).await?;
Ok(crate::token::RefreshOutcome::Refreshed(Box::new(
new_tokens,
)))
}
}
#[derive(serde::Deserialize)]
struct TokenResponse {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
token_type: Option<String>,
id_token: Option<String>,
scope: Option<String>,
}
fn parse_oidc_if_requested(
id_token: Option<&str>,
scopes: &[crate::scope::OAuth2Scope],
) -> Result<Option<crate::oidc::Token>, crate::error::IdTokenError> {
if !scopes.contains(&crate::scope::OAuth2Scope::OpenId) {
return Ok(None);
}
id_token.map(crate::oidc::Token::from_raw_jwt).transpose()
}
fn parse_scopes(scope_str: &str) -> Vec<OAuth2Scope> {
scope_str
.split_whitespace()
.map(OAuth2Scope::from)
.collect()
}
async fn trigger_shutdown(shutdown_arc: &Arc<Mutex<Option<oneshot::Sender<()>>>>) {
let mut guard = shutdown_arc.lock().await;
if let Some(tx) = guard.take() {
let _ = tx.send(());
}
}
async fn resolve_callback_code(
callback_result: CallbackResult,
state_token: &str,
auth: &CliTokenClient,
redirect_uri_url: &url::Url,
inner_tx: &mpsc::Sender<RenderedHtml>,
) -> Result<String, CallbackError> {
match validate_callback_code(callback_result, state_token) {
Err(err) => {
let html = render_error_html(&err.clone().into(), auth, redirect_uri_url).await;
let _ = inner_tx.send(RenderedHtml(html)).await;
Err(err)
}
v => v,
}
}
fn validate_callback_code(
callback_result: CallbackResult,
state_token: &str,
) -> Result<String, CallbackError> {
use subtle::ConstantTimeEq as _;
match callback_result {
CallbackResult::Success { code, state }
if state.as_bytes().ct_eq(state_token.as_bytes()).into() =>
{
Ok(code)
}
CallbackResult::Success { .. } => Err(CallbackError::StateMismatch),
CallbackResult::ProviderError { error, description } => Err(CallbackError::ProviderError {
error,
description: description.unwrap_or_default(),
}),
}
}
async fn validate_id_token_required(
oidc_jwks: &OidcJwksConfig,
token_set: crate::token::TokenSet<crate::token::Unvalidated>,
client_id: &str,
issuer: crate::oidc::IssuerValidation<'_>,
expected_nonce: Option<&str>,
) -> Result<crate::token::TokenSet<crate::token::Validated>, crate::error::IdTokenError> {
use crate::error::IdTokenError;
let oidc = token_set.oidc_token().ok_or(IdTokenError::NoIdToken)?;
if let OidcJwksConfig::Enabled(validator) = oidc_jwks {
validator
.validate(oidc.raw())
.await
.map_err(IdTokenError::JwksValidationFailed)?;
}
oidc.validate_standard_claims(client_id, issuer, expected_nonce)?;
Ok(token_set.into_validated())
}
async fn validate_id_token_if_present(
oidc_jwks: &OidcJwksConfig,
token_set: crate::token::TokenSet<crate::token::Unvalidated>,
client_id: &str,
issuer: crate::oidc::IssuerValidation<'_>,
) -> Result<crate::token::TokenSet<crate::token::Validated>, crate::error::IdTokenError> {
use crate::error::IdTokenError;
let Some(oidc) = token_set.oidc_token() else {
return Ok(token_set.into_validated());
};
if let OidcJwksConfig::Enabled(validator) = oidc_jwks {
validator
.validate(oidc.raw())
.await
.map_err(IdTokenError::JwksValidationFailed)?;
}
oidc.validate_standard_claims(client_id, issuer, None)?;
Ok(token_set.into_validated())
}
#[expect(
clippy::too_many_arguments,
reason = "private orchestrator function; all args are distinct concerns that cannot be bundled without noise"
)]
async fn handle_callback(
auth: &CliTokenClient,
redirect_uri_url: &url::Url,
state_token: &str,
code_verifier: &str,
nonce: Option<&str>,
inner_tx: mpsc::Sender<RenderedHtml>,
mut outer_rx: mpsc::Receiver<CallbackResult>,
shutdown_arc: Arc<Mutex<Option<oneshot::Sender<()>>>>,
) -> Result<crate::token::TokenSet<crate::token::Validated>, AuthError> {
let callback_result = tokio::select! {
result = tokio::time::timeout(auth.timeout, outer_rx.recv()) => {
match result {
Err(_) => {
trigger_shutdown(&shutdown_arc).await;
return Err(AuthError::Timeout);
}
Ok(None) => return Err(AuthError::Server("channel closed".to_string())),
Ok(Some(r)) => r,
}
}
_ = tokio::signal::ctrl_c() => {
trigger_shutdown(&shutdown_arc).await;
return Err(AuthError::Cancelled);
}
};
let code = resolve_callback_code(
callback_result,
state_token,
auth,
redirect_uri_url,
&inner_tx,
)
.await?;
let token_set = match exchange_code(
&auth.http_client,
&auth.token_url,
auth.client_id.as_str(),
auth.client_secret.as_deref(),
&code,
redirect_uri_url.as_str(),
code_verifier,
&auth.scopes,
)
.await
{
Ok(ts) => ts,
Err(e) => {
let html = render_error_html(&e, auth, redirect_uri_url).await;
let _ = inner_tx.send(RenderedHtml(html)).await;
return Err(e);
}
};
let token_set = if let Some(oidc_jwks) = &auth.oidc_jwks {
match validate_id_token_required(
oidc_jwks,
token_set,
auth.client_id.as_str(),
auth.issuer.as_ref().map_or(
crate::oidc::IssuerValidation::Skip,
crate::oidc::IssuerValidation::MustMatch,
),
nonce,
)
.await
.map_err(AuthError::IdToken)
{
Ok(ts) => ts,
Err(e) => {
let html = render_error_html(&e, auth, redirect_uri_url).await;
let _ = inner_tx.send(RenderedHtml(html)).await;
return Err(e);
}
}
} else {
token_set.into_validated()
};
let html = render_success_html(
&token_set,
&auth.scopes,
redirect_uri_url,
auth.client_id.as_str(),
auth.success_renderer.as_deref(),
auth.success_html.as_deref(),
)
.await;
let _ = inner_tx.send(RenderedHtml(html)).await;
Ok(token_set)
}
async fn render_error_html(
err: &AuthError,
auth: &CliTokenClient,
redirect_uri_url: &url::Url,
) -> String {
let ctx = crate::pages::ErrorPageContext::new(
err,
&auth.scopes,
redirect_uri_url,
auth.client_id.as_str(),
);
if let Some(renderer) = auth.error_renderer.as_deref() {
renderer.render_error(&ctx).await
} else if let Some(html) = auth.error_html.as_deref() {
html.to_string()
} else {
crate::pages::DefaultErrorPageRenderer
.render_error(&ctx)
.await
}
}
async fn render_success_html(
token_set: &crate::token::TokenSet,
scopes: &[OAuth2Scope],
redirect_uri_url: &url::Url,
client_id: &str,
success_renderer: Option<&(dyn crate::pages::SuccessPageRenderer + Send + Sync)>,
success_html: Option<&str>,
) -> String {
let ctx = crate::pages::PageContext::new(
token_set.oidc().map(crate::oidc::Token::claims),
scopes,
redirect_uri_url,
client_id,
token_set.expires_at(),
token_set.refresh_token().is_some(),
);
if let Some(renderer) = success_renderer {
renderer.render_success(&ctx).await
} else if let Some(html) = success_html {
html.to_string()
} else {
crate::pages::DefaultSuccessPageRenderer
.render_success(&ctx)
.await
}
}
#[expect(
clippy::too_many_arguments,
reason = "all arguments are distinct OAuth2 code exchange parameters; grouping them would obscure their individual meanings"
)]
async fn exchange_code(
http_client: &reqwest::Client,
token_url: &url::Url,
client_id: &str,
client_secret: Option<&str>,
code: &str,
redirect_uri: &str,
code_verifier: &str,
scopes: &[crate::scope::OAuth2Scope],
) -> Result<crate::token::TokenSet<crate::token::Unvalidated>, AuthError> {
let mut params = vec![
("grant_type", "authorization_code"),
("code", code),
("redirect_uri", redirect_uri),
("client_id", client_id),
("code_verifier", code_verifier),
];
if let Some(secret) = client_secret {
params.push(("client_secret", secret));
}
let t0 = std::time::SystemTime::now();
let response = http_client
.post(token_url.as_str())
.header(reqwest::header::ACCEPT, "application/json")
.form(¶ms)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body_bytes = response.bytes().await.unwrap_or_default();
let body = String::from_utf8_lossy(&body_bytes).into_owned();
return Err(AuthError::TokenExchange { status, body });
}
let body = response.text().await?;
let token_response: TokenResponse =
serde_json::from_str(&body).map_err(|e| AuthError::Server(format!("{e}: {body}")))?;
let expires_at = token_response
.expires_in
.map(|secs| t0 + std::time::Duration::from_secs(secs));
let oidc = parse_oidc_if_requested(token_response.id_token.as_deref(), scopes)
.map_err(AuthError::IdToken)?;
let resolved_scopes = token_response
.scope
.as_deref()
.map_or_else(|| scopes.to_vec(), parse_scopes);
Ok(crate::token::TokenSet::new(
token_response.access_token,
token_response.refresh_token,
expires_at,
token_response
.token_type
.unwrap_or_else(|| "Bearer".to_string()),
oidc,
resolved_scopes,
))
}
async fn exchange_refresh_token(
http_client: &reqwest::Client,
token_url: &url::Url,
client_id: &str,
client_secret: Option<&str>,
refresh_token: &str,
scopes: &[crate::scope::OAuth2Scope],
) -> Result<crate::token::TokenSet<crate::token::Unvalidated>, RefreshError> {
let scope_str = (!scopes.is_empty()).then(|| {
scopes
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(" ")
});
let mut params = vec![
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", client_id),
];
if let Some(secret) = client_secret {
params.push(("client_secret", secret));
}
if let Some(ref s) = scope_str {
params.push(("scope", s.as_str()));
}
let t0 = std::time::SystemTime::now();
let response = http_client
.post(token_url.as_str())
.header(reqwest::header::ACCEPT, "application/json")
.form(¶ms)
.send()
.await?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body_bytes = response.bytes().await.unwrap_or_default();
let body = String::from_utf8_lossy(&body_bytes).into_owned();
return Err(RefreshError::TokenExchange { status, body });
}
let token_response: TokenResponse = response.json().await?;
let expires_at = token_response
.expires_in
.map(|secs| t0 + std::time::Duration::from_secs(secs));
let oidc = parse_oidc_if_requested(token_response.id_token.as_deref(), scopes)
.map_err(RefreshError::IdToken)?;
let resolved_scopes = token_response
.scope
.as_deref()
.map_or_else(|| scopes.to_vec(), parse_scopes);
let resolved_refresh_token = token_response
.refresh_token
.or_else(|| Some(refresh_token.to_string()));
Ok(crate::token::TokenSet::new(
token_response.access_token,
resolved_refresh_token,
expires_at,
token_response
.token_type
.unwrap_or_else(|| "Bearer".to_string()),
oidc,
resolved_scopes,
))
}
#[non_exhaustive]
pub struct Http;
pub struct Https(Option<crate::tls::TlsCertificate>);
pub trait IntoTransport: sealed::Sealed {
fn into_transport(self) -> Arc<dyn Transport>;
}
impl sealed::Sealed for Http {}
impl IntoTransport for Http {
fn into_transport(self) -> Arc<dyn Transport> {
Arc::new(HttpTransport)
}
}
impl sealed::Sealed for Https {}
impl IntoTransport for Https {
fn into_transport(self) -> Arc<dyn Transport> {
match self.0 {
Some(cert) => Arc::new(crate::server::HttpsCustomTransport {
acceptor: cert.acceptor,
}),
None => Arc::new(crate::server::HttpsSelfSignedTransport),
}
}
}
mod sealed {
pub trait Sealed {}
}
#[non_exhaustive]
pub struct NoClientId;
#[non_exhaustive]
pub struct HasClientId(ClientId);
#[non_exhaustive]
pub struct NoAuthUrl;
#[non_exhaustive]
pub struct HasAuthUrl(url::Url);
#[non_exhaustive]
pub struct NoTokenUrl;
#[non_exhaustive]
pub struct HasTokenUrl(url::Url);
#[non_exhaustive]
pub struct NoOidc;
#[non_exhaustive]
pub struct OidcPending;
pub struct JwksEnabled(JwksValidatorStorage);
#[non_exhaustive]
pub struct JwksDisabled;
struct BuilderConfig {
client_secret: Option<String>,
issuer: Option<url::Url>,
scopes: std::collections::BTreeSet<OAuth2Scope>,
port_config: PortConfig,
success_html: Option<String>,
error_html: Option<String>,
success_renderer: Option<SuccessRendererStorage>,
error_renderer: Option<ErrorRendererStorage>,
open_browser: bool,
timeout: std::time::Duration,
on_auth_url: Option<OnAuthUrlCallback>,
on_url: Option<OnUrlCallback>,
on_server_ready: Option<OnServerReadyCallback>,
}
impl Default for BuilderConfig {
fn default() -> Self {
Self {
client_secret: None,
scopes: std::collections::BTreeSet::new(),
port_config: PortConfig::Random,
success_html: None,
error_html: None,
success_renderer: None,
error_renderer: None,
open_browser: true,
timeout: std::time::Duration::from_secs(TIMEOUT_DURATION_IN_SECONDS),
on_auth_url: None,
on_url: None,
on_server_ready: None,
issuer: None,
}
}
}
pub struct CliTokenClientBuilder<
C = NoClientId,
A = NoAuthUrl,
T = NoTokenUrl,
O = NoOidc,
S = Http,
> {
client_id: C,
auth_url: A,
token_url: T,
oidc: O,
scheme: S,
config: BuilderConfig,
}
impl Default for CliTokenClientBuilder {
fn default() -> Self {
Self {
client_id: NoClientId,
auth_url: NoAuthUrl,
token_url: NoTokenUrl,
oidc: NoOidc,
scheme: Http,
config: BuilderConfig::default(),
}
}
}
impl CliTokenClientBuilder {
#[must_use]
pub fn from_open_id_configuration(
open_id_configuration: &OpenIdConfiguration,
) -> CliTokenClientBuilder<NoClientId, HasAuthUrl, HasTokenUrl, OidcPending, Http> {
CliTokenClientBuilder {
client_id: NoClientId,
auth_url: HasAuthUrl(open_id_configuration.authorization_endpoint().clone()),
token_url: HasTokenUrl(open_id_configuration.token_endpoint().clone()),
oidc: OidcPending,
scheme: Http,
config: BuilderConfig {
issuer: Some(open_id_configuration.issuer().clone()),
scopes: std::collections::BTreeSet::from([OAuth2Scope::OpenId]),
..BuilderConfig::default()
},
}
}
}
impl<C, A, T, O, S> CliTokenClientBuilder<C, A, T, O, S> {
#[must_use]
pub fn client_id(self, v: impl Into<String>) -> CliTokenClientBuilder<HasClientId, A, T, O, S> {
CliTokenClientBuilder {
client_id: HasClientId(ClientId(v.into())),
auth_url: self.auth_url,
token_url: self.token_url,
oidc: self.oidc,
scheme: self.scheme,
config: self.config,
}
}
#[must_use]
pub fn auth_url(self, v: url::Url) -> CliTokenClientBuilder<C, HasAuthUrl, T, O, S> {
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: HasAuthUrl(v),
token_url: self.token_url,
oidc: self.oidc,
scheme: self.scheme,
config: self.config,
}
}
#[must_use]
pub fn token_url(self, v: url::Url) -> CliTokenClientBuilder<C, A, HasTokenUrl, O, S> {
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: HasTokenUrl(v),
oidc: self.oidc,
scheme: self.scheme,
config: self.config,
}
}
#[must_use]
pub fn client_secret(mut self, v: impl Into<String>) -> Self {
self.config.client_secret = Some(v.into());
self
}
#[must_use]
pub fn add_scopes(mut self, v: impl IntoIterator<Item = RequestScope>) -> Self {
self.config
.scopes
.extend(v.into_iter().map(OAuth2Scope::from));
self
}
#[must_use]
pub const fn port_hint(mut self, v: u16) -> Self {
self.config.port_config = PortConfig::Hint(v);
self
}
#[must_use]
pub const fn require_port(mut self, v: u16) -> Self {
self.config.port_config = PortConfig::Required(v);
self
}
#[must_use]
pub fn success_html(mut self, v: impl Into<String>) -> Self {
self.config.success_html = Some(v.into());
self
}
#[must_use]
pub fn error_html(mut self, v: impl Into<String>) -> Self {
self.config.error_html = Some(v.into());
self
}
#[must_use]
pub fn success_renderer(mut self, r: impl SuccessPageRenderer + 'static) -> Self {
self.config.success_renderer = Some(Box::new(r));
self
}
#[must_use]
pub fn error_renderer(mut self, r: impl ErrorPageRenderer + 'static) -> Self {
self.config.error_renderer = Some(Box::new(r));
self
}
#[must_use]
pub const fn open_browser(mut self, v: bool) -> Self {
self.config.open_browser = v;
self
}
#[must_use]
pub const fn timeout(mut self, v: std::time::Duration) -> Self {
self.config.timeout = v;
self
}
#[must_use]
pub fn on_auth_url(mut self, f: impl Fn(&mut ExtraAuthParams) + Send + Sync + 'static) -> Self {
self.config.on_auth_url = Some(Box::new(f));
self
}
#[must_use]
pub fn on_url(mut self, f: impl Fn(&url::Url) + Send + Sync + 'static) -> Self {
self.config.on_url = Some(Box::new(f));
self
}
#[must_use]
pub fn on_server_ready(mut self, f: impl Fn(u16) + Send + Sync + 'static) -> Self {
self.config.on_server_ready = Some(Box::new(f));
self
}
#[must_use]
pub fn use_https(self) -> CliTokenClientBuilder<C, A, T, O, Https> {
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: self.token_url,
oidc: self.oidc,
scheme: Https(None),
config: self.config,
}
}
#[must_use]
pub fn use_https_with(
self,
certificate: crate::tls::TlsCertificate,
) -> CliTokenClientBuilder<C, A, T, O, Https> {
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: self.token_url,
oidc: self.oidc,
scheme: Https(Some(certificate)),
config: self.config,
}
}
}
impl<C, A, T, S> CliTokenClientBuilder<C, A, T, NoOidc, S> {
#[must_use]
pub fn with_openid_scope(mut self) -> CliTokenClientBuilder<C, A, T, OidcPending, S> {
self.config.scopes.insert(OAuth2Scope::OpenId);
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: self.token_url,
oidc: OidcPending,
scheme: self.scheme,
config: self.config,
}
}
}
impl<C, A, T, S> CliTokenClientBuilder<C, A, T, OidcPending, S> {
#[must_use]
pub fn issuer(mut self, v: url::Url) -> Self {
self.config.issuer = Some(v);
self
}
#[must_use]
pub fn jwks_validator(
self,
v: Box<dyn JwksValidator>,
) -> CliTokenClientBuilder<C, A, T, JwksEnabled, S> {
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: self.token_url,
oidc: JwksEnabled(v),
scheme: self.scheme,
config: self.config,
}
}
#[must_use]
pub fn without_jwks_validation(self) -> CliTokenClientBuilder<C, A, T, JwksDisabled, S> {
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: self.token_url,
oidc: JwksDisabled,
scheme: self.scheme,
config: self.config,
}
}
}
impl<A, T, S> CliTokenClientBuilder<HasClientId, A, T, OidcPending, S> {
#[must_use]
pub fn with_open_id_configuration_jwks_validator(
self,
open_id_configuration: &OpenIdConfiguration,
) -> CliTokenClientBuilder<HasClientId, A, T, JwksEnabled, S> {
let client_id = self.client_id.0.as_str().to_owned();
let validator = Box::new(RemoteJwksValidator::from_open_id_configuration(
open_id_configuration,
client_id,
));
CliTokenClientBuilder {
client_id: self.client_id,
auth_url: self.auth_url,
token_url: self.token_url,
oidc: JwksEnabled(validator),
scheme: self.scheme,
config: self.config,
}
}
}
impl<C, A, T, S> CliTokenClientBuilder<C, A, T, JwksEnabled, S> {
#[must_use]
pub fn issuer(mut self, v: url::Url) -> Self {
self.config.issuer = Some(v);
self
}
}
impl<C, A, T, S> CliTokenClientBuilder<C, A, T, JwksDisabled, S> {
#[must_use]
pub fn issuer(mut self, v: url::Url) -> Self {
self.config.issuer = Some(v);
self
}
}
impl<S: IntoTransport> CliTokenClientBuilder<HasClientId, HasAuthUrl, HasTokenUrl, JwksEnabled, S> {
#[must_use]
pub fn build(mut self) -> CliTokenClient {
self.config.scopes.insert(OAuth2Scope::OpenId);
build_client(
self.client_id.0,
self.auth_url.0,
self.token_url.0,
self.config,
Some(OidcJwksConfig::Enabled(self.oidc.0)),
self.scheme.into_transport(),
)
}
}
impl<S: IntoTransport>
CliTokenClientBuilder<HasClientId, HasAuthUrl, HasTokenUrl, JwksDisabled, S>
{
#[must_use]
pub fn build(mut self) -> CliTokenClient {
self.config.scopes.insert(OAuth2Scope::OpenId);
build_client(
self.client_id.0,
self.auth_url.0,
self.token_url.0,
self.config,
Some(OidcJwksConfig::Disabled),
self.scheme.into_transport(),
)
}
}
impl<S: IntoTransport> CliTokenClientBuilder<HasClientId, HasAuthUrl, HasTokenUrl, NoOidc, S> {
#[must_use]
pub fn build(self) -> CliTokenClient {
build_client(
self.client_id.0,
self.auth_url.0,
self.token_url.0,
self.config,
None,
self.scheme.into_transport(),
)
}
}
fn build_client(
client_id: ClientId,
auth_url: url::Url,
token_url: url::Url,
config: BuilderConfig,
oidc_jwks: Option<OidcJwksConfig>,
transport: Arc<dyn Transport>,
) -> CliTokenClient {
CliTokenClient {
client_id,
client_secret: config.client_secret,
auth_url,
token_url,
issuer: config.issuer,
scopes: config.scopes.into_iter().collect(),
port_config: config.port_config,
success_html: config.success_html,
error_html: config.error_html,
success_renderer: config.success_renderer,
error_renderer: config.error_renderer,
open_browser: config.open_browser,
timeout: config.timeout,
on_auth_url: config.on_auth_url,
on_url: config.on_url,
on_server_ready: config.on_server_ready,
oidc_jwks,
http_client: reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_secs(HTTP_CONNECT_TIMEOUT_SECONDS))
.timeout(std::time::Duration::from_secs(HTTP_REQUEST_TIMEOUT_SECONDS))
.build()
.unwrap_or_default(),
transport,
}
}
#[cfg(test)]
mod tests {
#![expect(
clippy::indexing_slicing,
clippy::expect_used,
clippy::unwrap_used,
reason = "tests do not need to meet production lint standards"
)]
use super::{
AuthUrlParams, CliTokenClient, CliTokenClientBuilder, ExtraAuthParams, HasAuthUrl,
HasClientId, HasTokenUrl, NoOidc, parse_scopes,
};
use crate::jwks::{JwksValidationError, JwksValidator};
use crate::oidc::Token;
use crate::scope::OAuth2Scope;
use async_trait::async_trait;
fn fake_jwt(sub: &str, email: &str) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
let claims = URL_SAFE_NO_PAD.encode(format!(
r#"{{"sub":"{sub}","email":"{email}","iss":"https://accounts.example.com","iat":1000000000,"exp":9999999999}}"#
));
format!("{header}.{claims}.fakesig")
}
fn fake_jwt_google_style(
sub: &str,
email: &str,
name: &str,
picture: &str,
aud: &str,
) -> String {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let header = URL_SAFE_NO_PAD.encode(r#"{"alg":"RS256","typ":"JWT"}"#);
let claims = URL_SAFE_NO_PAD.encode(format!(
r#"{{"iss":"https://accounts.google.com","aud":"{aud}","sub":"{sub}","email":"{email}","email_verified":true,"name":"{name}","picture":"{picture}","iat":1000000000,"exp":9999999999}}"#
));
format!("{header}.{claims}.fakesig")
}
#[test]
fn oidc_token_from_raw_jwt_returns_ok_for_valid_fake_jwt() {
let jwt = fake_jwt("user_42", "user@example.com");
let oidc = Token::from_raw_jwt(&jwt).expect("expected Ok for valid fake JWT");
assert_eq!(oidc.claims().sub().as_str(), "user_42");
assert_eq!(
oidc.claims().email().map(crate::oidc::Email::as_str),
Some("user@example.com")
);
}
#[test]
fn oidc_token_from_raw_jwt_returns_err_for_invalid_input() {
let result = Token::from_raw_jwt("not.a.jwt");
assert!(result.is_err(), "expected Err for invalid JWT");
}
#[test]
fn oidc_token_from_raw_jwt_with_aud_claim_returns_ok() {
let jwt = fake_jwt_google_style(
"1234567890",
"user@gmail.com",
"Test User",
"https://example.com/photo.jpg",
"my-client-id.apps.googleusercontent.com",
);
let oidc = Token::from_raw_jwt(&jwt).expect("expected Ok for JWT with aud claim");
assert_eq!(oidc.claims().sub().as_str(), "1234567890");
assert_eq!(
oidc.claims().email().map(crate::oidc::Email::as_str),
Some("user@gmail.com")
);
assert_eq!(oidc.claims().name(), Some("Test User"));
assert_eq!(
oidc.claims().picture().map(|p| p.as_url().as_str()),
Some("https://example.com/photo.jpg")
);
assert!(oidc.claims().email().unwrap().is_verified());
}
fn valid_builder() -> CliTokenClientBuilder<HasClientId, HasAuthUrl, HasTokenUrl, NoOidc> {
CliTokenClient::builder()
.client_id("test-client")
.auth_url(url::Url::parse("https://example.com/auth").unwrap())
.token_url(url::Url::parse("https://example.com/token").unwrap())
}
#[test]
fn builder_returns_cli_token_client_builder() {
let _builder: CliTokenClientBuilder = CliTokenClient::builder();
}
#[test]
fn build_with_valid_inputs_returns_client() {
let _client = valid_builder().build();
}
#[test]
fn rfc_6749_s5_1_scope_fallback_uses_requested_scopes_when_response_omits_scope() {
let requested = vec![OAuth2Scope::OpenId, OAuth2Scope::Email];
let resolved: Vec<OAuth2Scope> = None::<String>
.as_deref()
.map_or_else(|| requested.clone(), parse_scopes);
assert_eq!(resolved, requested);
let resolved_from_response: Vec<OAuth2Scope> = Some("openid profile".to_string())
.as_deref()
.map_or_else(|| requested.clone(), parse_scopes);
assert_eq!(
resolved_from_response,
vec![OAuth2Scope::OpenId, OAuth2Scope::Profile]
);
}
#[test]
fn oidc_token_from_raw_jwt_populates_iss_aud_iat_exp() {
let jwt = fake_jwt_google_style(
"sub-iss-test",
"user@example.com",
"Test User",
"https://example.com/photo.jpg",
"my-client-id",
);
let oidc = Token::from_raw_jwt(&jwt).expect("should decode");
let claims = oidc.claims();
assert_eq!(
claims.iss().as_url(),
&url::Url::parse("https://accounts.google.com").unwrap()
);
assert_eq!(claims.aud().len(), 1);
assert_eq!(claims.aud()[0].as_str(), "my-client-id");
assert!(
claims.iat() > std::time::UNIX_EPOCH,
"iat should be after epoch"
);
assert!(
claims.exp() > std::time::UNIX_EPOCH,
"exp should be after epoch"
);
}
struct AcceptAll;
#[async_trait]
impl JwksValidator for AcceptAll {
async fn validate(&self, _raw_token: &str) -> Result<(), JwksValidationError> {
Ok(())
}
}
#[test]
fn build_with_jwks_validator_and_openid_scope_succeeds() {
let _client = valid_builder()
.with_openid_scope()
.jwks_validator(Box::new(AcceptAll))
.build();
}
fn make_open_id_configuration() -> crate::oidc::OpenIdConfiguration {
use url::Url;
crate::oidc::OpenIdConfiguration::new_for_test(
Url::parse("https://accounts.example.com").unwrap(),
Url::parse("https://accounts.example.com/authorize").unwrap(),
Url::parse("https://accounts.example.com/token").unwrap(),
Url::parse("https://accounts.example.com/.well-known/jwks.json").unwrap(),
)
}
#[test]
fn from_open_id_configuration_always_includes_openid_scope() {
let config = make_open_id_configuration();
let _client = CliTokenClientBuilder::from_open_id_configuration(&config)
.client_id("test-client")
.without_jwks_validation()
.build();
}
#[test]
fn extra_auth_params_append_accumulates_pairs() {
let mut params = ExtraAuthParams::new();
params.append("access_type", "offline");
params.append("prompt", "consent");
assert_eq!(params.pairs.len(), 2);
assert_eq!(
params.pairs[0],
("access_type".to_string(), "offline".to_string())
);
assert_eq!(
params.pairs[1],
("prompt".to_string(), "consent".to_string())
);
}
#[test]
fn extra_auth_params_apply_to_adds_non_reserved_keys() {
let mut params = ExtraAuthParams::new();
params.append("access_type", "offline");
let mut url = url::Url::parse("https://example.com/auth").unwrap();
params.apply_to(&mut url);
let pairs: Vec<(_, _)> = url.query_pairs().collect();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "access_type");
assert_eq!(pairs[0].1, "offline");
}
#[test]
fn extra_auth_params_apply_to_drops_reserved_keys() {
for reserved in AuthUrlParams::KEYS {
let mut params = ExtraAuthParams::new();
params.append(*reserved, "injected");
let mut url = url::Url::parse("https://example.com/auth").unwrap();
params.apply_to(&mut url);
assert!(
url.query_pairs().next().is_none(),
"reserved key '{reserved}' should have been dropped"
);
}
}
#[test]
fn extra_auth_params_apply_to_passes_non_reserved_and_drops_reserved() {
let mut params = ExtraAuthParams::new();
params.append("state", "injected"); params.append("access_type", "offline"); let mut url = url::Url::parse("https://example.com/auth").unwrap();
params.apply_to(&mut url);
let pairs: Vec<(_, _)> = url.query_pairs().collect();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0].0, "access_type");
}
}