google_ai_rs/
auth.rs

1use serde::Serialize;
2use thiserror::Error;
3use tokio::stream;
4use tonic::RawRequestHeaderValue;
5
6#[cfg(feature = "jwt")]
7use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
8#[cfg(feature = "jwt")]
9use rsa::{
10    pkcs1::{DecodeRsaPrivateKey, Error as Pkcs1Error},
11    pkcs1v15::SigningKey,
12    pkcs8::{DecodePrivateKey, Error as Pkcs8Error},
13    signature::{RandomizedSigner, SignatureEncoding},
14    RsaPrivateKey,
15};
16#[cfg(feature = "jwt")]
17use serde::Deserialize;
18#[cfg(feature = "jwt")]
19use serde_json::json;
20#[cfg(feature = "jwt")]
21use sha2::Sha256;
22#[cfg(feature = "jwt")]
23use std::{
24    path::Path,
25    sync::Arc,
26    time::{Duration, SystemTime, SystemTimeError},
27};
28#[cfg(feature = "jwt")]
29use tokio::sync::RwLock;
30
31/// Authentication configuration options
32#[derive(Clone, Debug)]
33pub enum Auth {
34    /// API key authentication (simple and fast but less secure)
35    ApiKey(String),
36    #[cfg(feature = "jwt")]
37    /// JWT-based service account authentication (more secure but with more overhead)
38    TokenSource(TokenSource),
39}
40
41impl<S: Into<String>> From<S> for Auth {
42    fn from(value: S) -> Self {
43        Auth::ApiKey(value.into())
44    }
45}
46
47/// JSON Web Token configuration for service account authentication
48#[cfg(feature = "jwt")]
49#[derive(Deserialize, Clone, Debug)]
50pub struct JWTConfig {
51    /// Service account client email (format: name@project.iam.gserviceaccount.com)
52    #[serde(rename = "client_email")]
53    pub client_email: String,
54
55    /// RSA private key in PEM or DER format (keep secure!)
56    #[serde(rename = "private_key")]
57    pub private_key: String,
58
59    /// Optional private key identifier from Google Cloud
60    #[serde(rename = "private_key_id")]
61    pub private_key_id: String,
62
63    /// Token lifetime duration (default: 1 hour)
64    #[serde(skip)]
65    pub lifetime: Option<Duration>,
66}
67
68/// Token generation source types
69#[cfg(feature = "jwt")]
70#[derive(Clone, Debug)]
71pub enum TokenSource {
72    #[deprecated(note = "use `TokenSource::JWT`(all-caps) instead")]
73    /// JSON Web Token authentication flow
74    Jwt { jwt: Box<JwtService> },
75
76    /// JSON Web Token authentication flow
77    JWT(JWTConfig),
78}
79
80/// Authentication error types
81#[derive(Debug, Error)]
82pub enum Error {
83    #[error("Token generation failed: {0}")]
84    TokenGeneration(String),
85
86    #[error("Invalid header value")]
87    InvalidHeader,
88
89    #[cfg(feature = "jwt")]
90    #[error("I/O error: {0}")]
91    Io(#[from] std::io::Error),
92
93    #[cfg(feature = "jwt")]
94    #[error("JSON parsing error: {0}")]
95    Json(#[from] serde_json::Error),
96
97    #[error("Private key parsing failed: {0}")]
98    #[cfg(feature = "jwt")]
99    PrivateKey(#[from] PrivateKeyError),
100
101    #[error("System time error: {0}")]
102    #[cfg(feature = "jwt")]
103    SystemTime(#[from] SystemTimeError),
104
105    #[error("Invalid token lifetime")]
106    #[cfg(feature = "jwt")]
107    InvalidLifetime,
108}
109
110/// Private key parsing specific errors
111#[cfg(feature = "jwt")]
112#[derive(Debug, Error)]
113pub enum PrivateKeyError {
114    #[error("PKCS#1 parsing error: {0}")]
115    Pkcs1(#[from] Pkcs1Error),
116
117    #[error("PKCS#8 parsing error: {0}")]
118    Pkcs8(#[from] Pkcs8Error),
119
120    #[error("PEM format error: {0}")]
121    Pem(#[from] pem::PemError),
122}
123
124#[cfg(feature = "jwt")]
125const DEFAULT_TOKEN_LIFETIME: Duration = Duration::from_secs(3600);
126#[cfg(feature = "jwt")]
127const JWT_AUDIENCE: &str = "https://generativelanguage.googleapis.com/";
128#[cfg(feature = "jwt")]
129const JWT_HEADER: &str = "authorization";
130
131const API_KEY_HEADER: &str = "x-goog-api-key";
132
133impl Auth {
134    /// Creates API key authentication
135    pub fn new(api_key: &str) -> Self {
136        Self::ApiKey(api_key.to_owned())
137    }
138
139    /// Creates service account authentication from JSON file
140    ///
141    ///
142    /// # Example
143    /// ```
144    /// # use google_ai_rs::auth::Auth;
145    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
146    /// let auth = Auth::service_account("path/to/service-account.json")
147    ///     .await
148    ///     .expect("Valid service account");
149    /// # Ok(())
150    /// # }
151    /// ```
152    #[deprecated(note = "use `Auth::service` instead")]
153    #[allow(deprecated)]
154    #[cfg(feature = "jwt")]
155    pub async fn service_account<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
156        Ok(Self::TokenSource(
157            TokenSource::from_service_account(path).await?,
158        ))
159    }
160
161    /// Creates service account authentication from JSON file
162    ///
163    /// # Example
164    /// ```
165    /// # use google_ai_rs::auth::Auth;
166    /// # async fn f() -> Result<(), Box<dyn std::error::Error>> {
167    /// let auth = Auth::service("path/to/service-account.json")
168    ///     .await
169    ///     .expect("Valid service account");
170    /// # Ok(())
171    /// # }
172    /// ```
173    #[cfg(feature = "jwt")]
174    pub async fn service<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
175        Ok(Self::TokenSource(TokenSource::service(path).await?))
176    }
177
178    /// Creates JWT authentication from configuration
179    #[deprecated(note = "use `Auth::jwt` instead")]
180    #[allow(deprecated)]
181    #[cfg(feature = "jwt")]
182    pub async fn from_jwt_config(config: JWTConfig) -> Result<Self, Error> {
183        Ok(Self::TokenSource(TokenSource::from_jwt(config).await?))
184    }
185
186    /// Creates JWT authentication from configuration
187    #[cfg(feature = "jwt")]
188    pub fn jwt(config: JWTConfig) -> Self {
189        Self::TokenSource(TokenSource::jwt(config))
190    }
191}
192
193#[cfg(feature = "jwt")]
194impl TokenSource {
195    /// Creates service account authentication from JSON file
196    #[deprecated(note = "use `TokenSource::service` instead")]
197    #[allow(deprecated)]
198    #[cfg(feature = "jwt")]
199    pub async fn from_service_account<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
200        let json = tokio::fs::read(path).await?;
201        let config: JWTConfig = serde_json::from_slice(&json)?;
202        Self::from_jwt(config).await
203    }
204
205    /// Creates service account authentication from JSON file
206    #[cfg(feature = "jwt")]
207    pub async fn service<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
208        let json = tokio::fs::read(path).await?;
209        let config: JWTConfig = serde_json::from_slice(&json)?;
210        Ok(Self::jwt(config))
211    }
212
213    /// Creates JWT authentication from configuration
214    #[deprecated(note = "use `TokenSource::jwt` instead")]
215    #[allow(deprecated)]
216    #[cfg(feature = "jwt")]
217    pub async fn from_jwt(config: JWTConfig) -> Result<Self, Error> {
218        config.parsed().map(|jwt| Self::Jwt { jwt: Box::new(jwt) })
219    }
220
221    /// Creates JWT authentication from configuration
222    #[cfg(feature = "jwt")]
223    pub fn jwt(config: JWTConfig) -> Self {
224        Self::JWT(config)
225    }
226}
227
228#[derive(Debug)]
229pub(crate) enum AuthParsed {
230    ApiKey(RawRequestHeaderValue),
231    #[cfg(feature = "jwt")]
232    JwtKind(JwtService),
233}
234
235impl Auth {
236    pub(crate) fn parsed(self) -> Result<AuthParsed, Error> {
237        match self {
238            // FIXME: InvalidHeader is such a poor error variant...
239            Auth::ApiKey(api_key) => Ok(AuthParsed::ApiKey(
240                api_key.parse().map_err(|_| Error::InvalidHeader)?,
241            )),
242            #[cfg(feature = "jwt")]
243            Auth::TokenSource(token_source) => match token_source {
244                #[allow(deprecated)]
245                // FIXME: Revalidate the unchecked jwt
246                TokenSource::Jwt { jwt } => Ok(AuthParsed::JwtKind(*jwt)),
247                TokenSource::JWT(jwtconfig) => jwtconfig.parsed().map(AuthParsed::JwtKind),
248            },
249        }
250    }
251}
252
253#[cfg(feature = "jwt")]
254impl JWTConfig {
255    fn parsed(self) -> Result<JwtService, Error> {
256        JwtService::new(self)
257    }
258}
259
260impl AuthParsed {
261    /// Adds authentication headers to gRPC requests
262    #[cfg(not(feature = "jwt"))]
263    pub(crate) fn _into_request(self, request: &mut tonic::RawRequestHeader) {
264        match self {
265            Self::ApiKey(metadata_value) => {
266                request.insert(API_KEY_HEADER, metadata_value);
267            }
268        }
269    }
270
271    // /// Adds authentication headers to gRPC requests
272    // #[cfg(feature = "jwt")]
273    // pub(crate) async fn _into_request(
274    //     self,
275    //     request: &mut tonic::RawRequestHeader,
276    // ) {
277    //     match self {
278    //         Self::ApiKey(metadata_value) => {
279    //             request.insert(API_KEY_HEADER, metadata_value);
280    //         }
281    //         Self::JwtKind(jwt_service) => {
282    //             let token = jwt_service.get_token().await;
283    //             request.insert(JWT_HEADER, token);
284    //         }
285    //     }
286    // }
287
288    // /// Adds authentication headers to gRPC requests
289    // #[cfg(not(feature = "jwt"))]
290    // pub(crate) fn to_request(&self, request: &mut tonic::RawRequestHeader) {
291    //     match self {
292    //         Self::ApiKey(metadata_value) => {
293    //             request.insert(API_KEY_HEADER, metadata_value.clone());
294    //         }
295    //     }
296    // }
297
298    /// Adds authentication headers to gRPC requests
299    #[cfg(feature = "jwt")]
300    pub(crate) async fn to_request(&self, request: &mut tonic::RawRequestHeader) {
301        match self {
302            Self::ApiKey(metadata_value) => {
303                request.insert(API_KEY_HEADER, metadata_value.clone());
304            }
305            Self::JwtKind(jwt_service) => {
306                let token = jwt_service.get_token().await;
307                request.insert(JWT_HEADER, token);
308            }
309        }
310    }
311}
312
313#[cfg(feature = "jwt")]
314use hidden::*;
315
316#[cfg(feature = "jwt")]
317mod hidden {
318    use std::ops::Deref;
319
320    use super::*;
321
322    /// JWT token service with caching
323    #[derive(Clone, Debug)]
324    pub struct JwtService {
325        pub(super) inner: Arc<JwtServiceInner>,
326    }
327
328    #[derive(Debug)]
329    pub struct JwtServiceInner {
330        pub(super) config: JWTConfig,
331        pub(super) signing_key: SigningKey<Sha256>,
332        pub(super) cache: RwLock<JwtCache>,
333    }
334
335    impl Deref for JwtService {
336        type Target = JwtServiceInner;
337
338        fn deref(&self) -> &Self::Target {
339            &self.inner
340        }
341    }
342    // TODO: Rename to JwtParsed
343}
344
345/// Cached JWT token data
346#[derive(Debug)]
347#[cfg(feature = "jwt")]
348struct JwtCache {
349    token: RawRequestHeaderValue,
350    expires_at: SystemTime,
351}
352
353#[cfg(feature = "jwt")]
354impl JwtService {
355    fn new(config: JWTConfig) -> Result<Self, Error> {
356        let private_key = parse_private_key(config.private_key.as_bytes())?;
357        let signing_key = SigningKey::<Sha256>::new(private_key);
358
359        let (token, expires_at) = Self::generate_token_(&config, &signing_key)?;
360
361        Ok(JwtService {
362            inner: Arc::new(JwtServiceInner {
363                config,
364                signing_key,
365                cache: RwLock::new(JwtCache { token, expires_at }),
366            }),
367        })
368    }
369
370    #[inline(always)]
371    fn generate_token_(
372        config: &JWTConfig,
373        signing_key: &SigningKey<Sha256>,
374    ) -> Result<(RawRequestHeaderValue, SystemTime), Error> {
375        #[derive(Serialize)]
376        struct JwtHeader<'a> {
377            alg: &'static str,
378            typ: &'static str,
379            kid: &'a str,
380        }
381
382        let header = JwtHeader {
383            alg: "RS256",
384            typ: "JWT",
385            kid: &config.private_key_id,
386        };
387
388        let now = SystemTime::now();
389        // FIXME: Have fallback?
390        let iat = now.duration_since(SystemTime::UNIX_EPOCH)?.as_secs();
391        let lifetime = config.lifetime.unwrap_or(DEFAULT_TOKEN_LIFETIME);
392
393        let exp = iat + lifetime.as_secs();
394
395        #[derive(Serialize)]
396        struct JwtClaims<'a> {
397            iss: &'a str,
398            sub: &'a str,
399            aud: &'static str,
400            exp: u64,
401            iat: u64,
402        }
403
404        let claims = JwtClaims {
405            iss: &config.client_email,
406            sub: &config.client_email,
407            aud: JWT_AUDIENCE,
408            exp,
409            iat,
410        };
411
412        // The serializations can't fail
413        let encoded_header = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).unwrap());
414        let encoded_claims = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims).unwrap());
415        let message = format!("{encoded_header}.{encoded_claims}");
416
417        let signature = signing_key
418            .sign_with_rng(&mut rand::thread_rng(), message.as_bytes())
419            .to_bytes();
420
421        let encoded_sig = URL_SAFE_NO_PAD.encode(signature);
422        let jwt_token = format!("Bearer {message}.{encoded_sig}")
423            .parse()
424            // Once the start auth is valid, it'll continue to be valid
425            .map_err(|_| Error::InvalidHeader)?;
426        let expires_at = now + lifetime;
427
428        Ok((jwt_token, expires_at))
429    }
430
431    /// Generates a new signed JWT token
432    fn generate_token_infallibly(&self) -> (RawRequestHeaderValue, SystemTime) {
433        Self::generate_token_(&self.config, &self.signing_key).unwrap()
434    }
435
436    /// Retrieves valid token from cache or generates new one
437    async fn get_token(&self) -> RawRequestHeaderValue {
438        // Fast path: check cache with read lock
439        {
440            let cache = self.cache.read().await;
441            if SystemTime::now() < cache.expires_at {
442                return cache.token.clone();
443            }
444        }
445
446        // Slow path: regenerate token with write lock
447
448        // Once the start auth is valid (parsed is called and it's called in client builder),
449        // it'll continue to be valid
450        let (new_token, expires_at) = self.generate_token_infallibly();
451        let mut cache = self.cache.write().await;
452
453        *cache = JwtCache {
454            token: new_token,
455            expires_at,
456        };
457
458        cache.token.clone()
459    }
460}
461
462/// Parses RSA private key from multiple formats
463#[cfg(feature = "jwt")]
464fn parse_private_key(bytes: &[u8]) -> Result<RsaPrivateKey, PrivateKeyError> {
465    // Try PEM format first
466    if let Ok(pem) = pem::parse(bytes) {
467        return RsaPrivateKey::from_pkcs8_der(pem.contents())
468            .or_else(|_| RsaPrivateKey::from_pkcs1_der(pem.contents()))
469            .map_err(Into::into);
470    }
471
472    // Fallback to DER format
473    RsaPrivateKey::from_pkcs8_der(bytes)
474        .or_else(|_| RsaPrivateKey::from_pkcs1_der(bytes))
475        .map_err(Into::into)
476}