spacetimedb_auth/
identity.rs

1pub use jsonwebtoken::errors::Error as JwtError;
2pub use jsonwebtoken::errors::ErrorKind as JwtErrorKind;
3pub use jsonwebtoken::{DecodingKey, EncodingKey};
4use serde::Deserializer;
5use serde::{Deserialize, Serialize};
6use spacetimedb_lib::Identity;
7use std::time::SystemTime;
8
9// These are the claims that can be attached to a request/connection.
10#[serde_with::serde_as]
11#[derive(Debug, Serialize, Deserialize)]
12pub struct SpacetimeIdentityClaims {
13    #[serde(rename = "hex_identity")]
14    pub identity: Identity,
15    #[serde(rename = "sub")]
16    pub subject: String,
17    #[serde(rename = "iss")]
18    pub issuer: String,
19    #[serde(rename = "aud")]
20    pub audience: Vec<String>,
21
22    /// The unix timestamp the token was issued at
23    #[serde_as(as = "serde_with::TimestampSeconds")]
24    pub iat: SystemTime,
25    #[serde_as(as = "Option<serde_with::TimestampSeconds>")]
26    pub exp: Option<SystemTime>,
27}
28
29fn deserialize_audience<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
30where
31    D: Deserializer<'de>,
32{
33    // By using `untagged`, it will try the different options.
34    #[derive(Deserialize)]
35    #[serde(untagged)]
36    enum Audience {
37        Single(String),
38        Multiple(Vec<String>),
39    }
40
41    // Deserialize into the enum
42    let audience = Audience::deserialize(deserializer)?;
43
44    // Convert the enum into a Vec<String>
45    Ok(match audience {
46        Audience::Single(s) => vec![s],
47        Audience::Multiple(v) => v,
48    })
49}
50
51// IncomingClaims are from the token we receive from the client.
52// The signature should be verified already, but further validation is needed to have a SpacetimeIdentityClaims2.
53#[serde_with::serde_as]
54#[derive(Debug, Serialize, Deserialize)]
55pub struct IncomingClaims {
56    #[serde(rename = "hex_identity")]
57    pub identity: Option<Identity>,
58    #[serde(rename = "sub")]
59    pub subject: String,
60    #[serde(rename = "iss")]
61    pub issuer: String,
62    #[serde(rename = "aud", default, deserialize_with = "deserialize_audience")]
63    pub audience: Vec<String>,
64
65    /// The unix timestamp the token was issued at
66    #[serde_as(as = "serde_with::TimestampSeconds")]
67    pub iat: SystemTime,
68    #[serde_as(as = "Option<serde_with::TimestampSeconds>")]
69    pub exp: Option<SystemTime>,
70}
71
72impl TryInto<SpacetimeIdentityClaims> for IncomingClaims {
73    type Error = anyhow::Error;
74
75    fn try_into(self) -> anyhow::Result<SpacetimeIdentityClaims> {
76        // The issuer and subject must be less than 128 bytes.
77        if self.issuer.len() > 128 {
78            return Err(anyhow::anyhow!("Issuer too long: {:?}", self.issuer));
79        }
80        if self.subject.len() > 128 {
81            return Err(anyhow::anyhow!("Subject too long: {:?}", self.subject));
82        }
83        // The issuer and subject must be non-empty.
84        if self.issuer.is_empty() {
85            return Err(anyhow::anyhow!("Issuer empty"));
86        }
87        if self.subject.is_empty() {
88            return Err(anyhow::anyhow!("Subject empty"));
89        }
90
91        let computed_identity = Identity::from_claims(&self.issuer, &self.subject);
92        // If an identity is provided, it must match the computed identity.
93        if let Some(token_identity) = self.identity {
94            if token_identity != computed_identity {
95                return Err(anyhow::anyhow!(
96                    "Identity mismatch: token identity {:?} does not match computed identity {:?}",
97                    token_identity,
98                    computed_identity,
99                ));
100            }
101        }
102
103        Ok(SpacetimeIdentityClaims {
104            identity: computed_identity,
105            subject: self.subject,
106            issuer: self.issuer,
107            audience: self.audience,
108            iat: self.iat,
109            exp: self.exp,
110        })
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use serde_json::json;
118    use std::time::UNIX_EPOCH;
119
120    #[test]
121    fn test_deserialize_audience_single_string() {
122        let json_data = json!({
123            "sub": "123",
124            "iss": "example.com",
125            "aud": "audience1",
126            "iat": 1693425600,
127            "exp": 1693512000
128        });
129
130        let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
131
132        assert_eq!(claims.audience, vec!["audience1"]);
133        assert_eq!(claims.subject, "123");
134        assert_eq!(claims.issuer, "example.com");
135        assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
136        assert_eq!(
137            claims.exp,
138            Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
139        );
140    }
141
142    #[test]
143    fn test_deserialize_audience_multiple_strings() {
144        let json_data = json!({
145            "sub": "123",
146            "iss": "example.com",
147            "aud": ["audience1", "audience2"],
148            "iat": 1693425600,
149            "exp": 1693512000
150        });
151
152        let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
153
154        assert_eq!(claims.audience, vec!["audience1", "audience2"]);
155        assert_eq!(claims.subject, "123");
156        assert_eq!(claims.issuer, "example.com");
157        assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
158        assert_eq!(
159            claims.exp,
160            Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
161        );
162    }
163
164    #[test]
165    fn test_deserialize_audience_missing_field() {
166        let json_data = json!({
167            "sub": "123",
168            "iss": "example.com",
169            "iat": 1693425600,
170            "exp": 1693512000
171        });
172
173        let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
174
175        assert!(claims.audience.is_empty()); // Since `default` is used, it should be an empty vector
176        assert_eq!(claims.subject, "123");
177        assert_eq!(claims.issuer, "example.com");
178        assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
179        assert_eq!(
180            claims.exp,
181            Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
182        );
183    }
184}