auth_framework/protocols/
cas.rs1use crate::errors::{AuthError, Result};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25
26#[derive(Debug, Clone)]
30pub struct CasConfig {
31 pub server_url: String,
33
34 pub service_url: String,
36
37 pub protocol_version: CasProtocolVersion,
39
40 pub allow_proxy: bool,
42
43 pub timeout_secs: u64,
45
46 pub renew: bool,
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
52pub enum CasProtocolVersion {
53 V1,
55 V2,
57 V3,
59}
60
61impl Default for CasConfig {
62 fn default() -> Self {
63 Self {
64 server_url: String::new(),
65 service_url: String::new(),
66 protocol_version: CasProtocolVersion::V3,
67 allow_proxy: false,
68 timeout_secs: 10,
69 renew: false,
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CasValidationResult {
79 pub valid: bool,
81
82 pub user: Option<String>,
84
85 pub attributes: HashMap<String, Vec<String>>,
87
88 pub proxy_granting_ticket: Option<String>,
90
91 pub proxies: Vec<String>,
93
94 pub error_code: Option<String>,
96
97 pub error_message: Option<String>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct CasSloRequest {
104 pub ticket: String,
106
107 pub session_id: Option<String>,
109
110 pub timestamp: String,
112}
113
114#[derive(Debug)]
118pub struct CasClient {
119 config: CasConfig,
120 http: reqwest::Client,
121}
122
123impl CasClient {
124 pub fn new(config: CasConfig) -> Result<Self> {
126 if config.server_url.is_empty() {
127 return Err(AuthError::config("CAS server URL must be set"));
128 }
129 if !config.server_url.starts_with("https://") {
130 return Err(AuthError::config("CAS server URL must use HTTPS"));
131 }
132 if config.service_url.is_empty() {
133 return Err(AuthError::config("CAS service URL must be set"));
134 }
135
136 let http = reqwest::Client::builder()
137 .timeout(std::time::Duration::from_secs(config.timeout_secs))
138 .build()
139 .map_err(|e| AuthError::internal(format!("Failed to build HTTP client: {e}")))?;
140
141 Ok(Self { config, http })
142 }
143
144 pub fn login_url(&self) -> String {
146 let mut url = format!(
147 "{}/login?service={}",
148 self.config.server_url,
149 urlencoding::encode(&self.config.service_url)
150 );
151 if self.config.renew {
152 url.push_str("&renew=true");
153 }
154 url
155 }
156
157 pub fn logout_url(&self, redirect_url: Option<&str>) -> String {
159 let mut url = format!("{}/logout", self.config.server_url);
160 if let Some(redirect) = redirect_url {
161 url.push_str(&format!("?service={}", urlencoding::encode(redirect)));
162 }
163 url
164 }
165
166 pub async fn validate_ticket(&self, ticket: &str) -> Result<CasValidationResult> {
168 match self.config.protocol_version {
169 CasProtocolVersion::V1 => self.validate_v1(ticket).await,
170 CasProtocolVersion::V2 | CasProtocolVersion::V3 => self.validate_v2_v3(ticket).await,
171 }
172 }
173
174 pub async fn validate_proxy_ticket(&self, ticket: &str) -> Result<CasValidationResult> {
176 if !self.config.allow_proxy {
177 return Err(AuthError::config("Proxy tickets are not allowed"));
178 }
179 self.validate_at_endpoint("/proxyValidate", ticket).await
180 }
181
182 async fn validate_v1(&self, ticket: &str) -> Result<CasValidationResult> {
184 let url = format!(
185 "{}/validate?service={}&ticket={}",
186 self.config.server_url,
187 urlencoding::encode(&self.config.service_url),
188 urlencoding::encode(ticket)
189 );
190
191 let resp =
192 self.http.get(&url).send().await.map_err(|e| {
193 AuthError::internal(format!("CAS v1 validation request failed: {e}"))
194 })?;
195
196 let body = resp
197 .text()
198 .await
199 .map_err(|e| AuthError::internal(format!("CAS v1 response read failed: {e}")))?;
200
201 let lines: Vec<&str> = body.trim().lines().collect();
203 if lines.first().map(|l| l.trim()) == Some("yes") {
204 Ok(CasValidationResult {
205 valid: true,
206 user: lines.get(1).map(|u| u.trim().to_string()),
207 attributes: HashMap::new(),
208 proxy_granting_ticket: None,
209 proxies: Vec::new(),
210 error_code: None,
211 error_message: None,
212 })
213 } else {
214 Ok(CasValidationResult {
215 valid: false,
216 user: None,
217 attributes: HashMap::new(),
218 proxy_granting_ticket: None,
219 proxies: Vec::new(),
220 error_code: Some("INVALID_TICKET".into()),
221 error_message: Some("CAS 1.0 validation failed".into()),
222 })
223 }
224 }
225
226 async fn validate_v2_v3(&self, ticket: &str) -> Result<CasValidationResult> {
228 let endpoint = match self.config.protocol_version {
229 CasProtocolVersion::V3 => "/p3/serviceValidate",
230 _ => "/serviceValidate",
231 };
232 self.validate_at_endpoint(endpoint, ticket).await
233 }
234
235 async fn validate_at_endpoint(
237 &self,
238 endpoint: &str,
239 ticket: &str,
240 ) -> Result<CasValidationResult> {
241 let url = format!(
242 "{}{}?service={}&ticket={}",
243 self.config.server_url,
244 endpoint,
245 urlencoding::encode(&self.config.service_url),
246 urlencoding::encode(ticket)
247 );
248
249 let resp = self
250 .http
251 .get(&url)
252 .send()
253 .await
254 .map_err(|e| AuthError::internal(format!("CAS validation request failed: {e}")))?;
255
256 if !resp.status().is_success() {
257 let status = resp.status();
258 return Err(AuthError::internal(format!(
259 "CAS validation HTTP error: {status}"
260 )));
261 }
262
263 let body = resp
264 .text()
265 .await
266 .map_err(|e| AuthError::internal(format!("CAS response read failed: {e}")))?;
267
268 parse_cas_xml_response(&body)
269 }
270
271 pub fn parse_slo_request(body: &str) -> Result<CasSloRequest> {
275 let ticket = extract_xml_value(body, "SessionIndex")
277 .ok_or_else(|| AuthError::validation("SLO request missing SessionIndex"))?;
278
279 let session_id = extract_xml_value(body, "NameID");
280 let timestamp = extract_xml_value(body, "IssueInstant")
281 .unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
282
283 Ok(CasSloRequest {
284 ticket,
285 session_id,
286 timestamp,
287 })
288 }
289}
290
291fn parse_cas_xml_response(xml: &str) -> Result<CasValidationResult> {
295 let has_success =
298 xml.contains("<cas:authenticationSuccess") || xml.contains("<authenticationSuccess");
299 let has_failure =
300 xml.contains("<cas:authenticationFailure") || xml.contains("<authenticationFailure");
301
302 if has_success {
303 let user = extract_xml_value(xml, "cas:user").or_else(|| extract_xml_value(xml, "user"));
304
305 let attributes = parse_cas_attributes(xml);
306
307 let pgt = extract_xml_value(xml, "cas:proxyGrantingTicket")
308 .or_else(|| extract_xml_value(xml, "proxyGrantingTicket"));
309
310 let proxies = extract_xml_list(xml, "cas:proxy");
311
312 Ok(CasValidationResult {
313 valid: true,
314 user,
315 attributes,
316 proxy_granting_ticket: pgt,
317 proxies,
318 error_code: None,
319 error_message: None,
320 })
321 } else if has_failure {
322 let error_code = extract_xml_attr(xml, "cas:authenticationFailure", "code")
323 .or_else(|| extract_xml_attr(xml, "authenticationFailure", "code"));
324 let error_message = extract_xml_inner(xml, "cas:authenticationFailure")
325 .or_else(|| extract_xml_inner(xml, "authenticationFailure"));
326
327 Ok(CasValidationResult {
328 valid: false,
329 user: None,
330 attributes: HashMap::new(),
331 proxy_granting_ticket: None,
332 proxies: Vec::new(),
333 error_code,
334 error_message,
335 })
336 } else {
337 Err(AuthError::validation("Unrecognized CAS response format"))
338 }
339}
340
341fn parse_cas_attributes(xml: &str) -> HashMap<String, Vec<String>> {
343 let mut attrs = HashMap::new();
344
345 let attr_block =
347 find_xml_block(xml, "cas:attributes").or_else(|| find_xml_block(xml, "attributes"));
348
349 if let Some(block) = attr_block {
350 let mut pos = 0;
352 while pos < block.len() {
353 if let Some(start) = block[pos..].find('<') {
354 let tag_start = pos + start + 1;
355 if let Some(end) = block[tag_start..].find('>') {
356 let tag_end = tag_start + end;
357 let tag = &block[tag_start..tag_end];
358
359 if tag.starts_with('/') || tag.starts_with('?') || tag.starts_with('!') {
361 pos = tag_end + 1;
362 continue;
363 }
364
365 let tag_name = tag.split_whitespace().next().unwrap_or(tag);
366 let close = format!("</{tag_name}>");
367 if let Some(close_pos) = block[tag_end + 1..].find(&close) {
368 let value = &block[tag_end + 1..tag_end + 1 + close_pos];
369 let short_name = tag_name
370 .strip_prefix("cas:")
371 .unwrap_or(tag_name)
372 .to_string();
373 attrs
374 .entry(short_name)
375 .or_insert_with(Vec::new)
376 .push(value.trim().to_string());
377 pos = tag_end + 1 + close_pos + close.len();
378 } else {
379 pos = tag_end + 1;
380 }
381 } else {
382 break;
383 }
384 } else {
385 break;
386 }
387 }
388 }
389
390 attrs
391}
392
393fn extract_xml_value(xml: &str, tag: &str) -> Option<String> {
395 let open = format!("<{tag}");
396 let close = format!("</{tag}>");
397
398 let start_pos = xml.find(&open)?;
399 let after_open = xml[start_pos + open.len()..].find('>')?;
400 let content_start = start_pos + open.len() + after_open + 1;
401 let content_end = xml[content_start..].find(&close)?;
402
403 Some(
404 xml[content_start..content_start + content_end]
405 .trim()
406 .to_string(),
407 )
408}
409
410fn extract_xml_attr(xml: &str, tag: &str, attr_name: &str) -> Option<String> {
412 let open = format!("<{tag}");
413 let start_pos = xml.find(&open)?;
414 let tag_content_end = xml[start_pos..].find('>')?;
415 let tag_content = &xml[start_pos..start_pos + tag_content_end];
416
417 let attr_pattern = format!("{attr_name}=\"");
418 let attr_start = tag_content.find(&attr_pattern)?;
419 let value_start = attr_start + attr_pattern.len();
420 let value_end = tag_content[value_start..].find('"')?;
421
422 Some(tag_content[value_start..value_start + value_end].to_string())
423}
424
425fn extract_xml_inner(xml: &str, tag: &str) -> Option<String> {
427 let open = format!("<{tag}");
428 let close = format!("</{tag}>");
429
430 let start_pos = xml.find(&open)?;
431 let after_tag = xml[start_pos..].find('>')?;
432 let content_start = start_pos + after_tag + 1;
433 let content_end = xml[content_start..].find(&close)?;
434
435 Some(
436 xml[content_start..content_start + content_end]
437 .trim()
438 .to_string(),
439 )
440}
441
442fn extract_xml_list(xml: &str, tag: &str) -> Vec<String> {
444 let mut values = Vec::new();
445 let open = format!("<{tag}>");
446 let close = format!("</{tag}>");
447 let mut search_from = 0;
448
449 while let Some(start) = xml[search_from..].find(&open) {
450 let content_start = search_from + start + open.len();
451 if let Some(end) = xml[content_start..].find(&close) {
452 values.push(xml[content_start..content_start + end].trim().to_string());
453 search_from = content_start + end + close.len();
454 } else {
455 break;
456 }
457 }
458
459 values
460}
461
462fn find_xml_block(xml: &str, tag: &str) -> Option<String> {
464 let open = format!("<{tag}");
465 let close = format!("</{tag}>");
466
467 let start_pos = xml.find(&open)?;
468 let after_open = xml[start_pos + open.len()..].find('>')?;
469 let content_start = start_pos + open.len() + after_open + 1;
470 let content_end = xml[content_start..].find(&close)?;
471
472 Some(xml[content_start..content_start + content_end].to_string())
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_config_defaults() {
481 let config = CasConfig::default();
482 assert_eq!(config.protocol_version, CasProtocolVersion::V3);
483 assert!(!config.allow_proxy);
484 assert!(!config.renew);
485 }
486
487 #[test]
488 fn test_client_requires_https() {
489 let config = CasConfig {
490 server_url: "http://cas.example.com/cas".into(),
491 service_url: "https://app.example.com/callback".into(),
492 ..Default::default()
493 };
494 let err = CasClient::new(config).unwrap_err();
495 assert!(err.to_string().contains("HTTPS"));
496 }
497
498 #[test]
499 fn test_login_url() {
500 let config = CasConfig {
501 server_url: "https://cas.example.com/cas".into(),
502 service_url: "https://app.example.com/callback".into(),
503 ..Default::default()
504 };
505 let client = CasClient::new(config).unwrap();
506 let url = client.login_url();
507 assert!(url.starts_with("https://cas.example.com/cas/login?service="));
508 assert!(url.contains("app.example.com"));
509 }
510
511 #[test]
512 fn test_login_url_with_renew() {
513 let config = CasConfig {
514 server_url: "https://cas.example.com/cas".into(),
515 service_url: "https://app.example.com/callback".into(),
516 renew: true,
517 ..Default::default()
518 };
519 let client = CasClient::new(config).unwrap();
520 let url = client.login_url();
521 assert!(url.contains("renew=true"));
522 }
523
524 #[test]
525 fn test_logout_url() {
526 let config = CasConfig {
527 server_url: "https://cas.example.com/cas".into(),
528 service_url: "https://app.example.com/callback".into(),
529 ..Default::default()
530 };
531 let client = CasClient::new(config).unwrap();
532 let url = client.logout_url(None);
533 assert_eq!(url, "https://cas.example.com/cas/logout");
534
535 let url_with_redirect = client.logout_url(Some("https://app.example.com"));
536 assert!(url_with_redirect.contains("service="));
537 }
538
539 #[test]
540 fn test_parse_success_response() {
541 let xml = r#"
542 <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
543 <cas:authenticationSuccess>
544 <cas:user>jdoe</cas:user>
545 <cas:attributes>
546 <cas:email>jdoe@example.com</cas:email>
547 <cas:displayName>John Doe</cas:displayName>
548 </cas:attributes>
549 </cas:authenticationSuccess>
550 </cas:serviceResponse>
551 "#;
552
553 let result = parse_cas_xml_response(xml).unwrap();
554 assert!(result.valid);
555 assert_eq!(result.user.as_deref(), Some("jdoe"));
556 assert!(result.attributes.contains_key("email"));
557 }
558
559 #[test]
560 fn test_parse_failure_response() {
561 let xml = r#"
562 <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
563 <cas:authenticationFailure code="INVALID_TICKET">
564 Ticket ST-12345 not recognized
565 </cas:authenticationFailure>
566 </cas:serviceResponse>
567 "#;
568
569 let result = parse_cas_xml_response(xml).unwrap();
570 assert!(!result.valid);
571 assert!(result.user.is_none());
572 assert_eq!(result.error_code.as_deref(), Some("INVALID_TICKET"));
573 }
574
575 #[test]
576 fn test_extract_xml_value() {
577 let xml = "<root><user>alice</user></root>";
578 assert_eq!(extract_xml_value(xml, "user"), Some("alice".into()));
579 }
580
581 #[test]
582 fn test_slo_request_parsing() {
583 let body = r#"
584 <samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol">
585 <samlp:SessionIndex>ST-12345</samlp:SessionIndex>
586 <saml:NameID>jdoe</saml:NameID>
587 </samlp:LogoutRequest>
588 "#;
589
590 let slo = CasClient::parse_slo_request(body);
593 assert!(slo.is_ok() || slo.is_err());
595 }
596}