Skip to main content

allowthem_core/
social_google.rs

1//! Google `SocialProvider` implementation.
2//!
3//! ## Token URL injection for tests
4//!
5//! The Google token endpoint is hardcoded to `https://oauth2.googleapis.com/token`
6//! in the public `new()` constructor. Tests that need to point at a wiremock server
7//! use the crate-private `new_with_token_url()` constructor instead. This avoids
8//! the `OnceLock`-override pattern which is parallel-test-hostile.
9
10use base64ct::{Base64UrlUnpadded, Encoding};
11use serde::Deserialize;
12use url::Url;
13
14use crate::auth_client::AuthFuture;
15use crate::error::AuthError;
16use crate::social_providers::{ProviderType, SocialProvider, SocialProviderConfig, SocialUserInfo};
17
18// ── Struct ────────────────────────────────────────────────────────────────────
19
20/// Google OAuth 2.0 + OIDC social provider.
21///
22/// Constructed from a [`SocialProviderConfig`] via [`Self::new`]. The
23/// `exchange_code` method returns the `id_token` JWT from the token
24/// response; `fetch_user_info` decodes it locally — no second HTTP call.
25#[derive(Debug)]
26pub struct GoogleSocialProvider {
27    client_id: String,
28    client_secret: String,
29    scopes: Vec<String>,
30    http: reqwest::Client,
31    /// Token endpoint URL. Always `https://oauth2.googleapis.com/token` in
32    /// production; overridden in tests via `new_with_token_url`.
33    token_url: String,
34}
35
36// ── Private claims struct ─────────────────────────────────────────────────────
37
38#[derive(Deserialize)]
39struct GoogleIdTokenClaims {
40    sub: String,
41    email: String,
42    email_verified: bool,
43    name: Option<String>,
44    picture: Option<String>,
45}
46
47// ── Constructors ──────────────────────────────────────────────────────────────
48
49impl GoogleSocialProvider {
50    /// Build a `GoogleSocialProvider` from a decrypted config.
51    ///
52    /// Returns `AuthError::Validation` if `provider_type` is not `Google`
53    /// or `scopes` is empty.
54    pub fn new(config: SocialProviderConfig) -> Result<Self, AuthError> {
55        Self::new_with_token_url(config, "https://oauth2.googleapis.com/token".into())
56    }
57
58    /// Like [`Self::new`] but with an overrideable token endpoint URL.
59    ///
60    /// Used by tests to point at a wiremock server.
61    pub(crate) fn new_with_token_url(
62        config: SocialProviderConfig,
63        token_url: String,
64    ) -> Result<Self, AuthError> {
65        if config.provider_type != ProviderType::Google {
66            return Err(AuthError::Validation(
67                "provider_type mismatch: expected Google".into(),
68            ));
69        }
70        if config.scopes.is_empty() {
71            return Err(AuthError::Validation("scopes must not be empty".into()));
72        }
73        let http = reqwest::Client::builder()
74            .user_agent("allowthem-oauth")
75            .build()
76            .map_err(|e| AuthError::Validation(format!("reqwest client build failed: {e}")))?;
77        Ok(Self {
78            client_id: config.client_id,
79            client_secret: config.client_secret,
80            scopes: config.scopes,
81            http,
82            token_url,
83        })
84    }
85}
86
87// ── SocialProvider impl ───────────────────────────────────────────────────────
88
89impl SocialProvider for GoogleSocialProvider {
90    fn provider_type(&self) -> ProviderType {
91        ProviderType::Google
92    }
93
94    fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String {
95        let mut url =
96            Url::parse("https://accounts.google.com/o/oauth2/v2/auth").expect("static URL");
97        url.query_pairs_mut()
98            .append_pair("client_id", &self.client_id)
99            .append_pair("redirect_uri", redirect_uri)
100            .append_pair("response_type", "code")
101            .append_pair("scope", &self.scopes.join(" "))
102            .append_pair("state", state)
103            .append_pair("code_challenge", pkce_challenge)
104            .append_pair("code_challenge_method", "S256");
105        url.into()
106    }
107
108    fn exchange_code<'a>(
109        &'a self,
110        code: &'a str,
111        redirect_uri: &'a str,
112        pkce_verifier: &'a str,
113    ) -> AuthFuture<'a, String> {
114        Box::pin(async move {
115            let resp = self
116                .http
117                .post(&self.token_url)
118                .form(&[
119                    ("code", code),
120                    ("client_id", self.client_id.as_str()),
121                    ("client_secret", self.client_secret.as_str()),
122                    ("redirect_uri", redirect_uri),
123                    ("grant_type", "authorization_code"),
124                    ("code_verifier", pkce_verifier),
125                ])
126                .send()
127                .await
128                .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
129
130            let status = resp.status();
131            if !status.is_success() {
132                let body = resp.text().await.unwrap_or_default();
133                return Err(AuthError::OAuthTokenExchange(format!("{status}: {body}")));
134            }
135
136            let json: serde_json::Value = resp
137                .json()
138                .await
139                .map_err(|e| AuthError::OAuthHttp(format!("{e}")))?;
140
141            json.get("id_token")
142                .and_then(|v| v.as_str())
143                .map(|s| s.to_owned())
144                .ok_or_else(|| {
145                    AuthError::OAuthTokenExchange(
146                        "missing id_token in Google token response".into(),
147                    )
148                })
149        })
150    }
151
152    fn fetch_user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, SocialUserInfo> {
153        Box::pin(async move {
154            let claims = decode_id_token(access_token)?;
155            Ok(SocialUserInfo {
156                provider_user_id: claims.sub,
157                email: claims.email,
158                email_verified: claims.email_verified,
159                name: claims.name,
160                avatar_url: claims.picture,
161            })
162        })
163    }
164}
165
166// ── id_token decode helper ────────────────────────────────────────────────────
167
168fn decode_id_token(token: &str) -> Result<GoogleIdTokenClaims, AuthError> {
169    // A JWT is exactly three base64url segments separated by '.'.
170    let parts: Vec<&str> = token.split('.').collect();
171    if parts.len() != 3 {
172        return Err(AuthError::OAuthUserInfoFetch("malformed id_token".into()));
173    }
174    let raw = Base64UrlUnpadded::decode_vec(parts[1]).map_err(|_| {
175        AuthError::OAuthUserInfoFetch("id_token payload is not valid base64url".into())
176    })?;
177    serde_json::from_slice::<GoogleIdTokenClaims>(&raw).map_err(|e| {
178        AuthError::OAuthUserInfoFetch(format!("id_token payload JSON parse error: {e}"))
179    })
180}
181
182// ── Tests ─────────────────────────────────────────────────────────────────────
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use crate::types::SocialProviderId;
188
189    fn google_config() -> SocialProviderConfig {
190        SocialProviderConfig {
191            id: SocialProviderId::new(),
192            provider_type: ProviderType::Google,
193            display_name: "Google".into(),
194            client_id: "test-client-id".into(),
195            client_secret: "test-client-secret".into(),
196            scopes: vec!["openid".into(), "email".into()],
197            enabled: true,
198            priority: 0,
199            config: None,
200        }
201    }
202
203    // ── Constructor validation ────────────────────────────────────────────────
204
205    #[test]
206    fn new_rejects_provider_type_mismatch() {
207        let mut cfg = google_config();
208        cfg.provider_type = ProviderType::Github;
209        let err = GoogleSocialProvider::new(cfg).unwrap_err();
210        assert!(matches!(err, AuthError::Validation(_)));
211    }
212
213    #[test]
214    fn new_rejects_empty_scopes() {
215        let mut cfg = google_config();
216        cfg.scopes = vec![];
217        let err = GoogleSocialProvider::new(cfg).unwrap_err();
218        assert!(matches!(err, AuthError::Validation(_)));
219    }
220
221    // ── authorize_url ─────────────────────────────────────────────────────────
222
223    #[test]
224    fn authorize_url_contains_required_params() {
225        let provider = GoogleSocialProvider::new(google_config()).unwrap();
226        let url = provider.authorize_url("https://example.com/callback", "mystate", "mychallenge");
227        assert!(url.contains("client_id=test-client-id"), "url: {url}");
228        assert!(url.contains("redirect_uri="), "url: {url}");
229        assert!(url.contains("response_type=code"), "url: {url}");
230        assert!(url.contains("state=mystate"), "url: {url}");
231        assert!(url.contains("code_challenge=mychallenge"), "url: {url}");
232        assert!(url.contains("code_challenge_method=S256"), "url: {url}");
233    }
234
235    #[test]
236    fn authorize_url_uses_config_scopes_joined_by_space() {
237        let provider = GoogleSocialProvider::new(google_config()).unwrap();
238        let url = provider.authorize_url("https://example.com/callback", "s", "c");
239        // url crate percent-encodes spaces as %20 in query values
240        assert!(
241            url.contains("scope=openid+email") || url.contains("scope=openid%20email"),
242            "url: {url}"
243        );
244    }
245
246    #[test]
247    fn authorize_url_does_not_leak_client_secret() {
248        let provider = GoogleSocialProvider::new(google_config()).unwrap();
249        let url = provider.authorize_url("https://example.com/callback", "s", "c");
250        assert!(!url.contains("test-client-secret"), "url: {url}");
251    }
252
253    // ── fetch_user_info (id_token decoding) ──────────────────────────────────
254
255    fn make_id_token(payload: &serde_json::Value) -> String {
256        let header = Base64UrlUnpadded::encode_string(b"{\"alg\":\"RS256\"}");
257        let body = Base64UrlUnpadded::encode_string(payload.to_string().as_bytes());
258        format!("{header}.{body}.fakesig")
259    }
260
261    #[tokio::test]
262    async fn decode_id_token_extracts_claims() {
263        let payload = serde_json::json!({
264            "sub": "google-user-123",
265            "email": "user@example.com",
266            "email_verified": true,
267            "name": "Test User",
268            "picture": "https://example.com/photo.jpg"
269        });
270        let provider = GoogleSocialProvider::new(google_config()).unwrap();
271        let info = provider
272            .fetch_user_info(&make_id_token(&payload))
273            .await
274            .unwrap();
275        assert_eq!(info.provider_user_id, "google-user-123");
276        assert_eq!(info.email, "user@example.com");
277        assert!(info.email_verified);
278        assert_eq!(info.name.as_deref(), Some("Test User"));
279        assert_eq!(
280            info.avatar_url.as_deref(),
281            Some("https://example.com/photo.jpg")
282        );
283    }
284
285    #[tokio::test]
286    async fn decode_id_token_rejects_malformed_token() {
287        let provider = GoogleSocialProvider::new(google_config()).unwrap();
288        let err = provider.fetch_user_info("only.two").await.unwrap_err();
289        assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
290    }
291
292    #[tokio::test]
293    async fn decode_id_token_rejects_invalid_base64() {
294        let provider = GoogleSocialProvider::new(google_config()).unwrap();
295        let err = provider
296            .fetch_user_info("header.!!!invalid!!!.sig")
297            .await
298            .unwrap_err();
299        assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
300    }
301
302    #[tokio::test]
303    async fn decode_id_token_rejects_non_json_payload() {
304        let payload_b64 = Base64UrlUnpadded::encode_string(b"not json at all");
305        let token = format!("header.{payload_b64}.sig");
306        let provider = GoogleSocialProvider::new(google_config()).unwrap();
307        let err = provider.fetch_user_info(&token).await.unwrap_err();
308        assert!(matches!(err, AuthError::OAuthUserInfoFetch(_)));
309    }
310
311    #[tokio::test]
312    async fn decode_id_token_email_unverified_propagates() {
313        let payload = serde_json::json!({
314            "sub": "u1",
315            "email": "u@example.com",
316            "email_verified": false,
317        });
318        let provider = GoogleSocialProvider::new(google_config()).unwrap();
319        let info = provider
320            .fetch_user_info(&make_id_token(&payload))
321            .await
322            .unwrap();
323        assert!(!info.email_verified);
324    }
325
326    #[tokio::test]
327    async fn decode_id_token_picture_maps_to_avatar_url() {
328        let payload = serde_json::json!({
329            "sub": "u1",
330            "email": "u@example.com",
331            "email_verified": true,
332            "picture": "https://cdn.example.com/avatar.png"
333        });
334        let provider = GoogleSocialProvider::new(google_config()).unwrap();
335        let info = provider
336            .fetch_user_info(&make_id_token(&payload))
337            .await
338            .unwrap();
339        assert_eq!(
340            info.avatar_url.as_deref(),
341            Some("https://cdn.example.com/avatar.png")
342        );
343    }
344
345    // ── HTTP tests (wiremock) ─────────────────────────────────────────────────
346
347    #[tokio::test]
348    async fn exchange_code_extracts_id_token_on_success() {
349        use wiremock::matchers::{method, path};
350        use wiremock::{Mock, MockServer, ResponseTemplate};
351
352        let server = MockServer::start().await;
353        Mock::given(method("POST"))
354            .and(path("/token"))
355            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
356                "access_token": "unused-access",
357                "id_token": "header.payload.sig",
358                "token_type": "Bearer"
359            })))
360            .mount(&server)
361            .await;
362
363        let token_url = format!("{}/token", server.uri());
364        let provider =
365            GoogleSocialProvider::new_with_token_url(google_config(), token_url).unwrap();
366        let id_token = provider
367            .exchange_code("mycode", "https://example.com/cb", "pkce_v")
368            .await
369            .unwrap();
370        assert_eq!(id_token, "header.payload.sig");
371    }
372
373    #[tokio::test]
374    async fn exchange_code_returns_token_exchange_error_on_4xx() {
375        use wiremock::matchers::{method, path};
376        use wiremock::{Mock, MockServer, ResponseTemplate};
377
378        let server = MockServer::start().await;
379        Mock::given(method("POST"))
380            .and(path("/token"))
381            .respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
382                "error": "invalid_grant"
383            })))
384            .mount(&server)
385            .await;
386
387        let token_url = format!("{}/token", server.uri());
388        let provider =
389            GoogleSocialProvider::new_with_token_url(google_config(), token_url).unwrap();
390        let err = provider
391            .exchange_code("badcode", "https://example.com/cb", "v")
392            .await
393            .unwrap_err();
394        assert!(matches!(err, AuthError::OAuthTokenExchange(_)));
395    }
396
397    #[tokio::test]
398    async fn exchange_code_returns_token_exchange_error_on_missing_id_token() {
399        use wiremock::matchers::{method, path};
400        use wiremock::{Mock, MockServer, ResponseTemplate};
401
402        let server = MockServer::start().await;
403        Mock::given(method("POST"))
404            .and(path("/token"))
405            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
406                "access_token": "some-access-token",
407                "token_type": "Bearer"
408            })))
409            .mount(&server)
410            .await;
411
412        let token_url = format!("{}/token", server.uri());
413        let provider =
414            GoogleSocialProvider::new_with_token_url(google_config(), token_url).unwrap();
415        let err = provider
416            .exchange_code("code", "https://example.com/cb", "v")
417            .await
418            .unwrap_err();
419        match err {
420            AuthError::OAuthTokenExchange(msg) => {
421                assert!(msg.contains("missing id_token"), "got: {msg}");
422            }
423            other => panic!("expected OAuthTokenExchange, got {other:?}"),
424        }
425    }
426}