pub use jsonwebtoken::errors::Error as JwtError;
pub use jsonwebtoken::errors::ErrorKind as JwtErrorKind;
pub use jsonwebtoken::{DecodingKey, EncodingKey};
use serde::Deserializer;
use serde::{Deserialize, Serialize};
use spacetimedb_lib::Identity;
use std::time::SystemTime;
#[serde_with::serde_as]
#[derive(Debug, Serialize, Deserialize)]
pub struct SpacetimeIdentityClaims {
#[serde(rename = "hex_identity")]
pub identity: Identity,
#[serde(rename = "sub")]
pub subject: String,
#[serde(rename = "iss")]
pub issuer: String,
#[serde(rename = "aud")]
pub audience: Vec<String>,
#[serde_as(as = "serde_with::TimestampSeconds")]
pub iat: SystemTime,
#[serde_as(as = "Option<serde_with::TimestampSeconds>")]
pub exp: Option<SystemTime>,
}
fn deserialize_audience<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum Audience {
Single(String),
Multiple(Vec<String>),
}
let audience = Audience::deserialize(deserializer)?;
Ok(match audience {
Audience::Single(s) => vec![s],
Audience::Multiple(v) => v,
})
}
#[serde_with::serde_as]
#[derive(Debug, Serialize, Deserialize)]
pub struct IncomingClaims {
#[serde(rename = "hex_identity")]
pub identity: Option<Identity>,
#[serde(rename = "sub")]
pub subject: String,
#[serde(rename = "iss")]
pub issuer: String,
#[serde(rename = "aud", default, deserialize_with = "deserialize_audience")]
pub audience: Vec<String>,
#[serde_as(as = "serde_with::TimestampSeconds")]
pub iat: SystemTime,
#[serde_as(as = "Option<serde_with::TimestampSeconds>")]
pub exp: Option<SystemTime>,
}
impl TryInto<SpacetimeIdentityClaims> for IncomingClaims {
type Error = anyhow::Error;
fn try_into(self) -> anyhow::Result<SpacetimeIdentityClaims> {
if self.issuer.len() > 128 {
return Err(anyhow::anyhow!("Issuer too long: {:?}", self.issuer));
}
if self.subject.len() > 128 {
return Err(anyhow::anyhow!("Subject too long: {:?}", self.subject));
}
if self.issuer.is_empty() {
return Err(anyhow::anyhow!("Issuer empty"));
}
if self.subject.is_empty() {
return Err(anyhow::anyhow!("Subject empty"));
}
let computed_identity = Identity::from_claims(&self.issuer, &self.subject);
if let Some(token_identity) = self.identity {
if token_identity != computed_identity {
return Err(anyhow::anyhow!(
"Identity mismatch: token identity {:?} does not match computed identity {:?}",
token_identity,
computed_identity,
));
}
}
Ok(SpacetimeIdentityClaims {
identity: computed_identity,
subject: self.subject,
issuer: self.issuer,
audience: self.audience,
iat: self.iat,
exp: self.exp,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::time::UNIX_EPOCH;
#[test]
fn test_deserialize_audience_single_string() {
let json_data = json!({
"sub": "123",
"iss": "example.com",
"aud": "audience1",
"iat": 1693425600,
"exp": 1693512000
});
let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
assert_eq!(claims.audience, vec!["audience1"]);
assert_eq!(claims.subject, "123");
assert_eq!(claims.issuer, "example.com");
assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
assert_eq!(
claims.exp,
Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
);
}
#[test]
fn test_deserialize_audience_multiple_strings() {
let json_data = json!({
"sub": "123",
"iss": "example.com",
"aud": ["audience1", "audience2"],
"iat": 1693425600,
"exp": 1693512000
});
let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
assert_eq!(claims.audience, vec!["audience1", "audience2"]);
assert_eq!(claims.subject, "123");
assert_eq!(claims.issuer, "example.com");
assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
assert_eq!(
claims.exp,
Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
);
}
#[test]
fn test_deserialize_audience_missing_field() {
let json_data = json!({
"sub": "123",
"iss": "example.com",
"iat": 1693425600,
"exp": 1693512000
});
let claims: IncomingClaims = serde_json::from_value(json_data).unwrap();
assert!(claims.audience.is_empty()); assert_eq!(claims.subject, "123");
assert_eq!(claims.issuer, "example.com");
assert_eq!(claims.iat, UNIX_EPOCH + std::time::Duration::from_secs(1693425600));
assert_eq!(
claims.exp,
Some(UNIX_EPOCH + std::time::Duration::from_secs(1693512000))
);
}
}