1use axum::extract::{FromRef, FromRequestParts};
2use axum::response::{IntoResponse, Redirect, Response};
3use axum_extra::extract::cookie::Key;
4use cookie::{Cookie, CookieJar, SameSite};
5use http::header::{COOKIE, SET_COOKIE};
6use http::request::Parts;
7
8use crate::cookie::CookieConfig;
9use crate::service::AppState;
10
11const OAUTH_COOKIE_NAME: &str = "_oauth_state";
12const OAUTH_COOKIE_MAX_AGE_SECS: i64 = 300;
13
14pub struct OAuthState {
25 state_nonce: String,
26 pkce_verifier: String,
27 provider: String,
28}
29
30impl OAuthState {
31 pub(crate) fn provider(&self) -> &str {
32 &self.provider
33 }
34
35 pub(crate) fn pkce_verifier(&self) -> &str {
36 &self.pkce_verifier
37 }
38
39 pub(crate) fn state_nonce(&self) -> &str {
40 &self.state_nonce
41 }
42
43 pub(crate) fn from_signed_cookie(cookie_header: &str, key: &Key) -> crate::Result<Self> {
44 let mut jar = CookieJar::new();
45
46 for part in cookie_header.split(';') {
47 let trimmed = part.trim();
48 if let Ok(cookie) = Cookie::parse(trimmed) {
49 jar.add_original(cookie.into_owned());
50 }
51 }
52
53 let verified = jar
54 .signed(key)
55 .get(OAUTH_COOKIE_NAME)
56 .ok_or_else(|| crate::Error::bad_request("invalid or missing OAuth state cookie"))?;
57
58 let payload: serde_json::Value = serde_json::from_str(verified.value())
59 .map_err(|e| crate::Error::bad_request(format!("invalid OAuth state: {e}")))?;
60
61 Ok(Self {
62 state_nonce: payload["state"]
63 .as_str()
64 .ok_or_else(|| crate::Error::bad_request("missing state nonce"))?
65 .to_string(),
66 pkce_verifier: payload["pkce_verifier"]
67 .as_str()
68 .ok_or_else(|| crate::Error::bad_request("missing PKCE verifier"))?
69 .to_string(),
70 provider: payload["provider"]
71 .as_str()
72 .ok_or_else(|| crate::Error::bad_request("missing provider"))?
73 .to_string(),
74 })
75 }
76}
77
78impl<S> FromRequestParts<S> for OAuthState
79where
80 S: Send + Sync,
81 AppState: axum::extract::FromRef<S>,
82{
83 type Rejection = crate::Error;
84
85 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
86 let app_state = AppState::from_ref(state);
87 let key: std::sync::Arc<Key> = app_state
88 .get::<Key>()
89 .ok_or_else(|| crate::Error::internal("Key not registered in service registry"))?;
90
91 let cookie_header = parts
92 .headers
93 .get(COOKIE)
94 .and_then(|v| v.to_str().ok())
95 .ok_or_else(|| crate::Error::bad_request("missing OAuth state cookie"))?;
96
97 Self::from_signed_cookie(cookie_header, &key)
98 }
99}
100
101pub struct AuthorizationRequest {
108 pub(crate) redirect_url: String,
109 pub(crate) set_cookie_header: String,
110}
111
112impl IntoResponse for AuthorizationRequest {
113 fn into_response(self) -> Response {
114 let mut response = Redirect::to(&self.redirect_url).into_response();
115 if let Ok(value) = self.set_cookie_header.parse() {
116 response.headers_mut().insert(SET_COOKIE, value);
117 }
118 response
119 }
120}
121
122pub(crate) fn build_oauth_cookie(
124 provider: &str,
125 key: &Key,
126 cookie_config: &CookieConfig,
127) -> (String, String, String) {
128 let state_nonce = generate_random_string(32);
129 let pkce_verifier = generate_random_string(64);
130
131 let payload = serde_json::json!({
132 "state": state_nonce,
133 "pkce_verifier": pkce_verifier,
134 "provider": provider,
135 });
136
137 let mut jar = CookieJar::new();
138 let mut cookie = Cookie::new(OAUTH_COOKIE_NAME, payload.to_string());
139 cookie.set_path("/");
140 cookie.set_http_only(cookie_config.http_only);
141 cookie.set_secure(cookie_config.secure);
142 cookie.set_max_age(cookie::time::Duration::seconds(OAUTH_COOKIE_MAX_AGE_SECS));
143 cookie.set_same_site(match cookie_config.same_site.as_str() {
144 "strict" => SameSite::Strict,
145 "none" => SameSite::None,
146 _ => SameSite::Lax,
147 });
148
149 jar.signed_mut(key).add(cookie);
150
151 let set_cookie_header = jar
152 .get(OAUTH_COOKIE_NAME)
153 .map(|c| c.to_string())
154 .unwrap_or_default();
155
156 (set_cookie_header, state_nonce, pkce_verifier)
157}
158
159pub(crate) fn pkce_challenge(verifier: &str) -> String {
161 use sha2::{Digest, Sha256};
162 let hash = Sha256::digest(verifier.as_bytes());
163 base64url_encode(&hash)
164}
165
166fn generate_random_string(len: usize) -> String {
167 let mut bytes = vec![0u8; len];
168 rand::fill(&mut bytes[..]);
169 base64url_encode(&bytes)
170}
171
172fn base64url_encode(bytes: &[u8]) -> String {
173 crate::encoding::base64url::encode(bytes)
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use http::StatusCode;
180
181 fn test_cookie_config() -> CookieConfig {
182 CookieConfig {
183 secret: "a".repeat(64),
184 secure: false,
185 http_only: true,
186 same_site: "lax".to_string(),
187 }
188 }
189
190 fn test_key() -> Key {
191 crate::cookie::key_from_config(&test_cookie_config()).unwrap()
192 }
193
194 #[test]
195 fn authorization_request_into_response_redirects() {
196 let req = AuthorizationRequest {
197 redirect_url: "https://accounts.google.com/o/oauth2/v2/auth?foo=bar".to_string(),
198 set_cookie_header: "_oauth_state=signed_value; Path=/; HttpOnly; SameSite=Lax"
199 .to_string(),
200 };
201 let response = req.into_response();
202 assert_eq!(response.status(), StatusCode::SEE_OTHER);
203 let cookie = response
204 .headers()
205 .get("set-cookie")
206 .unwrap()
207 .to_str()
208 .unwrap();
209 assert!(cookie.contains("_oauth_state="));
210 }
211
212 #[test]
213 fn build_and_parse_oauth_cookie_roundtrip() {
214 let key = test_key();
215 let cookie_config = test_cookie_config();
216
217 let (set_cookie_header, state_nonce, pkce_verifier) =
218 build_oauth_cookie("google", &key, &cookie_config);
219
220 assert!(set_cookie_header.contains("_oauth_state="));
221 assert!(set_cookie_header.contains("HttpOnly"));
222 assert!(!state_nonce.is_empty());
223 assert!(!pkce_verifier.is_empty());
224
225 let parsed = OAuthState::from_signed_cookie(&set_cookie_header, &key).unwrap();
226 assert_eq!(parsed.provider(), "google");
227 assert_eq!(parsed.state_nonce(), &state_nonce);
228 assert_eq!(parsed.pkce_verifier(), &pkce_verifier);
229 }
230
231 #[test]
232 fn parse_tampered_cookie_fails() {
233 let key = test_key();
234 let cookie_config = test_cookie_config();
235
236 let (set_cookie_header, _, _) = build_oauth_cookie("google", &key, &cookie_config);
237
238 let tampered = set_cookie_header.replace("_oauth_state=", "_oauth_state=tampered");
239 assert!(OAuthState::from_signed_cookie(&tampered, &key).is_err());
240 }
241
242 #[test]
243 fn cross_provider_state_detected() {
244 let key = test_key();
245 let cookie_config = test_cookie_config();
246
247 let (set_cookie_header, _, _) = build_oauth_cookie("google", &key, &cookie_config);
248 let parsed = OAuthState::from_signed_cookie(&set_cookie_header, &key).unwrap();
249 assert_eq!(parsed.provider(), "google");
250 assert_ne!(parsed.provider(), "github");
251 }
252}