1use serde::Serialize;
2use thiserror::Error;
3use tokio::stream;
4use tonic::RawRequestHeaderValue;
5
6#[cfg(feature = "jwt")]
7use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
8#[cfg(feature = "jwt")]
9use rsa::{
10 pkcs1::{DecodeRsaPrivateKey, Error as Pkcs1Error},
11 pkcs1v15::SigningKey,
12 pkcs8::{DecodePrivateKey, Error as Pkcs8Error},
13 signature::{RandomizedSigner, SignatureEncoding},
14 RsaPrivateKey,
15};
16#[cfg(feature = "jwt")]
17use serde::Deserialize;
18#[cfg(feature = "jwt")]
19use serde_json::json;
20#[cfg(feature = "jwt")]
21use sha2::Sha256;
22#[cfg(feature = "jwt")]
23use std::{
24 path::Path,
25 sync::Arc,
26 time::{Duration, SystemTime, SystemTimeError},
27};
28#[cfg(feature = "jwt")]
29use tokio::sync::RwLock;
30
31#[derive(Clone, Debug)]
33pub enum Auth {
34 ApiKey(String),
36 #[cfg(feature = "jwt")]
37 TokenSource(TokenSource),
39}
40
41impl<S: Into<String>> From<S> for Auth {
42 fn from(value: S) -> Self {
43 Auth::ApiKey(value.into())
44 }
45}
46
47#[cfg(feature = "jwt")]
49#[derive(Deserialize, Clone, Debug)]
50pub struct JWTConfig {
51 #[serde(rename = "client_email")]
53 pub client_email: String,
54
55 #[serde(rename = "private_key")]
57 pub private_key: String,
58
59 #[serde(rename = "private_key_id")]
61 pub private_key_id: String,
62
63 #[serde(skip)]
65 pub lifetime: Option<Duration>,
66}
67
68#[cfg(feature = "jwt")]
70#[derive(Clone, Debug)]
71pub enum TokenSource {
72 #[deprecated(note = "use `TokenSource::JWT`(all-caps) instead")]
73 Jwt { jwt: Box<JwtService> },
75
76 JWT(JWTConfig),
78}
79
80#[derive(Debug, Error)]
82pub enum Error {
83 #[error("Token generation failed: {0}")]
84 TokenGeneration(String),
85
86 #[error("Invalid header value")]
87 InvalidHeader,
88
89 #[cfg(feature = "jwt")]
90 #[error("I/O error: {0}")]
91 Io(#[from] std::io::Error),
92
93 #[cfg(feature = "jwt")]
94 #[error("JSON parsing error: {0}")]
95 Json(#[from] serde_json::Error),
96
97 #[error("Private key parsing failed: {0}")]
98 #[cfg(feature = "jwt")]
99 PrivateKey(#[from] PrivateKeyError),
100
101 #[error("System time error: {0}")]
102 #[cfg(feature = "jwt")]
103 SystemTime(#[from] SystemTimeError),
104
105 #[error("Invalid token lifetime")]
106 #[cfg(feature = "jwt")]
107 InvalidLifetime,
108}
109
110#[cfg(feature = "jwt")]
112#[derive(Debug, Error)]
113pub enum PrivateKeyError {
114 #[error("PKCS#1 parsing error: {0}")]
115 Pkcs1(#[from] Pkcs1Error),
116
117 #[error("PKCS#8 parsing error: {0}")]
118 Pkcs8(#[from] Pkcs8Error),
119
120 #[error("PEM format error: {0}")]
121 Pem(#[from] pem::PemError),
122}
123
124#[cfg(feature = "jwt")]
125const DEFAULT_TOKEN_LIFETIME: Duration = Duration::from_secs(3600);
126#[cfg(feature = "jwt")]
127const JWT_AUDIENCE: &str = "https://generativelanguage.googleapis.com/";
128#[cfg(feature = "jwt")]
129const JWT_HEADER: &str = "authorization";
130
131const API_KEY_HEADER: &str = "x-goog-api-key";
132
133impl Auth {
134 pub fn new(api_key: &str) -> Self {
136 Self::ApiKey(api_key.to_owned())
137 }
138
139 #[deprecated(note = "use `Auth::service` instead")]
153 #[allow(deprecated)]
154 #[cfg(feature = "jwt")]
155 pub async fn service_account<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
156 Ok(Self::TokenSource(
157 TokenSource::from_service_account(path).await?,
158 ))
159 }
160
161 #[cfg(feature = "jwt")]
174 pub async fn service<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
175 Ok(Self::TokenSource(TokenSource::service(path).await?))
176 }
177
178 #[deprecated(note = "use `Auth::jwt` instead")]
180 #[allow(deprecated)]
181 #[cfg(feature = "jwt")]
182 pub async fn from_jwt_config(config: JWTConfig) -> Result<Self, Error> {
183 Ok(Self::TokenSource(TokenSource::from_jwt(config).await?))
184 }
185
186 #[cfg(feature = "jwt")]
188 pub fn jwt(config: JWTConfig) -> Self {
189 Self::TokenSource(TokenSource::jwt(config))
190 }
191}
192
193#[cfg(feature = "jwt")]
194impl TokenSource {
195 #[deprecated(note = "use `TokenSource::service` instead")]
197 #[allow(deprecated)]
198 #[cfg(feature = "jwt")]
199 pub async fn from_service_account<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
200 let json = tokio::fs::read(path).await?;
201 let config: JWTConfig = serde_json::from_slice(&json)?;
202 Self::from_jwt(config).await
203 }
204
205 #[cfg(feature = "jwt")]
207 pub async fn service<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
208 let json = tokio::fs::read(path).await?;
209 let config: JWTConfig = serde_json::from_slice(&json)?;
210 Ok(Self::jwt(config))
211 }
212
213 #[deprecated(note = "use `TokenSource::jwt` instead")]
215 #[allow(deprecated)]
216 #[cfg(feature = "jwt")]
217 pub async fn from_jwt(config: JWTConfig) -> Result<Self, Error> {
218 config.parsed().map(|jwt| Self::Jwt { jwt: Box::new(jwt) })
219 }
220
221 #[cfg(feature = "jwt")]
223 pub fn jwt(config: JWTConfig) -> Self {
224 Self::JWT(config)
225 }
226}
227
228#[derive(Debug)]
229pub(crate) enum AuthParsed {
230 ApiKey(RawRequestHeaderValue),
231 #[cfg(feature = "jwt")]
232 JwtKind(JwtService),
233}
234
235impl Auth {
236 pub(crate) fn parsed(self) -> Result<AuthParsed, Error> {
237 match self {
238 Auth::ApiKey(api_key) => Ok(AuthParsed::ApiKey(
240 api_key.parse().map_err(|_| Error::InvalidHeader)?,
241 )),
242 #[cfg(feature = "jwt")]
243 Auth::TokenSource(token_source) => match token_source {
244 #[allow(deprecated)]
245 TokenSource::Jwt { jwt } => Ok(AuthParsed::JwtKind(*jwt)),
247 TokenSource::JWT(jwtconfig) => jwtconfig.parsed().map(AuthParsed::JwtKind),
248 },
249 }
250 }
251}
252
253#[cfg(feature = "jwt")]
254impl JWTConfig {
255 fn parsed(self) -> Result<JwtService, Error> {
256 JwtService::new(self)
257 }
258}
259
260impl AuthParsed {
261 #[cfg(not(feature = "jwt"))]
263 pub(crate) fn _into_request(self, request: &mut tonic::RawRequestHeader) {
264 match self {
265 Self::ApiKey(metadata_value) => {
266 request.insert(API_KEY_HEADER, metadata_value);
267 }
268 }
269 }
270
271 #[cfg(feature = "jwt")]
300 pub(crate) async fn to_request(&self, request: &mut tonic::RawRequestHeader) {
301 match self {
302 Self::ApiKey(metadata_value) => {
303 request.insert(API_KEY_HEADER, metadata_value.clone());
304 }
305 Self::JwtKind(jwt_service) => {
306 let token = jwt_service.get_token().await;
307 request.insert(JWT_HEADER, token);
308 }
309 }
310 }
311}
312
313#[cfg(feature = "jwt")]
314use hidden::*;
315
316#[cfg(feature = "jwt")]
317mod hidden {
318 use std::ops::Deref;
319
320 use super::*;
321
322 #[derive(Clone, Debug)]
324 pub struct JwtService {
325 pub(super) inner: Arc<JwtServiceInner>,
326 }
327
328 #[derive(Debug)]
329 pub struct JwtServiceInner {
330 pub(super) config: JWTConfig,
331 pub(super) signing_key: SigningKey<Sha256>,
332 pub(super) cache: RwLock<JwtCache>,
333 }
334
335 impl Deref for JwtService {
336 type Target = JwtServiceInner;
337
338 fn deref(&self) -> &Self::Target {
339 &self.inner
340 }
341 }
342 }
344
345#[derive(Debug)]
347#[cfg(feature = "jwt")]
348struct JwtCache {
349 token: RawRequestHeaderValue,
350 expires_at: SystemTime,
351}
352
353#[cfg(feature = "jwt")]
354impl JwtService {
355 fn new(config: JWTConfig) -> Result<Self, Error> {
356 let private_key = parse_private_key(config.private_key.as_bytes())?;
357 let signing_key = SigningKey::<Sha256>::new(private_key);
358
359 let (token, expires_at) = Self::generate_token_(&config, &signing_key)?;
360
361 Ok(JwtService {
362 inner: Arc::new(JwtServiceInner {
363 config,
364 signing_key,
365 cache: RwLock::new(JwtCache { token, expires_at }),
366 }),
367 })
368 }
369
370 #[inline(always)]
371 fn generate_token_(
372 config: &JWTConfig,
373 signing_key: &SigningKey<Sha256>,
374 ) -> Result<(RawRequestHeaderValue, SystemTime), Error> {
375 #[derive(Serialize)]
376 struct JwtHeader<'a> {
377 alg: &'static str,
378 typ: &'static str,
379 kid: &'a str,
380 }
381
382 let header = JwtHeader {
383 alg: "RS256",
384 typ: "JWT",
385 kid: &config.private_key_id,
386 };
387
388 let now = SystemTime::now();
389 let iat = now.duration_since(SystemTime::UNIX_EPOCH)?.as_secs();
391 let lifetime = config.lifetime.unwrap_or(DEFAULT_TOKEN_LIFETIME);
392
393 let exp = iat + lifetime.as_secs();
394
395 #[derive(Serialize)]
396 struct JwtClaims<'a> {
397 iss: &'a str,
398 sub: &'a str,
399 aud: &'static str,
400 exp: u64,
401 iat: u64,
402 }
403
404 let claims = JwtClaims {
405 iss: &config.client_email,
406 sub: &config.client_email,
407 aud: JWT_AUDIENCE,
408 exp,
409 iat,
410 };
411
412 let encoded_header = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).unwrap());
414 let encoded_claims = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims).unwrap());
415 let message = format!("{encoded_header}.{encoded_claims}");
416
417 let signature = signing_key
418 .sign_with_rng(&mut rand::thread_rng(), message.as_bytes())
419 .to_bytes();
420
421 let encoded_sig = URL_SAFE_NO_PAD.encode(signature);
422 let jwt_token = format!("Bearer {message}.{encoded_sig}")
423 .parse()
424 .map_err(|_| Error::InvalidHeader)?;
426 let expires_at = now + lifetime;
427
428 Ok((jwt_token, expires_at))
429 }
430
431 fn generate_token_infallibly(&self) -> (RawRequestHeaderValue, SystemTime) {
433 Self::generate_token_(&self.config, &self.signing_key).unwrap()
434 }
435
436 async fn get_token(&self) -> RawRequestHeaderValue {
438 {
440 let cache = self.cache.read().await;
441 if SystemTime::now() < cache.expires_at {
442 return cache.token.clone();
443 }
444 }
445
446 let (new_token, expires_at) = self.generate_token_infallibly();
451 let mut cache = self.cache.write().await;
452
453 *cache = JwtCache {
454 token: new_token,
455 expires_at,
456 };
457
458 cache.token.clone()
459 }
460}
461
462#[cfg(feature = "jwt")]
464fn parse_private_key(bytes: &[u8]) -> Result<RsaPrivateKey, PrivateKeyError> {
465 if let Ok(pem) = pem::parse(bytes) {
467 return RsaPrivateKey::from_pkcs8_der(pem.contents())
468 .or_else(|_| RsaPrivateKey::from_pkcs1_der(pem.contents()))
469 .map_err(Into::into);
470 }
471
472 RsaPrivateKey::from_pkcs8_der(bytes)
474 .or_else(|_| RsaPrivateKey::from_pkcs1_der(bytes))
475 .map_err(Into::into)
476}