1use async_trait::async_trait;
2use chrono::{Duration, Utc};
3use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
4use reqwest::header::HeaderMap;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::credentials::ServiceAccountCredentials;
10use crate::error::Result;
11use crate::token::{Token, TokenResponse, TokenStore};
12
13const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
14
15#[async_trait]
17pub trait AuthStrategy: Send + Sync {
18 async fn apply(&self, headers: &mut HeaderMap) -> Result<()>;
20}
21
22#[derive(Debug, Clone)]
24pub struct ApiKeyAuth {
25 api_key: String,
26}
27
28impl ApiKeyAuth {
29 pub fn new(api_key: impl Into<String>) -> Self {
31 Self {
32 api_key: api_key.into(),
33 }
34 }
35}
36
37#[async_trait]
38impl AuthStrategy for ApiKeyAuth {
39 async fn apply(&self, headers: &mut HeaderMap) -> Result<()> {
40 headers.insert("x-goog-api-key", self.api_key.parse().unwrap());
41 Ok(())
42 }
43}
44
45#[derive(Debug, Serialize, Deserialize)]
47struct JwtClaims {
48 iss: String,
49 sub: Option<String>,
50 aud: String,
51 iat: i64,
52 exp: i64,
53 scope: String,
54}
55
56pub struct ServiceAccountAuth {
58 credentials: ServiceAccountCredentials,
59 scopes: Vec<String>,
60 subject: Option<String>,
61 token: Arc<RwLock<Option<Token>>>,
62 http: reqwest::Client,
63}
64
65impl ServiceAccountAuth {
66 pub fn new(credentials: ServiceAccountCredentials, scopes: Vec<String>) -> Self {
68 Self {
69 credentials,
70 scopes,
71 subject: None,
72 token: Arc::new(RwLock::new(None)),
73 http: reqwest::Client::new(),
74 }
75 }
76
77 pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
79 self.subject = Some(subject.into());
80 self
81 }
82
83 fn create_jwt(&self) -> Result<String> {
85 let now = Utc::now();
86 let claims = JwtClaims {
87 iss: self.credentials.client_email.clone(),
88 sub: self.subject.clone(),
89 aud: GOOGLE_TOKEN_URL.to_string(),
90 iat: now.timestamp(),
91 exp: (now + Duration::hours(1)).timestamp(),
92 scope: self.scopes.join(" "),
93 };
94
95 let header = Header::new(Algorithm::RS256);
96 let key = EncodingKey::from_rsa_pem(self.credentials.private_key.as_bytes())?;
97 Ok(encode(&header, &claims, &key)?)
98 }
99
100 async fn fetch_token(&self) -> Result<Token> {
102 let jwt = self.create_jwt()?;
103
104 let params = [
105 ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
106 ("assertion", &jwt),
107 ];
108
109 let response = self
110 .http
111 .post(GOOGLE_TOKEN_URL)
112 .form(¶ms)
113 .send()
114 .await?;
115
116 if !response.status().is_success() {
117 let status = response.status();
118 let text = response.text().await.unwrap_or_default();
119 return Err(crate::Error::TokenRefresh {
120 message: format!("Status {}: {}", status, text),
121 });
122 }
123
124 let token_resp: TokenResponse = response.json().await?;
125 Ok(token_resp.into())
126 }
127
128 async fn get_token(&self) -> Result<String> {
130 {
131 let token = self.token.read().await;
132 if let Some(t) = token.as_ref() {
133 if !t.is_expired() {
134 return Ok(t.access_token.clone());
135 }
136 }
137 }
138
139 let new_token = self.fetch_token().await?;
140 let access_token = new_token.access_token.clone();
141
142 let mut token = self.token.write().await;
143 *token = Some(new_token);
144
145 Ok(access_token)
146 }
147}
148
149#[async_trait]
150impl AuthStrategy for ServiceAccountAuth {
151 async fn apply(&self, headers: &mut HeaderMap) -> Result<()> {
152 let token = self.get_token().await?;
153 headers.insert(
154 "Authorization",
155 format!("Bearer {}", token).parse().unwrap(),
156 );
157 Ok(())
158 }
159}
160
161pub struct OAuth2Auth {
163 client_id: String,
164 client_secret: String,
165 scopes: Vec<String>,
166 token_store: Option<Arc<dyn TokenStore>>,
167 token: Arc<RwLock<Option<Token>>>,
168 http: reqwest::Client,
169}
170
171impl OAuth2Auth {
172 pub fn new(
174 client_id: impl Into<String>,
175 client_secret: impl Into<String>,
176 scopes: Vec<String>,
177 ) -> Self {
178 Self {
179 client_id: client_id.into(),
180 client_secret: client_secret.into(),
181 scopes,
182 token_store: None,
183 token: Arc::new(RwLock::new(None)),
184 http: reqwest::Client::new(),
185 }
186 }
187
188 pub fn with_token_store(mut self, store: Arc<dyn TokenStore>) -> Self {
190 self.token_store = Some(store);
191 self
192 }
193
194 pub fn with_token(mut self, token: Token) -> Self {
196 self.token = Arc::new(RwLock::new(Some(token)));
197 self
198 }
199
200 pub fn authorization_url(&self, redirect_uri: &str, state: &str) -> String {
202 let scope = self.scopes.join(" ");
203 format!(
204 "https://accounts.google.com/o/oauth2/v2/auth?\
205 client_id={}&\
206 redirect_uri={}&\
207 response_type=code&\
208 scope={}&\
209 state={}&\
210 access_type=offline&\
211 prompt=consent",
212 urlencoding::encode(&self.client_id),
213 urlencoding::encode(redirect_uri),
214 urlencoding::encode(&scope),
215 urlencoding::encode(state)
216 )
217 }
218
219 pub async fn exchange_code(&self, code: &str, redirect_uri: &str) -> Result<Token> {
221 let params = [
222 ("code", code),
223 ("client_id", &self.client_id),
224 ("client_secret", &self.client_secret),
225 ("redirect_uri", redirect_uri),
226 ("grant_type", "authorization_code"),
227 ];
228
229 let response = self
230 .http
231 .post(GOOGLE_TOKEN_URL)
232 .form(¶ms)
233 .send()
234 .await?;
235
236 if !response.status().is_success() {
237 let status = response.status();
238 let text = response.text().await.unwrap_or_default();
239 return Err(crate::Error::AuthorizationFailed(format!(
240 "Status {}: {}",
241 status, text
242 )));
243 }
244
245 let token_resp: TokenResponse = response.json().await?;
246 let token: Token = token_resp.into();
247
248 let mut stored = self.token.write().await;
249 *stored = Some(token.clone());
250
251 if let Some(store) = &self.token_store {
252 store.store("google_oauth", &token).await?;
253 }
254
255 Ok(token)
256 }
257
258 async fn refresh_token(&self, refresh_token: &str) -> Result<Token> {
260 let params = [
261 ("refresh_token", refresh_token),
262 ("client_id", &self.client_id),
263 ("client_secret", &self.client_secret),
264 ("grant_type", "refresh_token"),
265 ];
266
267 let response = self
268 .http
269 .post(GOOGLE_TOKEN_URL)
270 .form(¶ms)
271 .send()
272 .await?;
273
274 if !response.status().is_success() {
275 let status = response.status();
276 let text = response.text().await.unwrap_or_default();
277 return Err(crate::Error::TokenRefresh {
278 message: format!("Status {}: {}", status, text),
279 });
280 }
281
282 let token_resp: TokenResponse = response.json().await?;
283 let mut token: Token = token_resp.into();
284
285 if token.refresh_token.is_none() || token.refresh_token.as_ref().map(|s| s.is_empty()).unwrap_or(true) {
287 token.refresh_token = Some(refresh_token.to_string());
288 }
289
290 Ok(token)
291 }
292
293 async fn get_token(&self) -> Result<String> {
295 {
296 let token = self.token.read().await;
297 if let Some(t) = token.as_ref() {
298 if !t.is_expired() {
299 return Ok(t.access_token.clone());
300 }
301 }
302 }
303
304 let refresh_token = {
305 let token = self.token.read().await;
306 token
307 .as_ref()
308 .and_then(|t| t.refresh_token.clone())
309 .ok_or(crate::Error::TokenExpired)?
310 };
311
312 let new_token = self.refresh_token(&refresh_token).await?;
313 let access_token = new_token.access_token.clone();
314
315 let mut token = self.token.write().await;
316 *token = Some(new_token.clone());
317
318 if let Some(store) = &self.token_store {
319 store.store("google_oauth", &new_token).await?;
320 }
321
322 Ok(access_token)
323 }
324}
325
326#[async_trait]
327impl AuthStrategy for OAuth2Auth {
328 async fn apply(&self, headers: &mut HeaderMap) -> Result<()> {
329 let token = self.get_token().await?;
330 headers.insert(
331 "Authorization",
332 format!("Bearer {}", token).parse().unwrap(),
333 );
334 Ok(())
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_api_key_auth() {
344 let auth = ApiKeyAuth::new("test-key");
345 assert_eq!(auth.api_key, "test-key");
346 }
347
348 #[test]
349 fn test_authorization_url() {
350 let auth = OAuth2Auth::new(
351 "client-id",
352 "client-secret",
353 vec!["https://www.googleapis.com/auth/drive".to_string()],
354 );
355 let url = auth.authorization_url("http://localhost:8080/callback", "state123");
356 assert!(url.contains("client_id=client-id"));
357 assert!(url.contains("state=state123"));
358 }
359}