1use std::env;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::process::Command;
7
8use crate::error::{Error, Result};
9
10pub mod oauth;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ProviderKind {
14 GcloudOauth,
15 EnvAccessToken,
16 StaticToken,
17 UserOauth,
18}
19
20impl ProviderKind {
21 pub fn as_str(&self) -> &'static str {
22 match self {
23 ProviderKind::GcloudOauth => "gcloud-oauth",
24 ProviderKind::EnvAccessToken => "env-access-token",
25 ProviderKind::StaticToken => "static-token",
26 ProviderKind::UserOauth => "user-oauth",
27 }
28 }
29
30 pub fn is_experimental(&self) -> bool {
31 matches!(self, ProviderKind::UserOauth)
32 }
33}
34
35#[async_trait]
36pub trait TokenProvider: Send + Sync {
37 async fn access_token(&self) -> Result<String>;
38 async fn refresh_token(&self) -> Result<String> {
39 self.access_token().await
40 }
41
42 fn kind(&self) -> ProviderKind {
43 ProviderKind::StaticToken
44 }
45}
46
47const TOKENINFO_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/tokeninfo";
48const DRIVE_SCOPE: &str = "https://www.googleapis.com/auth/drive";
49const DRIVE_FILE_SCOPE: &str = "https://www.googleapis.com/auth/drive.file";
50
51#[derive(Debug, Deserialize)]
52struct TokenInfoResponse {
53 scope: Option<String>,
54}
55
56pub async fn ensure_drive_scope(provider: &dyn TokenProvider) -> Result<()> {
57 let client = Client::new();
58 let endpoint =
59 std::env::var("NBLM_TOKENINFO_ENDPOINT").unwrap_or_else(|_| TOKENINFO_ENDPOINT.to_string());
60 ensure_drive_scope_internal(provider, &client, &endpoint).await
61}
62
63async fn ensure_drive_scope_internal(
64 provider: &dyn TokenProvider,
65 client: &Client,
66 endpoint: &str,
67) -> Result<()> {
68 let access_token = provider.access_token().await?;
69
70 let response = client
71 .get(endpoint)
72 .query(&[("access_token", access_token.as_str())])
73 .send()
74 .await
75 .map_err(|err| {
76 Error::TokenProvider(format!("failed to validate Google Drive token: {err}"))
77 })?;
78
79 if !response.status().is_success() {
80 let status = response.status();
81 let body = response
82 .text()
83 .await
84 .unwrap_or_else(|_| String::from("<failed to read body>"));
85 return Err(Error::TokenProvider(format!(
86 "failed to validate Google Drive token (status {}): {}",
87 status.as_u16(),
88 body.trim()
89 )));
90 }
91
92 let info: TokenInfoResponse = response
93 .json()
94 .await
95 .map_err(|err| Error::TokenProvider(format!("invalid tokeninfo response: {err}")))?;
96
97 let scopes = info.scope.unwrap_or_default();
98 if scope_grants_drive_access(&scopes) {
99 Ok(())
100 } else {
101 Err(Error::TokenProvider(
102 "Google Drive access token is missing the required drive.file scope. Run `gcloud auth login --enable-gdrive-access` and retry.".to_string(),
103 ))
104 }
105}
106
107fn scope_grants_drive_access(scopes: &str) -> bool {
108 scopes
109 .split_whitespace()
110 .any(|scope| scope == DRIVE_FILE_SCOPE || scope == DRIVE_SCOPE)
111}
112
113#[cfg(test)]
114pub(crate) async fn ensure_drive_scope_with_endpoint(
115 provider: &dyn TokenProvider,
116 client: &Client,
117 endpoint: &str,
118) -> Result<()> {
119 ensure_drive_scope_internal(provider, client, endpoint).await
120}
121
122#[derive(Debug, Default, Clone)]
123pub struct GcloudTokenProvider {
124 binary: String,
125}
126
127impl GcloudTokenProvider {
128 pub fn new(binary: impl Into<String>) -> Self {
129 Self {
130 binary: binary.into(),
131 }
132 }
133}
134
135#[async_trait]
136impl TokenProvider for GcloudTokenProvider {
137 async fn access_token(&self) -> Result<String> {
138 let output = Command::new(&self.binary)
139 .arg("auth")
140 .arg("print-access-token")
141 .output()
142 .await
143 .map_err(|err| {
144 Error::TokenProvider(format!(
145 "Failed to execute gcloud command. Make sure gcloud CLI is installed and in PATH.\nError: {}",
146 err
147 ))
148 })?;
149
150 if !output.status.success() {
151 let stderr = String::from_utf8_lossy(&output.stderr);
152 return Err(Error::TokenProvider(format!(
153 "Failed to get access token from gcloud. Please run 'gcloud auth login' to authenticate.\nError: {}",
154 stderr.trim()
155 )));
156 }
157
158 let token = String::from_utf8(output.stdout)
159 .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;
160
161 Ok(token.trim().to_owned())
162 }
163
164 fn kind(&self) -> ProviderKind {
165 ProviderKind::GcloudOauth
166 }
167}
168
169#[derive(Debug, Clone)]
170pub struct EnvTokenProvider {
171 key: String,
172}
173
174impl EnvTokenProvider {
175 pub fn new(key: impl Into<String>) -> Self {
176 Self { key: key.into() }
177 }
178}
179
180#[async_trait]
181impl TokenProvider for EnvTokenProvider {
182 async fn access_token(&self) -> Result<String> {
183 env::var(&self.key)
184 .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
185 }
186
187 fn kind(&self) -> ProviderKind {
188 ProviderKind::EnvAccessToken
189 }
190}
191
192#[derive(Debug, Clone)]
193pub struct StaticTokenProvider {
194 token: String,
195}
196
197impl StaticTokenProvider {
198 pub fn new(token: impl Into<String>) -> Self {
199 Self {
200 token: token.into(),
201 }
202 }
203}
204
205#[async_trait]
206impl TokenProvider for StaticTokenProvider {
207 async fn access_token(&self) -> Result<String> {
208 Ok(self.token.clone())
209 }
210
211 fn kind(&self) -> ProviderKind {
212 ProviderKind::StaticToken
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use wiremock::matchers::{method, path, query_param};
220 use wiremock::{Mock, MockServer, ResponseTemplate};
221
222 #[tokio::test]
223 async fn static_token_provider_returns_token() {
224 let provider = StaticTokenProvider::new("test-token-123");
225 let token = provider.access_token().await.unwrap();
226 assert_eq!(token, "test-token-123");
227 }
228
229 #[tokio::test]
230 async fn env_token_provider_reads_from_env() {
231 std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
232 let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
233 let token = provider.access_token().await.unwrap();
234 assert_eq!(token, "env-token-456");
235 std::env::remove_var("TEST_NBLM_TOKEN");
236 }
237
238 #[tokio::test]
239 async fn env_token_provider_errors_when_missing() {
240 std::env::remove_var("NONEXISTENT_TOKEN");
241 let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
242 let result = provider.access_token().await;
243 assert!(result.is_err());
244 assert!(result
245 .unwrap_err()
246 .to_string()
247 .contains("environment variable NONEXISTENT_TOKEN missing"));
248 }
249
250 #[test]
251 fn provider_kind_as_str_returns_correct_labels() {
252 assert_eq!(ProviderKind::GcloudOauth.as_str(), "gcloud-oauth");
253 assert_eq!(ProviderKind::EnvAccessToken.as_str(), "env-access-token");
254 assert_eq!(ProviderKind::StaticToken.as_str(), "static-token");
255 assert_eq!(ProviderKind::UserOauth.as_str(), "user-oauth");
256 }
257
258 #[test]
259 fn provider_kind_is_experimental_only_for_user_oauth() {
260 assert!(!ProviderKind::GcloudOauth.is_experimental());
261 assert!(!ProviderKind::EnvAccessToken.is_experimental());
262 assert!(!ProviderKind::StaticToken.is_experimental());
263 assert!(ProviderKind::UserOauth.is_experimental());
264 }
265
266 #[test]
267 fn gcloud_token_provider_returns_correct_kind() {
268 let provider = GcloudTokenProvider::new("gcloud");
269 assert_eq!(provider.kind(), ProviderKind::GcloudOauth);
270 }
271
272 #[test]
273 fn env_token_provider_returns_correct_kind() {
274 let provider = EnvTokenProvider::new("TEST_TOKEN");
275 assert_eq!(provider.kind(), ProviderKind::EnvAccessToken);
276 }
277
278 #[test]
279 fn static_token_provider_returns_correct_kind() {
280 let provider = StaticTokenProvider::new("token");
281 assert_eq!(provider.kind(), ProviderKind::StaticToken);
282 }
283
284 fn expect_scope_result(scopes: &str, expected: bool) {
285 assert_eq!(scope_grants_drive_access(scopes), expected);
286 }
287
288 #[test]
289 fn scope_grants_drive_access_detects_required_scopes() {
290 expect_scope_result(DRIVE_FILE_SCOPE, true);
291 expect_scope_result(DRIVE_SCOPE, true);
292 expect_scope_result(
293 "https://www.googleapis.com/auth/spreadsheets.readonly",
294 false,
295 );
296 expect_scope_result(
297 &format!("{DRIVE_FILE_SCOPE} https://www.googleapis.com/auth/calendar"),
298 true,
299 );
300 }
301
302 #[tokio::test]
303 async fn ensure_drive_scope_accepts_valid_scope() {
304 let server = MockServer::start().await;
305 Mock::given(method("GET"))
306 .and(path("/oauth2/v3/tokeninfo"))
307 .and(query_param("access_token", "valid-token"))
308 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
309 "scope": DRIVE_FILE_SCOPE
310 })))
311 .mount(&server)
312 .await;
313
314 let provider = StaticTokenProvider::new("valid-token");
315 let client = reqwest::Client::new();
316 let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
317 let result = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint).await;
318 assert!(result.is_ok());
319 }
320
321 #[tokio::test]
322 async fn ensure_drive_scope_rejects_missing_scope() {
323 let server = MockServer::start().await;
324 Mock::given(method("GET"))
325 .and(path("/oauth2/v3/tokeninfo"))
326 .and(query_param("access_token", "no-scope"))
327 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
328 "scope": "https://www.googleapis.com/auth/spreadsheets.readonly"
329 })))
330 .mount(&server)
331 .await;
332
333 let provider = StaticTokenProvider::new("no-scope");
334 let client = reqwest::Client::new();
335 let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
336 let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
337 .await
338 .unwrap_err();
339
340 match err {
341 Error::TokenProvider(message) => {
342 assert!(message.contains("drive.file scope"));
343 }
344 _ => panic!("expected TokenProvider error"),
345 }
346 }
347
348 #[tokio::test]
349 async fn ensure_drive_scope_converts_http_failures() {
350 let server = MockServer::start().await;
351 Mock::given(method("GET"))
352 .and(path("/oauth2/v3/tokeninfo"))
353 .and(query_param("access_token", "bad-token"))
354 .respond_with(ResponseTemplate::new(400).set_body_string("invalid_token"))
355 .mount(&server)
356 .await;
357
358 let provider = StaticTokenProvider::new("bad-token");
359 let client = reqwest::Client::new();
360 let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
361 let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
362 .await
363 .unwrap_err();
364
365 match err {
366 Error::TokenProvider(message) => {
367 assert!(message.contains("status 400"));
368 }
369 _ => panic!("expected TokenProvider error"),
370 }
371 }
372}