tiny_oidc_rp/
client.rs

1// SPDX-License-Identifier: MIT
2use crate::error::AuthenticationFailedError;
3use crate::{Error, IdToken, Provider};
4
5/// OpenID connect `response_mode` parameter.
6///
7/// See: <https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html>
8#[derive(Clone, Debug)]
9pub enum OidcResponseMode {
10    /// Default for "code" flow.
11    /// Authentication code is returned by HTTP GET with query parameter.
12    Query,
13    /// Alternate mode.
14    /// Authentication code is returned by HTTP POST with form body.
15    ///
16    /// `form_post` mode lowers the risk of authentication code disclosure
17    /// by `Referer` HTTP header or HTTP server log,
18    /// but consider that SameSite session cookie will not be POST with this mode.
19    FormPost,
20    /// For single page Web app,
21    /// Authentication code is returned by HTTP GET with fragment
22    /// and will not be sent to server directly.
23    Fragment,
24}
25
26// response_mode as &str
27impl std::ops::Deref for OidcResponseMode {
28    type Target = str;
29    fn deref(&self) -> &str {
30        match self {
31            Self::Query => "query",
32            Self::FormPost => "form_post",
33            Self::Fragment => "fragment",
34        }
35    }
36}
37
38/// OpenID connect `prompt` parameter.
39#[derive(Clone, Debug)]
40pub enum OidcPrompt {
41    NoPrompt, // `prompt=none`, renamed to avoid confusion with Option::None
42    Login,
43    Consent,
44    SelectAccount,
45}
46
47// prompt as &str
48impl std::ops::Deref for OidcPrompt {
49    type Target = str;
50    fn deref(&self) -> &str {
51        match self {
52            Self::NoPrompt => "none",
53            Self::Login => "login",
54            Self::Consent => "consent",
55            Self::SelectAccount => "select_account",
56        }
57    }
58}
59
60/// OpenID Connect relying party client
61#[derive(Clone, Debug)]
62pub struct Client<P: Provider> {
63    client_id: String,
64    client_secret: String,
65    redirect_uri: String,
66    response_mode: OidcResponseMode,
67    provider: P,
68}
69
70impl<P: Provider> Client<P> {
71    /// Create authn URL with query parameter
72    ///
73    /// If you request the user to force re-login, set prompt=Some(Login)
74    pub fn auth_url(&self, session: &Session, prompt: Option<OidcPrompt>) -> url::Url {
75        // append queries to authorize endpoint
76        let mut authurl = self.provider.authorization_endpoint();
77        authurl
78            .query_pairs_mut()
79            .append_pair("scope", "openid profile email")
80            .append_pair("response_type", "code")
81            .append_pair("client_id", &self.client_id)
82            .append_pair("nonce", &session.nonce())
83            .append_pair("state", &session.state())
84            .append_pair("response_mode", &self.response_mode)
85            .append_pair("redirect_uri", &self.redirect_uri)
86            .append_pair("code_challenge_method", "S256")
87            .append_pair("code_challenge", &session.pkce_challenge());
88
89        if let Some(prompt) = prompt {
90            authurl.query_pairs_mut().append_pair("prompt", &prompt);
91        }
92
93        authurl
94    }
95
96    /// Authenticate user with `state`, `code`
97    ///
98    /// `state`, `code` are retrived from HTTP query parameters or form body.
99    /// `session` is retrived from HTTP cookie.
100    ///
101    /// If you need decoding extra claims in ID token,
102    /// specify your own Deserialized type as T.
103    /// Otherwise, set T as ()
104    ///
105    /// ```ignore
106    /// let session_key = cookie_jar.get("__Host-oidc-session")?.value();
107    /// let session_value = some_database.load(session_key)?;
108    /// let session = tiny_oidc_rp::Session::load_session(session_key, session_value)?;
109    /// let id_token = oidc_client.authenticate<()>(state, code, &session)?;
110    /// ```
111    pub async fn authenticate<T>(
112        &self,
113        state: &str,
114        code: &str,
115        session: &Session,
116    ) -> Result<IdToken<T>, Error>
117    where
118        T: serde::de::DeserializeOwned,
119    {
120        // Check state mismatch (possible CSRF)
121        if state != session.state() {
122            log::warn!("state mismatch");
123            return Err(Error::BadRequest);
124        }
125
126        // Prepare token endpoint request
127        let code_verifier = session.pkce_verifier();
128        let params = vec![
129            ("grant_type", "authorization_code"),
130            ("code", code),
131            ("client_id", &self.client_id),
132            ("client_secret", &self.client_secret),
133            ("redirect_uri", &self.redirect_uri),
134            ("code_verifier", &code_verifier),
135        ];
136
137        // Send POST request to token endpoint
138        let response = reqwest::Client::new()
139            .post(self.provider.token_endpoint().clone())
140            .form(&params)
141            .send()
142            .await?;
143
144        if let Err(err) = response.error_for_status_ref() {
145            // Error, log body
146            let err_body = response.text().await?;
147            log::warn!("Token endpoint returns error {}", err_body);
148
149            Err(err.into())
150        } else {
151            // Ok, decode body as JSON
152            let token_response = response.json::<OidcTokenEndpointResponse>().await?;
153            log::debug!("Token endpoint returns {:?}", token_response);
154
155            // Decode ID Token string.
156            //   Skip JWS signature validation here,
157            //   because code flow can trust issuer by TLS server certificate validation
158            let id_token = IdToken::<T>::decode_without_jws_validation(&token_response.id_token)?;
159
160            self.validate_claims(&id_token, session)?;
161            Ok(id_token)
162        }
163    }
164
165    /// Validate ID token claims
166    /// See also [OpenID connect spec 3.1.3.7. ID Token Validation](https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation)
167    fn validate_claims<T>(
168        &self,
169        id_token: &IdToken<T>,
170        session: &Session,
171    ) -> Result<(), AuthenticationFailedError> {
172        use std::time::SystemTime;
173
174        if !self.provider.validate_iss(&id_token.iss) {
175            log::info!("Invalid iss {}", id_token.iss);
176            return Err(AuthenticationFailedError::ClaimValidationError);
177        }
178
179        if id_token.aud != self.client_id {
180            log::info!("Invalid aud {}", id_token.aud);
181            return Err(AuthenticationFailedError::ClaimValidationError);
182        }
183
184        if &id_token.nonce != &session.nonce() {
185            log::info!("Invalid nonce {}", id_token.nonce);
186            return Err(AuthenticationFailedError::ClaimValidationError);
187        }
188
189        let now = SystemTime::now()
190            .duration_since(SystemTime::UNIX_EPOCH)
191            .map_or(0, |t| t.as_secs());
192        if id_token.iat > now + 60 || now > id_token.exp {
193            // token expired
194            log::info!(
195                "Invalid iat {} or exp {} : now = {}",
196                id_token.iat,
197                id_token.exp,
198                now
199            );
200            return Err(AuthenticationFailedError::ClaimValidationError);
201        }
202
203        Ok(())
204    }
205}
206
207/// Setup Client
208pub struct ClientBuilder<P: Provider> {
209    client_id: Option<String>,
210    client_secret: Option<String>,
211    redirect_uri: Option<String>,
212    response_mode: OidcResponseMode,
213    provider: P,
214}
215
216impl<P: Provider> ClientBuilder<P> {
217    /// Client builder from OpenID connect Provider
218    pub(crate) fn from_provider(provider: P) -> Self {
219        Self {
220            client_id: None,
221            client_secret: None,
222            redirect_uri: None,
223            response_mode: OidcResponseMode::Query,
224            provider,
225        }
226    }
227
228    /// Build OpenID connect Client
229    pub fn build(self) -> Option<Client<P>> {
230        match self {
231            Self {
232                client_id: Some(client_id),
233                client_secret: Some(client_secret),
234                redirect_uri: Some(redirect_uri),
235                response_mode,
236                provider,
237            } => Some(Client {
238                client_id,
239                client_secret,
240                redirect_uri,
241                response_mode,
242                provider,
243            }),
244            _ => {
245                // Some elements are not initialized.
246                None
247            }
248        }
249    }
250
251    /// Client ID
252    pub fn client_id(self, client_id: &str) -> Self {
253        let mut builder = self;
254        builder.client_id = Some(client_id.to_string());
255        builder
256    }
257
258    /// Client secret
259    pub fn client_secret(self, client_secret: &str) -> Self {
260        let mut builder = self;
261        builder.client_secret = Some(client_secret.to_string());
262        builder
263    }
264
265    /// Redirect URI
266    pub fn redirect_uri(self, redirect_uri: &str) -> Self {
267        let mut builder = self;
268        builder.redirect_uri = Some(redirect_uri.to_string());
269        builder
270    }
271
272    /// Response mode
273    pub fn response_mode(self, response_mode: OidcResponseMode) -> Self {
274        let mut builder = self;
275        builder.response_mode = response_mode;
276        builder
277    }
278}
279
280/// OpenID connect login session
281pub struct Session {
282    // 0..36=key, 36..72=state, 72..108=nonce, 108..144=pkce_verifier
283    rand_bytes: [u8; 144],
284}
285
286impl Session {
287    /// Start new OpenID connect session
288    pub fn new_session() -> Result<Session, crate::Error> {
289        // Make random bytes
290        let mut rand_bytes = [0u8; 144];
291        getrandom::fill(&mut rand_bytes).map_err(|e| {
292            log::error!("getrandom() failed with {:?}", e);
293            crate::Error::InternalError
294        })?;
295        Ok(Session { rand_bytes })
296    }
297
298    /// Serialize session and returns (key, value) pair.
299    /// Implementer should store `key` in browser session cookie or local storage,
300    /// and store `(key,value)` pair in server side database.
301    /// Both `key` and `value` is URL safe string
302    pub fn save_session(&self) -> (String, String) {
303        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
304        return (self.key(), URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..]));
305    }
306
307    /// Deserialize session saved by `save_session()`
308    /// Implementer should get session key from cookie,
309    /// and load session_value from server side database.
310    pub fn load_session(
311        session_key: &str,
312        session_value: &str,
313    ) -> Result<Self, base64::DecodeSliceError> {
314        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
315        let mut rand_bytes = [0u8; 144];
316
317        // Decode key & value
318        URL_SAFE_NO_PAD.decode_slice(session_key, &mut rand_bytes[..36])?;
319        URL_SAFE_NO_PAD.decode_slice(session_value, &mut rand_bytes[36..])?;
320
321        Ok(Self { rand_bytes })
322    }
323
324    /// Base64Url(key) -> 48 chars
325    pub fn key(&self) -> String {
326        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
327        URL_SAFE_NO_PAD.encode(&self.rand_bytes[..36])
328    }
329
330    /// Base64Url(state) -> 48 chars
331    fn state(&self) -> String {
332        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
333        URL_SAFE_NO_PAD.encode(&self.rand_bytes[36..72])
334    }
335
336    /// Base64Url(nonce) -> 48 chars
337    fn nonce(&self) -> String {
338        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
339        URL_SAFE_NO_PAD.encode(&self.rand_bytes[72..108])
340    }
341
342    /// PKCE code_challenge in Base64 string
343    fn pkce_challenge(&self) -> String {
344        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
345        use sha2::{Digest, Sha256};
346
347        // PKCE code_challenge=Base64Url(SHA256(pkce_verifier))
348        let challenge_byte = Sha256::digest(&self.pkce_verifier().as_bytes());
349
350        URL_SAFE_NO_PAD.encode(&challenge_byte)
351    }
352
353    /// PKCE code_verifier in Base64 string
354    fn pkce_verifier(&self) -> String {
355        use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
356        // code_verifier = Base64Url(pkce_verifier)
357        URL_SAFE_NO_PAD.encode(&self.rand_bytes[108..144])
358    }
359}
360
361/// Response body JSON from token endpoint
362#[derive(Debug, serde::Deserialize)]
363struct OidcTokenEndpointResponse {
364    // access_token: Option<String>,
365    id_token: String,
366}