1use crate::address::Pattern;
42use crate::{Error, Result};
43use std::collections::HashMap;
44use std::fmt;
45use std::str::FromStr;
46use std::sync::RwLock;
47use std::time::{Duration, SystemTime, UNIX_EPOCH};
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum Action {
52 Read,
54 Write,
56 Admin,
58}
59
60impl Action {
61 pub fn allows(&self, other: Action) -> bool {
63 match self {
64 Action::Admin => true, Action::Write => matches!(other, Action::Write | Action::Read),
66 Action::Read => matches!(other, Action::Read),
67 }
68 }
69}
70
71impl fmt::Display for Action {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Action::Read => write!(f, "read"),
75 Action::Write => write!(f, "write"),
76 Action::Admin => write!(f, "admin"),
77 }
78 }
79}
80
81impl FromStr for Action {
82 type Err = Error;
83
84 fn from_str(s: &str) -> Result<Self> {
85 match s.to_lowercase().as_str() {
86 "read" | "r" => Ok(Action::Read),
87 "write" | "w" => Ok(Action::Write),
88 "admin" | "a" | "*" => Ok(Action::Admin),
89 _ => Err(Error::InvalidPattern(format!("unknown action: {}", s))),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct Scope {
97 action: Action,
98 pattern: Pattern,
99 raw: String,
100}
101
102impl Scope {
103 pub fn new(action: Action, pattern_str: &str) -> Result<Self> {
105 let pattern = Pattern::compile(pattern_str)?;
106 Ok(Self {
107 action,
108 pattern,
109 raw: format!("{}:{}", action, pattern_str),
110 })
111 }
112
113 pub fn parse(s: &str) -> Result<Self> {
115 let parts: Vec<&str> = s.splitn(2, ':').collect();
116 if parts.len() != 2 {
117 return Err(Error::InvalidPattern(format!(
118 "scope must be in format 'action:pattern', got: {}",
119 s
120 )));
121 }
122
123 let action = Action::from_str(parts[0])?;
124 let pattern = Pattern::compile(parts[1])?;
125
126 Ok(Self {
127 action,
128 pattern,
129 raw: s.to_string(),
130 })
131 }
132
133 pub fn allows(&self, action: Action, address: &str) -> bool {
135 self.action.allows(action) && self.pattern.matches(address)
136 }
137
138 pub fn action(&self) -> Action {
140 self.action
141 }
142
143 pub fn pattern(&self) -> &Pattern {
145 &self.pattern
146 }
147
148 pub fn as_str(&self) -> &str {
150 &self.raw
151 }
152}
153
154impl fmt::Display for Scope {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 write!(f, "{}", self.raw)
157 }
158}
159
160impl FromStr for Scope {
161 type Err = Error;
162
163 fn from_str(s: &str) -> Result<Self> {
164 Scope::parse(s)
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct TokenInfo {
171 pub token_id: String,
173 pub subject: Option<String>,
175 pub scopes: Vec<Scope>,
177 pub expires_at: Option<SystemTime>,
179 pub metadata: HashMap<String, String>,
181}
182
183impl TokenInfo {
184 pub fn new(token_id: String, scopes: Vec<Scope>) -> Self {
186 Self {
187 token_id,
188 subject: None,
189 scopes,
190 expires_at: None,
191 metadata: HashMap::new(),
192 }
193 }
194
195 pub fn is_expired(&self) -> bool {
200 if let Some(expires_at) = self.expires_at {
201 SystemTime::now() > expires_at
202 } else {
203 false
204 }
205 }
206
207 pub fn has_scope(&self, action: Action, address: &str) -> bool {
209 self.scopes
210 .iter()
211 .any(|scope| scope.allows(action, address))
212 }
213
214 pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
216 self.subject = Some(subject.into());
217 self
218 }
219
220 pub fn with_expires_at(mut self, expires_at: SystemTime) -> Self {
222 self.expires_at = Some(expires_at);
223 self
224 }
225
226 pub fn with_expires_in(mut self, duration: Duration) -> Self {
228 self.expires_at = Some(SystemTime::now() + duration);
229 self
230 }
231
232 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
234 self.metadata.insert(key.into(), value.into());
235 self
236 }
237}
238
239#[derive(Debug)]
241pub enum ValidationResult {
242 Valid(TokenInfo),
244 NotMyToken,
246 Invalid(String),
248 Expired,
250}
251
252pub trait TokenValidator: Send + Sync + std::any::Any {
254 fn validate(&self, token: &str) -> ValidationResult;
256
257 fn name(&self) -> &str;
259
260 fn as_any(&self) -> &dyn std::any::Any;
262}
263
264pub struct CpskValidator {
269 tokens: RwLock<HashMap<String, TokenInfo>>,
270 default_ttl: Option<Duration>,
271}
272
273impl CpskValidator {
274 pub const PREFIX: &'static str = "cpsk_";
276
277 pub fn new() -> Self {
279 Self {
280 tokens: RwLock::new(HashMap::new()),
281 default_ttl: None,
282 }
283 }
284
285 pub fn with_default_ttl(ttl: Duration) -> Self {
289 Self {
290 tokens: RwLock::new(HashMap::new()),
291 default_ttl: Some(ttl),
292 }
293 }
294
295 pub fn register(&self, token: String, mut info: TokenInfo) {
298 if info.expires_at.is_none() {
299 if let Some(ttl) = self.default_ttl {
300 info.expires_at = Some(SystemTime::now() + ttl);
301 }
302 }
303 self.tokens.write().unwrap().insert(token, info);
304 }
305
306 pub fn revoke(&self, token: &str) -> bool {
308 self.tokens.write().unwrap().remove(token).is_some()
309 }
310
311 pub fn exists(&self, token: &str) -> bool {
313 self.tokens.read().unwrap().contains_key(token)
314 }
315
316 pub fn len(&self) -> usize {
318 self.tokens.read().unwrap().len()
319 }
320
321 pub fn is_empty(&self) -> bool {
323 self.tokens.read().unwrap().is_empty()
324 }
325
326 pub fn list_tokens(&self) -> Vec<String> {
328 self.tokens.read().unwrap().keys().cloned().collect()
329 }
330
331 pub fn generate_token() -> String {
333 let uuid = uuid::Uuid::new_v4();
334 format!("{}{}", Self::PREFIX, uuid.as_simple())
335 }
336}
337
338impl Default for CpskValidator {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344impl TokenValidator for CpskValidator {
345 fn validate(&self, token: &str) -> ValidationResult {
346 if !token.starts_with(Self::PREFIX) {
348 return ValidationResult::NotMyToken;
349 }
350
351 let tokens = self.tokens.read().unwrap();
353 match tokens.get(token) {
354 Some(info) => {
355 if info.is_expired() {
356 ValidationResult::Expired
357 } else {
358 ValidationResult::Valid(info.clone())
359 }
360 }
361 None => ValidationResult::Invalid("token not found".to_string()),
362 }
363 }
364
365 fn name(&self) -> &str {
366 "CPSK"
367 }
368
369 fn as_any(&self) -> &dyn std::any::Any {
370 self
371 }
372}
373
374pub struct ValidatorChain {
376 validators: Vec<Box<dyn TokenValidator>>,
377}
378
379impl ValidatorChain {
380 pub fn new() -> Self {
382 Self {
383 validators: Vec::new(),
384 }
385 }
386
387 pub fn add<V: TokenValidator + 'static>(&mut self, validator: V) {
389 self.validators.push(Box::new(validator));
390 }
391
392 pub fn with<V: TokenValidator + 'static>(mut self, validator: V) -> Self {
394 self.add(validator);
395 self
396 }
397
398 pub fn validate(&self, token: &str) -> ValidationResult {
400 for validator in &self.validators {
401 match validator.validate(token) {
402 ValidationResult::NotMyToken => continue,
403 result => return result,
404 }
405 }
406 ValidationResult::Invalid("no validator accepted the token".to_string())
407 }
408
409 pub fn len(&self) -> usize {
411 self.validators.len()
412 }
413
414 pub fn is_empty(&self) -> bool {
416 self.validators.is_empty()
417 }
418}
419
420impl TokenValidator for ValidatorChain {
421 fn validate(&self, token: &str) -> ValidationResult {
422 for validator in &self.validators {
423 match validator.validate(token) {
424 ValidationResult::NotMyToken => continue,
425 result => return result,
426 }
427 }
428 ValidationResult::Invalid("no validator accepted the token".to_string())
429 }
430
431 fn name(&self) -> &str {
432 "ValidatorChain"
433 }
434
435 fn as_any(&self) -> &dyn std::any::Any {
436 self
437 }
438}
439
440impl Default for ValidatorChain {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
448pub enum SecurityMode {
449 #[default]
451 Open,
452 Authenticated,
454}
455
456impl fmt::Display for SecurityMode {
457 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
458 match self {
459 SecurityMode::Open => write!(f, "open"),
460 SecurityMode::Authenticated => write!(f, "authenticated"),
461 }
462 }
463}
464
465impl FromStr for SecurityMode {
466 type Err = Error;
467
468 fn from_str(s: &str) -> Result<Self> {
469 match s.to_lowercase().as_str() {
470 "open" | "none" | "off" => Ok(SecurityMode::Open),
471 "authenticated" | "auth" | "token" => Ok(SecurityMode::Authenticated),
472 _ => Err(Error::InvalidPattern(format!(
473 "unknown security mode: {}",
474 s
475 ))),
476 }
477 }
478}
479
480pub fn parse_scopes(s: &str) -> Result<Vec<Scope>> {
482 s.split(',').map(|part| Scope::parse(part.trim())).collect()
483}
484
485pub fn parse_duration(s: &str) -> Result<Duration> {
487 let s = s.trim();
488 if s.is_empty() {
489 return Err(Error::InvalidPattern("empty duration".to_string()));
490 }
491
492 let (num_str, unit) = if let Some(n) = s.strip_suffix('d') {
493 (n, "d")
494 } else if let Some(n) = s.strip_suffix('h') {
495 (n, "h")
496 } else if let Some(n) = s.strip_suffix('m') {
497 (n, "m")
498 } else if let Some(n) = s.strip_suffix('s') {
499 (n, "s")
500 } else {
501 (s, "s")
503 };
504
505 let num: u64 = num_str
506 .parse()
507 .map_err(|_| Error::InvalidPattern(format!("invalid duration number: {}", num_str)))?;
508
509 let secs = match unit {
510 "d" => num * 86400,
511 "h" => num * 3600,
512 "m" => num * 60,
513 "s" => num,
514 _ => unreachable!(),
515 };
516
517 Ok(Duration::from_secs(secs))
518}
519
520pub fn to_unix_timestamp(time: SystemTime) -> u64 {
522 time.duration_since(UNIX_EPOCH)
523 .map(|d| d.as_secs())
524 .unwrap_or(0)
525}
526
527pub fn from_unix_timestamp(ts: u64) -> SystemTime {
529 UNIX_EPOCH + Duration::from_secs(ts)
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_action_allows() {
538 assert!(Action::Admin.allows(Action::Read));
539 assert!(Action::Admin.allows(Action::Write));
540 assert!(Action::Admin.allows(Action::Admin));
541
542 assert!(Action::Write.allows(Action::Read));
543 assert!(Action::Write.allows(Action::Write));
544 assert!(!Action::Write.allows(Action::Admin));
545
546 assert!(Action::Read.allows(Action::Read));
547 assert!(!Action::Read.allows(Action::Write));
548 assert!(!Action::Read.allows(Action::Admin));
549 }
550
551 #[test]
552 fn test_action_from_str() {
553 assert_eq!(Action::from_str("read").unwrap(), Action::Read);
554 assert_eq!(Action::from_str("write").unwrap(), Action::Write);
555 assert_eq!(Action::from_str("admin").unwrap(), Action::Admin);
556 assert_eq!(Action::from_str("r").unwrap(), Action::Read);
557 assert_eq!(Action::from_str("w").unwrap(), Action::Write);
558 assert_eq!(Action::from_str("a").unwrap(), Action::Admin);
559 assert!(Action::from_str("invalid").is_err());
560 }
561
562 #[test]
563 fn test_scope_parse() {
564 let scope = Scope::parse("read:/**").unwrap();
565 assert_eq!(scope.action(), Action::Read);
566 assert!(scope.allows(Action::Read, "/any/path"));
567 assert!(!scope.allows(Action::Write, "/any/path"));
568
569 let scope = Scope::parse("write:/lights/**").unwrap();
570 assert!(scope.allows(Action::Write, "/lights/room/1"));
571 assert!(scope.allows(Action::Read, "/lights/room/1"));
572 assert!(!scope.allows(Action::Write, "/sensors/temp"));
573 assert!(!scope.allows(Action::Read, "/sensors/temp"));
574
575 let scope = Scope::parse("admin:/**").unwrap();
576 assert!(scope.allows(Action::Admin, "/any/path"));
577 assert!(scope.allows(Action::Write, "/any/path"));
578 assert!(scope.allows(Action::Read, "/any/path"));
579 }
580
581 #[test]
582 fn test_scope_wildcards() {
583 let scope = Scope::parse("read:/lumen/scene/*/layer/**").unwrap();
584 assert!(scope.allows(Action::Read, "/lumen/scene/0/layer/1/opacity"));
585 assert!(scope.allows(Action::Read, "/lumen/scene/main/layer/2"));
586 assert!(!scope.allows(Action::Read, "/lumen/scene/0/effect"));
587 }
588
589 #[test]
590 fn test_token_info() {
591 let scopes = vec![
592 Scope::parse("read:/**").unwrap(),
593 Scope::parse("write:/lights/**").unwrap(),
594 ];
595 let info = TokenInfo::new("test_token".to_string(), scopes);
596
597 assert!(info.has_scope(Action::Read, "/any/path"));
598 assert!(info.has_scope(Action::Write, "/lights/room"));
599 assert!(!info.has_scope(Action::Write, "/sensors/temp"));
600 assert!(!info.is_expired());
601 }
602
603 #[test]
604 fn test_token_expiry() {
605 let scopes = vec![Scope::parse("read:/**").unwrap()];
606 let info = TokenInfo::new("test_token".to_string(), scopes)
607 .with_expires_at(SystemTime::now() - Duration::from_secs(1));
608 assert!(info.is_expired());
609
610 let scopes = vec![Scope::parse("read:/**").unwrap()];
611 let info = TokenInfo::new("test_token".to_string(), scopes)
612 .with_expires_in(Duration::from_secs(3600));
613 assert!(!info.is_expired());
614 }
615
616 #[test]
617 fn test_cpsk_validator() {
618 let validator = CpskValidator::new();
619
620 let token = CpskValidator::generate_token();
622 assert!(token.starts_with("cpsk_"));
623 assert_eq!(token.len(), 37); let scopes = vec![Scope::parse("read:/**").unwrap()];
626 let info = TokenInfo::new(token.clone(), scopes);
627 validator.register(token.clone(), info);
628
629 match validator.validate(&token) {
631 ValidationResult::Valid(info) => {
632 assert!(info.has_scope(Action::Read, "/test"));
633 }
634 _ => panic!("expected valid token"),
635 }
636
637 match validator.validate("cpsk_unknown") {
639 ValidationResult::Invalid(_) => {}
640 _ => panic!("expected invalid token"),
641 }
642
643 match validator.validate("jwt_token") {
645 ValidationResult::NotMyToken => {}
646 _ => panic!("expected not my token"),
647 }
648
649 assert!(validator.revoke(&token));
651 match validator.validate(&token) {
652 ValidationResult::Invalid(_) => {}
653 _ => panic!("expected invalid after revoke"),
654 }
655 }
656
657 #[test]
658 fn test_validator_chain() {
659 let mut chain = ValidatorChain::new();
660
661 let cpsk = CpskValidator::new();
662 let token = CpskValidator::generate_token();
663 let scopes = vec![Scope::parse("admin:/**").unwrap()];
664 cpsk.register(token.clone(), TokenInfo::new(token.clone(), scopes));
665 chain.add(cpsk);
666
667 match chain.validate(&token) {
668 ValidationResult::Valid(_) => {}
669 _ => panic!("expected valid token"),
670 }
671
672 match chain.validate("unknown_token") {
673 ValidationResult::Invalid(_) => {}
674 _ => panic!("expected invalid token"),
675 }
676 }
677
678 #[test]
679 fn test_validator_chain_as_trait_object() {
680 let mut chain = ValidatorChain::new();
681
682 let cpsk = CpskValidator::new();
683 let token = CpskValidator::generate_token();
684 let scopes = vec![Scope::parse("admin:/**").unwrap()];
685 cpsk.register(token.clone(), TokenInfo::new(token.clone(), scopes));
686 chain.add(cpsk);
687
688 let validator: &dyn TokenValidator = &chain;
690 assert_eq!(validator.name(), "ValidatorChain");
691
692 match validator.validate(&token) {
693 ValidationResult::Valid(info) => {
694 assert!(info.has_scope(Action::Admin, "/any/path"));
695 }
696 _ => panic!("expected valid token through trait object"),
697 }
698
699 match validator.validate("unknown_prefix_token") {
700 ValidationResult::Invalid(_) => {}
701 _ => panic!("expected invalid for unknown token"),
702 }
703 }
704
705 #[test]
706 fn test_parse_scopes() {
707 let scopes = parse_scopes("read:/**, write:/lights/**").unwrap();
708 assert_eq!(scopes.len(), 2);
709 assert!(scopes[0].allows(Action::Read, "/any"));
710 assert!(scopes[1].allows(Action::Write, "/lights/1"));
711 }
712
713 #[test]
714 fn test_parse_duration() {
715 assert_eq!(
716 parse_duration("7d").unwrap(),
717 Duration::from_secs(7 * 86400)
718 );
719 assert_eq!(
720 parse_duration("24h").unwrap(),
721 Duration::from_secs(24 * 3600)
722 );
723 assert_eq!(parse_duration("30m").unwrap(), Duration::from_secs(30 * 60));
724 assert_eq!(parse_duration("60s").unwrap(), Duration::from_secs(60));
725 assert_eq!(parse_duration("120").unwrap(), Duration::from_secs(120));
726 assert!(parse_duration("").is_err());
727 assert!(parse_duration("abc").is_err());
728 }
729
730 #[test]
731 fn test_security_mode() {
732 assert_eq!(SecurityMode::from_str("open").unwrap(), SecurityMode::Open);
733 assert_eq!(
734 SecurityMode::from_str("authenticated").unwrap(),
735 SecurityMode::Authenticated
736 );
737 assert_eq!(
738 SecurityMode::from_str("auth").unwrap(),
739 SecurityMode::Authenticated
740 );
741 }
742
743 #[test]
744 fn test_cpsk_default_ttl() {
745 let validator = CpskValidator::with_default_ttl(Duration::from_secs(3600));
746 let token = CpskValidator::generate_token();
747 let scopes = vec![Scope::parse("read:/**").unwrap()];
748
749 let info = TokenInfo::new(token.clone(), scopes);
751 assert!(info.expires_at.is_none());
752 validator.register(token.clone(), info);
753
754 let tokens = validator.tokens.read().unwrap();
756 let stored = tokens.get(&token).unwrap();
757 assert!(stored.expires_at.is_some());
758 assert!(!stored.is_expired());
759 }
760
761 #[test]
762 fn test_cpsk_default_ttl_no_override() {
763 let validator = CpskValidator::with_default_ttl(Duration::from_secs(3600));
764 let token = CpskValidator::generate_token();
765 let scopes = vec![Scope::parse("read:/**").unwrap()];
766
767 let explicit_expiry = SystemTime::now() + Duration::from_secs(7200);
769 let info = TokenInfo::new(token.clone(), scopes).with_expires_at(explicit_expiry);
770 validator.register(token.clone(), info);
771
772 let tokens = validator.tokens.read().unwrap();
773 let stored = tokens.get(&token).unwrap();
774 let stored_expiry = stored.expires_at.unwrap();
776 let diff = stored_expiry
777 .duration_since(SystemTime::now())
778 .unwrap()
779 .as_secs();
780 assert!(diff > 3600, "explicit expiry should be preserved");
781 }
782
783 #[test]
784 fn test_cpsk_token_uniqueness() {
785 use std::collections::HashSet;
786
787 let mut tokens = HashSet::new();
788 for _ in 0..10_000 {
789 let token = CpskValidator::generate_token();
790 assert!(
791 tokens.insert(token.clone()),
792 "duplicate token generated: {}",
793 token
794 );
795 }
796 assert_eq!(tokens.len(), 10_000);
797 }
798
799 #[test]
800 fn test_cpsk_token_format() {
801 let token = CpskValidator::generate_token();
802 assert!(token.starts_with("cpsk_"));
803 assert_eq!(token.len(), 37);
805 let uuid_part = &token[5..];
807 assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
808 }
809}