use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OAuth2Error {
InvalidRequest,
InvalidToken,
InsufficientScope,
}
impl OAuth2Error {
pub fn as_str(&self) -> &'static str {
match self {
OAuth2Error::InvalidRequest => "invalid_request",
OAuth2Error::InvalidToken => "invalid_token",
OAuth2Error::InsufficientScope => "insufficient_scope",
}
}
}
impl fmt::Display for OAuth2Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone)]
pub struct IncrementalConsentChallenge {
realm: Option<String>,
scopes: Vec<String>,
error: Option<OAuth2Error>,
error_description: Option<String>,
error_uri: Option<String>,
additional_params: Vec<(String, String)>,
}
impl IncrementalConsentChallenge {
pub fn builder() -> IncrementalConsentChallengeBuilder {
IncrementalConsentChallengeBuilder::default()
}
pub fn to_header_value(&self) -> String {
let mut parts = vec!["Bearer".to_string()];
if let Some(ref realm) = self.realm {
parts.push(format!("realm=\"{}\"", escape_param_value(realm)));
}
if !self.scopes.is_empty() {
let scope_value = self.scopes.join(" ");
parts.push(format!("scope=\"{}\"", escape_param_value(&scope_value)));
}
if let Some(error) = self.error {
parts.push(format!("error=\"{}\"", error.as_str()));
}
if let Some(ref desc) = self.error_description {
parts.push(format!(
"error_description=\"{}\"",
escape_param_value(desc)
));
}
if let Some(ref uri) = self.error_uri {
parts.push(format!("error_uri=\"{}\"", escape_param_value(uri)));
}
for (key, value) in &self.additional_params {
parts.push(format!("{}=\"{}\"", key, escape_param_value(value)));
}
if parts.len() == 1 {
parts[0].clone()
} else {
format!("{} {}", parts[0], parts[1..].join(", "))
}
}
pub fn realm(&self) -> Option<&str> {
self.realm.as_deref()
}
pub fn scopes(&self) -> &[String] {
&self.scopes
}
pub fn error(&self) -> Option<OAuth2Error> {
self.error
}
pub fn error_description(&self) -> Option<&str> {
self.error_description.as_deref()
}
pub fn error_uri(&self) -> Option<&str> {
self.error_uri.as_deref()
}
}
#[derive(Debug, Default)]
pub struct IncrementalConsentChallengeBuilder {
realm: Option<String>,
scopes: Vec<String>,
error: Option<OAuth2Error>,
error_description: Option<String>,
error_uri: Option<String>,
additional_params: Vec<(String, String)>,
}
impl IncrementalConsentChallengeBuilder {
pub fn realm(mut self, realm: impl Into<String>) -> Self {
self.realm = Some(realm.into());
self
}
pub fn scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn add_scope(mut self, scope: impl Into<String>) -> Self {
self.scopes.push(scope.into());
self
}
pub fn error(mut self, error: OAuth2Error) -> Self {
self.error = Some(error);
self
}
pub fn error_description(mut self, description: impl Into<String>) -> Self {
self.error_description = Some(description.into());
self
}
pub fn error_uri(mut self, uri: impl Into<String>) -> Self {
self.error_uri = Some(uri.into());
self
}
pub fn additional_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.additional_params.push((key.into(), value.into()));
self
}
pub fn build(self) -> IncrementalConsentChallenge {
IncrementalConsentChallenge {
realm: self.realm,
scopes: self.scopes,
error: self.error,
error_description: self.error_description,
error_uri: self.error_uri,
additional_params: self.additional_params,
}
}
}
fn escape_param_value(value: &str) -> String {
value.replace('\\', "\\\\").replace('"', "\\\"")
}
pub fn insufficient_scope_challenge(
realm: impl Into<String>,
required_scopes: Vec<String>,
description: Option<&str>,
) -> IncrementalConsentChallenge {
let mut builder = IncrementalConsentChallenge::builder()
.realm(realm)
.scopes(required_scopes)
.error(OAuth2Error::InsufficientScope);
if let Some(desc) = description {
builder = builder.error_description(desc);
}
builder.build()
}
pub fn invalid_token_challenge(
realm: impl Into<String>,
description: Option<&str>,
) -> IncrementalConsentChallenge {
let mut builder = IncrementalConsentChallenge::builder()
.realm(realm)
.error(OAuth2Error::InvalidToken);
if let Some(desc) = description {
builder = builder.error_description(desc);
}
builder.build()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oauth2_error_codes() {
assert_eq!(OAuth2Error::InvalidRequest.as_str(), "invalid_request");
assert_eq!(OAuth2Error::InvalidToken.as_str(), "invalid_token");
assert_eq!(
OAuth2Error::InsufficientScope.as_str(),
"insufficient_scope"
);
}
#[test]
fn test_basic_challenge() {
let challenge = IncrementalConsentChallenge::builder()
.realm("Example Realm")
.build();
let header = challenge.to_header_value();
assert_eq!(header, "Bearer realm=\"Example Realm\"");
}
#[test]
fn test_challenge_with_scopes() {
let challenge = IncrementalConsentChallenge::builder()
.realm("API")
.scopes(vec!["read".to_string(), "write".to_string()])
.build();
let header = challenge.to_header_value();
assert!(header.contains("realm=\"API\""));
assert!(header.contains("scope=\"read write\""));
}
#[test]
fn test_insufficient_scope_challenge() {
let challenge = IncrementalConsentChallenge::builder()
.realm("MCP Server")
.scopes(vec!["write:tools".to_string()])
.error(OAuth2Error::InsufficientScope)
.error_description("Additional permissions required")
.build();
let header = challenge.to_header_value();
assert!(header.contains("realm=\"MCP Server\""));
assert!(header.contains("scope=\"write:tools\""));
assert!(header.contains("error=\"insufficient_scope\""));
assert!(header.contains("error_description=\"Additional permissions required\""));
}
#[test]
fn test_challenge_with_error_uri() {
let challenge = IncrementalConsentChallenge::builder()
.realm("API")
.error(OAuth2Error::InvalidToken)
.error_uri("https://docs.example.com/errors/invalid_token")
.build();
let header = challenge.to_header_value();
assert!(header.contains("error=\"invalid_token\""));
assert!(header.contains("error_uri=\"https://docs.example.com/errors/invalid_token\""));
}
#[test]
fn test_escape_param_value() {
let value = "Hello \"World\" with \\backslash";
let escaped = escape_param_value(value);
assert_eq!(escaped, "Hello \\\"World\\\" with \\\\backslash");
}
#[test]
fn test_challenge_with_special_characters() {
let challenge = IncrementalConsentChallenge::builder()
.realm("Realm with \"quotes\"")
.error_description("Error with \\backslash")
.build();
let header = challenge.to_header_value();
assert!(header.contains("realm=\"Realm with \\\"quotes\\\"\""));
assert!(header.contains("error_description=\"Error with \\\\backslash\""));
}
#[test]
fn test_insufficient_scope_helper() {
let challenge = insufficient_scope_challenge(
"MCP Tools",
vec!["write:tools".to_string(), "read:resources".to_string()],
Some("Tool execution requires additional scopes"),
);
assert_eq!(challenge.realm(), Some("MCP Tools"));
assert_eq!(challenge.scopes(), &["write:tools", "read:resources"]);
assert_eq!(challenge.error(), Some(OAuth2Error::InsufficientScope));
assert_eq!(
challenge.error_description(),
Some("Tool execution requires additional scopes")
);
}
#[test]
fn test_invalid_token_helper() {
let challenge = invalid_token_challenge("MCP API", Some("Token has expired"));
assert_eq!(challenge.realm(), Some("MCP API"));
assert_eq!(challenge.error(), Some(OAuth2Error::InvalidToken));
assert_eq!(challenge.error_description(), Some("Token has expired"));
}
#[test]
fn test_additional_params() {
let challenge = IncrementalConsentChallenge::builder()
.realm("API")
.additional_param("custom_param", "custom_value")
.build();
let header = challenge.to_header_value();
assert!(header.contains("custom_param=\"custom_value\""));
}
#[test]
fn test_bearer_only() {
let challenge = IncrementalConsentChallenge::builder().build();
let header = challenge.to_header_value();
assert_eq!(header, "Bearer");
}
#[test]
fn test_multiple_scopes_ordering() {
let challenge = IncrementalConsentChallenge::builder()
.scope("read")
.scope("write")
.scope("delete")
.build();
let header = challenge.to_header_value();
assert!(header.contains("scope=\"read write delete\""));
}
}