use crate::errors::{AuthError, Result};
use html_escape;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ResponseMode {
Query,
Fragment,
FormPost,
JwtQuery,
JwtFragment,
JwtFormPost,
}
#[derive(Debug, Clone)]
pub struct MultipleResponseTypesManager {
config: MultipleResponseTypesConfig,
}
#[derive(Debug, Clone)]
pub struct MultipleResponseTypesConfig {
pub supported_response_types: Vec<String>,
pub enable_multiple_types: bool,
}
impl Default for MultipleResponseTypesConfig {
fn default() -> Self {
Self {
supported_response_types: vec![
"code".to_string(),
"token".to_string(),
"id_token".to_string(),
"code token".to_string(),
"code id_token".to_string(),
"token id_token".to_string(),
"code token id_token".to_string(),
],
enable_multiple_types: true,
}
}
}
#[derive(Debug, Clone)]
pub struct FormPostResponseMode {
pub redirect_uri: String,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct JarmResponseMode {
pub response_token: String,
pub delivery_mode: ResponseMode,
}
impl MultipleResponseTypesManager {
pub fn new(config: MultipleResponseTypesConfig) -> Self {
Self { config }
}
pub fn parse_response_type(&self, response_type: &str) -> Result<Vec<String>> {
let types: Vec<String> = response_type
.split_whitespace()
.map(|s| s.to_string())
.collect();
for response_type in &types {
if !self.is_supported_response_type(response_type) {
return Err(AuthError::validation(format!(
"Unsupported response_type: {}",
response_type
)));
}
}
self.validate_response_type_combination(&types)?;
Ok(types)
}
pub fn is_supported_response_type(&self, response_type: &str) -> bool {
let full_type = match response_type {
"code" | "token" | "id_token" => response_type.to_string(),
_ => return false,
};
self.config.supported_response_types.contains(&full_type)
|| self
.config
.supported_response_types
.iter()
.any(|t| t.contains(response_type))
}
fn validate_response_type_combination(&self, types: &[String]) -> Result<()> {
if types.is_empty() {
return Err(AuthError::validation("Empty response_type"));
}
if types.contains(&"token".to_string()) || types.contains(&"id_token".to_string()) {
}
if types.len() > 3 {
return Err(AuthError::validation("Too many response types"));
}
Ok(())
}
pub async fn generate_response(
&self,
response_types: &[String],
authorization_code: Option<String>,
access_token: Option<String>,
id_token: Option<String>,
) -> Result<HashMap<String, String>> {
let mut response = HashMap::new();
for response_type in response_types {
match response_type.as_str() {
"code" => {
if let Some(code) = &authorization_code {
response.insert("code".to_string(), code.clone());
}
}
"token" => {
if let Some(token) = &access_token {
response.insert("access_token".to_string(), token.clone());
response.insert("token_type".to_string(), "Bearer".to_string());
response.insert("expires_in".to_string(), "3600".to_string());
}
}
"id_token" => {
if let Some(token) = &id_token {
response.insert("id_token".to_string(), token.clone());
}
}
_ => {
return Err(AuthError::validation(format!(
"Unsupported response type: {}",
response_type
)));
}
}
}
Ok(response)
}
}
impl FormPostResponseMode {
pub fn new(redirect_uri: String, parameters: HashMap<String, String>) -> Self {
Self {
redirect_uri,
parameters,
}
}
pub fn generate_html_form(&self) -> String {
let mut form = format!(
r#"<!DOCTYPE html>
<html>
<head>
<title>Authorization Response</title>
</head>
<body>
<form method="post" action="{}" id="response_form">
"#,
self.redirect_uri
);
for (name, value) in &self.parameters {
form.push_str(&format!(
r#" <input type="hidden" name="{}" value="{}" />
"#,
html_escape::encode_text(name),
html_escape::encode_text(value)
));
}
form.push_str(
r#" </form>
<script>
window.onload = function() {
document.getElementById('response_form').submit();
};
</script>
</body>
</html>"#,
);
form
}
}
impl JarmResponseMode {
pub fn new(response_token: String, delivery_mode: ResponseMode) -> Self {
Self {
response_token,
delivery_mode,
}
}
pub fn generate_response_parameters(&self) -> HashMap<String, String> {
let mut params = HashMap::new();
params.insert("response".to_string(), self.response_token.clone());
params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multiple_response_types_parsing() {
let manager = MultipleResponseTypesManager::new(MultipleResponseTypesConfig::default());
let result = manager.parse_response_type("code").unwrap();
assert_eq!(result, vec!["code"]);
let result = manager.parse_response_type("code token").unwrap();
assert_eq!(result, vec!["code", "token"]);
assert!(manager.parse_response_type("invalid").is_err());
}
#[test]
fn test_form_post_html_generation() {
let mut params = HashMap::new();
params.insert("code".to_string(), "auth_code_123".to_string());
params.insert("state".to_string(), "client_state".to_string());
let form_post =
FormPostResponseMode::new("https://client.example.com/callback".to_string(), params);
let html = form_post.generate_html_form();
assert!(html.contains("auth_code_123"));
assert!(html.contains("client_state"));
assert!(html.contains("https://client.example.com/callback"));
}
#[test]
fn test_jarm_response_generation() {
let jarm = JarmResponseMode::new(
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9...".to_string(),
ResponseMode::JwtQuery,
);
let params = jarm.generate_response_parameters();
assert!(params.contains_key("response"));
assert!(params["response"].starts_with("eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9"));
}
}