Skip to main content

allowthem_core/
oauth_google.rs

1use serde::Deserialize;
2use url::Url;
3
4use crate::auth_client::AuthFuture;
5use crate::error::AuthError;
6use crate::oauth::{OAuthProvider, OAuthUserInfo};
7
8// ---------------------------------------------------------------------------
9// GoogleProvider
10// ---------------------------------------------------------------------------
11
12/// OAuth2 authorization-code + PKCE provider for Google.
13pub struct GoogleProvider {
14    client_id: String,
15    client_secret: String,
16    http: reqwest::Client,
17}
18
19impl GoogleProvider {
20    pub fn new(client_id: impl Into<String>, client_secret: impl Into<String>) -> Self {
21        Self {
22            client_id: client_id.into(),
23            client_secret: client_secret.into(),
24            http: reqwest::Client::new(),
25        }
26    }
27}
28
29// ---------------------------------------------------------------------------
30// Response shapes
31// ---------------------------------------------------------------------------
32
33#[derive(Deserialize)]
34struct TokenResponse {
35    access_token: String,
36}
37
38#[derive(Deserialize)]
39struct UserInfoResponse {
40    sub: String,
41    email: String,
42    email_verified: bool,
43    name: Option<String>,
44}
45
46// ---------------------------------------------------------------------------
47// OAuthProvider impl
48// ---------------------------------------------------------------------------
49
50impl OAuthProvider for GoogleProvider {
51    fn name(&self) -> &str {
52        "google"
53    }
54
55    fn authorize_url(&self, redirect_uri: &str, state: &str, pkce_challenge: &str) -> String {
56        let mut url =
57            Url::parse("https://accounts.google.com/o/oauth2/v2/auth").expect("static URL");
58        url.query_pairs_mut()
59            .append_pair("client_id", &self.client_id)
60            .append_pair("redirect_uri", redirect_uri)
61            .append_pair("response_type", "code")
62            .append_pair("scope", "openid email profile")
63            .append_pair("state", state)
64            .append_pair("code_challenge", pkce_challenge)
65            .append_pair("code_challenge_method", "S256");
66        url.into()
67    }
68
69    fn exchange_code<'a>(
70        &'a self,
71        code: &'a str,
72        redirect_uri: &'a str,
73        pkce_verifier: &'a str,
74    ) -> AuthFuture<'a, String> {
75        Box::pin(async move {
76            let resp = self
77                .http
78                .post("https://oauth2.googleapis.com/token")
79                .form(&[
80                    ("code", code),
81                    ("client_id", &self.client_id),
82                    ("client_secret", &self.client_secret),
83                    ("redirect_uri", redirect_uri),
84                    ("grant_type", "authorization_code"),
85                    ("code_verifier", pkce_verifier),
86                ])
87                .send()
88                .await
89                .map_err(|e| AuthError::OAuthHttp(e.to_string()))?;
90
91            if !resp.status().is_success() {
92                let status = resp.status();
93                let body = resp.text().await.unwrap_or_default();
94                return Err(AuthError::OAuthTokenExchange(format!("{status}: {body}")));
95            }
96
97            let token: TokenResponse = resp
98                .json()
99                .await
100                .map_err(|e| AuthError::OAuthTokenExchange(e.to_string()))?;
101
102            Ok(token.access_token)
103        })
104    }
105
106    fn user_info<'a>(&'a self, access_token: &'a str) -> AuthFuture<'a, OAuthUserInfo> {
107        Box::pin(async move {
108            let resp = self
109                .http
110                .get("https://www.googleapis.com/oauth2/v3/userinfo")
111                .bearer_auth(access_token)
112                .send()
113                .await
114                .map_err(|e| AuthError::OAuthHttp(e.to_string()))?;
115
116            if !resp.status().is_success() {
117                let status = resp.status();
118                let body = resp.text().await.unwrap_or_default();
119                return Err(AuthError::OAuthUserInfoFetch(format!("{status}: {body}")));
120            }
121
122            let info: UserInfoResponse = resp
123                .json()
124                .await
125                .map_err(|e| AuthError::OAuthUserInfoFetch(e.to_string()))?;
126
127            Ok(OAuthUserInfo {
128                provider_user_id: info.sub,
129                email: info.email,
130                email_verified: info.email_verified,
131                name: info.name,
132            })
133        })
134    }
135}
136
137// ---------------------------------------------------------------------------
138// Tests
139// ---------------------------------------------------------------------------
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    fn make_provider() -> GoogleProvider {
146        GoogleProvider::new("test-client-id", "test-client-secret")
147    }
148
149    #[test]
150    fn name_returns_google() {
151        assert_eq!(make_provider().name(), "google");
152    }
153
154    #[test]
155    fn authorize_url_contains_required_params() {
156        let p = make_provider();
157        let url = p.authorize_url("https://example.com/callback", "state-abc", "challenge-xyz");
158
159        assert!(url.contains("client_id=test-client-id"), "client_id");
160        assert!(
161            url.contains("redirect_uri=https%3A%2F%2Fexample.com%2Fcallback"),
162            "redirect_uri encoded"
163        );
164        assert!(url.contains("response_type=code"), "response_type");
165        assert!(
166            url.contains("scope=openid+email+profile")
167                || url.contains("scope=openid%20email%20profile"),
168            "scope"
169        );
170        assert!(url.contains("state=state-abc"), "state");
171        assert!(
172            url.contains("code_challenge=challenge-xyz"),
173            "code_challenge"
174        );
175        assert!(url.contains("code_challenge_method=S256"), "method");
176    }
177
178    #[test]
179    fn authorize_url_starts_with_google_endpoint() {
180        let p = make_provider();
181        let url = p.authorize_url("https://example.com/cb", "s", "c");
182        assert!(
183            url.starts_with("https://accounts.google.com/o/oauth2/v2/auth"),
184            "unexpected base: {url}"
185        );
186    }
187
188    #[test]
189    fn new_accepts_string_and_str() {
190        let _p1 = GoogleProvider::new("id".to_string(), "secret".to_string());
191        let _p2 = GoogleProvider::new("id", "secret");
192    }
193}