1use std::{fmt, str::FromStr};
2
3use base64::Engine;
4use rand::Rng;
5
6#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
8pub struct Sid([u8; 16]);
9
10impl Sid {
11 pub const ZERO: Self = Self([0u8; 16]);
13 pub fn new() -> Self {
15 Self::default()
16 }
17
18 pub const fn as_str(&self) -> &str {
20 unsafe { std::str::from_utf8_unchecked(&self.0) }
22 }
23}
24
25#[derive(Debug)]
27pub enum SidDecodeError {
28 InvalidBase64String,
30 InvalidLength,
32}
33impl fmt::Display for SidDecodeError {
34 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35 match self {
36 SidDecodeError::InvalidBase64String => write!(f, "Invalid url base64 string"),
37 SidDecodeError::InvalidLength => write!(f, "Invalid sid length"),
38 }
39 }
40}
41impl std::error::Error for SidDecodeError {}
42
43impl FromStr for Sid {
44 type Err = SidDecodeError;
45
46 fn from_str(s: &str) -> Result<Self, Self::Err> {
47 use SidDecodeError::*;
48
49 let mut id = [0u8; 16];
50
51 if s.len() != 16 {
53 return Err(InvalidLength);
54 }
55
56 for (idx, byte) in s.as_bytes()[0..16].iter().enumerate() {
58 if byte.is_ascii_alphanumeric() || byte == &b'_' || byte == &b'-' {
59 id[idx] = *byte;
60 } else {
61 return Err(InvalidBase64String);
62 }
63 }
64 Ok(Sid(id))
65 }
66}
67
68impl Default for Sid {
69 fn default() -> Self {
70 let mut random = [0u8; 12]; let mut id = [0u8; 16];
72
73 rand::rng().fill(&mut random);
74
75 base64::prelude::BASE64_URL_SAFE_NO_PAD
76 .encode_slice(random, &mut id)
77 .unwrap();
78
79 Sid(id)
80 }
81}
82
83impl fmt::Display for Sid {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "{}", self.as_str())
86 }
87}
88impl serde::Serialize for Sid {
89 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
90 serializer.serialize_str(self.as_str())
91 }
92}
93
94struct SidVisitor;
95impl serde::de::Visitor<'_> for SidVisitor {
96 type Value = Sid;
97
98 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
99 formatter.write_str("a valid sid")
100 }
101
102 fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
103 Sid::from_str(v).map_err(serde::de::Error::custom)
104 }
105}
106impl<'de> serde::Deserialize<'de> for Sid {
107 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
108 deserializer.deserialize_str(SidVisitor)
109 }
110}
111
112impl fmt::Debug for Sid {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 write!(f, "{}", self.as_str())
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use std::str::FromStr;
121
122 use crate::sid::Sid;
123
124 #[test]
125 fn test_sid_from_str() {
126 let id = Sid::new();
127 let id2 = Sid::from_str(&id.to_string()).unwrap();
128 assert_eq!(id, id2);
129 let id = Sid::from_str("AA9AAA0AAzAAAAHs").unwrap();
130 assert_eq!(id.to_string(), "AA9AAA0AAzAAAAHs");
131 }
132
133 #[test]
134 fn test_sid_from_str_invalid() {
135 let id = Sid::from_str("*$^รนรน!").unwrap_err();
136 assert_eq!(id.to_string(), "Invalid sid length");
137 let id = Sid::from_str("aoassaAZDoin#zd{").unwrap_err();
138 assert_eq!(id.to_string(), "Invalid url base64 string");
139 let id = Sid::from_str("aoassaAZDoinazd<").unwrap_err();
140 assert_eq!(id.to_string(), "Invalid url base64 string");
141 }
142}