engineioxide_core/
sid.rs

1use std::{fmt, str::FromStr};
2
3use base64::Engine;
4use rand::Rng;
5
6/// A 128 bit session id type representing a base64 16 char string
7#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
8pub struct Sid([u8; 16]);
9
10impl Sid {
11    /// A zeroed session id
12    pub const ZERO: Self = Self([0u8; 16]);
13    /// Generate a new random session id (base64 16 chars)
14    pub fn new() -> Self {
15        Self::default()
16    }
17
18    /// Get the session id as a base64 16 chars string
19    pub const fn as_str(&self) -> &str {
20        // SAFETY: SID is always a base64 chars string
21        unsafe { std::str::from_utf8_unchecked(&self.0) }
22    }
23}
24
25/// Error type for [`Sid::from_str`]
26#[derive(Debug)]
27pub enum SidDecodeError {
28    /// Invalid base64 string
29    InvalidBase64String,
30    /// Invalid length
31    InvalidLength,
32}
33impl fmt::Display for SidDecodeError {
34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35        match self {
36            SidDecodeError::InvalidBase64String => write!(f, "Invalid url base64 string"),
37            SidDecodeError::InvalidLength => write!(f, "Invalid sid length"),
38        }
39    }
40}
41impl std::error::Error for SidDecodeError {}
42
43impl FromStr for Sid {
44    type Err = SidDecodeError;
45
46    fn from_str(s: &str) -> Result<Self, Self::Err> {
47        use SidDecodeError::*;
48
49        let mut id = [0u8; 16];
50
51        // Verify the length of the string
52        if s.len() != 16 {
53            return Err(InvalidLength);
54        }
55
56        // Verify that the string is a valid base64 url safe string without padding
57        for (idx, byte) in s.as_bytes()[0..16].iter().enumerate() {
58            if byte.is_ascii_alphanumeric() || byte == &b'_' || byte == &b'-' {
59                id[idx] = *byte;
60            } else {
61                return Err(InvalidBase64String);
62            }
63        }
64        Ok(Sid(id))
65    }
66}
67
68impl Default for Sid {
69    fn default() -> Self {
70        let mut random = [0u8; 12]; // 12 bytes = 16 chars base64
71        let mut id = [0u8; 16];
72
73        rand::rng().fill(&mut random);
74
75        base64::prelude::BASE64_URL_SAFE_NO_PAD
76            .encode_slice(random, &mut id)
77            .unwrap();
78
79        Sid(id)
80    }
81}
82
83impl fmt::Display for Sid {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        write!(f, "{}", self.as_str())
86    }
87}
88impl serde::Serialize for Sid {
89    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
90        serializer.serialize_str(self.as_str())
91    }
92}
93
94struct SidVisitor;
95impl serde::de::Visitor<'_> for SidVisitor {
96    type Value = Sid;
97
98    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
99        formatter.write_str("a valid sid")
100    }
101
102    fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
103        Sid::from_str(v).map_err(serde::de::Error::custom)
104    }
105}
106impl<'de> serde::Deserialize<'de> for Sid {
107    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
108        deserializer.deserialize_str(SidVisitor)
109    }
110}
111
112impl fmt::Debug for Sid {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "{}", self.as_str())
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use std::str::FromStr;
121
122    use crate::sid::Sid;
123
124    #[test]
125    fn test_sid_from_str() {
126        let id = Sid::new();
127        let id2 = Sid::from_str(&id.to_string()).unwrap();
128        assert_eq!(id, id2);
129        let id = Sid::from_str("AA9AAA0AAzAAAAHs").unwrap();
130        assert_eq!(id.to_string(), "AA9AAA0AAzAAAAHs");
131    }
132
133    #[test]
134    fn test_sid_from_str_invalid() {
135        let id = Sid::from_str("*$^รนรน!").unwrap_err();
136        assert_eq!(id.to_string(), "Invalid sid length");
137        let id = Sid::from_str("aoassaAZDoin#zd{").unwrap_err();
138        assert_eq!(id.to_string(), "Invalid url base64 string");
139        let id = Sid::from_str("aoassaAZDoinazd<").unwrap_err();
140        assert_eq!(id.to_string(), "Invalid url base64 string");
141    }
142}