use std::fmt;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum OAuthError {
MissingToken,
InvalidToken {
description: String,
},
InsufficientScope {
required: Vec<String>,
provided: Vec<String>,
},
InvalidAudience,
ExpiredToken,
}
impl OAuthError {
pub fn status_code(&self) -> u16 {
match self {
OAuthError::InsufficientScope { .. } => 403,
_ => 401,
}
}
pub fn www_authenticate(&self, resource_metadata_url: Option<&str>) -> String {
let mut parts = Vec::new();
if let Some(url) = resource_metadata_url {
parts.push(format!("resource_metadata=\"{}\"", url));
}
match self {
OAuthError::MissingToken => {
if parts.is_empty() {
return "Bearer".to_string();
}
format!("Bearer {}", parts.join(", "))
}
OAuthError::InvalidToken { description } => {
parts.push("error=\"invalid_token\"".to_string());
parts.push(format!("error_description=\"{}\"", description));
format!("Bearer {}", parts.join(", "))
}
OAuthError::InsufficientScope { required, .. } => {
parts.push("error=\"insufficient_scope\"".to_string());
if !required.is_empty() {
parts.push(format!("scope=\"{}\"", required.join(" ")));
}
format!("Bearer {}", parts.join(", "))
}
OAuthError::InvalidAudience => {
parts.push("error=\"invalid_token\"".to_string());
parts.push(
"error_description=\"The token audience does not match this resource\""
.to_string(),
);
format!("Bearer {}", parts.join(", "))
}
OAuthError::ExpiredToken => {
parts.push("error=\"invalid_token\"".to_string());
parts.push("error_description=\"The access token has expired\"".to_string());
format!("Bearer {}", parts.join(", "))
}
}
}
}
impl fmt::Display for OAuthError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
OAuthError::MissingToken => write!(f, "missing bearer token"),
OAuthError::InvalidToken { description } => {
write!(f, "invalid token: {}", description)
}
OAuthError::InsufficientScope { required, provided } => write!(
f,
"insufficient scope: required [{}], provided [{}]",
required.join(", "),
provided.join(", ")
),
OAuthError::InvalidAudience => write!(f, "token audience does not match"),
OAuthError::ExpiredToken => write!(f, "token has expired"),
}
}
}
impl std::error::Error for OAuthError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_missing_token_no_metadata() {
let err = OAuthError::MissingToken;
assert_eq!(err.status_code(), 401);
assert_eq!(err.www_authenticate(None), "Bearer");
}
#[test]
fn test_missing_token_with_metadata() {
let err = OAuthError::MissingToken;
assert_eq!(err.status_code(), 401);
let header = err.www_authenticate(Some(
"https://example.com/.well-known/oauth-protected-resource",
));
assert!(header.starts_with("Bearer "));
assert!(header.contains("resource_metadata="));
}
#[test]
fn test_invalid_token() {
let err = OAuthError::InvalidToken {
description: "signature mismatch".to_string(),
};
assert_eq!(err.status_code(), 401);
let header = err.www_authenticate(None);
assert!(header.contains("error=\"invalid_token\""));
assert!(header.contains("error_description=\"signature mismatch\""));
}
#[test]
fn test_insufficient_scope() {
let err = OAuthError::InsufficientScope {
required: vec!["mcp:admin".to_string()],
provided: vec!["mcp:read".to_string()],
};
assert_eq!(err.status_code(), 403);
let header = err.www_authenticate(None);
assert!(header.contains("error=\"insufficient_scope\""));
assert!(header.contains("scope=\"mcp:admin\""));
}
#[test]
fn test_invalid_audience() {
let err = OAuthError::InvalidAudience;
assert_eq!(err.status_code(), 401);
let header = err.www_authenticate(None);
assert!(header.contains("error=\"invalid_token\""));
assert!(header.contains("audience"));
}
#[test]
fn test_expired_token() {
let err = OAuthError::ExpiredToken;
assert_eq!(err.status_code(), 401);
let header = err.www_authenticate(None);
assert!(header.contains("error=\"invalid_token\""));
assert!(header.contains("expired"));
}
#[test]
fn test_display() {
assert_eq!(OAuthError::MissingToken.to_string(), "missing bearer token");
assert_eq!(OAuthError::ExpiredToken.to_string(), "token has expired");
assert_eq!(
OAuthError::InvalidAudience.to_string(),
"token audience does not match"
);
}
}