1use async_trait::async_trait;
2use authly_core::{AuthError, Identity, OAuthProvider, OAuthToken};
3use serde::Deserialize;
4use std::collections::HashMap;
5
6pub struct GoogleProvider {
7 client_id: String,
8 client_secret: String,
9 redirect_uri: String,
10 http_client: reqwest::Client,
11 auth_url: String,
12 token_url: String,
13 userinfo_url: String,
14 revoke_url: String,
15}
16
17impl GoogleProvider {
18 pub fn new(client_id: String, client_secret: String, redirect_uri: String) -> Self {
19 Self {
20 client_id,
21 client_secret,
22 redirect_uri,
23 http_client: reqwest::Client::new(),
24 auth_url: "https://accounts.google.com/o/oauth2/v2/auth".to_string(),
25 token_url: "https://oauth2.googleapis.com/token".to_string(),
26 userinfo_url: "https://www.googleapis.com/oauth2/v3/userinfo".to_string(),
27 revoke_url: "https://oauth2.googleapis.com/revoke".to_string(),
28 }
29 }
30
31 pub fn with_test_urls(
32 mut self,
33 auth_url: String,
34 token_url: String,
35 userinfo_url: String,
36 revoke_url: String,
37 ) -> Self {
38 self.auth_url = auth_url;
39 self.token_url = token_url;
40 self.userinfo_url = userinfo_url;
41 self.revoke_url = revoke_url;
42 self
43 }
44}
45
46#[derive(Deserialize)]
47struct GoogleTokenResponse {
48 access_token: String,
49 token_type: String,
50 expires_in: Option<u64>,
51 refresh_token: Option<String>,
52 scope: Option<String>,
53 id_token: Option<String>,
54}
55
56#[derive(Deserialize)]
57struct GoogleUserResponse {
58 sub: String,
59 email: Option<String>,
60 name: Option<String>,
61 picture: Option<String>,
62 email_verified: Option<bool>,
63 locale: Option<String>,
64}
65
66#[async_trait]
67impl OAuthProvider for GoogleProvider {
68 fn get_authorization_url(
69 &self,
70 state: &str,
71 scopes: &[&str],
72 code_challenge: Option<&str>,
73 ) -> String {
74 let scope_param = if scopes.is_empty() {
75 "openid email profile".to_string()
76 } else {
77 scopes.join(" ")
78 };
79
80 let mut url = format!(
81 "{}?client_id={}&redirect_uri={}&state={}&scope={}&response_type=code&access_type=offline&prompt=consent",
82 self.auth_url, self.client_id, self.redirect_uri, state, scope_param
83 );
84
85 if let Some(challenge) = code_challenge {
86 url.push_str(&format!(
87 "&code_challenge={}&code_challenge_method=S256",
88 challenge
89 ));
90 }
91
92 url
93 }
94
95 async fn exchange_code_for_identity(
96 &self,
97 code: &str,
98 code_verifier: Option<&str>,
99 ) -> Result<(Identity, OAuthToken), AuthError> {
100 let mut params = vec![
102 ("code", code.to_string()),
103 ("client_id", self.client_id.clone()),
104 ("client_secret", self.client_secret.clone()),
105 ("redirect_uri", self.redirect_uri.clone()),
106 ("grant_type", "authorization_code".to_string()),
107 ];
108
109 if let Some(verifier) = code_verifier {
110 params.push(("code_verifier", verifier.to_string()));
111 }
112
113 let token_response = self
114 .http_client
115 .post(&self.token_url)
116 .form(¶ms)
117 .send()
118 .await
119 .map_err(|_| AuthError::Network)?
120 .json::<GoogleTokenResponse>()
121 .await
122 .map_err(|e| AuthError::Provider(format!("Failed to parse token response: {}", e)))?;
123
124 let user_response = self
126 .http_client
127 .get(&self.userinfo_url)
128 .header(
129 "Authorization",
130 format!("Bearer {}", token_response.access_token),
131 )
132 .send()
133 .await
134 .map_err(|_| AuthError::Network)?
135 .json::<GoogleUserResponse>()
136 .await
137 .map_err(|e| AuthError::Provider(format!("Failed to parse user response: {}", e)))?;
138
139 let mut attributes = HashMap::new();
141 if let Some(picture) = user_response.picture {
142 attributes.insert("picture".to_string(), picture);
143 }
144 if let Some(verified) = user_response.email_verified {
145 attributes.insert("email_verified".to_string(), verified.to_string());
146 }
147 if let Some(locale) = user_response.locale {
148 attributes.insert("locale".to_string(), locale);
149 }
150
151 let identity = Identity {
152 provider_id: "google".to_string(),
153 external_id: user_response.sub,
154 email: user_response.email,
155 username: user_response.name,
156 attributes,
157 };
158
159 let token = OAuthToken {
160 access_token: token_response.access_token,
161 token_type: token_response.token_type,
162 expires_in: token_response.expires_in,
163 refresh_token: token_response.refresh_token,
164 scope: token_response.scope,
165 id_token: token_response.id_token,
166 };
167
168 Ok((identity, token))
169 }
170
171 async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken, AuthError> {
172 let token_response = self
173 .http_client
174 .post(&self.token_url)
175 .form(&[
176 ("refresh_token", refresh_token),
177 ("client_id", &self.client_id),
178 ("client_secret", &self.client_secret),
179 ("grant_type", "refresh_token"),
180 ])
181 .send()
182 .await
183 .map_err(|_| AuthError::Network)?
184 .json::<GoogleTokenResponse>()
185 .await
186 .map_err(|e| {
187 AuthError::Provider(format!("Failed to parse refresh token response: {}", e))
188 })?;
189
190 Ok(OAuthToken {
191 access_token: token_response.access_token,
192 token_type: token_response.token_type,
193 expires_in: token_response.expires_in,
194 refresh_token: token_response
195 .refresh_token
196 .or_else(|| Some(refresh_token.to_string())),
197 scope: token_response.scope,
198 id_token: token_response.id_token,
199 })
200 }
201
202 async fn revoke_token(&self, token: &str) -> Result<(), AuthError> {
203 let response = self
204 .http_client
205 .post(&self.revoke_url)
206 .form(&[("token", token)])
207 .send()
208 .await
209 .map_err(|_| AuthError::Network)?;
210
211 if response.status().is_success() {
212 Ok(())
213 } else {
214 let error_text = response
215 .text()
216 .await
217 .unwrap_or_else(|_| "Unknown error".to_string());
218 Err(AuthError::Provider(format!(
219 "Failed to revoke token: {}",
220 error_text
221 )))
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use wiremock::matchers::{method, path};
230 use wiremock::{Mock, MockServer, ResponseTemplate};
231
232 #[tokio::test]
233 async fn test_exchange_code_for_identity() {
234 let server = MockServer::start().await;
235 let auth_url = format!("{}/auth", server.uri());
236 let token_url = format!("{}/token", server.uri());
237 let userinfo_url = format!("{}/userinfo", server.uri());
238
239 Mock::given(method("POST"))
240 .and(path("/token"))
241 .respond_with(ResponseTemplate::new(200)
242 .set_body_json(serde_json::json!({"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "test_refresh_token"})))
243 .mount(&server)
244 .await;
245
246 Mock::given(method("GET"))
247 .and(path("/userinfo"))
248 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
249 "sub": "google-123",
250 "email": "test@google.com",
251 "name": "Google User",
252 "picture": "http://picture",
253 "email_verified": true,
254 "locale": "en"
255 })))
256 .mount(&server)
257 .await;
258
259 let provider = GoogleProvider::new(
260 "client_id".to_string(),
261 "client_secret".to_string(),
262 "http://localhost/callback".to_string(),
263 )
264 .with_test_urls(
265 auth_url,
266 token_url,
267 userinfo_url,
268 format!("{}/revoke", server.uri()),
269 );
270
271 let (identity, token) = provider
272 .exchange_code_for_identity("test_code", None)
273 .await
274 .unwrap();
275
276 assert_eq!(identity.provider_id, "google");
277 assert_eq!(identity.external_id, "google-123");
278 assert_eq!(identity.username, Some("Google User".to_string()));
279 assert_eq!(identity.email, Some("test@google.com".to_string()));
280 assert_eq!(
281 identity.attributes.get("picture").unwrap(),
282 "http://picture"
283 );
284 assert_eq!(identity.attributes.get("email_verified").unwrap(), "true");
285 assert_eq!(identity.attributes.get("locale").unwrap(), "en");
286 assert_eq!(token.access_token, "test_token");
287 assert_eq!(token.refresh_token, Some("test_refresh_token".to_string()));
288 }
289}