use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use url::Url;
use super::port::{IdAssertion, ScopePiiReader};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Config {
pub client_id: String,
pub redirect_uri: Url,
pub issuer: Url,
pub state_ttl: Duration,
}
impl Config {
pub fn new(client_id: impl Into<String>, redirect_uri: Url, issuer: Url) -> Self {
Self {
client_id: client_id.into(),
redirect_uri,
issuer,
state_ttl: Duration::from_secs(600),
}
}
#[must_use]
pub fn with_state_ttl(mut self, ttl: Duration) -> Self {
self.state_ttl = ttl;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct State(String);
impl State {
pub fn from_string(s: String) -> Self {
Self(s)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct RelativePath(String);
impl RelativePath {
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for RelativePath {
fn default() -> Self {
Self("/".to_owned())
}
}
#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
pub enum RelativePathError {
#[error("relative path must not be protocol-relative (e.g., '//host/path')")]
ProtocolRelative,
#[error("relative path must start with '/'")]
NotRooted,
#[error("relative path must not contain a scheme (e.g., 'https://...', 'javascript:')")]
SchemePresent,
#[error("relative path must not contain control characters")]
ControlCharacters,
}
impl<'a> TryFrom<&'a str> for RelativePath {
type Error = RelativePathError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
if value.starts_with("//") {
return Err(RelativePathError::ProtocolRelative);
}
if !value.starts_with('/') {
return Err(RelativePathError::NotRooted);
}
let path_only = value.split(['?', '#']).next().unwrap_or(value);
if path_only.contains(':') {
return Err(RelativePathError::SchemePresent);
}
if value.chars().any(char::is_control) {
return Err(RelativePathError::ControlCharacters);
}
Ok(Self(value.to_owned()))
}
}
impl TryFrom<String> for RelativePath {
type Error = RelativePathError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl<'de> Deserialize<'de> for RelativePath {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
RelativePath::try_from(s).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PendingAuthRequest {
pub code_verifier: String,
pub nonce: String,
pub after_login: RelativePath,
#[serde(with = "time::serde::rfc3339")]
pub created_at: OffsetDateTime,
}
#[derive(Debug, Clone)]
pub struct AuthorizationRedirect {
pub url: Url,
pub state: State,
}
#[derive(Debug, Clone)]
pub struct CallbackParams {
pub code: String,
pub state: State,
}
#[derive(Debug)]
pub struct Completion<S: ScopePiiReader> {
pub id_assertion: IdAssertion<S>,
pub tokens: crate::oauth::TokenResponse,
pub redirect_to: RelativePath,
}
#[async_trait]
pub trait StateStore: Send + Sync {
async fn put(
&self,
state: &State,
pending: PendingAuthRequest,
ttl: Duration,
) -> Result<(), StateStoreError>;
async fn take(
&self,
state: &State,
) -> Result<Option<PendingAuthRequest>, StateStoreError>;
}
#[derive(Debug, thiserror::Error)]
pub enum StateStoreError {
#[error("state-store substrate failure: {0}")]
Substrate(String),
#[error("state-store serialization failure: {0}")]
Serialization(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn relative_path_accepts_root() {
let p = RelativePath::try_from("/").expect("rooted path accepted");
assert_eq!(p.as_str(), "/");
}
#[test]
fn relative_path_accepts_nested() {
let p = RelativePath::try_from("/dashboard/settings").expect("nested accepted");
assert_eq!(p.as_str(), "/dashboard/settings");
}
#[test]
fn relative_path_accepts_query_and_fragment() {
let p = RelativePath::try_from("/x?y=1#z").expect("query+fragment accepted");
assert_eq!(p.as_str(), "/x?y=1#z");
}
#[test]
fn relative_path_rejects_https_scheme() {
assert_eq!(
RelativePath::try_from("https://evil.com"),
Err(RelativePathError::NotRooted),
);
}
#[test]
fn relative_path_rejects_javascript_scheme() {
assert_eq!(
RelativePath::try_from("javascript:alert(1)"),
Err(RelativePathError::NotRooted),
);
}
#[test]
fn relative_path_rejects_protocol_relative() {
assert_eq!(
RelativePath::try_from("//evil.com/path"),
Err(RelativePathError::ProtocolRelative),
);
}
#[test]
fn relative_path_rejects_colon_smuggled_after_root() {
assert_eq!(
RelativePath::try_from("/https://x"),
Err(RelativePathError::SchemePresent),
);
}
#[test]
fn relative_path_rejects_control_characters() {
assert_eq!(
RelativePath::try_from("/path\rwith\nnewline"),
Err(RelativePathError::ControlCharacters),
);
}
#[test]
fn relative_path_serde_roundtrip_validates() {
let p = RelativePath::try_from("/ok").unwrap();
let json = serde_json::to_string(&p).unwrap();
let back: RelativePath = serde_json::from_str(&json).unwrap();
assert_eq!(back.as_str(), "/ok");
}
#[test]
fn relative_path_deserialize_rejects_smuggled_scheme() {
let result: Result<RelativePath, _> = serde_json::from_str(r#""https://evil""#);
assert!(result.is_err(), "smuggled absolute URL must reject on deserialize");
}
}