1use crate::address::Pattern;
42use crate::{Error, Result};
43use std::collections::HashMap;
44use std::fmt;
45use std::str::FromStr;
46use std::sync::RwLock;
47use std::time::{Duration, SystemTime, UNIX_EPOCH};
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum Action {
52 Read,
54 Write,
56 Admin,
58}
59
60impl Action {
61 pub fn allows(&self, other: Action) -> bool {
63 match self {
64 Action::Admin => true, Action::Write => matches!(other, Action::Write | Action::Read),
66 Action::Read => matches!(other, Action::Read),
67 }
68 }
69}
70
71impl fmt::Display for Action {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Action::Read => write!(f, "read"),
75 Action::Write => write!(f, "write"),
76 Action::Admin => write!(f, "admin"),
77 }
78 }
79}
80
81impl FromStr for Action {
82 type Err = Error;
83
84 fn from_str(s: &str) -> Result<Self> {
85 match s.to_lowercase().as_str() {
86 "read" | "r" => Ok(Action::Read),
87 "write" | "w" => Ok(Action::Write),
88 "admin" | "a" | "*" => Ok(Action::Admin),
89 _ => Err(Error::InvalidPattern(format!("unknown action: {}", s))),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct Scope {
97 action: Action,
98 pattern: Pattern,
99 raw: String,
100}
101
102impl Scope {
103 pub fn new(action: Action, pattern_str: &str) -> Result<Self> {
105 let pattern = Pattern::compile(pattern_str)?;
106 Ok(Self {
107 action,
108 pattern,
109 raw: format!("{}:{}", action, pattern_str),
110 })
111 }
112
113 pub fn parse(s: &str) -> Result<Self> {
115 let parts: Vec<&str> = s.splitn(2, ':').collect();
116 if parts.len() != 2 {
117 return Err(Error::InvalidPattern(format!(
118 "scope must be in format 'action:pattern', got: {}",
119 s
120 )));
121 }
122
123 let action = Action::from_str(parts[0])?;
124 let pattern = Pattern::compile(parts[1])?;
125
126 Ok(Self {
127 action,
128 pattern,
129 raw: s.to_string(),
130 })
131 }
132
133 pub fn allows(&self, action: Action, address: &str) -> bool {
135 self.action.allows(action) && self.pattern.matches(address)
136 }
137
138 pub fn action(&self) -> Action {
140 self.action
141 }
142
143 pub fn pattern(&self) -> &Pattern {
145 &self.pattern
146 }
147
148 pub fn as_str(&self) -> &str {
150 &self.raw
151 }
152}
153
154impl fmt::Display for Scope {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 write!(f, "{}", self.raw)
157 }
158}
159
160impl FromStr for Scope {
161 type Err = Error;
162
163 fn from_str(s: &str) -> Result<Self> {
164 Scope::parse(s)
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct TokenInfo {
171 pub token_id: String,
173 pub subject: Option<String>,
175 pub scopes: Vec<Scope>,
177 pub expires_at: Option<SystemTime>,
179 pub metadata: HashMap<String, String>,
181}
182
183impl TokenInfo {
184 pub fn new(token_id: String, scopes: Vec<Scope>) -> Self {
186 Self {
187 token_id,
188 subject: None,
189 scopes,
190 expires_at: None,
191 metadata: HashMap::new(),
192 }
193 }
194
195 pub fn is_expired(&self) -> bool {
197 if let Some(expires_at) = self.expires_at {
198 SystemTime::now() > expires_at
199 } else {
200 false
201 }
202 }
203
204 pub fn has_scope(&self, action: Action, address: &str) -> bool {
206 self.scopes
207 .iter()
208 .any(|scope| scope.allows(action, address))
209 }
210
211 pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
213 self.subject = Some(subject.into());
214 self
215 }
216
217 pub fn with_expires_at(mut self, expires_at: SystemTime) -> Self {
219 self.expires_at = Some(expires_at);
220 self
221 }
222
223 pub fn with_expires_in(mut self, duration: Duration) -> Self {
225 self.expires_at = Some(SystemTime::now() + duration);
226 self
227 }
228
229 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
231 self.metadata.insert(key.into(), value.into());
232 self
233 }
234}
235
236#[derive(Debug)]
238pub enum ValidationResult {
239 Valid(TokenInfo),
241 NotMyToken,
243 Invalid(String),
245 Expired,
247}
248
249pub trait TokenValidator: Send + Sync + std::any::Any {
251 fn validate(&self, token: &str) -> ValidationResult;
253
254 fn name(&self) -> &str;
256
257 fn as_any(&self) -> &dyn std::any::Any;
259}
260
261pub struct CpskValidator {
266 tokens: RwLock<HashMap<String, TokenInfo>>,
267}
268
269impl CpskValidator {
270 pub const PREFIX: &'static str = "cpsk_";
272
273 pub fn new() -> Self {
275 Self {
276 tokens: RwLock::new(HashMap::new()),
277 }
278 }
279
280 pub fn register(&self, token: String, info: TokenInfo) {
282 self.tokens.write().unwrap().insert(token, info);
283 }
284
285 pub fn revoke(&self, token: &str) -> bool {
287 self.tokens.write().unwrap().remove(token).is_some()
288 }
289
290 pub fn exists(&self, token: &str) -> bool {
292 self.tokens.read().unwrap().contains_key(token)
293 }
294
295 pub fn len(&self) -> usize {
297 self.tokens.read().unwrap().len()
298 }
299
300 pub fn is_empty(&self) -> bool {
302 self.tokens.read().unwrap().is_empty()
303 }
304
305 pub fn list_tokens(&self) -> Vec<String> {
307 self.tokens.read().unwrap().keys().cloned().collect()
308 }
309
310 pub fn generate_token() -> String {
312 let uuid = uuid::Uuid::new_v4();
313 format!("{}{}", Self::PREFIX, uuid.as_simple())
314 }
315}
316
317impl Default for CpskValidator {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323impl TokenValidator for CpskValidator {
324 fn validate(&self, token: &str) -> ValidationResult {
325 if !token.starts_with(Self::PREFIX) {
327 return ValidationResult::NotMyToken;
328 }
329
330 let tokens = self.tokens.read().unwrap();
332 match tokens.get(token) {
333 Some(info) => {
334 if info.is_expired() {
335 ValidationResult::Expired
336 } else {
337 ValidationResult::Valid(info.clone())
338 }
339 }
340 None => ValidationResult::Invalid("token not found".to_string()),
341 }
342 }
343
344 fn name(&self) -> &str {
345 "CPSK"
346 }
347
348 fn as_any(&self) -> &dyn std::any::Any {
349 self
350 }
351}
352
353pub struct ValidatorChain {
355 validators: Vec<Box<dyn TokenValidator>>,
356}
357
358impl ValidatorChain {
359 pub fn new() -> Self {
361 Self {
362 validators: Vec::new(),
363 }
364 }
365
366 pub fn add<V: TokenValidator + 'static>(&mut self, validator: V) {
368 self.validators.push(Box::new(validator));
369 }
370
371 pub fn with<V: TokenValidator + 'static>(mut self, validator: V) -> Self {
373 self.add(validator);
374 self
375 }
376
377 pub fn validate(&self, token: &str) -> ValidationResult {
379 for validator in &self.validators {
380 match validator.validate(token) {
381 ValidationResult::NotMyToken => continue,
382 result => return result,
383 }
384 }
385 ValidationResult::Invalid("no validator accepted the token".to_string())
386 }
387
388 pub fn len(&self) -> usize {
390 self.validators.len()
391 }
392
393 pub fn is_empty(&self) -> bool {
395 self.validators.is_empty()
396 }
397}
398
399impl Default for ValidatorChain {
400 fn default() -> Self {
401 Self::new()
402 }
403}
404
405#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
407pub enum SecurityMode {
408 #[default]
410 Open,
411 Authenticated,
413}
414
415impl fmt::Display for SecurityMode {
416 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
417 match self {
418 SecurityMode::Open => write!(f, "open"),
419 SecurityMode::Authenticated => write!(f, "authenticated"),
420 }
421 }
422}
423
424impl FromStr for SecurityMode {
425 type Err = Error;
426
427 fn from_str(s: &str) -> Result<Self> {
428 match s.to_lowercase().as_str() {
429 "open" | "none" | "off" => Ok(SecurityMode::Open),
430 "authenticated" | "auth" | "token" => Ok(SecurityMode::Authenticated),
431 _ => Err(Error::InvalidPattern(format!(
432 "unknown security mode: {}",
433 s
434 ))),
435 }
436 }
437}
438
439pub fn parse_scopes(s: &str) -> Result<Vec<Scope>> {
441 s.split(',').map(|part| Scope::parse(part.trim())).collect()
442}
443
444pub fn parse_duration(s: &str) -> Result<Duration> {
446 let s = s.trim();
447 if s.is_empty() {
448 return Err(Error::InvalidPattern("empty duration".to_string()));
449 }
450
451 let (num_str, unit) = if s.ends_with('d') {
452 (&s[..s.len() - 1], "d")
453 } else if s.ends_with('h') {
454 (&s[..s.len() - 1], "h")
455 } else if s.ends_with('m') {
456 (&s[..s.len() - 1], "m")
457 } else if s.ends_with('s') {
458 (&s[..s.len() - 1], "s")
459 } else {
460 (s, "s")
462 };
463
464 let num: u64 = num_str
465 .parse()
466 .map_err(|_| Error::InvalidPattern(format!("invalid duration number: {}", num_str)))?;
467
468 let secs = match unit {
469 "d" => num * 86400,
470 "h" => num * 3600,
471 "m" => num * 60,
472 "s" => num,
473 _ => unreachable!(),
474 };
475
476 Ok(Duration::from_secs(secs))
477}
478
479pub fn to_unix_timestamp(time: SystemTime) -> u64 {
481 time.duration_since(UNIX_EPOCH)
482 .map(|d| d.as_secs())
483 .unwrap_or(0)
484}
485
486pub fn from_unix_timestamp(ts: u64) -> SystemTime {
488 UNIX_EPOCH + Duration::from_secs(ts)
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_action_allows() {
497 assert!(Action::Admin.allows(Action::Read));
498 assert!(Action::Admin.allows(Action::Write));
499 assert!(Action::Admin.allows(Action::Admin));
500
501 assert!(Action::Write.allows(Action::Read));
502 assert!(Action::Write.allows(Action::Write));
503 assert!(!Action::Write.allows(Action::Admin));
504
505 assert!(Action::Read.allows(Action::Read));
506 assert!(!Action::Read.allows(Action::Write));
507 assert!(!Action::Read.allows(Action::Admin));
508 }
509
510 #[test]
511 fn test_action_from_str() {
512 assert_eq!(Action::from_str("read").unwrap(), Action::Read);
513 assert_eq!(Action::from_str("write").unwrap(), Action::Write);
514 assert_eq!(Action::from_str("admin").unwrap(), Action::Admin);
515 assert_eq!(Action::from_str("r").unwrap(), Action::Read);
516 assert_eq!(Action::from_str("w").unwrap(), Action::Write);
517 assert_eq!(Action::from_str("a").unwrap(), Action::Admin);
518 assert!(Action::from_str("invalid").is_err());
519 }
520
521 #[test]
522 fn test_scope_parse() {
523 let scope = Scope::parse("read:/**").unwrap();
524 assert_eq!(scope.action(), Action::Read);
525 assert!(scope.allows(Action::Read, "/any/path"));
526 assert!(!scope.allows(Action::Write, "/any/path"));
527
528 let scope = Scope::parse("write:/lights/**").unwrap();
529 assert!(scope.allows(Action::Write, "/lights/room/1"));
530 assert!(scope.allows(Action::Read, "/lights/room/1"));
531 assert!(!scope.allows(Action::Write, "/sensors/temp"));
532 assert!(!scope.allows(Action::Read, "/sensors/temp"));
533
534 let scope = Scope::parse("admin:/**").unwrap();
535 assert!(scope.allows(Action::Admin, "/any/path"));
536 assert!(scope.allows(Action::Write, "/any/path"));
537 assert!(scope.allows(Action::Read, "/any/path"));
538 }
539
540 #[test]
541 fn test_scope_wildcards() {
542 let scope = Scope::parse("read:/lumen/scene/*/layer/**").unwrap();
543 assert!(scope.allows(Action::Read, "/lumen/scene/0/layer/1/opacity"));
544 assert!(scope.allows(Action::Read, "/lumen/scene/main/layer/2"));
545 assert!(!scope.allows(Action::Read, "/lumen/scene/0/effect"));
546 }
547
548 #[test]
549 fn test_token_info() {
550 let scopes = vec![
551 Scope::parse("read:/**").unwrap(),
552 Scope::parse("write:/lights/**").unwrap(),
553 ];
554 let info = TokenInfo::new("test_token".to_string(), scopes);
555
556 assert!(info.has_scope(Action::Read, "/any/path"));
557 assert!(info.has_scope(Action::Write, "/lights/room"));
558 assert!(!info.has_scope(Action::Write, "/sensors/temp"));
559 assert!(!info.is_expired());
560 }
561
562 #[test]
563 fn test_token_expiry() {
564 let scopes = vec![Scope::parse("read:/**").unwrap()];
565 let info = TokenInfo::new("test_token".to_string(), scopes)
566 .with_expires_at(SystemTime::now() - Duration::from_secs(1));
567 assert!(info.is_expired());
568
569 let scopes = vec![Scope::parse("read:/**").unwrap()];
570 let info = TokenInfo::new("test_token".to_string(), scopes)
571 .with_expires_in(Duration::from_secs(3600));
572 assert!(!info.is_expired());
573 }
574
575 #[test]
576 fn test_cpsk_validator() {
577 let validator = CpskValidator::new();
578
579 let token = CpskValidator::generate_token();
581 assert!(token.starts_with("cpsk_"));
582 assert_eq!(token.len(), 37); let scopes = vec![Scope::parse("read:/**").unwrap()];
585 let info = TokenInfo::new(token.clone(), scopes);
586 validator.register(token.clone(), info);
587
588 match validator.validate(&token) {
590 ValidationResult::Valid(info) => {
591 assert!(info.has_scope(Action::Read, "/test"));
592 }
593 _ => panic!("expected valid token"),
594 }
595
596 match validator.validate("cpsk_unknown") {
598 ValidationResult::Invalid(_) => {}
599 _ => panic!("expected invalid token"),
600 }
601
602 match validator.validate("jwt_token") {
604 ValidationResult::NotMyToken => {}
605 _ => panic!("expected not my token"),
606 }
607
608 assert!(validator.revoke(&token));
610 match validator.validate(&token) {
611 ValidationResult::Invalid(_) => {}
612 _ => panic!("expected invalid after revoke"),
613 }
614 }
615
616 #[test]
617 fn test_validator_chain() {
618 let mut chain = ValidatorChain::new();
619
620 let cpsk = CpskValidator::new();
621 let token = CpskValidator::generate_token();
622 let scopes = vec![Scope::parse("admin:/**").unwrap()];
623 cpsk.register(token.clone(), TokenInfo::new(token.clone(), scopes));
624 chain.add(cpsk);
625
626 match chain.validate(&token) {
627 ValidationResult::Valid(_) => {}
628 _ => panic!("expected valid token"),
629 }
630
631 match chain.validate("unknown_token") {
632 ValidationResult::Invalid(_) => {}
633 _ => panic!("expected invalid token"),
634 }
635 }
636
637 #[test]
638 fn test_parse_scopes() {
639 let scopes = parse_scopes("read:/**, write:/lights/**").unwrap();
640 assert_eq!(scopes.len(), 2);
641 assert!(scopes[0].allows(Action::Read, "/any"));
642 assert!(scopes[1].allows(Action::Write, "/lights/1"));
643 }
644
645 #[test]
646 fn test_parse_duration() {
647 assert_eq!(
648 parse_duration("7d").unwrap(),
649 Duration::from_secs(7 * 86400)
650 );
651 assert_eq!(
652 parse_duration("24h").unwrap(),
653 Duration::from_secs(24 * 3600)
654 );
655 assert_eq!(parse_duration("30m").unwrap(), Duration::from_secs(30 * 60));
656 assert_eq!(parse_duration("60s").unwrap(), Duration::from_secs(60));
657 assert_eq!(parse_duration("120").unwrap(), Duration::from_secs(120));
658 assert!(parse_duration("").is_err());
659 assert!(parse_duration("abc").is_err());
660 }
661
662 #[test]
663 fn test_security_mode() {
664 assert_eq!(SecurityMode::from_str("open").unwrap(), SecurityMode::Open);
665 assert_eq!(
666 SecurityMode::from_str("authenticated").unwrap(),
667 SecurityMode::Authenticated
668 );
669 assert_eq!(
670 SecurityMode::from_str("auth").unwrap(),
671 SecurityMode::Authenticated
672 );
673 }
674
675 #[test]
676 fn test_cpsk_token_uniqueness() {
677 use std::collections::HashSet;
678
679 let mut tokens = HashSet::new();
680 for _ in 0..10_000 {
681 let token = CpskValidator::generate_token();
682 assert!(
683 tokens.insert(token.clone()),
684 "duplicate token generated: {}",
685 token
686 );
687 }
688 assert_eq!(tokens.len(), 10_000);
689 }
690
691 #[test]
692 fn test_cpsk_token_format() {
693 let token = CpskValidator::generate_token();
694 assert!(token.starts_with("cpsk_"));
695 assert_eq!(token.len(), 37);
697 let uuid_part = &token[5..];
699 assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
700 }
701}