use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
ClientCredentials,
}
impl GrantType {
pub fn as_str(&self) -> &'static str {
match self {
Self::AuthorizationCode => "authorization_code",
Self::RefreshToken => "refresh_token",
Self::ClientCredentials => "client_credentials",
}
}
}
impl std::fmt::Display for GrantType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseType {
Code,
}
impl ResponseType {
pub fn as_str(&self) -> &'static str {
match self {
Self::Code => "code",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum CodeChallengeMethod {
#[serde(rename = "plain")]
Plain,
#[serde(rename = "S256")]
S256,
}
impl CodeChallengeMethod {
pub fn as_str(&self) -> &'static str {
match self {
Self::Plain => "plain",
Self::S256 => "S256",
}
}
}
impl Default for CodeChallengeMethod {
fn default() -> Self {
Self::S256
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ClientAuthMethod {
None,
ClientSecretPost,
ClientSecretBasic,
}
impl Default for ClientAuthMethod {
fn default() -> Self {
Self::None
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientConfig {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
pub redirect_uris: Vec<String>,
#[serde(default)]
pub grant_types: Vec<GrantType>,
#[serde(default)]
pub auth_method: ClientAuthMethod,
#[serde(default = "default_pkce_required")]
pub pkce_required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub homepage_url: Option<String>,
#[serde(default)]
pub allowed_scopes: Vec<String>,
#[serde(default)]
pub default_scopes: Vec<String>,
}
fn default_pkce_required() -> bool {
true
}
impl ClientConfig {
pub fn public(client_id: impl Into<String>, redirect_uris: Vec<String>) -> Self {
Self {
client_id: client_id.into(),
client_secret: None,
redirect_uris,
grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
auth_method: ClientAuthMethod::None,
pkce_required: true, name: None,
homepage_url: None,
allowed_scopes: Vec::new(),
default_scopes: Vec::new(),
}
}
pub fn confidential(
client_id: impl Into<String>,
client_secret: impl Into<String>,
redirect_uris: Vec<String>,
) -> Self {
Self {
client_id: client_id.into(),
client_secret: Some(client_secret.into()),
redirect_uris,
grant_types: vec![GrantType::AuthorizationCode, GrantType::RefreshToken],
auth_method: ClientAuthMethod::ClientSecretBasic,
pkce_required: false, name: None,
homepage_url: None,
allowed_scopes: Vec::new(),
default_scopes: Vec::new(),
}
}
pub fn with_grant_type(mut self, grant_type: GrantType) -> Self {
if !self.grant_types.contains(&grant_type) {
self.grant_types.push(grant_type);
}
self
}
pub fn with_allowed_scopes(mut self, scopes: Vec<String>) -> Self {
self.allowed_scopes = scopes;
self
}
pub fn with_default_scopes(mut self, scopes: Vec<String>) -> Self {
self.default_scopes = scopes;
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn is_redirect_uri_allowed(&self, uri: &str) -> bool {
self.redirect_uris.iter().any(|allowed| allowed == uri)
}
pub fn is_grant_type_allowed(&self, grant_type: GrantType) -> bool {
self.grant_types.contains(&grant_type)
}
pub fn is_scope_allowed(&self, scope: &str) -> bool {
self.allowed_scopes.is_empty() || self.allowed_scopes.iter().any(|s| s == scope)
}
pub fn filter_scopes(&self, requested: &[String]) -> Vec<String> {
if self.allowed_scopes.is_empty() {
requested.to_vec()
} else {
requested
.iter()
.filter(|s| self.allowed_scopes.contains(s))
.cloned()
.collect()
}
}
pub fn effective_scopes(&self, requested: &[String]) -> Vec<String> {
if requested.is_empty() {
self.default_scopes.clone()
} else {
self.filter_scopes(requested)
}
}
}
#[derive(Debug, Clone)]
pub struct OAuthProviderConfig {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub revocation_endpoint: String,
pub introspection_endpoint: String,
pub jwks_endpoint: String,
pub authorization_code_lifetime: u64,
pub access_token_lifetime: u64,
pub refresh_token_lifetime: u64,
pub issue_refresh_tokens: bool,
pub signing_algorithm: String,
pub clients: HashMap<String, ClientConfig>,
pub supported_scopes: Vec<String>,
}
impl Default for OAuthProviderConfig {
fn default() -> Self {
Self {
issuer: String::new(),
authorization_endpoint: "/oauth/authorize".to_string(),
token_endpoint: "/oauth/token".to_string(),
revocation_endpoint: "/oauth/revoke".to_string(),
introspection_endpoint: "/oauth/introspect".to_string(),
jwks_endpoint: "/.well-known/jwks.json".to_string(),
authorization_code_lifetime: 600,
access_token_lifetime: 3600,
refresh_token_lifetime: 2592000,
issue_refresh_tokens: true,
signing_algorithm: "RS256".to_string(),
clients: HashMap::new(),
supported_scopes: Vec::new(),
}
}
}
impl OAuthProviderConfig {
pub fn new(issuer: impl Into<String>) -> Self {
Self {
issuer: issuer.into(),
..Default::default()
}
}
pub fn with_client(mut self, config: ClientConfig) -> Self {
self.clients.insert(config.client_id.clone(), config);
self
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.supported_scopes = scopes;
self
}
pub fn with_access_token_lifetime(mut self, seconds: u64) -> Self {
self.access_token_lifetime = seconds;
self
}
pub fn with_refresh_token_lifetime(mut self, seconds: u64) -> Self {
self.refresh_token_lifetime = seconds;
self
}
pub fn get_client(&self, client_id: &str) -> Option<&ClientConfig> {
self.clients.get(client_id)
}
pub fn authorization_endpoint_url(&self) -> String {
format!("{}{}", self.issuer, self.authorization_endpoint)
}
pub fn token_endpoint_url(&self) -> String {
format!("{}{}", self.issuer, self.token_endpoint)
}
pub fn jwks_endpoint_url(&self) -> String {
format!("{}{}", self.issuer, self.jwks_endpoint)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthError {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub state: Option<String>,
}
impl OAuthError {
pub fn invalid_request(description: impl Into<String>) -> Self {
Self {
error: "invalid_request".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn unauthorized_client(description: impl Into<String>) -> Self {
Self {
error: "unauthorized_client".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn access_denied(description: impl Into<String>) -> Self {
Self {
error: "access_denied".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn unsupported_response_type(description: impl Into<String>) -> Self {
Self {
error: "unsupported_response_type".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn invalid_scope(description: impl Into<String>) -> Self {
Self {
error: "invalid_scope".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn server_error(description: impl Into<String>) -> Self {
Self {
error: "server_error".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn invalid_grant(description: impl Into<String>) -> Self {
Self {
error: "invalid_grant".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn invalid_client(description: impl Into<String>) -> Self {
Self {
error: "invalid_client".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn unsupported_grant_type(description: impl Into<String>) -> Self {
Self {
error: "unsupported_grant_type".to_string(),
error_description: Some(description.into()),
error_uri: None,
state: None,
}
}
pub fn with_state(mut self, state: impl Into<String>) -> Self {
self.state = Some(state.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
impl TokenResponse {
pub fn new(access_token: impl Into<String>) -> Self {
Self {
access_token: access_token.into(),
token_type: "Bearer".to_string(),
expires_in: None,
refresh_token: None,
scope: None,
}
}
pub fn with_expires_in(mut self, seconds: u64) -> Self {
self.expires_in = Some(seconds);
self
}
pub fn with_refresh_token(mut self, token: impl Into<String>) -> Self {
self.refresh_token = Some(token.into());
self
}
pub fn with_scope(mut self, scopes: &[String]) -> Self {
if !scopes.is_empty() {
self.scope = Some(scopes.join(" "));
}
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntrospectionResponse {
pub active: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sub: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub exp: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub iat: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_type: Option<String>,
}
impl IntrospectionResponse {
pub fn inactive() -> Self {
Self {
active: false,
scope: None,
client_id: None,
sub: None,
exp: None,
iat: None,
token_type: None,
}
}
pub fn active(
subject: impl Into<String>,
client_id: impl Into<String>,
scopes: &[String],
expires_at: u64,
issued_at: u64,
) -> Self {
Self {
active: true,
scope: if scopes.is_empty() {
None
} else {
Some(scopes.join(" "))
},
client_id: Some(client_id.into()),
sub: Some(subject.into()),
exp: Some(expires_at),
iat: Some(issued_at),
token_type: Some("Bearer".to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_public_client_requires_pkce() {
let client = ClientConfig::public(
"my-client",
vec!["https://app.example.com/callback".to_string()],
);
assert!(client.pkce_required);
assert!(client.client_secret.is_none());
}
#[test]
fn test_confidential_client_has_secret() {
let client = ClientConfig::confidential(
"my-client",
"secret123",
vec!["https://app.example.com/callback".to_string()],
);
assert!(!client.pkce_required);
assert!(client.client_secret.is_some());
}
#[test]
fn test_redirect_uri_validation() {
let client = ClientConfig::public(
"my-client",
vec![
"https://app.example.com/callback".to_string(),
"https://app.example.com/callback2".to_string(),
],
);
assert!(client.is_redirect_uri_allowed("https://app.example.com/callback"));
assert!(client.is_redirect_uri_allowed("https://app.example.com/callback2"));
assert!(!client.is_redirect_uri_allowed("https://evil.com/callback"));
}
#[test]
fn test_scope_filtering() {
let client = ClientConfig::public("my-client", vec![])
.with_allowed_scopes(vec!["read".to_string(), "write".to_string()]);
let filtered = client.filter_scopes(&[
"read".to_string(),
"admin".to_string(), "write".to_string(),
]);
assert_eq!(filtered, vec!["read", "write"]);
}
#[test]
fn test_oauth_error_types() {
let err = OAuthError::invalid_request("Missing client_id");
assert_eq!(err.error, "invalid_request");
let err = OAuthError::invalid_grant("Code expired").with_state("state123");
assert_eq!(err.error, "invalid_grant");
assert_eq!(err.state, Some("state123".to_string()));
}
#[test]
fn test_token_response() {
let response = TokenResponse::new("access_token_123")
.with_expires_in(3600)
.with_refresh_token("refresh_token_456")
.with_scope(&["read".to_string(), "write".to_string()]);
assert_eq!(response.access_token, "access_token_123");
assert_eq!(response.token_type, "Bearer");
assert_eq!(response.expires_in, Some(3600));
assert_eq!(
response.refresh_token,
Some("refresh_token_456".to_string())
);
assert_eq!(response.scope, Some("read write".to_string()));
}
}