Skip to main content

authly_providers_google/
lib.rs

1use async_trait::async_trait;
2use authly_core::{AuthError, Identity, OAuthProvider, OAuthToken};
3use serde::Deserialize;
4use std::collections::HashMap;
5
6pub struct GoogleProvider {
7    client_id: String,
8    client_secret: String,
9    redirect_uri: String,
10    http_client: reqwest::Client,
11    auth_url: String,
12    token_url: String,
13    userinfo_url: String,
14    revoke_url: String,
15}
16
17impl GoogleProvider {
18    pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
19        Self {
20            client_id,
21            client_secret,
22            redirect_uri,
23            http_client: reqwest::Client::new(),
24            auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
25            token_url: "https://oauth2.googleapis.com/token".to_string(),
26            userinfo_url: "https://www.googleapis.com/oauth2/v3/userinfo".to_string(),
27            revoke_url: "https://oauth2.googleapis.com/revoke".to_string(),
28        }
29    }
30
31    pub fn with_test_urls(
32        mut self,
33        auth_url: String,
34        token_url: String,
35        userinfo_url: String,
36        revoke_url: String,
37    ) -> Self {
38        self.auth_url = auth_url;
39        self.token_url = token_url;
40        self.userinfo_url = userinfo_url;
41        self.revoke_url = revoke_url;
42        self
43    }
44}
45
46#[derive(Deserialize)]
47struct GoogleTokenResponse {
48    access_token: String,
49    token_type: String,
50    expires_in: Option<u64>,
51    refresh_token: Option<String>,
52    scope: Option<String>,
53    id_token: Option<String>,
54}
55
56#[derive(Deserialize)]
57struct GoogleUserResponse {
58    sub: String,
59    email: Option<String>,
60    name: Option<String>,
61    picture: Option<String>,
62    email_verified: Option<bool>,
63    locale: Option<String>,
64}
65
66#[async_trait]
67impl OAuthProvider for GoogleProvider {
68    fn get_authorization_url(
69        &self,
70        state: &str,
71        scopes: &[&str],
72        code_challenge: Option<&str>,
73    ) -> String {
74        let scope_param = if scopes.is_empty() {
75            "openid email profile".to_string()
76        } else {
77            scopes.join(" ")
78        };
79
80        let mut url = format!(
81            "{}?client_id={}&redirect_uri={}&state={}&scope={}&response_type=code&access_type=offline&prompt=consent",
82            self.auth_url, self.client_id, self.redirect_uri, state, scope_param
83        );
84
85        if let Some(challenge) = code_challenge {
86            url.push_str(&format!(
87                "&code_challenge={}&code_challenge_method=S256",
88                challenge
89            ));
90        }
91
92        url
93    }
94
95    async fn exchange_code_for_identity(
96        &self,
97        code: &str,
98        code_verifier: Option<&str>,
99    ) -> Result<(Identity, OAuthToken), AuthError> {
100        // 1. Exchange code for access token
101        let mut params = vec![
102            ("code", code.to_string()),
103            ("client_id", self.client_id.clone()),
104            ("client_secret", self.client_secret.clone()),
105            ("redirect_uri", self.redirect_uri.clone()),
106            ("grant_type", "authorization_code".to_string()),
107        ];
108
109        if let Some(verifier) = code_verifier {
110            params.push(("code_verifier", verifier.to_string()));
111        }
112
113        let token_response = self
114            .http_client
115            .post(&self.token_url)
116            .form(&params)
117            .send()
118            .await
119            .map_err(|_| AuthError::Network)?
120            .json::<GoogleTokenResponse>()
121            .await
122            .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
123
124        // 2. Get user information
125        let user_response = self
126            .http_client
127            .get(&self.userinfo_url)
128            .header(
129                "Authorization",
130                format!("Bearer {}", token_response.access_token),
131            )
132            .send()
133            .await
134            .map_err(|_| AuthError::Network)?
135            .json::<GoogleUserResponse>()
136            .await
137            .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
138
139        // 3. Map to Identity
140        let mut attributes = HashMap::new();
141        if let Some(picture) = user_response.picture {
142            attributes.insert("picture".to_string(), picture);
143        }
144        if let Some(verified) = user_response.email_verified {
145            attributes.insert("email_verified".to_string(), verified.to_string());
146        }
147        if let Some(locale) = user_response.locale {
148            attributes.insert("locale".to_string(), locale);
149        }
150
151        let identity = Identity {
152            provider_id: "google".to_string(),
153            external_id: user_response.sub,
154            email: user_response.email,
155            username: user_response.name,
156            attributes,
157        };
158
159        let token = OAuthToken {
160            access_token: token_response.access_token,
161            token_type: token_response.token_type,
162            expires_in: token_response.expires_in,
163            refresh_token: token_response.refresh_token,
164            scope: token_response.scope,
165            id_token: token_response.id_token,
166        };
167
168        Ok((identity, token))
169    }
170
171    async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
172        let token_response = self
173            .http_client
174            .post(&self.token_url)
175            .form(&[
176                ("refresh_token", refresh_token),
177                ("client_id", &self.client_id),
178                ("client_secret", &self.client_secret),
179                ("grant_type", "refresh_token"),
180            ])
181            .send()
182            .await
183            .map_err(|_| AuthError::Network)?
184            .json::<GoogleTokenResponse>()
185            .await
186            .map_err(|e| {
187                AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
188            })?;
189
190        Ok(OAuthToken {
191            access_token: token_response.access_token,
192            token_type: token_response.token_type,
193            expires_in: token_response.expires_in,
194            refresh_token: token_response
195                .refresh_token
196                .or_else(|| Some(refresh_token.to_string())),
197            scope: token_response.scope,
198            id_token: token_response.id_token,
199        })
200    }
201
202    async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
203        let response = self
204            .http_client
205            .post(&self.revoke_url)
206            .form(&[("token", token)])
207            .send()
208            .await
209            .map_err(|_| AuthError::Network)?;
210
211        if response.status().is_success() {
212            Ok(())
213        } else {
214            let error_text = response
215                .text()
216                .await
217                .unwrap_or_else(|_| "Unknown error".to_string());
218            Err(AuthError::Provider(format!(
219                "Failed to revoke token: {}",
220                error_text
221            )))
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229    use wiremock::matchers::{method, path};
230    use wiremock::{Mock, MockServer, ResponseTemplate};
231
232    #[tokio::test]
233    async fn test_exchange_code_for_identity() {
234        let server = MockServer::start().await;
235        let auth_url = format!("{}/auth", server.uri());
236        let token_url = format!("{}/token", server.uri());
237        let userinfo_url = format!("{}/userinfo", server.uri());
238
239        Mock::given(method("POST"))
240            .and(path("/token"))
241            .respond_with(ResponseTemplate::new(200)
242                .set_body_json(serde_json::json!({"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "test_refresh_token"})))
243            .mount(&server)
244            .await;
245
246        Mock::given(method("GET"))
247            .and(path("/userinfo"))
248            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
249                "sub": "google-123",
250                "email": "test@google.com",
251                "name": "Google User",
252                "picture": "http://picture",
253                "email_verified": true,
254                "locale": "en"
255            })))
256            .mount(&server)
257            .await;
258
259        let provider = GoogleProvider::new(
260            "client_id".to_string(),
261            "client_secret".to_string(),
262            "http://localhost/callback".to_string(),
263        )
264        .with_test_urls(
265            auth_url,
266            token_url,
267            userinfo_url,
268            format!("{}/revoke", server.uri()),
269        );
270
271        let (identity, token) = provider
272            .exchange_code_for_identity("test_code", None)
273            .await
274            .unwrap();
275
276        assert_eq!(identity.provider_id, "google");
277        assert_eq!(identity.external_id, "google-123");
278        assert_eq!(identity.username, Some("Google User".to_string()));
279        assert_eq!(identity.email, Some("test@google.com".to_string()));
280        assert_eq!(
281            identity.attributes.get("picture").unwrap(),
282            "http://picture"
283        );
284        assert_eq!(identity.attributes.get("email_verified").unwrap(), "true");
285        assert_eq!(identity.attributes.get("locale").unwrap(), "en");
286        assert_eq!(token.access_token, "test_token");
287        assert_eq!(token.refresh_token, Some("test_refresh_token".to_string()));
288    }
289}