nblm_core/auth/
mod.rs

1use std::env;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::process::Command;
7
8use crate::error::{Error, Result};
9
10pub mod oauth;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ProviderKind {
14    GcloudOauth,
15    EnvAccessToken,
16    StaticToken,
17    UserOauth,
18}
19
20impl ProviderKind {
21    pub fn as_str(&self) -> &'static str {
22        match self {
23            ProviderKind::GcloudOauth => "gcloud-oauth",
24            ProviderKind::EnvAccessToken => "env-access-token",
25            ProviderKind::StaticToken => "static-token",
26            ProviderKind::UserOauth => "user-oauth",
27        }
28    }
29
30    pub fn is_experimental(&self) -> bool {
31        matches!(self, ProviderKind::UserOauth)
32    }
33}
34
35#[async_trait]
36pub trait TokenProvider: Send + Sync {
37    async fn access_token(&self) -> Result<String>;
38    async fn refresh_token(&self) -> Result<String> {
39        self.access_token().await
40    }
41
42    fn kind(&self) -> ProviderKind {
43        ProviderKind::StaticToken
44    }
45}
46
47const TOKENINFO_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/tokeninfo";
48const DRIVE_SCOPE: &str = "https://www.googleapis.com/auth/drive";
49const DRIVE_FILE_SCOPE: &str = "https://www.googleapis.com/auth/drive.file";
50
51#[derive(Debug, Deserialize)]
52struct TokenInfoResponse {
53    scope: Option<String>,
54}
55
56pub async fn ensure_drive_scope(provider: &dyn TokenProvider) -> Result<()> {
57    let client = Client::new();
58    let endpoint =
59        std::env::var("NBLM_TOKENINFO_ENDPOINT").unwrap_or_else(|_| TOKENINFO_ENDPOINT.to_string());
60    ensure_drive_scope_internal(provider, &client, &endpoint).await
61}
62
63async fn ensure_drive_scope_internal(
64    provider: &dyn TokenProvider,
65    client: &Client,
66    endpoint: &str,
67) -> Result<()> {
68    let access_token = provider.access_token().await?;
69
70    let response = client
71        .get(endpoint)
72        .query(&[("access_token", access_token.as_str())])
73        .send()
74        .await
75        .map_err(|err| {
76            Error::TokenProvider(format!("failed to validate Google Drive token: {err}"))
77        })?;
78
79    if !response.status().is_success() {
80        let status = response.status();
81        let body = response
82            .text()
83            .await
84            .unwrap_or_else(|_| String::from("<failed to read body>"));
85        return Err(Error::TokenProvider(format!(
86            "failed to validate Google Drive token (status {}): {}",
87            status.as_u16(),
88            body.trim()
89        )));
90    }
91
92    let info: TokenInfoResponse = response
93        .json()
94        .await
95        .map_err(|err| Error::TokenProvider(format!("invalid tokeninfo response: {err}")))?;
96
97    let scopes = info.scope.unwrap_or_default();
98    if scope_grants_drive_access(&scopes) {
99        Ok(())
100    } else {
101        Err(Error::TokenProvider(
102            "Google Drive access token is missing the required drive.file scope. Run `gcloud auth login --enable-gdrive-access` and retry.".to_string(),
103        ))
104    }
105}
106
107fn scope_grants_drive_access(scopes: &str) -> bool {
108    scopes
109        .split_whitespace()
110        .any(|scope| scope == DRIVE_FILE_SCOPE || scope == DRIVE_SCOPE)
111}
112
113#[cfg(test)]
114pub(crate) async fn ensure_drive_scope_with_endpoint(
115    provider: &dyn TokenProvider,
116    client: &Client,
117    endpoint: &str,
118) -> Result<()> {
119    ensure_drive_scope_internal(provider, client, endpoint).await
120}
121
122#[derive(Debug, Default, Clone)]
123pub struct GcloudTokenProvider {
124    binary: String,
125}
126
127impl GcloudTokenProvider {
128    pub fn new(binary: impl Into<String>) -> Self {
129        Self {
130            binary: binary.into(),
131        }
132    }
133}
134
135#[async_trait]
136impl TokenProvider for GcloudTokenProvider {
137    async fn access_token(&self) -> Result<String> {
138        let output = Command::new(&self.binary)
139            .arg("auth")
140            .arg("print-access-token")
141            .output()
142            .await
143            .map_err(|err| {
144                Error::TokenProvider(format!(
145                    "Failed to execute gcloud command. Make sure gcloud CLI is installed and in PATH.\nError: {}",
146                    err
147                ))
148            })?;
149
150        if !output.status.success() {
151            let stderr = String::from_utf8_lossy(&output.stderr);
152            return Err(Error::TokenProvider(format!(
153                "Failed to get access token from gcloud. Please run 'gcloud auth login' to authenticate.\nError: {}",
154                stderr.trim()
155            )));
156        }
157
158        let token = String::from_utf8(output.stdout)
159            .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;
160
161        Ok(token.trim().to_owned())
162    }
163
164    fn kind(&self) -> ProviderKind {
165        ProviderKind::GcloudOauth
166    }
167}
168
169#[derive(Debug, Clone)]
170pub struct EnvTokenProvider {
171    key: String,
172}
173
174impl EnvTokenProvider {
175    pub fn new(key: impl Into<String>) -> Self {
176        Self { key: key.into() }
177    }
178}
179
180#[async_trait]
181impl TokenProvider for EnvTokenProvider {
182    async fn access_token(&self) -> Result<String> {
183        env::var(&self.key)
184            .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
185    }
186
187    fn kind(&self) -> ProviderKind {
188        ProviderKind::EnvAccessToken
189    }
190}
191
192#[derive(Debug, Clone)]
193pub struct StaticTokenProvider {
194    token: String,
195}
196
197impl StaticTokenProvider {
198    pub fn new(token: impl Into<String>) -> Self {
199        Self {
200            token: token.into(),
201        }
202    }
203}
204
205#[async_trait]
206impl TokenProvider for StaticTokenProvider {
207    async fn access_token(&self) -> Result<String> {
208        Ok(self.token.clone())
209    }
210
211    fn kind(&self) -> ProviderKind {
212        ProviderKind::StaticToken
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use wiremock::matchers::{method, path, query_param};
220    use wiremock::{Mock, MockServer, ResponseTemplate};
221
222    #[tokio::test]
223    async fn static_token_provider_returns_token() {
224        let provider = StaticTokenProvider::new("test-token-123");
225        let token = provider.access_token().await.unwrap();
226        assert_eq!(token, "test-token-123");
227    }
228
229    #[tokio::test]
230    async fn env_token_provider_reads_from_env() {
231        std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
232        let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
233        let token = provider.access_token().await.unwrap();
234        assert_eq!(token, "env-token-456");
235        std::env::remove_var("TEST_NBLM_TOKEN");
236    }
237
238    #[tokio::test]
239    async fn env_token_provider_errors_when_missing() {
240        std::env::remove_var("NONEXISTENT_TOKEN");
241        let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
242        let result = provider.access_token().await;
243        assert!(result.is_err());
244        assert!(result
245            .unwrap_err()
246            .to_string()
247            .contains("environment variable NONEXISTENT_TOKEN missing"));
248    }
249
250    #[test]
251    fn provider_kind_as_str_returns_correct_labels() {
252        assert_eq!(ProviderKind::GcloudOauth.as_str(), "gcloud-oauth");
253        assert_eq!(ProviderKind::EnvAccessToken.as_str(), "env-access-token");
254        assert_eq!(ProviderKind::StaticToken.as_str(), "static-token");
255        assert_eq!(ProviderKind::UserOauth.as_str(), "user-oauth");
256    }
257
258    #[test]
259    fn provider_kind_is_experimental_only_for_user_oauth() {
260        assert!(!ProviderKind::GcloudOauth.is_experimental());
261        assert!(!ProviderKind::EnvAccessToken.is_experimental());
262        assert!(!ProviderKind::StaticToken.is_experimental());
263        assert!(ProviderKind::UserOauth.is_experimental());
264    }
265
266    #[test]
267    fn gcloud_token_provider_returns_correct_kind() {
268        let provider = GcloudTokenProvider::new("gcloud");
269        assert_eq!(provider.kind(), ProviderKind::GcloudOauth);
270    }
271
272    #[test]
273    fn env_token_provider_returns_correct_kind() {
274        let provider = EnvTokenProvider::new("TEST_TOKEN");
275        assert_eq!(provider.kind(), ProviderKind::EnvAccessToken);
276    }
277
278    #[test]
279    fn static_token_provider_returns_correct_kind() {
280        let provider = StaticTokenProvider::new("token");
281        assert_eq!(provider.kind(), ProviderKind::StaticToken);
282    }
283
284    fn expect_scope_result(scopes: &str, expected: bool) {
285        assert_eq!(scope_grants_drive_access(scopes), expected);
286    }
287
288    #[test]
289    fn scope_grants_drive_access_detects_required_scopes() {
290        expect_scope_result(DRIVE_FILE_SCOPE, true);
291        expect_scope_result(DRIVE_SCOPE, true);
292        expect_scope_result(
293            "https://www.googleapis.com/auth/spreadsheets.readonly",
294            false,
295        );
296        expect_scope_result(
297            &format!("{DRIVE_FILE_SCOPE} https://www.googleapis.com/auth/calendar"),
298            true,
299        );
300    }
301
302    #[tokio::test]
303    async fn ensure_drive_scope_accepts_valid_scope() {
304        let server = MockServer::start().await;
305        Mock::given(method("GET"))
306            .and(path("/oauth2/v3/tokeninfo"))
307            .and(query_param("access_token", "valid-token"))
308            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
309                "scope": DRIVE_FILE_SCOPE
310            })))
311            .mount(&server)
312            .await;
313
314        let provider = StaticTokenProvider::new("valid-token");
315        let client = reqwest::Client::new();
316        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
317        let result = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint).await;
318        assert!(result.is_ok());
319    }
320
321    #[tokio::test]
322    async fn ensure_drive_scope_rejects_missing_scope() {
323        let server = MockServer::start().await;
324        Mock::given(method("GET"))
325            .and(path("/oauth2/v3/tokeninfo"))
326            .and(query_param("access_token", "no-scope"))
327            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
328                "scope": "https://www.googleapis.com/auth/spreadsheets.readonly"
329            })))
330            .mount(&server)
331            .await;
332
333        let provider = StaticTokenProvider::new("no-scope");
334        let client = reqwest::Client::new();
335        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
336        let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
337            .await
338            .unwrap_err();
339
340        match err {
341            Error::TokenProvider(message) => {
342                assert!(message.contains("drive.file scope"));
343            }
344            _ => panic!("expected TokenProvider error"),
345        }
346    }
347
348    #[tokio::test]
349    async fn ensure_drive_scope_converts_http_failures() {
350        let server = MockServer::start().await;
351        Mock::given(method("GET"))
352            .and(path("/oauth2/v3/tokeninfo"))
353            .and(query_param("access_token", "bad-token"))
354            .respond_with(ResponseTemplate::new(400).set_body_string("invalid_token"))
355            .mount(&server)
356            .await;
357
358        let provider = StaticTokenProvider::new("bad-token");
359        let client = reqwest::Client::new();
360        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
361        let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
362            .await
363            .unwrap_err();
364
365        match err {
366            Error::TokenProvider(message) => {
367                assert!(message.contains("status 400"));
368            }
369            _ => panic!("expected TokenProvider error"),
370        }
371    }
372}