nblm-core 0.2.2

Core library for NotebookLM Enterprise API client
Documentation
use std::env;

use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use tokio::process::Command;

use crate::error::{Error, Result};

pub mod oauth;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProviderKind {
    GcloudOauth,
    EnvAccessToken,
    StaticToken,
    UserOauth,
}

impl ProviderKind {
    pub fn as_str(&self) -> &'static str {
        match self {
            ProviderKind::GcloudOauth => "gcloud-oauth",
            ProviderKind::EnvAccessToken => "env-access-token",
            ProviderKind::StaticToken => "static-token",
            ProviderKind::UserOauth => "user-oauth",
        }
    }

    pub fn is_experimental(&self) -> bool {
        matches!(self, ProviderKind::UserOauth)
    }
}

#[async_trait]
pub trait TokenProvider: Send + Sync {
    async fn access_token(&self) -> Result<String>;
    async fn refresh_token(&self) -> Result<String> {
        self.access_token().await
    }

    fn kind(&self) -> ProviderKind {
        ProviderKind::StaticToken
    }
}

const TOKENINFO_ENDPOINT: &str = "https://www.googleapis.com/oauth2/v3/tokeninfo";
const DRIVE_SCOPE: &str = "https://www.googleapis.com/auth/drive";
const DRIVE_FILE_SCOPE: &str = "https://www.googleapis.com/auth/drive.file";

#[derive(Debug, Deserialize)]
struct TokenInfoResponse {
    scope: Option<String>,
}

pub async fn ensure_drive_scope(provider: &dyn TokenProvider) -> Result<()> {
    let client = Client::new();
    let endpoint =
        std::env::var("NBLM_TOKENINFO_ENDPOINT").unwrap_or_else(|_| TOKENINFO_ENDPOINT.to_string());
    ensure_drive_scope_internal(provider, &client, &endpoint).await
}

async fn ensure_drive_scope_internal(
    provider: &dyn TokenProvider,
    client: &Client,
    endpoint: &str,
) -> Result<()> {
    let access_token = provider.access_token().await?;

    let response = client
        .get(endpoint)
        .query(&[("access_token", access_token.as_str())])
        .send()
        .await
        .map_err(|err| {
            Error::TokenProvider(format!("failed to validate Google Drive token: {err}"))
        })?;

    if !response.status().is_success() {
        let status = response.status();
        let body = response
            .text()
            .await
            .unwrap_or_else(|_| String::from("<failed to read body>"));
        return Err(Error::TokenProvider(format!(
            "failed to validate Google Drive token (status {}): {}",
            status.as_u16(),
            body.trim()
        )));
    }

    let info: TokenInfoResponse = response
        .json()
        .await
        .map_err(|err| Error::TokenProvider(format!("invalid tokeninfo response: {err}")))?;

    let scopes = info.scope.unwrap_or_default();
    if scope_grants_drive_access(&scopes) {
        Ok(())
    } else {
        Err(Error::TokenProvider(
            "Google Drive access token is missing the required drive.file scope. Run `gcloud auth login --enable-gdrive-access` and retry.".to_string(),
        ))
    }
}

fn scope_grants_drive_access(scopes: &str) -> bool {
    scopes
        .split_whitespace()
        .any(|scope| scope == DRIVE_FILE_SCOPE || scope == DRIVE_SCOPE)
}

#[cfg(test)]
pub(crate) async fn ensure_drive_scope_with_endpoint(
    provider: &dyn TokenProvider,
    client: &Client,
    endpoint: &str,
) -> Result<()> {
    ensure_drive_scope_internal(provider, client, endpoint).await
}

#[derive(Debug, Default, Clone)]
pub struct GcloudTokenProvider {
    binary: String,
}

impl GcloudTokenProvider {
    pub fn new(binary: impl Into<String>) -> Self {
        Self {
            binary: binary.into(),
        }
    }
}

#[async_trait]
impl TokenProvider for GcloudTokenProvider {
    async fn access_token(&self) -> Result<String> {
        let output = Command::new(&self.binary)
            .arg("auth")
            .arg("print-access-token")
            .output()
            .await
            .map_err(|err| {
                Error::TokenProvider(format!(
                    "Failed to execute gcloud command. Make sure gcloud CLI is installed and in PATH.\nError: {}",
                    err
                ))
            })?;

        if !output.status.success() {
            let stderr = String::from_utf8_lossy(&output.stderr);
            return Err(Error::TokenProvider(format!(
                "Failed to get access token from gcloud. Please run 'gcloud auth login' to authenticate.\nError: {}",
                stderr.trim()
            )));
        }

        let token = String::from_utf8(output.stdout)
            .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;

        Ok(token.trim().to_owned())
    }

    fn kind(&self) -> ProviderKind {
        ProviderKind::GcloudOauth
    }
}

#[derive(Debug, Clone)]
pub struct EnvTokenProvider {
    key: String,
}

impl EnvTokenProvider {
    pub fn new(key: impl Into<String>) -> Self {
        Self { key: key.into() }
    }
}

#[async_trait]
impl TokenProvider for EnvTokenProvider {
    async fn access_token(&self) -> Result<String> {
        env::var(&self.key)
            .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
    }

    fn kind(&self) -> ProviderKind {
        ProviderKind::EnvAccessToken
    }
}

#[derive(Debug, Clone)]
pub struct StaticTokenProvider {
    token: String,
}

impl StaticTokenProvider {
    pub fn new(token: impl Into<String>) -> Self {
        Self {
            token: token.into(),
        }
    }
}

#[async_trait]
impl TokenProvider for StaticTokenProvider {
    async fn access_token(&self) -> Result<String> {
        Ok(self.token.clone())
    }

    fn kind(&self) -> ProviderKind {
        ProviderKind::StaticToken
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use wiremock::matchers::{method, path, query_param};
    use wiremock::{Mock, MockServer, ResponseTemplate};

    #[tokio::test]
    async fn static_token_provider_returns_token() {
        let provider = StaticTokenProvider::new("test-token-123");
        let token = provider.access_token().await.unwrap();
        assert_eq!(token, "test-token-123");
    }

    #[tokio::test]
    async fn env_token_provider_reads_from_env() {
        std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
        let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
        let token = provider.access_token().await.unwrap();
        assert_eq!(token, "env-token-456");
        std::env::remove_var("TEST_NBLM_TOKEN");
    }

    #[tokio::test]
    async fn env_token_provider_errors_when_missing() {
        std::env::remove_var("NONEXISTENT_TOKEN");
        let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
        let result = provider.access_token().await;
        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .to_string()
            .contains("environment variable NONEXISTENT_TOKEN missing"));
    }

    #[test]
    fn provider_kind_as_str_returns_correct_labels() {
        assert_eq!(ProviderKind::GcloudOauth.as_str(), "gcloud-oauth");
        assert_eq!(ProviderKind::EnvAccessToken.as_str(), "env-access-token");
        assert_eq!(ProviderKind::StaticToken.as_str(), "static-token");
        assert_eq!(ProviderKind::UserOauth.as_str(), "user-oauth");
    }

    #[test]
    fn provider_kind_is_experimental_only_for_user_oauth() {
        assert!(!ProviderKind::GcloudOauth.is_experimental());
        assert!(!ProviderKind::EnvAccessToken.is_experimental());
        assert!(!ProviderKind::StaticToken.is_experimental());
        assert!(ProviderKind::UserOauth.is_experimental());
    }

    #[test]
    fn gcloud_token_provider_returns_correct_kind() {
        let provider = GcloudTokenProvider::new("gcloud");
        assert_eq!(provider.kind(), ProviderKind::GcloudOauth);
    }

    #[test]
    fn env_token_provider_returns_correct_kind() {
        let provider = EnvTokenProvider::new("TEST_TOKEN");
        assert_eq!(provider.kind(), ProviderKind::EnvAccessToken);
    }

    #[test]
    fn static_token_provider_returns_correct_kind() {
        let provider = StaticTokenProvider::new("token");
        assert_eq!(provider.kind(), ProviderKind::StaticToken);
    }

    fn expect_scope_result(scopes: &str, expected: bool) {
        assert_eq!(scope_grants_drive_access(scopes), expected);
    }

    #[test]
    fn scope_grants_drive_access_detects_required_scopes() {
        expect_scope_result(DRIVE_FILE_SCOPE, true);
        expect_scope_result(DRIVE_SCOPE, true);
        expect_scope_result(
            "https://www.googleapis.com/auth/spreadsheets.readonly",
            false,
        );
        expect_scope_result(
            &format!("{DRIVE_FILE_SCOPE} https://www.googleapis.com/auth/calendar"),
            true,
        );
    }

    #[tokio::test]
    async fn ensure_drive_scope_accepts_valid_scope() {
        let server = MockServer::start().await;
        Mock::given(method("GET"))
            .and(path("/oauth2/v3/tokeninfo"))
            .and(query_param("access_token", "valid-token"))
            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
                "scope": DRIVE_FILE_SCOPE
            })))
            .mount(&server)
            .await;

        let provider = StaticTokenProvider::new("valid-token");
        let client = reqwest::Client::new();
        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
        let result = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint).await;
        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn ensure_drive_scope_rejects_missing_scope() {
        let server = MockServer::start().await;
        Mock::given(method("GET"))
            .and(path("/oauth2/v3/tokeninfo"))
            .and(query_param("access_token", "no-scope"))
            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
                "scope": "https://www.googleapis.com/auth/spreadsheets.readonly"
            })))
            .mount(&server)
            .await;

        let provider = StaticTokenProvider::new("no-scope");
        let client = reqwest::Client::new();
        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
        let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
            .await
            .unwrap_err();

        match err {
            Error::TokenProvider(message) => {
                assert!(message.contains("drive.file scope"));
            }
            _ => panic!("expected TokenProvider error"),
        }
    }

    #[tokio::test]
    async fn ensure_drive_scope_converts_http_failures() {
        let server = MockServer::start().await;
        Mock::given(method("GET"))
            .and(path("/oauth2/v3/tokeninfo"))
            .and(query_param("access_token", "bad-token"))
            .respond_with(ResponseTemplate::new(400).set_body_string("invalid_token"))
            .mount(&server)
            .await;

        let provider = StaticTokenProvider::new("bad-token");
        let client = reqwest::Client::new();
        let endpoint = format!("{}/oauth2/v3/tokeninfo", server.uri());
        let err = ensure_drive_scope_with_endpoint(&provider, &client, &endpoint)
            .await
            .unwrap_err();

        match err {
            Error::TokenProvider(message) => {
                assert!(message.contains("status 400"));
            }
            _ => panic!("expected TokenProvider error"),
        }
    }
}