use super::{StdbOidcAuthOptions, common};
use crate::{
error::StdbAuthError,
session::{StdbAuthSessionParts, StdbAuthSessionSource},
token::StdbTokenResponse,
};
use std::{
io::{Read, Write},
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream},
thread,
time::{Duration, Instant},
};
use url::Url;
const CALLBACK_TIMEOUT: Duration = Duration::from_secs(120);
const CALLBACK_POLL_INTERVAL: Duration = Duration::from_millis(10);
const CALLBACK_READ_TIMEOUT: Duration = Duration::from_secs(5);
const CALLBACK_REQUEST_BUFFER_SIZE: usize = 8192;
pub(crate) async fn acquire_session(
options: StdbOidcAuthOptions,
) -> Result<StdbAuthSessionParts, StdbAuthError> {
let redirect_uri = validate_native_redirect_uri(&options.redirect_uri)?;
#[cfg(feature = "persistence")]
if let Some(parts) = try_refresh_stored_session(&options).await {
return Ok(parts);
}
let listener = bind_loopback_listener(&redirect_uri)?;
let authorization_request = common::build_authorization_request(&options)?;
webbrowser::open(authorization_request.authorization_url.as_str()).map_err(|error| {
StdbAuthError::Internal(format!("failed to open system browser: {error}"))
})?;
let authorization_code = receive_authorization_code(
&listener,
&redirect_uri,
&authorization_request.state,
CALLBACK_TIMEOUT,
)?;
let token_form = common::authorization_code_token_form(
&options,
&authorization_code.code,
&authorization_request.pkce_verifier,
)?;
let token = exchange_authorization_code(token_form)?;
token.into_session_parts(
Some(options.client_id),
StdbAuthSessionSource::Oidc,
options.post_logout_redirect_uri,
)
}
#[cfg(feature = "persistence")]
async fn try_refresh_stored_session(options: &StdbOidcAuthOptions) -> Option<StdbAuthSessionParts> {
let refresh_token = super::persistence::stored_refresh_token_best_effort(&options.client_id)?;
let session = crate::session::StdbAuthSession {
access_token: String::new(),
token_type: "Bearer".to_string(),
expires_at: None,
can_refresh: true,
scope: None,
client_id: Some(options.client_id.clone()),
source: StdbAuthSessionSource::Oidc,
post_logout_redirect_uri: options.post_logout_redirect_uri.clone(),
};
crate::refresh::refresh_session(session, refresh_token)
.await
.ok()
}
fn bind_loopback_listener(redirect_uri: &Url) -> Result<TcpListener, StdbAuthError> {
let bind_addr = loopback_bind_addr(redirect_uri)?;
let listener = TcpListener::bind(bind_addr).map_err(|error| {
StdbAuthError::Internal(format!("failed to bind OIDC callback listener: {error}"))
})?;
listener.set_nonblocking(true).map_err(|error| {
StdbAuthError::Internal(format!(
"failed to configure OIDC callback listener: {error}"
))
})?;
Ok(listener)
}
fn receive_authorization_code(
listener: &TcpListener,
redirect_uri: &Url,
expected_state: &str,
timeout: Duration,
) -> Result<common::StdbOidcAuthorizationCode, StdbAuthError> {
let started_at = Instant::now();
loop {
match listener.accept() {
Ok((mut stream, _remote_addr)) => {
let callback_url = match read_callback_url(&mut stream, redirect_uri) {
Ok(callback_url) => callback_url,
Err(error) => {
let _ = write_callback_response(&mut stream, false);
return Err(error);
}
};
let result = common::parse_callback_url(callback_url.as_str(), expected_state);
let _ = write_callback_response(&mut stream, result.is_ok());
return result;
}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
if started_at.elapsed() >= timeout {
return Err(StdbAuthError::Timeout);
}
thread::sleep(CALLBACK_POLL_INTERVAL);
}
Err(error) => {
return Err(StdbAuthError::Internal(format!(
"failed to accept OIDC callback: {error}"
)));
}
}
}
}
fn exchange_authorization_code(
token_form: common::StdbOidcTokenRequestForm,
) -> Result<StdbTokenResponse, StdbAuthError> {
let client = crate::transport::token_client()?;
let response = crate::transport::token_endpoint_request(&client)
.form(&token_form.params)
.send()
.map_err(StdbAuthError::from)?
.error_for_status()
.map_err(StdbAuthError::from)?;
response
.json::<StdbTokenResponse>()
.map_err(StdbAuthError::from)
}
fn read_callback_url(stream: &mut TcpStream, redirect_uri: &Url) -> Result<Url, StdbAuthError> {
stream
.set_read_timeout(Some(CALLBACK_READ_TIMEOUT))
.map_err(|error| {
StdbAuthError::Internal(format!("failed to configure OIDC callback stream: {error}"))
})?;
let mut buffer = [0; CALLBACK_REQUEST_BUFFER_SIZE];
let bytes_read = stream.read(&mut buffer).map_err(|error| {
StdbAuthError::InvalidOidcCallback(format!("failed to read callback request: {error}"))
})?;
let request = String::from_utf8_lossy(&buffer[..bytes_read]);
let request_line = request.lines().next().ok_or_else(|| {
StdbAuthError::InvalidOidcCallback("callback request is empty".to_string())
})?;
let mut parts = request_line.split_whitespace();
let method = parts.next().ok_or_else(|| {
StdbAuthError::InvalidOidcCallback("callback request is missing method".to_string())
})?;
let request_target = parts.next().ok_or_else(|| {
StdbAuthError::InvalidOidcCallback("callback request is missing target".to_string())
})?;
if method != "GET" {
return Err(StdbAuthError::InvalidOidcCallback(
"callback request method must be GET".to_string(),
));
}
callback_url_from_request_target(redirect_uri, request_target)
}
fn callback_url_from_request_target(
redirect_uri: &Url,
request_target: &str,
) -> Result<Url, StdbAuthError> {
let callback_url = if request_target.starts_with("http://")
|| request_target.starts_with("https://")
{
Url::parse(request_target)
} else {
redirect_uri.join(request_target)
}
.map_err(|error| {
StdbAuthError::InvalidOidcCallback(format!("callback request target is invalid: {error}"))
})?;
validate_callback_url_matches_redirect_uri(&callback_url, redirect_uri)?;
Ok(callback_url)
}
fn validate_callback_url_matches_redirect_uri(
callback_url: &Url,
redirect_uri: &Url,
) -> Result<(), StdbAuthError> {
let matches_redirect = callback_url.scheme() == redirect_uri.scheme()
&& callback_url.host_str() == redirect_uri.host_str()
&& callback_url.port_or_known_default() == redirect_uri.port_or_known_default()
&& callback_url.path() == redirect_uri.path();
if !matches_redirect {
return Err(StdbAuthError::InvalidOidcCallback(
"callback URL does not match the configured redirect URI".to_string(),
));
}
Ok(())
}
fn write_callback_response(stream: &mut TcpStream, succeeded: bool) -> std::io::Result<()> {
let (status, body) = if succeeded {
(
"200 OK",
"<!doctype html><title>Authenticated</title><p>Authentication completed. You can close this window.</p>",
)
} else {
(
"400 Bad Request",
"<!doctype html><title>Authentication failed</title><p>Authentication failed. You can close this window.</p>",
)
};
let response = format!(
"HTTP/1.1 {status}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes())
}
fn validate_native_redirect_uri(redirect_uri: &str) -> Result<Url, StdbAuthError> {
let redirect_uri = Url::parse(redirect_uri).map_err(|error| {
StdbAuthError::InvalidConfig(format!("`redirect_uri` is invalid: {error}"))
})?;
if redirect_uri.scheme() != "http" {
return Err(StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must use the `http` scheme".to_string(),
));
}
if redirect_uri.query().is_some() {
return Err(StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must not include a query string".to_string(),
));
}
if redirect_uri.fragment().is_some() {
return Err(StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must not include a fragment".to_string(),
));
}
let host = redirect_uri.host_str().ok_or_else(|| {
StdbAuthError::InvalidConfig("native OIDC `redirect_uri` must include a host".to_string())
})?;
if !is_loopback_host(host) {
return Err(StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must use a loopback host".to_string(),
));
}
if redirect_uri.port().is_none_or(|port| port == 0) {
return Err(StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must include a non-zero explicit port".to_string(),
));
}
Ok(redirect_uri)
}
fn loopback_bind_addr(redirect_uri: &Url) -> Result<SocketAddr, StdbAuthError> {
let host = redirect_uri.host_str().ok_or_else(|| {
StdbAuthError::InvalidConfig("native OIDC `redirect_uri` must include a host".to_string())
})?;
let port = redirect_uri.port().ok_or_else(|| {
StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must include an explicit port".to_string(),
)
})?;
let ip = if host.eq_ignore_ascii_case("localhost") {
IpAddr::V4(Ipv4Addr::LOCALHOST)
} else {
host.parse::<IpAddr>().map_err(|error| {
StdbAuthError::InvalidConfig(format!("native OIDC loopback host is invalid: {error}"))
})?
};
if !ip.is_loopback() {
return Err(StdbAuthError::InvalidConfig(
"native OIDC `redirect_uri` must use a loopback host".to_string(),
));
}
Ok(SocketAddr::new(ip, port))
}
fn is_loopback_host(host: &str) -> bool {
host.eq_ignore_ascii_case("localhost")
|| host
.parse::<IpAddr>()
.is_ok_and(|address| address.is_loopback())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn native_redirect_uri_accepts_loopback_http_with_port() {
let redirect_uri = validate_native_redirect_uri("http://127.0.0.1:3000/callback")
.expect("loopback redirect URI should be valid");
assert_eq!(redirect_uri.as_str(), "http://127.0.0.1:3000/callback");
}
#[test]
fn native_redirect_uri_rejects_non_loopback_hosts() {
let result = validate_native_redirect_uri("http://example.com:3000/callback");
assert!(matches!(result, Err(StdbAuthError::InvalidConfig(_))));
}
#[test]
fn native_redirect_uri_rejects_query_string() {
let result = validate_native_redirect_uri("http://127.0.0.1:3000/callback?route=auth");
assert!(matches!(result, Err(StdbAuthError::InvalidConfig(_))));
}
#[test]
fn native_redirect_uri_rejects_missing_port() {
let result = validate_native_redirect_uri("http://127.0.0.1/callback");
assert!(matches!(result, Err(StdbAuthError::InvalidConfig(_))));
}
#[test]
fn native_redirect_uri_rejects_zero_port() {
let result = validate_native_redirect_uri("http://127.0.0.1:0/callback");
assert!(matches!(result, Err(StdbAuthError::InvalidConfig(_))));
}
#[test]
fn callback_target_uses_redirect_origin() {
let redirect_uri = Url::parse("http://127.0.0.1:3000/callback").unwrap();
let callback_url =
callback_url_from_request_target(&redirect_uri, "/callback?code=abc&state=state")
.expect("callback target should be valid");
assert_eq!(
callback_url.as_str(),
"http://127.0.0.1:3000/callback?code=abc&state=state"
);
}
#[test]
fn callback_target_rejects_wrong_path() {
let redirect_uri = Url::parse("http://127.0.0.1:3000/callback").unwrap();
let result = callback_url_from_request_target(&redirect_uri, "/other?code=abc&state=state");
assert!(matches!(result, Err(StdbAuthError::InvalidOidcCallback(_))));
}
}