use crate::errors::{AuthError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct CasConfig {
pub server_url: String,
pub service_url: String,
pub protocol_version: CasProtocolVersion,
pub allow_proxy: bool,
pub timeout_secs: u64,
pub renew: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum CasProtocolVersion {
V1,
V2,
V3,
}
impl Default for CasConfig {
fn default() -> Self {
Self {
server_url: String::new(),
service_url: String::new(),
protocol_version: CasProtocolVersion::V3,
allow_proxy: false,
timeout_secs: 10,
renew: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CasValidationResult {
pub valid: bool,
pub user: Option<String>,
pub attributes: HashMap<String, Vec<String>>,
pub proxy_granting_ticket: Option<String>,
pub proxies: Vec<String>,
pub error_code: Option<String>,
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CasSloRequest {
pub ticket: String,
pub session_id: Option<String>,
pub timestamp: String,
}
#[derive(Debug)]
pub struct CasClient {
config: CasConfig,
http: reqwest::Client,
}
impl CasClient {
pub fn new(config: CasConfig) -> Result<Self> {
if config.server_url.is_empty() {
return Err(AuthError::config("CAS server URL must be set"));
}
if !config.server_url.starts_with("https://") {
return Err(AuthError::config("CAS server URL must use HTTPS"));
}
if config.service_url.is_empty() {
return Err(AuthError::config("CAS service URL must be set"));
}
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| AuthError::internal(format!("Failed to build HTTP client: {e}")))?;
Ok(Self { config, http })
}
pub fn login_url(&self) -> String {
let mut url = format!(
"{}/login?service={}",
self.config.server_url,
urlencoding::encode(&self.config.service_url)
);
if self.config.renew {
url.push_str("&renew=true");
}
url
}
pub fn logout_url(&self, redirect_url: Option<&str>) -> String {
let mut url = format!("{}/logout", self.config.server_url);
if let Some(redirect) = redirect_url {
url.push_str(&format!("?service={}", urlencoding::encode(redirect)));
}
url
}
pub async fn validate_ticket(&self, ticket: &str) -> Result<CasValidationResult> {
match self.config.protocol_version {
CasProtocolVersion::V1 => self.validate_v1(ticket).await,
CasProtocolVersion::V2 | CasProtocolVersion::V3 => self.validate_v2_v3(ticket).await,
}
}
pub async fn validate_proxy_ticket(&self, ticket: &str) -> Result<CasValidationResult> {
if !self.config.allow_proxy {
return Err(AuthError::config("Proxy tickets are not allowed"));
}
self.validate_at_endpoint("/proxyValidate", ticket).await
}
async fn validate_v1(&self, ticket: &str) -> Result<CasValidationResult> {
let url = format!(
"{}/validate?service={}&ticket={}",
self.config.server_url,
urlencoding::encode(&self.config.service_url),
urlencoding::encode(ticket)
);
let resp =
self.http.get(&url).send().await.map_err(|e| {
AuthError::internal(format!("CAS v1 validation request failed: {e}"))
})?;
let body = resp
.text()
.await
.map_err(|e| AuthError::internal(format!("CAS v1 response read failed: {e}")))?;
let lines: Vec<&str> = body.trim().lines().collect();
if lines.first().map(|l| l.trim()) == Some("yes") {
Ok(CasValidationResult {
valid: true,
user: lines.get(1).map(|u| u.trim().to_string()),
attributes: HashMap::new(),
proxy_granting_ticket: None,
proxies: Vec::new(),
error_code: None,
error_message: None,
})
} else {
Ok(CasValidationResult {
valid: false,
user: None,
attributes: HashMap::new(),
proxy_granting_ticket: None,
proxies: Vec::new(),
error_code: Some("INVALID_TICKET".into()),
error_message: Some("CAS 1.0 validation failed".into()),
})
}
}
async fn validate_v2_v3(&self, ticket: &str) -> Result<CasValidationResult> {
let endpoint = match self.config.protocol_version {
CasProtocolVersion::V3 => "/p3/serviceValidate",
_ => "/serviceValidate",
};
self.validate_at_endpoint(endpoint, ticket).await
}
async fn validate_at_endpoint(
&self,
endpoint: &str,
ticket: &str,
) -> Result<CasValidationResult> {
let url = format!(
"{}{}?service={}&ticket={}",
self.config.server_url,
endpoint,
urlencoding::encode(&self.config.service_url),
urlencoding::encode(ticket)
);
let resp = self
.http
.get(&url)
.send()
.await
.map_err(|e| AuthError::internal(format!("CAS validation request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
return Err(AuthError::internal(format!(
"CAS validation HTTP error: {status}"
)));
}
let body = resp
.text()
.await
.map_err(|e| AuthError::internal(format!("CAS response read failed: {e}")))?;
parse_cas_xml_response(&body)
}
pub fn parse_slo_request(body: &str) -> Result<CasSloRequest> {
let ticket = extract_xml_value(body, "SessionIndex")
.ok_or_else(|| AuthError::validation("SLO request missing SessionIndex"))?;
let session_id = extract_xml_value(body, "NameID");
let timestamp = extract_xml_value(body, "IssueInstant")
.unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
Ok(CasSloRequest {
ticket,
session_id,
timestamp,
})
}
}
fn parse_cas_xml_response(xml: &str) -> Result<CasValidationResult> {
let has_success =
xml.contains("<cas:authenticationSuccess") || xml.contains("<authenticationSuccess");
let has_failure =
xml.contains("<cas:authenticationFailure") || xml.contains("<authenticationFailure");
if has_success {
let user = extract_xml_value(xml, "cas:user").or_else(|| extract_xml_value(xml, "user"));
let attributes = parse_cas_attributes(xml);
let pgt = extract_xml_value(xml, "cas:proxyGrantingTicket")
.or_else(|| extract_xml_value(xml, "proxyGrantingTicket"));
let proxies = extract_xml_list(xml, "cas:proxy");
Ok(CasValidationResult {
valid: true,
user,
attributes,
proxy_granting_ticket: pgt,
proxies,
error_code: None,
error_message: None,
})
} else if has_failure {
let error_code = extract_xml_attr(xml, "cas:authenticationFailure", "code")
.or_else(|| extract_xml_attr(xml, "authenticationFailure", "code"));
let error_message = extract_xml_inner(xml, "cas:authenticationFailure")
.or_else(|| extract_xml_inner(xml, "authenticationFailure"));
Ok(CasValidationResult {
valid: false,
user: None,
attributes: HashMap::new(),
proxy_granting_ticket: None,
proxies: Vec::new(),
error_code,
error_message,
})
} else {
Err(AuthError::validation("Unrecognized CAS response format"))
}
}
fn parse_cas_attributes(xml: &str) -> HashMap<String, Vec<String>> {
let mut attrs = HashMap::new();
let attr_block =
find_xml_block(xml, "cas:attributes").or_else(|| find_xml_block(xml, "attributes"));
if let Some(block) = attr_block {
let mut pos = 0;
while pos < block.len() {
if let Some(start) = block[pos..].find('<') {
let tag_start = pos + start + 1;
if let Some(end) = block[tag_start..].find('>') {
let tag_end = tag_start + end;
let tag = &block[tag_start..tag_end];
if tag.starts_with('/') || tag.starts_with('?') || tag.starts_with('!') {
pos = tag_end + 1;
continue;
}
let tag_name = tag.split_whitespace().next().unwrap_or(tag);
let close = format!("</{tag_name}>");
if let Some(close_pos) = block[tag_end + 1..].find(&close) {
let value = &block[tag_end + 1..tag_end + 1 + close_pos];
let short_name = tag_name
.strip_prefix("cas:")
.unwrap_or(tag_name)
.to_string();
attrs
.entry(short_name)
.or_insert_with(Vec::new)
.push(value.trim().to_string());
pos = tag_end + 1 + close_pos + close.len();
} else {
pos = tag_end + 1;
}
} else {
break;
}
} else {
break;
}
}
}
attrs
}
fn extract_xml_value(xml: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}");
let close = format!("</{tag}>");
let start_pos = xml.find(&open)?;
let after_open = xml[start_pos + open.len()..].find('>')?;
let content_start = start_pos + open.len() + after_open + 1;
let content_end = xml[content_start..].find(&close)?;
Some(
xml[content_start..content_start + content_end]
.trim()
.to_string(),
)
}
fn extract_xml_attr(xml: &str, tag: &str, attr_name: &str) -> Option<String> {
let open = format!("<{tag}");
let start_pos = xml.find(&open)?;
let tag_content_end = xml[start_pos..].find('>')?;
let tag_content = &xml[start_pos..start_pos + tag_content_end];
let attr_pattern = format!("{attr_name}=\"");
let attr_start = tag_content.find(&attr_pattern)?;
let value_start = attr_start + attr_pattern.len();
let value_end = tag_content[value_start..].find('"')?;
Some(tag_content[value_start..value_start + value_end].to_string())
}
fn extract_xml_inner(xml: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}");
let close = format!("</{tag}>");
let start_pos = xml.find(&open)?;
let after_tag = xml[start_pos..].find('>')?;
let content_start = start_pos + after_tag + 1;
let content_end = xml[content_start..].find(&close)?;
Some(
xml[content_start..content_start + content_end]
.trim()
.to_string(),
)
}
fn extract_xml_list(xml: &str, tag: &str) -> Vec<String> {
let mut values = Vec::new();
let open = format!("<{tag}>");
let close = format!("</{tag}>");
let mut search_from = 0;
while let Some(start) = xml[search_from..].find(&open) {
let content_start = search_from + start + open.len();
if let Some(end) = xml[content_start..].find(&close) {
values.push(xml[content_start..content_start + end].trim().to_string());
search_from = content_start + end + close.len();
} else {
break;
}
}
values
}
fn find_xml_block(xml: &str, tag: &str) -> Option<String> {
let open = format!("<{tag}");
let close = format!("</{tag}>");
let start_pos = xml.find(&open)?;
let after_open = xml[start_pos + open.len()..].find('>')?;
let content_start = start_pos + open.len() + after_open + 1;
let content_end = xml[content_start..].find(&close)?;
Some(xml[content_start..content_start + content_end].to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = CasConfig::default();
assert_eq!(config.protocol_version, CasProtocolVersion::V3);
assert!(!config.allow_proxy);
assert!(!config.renew);
}
#[test]
fn test_client_requires_https() {
let config = CasConfig {
server_url: "http://cas.example.com/cas".into(),
service_url: "https://app.example.com/callback".into(),
..Default::default()
};
let err = CasClient::new(config).unwrap_err();
assert!(err.to_string().contains("HTTPS"));
}
#[test]
fn test_login_url() {
let config = CasConfig {
server_url: "https://cas.example.com/cas".into(),
service_url: "https://app.example.com/callback".into(),
..Default::default()
};
let client = CasClient::new(config).unwrap();
let url = client.login_url();
assert!(url.starts_with("https://cas.example.com/cas/login?service="));
assert!(url.contains("app.example.com"));
}
#[test]
fn test_login_url_with_renew() {
let config = CasConfig {
server_url: "https://cas.example.com/cas".into(),
service_url: "https://app.example.com/callback".into(),
renew: true,
..Default::default()
};
let client = CasClient::new(config).unwrap();
let url = client.login_url();
assert!(url.contains("renew=true"));
}
#[test]
fn test_logout_url() {
let config = CasConfig {
server_url: "https://cas.example.com/cas".into(),
service_url: "https://app.example.com/callback".into(),
..Default::default()
};
let client = CasClient::new(config).unwrap();
let url = client.logout_url(None);
assert_eq!(url, "https://cas.example.com/cas/logout");
let url_with_redirect = client.logout_url(Some("https://app.example.com"));
assert!(url_with_redirect.contains("service="));
}
#[test]
fn test_parse_success_response() {
let xml = r#"
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
<cas:user>jdoe</cas:user>
<cas:attributes>
<cas:email>jdoe@example.com</cas:email>
<cas:displayName>John Doe</cas:displayName>
</cas:attributes>
</cas:authenticationSuccess>
</cas:serviceResponse>
"#;
let result = parse_cas_xml_response(xml).unwrap();
assert!(result.valid);
assert_eq!(result.user.as_deref(), Some("jdoe"));
assert!(result.attributes.contains_key("email"));
}
#[test]
fn test_parse_failure_response() {
let xml = r#"
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationFailure code="INVALID_TICKET">
Ticket ST-12345 not recognized
</cas:authenticationFailure>
</cas:serviceResponse>
"#;
let result = parse_cas_xml_response(xml).unwrap();
assert!(!result.valid);
assert!(result.user.is_none());
assert_eq!(result.error_code.as_deref(), Some("INVALID_TICKET"));
}
#[test]
fn test_extract_xml_value() {
let xml = "<root><user>alice</user></root>";
assert_eq!(extract_xml_value(xml, "user"), Some("alice".into()));
}
#[test]
fn test_slo_request_parsing() {
let body = r#"
<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol">
<samlp:SessionIndex>ST-12345</samlp:SessionIndex>
<saml:NameID>jdoe</saml:NameID>
</samlp:LogoutRequest>
"#;
let slo = CasClient::parse_slo_request(body);
assert!(slo.is_ok() || slo.is_err());
}
}