1use base64ct::{Base64UrlUnpadded, Encoding};
11use serde::Deserialize;
12use url::Url;
13
14use crate::auth_client::AuthFuture;
15use crate::error::AuthError;
16use crate::social_providers::{ProviderType, SocialProvider, SocialProviderConfig, SocialUserInfo};
17
18#[derive(Debug)]
26pub struct GoogleSocialProvider {
27 client_id: String,
28 client_secret: String,
29 scopes: Vec<String>,
30 http: reqwest::Client,
31 token_url: String,
34}
35
36#[derive(Deserialize)]
39struct GoogleIdTokenClaims {
40 sub: String,
41 email: String,
42 email_verified: bool,
43 name: Option<String>,
44 picture: Option<String>,
45}
46
47impl GoogleSocialProvider {
50 pub fn new(config: SocialProviderConfig) -> Result<Self, AuthError> {
55 Self::new_with_token_url(config, "https://oauth2.googleapis.com/token".into())
56 }
57
58 pub(crate) fn new_with_token_url(
62 config: SocialProviderConfig,
63 token_url: String,
64 ) -> Result<Self, AuthError> {
65 if config.provider_type != ProviderType::Google {
66 return Err(AuthError::Validation(
67 "provider_type mismatch: expected Google".into(),
68 ));
69 }
70 if config.scopes.is_empty() {
71 return Err(AuthError::Validation("scopes must not be empty".into()));
72 }
73 let http = reqwest::Client::builder()
74 .user_agent("allowthem-oauth")
75 .build()
76 .map_err(|e| AuthError::Validation(format!("reqwest client build failed: {e}")))?;
77 Ok(Self {
78 client_id: config.client_id,
79 client_secret: config.client_secret,
80 scopes: config.scopes,
81 http,
82 token_url,
83 })
84 }
85}
86
87impl SocialProvider for GoogleSocialProvider {
90 fn provider_type(&self) -> ProviderType {
91 ProviderType::Google
92 }
93
94 fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String {
95 let mut url =
96 Url::parse("https://accounts.google.com/o/oauth2/v2/auth").expect("static URL");
97 url.query_pairs_mut()
98 .append_pair("client_id", &self.client_id)
99 .append_pair("redirect_uri", redirect_uri)
100 .append_pair("response_type", "code")
101 .append_pair("scope", &self.scopes.join(" "))
102 .append_pair("state", state)
103 .append_pair("code_challenge", pkce_challenge)
104 .append_pair("code_challenge_method", "S256");
105 url.into()
106 }
107
108 fn exchange_code<'a>(
109 &'a self,
110 code: &'a str,
111 redirect_uri: &'a str,
112 pkce_verifier: &'a str,
113 ) -> AuthFuture<'a, String> {
114 Box::pin(async move {
115 let resp = self
116 .http
117 .post(&self.token_url)
118 .form(&[
119 ("code", code),
120 ("client_id", self.client_id.as_str()),
121 ("client_secret", self.client_secret.as_str()),
122 ("redirect_uri", redirect_uri),
123 ("grant_type", "authorization_code"),
124 ("code_verifier", pkce_verifier),
125 ])
126 .send()
127 .await
128 .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
129
130 let status = resp.status();
131 if !status.is_success() {
132 let body = resp.text().await.unwrap_or_default();
133 return Err(AuthError::OAuthTokenExchange(format!("{status}: {body}")));
134 }
135
136 let json: serde_json::Value = resp
137 .json()
138 .await
139 .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
140
141 json.get("id_token")
142 .and_then(|v| v.as_str())
143 .map(|s| s.to_owned())
144 .ok_or_else(|| {
145 AuthError::OAuthTokenExchange(
146 "missing id_token in Google token response".into(),
147 )
148 })
149 })
150 }
151
152 fn fetch_user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, SocialUserInfo> {
153 Box::pin(async move {
154 let claims = decode_id_token(access_token)?;
155 Ok(SocialUserInfo {
156 provider_user_id: claims.sub,
157 email: claims.email,
158 email_verified: claims.email_verified,
159 name: claims.name,
160 avatar_url: claims.picture,
161 })
162 })
163 }
164}
165
166fn decode_id_token(token: &str) -> Result<GoogleIdTokenClaims, AuthError> {
169 let parts: Vec<&str> = token.split('.').collect();
171 if parts.len() != 3 {
172 return Err(AuthError::OAuthUserInfoFetch("malformed id_token".into()));
173 }
174 let raw = Base64UrlUnpadded::decode_vec(parts[1]).map_err(|_| {
175 AuthError::OAuthUserInfoFetch("id_token payload is not valid base64url".into())
176 })?;
177 serde_json::from_slice::<GoogleIdTokenClaims>(&raw).map_err(|e| {
178 AuthError::OAuthUserInfoFetch(format!("id_token payload JSON parse error: {e}"))
179 })
180}
181
182#[cfg(test)]
185mod tests {
186 use super::*;
187 use crate::types::SocialProviderId;
188
189 fn google_config() -> SocialProviderConfig {
190 SocialProviderConfig {
191 id: SocialProviderId::new(),
192 provider_type: ProviderType::Google,
193 display_name: "Google".into(),
194 client_id: "test-client-id".into(),
195 client_secret: "test-client-secret".into(),
196 scopes: vec!["openid".into(), "email".into()],
197 enabled: true,
198 priority: 0,
199 config: None,
200 }
201 }
202
203 #[test]
206 fn new_rejects_provider_type_mismatch() {
207 let mut cfg = google_config();
208 cfg.provider_type = ProviderType::Github;
209 let err = GoogleSocialProvider::new(cfg).unwrap_err();
210 assert!(matches!(err, AuthError::Validation(_)));
211 }
212
213 #[test]
214 fn new_rejects_empty_scopes() {
215 let mut cfg = google_config();
216 cfg.scopes = vec![];
217 let err = GoogleSocialProvider::new(cfg).unwrap_err();
218 assert!(matches!(err, AuthError::Validation(_)));
219 }
220
221 #[test]
224 fn authorize_url_contains_required_params() {
225 let provider = GoogleSocialProvider::new(google_config()).unwrap();
226 let url = provider.authorize_url("https://example.com/callback", "mystate", "mychallenge");
227 assert!(url.contains("client_id=test-client-id"), "url: {url}");
228 assert!(url.contains("redirect_uri="), "url: {url}");
229 assert!(url.contains("response_type=code"), "url: {url}");
230 assert!(url.contains("state=mystate"), "url: {url}");
231 assert!(url.contains("code_challenge=mychallenge"), "url: {url}");
232 assert!(url.contains("code_challenge_method=S256"), "url: {url}");
233 }
234
235 #[test]
236 fn authorize_url_uses_config_scopes_joined_by_space() {
237 let provider = GoogleSocialProvider::new(google_config()).unwrap();
238 let url = provider.authorize_url("https://example.com/callback", "s", "c");
239 assert!(
241 url.contains("scope=openid+email") || url.contains("scope=openid%20email"),
242 "url: {url}"
243 );
244 }
245
246 #[test]
247 fn authorize_url_does_not_leak_client_secret() {
248 let provider = GoogleSocialProvider::new(google_config()).unwrap();
249 let url = provider.authorize_url("https://example.com/callback", "s", "c");
250 assert!(!url.contains("test-client-secret"), "url: {url}");
251 }
252
253 fn make_id_token(payload: &serde_json::Value) -> String {
256 let header = Base64UrlUnpadded::encode_string(b"{\"alg\":\"RS256\"}");
257 let body = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes());
258 format!("{header}.{body}.fakesig")
259 }
260
261 #[tokio::test]
262 async fn decode_id_token_extracts_claims() {
263 let payload = serde_json::json!({
264 "sub": "google-user-123",
265 "email": "user@example.com",
266 "email_verified": true,
267 "name": "Test User",
268 "picture": "https://example.com/photo.jpg"
269 });
270 let provider = GoogleSocialProvider::new(google_config()).unwrap();
271 let info = provider
272 .fetch_user_info(&make_id_token(&payload))
273 .await
274 .unwrap();
275 assert_eq!(info.provider_user_id, "google-user-123");
276 assert_eq!(info.email, "user@example.com");
277 assert!(info.email_verified);
278 assert_eq!(info.name.as_deref(), Some("Test User"));
279 assert_eq!(
280 info.avatar_url.as_deref(),
281 Some("https://example.com/photo.jpg")
282 );
283 }
284
285 #[tokio::test]
286 async fn decode_id_token_rejects_malformed_token() {
287 let provider = GoogleSocialProvider::new(google_config()).unwrap();
288 let err = provider.fetch_user_info("only.two").await.unwrap_err();
289 assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
290 }
291
292 #[tokio::test]
293 async fn decode_id_token_rejects_invalid_base64() {
294 let provider = GoogleSocialProvider::new(google_config()).unwrap();
295 let err = provider
296 .fetch_user_info("header.!!!invalid!!!.sig")
297 .await
298 .unwrap_err();
299 assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
300 }
301
302 #[tokio::test]
303 async fn decode_id_token_rejects_non_json_payload() {
304 let payload_b64 = Base64UrlUnpadded::encode_string(b"not json at all");
305 let token = format!("header.{payload_b64}.sig");
306 let provider = GoogleSocialProvider::new(google_config()).unwrap();
307 let err = provider.fetch_user_info(&token).await.unwrap_err();
308 assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
309 }
310
311 #[tokio::test]
312 async fn decode_id_token_email_unverified_propagates() {
313 let payload = serde_json::json!({
314 "sub": "u1",
315 "email": "u@example.com",
316 "email_verified": false,
317 });
318 let provider = GoogleSocialProvider::new(google_config()).unwrap();
319 let info = provider
320 .fetch_user_info(&make_id_token(&payload))
321 .await
322 .unwrap();
323 assert!(!info.email_verified);
324 }
325
326 #[tokio::test]
327 async fn decode_id_token_picture_maps_to_avatar_url() {
328 let payload = serde_json::json!({
329 "sub": "u1",
330 "email": "u@example.com",
331 "email_verified": true,
332 "picture": "https://cdn.example.com/avatar.png"
333 });
334 let provider = GoogleSocialProvider::new(google_config()).unwrap();
335 let info = provider
336 .fetch_user_info(&make_id_token(&payload))
337 .await
338 .unwrap();
339 assert_eq!(
340 info.avatar_url.as_deref(),
341 Some("https://cdn.example.com/avatar.png")
342 );
343 }
344
345 #[tokio::test]
348 async fn exchange_code_extracts_id_token_on_success() {
349 use wiremock::matchers::{method, path};
350 use wiremock::{Mock, MockServer, ResponseTemplate};
351
352 let server = MockServer::start().await;
353 Mock::given(method("POST"))
354 .and(path("/token"))
355 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
356 "access_token": "unused-access",
357 "id_token": "header.payload.sig",
358 "token_type": "Bearer"
359 })))
360 .mount(&server)
361 .await;
362
363 let token_url = format!("{}/token", server.uri());
364 let provider =
365 GoogleSocialProvider::new_with_token_url(google_config(), token_url).unwrap();
366 let id_token = provider
367 .exchange_code("mycode", "https://example.com/cb", "pkce_v")
368 .await
369 .unwrap();
370 assert_eq!(id_token, "header.payload.sig");
371 }
372
373 #[tokio::test]
374 async fn exchange_code_returns_token_exchange_error_on_4xx() {
375 use wiremock::matchers::{method, path};
376 use wiremock::{Mock, MockServer, ResponseTemplate};
377
378 let server = MockServer::start().await;
379 Mock::given(method("POST"))
380 .and(path("/token"))
381 .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
382 "error": "invalid_grant"
383 })))
384 .mount(&server)
385 .await;
386
387 let token_url = format!("{}/token", server.uri());
388 let provider =
389 GoogleSocialProvider::new_with_token_url(google_config(), token_url).unwrap();
390 let err = provider
391 .exchange_code("badcode", "https://example.com/cb", "v")
392 .await
393 .unwrap_err();
394 assert!(matches!(err, AuthError::OAuthTokenExchange(_)));
395 }
396
397 #[tokio::test]
398 async fn exchange_code_returns_token_exchange_error_on_missing_id_token() {
399 use wiremock::matchers::{method, path};
400 use wiremock::{Mock, MockServer, ResponseTemplate};
401
402 let server = MockServer::start().await;
403 Mock::given(method("POST"))
404 .and(path("/token"))
405 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
406 "access_token": "some-access-token",
407 "token_type": "Bearer"
408 })))
409 .mount(&server)
410 .await;
411
412 let token_url = format!("{}/token", server.uri());
413 let provider =
414 GoogleSocialProvider::new_with_token_url(google_config(), token_url).unwrap();
415 let err = provider
416 .exchange_code("code", "https://example.com/cb", "v")
417 .await
418 .unwrap_err();
419 match err {
420 AuthError::OAuthTokenExchange(msg) => {
421 assert!(msg.contains("missing id_token"), "got: {msg}");
422 }
423 other => panic!("expected OAuthTokenExchange, got {other:?}"),
424 }
425 }
426}