greentic_oauth_core/
state.rs

1use std::time::{Duration, SystemTime, UNIX_EPOCH};
2
3use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use crate::types::ProviderId;
8
9/// Default lifetime applied to newly created state tokens.
10pub const DEFAULT_STATE_TTL: Duration = Duration::from_secs(300);
11
12/// Claims embedded into the secure state token.
13#[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))]
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct StateClaims {
16    pub tenant: String,
17    pub team: Option<String>,
18    pub provider: String,
19    pub nonce: String,
20    #[serde(default, skip_serializing_if = "Option::is_none")]
21    pub redirect_to: Option<String>,
22    pub iat: u64,
23    pub exp: u64,
24}
25
26impl StateClaims {
27    /// Construct claims with the provided context and TTL.
28    pub fn new(
29        tenant: impl Into<String>,
30        team: Option<String>,
31        provider: &ProviderId,
32        nonce: impl Into<String>,
33        redirect_to: Option<String>,
34        ttl: Duration,
35    ) -> Self {
36        let issued_at = current_epoch();
37        let expires_at = issued_at.saturating_add(ttl.as_secs());
38        Self {
39            tenant: tenant.into(),
40            team,
41            provider: provider.as_str().to_owned(),
42            nonce: nonce.into(),
43            redirect_to,
44            iat: issued_at,
45            exp: expires_at,
46        }
47    }
48}
49
50#[derive(Debug, Error)]
51pub enum StateError {
52    #[error("state token error: {0}")]
53    Jwt(#[from] jsonwebtoken::errors::Error),
54}
55
56/// Sign the provided claims using HS256.
57pub fn sign_state(claims: &StateClaims, secret: &[u8]) -> Result<String, StateError> {
58    let header = Header::new(Algorithm::HS256);
59    encode(&header, claims, &EncodingKey::from_secret(secret)).map_err(StateError::from)
60}
61
62/// Verify the state token and return the embedded claims.
63pub fn verify_state(token: &str, secret: &[u8]) -> Result<StateClaims, StateError> {
64    let validation = Validation::new(Algorithm::HS256);
65    decode::<StateClaims>(token, &DecodingKey::from_secret(secret), &validation)
66        .map(|data| data.claims)
67        .map_err(StateError::from)
68}
69
70fn current_epoch() -> u64 {
71    SystemTime::now()
72        .duration_since(UNIX_EPOCH)
73        .map(|d| d.as_secs())
74        .unwrap_or_default()
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use crate::types::ProviderId;
81
82    #[test]
83    fn sign_and_verify_roundtrip() {
84        let claims = StateClaims::new(
85            "acme",
86            Some("platform".into()),
87            &ProviderId::Google,
88            "nonce-123",
89            Some("https://app.example.com/callback".into()),
90            DEFAULT_STATE_TTL,
91        );
92        let secret = b"super-secret-key";
93        let token = sign_state(&claims, secret).expect("sign");
94        let decoded = verify_state(&token, secret).expect("verify");
95        assert_eq!(claims.tenant, decoded.tenant);
96        assert_eq!(claims.team, decoded.team);
97        assert_eq!(claims.provider, decoded.provider);
98        assert_eq!(claims.nonce, decoded.nonce);
99        assert_eq!(claims.redirect_to, decoded.redirect_to);
100    }
101
102    #[test]
103    fn rejects_tampered_token() {
104        let claims = StateClaims::new(
105            "acme",
106            None,
107            &ProviderId::Google,
108            "nonce",
109            None,
110            DEFAULT_STATE_TTL,
111        );
112        let secret = b"secret";
113        let mut token = sign_state(&claims, secret).expect("sign");
114        token.push('a');
115        assert!(verify_state(&token, secret).is_err());
116    }
117}