greentic_oauth_core/
state.rs1use std::time::{Duration, SystemTime, UNIX_EPOCH};
2
3use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use crate::types::ProviderId;
8
9pub const DEFAULT_STATE_TTL: Duration = Duration::from_secs(300);
11
12#[cfg_attr(feature = "schemas", derive(schemars::JsonSchema))]
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct StateClaims {
16 pub tenant: String,
17 pub team: Option<String>,
18 pub provider: String,
19 pub nonce: String,
20 #[serde(default, skip_serializing_if = "Option::is_none")]
21 pub redirect_to: Option<String>,
22 pub iat: u64,
23 pub exp: u64,
24}
25
26impl StateClaims {
27 pub fn new(
29 tenant: impl Into<String>,
30 team: Option<String>,
31 provider: &ProviderId,
32 nonce: impl Into<String>,
33 redirect_to: Option<String>,
34 ttl: Duration,
35 ) -> Self {
36 let issued_at = current_epoch();
37 let expires_at = issued_at.saturating_add(ttl.as_secs());
38 Self {
39 tenant: tenant.into(),
40 team,
41 provider: provider.as_str().to_owned(),
42 nonce: nonce.into(),
43 redirect_to,
44 iat: issued_at,
45 exp: expires_at,
46 }
47 }
48}
49
50#[derive(Debug, Error)]
51pub enum StateError {
52 #[error("state token error: {0}")]
53 Jwt(#[from] jsonwebtoken::errors::Error),
54}
55
56pub fn sign_state(claims: &StateClaims, secret: &[u8]) -> Result<String, StateError> {
58 let header = Header::new(Algorithm::HS256);
59 encode(&header, claims, &EncodingKey::from_secret(secret)).map_err(StateError::from)
60}
61
62pub fn verify_state(token: &str, secret: &[u8]) -> Result<StateClaims, StateError> {
64 let validation = Validation::new(Algorithm::HS256);
65 decode::<StateClaims>(token, &DecodingKey::from_secret(secret), &validation)
66 .map(|data| data.claims)
67 .map_err(StateError::from)
68}
69
70fn current_epoch() -> u64 {
71 SystemTime::now()
72 .duration_since(UNIX_EPOCH)
73 .map(|d| d.as_secs())
74 .unwrap_or_default()
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use crate::types::ProviderId;
81
82 #[test]
83 fn sign_and_verify_roundtrip() {
84 let claims = StateClaims::new(
85 "acme",
86 Some("platform".into()),
87 &ProviderId::Google,
88 "nonce-123",
89 Some("https://app.example.com/callback".into()),
90 DEFAULT_STATE_TTL,
91 );
92 let secret = b"super-secret-key";
93 let token = sign_state(&claims, secret).expect("sign");
94 let decoded = verify_state(&token, secret).expect("verify");
95 assert_eq!(claims.tenant, decoded.tenant);
96 assert_eq!(claims.team, decoded.team);
97 assert_eq!(claims.provider, decoded.provider);
98 assert_eq!(claims.nonce, decoded.nonce);
99 assert_eq!(claims.redirect_to, decoded.redirect_to);
100 }
101
102 #[test]
103 fn rejects_tampered_token() {
104 let claims = StateClaims::new(
105 "acme",
106 None,
107 &ProviderId::Google,
108 "nonce",
109 None,
110 DEFAULT_STATE_TTL,
111 );
112 let secret = b"secret";
113 let mut token = sign_state(&claims, secret).expect("sign");
114 token.push('a');
115 assert!(verify_state(&token, secret).is_err());
116 }
117}