1use crate::address::Pattern;
42use crate::{Error, Result};
43use std::collections::HashMap;
44use std::fmt;
45use std::str::FromStr;
46use std::sync::RwLock;
47use std::time::{Duration, SystemTime, UNIX_EPOCH};
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum Action {
52 Read,
54 Write,
56 Admin,
58}
59
60impl Action {
61 pub fn allows(&self, other: Action) -> bool {
63 match self {
64 Action::Admin => true, Action::Write => matches!(other, Action::Write | Action::Read),
66 Action::Read => matches!(other, Action::Read),
67 }
68 }
69}
70
71impl fmt::Display for Action {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Action::Read => write!(f, "read"),
75 Action::Write => write!(f, "write"),
76 Action::Admin => write!(f, "admin"),
77 }
78 }
79}
80
81impl FromStr for Action {
82 type Err = Error;
83
84 fn from_str(s: &str) -> Result<Self> {
85 match s.to_lowercase().as_str() {
86 "read" | "r" => Ok(Action::Read),
87 "write" | "w" => Ok(Action::Write),
88 "admin" | "a" | "*" => Ok(Action::Admin),
89 _ => Err(Error::InvalidPattern(format!("unknown action: {}", s))),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96pub struct Scope {
97 action: Action,
98 pattern: Pattern,
99 raw: String,
100}
101
102impl Scope {
103 pub fn new(action: Action, pattern_str: &str) -> Result<Self> {
105 let pattern = Pattern::compile(pattern_str)?;
106 Ok(Self {
107 action,
108 pattern,
109 raw: format!("{}:{}", action, pattern_str),
110 })
111 }
112
113 pub fn parse(s: &str) -> Result<Self> {
115 let parts: Vec<&str> = s.splitn(2, ':').collect();
116 if parts.len() != 2 {
117 return Err(Error::InvalidPattern(format!(
118 "scope must be in format 'action:pattern', got: {}",
119 s
120 )));
121 }
122
123 let action = Action::from_str(parts[0])?;
124 let pattern = Pattern::compile(parts[1])?;
125
126 Ok(Self {
127 action,
128 pattern,
129 raw: s.to_string(),
130 })
131 }
132
133 pub fn allows(&self, action: Action, address: &str) -> bool {
135 self.action.allows(action) && self.pattern.matches(address)
136 }
137
138 pub fn action(&self) -> Action {
140 self.action
141 }
142
143 pub fn pattern(&self) -> &Pattern {
145 &self.pattern
146 }
147
148 pub fn as_str(&self) -> &str {
150 &self.raw
151 }
152}
153
154impl fmt::Display for Scope {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 write!(f, "{}", self.raw)
157 }
158}
159
160impl FromStr for Scope {
161 type Err = Error;
162
163 fn from_str(s: &str) -> Result<Self> {
164 Scope::parse(s)
165 }
166}
167
168#[derive(Debug, Clone)]
170pub struct TokenInfo {
171 pub token_id: String,
173 pub subject: Option<String>,
175 pub scopes: Vec<Scope>,
177 pub expires_at: Option<SystemTime>,
179 pub metadata: HashMap<String, String>,
181}
182
183impl TokenInfo {
184 pub fn new(token_id: String, scopes: Vec<Scope>) -> Self {
186 Self {
187 token_id,
188 subject: None,
189 scopes,
190 expires_at: None,
191 metadata: HashMap::new(),
192 }
193 }
194
195 pub fn is_expired(&self) -> bool {
197 if let Some(expires_at) = self.expires_at {
198 SystemTime::now() > expires_at
199 } else {
200 false
201 }
202 }
203
204 pub fn has_scope(&self, action: Action, address: &str) -> bool {
206 self.scopes
207 .iter()
208 .any(|scope| scope.allows(action, address))
209 }
210
211 pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
213 self.subject = Some(subject.into());
214 self
215 }
216
217 pub fn with_expires_at(mut self, expires_at: SystemTime) -> Self {
219 self.expires_at = Some(expires_at);
220 self
221 }
222
223 pub fn with_expires_in(mut self, duration: Duration) -> Self {
225 self.expires_at = Some(SystemTime::now() + duration);
226 self
227 }
228
229 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
231 self.metadata.insert(key.into(), value.into());
232 self
233 }
234}
235
236#[derive(Debug)]
238pub enum ValidationResult {
239 Valid(TokenInfo),
241 NotMyToken,
243 Invalid(String),
245 Expired,
247}
248
249pub trait TokenValidator: Send + Sync + std::any::Any {
251 fn validate(&self, token: &str) -> ValidationResult;
253
254 fn name(&self) -> &str;
256
257 fn as_any(&self) -> &dyn std::any::Any;
259}
260
261pub struct CpskValidator {
266 tokens: RwLock<HashMap<String, TokenInfo>>,
267}
268
269impl CpskValidator {
270 pub const PREFIX: &'static str = "cpsk_";
272
273 pub fn new() -> Self {
275 Self {
276 tokens: RwLock::new(HashMap::new()),
277 }
278 }
279
280 pub fn register(&self, token: String, info: TokenInfo) {
282 self.tokens.write().unwrap().insert(token, info);
283 }
284
285 pub fn revoke(&self, token: &str) -> bool {
287 self.tokens.write().unwrap().remove(token).is_some()
288 }
289
290 pub fn exists(&self, token: &str) -> bool {
292 self.tokens.read().unwrap().contains_key(token)
293 }
294
295 pub fn len(&self) -> usize {
297 self.tokens.read().unwrap().len()
298 }
299
300 pub fn is_empty(&self) -> bool {
302 self.tokens.read().unwrap().is_empty()
303 }
304
305 pub fn list_tokens(&self) -> Vec<String> {
307 self.tokens.read().unwrap().keys().cloned().collect()
308 }
309
310 pub fn generate_token() -> String {
312 use std::time::{SystemTime, UNIX_EPOCH};
313
314 let seed = SystemTime::now()
316 .duration_since(UNIX_EPOCH)
317 .map(|d| d.as_nanos())
318 .unwrap_or(0);
319
320 let mut state = seed as u64;
322 let mut chars = String::with_capacity(32);
323 const ALPHABET: &[u8] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
324
325 for _ in 0..32 {
326 state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
327 let idx = ((state >> 33) as usize) % ALPHABET.len();
328 chars.push(ALPHABET[idx] as char);
329 }
330
331 format!("{}{}", Self::PREFIX, chars)
332 }
333}
334
335impl Default for CpskValidator {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl TokenValidator for CpskValidator {
342 fn validate(&self, token: &str) -> ValidationResult {
343 if !token.starts_with(Self::PREFIX) {
345 return ValidationResult::NotMyToken;
346 }
347
348 let tokens = self.tokens.read().unwrap();
350 match tokens.get(token) {
351 Some(info) => {
352 if info.is_expired() {
353 ValidationResult::Expired
354 } else {
355 ValidationResult::Valid(info.clone())
356 }
357 }
358 None => ValidationResult::Invalid("token not found".to_string()),
359 }
360 }
361
362 fn name(&self) -> &str {
363 "CPSK"
364 }
365
366 fn as_any(&self) -> &dyn std::any::Any {
367 self
368 }
369}
370
371pub struct ValidatorChain {
373 validators: Vec<Box<dyn TokenValidator>>,
374}
375
376impl ValidatorChain {
377 pub fn new() -> Self {
379 Self {
380 validators: Vec::new(),
381 }
382 }
383
384 pub fn add<V: TokenValidator + 'static>(&mut self, validator: V) {
386 self.validators.push(Box::new(validator));
387 }
388
389 pub fn with<V: TokenValidator + 'static>(mut self, validator: V) -> Self {
391 self.add(validator);
392 self
393 }
394
395 pub fn validate(&self, token: &str) -> ValidationResult {
397 for validator in &self.validators {
398 match validator.validate(token) {
399 ValidationResult::NotMyToken => continue,
400 result => return result,
401 }
402 }
403 ValidationResult::Invalid("no validator accepted the token".to_string())
404 }
405
406 pub fn len(&self) -> usize {
408 self.validators.len()
409 }
410
411 pub fn is_empty(&self) -> bool {
413 self.validators.is_empty()
414 }
415}
416
417impl Default for ValidatorChain {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
425pub enum SecurityMode {
426 #[default]
428 Open,
429 Authenticated,
431}
432
433impl fmt::Display for SecurityMode {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 match self {
436 SecurityMode::Open => write!(f, "open"),
437 SecurityMode::Authenticated => write!(f, "authenticated"),
438 }
439 }
440}
441
442impl FromStr for SecurityMode {
443 type Err = Error;
444
445 fn from_str(s: &str) -> Result<Self> {
446 match s.to_lowercase().as_str() {
447 "open" | "none" | "off" => Ok(SecurityMode::Open),
448 "authenticated" | "auth" | "token" => Ok(SecurityMode::Authenticated),
449 _ => Err(Error::InvalidPattern(format!(
450 "unknown security mode: {}",
451 s
452 ))),
453 }
454 }
455}
456
457pub fn parse_scopes(s: &str) -> Result<Vec<Scope>> {
459 s.split(',').map(|part| Scope::parse(part.trim())).collect()
460}
461
462pub fn parse_duration(s: &str) -> Result<Duration> {
464 let s = s.trim();
465 if s.is_empty() {
466 return Err(Error::InvalidPattern("empty duration".to_string()));
467 }
468
469 let (num_str, unit) = if s.ends_with('d') {
470 (&s[..s.len() - 1], "d")
471 } else if s.ends_with('h') {
472 (&s[..s.len() - 1], "h")
473 } else if s.ends_with('m') {
474 (&s[..s.len() - 1], "m")
475 } else if s.ends_with('s') {
476 (&s[..s.len() - 1], "s")
477 } else {
478 (s, "s")
480 };
481
482 let num: u64 = num_str
483 .parse()
484 .map_err(|_| Error::InvalidPattern(format!("invalid duration number: {}", num_str)))?;
485
486 let secs = match unit {
487 "d" => num * 86400,
488 "h" => num * 3600,
489 "m" => num * 60,
490 "s" => num,
491 _ => unreachable!(),
492 };
493
494 Ok(Duration::from_secs(secs))
495}
496
497pub fn to_unix_timestamp(time: SystemTime) -> u64 {
499 time.duration_since(UNIX_EPOCH)
500 .map(|d| d.as_secs())
501 .unwrap_or(0)
502}
503
504pub fn from_unix_timestamp(ts: u64) -> SystemTime {
506 UNIX_EPOCH + Duration::from_secs(ts)
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn test_action_allows() {
515 assert!(Action::Admin.allows(Action::Read));
516 assert!(Action::Admin.allows(Action::Write));
517 assert!(Action::Admin.allows(Action::Admin));
518
519 assert!(Action::Write.allows(Action::Read));
520 assert!(Action::Write.allows(Action::Write));
521 assert!(!Action::Write.allows(Action::Admin));
522
523 assert!(Action::Read.allows(Action::Read));
524 assert!(!Action::Read.allows(Action::Write));
525 assert!(!Action::Read.allows(Action::Admin));
526 }
527
528 #[test]
529 fn test_action_from_str() {
530 assert_eq!(Action::from_str("read").unwrap(), Action::Read);
531 assert_eq!(Action::from_str("write").unwrap(), Action::Write);
532 assert_eq!(Action::from_str("admin").unwrap(), Action::Admin);
533 assert_eq!(Action::from_str("r").unwrap(), Action::Read);
534 assert_eq!(Action::from_str("w").unwrap(), Action::Write);
535 assert_eq!(Action::from_str("a").unwrap(), Action::Admin);
536 assert!(Action::from_str("invalid").is_err());
537 }
538
539 #[test]
540 fn test_scope_parse() {
541 let scope = Scope::parse("read:/**").unwrap();
542 assert_eq!(scope.action(), Action::Read);
543 assert!(scope.allows(Action::Read, "/any/path"));
544 assert!(!scope.allows(Action::Write, "/any/path"));
545
546 let scope = Scope::parse("write:/lights/**").unwrap();
547 assert!(scope.allows(Action::Write, "/lights/room/1"));
548 assert!(scope.allows(Action::Read, "/lights/room/1"));
549 assert!(!scope.allows(Action::Write, "/sensors/temp"));
550 assert!(!scope.allows(Action::Read, "/sensors/temp"));
551
552 let scope = Scope::parse("admin:/**").unwrap();
553 assert!(scope.allows(Action::Admin, "/any/path"));
554 assert!(scope.allows(Action::Write, "/any/path"));
555 assert!(scope.allows(Action::Read, "/any/path"));
556 }
557
558 #[test]
559 fn test_scope_wildcards() {
560 let scope = Scope::parse("read:/lumen/scene/*/layer/**").unwrap();
561 assert!(scope.allows(Action::Read, "/lumen/scene/0/layer/1/opacity"));
562 assert!(scope.allows(Action::Read, "/lumen/scene/main/layer/2"));
563 assert!(!scope.allows(Action::Read, "/lumen/scene/0/effect"));
564 }
565
566 #[test]
567 fn test_token_info() {
568 let scopes = vec![
569 Scope::parse("read:/**").unwrap(),
570 Scope::parse("write:/lights/**").unwrap(),
571 ];
572 let info = TokenInfo::new("test_token".to_string(), scopes);
573
574 assert!(info.has_scope(Action::Read, "/any/path"));
575 assert!(info.has_scope(Action::Write, "/lights/room"));
576 assert!(!info.has_scope(Action::Write, "/sensors/temp"));
577 assert!(!info.is_expired());
578 }
579
580 #[test]
581 fn test_token_expiry() {
582 let scopes = vec![Scope::parse("read:/**").unwrap()];
583 let info = TokenInfo::new("test_token".to_string(), scopes)
584 .with_expires_at(SystemTime::now() - Duration::from_secs(1));
585 assert!(info.is_expired());
586
587 let scopes = vec![Scope::parse("read:/**").unwrap()];
588 let info = TokenInfo::new("test_token".to_string(), scopes)
589 .with_expires_in(Duration::from_secs(3600));
590 assert!(!info.is_expired());
591 }
592
593 #[test]
594 fn test_cpsk_validator() {
595 let validator = CpskValidator::new();
596
597 let token = CpskValidator::generate_token();
599 assert!(token.starts_with("cpsk_"));
600 assert_eq!(token.len(), 37); let scopes = vec![Scope::parse("read:/**").unwrap()];
603 let info = TokenInfo::new(token.clone(), scopes);
604 validator.register(token.clone(), info);
605
606 match validator.validate(&token) {
608 ValidationResult::Valid(info) => {
609 assert!(info.has_scope(Action::Read, "/test"));
610 }
611 _ => panic!("expected valid token"),
612 }
613
614 match validator.validate("cpsk_unknown") {
616 ValidationResult::Invalid(_) => {}
617 _ => panic!("expected invalid token"),
618 }
619
620 match validator.validate("jwt_token") {
622 ValidationResult::NotMyToken => {}
623 _ => panic!("expected not my token"),
624 }
625
626 assert!(validator.revoke(&token));
628 match validator.validate(&token) {
629 ValidationResult::Invalid(_) => {}
630 _ => panic!("expected invalid after revoke"),
631 }
632 }
633
634 #[test]
635 fn test_validator_chain() {
636 let mut chain = ValidatorChain::new();
637
638 let cpsk = CpskValidator::new();
639 let token = CpskValidator::generate_token();
640 let scopes = vec![Scope::parse("admin:/**").unwrap()];
641 cpsk.register(token.clone(), TokenInfo::new(token.clone(), scopes));
642 chain.add(cpsk);
643
644 match chain.validate(&token) {
645 ValidationResult::Valid(_) => {}
646 _ => panic!("expected valid token"),
647 }
648
649 match chain.validate("unknown_token") {
650 ValidationResult::Invalid(_) => {}
651 _ => panic!("expected invalid token"),
652 }
653 }
654
655 #[test]
656 fn test_parse_scopes() {
657 let scopes = parse_scopes("read:/**, write:/lights/**").unwrap();
658 assert_eq!(scopes.len(), 2);
659 assert!(scopes[0].allows(Action::Read, "/any"));
660 assert!(scopes[1].allows(Action::Write, "/lights/1"));
661 }
662
663 #[test]
664 fn test_parse_duration() {
665 assert_eq!(
666 parse_duration("7d").unwrap(),
667 Duration::from_secs(7 * 86400)
668 );
669 assert_eq!(
670 parse_duration("24h").unwrap(),
671 Duration::from_secs(24 * 3600)
672 );
673 assert_eq!(parse_duration("30m").unwrap(), Duration::from_secs(30 * 60));
674 assert_eq!(parse_duration("60s").unwrap(), Duration::from_secs(60));
675 assert_eq!(parse_duration("120").unwrap(), Duration::from_secs(120));
676 assert!(parse_duration("").is_err());
677 assert!(parse_duration("abc").is_err());
678 }
679
680 #[test]
681 fn test_security_mode() {
682 assert_eq!(SecurityMode::from_str("open").unwrap(), SecurityMode::Open);
683 assert_eq!(
684 SecurityMode::from_str("authenticated").unwrap(),
685 SecurityMode::Authenticated
686 );
687 assert_eq!(
688 SecurityMode::from_str("auth").unwrap(),
689 SecurityMode::Authenticated
690 );
691 }
692}