1use std::env;
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::Deserialize;
6use tokio::process::Command;
7
8use crate::error::{Error, Result};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ProviderKind {
12 GcloudOauth,
13 EnvAccessToken,
14 StaticToken,
15 UserOauth,
16}
17
18impl ProviderKind {
19 pub fn as_str(&self) -> &'static str {
20 match self {
21 ProviderKind::GcloudOauth => "gcloud-oauth",
22 ProviderKind::EnvAccessToken => "env-access-token",
23 ProviderKind::StaticToken => "static-token",
24 ProviderKind::UserOauth => "user-oauth",
25 }
26 }
27
28 pub fn is_experimental(&self) -> bool {
29 matches!(self, ProviderKind::UserOauth)
30 }
31}
32
33#[async_trait]
34pub trait TokenProvider: Send + Sync {
35 async fn access_token(&self) -> Result<String>;
36 async fn refresh_token(&self) -> Result<String> {
37 self.access_token().await
38 }
39
40 fn kind(&self) -> ProviderKind {
41 ProviderKind::StaticToken
42 }
43}
44
45const TOKENINFO_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/tokeninfo";
46const DRIVE_SCOPE: &str = "https://www.googleapis.com/auth/drive";
47const DRIVE_FILE_SCOPE: &str = "https://www.googleapis.com/auth/drive.file";
48
49#[derive(Debug, Deserialize)]
50struct TokenInfoResponse {
51 scope: Option<String>,
52}
53
54pub async fn ensure_drive_scope(provider: &dyn TokenProvider) -> Result<()> {
55 let client = Client::new();
56 let endpoint =
57 std::env::var("NBLM_TOKENINFO_ENDPOINT").unwrap_or_else(|_| TOKENINFO_ENDPOINT.to_string());
58 ensure_drive_scope_internal(provider, &client, &endpoint).await
59}
60
61async fn ensure_drive_scope_internal(
62 provider: &dyn TokenProvider,
63 client: &Client,
64 endpoint: &str,
65) -> Result<()> {
66 let access_token = provider.access_token().await?;
67
68 let response = client
69 .get(endpoint)
70 .query(&[("access_token", access_token.as_str())])
71 .send()
72 .await
73 .map_err(|err| {
74 Error::TokenProvider(format!("failed to validate Google Drive token: {err}"))
75 })?;
76
77 if !response.status().is_success() {
78 let status = response.status();
79 let body = response
80 .text()
81 .await
82 .unwrap_or_else(|_| String::from("<failed to read body>"));
83 return Err(Error::TokenProvider(format!(
84 "failed to validate Google Drive token (status {}): {}",
85 status.as_u16(),
86 body.trim()
87 )));
88 }
89
90 let info: TokenInfoResponse = response
91 .json()
92 .await
93 .map_err(|err| Error::TokenProvider(format!("invalid tokeninfo response: {err}")))?;
94
95 let scopes = info.scope.unwrap_or_default();
96 if scope_grants_drive_access(&scopes) {
97 Ok(())
98 } else {
99 Err(Error::TokenProvider(
100 "Google Drive access token is missing the required drive.file scope. Run `gcloud auth login --enable-gdrive-access` and retry.".to_string(),
101 ))
102 }
103}
104
105fn scope_grants_drive_access(scopes: &str) -> bool {
106 scopes
107 .split_whitespace()
108 .any(|scope| scope == DRIVE_FILE_SCOPE || scope == DRIVE_SCOPE)
109}
110
111#[cfg(test)]
112pub(crate) async fn ensure_drive_scope_with_endpoint(
113 provider: &dyn TokenProvider,
114 client: &Client,
115 endpoint: &str,
116) -> Result<()> {
117 ensure_drive_scope_internal(provider, client, endpoint).await
118}
119
120#[derive(Debug, Default, Clone)]
121pub struct GcloudTokenProvider {
122 binary: String,
123}
124
125impl GcloudTokenProvider {
126 pub fn new(binary: impl Into<String>) -> Self {
127 Self {
128 binary: binary.into(),
129 }
130 }
131}
132
133#[async_trait]
134impl TokenProvider for GcloudTokenProvider {
135 async fn access_token(&self) -> Result<String> {
136 let output = Command::new(&self.binary)
137 .arg("auth")
138 .arg("print-access-token")
139 .output()
140 .await
141 .map_err(|err| {
142 Error::TokenProvider(format!(
143 "Failed to execute gcloud command. Make sure gcloud CLI is installed and in PATH.\nError: {}",
144 err
145 ))
146 })?;
147
148 if !output.status.success() {
149 let stderr = String::from_utf8_lossy(&output.stderr);
150 return Err(Error::TokenProvider(format!(
151 "Failed to get access token from gcloud. Please run 'gcloud auth login' to authenticate.\nError: {}",
152 stderr.trim()
153 )));
154 }
155
156 let token = String::from_utf8(output.stdout)
157 .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;
158
159 Ok(token.trim().to_owned())
160 }
161
162 fn kind(&self) -> ProviderKind {
163 ProviderKind::GcloudOauth
164 }
165}
166
167#[derive(Debug, Clone)]
168pub struct EnvTokenProvider {
169 key: String,
170}
171
172impl EnvTokenProvider {
173 pub fn new(key: impl Into<String>) -> Self {
174 Self { key: key.into() }
175 }
176}
177
178#[async_trait]
179impl TokenProvider for EnvTokenProvider {
180 async fn access_token(&self) -> Result<String> {
181 env::var(&self.key)
182 .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
183 }
184
185 fn kind(&self) -> ProviderKind {
186 ProviderKind::EnvAccessToken
187 }
188}
189
190#[derive(Debug, Clone)]
191pub struct StaticTokenProvider {
192 token: String,
193}
194
195impl StaticTokenProvider {
196 pub fn new(token: impl Into<String>) -> Self {
197 Self {
198 token: token.into(),
199 }
200 }
201}
202
203#[async_trait]
204impl TokenProvider for StaticTokenProvider {
205 async fn access_token(&self) -> Result<String> {
206 Ok(self.token.clone())
207 }
208
209 fn kind(&self) -> ProviderKind {
210 ProviderKind::StaticToken
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use wiremock::matchers::{method, path, query_param};
218 use wiremock::{Mock, MockServer, ResponseTemplate};
219
220 #[tokio::test]
221 async fn static_token_provider_returns_token() {
222 let provider = StaticTokenProvider::new("test-token-123");
223 let token = provider.access_token().await.unwrap();
224 assert_eq!(token, "test-token-123");
225 }
226
227 #[tokio::test]
228 async fn env_token_provider_reads_from_env() {
229 std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
230 let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
231 let token = provider.access_token().await.unwrap();
232 assert_eq!(token, "env-token-456");
233 std::env::remove_var("TEST_NBLM_TOKEN");
234 }
235
236 #[tokio::test]
237 async fn env_token_provider_errors_when_missing() {
238 std::env::remove_var("NONEXISTENT_TOKEN");
239 let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
240 let result = provider.access_token().await;
241 assert!(result.is_err());
242 assert!(result
243 .unwrap_err()
244 .to_string()
245 .contains("environment variable NONEXISTENT_TOKEN missing"));
246 }
247
248 #[test]
249 fn provider_kind_as_str_returns_correct_labels() {
250 assert_eq!(ProviderKind::GcloudOauth.as_str(), "gcloud-oauth");
251 assert_eq!(ProviderKind::EnvAccessToken.as_str(), "env-access-token");
252 assert_eq!(ProviderKind::StaticToken.as_str(), "static-token");
253 assert_eq!(ProviderKind::UserOauth.as_str(), "user-oauth");
254 }
255
256 #[test]
257 fn provider_kind_is_experimental_only_for_user_oauth() {
258 assert!(!ProviderKind::GcloudOauth.is_experimental());
259 assert!(!ProviderKind::EnvAccessToken.is_experimental());
260 assert!(!ProviderKind::StaticToken.is_experimental());
261 assert!(ProviderKind::UserOauth.is_experimental());
262 }
263
264 #[test]
265 fn gcloud_token_provider_returns_correct_kind() {
266 let provider = GcloudTokenProvider::new("gcloud");
267 assert_eq!(provider.kind(), ProviderKind::GcloudOauth);
268 }
269
270 #[test]
271 fn env_token_provider_returns_correct_kind() {
272 let provider = EnvTokenProvider::new("TEST_TOKEN");
273 assert_eq!(provider.kind(), ProviderKind::EnvAccessToken);
274 }
275
276 #[test]
277 fn static_token_provider_returns_correct_kind() {
278 let provider = StaticTokenProvider::new("token");
279 assert_eq!(provider.kind(), ProviderKind::StaticToken);
280 }
281
282 fn expect_scope_result(scopes: &str, expected: bool) {
283 assert_eq!(scope_grants_drive_access(scopes), expected);
284 }
285
286 #[test]
287 fn scope_grants_drive_access_detects_required_scopes() {
288 expect_scope_result(DRIVE_FILE_SCOPE, true);
289 expect_scope_result(DRIVE_SCOPE, true);
290 expect_scope_result(
291 "https://www.googleapis.com/auth/spreadsheets.readonly",
292 false,
293 );
294 expect_scope_result(
295 &format!("{DRIVE_FILE_SCOPE} https://www.googleapis.com/auth/calendar"),
296 true,
297 );
298 }
299
300 #[tokio::test]
301 async fn ensure_drive_scope_accepts_valid_scope() {
302 let server = MockServer::start().await;
303 Mock::given(method("GET"))
304 .and(path("/oauth2/v3/tokeninfo"))
305 .and(query_param("access_token", "valid-token"))
306 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
307 "scope": DRIVE_FILE_SCOPE
308 })))
309 .mount(&server)
310 .await;
311
312 let provider = StaticTokenProvider::new("valid-token");
313 let client = reqwest::Client::new();
314 let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
315 let result = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint).await;
316 assert!(result.is_ok());
317 }
318
319 #[tokio::test]
320 async fn ensure_drive_scope_rejects_missing_scope() {
321 let server = MockServer::start().await;
322 Mock::given(method("GET"))
323 .and(path("/oauth2/v3/tokeninfo"))
324 .and(query_param("access_token", "no-scope"))
325 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
326 "scope": "https://www.googleapis.com/auth/spreadsheets.readonly"
327 })))
328 .mount(&server)
329 .await;
330
331 let provider = StaticTokenProvider::new("no-scope");
332 let client = reqwest::Client::new();
333 let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
334 let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
335 .await
336 .unwrap_err();
337
338 match err {
339 Error::TokenProvider(message) => {
340 assert!(message.contains("drive.file scope"));
341 }
342 _ => panic!("expected TokenProvider error"),
343 }
344 }
345
346 #[tokio::test]
347 async fn ensure_drive_scope_converts_http_failures() {
348 let server = MockServer::start().await;
349 Mock::given(method("GET"))
350 .and(path("/oauth2/v3/tokeninfo"))
351 .and(query_param("access_token", "bad-token"))
352 .respond_with(ResponseTemplate::new(400).set_body_string("invalid_token"))
353 .mount(&server)
354 .await;
355
356 let provider = StaticTokenProvider::new("bad-token");
357 let client = reqwest::Client::new();
358 let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
359 let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
360 .await
361 .unwrap_err();
362
363 match err {
364 Error::TokenProvider(message) => {
365 assert!(message.contains("status 400"));
366 }
367 _ => panic!("expected TokenProvider error"),
368 }
369 }
370}