nblm_core/
auth.rs

1use std::env;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::process::Command;
7
8use crate::error::{Error, Result};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ProviderKind {
12    GcloudOauth,
13    EnvAccessToken,
14    StaticToken,
15    UserOauth,
16}
17
18impl ProviderKind {
19    pub fn as_str(&self) -> &'static str {
20        match self {
21            ProviderKind::GcloudOauth => "gcloud-oauth",
22            ProviderKind::EnvAccessToken => "env-access-token",
23            ProviderKind::StaticToken => "static-token",
24            ProviderKind::UserOauth => "user-oauth",
25        }
26    }
27
28    pub fn is_experimental(&self) -> bool {
29        matches!(self, ProviderKind::UserOauth)
30    }
31}
32
33#[async_trait]
34pub trait TokenProvider: Send + Sync {
35    async fn access_token(&self) -> Result<String>;
36    async fn refresh_token(&self) -> Result<String> {
37        self.access_token().await
38    }
39
40    fn kind(&self) -> ProviderKind {
41        ProviderKind::StaticToken
42    }
43}
44
45const TOKENINFO_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/tokeninfo";
46const DRIVE_SCOPE: &str = "https://www.googleapis.com/auth/drive";
47const DRIVE_FILE_SCOPE: &str = "https://www.googleapis.com/auth/drive.file";
48
49#[derive(Debug, Deserialize)]
50struct TokenInfoResponse {
51    scope: Option<String>,
52}
53
54pub async fn ensure_drive_scope(provider: &dyn TokenProvider) -> Result<()> {
55    let client = Client::new();
56    let endpoint =
57        std::env::var("NBLM_TOKENINFO_ENDPOINT").unwrap_or_else(|_| TOKENINFO_ENDPOINT.to_string());
58    ensure_drive_scope_internal(provider, &client, &endpoint).await
59}
60
61async fn ensure_drive_scope_internal(
62    provider: &dyn TokenProvider,
63    client: &Client,
64    endpoint: &str,
65) -> Result<()> {
66    let access_token = provider.access_token().await?;
67
68    let response = client
69        .get(endpoint)
70        .query(&[("access_token", access_token.as_str())])
71        .send()
72        .await
73        .map_err(|err| {
74            Error::TokenProvider(format!("failed to validate Google Drive token: {err}"))
75        })?;
76
77    if !response.status().is_success() {
78        let status = response.status();
79        let body = response
80            .text()
81            .await
82            .unwrap_or_else(|_| String::from("<failed to read body>"));
83        return Err(Error::TokenProvider(format!(
84            "failed to validate Google Drive token (status {}): {}",
85            status.as_u16(),
86            body.trim()
87        )));
88    }
89
90    let info: TokenInfoResponse = response
91        .json()
92        .await
93        .map_err(|err| Error::TokenProvider(format!("invalid tokeninfo response: {err}")))?;
94
95    let scopes = info.scope.unwrap_or_default();
96    if scope_grants_drive_access(&scopes) {
97        Ok(())
98    } else {
99        Err(Error::TokenProvider(
100            "Google Drive access token is missing the required drive.file scope. Run `gcloud auth login --enable-gdrive-access` and retry.".to_string(),
101        ))
102    }
103}
104
105fn scope_grants_drive_access(scopes: &str) -> bool {
106    scopes
107        .split_whitespace()
108        .any(|scope| scope == DRIVE_FILE_SCOPE || scope == DRIVE_SCOPE)
109}
110
111#[cfg(test)]
112pub(crate) async fn ensure_drive_scope_with_endpoint(
113    provider: &dyn TokenProvider,
114    client: &Client,
115    endpoint: &str,
116) -> Result<()> {
117    ensure_drive_scope_internal(provider, client, endpoint).await
118}
119
120#[derive(Debug, Default, Clone)]
121pub struct GcloudTokenProvider {
122    binary: String,
123}
124
125impl GcloudTokenProvider {
126    pub fn new(binary: impl Into<String>) -> Self {
127        Self {
128            binary: binary.into(),
129        }
130    }
131}
132
133#[async_trait]
134impl TokenProvider for GcloudTokenProvider {
135    async fn access_token(&self) -> Result<String> {
136        let output = Command::new(&self.binary)
137            .arg("auth")
138            .arg("print-access-token")
139            .output()
140            .await
141            .map_err(|err| {
142                Error::TokenProvider(format!(
143                    "Failed to execute gcloud command. Make sure gcloud CLI is installed and in PATH.\nError: {}",
144                    err
145                ))
146            })?;
147
148        if !output.status.success() {
149            let stderr = String::from_utf8_lossy(&output.stderr);
150            return Err(Error::TokenProvider(format!(
151                "Failed to get access token from gcloud. Please run 'gcloud auth login' to authenticate.\nError: {}",
152                stderr.trim()
153            )));
154        }
155
156        let token = String::from_utf8(output.stdout)
157            .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;
158
159        Ok(token.trim().to_owned())
160    }
161
162    fn kind(&self) -> ProviderKind {
163        ProviderKind::GcloudOauth
164    }
165}
166
167#[derive(Debug, Clone)]
168pub struct EnvTokenProvider {
169    key: String,
170}
171
172impl EnvTokenProvider {
173    pub fn new(key: impl Into<String>) -> Self {
174        Self { key: key.into() }
175    }
176}
177
178#[async_trait]
179impl TokenProvider for EnvTokenProvider {
180    async fn access_token(&self) -> Result<String> {
181        env::var(&self.key)
182            .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
183    }
184
185    fn kind(&self) -> ProviderKind {
186        ProviderKind::EnvAccessToken
187    }
188}
189
190#[derive(Debug, Clone)]
191pub struct StaticTokenProvider {
192    token: String,
193}
194
195impl StaticTokenProvider {
196    pub fn new(token: impl Into<String>) -> Self {
197        Self {
198            token: token.into(),
199        }
200    }
201}
202
203#[async_trait]
204impl TokenProvider for StaticTokenProvider {
205    async fn access_token(&self) -> Result<String> {
206        Ok(self.token.clone())
207    }
208
209    fn kind(&self) -> ProviderKind {
210        ProviderKind::StaticToken
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use wiremock::matchers::{method, path, query_param};
218    use wiremock::{Mock, MockServer, ResponseTemplate};
219
220    #[tokio::test]
221    async fn static_token_provider_returns_token() {
222        let provider = StaticTokenProvider::new("test-token-123");
223        let token = provider.access_token().await.unwrap();
224        assert_eq!(token, "test-token-123");
225    }
226
227    #[tokio::test]
228    async fn env_token_provider_reads_from_env() {
229        std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
230        let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
231        let token = provider.access_token().await.unwrap();
232        assert_eq!(token, "env-token-456");
233        std::env::remove_var("TEST_NBLM_TOKEN");
234    }
235
236    #[tokio::test]
237    async fn env_token_provider_errors_when_missing() {
238        std::env::remove_var("NONEXISTENT_TOKEN");
239        let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
240        let result = provider.access_token().await;
241        assert!(result.is_err());
242        assert!(result
243            .unwrap_err()
244            .to_string()
245            .contains("environment variable NONEXISTENT_TOKEN missing"));
246    }
247
248    #[test]
249    fn provider_kind_as_str_returns_correct_labels() {
250        assert_eq!(ProviderKind::GcloudOauth.as_str(), "gcloud-oauth");
251        assert_eq!(ProviderKind::EnvAccessToken.as_str(), "env-access-token");
252        assert_eq!(ProviderKind::StaticToken.as_str(), "static-token");
253        assert_eq!(ProviderKind::UserOauth.as_str(), "user-oauth");
254    }
255
256    #[test]
257    fn provider_kind_is_experimental_only_for_user_oauth() {
258        assert!(!ProviderKind::GcloudOauth.is_experimental());
259        assert!(!ProviderKind::EnvAccessToken.is_experimental());
260        assert!(!ProviderKind::StaticToken.is_experimental());
261        assert!(ProviderKind::UserOauth.is_experimental());
262    }
263
264    #[test]
265    fn gcloud_token_provider_returns_correct_kind() {
266        let provider = GcloudTokenProvider::new("gcloud");
267        assert_eq!(provider.kind(), ProviderKind::GcloudOauth);
268    }
269
270    #[test]
271    fn env_token_provider_returns_correct_kind() {
272        let provider = EnvTokenProvider::new("TEST_TOKEN");
273        assert_eq!(provider.kind(), ProviderKind::EnvAccessToken);
274    }
275
276    #[test]
277    fn static_token_provider_returns_correct_kind() {
278        let provider = StaticTokenProvider::new("token");
279        assert_eq!(provider.kind(), ProviderKind::StaticToken);
280    }
281
282    fn expect_scope_result(scopes: &str, expected: bool) {
283        assert_eq!(scope_grants_drive_access(scopes), expected);
284    }
285
286    #[test]
287    fn scope_grants_drive_access_detects_required_scopes() {
288        expect_scope_result(DRIVE_FILE_SCOPE, true);
289        expect_scope_result(DRIVE_SCOPE, true);
290        expect_scope_result(
291            "https://www.googleapis.com/auth/spreadsheets.readonly",
292            false,
293        );
294        expect_scope_result(
295            &format!("{DRIVE_FILE_SCOPE} https://www.googleapis.com/auth/calendar"),
296            true,
297        );
298    }
299
300    #[tokio::test]
301    async fn ensure_drive_scope_accepts_valid_scope() {
302        let server = MockServer::start().await;
303        Mock::given(method("GET"))
304            .and(path("/oauth2/v3/tokeninfo"))
305            .and(query_param("access_token", "valid-token"))
306            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307                "scope": DRIVE_FILE_SCOPE
308            })))
309            .mount(&server)
310            .await;
311
312        let provider = StaticTokenProvider::new("valid-token");
313        let client = reqwest::Client::new();
314        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
315        let result = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint).await;
316        assert!(result.is_ok());
317    }
318
319    #[tokio::test]
320    async fn ensure_drive_scope_rejects_missing_scope() {
321        let server = MockServer::start().await;
322        Mock::given(method("GET"))
323            .and(path("/oauth2/v3/tokeninfo"))
324            .and(query_param("access_token", "no-scope"))
325            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
326                "scope": "https://www.googleapis.com/auth/spreadsheets.readonly"
327            })))
328            .mount(&server)
329            .await;
330
331        let provider = StaticTokenProvider::new("no-scope");
332        let client = reqwest::Client::new();
333        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
334        let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
335            .await
336            .unwrap_err();
337
338        match err {
339            Error::TokenProvider(message) => {
340                assert!(message.contains("drive.file scope"));
341            }
342            _ => panic!("expected TokenProvider error"),
343        }
344    }
345
346    #[tokio::test]
347    async fn ensure_drive_scope_converts_http_failures() {
348        let server = MockServer::start().await;
349        Mock::given(method("GET"))
350            .and(path("/oauth2/v3/tokeninfo"))
351            .and(query_param("access_token", "bad-token"))
352            .respond_with(ResponseTemplate::new(400).set_body_string("invalid_token"))
353            .mount(&server)
354            .await;
355
356        let provider = StaticTokenProvider::new("bad-token");
357        let client = reqwest::Client::new();
358        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
359        let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
360            .await
361            .unwrap_err();
362
363        match err {
364            Error::TokenProvider(message) => {
365                assert!(message.contains("status 400"));
366            }
367            _ => panic!("expected TokenProvider error"),
368        }
369    }
370}