use std::collections::HashSet;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct OAuthConfig {
pub resource_uri: String,
pub authorization_server: String,
pub scopes_supported: Vec<String>,
pub require_auth: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum CodeChallengeMethod {
S256,
}
#[must_use]
pub fn generate_code_verifier() -> String {
use std::fmt::Write;
let bytes: [u8; 32] = rand_bytes();
let mut verifier = String::with_capacity(43);
for b in &bytes {
let _ = write!(verifier, "{:02x}", b);
}
verifier
}
#[must_use]
pub fn compute_code_challenge(verifier: &str) -> String {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(verifier.as_bytes());
base64_url_encode(&hash)
}
#[must_use]
pub fn verify_pkce(verifier: &str, challenge: &str) -> bool {
let computed = compute_code_challenge(verifier);
constant_time_eq(computed.as_bytes(), challenge.as_bytes())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct TokenClaims {
pub sub: String,
pub scopes: HashSet<String>,
pub exp: u64,
pub resource: Option<String>,
}
impl TokenClaims {
#[must_use]
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.contains(scope)
}
#[must_use]
pub fn is_expired(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
now >= self.exp
}
#[must_use]
pub fn valid_for_resource(&self, resource_uri: &str) -> bool {
self.resource
.as_ref()
.map(|r| r == resource_uri)
.unwrap_or(true) }
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum TokenValidation {
Valid(TokenClaims),
Expired,
InsufficientScope { required: String },
WrongResource { expected: String, got: String },
Missing,
Invalid(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ProtectedResourceMetadata {
pub resource: String,
pub authorization_servers: Vec<String>,
#[serde(default)]
pub scopes_supported: Vec<String>,
}
impl ProtectedResourceMetadata {
#[must_use]
pub fn from_config(config: &OAuthConfig) -> Self {
Self {
resource: config.resource_uri.clone(),
authorization_servers: vec![config.authorization_server.clone()],
scopes_supported: config.scopes_supported.clone(),
}
}
}
#[must_use]
pub fn www_authenticate_header(metadata_url: &str) -> String {
format!(r#"Bearer resource_metadata="{metadata_url}""#)
}
#[must_use]
pub fn insufficient_scope_header(required_scope: &str) -> String {
format!(r#"Bearer error="insufficient_scope", scope="{required_scope}""#)
}
pub trait TokenValidator: Send + Sync {
fn validate_token(&self, token: &str) -> TokenValidation;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ClientMetadata {
pub client_id: String,
#[serde(default)]
pub redirect_uris: Vec<String>,
#[serde(default)]
pub client_name: Option<String>,
#[serde(default)]
pub grant_types: Vec<String>,
#[serde(default)]
pub response_types: Vec<String>,
#[serde(default)]
pub token_endpoint_auth_method: Option<String>,
}
fn rand_bytes<const N: usize>() -> [u8; N] {
let mut buf = [0u8; N];
getrandom::getrandom(&mut buf).expect("getrandom failed");
buf
}
fn base64_url_encode(data: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pkce_roundtrip() {
let verifier = generate_code_verifier();
assert!(verifier.len() >= 43);
let challenge = compute_code_challenge(&verifier);
assert!(verify_pkce(&verifier, &challenge));
}
#[test]
fn pkce_wrong_verifier_fails() {
let verifier = generate_code_verifier();
let challenge = compute_code_challenge(&verifier);
assert!(!verify_pkce("wrong-verifier", &challenge));
}
#[test]
fn token_claims_scope_check() {
let claims = TokenClaims {
sub: "client-1".into(),
scopes: ["read".into(), "write".into()].into_iter().collect(),
exp: u64::MAX,
resource: None,
};
assert!(claims.has_scope("read"));
assert!(claims.has_scope("write"));
assert!(!claims.has_scope("admin"));
}
#[test]
fn token_claims_expired() {
let claims = TokenClaims {
sub: "client-1".into(),
scopes: HashSet::new(),
exp: 0, resource: None,
};
assert!(claims.is_expired());
}
#[test]
fn token_claims_not_expired() {
let claims = TokenClaims {
sub: "client-1".into(),
scopes: HashSet::new(),
exp: u64::MAX,
resource: None,
};
assert!(!claims.is_expired());
}
#[test]
fn token_resource_validation() {
let claims = TokenClaims {
sub: "client-1".into(),
scopes: HashSet::new(),
exp: u64::MAX,
resource: Some("https://mcp.example.com/mcp".into()),
};
assert!(claims.valid_for_resource("https://mcp.example.com/mcp"));
assert!(!claims.valid_for_resource("https://other.example.com/mcp"));
}
#[test]
fn token_no_resource_valid_for_any() {
let claims = TokenClaims {
sub: "client-1".into(),
scopes: HashSet::new(),
exp: u64::MAX,
resource: None,
};
assert!(claims.valid_for_resource("https://anything.com"));
}
#[test]
fn protected_resource_metadata_from_config() {
let config = OAuthConfig {
resource_uri: "https://mcp.example.com/mcp".into(),
authorization_server: "https://auth.example.com".into(),
scopes_supported: vec!["read".into(), "write".into()],
require_auth: true,
};
let meta = ProtectedResourceMetadata::from_config(&config);
assert_eq!(meta.resource, "https://mcp.example.com/mcp");
assert_eq!(meta.authorization_servers.len(), 1);
assert_eq!(meta.scopes_supported.len(), 2);
}
#[test]
fn www_authenticate_header_format() {
let header =
www_authenticate_header("https://mcp.example.com/.well-known/oauth-protected-resource");
assert!(header.starts_with("Bearer resource_metadata="));
}
#[test]
fn insufficient_scope_header_format() {
let header = insufficient_scope_header("admin");
assert!(header.contains("insufficient_scope"));
assert!(header.contains("admin"));
}
#[test]
fn oauth_config_default() {
let config = OAuthConfig::default();
assert!(!config.require_auth);
assert!(config.resource_uri.is_empty());
}
#[test]
fn client_metadata_serde() {
let meta = ClientMetadata {
client_id: "https://app.example.com/oauth/client-metadata.json".into(),
redirect_uris: vec!["https://app.example.com/callback".into()],
client_name: Some("Test App".into()),
grant_types: vec!["authorization_code".into()],
response_types: vec!["code".into()],
token_endpoint_auth_method: Some("none".into()),
};
let json = serde_json::to_string(&meta).unwrap();
let back: ClientMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(back.client_id, meta.client_id);
}
}