openid_client/helpers/
public.rs

1use std::time::{SystemTime, UNIX_EPOCH};
2
3use josekit::{jws::JwsHeader, jwt::JwtPayload};
4use jwt_compact::jwk::JsonWebKey;
5use rand::Rng;
6use serde::Deserialize;
7use serde_json::{Map, Value};
8use sha2::{Digest, Sha256};
9
10use crate::types::{DecodedToken, OidcClientError, OidcReturnType};
11
12/// Gets a Unix Timestamp in seconds. Uses [`SystemTime::now`]
13pub fn now() -> u64 {
14    let start = SystemTime::now();
15    start
16        .duration_since(UNIX_EPOCH)
17        .expect("Time went backwards")
18        .as_secs()
19}
20
21/// Generates a random string using [rand::thread_rng]. You can pass in the bytes to generates
22pub fn generate_random(bytes_to_generate: Option<u32>) -> String {
23    let mut random_bytes = vec![];
24
25    for _ in 0..bytes_to_generate.unwrap_or(32) {
26        random_bytes.push(rand::thread_rng().gen());
27    }
28
29    base64_url::encode(&random_bytes)
30}
31
32/// Generates a random string as the state. Uses [generate_random] under the hood.
33pub fn generate_state(bytes: Option<u32>) -> String {
34    generate_random(bytes)
35}
36
37/// Generates a random string as the nonce. Uses [generate_random] under the hood.
38pub fn generate_nonce(bytes: Option<u32>) -> String {
39    generate_random(bytes)
40}
41
42/// Generates a random string as the code_verifier. Uses [generate_random] under the hood.
43pub fn generate_code_verifier() -> String {
44    generate_random(None)
45}
46
47/// Generates the S256 PKCE code challenge for `verifier`.
48pub fn code_challenge(verifier: &str) -> String {
49    let mut hasher = Sha256::new();
50
51    hasher.update(verifier);
52
53    base64_url::encode(&hasher.finalize().to_vec())
54}
55
56/// Converts plain JSON to a struct/enum that impl's serde's [Deserialize]. Uses [serde_json::from_str] under
57/// the hood
58pub fn convert_json_to<T: for<'a> Deserialize<'a>>(plain: &str) -> Result<T, String> {
59    if let Ok(r) = serde_json::from_str::<T>(plain) {
60        return Ok(r);
61    }
62
63    Err("Parse Error".to_string())
64}
65
66/// Gets S256 thumbprint of a JWK JSON.
67pub fn get_s256_jwk_thumbprint(jwk_str: &str) -> OidcReturnType<String> {
68    let jwk: JsonWebKey<'_> = serde_json::from_str(jwk_str)
69        .map_err(|_| OidcClientError::new_error("Invalid JWK", None))?;
70
71    Ok(base64_url::encode(&jwk.thumbprint::<Sha256>().to_vec()))
72}
73
74/// Decodes a JWT without verification
75pub fn decode_jwt(token: &str) -> OidcReturnType<DecodedToken> {
76    let split_token: Vec<&str> = token.split('.').collect();
77
78    if split_token.len() == 5 {
79        return Err(Box::new(OidcClientError::new_type_error(
80            "encrypted JWTs cannot be decoded",
81            None,
82        )));
83    }
84
85    if split_token.len() != 3 {
86        return Err(Box::new(OidcClientError::new_error(
87            "JWTs must have three components",
88            None,
89        )));
90    }
91
92    let map_err_decode = |_| OidcClientError::new_error("JWT is malformed", None);
93    let map_err_deserialize = |_| OidcClientError::new_error("JWT is malformed", None);
94    let map_err_jose = |_| OidcClientError::new_error("JWT is malformed", None);
95
96    let header_str = base64_url::decode(split_token[0]).map_err(map_err_decode)?;
97    let payload_str = base64_url::decode(split_token[1]).map_err(map_err_decode)?;
98    let signature = split_token[2].to_string();
99
100    let header = serde_json::from_slice::<Map<String, Value>>(&header_str)
101        .map(JwsHeader::from_map)
102        .map_err(map_err_deserialize)?
103        .map_err(map_err_jose)?;
104
105    let payload = serde_json::from_slice::<Map<String, Value>>(&payload_str)
106        .map(JwtPayload::from_map)
107        .map_err(map_err_deserialize)?
108        .map_err(map_err_jose)?;
109
110    Ok(DecodedToken {
111        header,
112        payload,
113        signature,
114    })
115}
116
117#[cfg(test)]
118mod helper_tests {
119    use crate::helpers::{code_challenge, generate_code_verifier};
120
121    #[test]
122    fn code_challenge_should_work() {
123        let verifier = "xupVEHY65t6sSASJL5eWq8e736TvtlgQUrU2c9hMqaA";
124
125        assert_eq!(
126            "TQduXP_9QfLCe9TY10TxZEP4gXy6xBPirtydtDOoQC0",
127            code_challenge(&verifier)
128        );
129    }
130
131    #[test]
132    fn code_verifier_should_only_create_string_of_length_43() {
133        for _ in 1..100 {
134            assert_eq!(43, generate_code_verifier().len());
135        }
136    }
137}