lib_client_google_auth/
auth.rs

1use async_trait::async_trait;
2use chrono::{Duration, Utc};
3use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
4use reqwest::header::HeaderMap;
5use serde::{Deserialize, Serialize};
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::credentials::ServiceAccountCredentials;
10use crate::error::Result;
11use crate::token::{Token, TokenResponse, TokenStore};
12
13const GOOGLE_TOKEN_URL: &str = "https://oauth2.googleapis.com/token";
14
15/// Authentication strategy for Google APIs.
16#[async_trait]
17pub trait AuthStrategy: Send + Sync {
18    /// Apply authentication to request headers.
19    async fn apply(&self, headers: &mut HeaderMap) -> Result<()>;
20}
21
22/// API key authentication.
23#[derive(Debug, Clone)]
24pub struct ApiKeyAuth {
25    api_key: String,
26}
27
28impl ApiKeyAuth {
29    /// Create new API key auth.
30    pub fn new(api_key: impl Into<String>) -> Self {
31        Self {
32            api_key: api_key.into(),
33        }
34    }
35}
36
37#[async_trait]
38impl AuthStrategy for ApiKeyAuth {
39    async fn apply(&self, headers: &mut HeaderMap) -> Result<()> {
40        headers.insert("x-goog-api-key", self.api_key.parse().unwrap());
41        Ok(())
42    }
43}
44
45/// JWT claims for service account authentication.
46#[derive(Debug, Serialize, Deserialize)]
47struct JwtClaims {
48    iss: String,
49    sub: Option<String>,
50    aud: String,
51    iat: i64,
52    exp: i64,
53    scope: String,
54}
55
56/// Service account authentication using JWT.
57pub struct ServiceAccountAuth {
58    credentials: ServiceAccountCredentials,
59    scopes: Vec<String>,
60    subject: Option<String>,
61    token: Arc<RwLock<Option<Token>>>,
62    http: reqwest::Client,
63}
64
65impl ServiceAccountAuth {
66    /// Create new service account auth.
67    pub fn new(credentials: ServiceAccountCredentials, scopes: Vec<String>) -> Self {
68        Self {
69            credentials,
70            scopes,
71            subject: None,
72            token: Arc::new(RwLock::new(None)),
73            http: reqwest::Client::new(),
74        }
75    }
76
77    /// Set subject (for domain-wide delegation).
78    pub fn with_subject(mut self, subject: impl Into<String>) -> Self {
79        self.subject = Some(subject.into());
80        self
81    }
82
83    /// Create signed JWT assertion.
84    fn create_jwt(&self) -> Result<String> {
85        let now = Utc::now();
86        let claims = JwtClaims {
87            iss: self.credentials.client_email.clone(),
88            sub: self.subject.clone(),
89            aud: GOOGLE_TOKEN_URL.to_string(),
90            iat: now.timestamp(),
91            exp: (now + Duration::hours(1)).timestamp(),
92            scope: self.scopes.join(" "),
93        };
94
95        let header = Header::new(Algorithm::RS256);
96        let key = EncodingKey::from_rsa_pem(self.credentials.private_key.as_bytes())?;
97        Ok(encode(&header, &claims, &key)?)
98    }
99
100    /// Fetch access token using JWT assertion.
101    async fn fetch_token(&self) -> Result<Token> {
102        let jwt = self.create_jwt()?;
103
104        let params = [
105            ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
106            ("assertion", &jwt),
107        ];
108
109        let response = self
110            .http
111            .post(GOOGLE_TOKEN_URL)
112            .form(&params)
113            .send()
114            .await?;
115
116        if !response.status().is_success() {
117            let status = response.status();
118            let text = response.text().await.unwrap_or_default();
119            return Err(crate::Error::TokenRefresh {
120                message: format!("Status {}: {}", status, text),
121            });
122        }
123
124        let token_resp: TokenResponse = response.json().await?;
125        Ok(token_resp.into())
126    }
127
128    /// Get valid access token, refreshing if needed.
129    async fn get_token(&self) -> Result<String> {
130        {
131            let token = self.token.read().await;
132            if let Some(t) = token.as_ref() {
133                if !t.is_expired() {
134                    return Ok(t.access_token.clone());
135                }
136            }
137        }
138
139        let new_token = self.fetch_token().await?;
140        let access_token = new_token.access_token.clone();
141
142        let mut token = self.token.write().await;
143        *token = Some(new_token);
144
145        Ok(access_token)
146    }
147}
148
149#[async_trait]
150impl AuthStrategy for ServiceAccountAuth {
151    async fn apply(&self, headers: &mut HeaderMap) -> Result<()> {
152        let token = self.get_token().await?;
153        headers.insert(
154            "Authorization",
155            format!("Bearer {}", token).parse().unwrap(),
156        );
157        Ok(())
158    }
159}
160
161/// OAuth2 authentication with refresh token support.
162pub struct OAuth2Auth {
163    client_id: String,
164    client_secret: String,
165    scopes: Vec<String>,
166    token_store: Option<Arc<dyn TokenStore>>,
167    token: Arc<RwLock<Option<Token>>>,
168    http: reqwest::Client,
169}
170
171impl OAuth2Auth {
172    /// Create new OAuth2 auth.
173    pub fn new(
174        client_id: impl Into<String>,
175        client_secret: impl Into<String>,
176        scopes: Vec<String>,
177    ) -> Self {
178        Self {
179            client_id: client_id.into(),
180            client_secret: client_secret.into(),
181            scopes,
182            token_store: None,
183            token: Arc::new(RwLock::new(None)),
184            http: reqwest::Client::new(),
185        }
186    }
187
188    /// Set token store for persistence.
189    pub fn with_token_store(mut self, store: Arc<dyn TokenStore>) -> Self {
190        self.token_store = Some(store);
191        self
192    }
193
194    /// Set initial token (e.g., from stored refresh token).
195    pub fn with_token(mut self, token: Token) -> Self {
196        self.token = Arc::new(RwLock::new(Some(token)));
197        self
198    }
199
200    /// Generate authorization URL for user consent.
201    pub fn authorization_url(&self, redirect_uri: &str, state: &str) -> String {
202        let scope = self.scopes.join(" ");
203        format!(
204            "https://accounts.google.com/o/oauth2/v2/auth?\
205            client_id={}&\
206            redirect_uri={}&\
207            response_type=code&\
208            scope={}&\
209            state={}&\
210            access_type=offline&\
211            prompt=consent",
212            urlencoding::encode(&self.client_id),
213            urlencoding::encode(redirect_uri),
214            urlencoding::encode(&scope),
215            urlencoding::encode(state)
216        )
217    }
218
219    /// Exchange authorization code for tokens.
220    pub async fn exchange_code(&self, code: &str, redirect_uri: &str) -> Result<Token> {
221        let params = [
222            ("code", code),
223            ("client_id", &self.client_id),
224            ("client_secret", &self.client_secret),
225            ("redirect_uri", redirect_uri),
226            ("grant_type", "authorization_code"),
227        ];
228
229        let response = self
230            .http
231            .post(GOOGLE_TOKEN_URL)
232            .form(&params)
233            .send()
234            .await?;
235
236        if !response.status().is_success() {
237            let status = response.status();
238            let text = response.text().await.unwrap_or_default();
239            return Err(crate::Error::AuthorizationFailed(format!(
240                "Status {}: {}",
241                status, text
242            )));
243        }
244
245        let token_resp: TokenResponse = response.json().await?;
246        let token: Token = token_resp.into();
247
248        let mut stored = self.token.write().await;
249        *stored = Some(token.clone());
250
251        if let Some(store) = &self.token_store {
252            store.store("google_oauth", &token).await?;
253        }
254
255        Ok(token)
256    }
257
258    /// Refresh access token using refresh token.
259    async fn refresh_token(&self, refresh_token: &str) -> Result<Token> {
260        let params = [
261            ("refresh_token", refresh_token),
262            ("client_id", &self.client_id),
263            ("client_secret", &self.client_secret),
264            ("grant_type", "refresh_token"),
265        ];
266
267        let response = self
268            .http
269            .post(GOOGLE_TOKEN_URL)
270            .form(&params)
271            .send()
272            .await?;
273
274        if !response.status().is_success() {
275            let status = response.status();
276            let text = response.text().await.unwrap_or_default();
277            return Err(crate::Error::TokenRefresh {
278                message: format!("Status {}: {}", status, text),
279            });
280        }
281
282        let token_resp: TokenResponse = response.json().await?;
283        let mut token: Token = token_resp.into();
284
285        // Preserve refresh token if not returned
286        if token.refresh_token.is_none() || token.refresh_token.as_ref().map(|s| s.is_empty()).unwrap_or(true) {
287            token.refresh_token = Some(refresh_token.to_string());
288        }
289
290        Ok(token)
291    }
292
293    /// Get valid access token, refreshing if needed.
294    async fn get_token(&self) -> Result<String> {
295        {
296            let token = self.token.read().await;
297            if let Some(t) = token.as_ref() {
298                if !t.is_expired() {
299                    return Ok(t.access_token.clone());
300                }
301            }
302        }
303
304        let refresh_token = {
305            let token = self.token.read().await;
306            token
307                .as_ref()
308                .and_then(|t| t.refresh_token.clone())
309                .ok_or(crate::Error::TokenExpired)?
310        };
311
312        let new_token = self.refresh_token(&refresh_token).await?;
313        let access_token = new_token.access_token.clone();
314
315        let mut token = self.token.write().await;
316        *token = Some(new_token.clone());
317
318        if let Some(store) = &self.token_store {
319            store.store("google_oauth", &new_token).await?;
320        }
321
322        Ok(access_token)
323    }
324}
325
326#[async_trait]
327impl AuthStrategy for OAuth2Auth {
328    async fn apply(&self, headers: &mut HeaderMap) -> Result<()> {
329        let token = self.get_token().await?;
330        headers.insert(
331            "Authorization",
332            format!("Bearer {}", token).parse().unwrap(),
333        );
334        Ok(())
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_api_key_auth() {
344        let auth = ApiKeyAuth::new("test-key");
345        assert_eq!(auth.api_key, "test-key");
346    }
347
348    #[test]
349    fn test_authorization_url() {
350        let auth = OAuth2Auth::new(
351            "client-id",
352            "client-secret",
353            vec!["https://www.googleapis.com/auth/drive".to_string()],
354        );
355        let url = auth.authorization_url("http://localhost:8080/callback", "state123");
356        assert!(url.contains("client_id=client-id"));
357        assert!(url.contains("state=state123"));
358    }
359}