Skip to main content

modo/auth/oauth/
state.rs

1use axum::extract::{FromRef, FromRequestParts};
2use axum::response::{IntoResponse, Redirect, Response};
3use axum_extra::extract::cookie::Key;
4use cookie::{Cookie, CookieJar, SameSite};
5use http::header::{COOKIE, SET_COOKIE};
6use http::request::Parts;
7
8use crate::cookie::CookieConfig;
9use crate::service::AppState;
10
11const OAUTH_COOKIE_NAME: &str = "_oauth_state";
12const OAUTH_COOKIE_MAX_AGE_SECS: i64 = 300;
13
14/// OAuth 2.0 state extracted from the signed `_oauth_state` cookie.
15///
16/// This is an axum extractor. Add it as a handler parameter on your callback route to
17/// automatically parse and verify the state cookie that was set by [`AuthorizationRequest`].
18///
19/// Requires a [`Key`](axum_extra::extract::cookie::Key) to be registered in the
20/// [`Registry`](crate::service::Registry) so that the cookie signature can be verified.
21///
22/// Returns [`crate::Error::bad_request`] if the cookie is missing, tampered with, or
23/// structurally invalid.
24pub struct OAuthState {
25    state_nonce: String,
26    pkce_verifier: String,
27    provider: String,
28}
29
30impl OAuthState {
31    pub(crate) fn provider(&self) -> &str {
32        &self.provider
33    }
34
35    pub(crate) fn pkce_verifier(&self) -> &str {
36        &self.pkce_verifier
37    }
38
39    pub(crate) fn state_nonce(&self) -> &str {
40        &self.state_nonce
41    }
42
43    pub(crate) fn from_signed_cookie(cookie_header: &str, key: &Key) -> crate::Result<Self> {
44        let mut jar = CookieJar::new();
45
46        for part in cookie_header.split(';') {
47            let trimmed = part.trim();
48            if let Ok(cookie) = Cookie::parse(trimmed) {
49                jar.add_original(cookie.into_owned());
50            }
51        }
52
53        let verified = jar
54            .signed(key)
55            .get(OAUTH_COOKIE_NAME)
56            .ok_or_else(|| crate::Error::bad_request("invalid or missing OAuth state cookie"))?;
57
58        let payload: serde_json::Value = serde_json::from_str(verified.value())
59            .map_err(|e| crate::Error::bad_request(format!("invalid OAuth state: {e}")))?;
60
61        Ok(Self {
62            state_nonce: payload["state"]
63                .as_str()
64                .ok_or_else(|| crate::Error::bad_request("missing state nonce"))?
65                .to_string(),
66            pkce_verifier: payload["pkce_verifier"]
67                .as_str()
68                .ok_or_else(|| crate::Error::bad_request("missing PKCE verifier"))?
69                .to_string(),
70            provider: payload["provider"]
71                .as_str()
72                .ok_or_else(|| crate::Error::bad_request("missing provider"))?
73                .to_string(),
74        })
75    }
76}
77
78impl<S> FromRequestParts<S> for OAuthState
79where
80    S: Send + Sync,
81    AppState: axum::extract::FromRef<S>,
82{
83    type Rejection = crate::Error;
84
85    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
86        let app_state = AppState::from_ref(state);
87        let key: std::sync::Arc<Key> = app_state
88            .get::<Key>()
89            .ok_or_else(|| crate::Error::internal("Key not registered in service registry"))?;
90
91        let cookie_header = parts
92            .headers
93            .get(COOKIE)
94            .and_then(|v| v.to_str().ok())
95            .ok_or_else(|| crate::Error::bad_request("missing OAuth state cookie"))?;
96
97        Self::from_signed_cookie(cookie_header, &key)
98    }
99}
100
101/// An authorization redirect that also sets the `_oauth_state` cookie.
102///
103/// Returned by [`OAuthProvider::authorize_url`](super::OAuthProvider::authorize_url).
104/// Implements [`axum::response::IntoResponse`]: returning it from a handler issues an HTTP
105/// `303 See Other` redirect to the provider's authorization endpoint and attaches the signed
106/// state cookie.
107pub struct AuthorizationRequest {
108    pub(crate) redirect_url: String,
109    pub(crate) set_cookie_header: String,
110}
111
112impl IntoResponse for AuthorizationRequest {
113    fn into_response(self) -> Response {
114        let mut response = Redirect::to(&self.redirect_url).into_response();
115        if let Ok(value) = self.set_cookie_header.parse() {
116            response.headers_mut().insert(SET_COOKIE, value);
117        }
118        response
119    }
120}
121
122/// Build a signed OAuth state cookie. Returns (set_cookie_header, state_nonce, pkce_verifier).
123pub(crate) fn build_oauth_cookie(
124    provider: &str,
125    key: &Key,
126    cookie_config: &CookieConfig,
127) -> (String, String, String) {
128    let state_nonce = generate_random_string(32);
129    let pkce_verifier = generate_random_string(64);
130
131    let payload = serde_json::json!({
132        "state": state_nonce,
133        "pkce_verifier": pkce_verifier,
134        "provider": provider,
135    });
136
137    let mut jar = CookieJar::new();
138    let mut cookie = Cookie::new(OAUTH_COOKIE_NAME, payload.to_string());
139    cookie.set_path("/");
140    cookie.set_http_only(cookie_config.http_only);
141    cookie.set_secure(cookie_config.secure);
142    cookie.set_max_age(cookie::time::Duration::seconds(OAUTH_COOKIE_MAX_AGE_SECS));
143    cookie.set_same_site(match cookie_config.same_site.as_str() {
144        "strict" => SameSite::Strict,
145        "none" => SameSite::None,
146        _ => SameSite::Lax,
147    });
148
149    jar.signed_mut(key).add(cookie);
150
151    let set_cookie_header = jar
152        .get(OAUTH_COOKIE_NAME)
153        .map(|c| c.to_string())
154        .unwrap_or_default();
155
156    (set_cookie_header, state_nonce, pkce_verifier)
157}
158
159/// Generate a PKCE code challenge (S256) from the verifier.
160pub(crate) fn pkce_challenge(verifier: &str) -> String {
161    use sha2::{Digest, Sha256};
162    let hash = Sha256::digest(verifier.as_bytes());
163    base64url_encode(&hash)
164}
165
166fn generate_random_string(len: usize) -> String {
167    let mut bytes = vec![0u8; len];
168    rand::fill(&mut bytes[..]);
169    base64url_encode(&bytes)
170}
171
172fn base64url_encode(bytes: &[u8]) -> String {
173    crate::encoding::base64url::encode(bytes)
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use http::StatusCode;
180
181    fn test_cookie_config() -> CookieConfig {
182        CookieConfig {
183            secret: "a".repeat(64),
184            secure: false,
185            http_only: true,
186            same_site: "lax".to_string(),
187        }
188    }
189
190    fn test_key() -> Key {
191        crate::cookie::key_from_config(&test_cookie_config()).unwrap()
192    }
193
194    #[test]
195    fn authorization_request_into_response_redirects() {
196        let req = AuthorizationRequest {
197            redirect_url: "https://accounts.google.com/o/oauth2/v2/auth?foo=bar".to_string(),
198            set_cookie_header: "_oauth_state=signed_value; Path=/; HttpOnly; SameSite=Lax"
199                .to_string(),
200        };
201        let response = req.into_response();
202        assert_eq!(response.status(), StatusCode::SEE_OTHER);
203        let cookie = response
204            .headers()
205            .get("set-cookie")
206            .unwrap()
207            .to_str()
208            .unwrap();
209        assert!(cookie.contains("_oauth_state="));
210    }
211
212    #[test]
213    fn build_and_parse_oauth_cookie_roundtrip() {
214        let key = test_key();
215        let cookie_config = test_cookie_config();
216
217        let (set_cookie_header, state_nonce, pkce_verifier) =
218            build_oauth_cookie("google", &key, &cookie_config);
219
220        assert!(set_cookie_header.contains("_oauth_state="));
221        assert!(set_cookie_header.contains("HttpOnly"));
222        assert!(!state_nonce.is_empty());
223        assert!(!pkce_verifier.is_empty());
224
225        let parsed = OAuthState::from_signed_cookie(&set_cookie_header, &key).unwrap();
226        assert_eq!(parsed.provider(), "google");
227        assert_eq!(parsed.state_nonce(), &state_nonce);
228        assert_eq!(parsed.pkce_verifier(), &pkce_verifier);
229    }
230
231    #[test]
232    fn parse_tampered_cookie_fails() {
233        let key = test_key();
234        let cookie_config = test_cookie_config();
235
236        let (set_cookie_header, _, _) = build_oauth_cookie("google", &key, &cookie_config);
237
238        let tampered = set_cookie_header.replace("_oauth_state=", "_oauth_state=tampered");
239        assert!(OAuthState::from_signed_cookie(&tampered, &key).is_err());
240    }
241
242    #[test]
243    fn cross_provider_state_detected() {
244        let key = test_key();
245        let cookie_config = test_cookie_config();
246
247        let (set_cookie_header, _, _) = build_oauth_cookie("google", &key, &cookie_config);
248        let parsed = OAuthState::from_signed_cookie(&set_cookie_header, &key).unwrap();
249        assert_eq!(parsed.provider(), "google");
250        assert_ne!(parsed.provider(), "github");
251    }
252}