nblm_core/
auth.rs

1use std::path::Path;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4use std::{env, fs};
5
6use async_trait::async_trait;
7use chrono::{Duration as ChronoDuration, Utc};
8use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use tokio::process::Command;
12use tokio::sync::Mutex;
13
14use crate::error::{Error, Result};
15
16#[async_trait]
17pub trait TokenProvider: Send + Sync {
18    async fn access_token(&self) -> Result<String>;
19    async fn refresh_token(&self) -> Result<String> {
20        self.access_token().await
21    }
22}
23
24#[derive(Debug, Default, Clone)]
25pub struct GcloudTokenProvider {
26    binary: String,
27}
28
29impl GcloudTokenProvider {
30    pub fn new(binary: impl Into<String>) -> Self {
31        Self {
32            binary: binary.into(),
33        }
34    }
35}
36
37#[async_trait]
38impl TokenProvider for GcloudTokenProvider {
39    async fn access_token(&self) -> Result<String> {
40        let output = Command::new(&self.binary)
41            .arg("auth")
42            .arg("print-access-token")
43            .output()
44            .await
45            .map_err(|err| Error::TokenProvider(err.to_string()))?;
46
47        if !output.status.success() {
48            return Err(Error::TokenProvider(format!(
49                "gcloud exited with status {}",
50                output.status
51            )));
52        }
53
54        let token = String::from_utf8(output.stdout)
55            .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;
56
57        Ok(token.trim().to_owned())
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct EnvTokenProvider {
63    key: String,
64}
65
66impl EnvTokenProvider {
67    pub fn new(key: impl Into<String>) -> Self {
68        Self { key: key.into() }
69    }
70}
71
72#[async_trait]
73impl TokenProvider for EnvTokenProvider {
74    async fn access_token(&self) -> Result<String> {
75        env::var(&self.key)
76            .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
77    }
78}
79
80#[derive(Debug, Clone)]
81pub struct StaticTokenProvider {
82    token: String,
83}
84
85impl StaticTokenProvider {
86    pub fn new(token: impl Into<String>) -> Self {
87        Self {
88            token: token.into(),
89        }
90    }
91}
92
93#[async_trait]
94impl TokenProvider for StaticTokenProvider {
95    async fn access_token(&self) -> Result<String> {
96        Ok(self.token.clone())
97    }
98}
99
100#[derive(Debug, Clone, Deserialize)]
101struct ServiceAccountKey {
102    #[serde(rename = "client_email")]
103    client_email: String,
104    #[serde(rename = "private_key")]
105    private_key: String,
106    #[serde(rename = "token_uri")]
107    token_uri: String,
108}
109
110#[derive(Debug, Clone)]
111struct CachedToken {
112    token: String,
113    expires_at: Instant,
114}
115
116#[derive(Debug, Clone)]
117pub struct ServiceAccountTokenProvider {
118    key: ServiceAccountKey,
119    scopes: Vec<String>,
120    cache: Arc<Mutex<Option<CachedToken>>>,
121    client: Client,
122    leeway: Duration,
123    http_timeout: Duration,
124}
125
126impl ServiceAccountTokenProvider {
127    pub fn from_file(path: impl AsRef<Path>, scopes: Vec<String>) -> Result<Self> {
128        let data = fs::read_to_string(path).map_err(|err| Error::TokenProvider(err.to_string()))?;
129        Self::from_json(&data, scopes)
130    }
131
132    pub fn from_json(data: &str, scopes: Vec<String>) -> Result<Self> {
133        let key: ServiceAccountKey = serde_json::from_str(data).map_err(|err| {
134            Error::TokenProvider(format!("failed to parse service account key: {err}"))
135        })?;
136        let client = Client::builder()
137            .build()
138            .map_err(|err| Error::TokenProvider(format!("failed to build HTTP client: {err}")))?;
139        Ok(Self {
140            key,
141            scopes,
142            cache: Arc::new(Mutex::new(None)),
143            client,
144            leeway: Duration::from_secs(60),
145            http_timeout: Duration::from_secs(10),
146        })
147    }
148
149    pub fn with_leeway(mut self, leeway: Duration) -> Self {
150        self.leeway = leeway;
151        self
152    }
153
154    pub fn with_http_timeout(mut self, timeout: Duration) -> Self {
155        self.http_timeout = timeout;
156        self
157    }
158
159    async fn cached_token(&self) -> Option<String> {
160        let cache = self.cache.lock().await;
161        cache
162            .as_ref()
163            .filter(|cached| Instant::now() < cached.expires_at)
164            .map(|cached| cached.token.clone())
165    }
166
167    async fn store_token(&self, token: String, expires_in: i64) {
168        let valid_for = Duration::from_secs(expires_in.max(0) as u64);
169        let now = Instant::now();
170        let expires_at = now + valid_for;
171        let expires_at = expires_at.checked_sub(self.leeway).unwrap_or(now);
172        let mut cache = self.cache.lock().await;
173        *cache = Some(CachedToken { token, expires_at });
174    }
175
176    fn create_jwt(&self) -> Result<String> {
177        #[derive(Serialize)]
178        struct Claims<'a> {
179            iss: &'a str,
180            scope: String,
181            aud: &'a str,
182            exp: i64,
183            iat: i64,
184        }
185
186        let now = Utc::now();
187        let exp = now + ChronoDuration::seconds(3600);
188        let claims = Claims {
189            iss: &self.key.client_email,
190            scope: self.scopes.join(" "),
191            aud: &self.key.token_uri,
192            exp: exp.timestamp(),
193            iat: now.timestamp(),
194        };
195
196        let header = Header::new(Algorithm::RS256);
197        encode(
198            &header,
199            &claims,
200            &EncodingKey::from_rsa_pem(self.key.private_key.as_bytes())
201                .map_err(|err| Error::TokenProvider(err.to_string()))?,
202        )
203        .map_err(|err| Error::TokenProvider(err.to_string()))
204    }
205}
206
207#[derive(Debug, Deserialize)]
208struct TokenResponse {
209    access_token: String,
210    expires_in: i64,
211}
212
213#[derive(Serialize)]
214struct TokenRequest<'a> {
215    grant_type: &'a str,
216    assertion: &'a str,
217}
218
219#[async_trait]
220impl TokenProvider for ServiceAccountTokenProvider {
221    async fn access_token(&self) -> Result<String> {
222        if let Some(token) = self.cached_token().await {
223            return Ok(token);
224        }
225
226        let assertion = self.create_jwt()?;
227        let body = TokenRequest {
228            grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
229            assertion: &assertion,
230        };
231
232        let response = self
233            .client
234            .post(&self.key.token_uri)
235            .timeout(self.http_timeout)
236            .form(&body)
237            .send()
238            .await
239            .map_err(|err| Error::TokenProvider(err.to_string()))?;
240
241        if !response.status().is_success() {
242            let status = response.status();
243            let text = response
244                .text()
245                .await
246                .unwrap_or_else(|_| "<failed to read body>".to_string());
247            return Err(Error::TokenProvider(format!(
248                "token endpoint error {}: {}",
249                status, text
250            )));
251        }
252
253        let token_response: TokenResponse = response
254            .json()
255            .await
256            .map_err(|err| Error::TokenProvider(format!("invalid token response: {err}")))?;
257
258        self.store_token(
259            token_response.access_token.clone(),
260            token_response.expires_in,
261        )
262        .await;
263        Ok(token_response.access_token)
264    }
265
266    async fn refresh_token(&self) -> Result<String> {
267        {
268            let mut cache = self.cache.lock().await;
269            *cache = None;
270        }
271        self.access_token().await
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    const TEST_KEY: &str = r#"{
280        "client_email": "test@example.com",
281        "private_key": "-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n",
282        "token_uri": "https://example.com"
283    }"#;
284
285    fn provider() -> ServiceAccountTokenProvider {
286        ServiceAccountTokenProvider::from_json(TEST_KEY, vec!["scope".to_string()]).unwrap()
287    }
288
289    #[tokio::test]
290    async fn service_account_cache_respects_leeway() {
291        let provider = provider().with_leeway(Duration::from_secs(90));
292        provider.store_token("token".to_string(), 60).await;
293        assert!(provider.cached_token().await.is_none());
294    }
295
296    #[tokio::test]
297    async fn service_account_cache_keeps_valid_token() {
298        let provider = provider();
299        provider.store_token("token".to_string(), 120).await;
300        assert_eq!(provider.cached_token().await, Some("token".to_string()));
301    }
302
303    #[tokio::test]
304    async fn static_token_provider_returns_token() {
305        let provider = StaticTokenProvider::new("test-token-123");
306        let token = provider.access_token().await.unwrap();
307        assert_eq!(token, "test-token-123");
308    }
309
310    #[tokio::test]
311    async fn env_token_provider_reads_from_env() {
312        std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
313        let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
314        let token = provider.access_token().await.unwrap();
315        assert_eq!(token, "env-token-456");
316        std::env::remove_var("TEST_NBLM_TOKEN");
317    }
318
319    #[tokio::test]
320    async fn env_token_provider_errors_when_missing() {
321        std::env::remove_var("NONEXISTENT_TOKEN");
322        let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
323        let result = provider.access_token().await;
324        assert!(result.is_err());
325        assert!(result
326            .unwrap_err()
327            .to_string()
328            .contains("environment variable NONEXISTENT_TOKEN missing"));
329    }
330}