Skip to main content

oauth2_test_server/
crypto.rs

1use base64::{engine::general_purpose, Engine};
2use chrono::{Duration, Utc};
3use jsonwebtoken::jwk::{CommonParameters, Jwk};
4use jsonwebtoken::{encode, Algorithm, DecodingKey, EncodingKey, Header};
5use rsa::pkcs8::{DecodePublicKey, EncodePrivateKey, EncodePublicKey};
6use rsa::traits::PublicKeyParts;
7use rsa::RsaPrivateKey;
8use serde_json::json;
9use sha2::{Digest, Sha256};
10use uuid::Uuid;
11
12use crate::models::{Claims, IdTokenClaims};
13
14/// RSA key pair used for signing and verifying JWT access tokens.
15pub struct Keys {
16    pub encoding: EncodingKey,
17    pub decoding: DecodingKey,
18    pub public_pem: String,
19    /// Key ID embedded in JWT headers and JWKS; unique per server instance.
20    pub kid: String,
21}
22
23impl Keys {
24    /// Generate a fresh 2048-bit RSA key pair for this server instance.
25    pub fn generate() -> Self {
26        let mut rng = rand::thread_rng();
27
28        let private_key =
29            RsaPrivateKey::new(&mut rng, 2048).expect("failed to generate RSA key pair");
30        let public_key = private_key.to_public_key();
31
32        let private_pem = private_key
33            .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
34            .expect("failed to encode private key as PKCS8 PEM")
35            .to_string();
36
37        let public_pem = public_key
38            .to_public_key_pem(rsa::pkcs8::LineEnding::LF)
39            .expect("failed to encode public key as PEM")
40            .to_string();
41
42        let encoding =
43            EncodingKey::from_rsa_pem(private_pem.as_bytes()).expect("failed to build EncodingKey");
44        let decoding =
45            DecodingKey::from_rsa_pem(public_pem.as_bytes()).expect("failed to build DecodingKey");
46        let kid = format!("key-{}", Uuid::new_v4());
47
48        Keys {
49            encoding,
50            decoding,
51            public_pem,
52            kid,
53        }
54    }
55}
56
57/// Build the JWKS JSON document (public keys) for a given key set.
58pub fn build_jwks_json(keys: &Keys) -> serde_json::Value {
59    let public_key = rsa::RsaPublicKey::from_public_key_pem(&keys.public_pem)
60        .expect("failed to re-parse stored public key");
61
62    let jwk = Jwk {
63        common: CommonParameters {
64            key_algorithm: Some(jsonwebtoken::jwk::KeyAlgorithm::RS256),
65            key_id: Some(keys.kid.clone()),
66            ..Default::default()
67        },
68        algorithm: jsonwebtoken::jwk::AlgorithmParameters::RSA(
69            jsonwebtoken::jwk::RSAKeyParameters {
70                n: general_purpose::URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be()),
71                e: general_purpose::URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be()),
72                key_type: jsonwebtoken::jwk::RSAKeyType::RSA,
73            },
74        ),
75    };
76
77    json!({ "keys": [jwk] })
78}
79
80/// Sign and return a JWT access token.
81pub fn issue_jwt(
82    issuer: &str,
83    client_id: &str,
84    user_id: &str,
85    requested_scope: &str,
86    expires_in: i64,
87    keys: &Keys,
88) -> Result<String, jsonwebtoken::errors::Error> {
89    let iat = Utc::now().timestamp() as usize;
90    let exp = (Utc::now() + Duration::seconds(expires_in)).timestamp() as usize;
91
92    let scopes: Vec<&str> = requested_scope.split_whitespace().collect();
93
94    let claims = Claims {
95        iss: issuer.to_string(),
96        sub: user_id.to_string(),
97        aud: client_id.to_string(),
98        exp,
99        iat,
100        scope: Some(scopes.join(" ")),
101        auth_time: Some(iat),
102        typ: "Bearer".to_string(),
103        azp: Some(client_id.to_string()),
104        sid: Some(format!("sid-{}", Uuid::new_v4())),
105        jti: Uuid::new_v4().to_string(),
106    };
107
108    let mut header = Header::new(Algorithm::RS256);
109    header.typ = Some("JWT".to_string());
110    header.kid = Some(keys.kid.clone());
111
112    encode(&header, &claims, &keys.encoding)
113}
114
115/// Generate a short authorization code.
116pub fn generate_code() -> String {
117    Uuid::new_v4().to_string()[..20].to_string()
118}
119
120/// Generate an opaque access/refresh token string.
121pub fn generate_token_string() -> String {
122    format!("tok_{}", Uuid::new_v4().to_string().replace("-", ""))
123}
124
125/// Calculate at_hash (Access Token Hash) per OIDC Core Section 3.2.2.9.
126/// Used to validate that an access token was issued alongside an ID token.
127pub fn calculate_at_hash(access_token: &str) -> String {
128    let hash = Sha256::digest(access_token.as_bytes());
129    let half = &hash[..hash.len() / 2];
130    general_purpose::URL_SAFE_NO_PAD.encode(half)
131}
132
133/// Calculate c_hash (Code Hash) per OIDC Core Section 3.2.2.9.
134/// Used to validate that an authorization code was issued alongside an ID token.
135pub fn calculate_c_hash(authorization_code: &str) -> String {
136    let hash = Sha256::digest(authorization_code.as_bytes());
137    let half = &hash[..hash.len() / 2];
138    general_purpose::URL_SAFE_NO_PAD.encode(half)
139}
140
141#[allow(clippy::too_many_arguments)]
142/// Sign and return an ID Token per OpenID Connect Core 1.0.
143pub fn issue_id_token(
144    issuer: &str,
145    client_id: &str,
146    user_id: &str,
147    nonce: Option<&str>,
148    at_hash: Option<&str>,
149    c_hash: Option<&str>,
150    expires_in: i64,
151    user_claims: serde_json::Value,
152    keys: &Keys,
153) -> Result<String, jsonwebtoken::errors::Error> {
154    let iat = Utc::now().timestamp() as usize;
155    let exp = (Utc::now() + Duration::seconds(expires_in)).timestamp() as usize;
156
157    let mut claims = IdTokenClaims::new(issuer, user_id, client_id, exp, iat);
158
159    if let Some(n) = nonce {
160        claims = claims.with_nonce(n);
161    }
162    if let Some(hash) = at_hash {
163        claims = claims.with_at_hash(hash);
164    }
165    if let Some(hash) = c_hash {
166        claims = claims.with_c_hash(hash);
167    }
168    claims = claims.with_azp(client_id);
169    claims = claims.with_user_claims(user_claims);
170
171    let mut header = Header::new(Algorithm::RS256);
172    header.typ = Some("JWT".to_string());
173    header.kid = Some(keys.kid.clone());
174
175    encode(&header, &claims, &keys.encoding)
176}