nblm_core/
auth.rs

1use std::env;
2
3use async_trait::async_trait;
4use tokio::process::Command;
5
6use crate::error::{Error, Result};
7
8#[async_trait]
9pub trait TokenProvider: Send + Sync {
10    async fn access_token(&self) -> Result<String>;
11    async fn refresh_token(&self) -> Result<String> {
12        self.access_token().await
13    }
14}
15
16#[derive(Debug, Default, Clone)]
17pub struct GcloudTokenProvider {
18    binary: String,
19}
20
21impl GcloudTokenProvider {
22    pub fn new(binary: impl Into<String>) -> Self {
23        Self {
24            binary: binary.into(),
25        }
26    }
27}
28
29#[async_trait]
30impl TokenProvider for GcloudTokenProvider {
31    async fn access_token(&self) -> Result<String> {
32        let output = Command::new(&self.binary)
33            .arg("auth")
34            .arg("print-access-token")
35            .output()
36            .await
37            .map_err(|err| {
38                Error::TokenProvider(format!(
39                    "Failed to execute gcloud command. Make sure gcloud CLI is installed and in PATH.\nError: {}",
40                    err
41                ))
42            })?;
43
44        if !output.status.success() {
45            let stderr = String::from_utf8_lossy(&output.stderr);
46            return Err(Error::TokenProvider(format!(
47                "Failed to get access token from gcloud. Please run 'gcloud auth login' to authenticate.\nError: {}",
48                stderr.trim()
49            )));
50        }
51
52        let token = String::from_utf8(output.stdout)
53            .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;
54
55        Ok(token.trim().to_owned())
56    }
57}
58
59#[derive(Debug, Clone)]
60pub struct EnvTokenProvider {
61    key: String,
62}
63
64impl EnvTokenProvider {
65    pub fn new(key: impl Into<String>) -> Self {
66        Self { key: key.into() }
67    }
68}
69
70#[async_trait]
71impl TokenProvider for EnvTokenProvider {
72    async fn access_token(&self) -> Result<String> {
73        env::var(&self.key)
74            .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
75    }
76}
77
78#[derive(Debug, Clone)]
79pub struct StaticTokenProvider {
80    token: String,
81}
82
83impl StaticTokenProvider {
84    pub fn new(token: impl Into<String>) -> Self {
85        Self {
86            token: token.into(),
87        }
88    }
89}
90
91#[async_trait]
92impl TokenProvider for StaticTokenProvider {
93    async fn access_token(&self) -> Result<String> {
94        Ok(self.token.clone())
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101
102    #[tokio::test]
103    async fn static_token_provider_returns_token() {
104        let provider = StaticTokenProvider::new("test-token-123");
105        let token = provider.access_token().await.unwrap();
106        assert_eq!(token, "test-token-123");
107    }
108
109    #[tokio::test]
110    async fn env_token_provider_reads_from_env() {
111        std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
112        let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
113        let token = provider.access_token().await.unwrap();
114        assert_eq!(token, "env-token-456");
115        std::env::remove_var("TEST_NBLM_TOKEN");
116    }
117
118    #[tokio::test]
119    async fn env_token_provider_errors_when_missing() {
120        std::env::remove_var("NONEXISTENT_TOKEN");
121        let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
122        let result = provider.access_token().await;
123        assert!(result.is_err());
124        assert!(result
125            .unwrap_err()
126            .to_string()
127            .contains("environment variable NONEXISTENT_TOKEN missing"));
128    }
129}