nblm-core 0.1.0

Core library for NotebookLM Enterprise API client
Documentation
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::{env, fs};

use async_trait::async_trait;
use chrono::{Duration as ChronoDuration, Utc};
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use tokio::sync::Mutex;

use crate::error::{Error, Result};

#[async_trait]
pub trait TokenProvider: Send + Sync {
    async fn access_token(&self) -> Result<String>;
    async fn refresh_token(&self) -> Result<String> {
        self.access_token().await
    }
}

#[derive(Debug, Default, Clone)]
pub struct GcloudTokenProvider {
    binary: String,
}

impl GcloudTokenProvider {
    pub fn new(binary: impl Into<String>) -> Self {
        Self {
            binary: binary.into(),
        }
    }
}

#[async_trait]
impl TokenProvider for GcloudTokenProvider {
    async fn access_token(&self) -> Result<String> {
        let output = Command::new(&self.binary)
            .arg("auth")
            .arg("print-access-token")
            .output()
            .await
            .map_err(|err| Error::TokenProvider(err.to_string()))?;

        if !output.status.success() {
            return Err(Error::TokenProvider(format!(
                "gcloud exited with status {}",
                output.status
            )));
        }

        let token = String::from_utf8(output.stdout)
            .map_err(|err| Error::TokenProvider(format!("invalid UTF-8 token: {err}")))?;

        Ok(token.trim().to_owned())
    }
}

#[derive(Debug, Clone)]
pub struct EnvTokenProvider {
    key: String,
}

impl EnvTokenProvider {
    pub fn new(key: impl Into<String>) -> Self {
        Self { key: key.into() }
    }
}

#[async_trait]
impl TokenProvider for EnvTokenProvider {
    async fn access_token(&self) -> Result<String> {
        env::var(&self.key)
            .map_err(|_| Error::TokenProvider(format!("environment variable {} missing", self.key)))
    }
}

#[derive(Debug, Clone)]
pub struct StaticTokenProvider {
    token: String,
}

impl StaticTokenProvider {
    pub fn new(token: impl Into<String>) -> Self {
        Self {
            token: token.into(),
        }
    }
}

#[async_trait]
impl TokenProvider for StaticTokenProvider {
    async fn access_token(&self) -> Result<String> {
        Ok(self.token.clone())
    }
}

#[derive(Debug, Clone, Deserialize)]
struct ServiceAccountKey {
    #[serde(rename = "client_email")]
    client_email: String,
    #[serde(rename = "private_key")]
    private_key: String,
    #[serde(rename = "token_uri")]
    token_uri: String,
}

#[derive(Debug, Clone)]
struct CachedToken {
    token: String,
    expires_at: Instant,
}

#[derive(Debug, Clone)]
pub struct ServiceAccountTokenProvider {
    key: ServiceAccountKey,
    scopes: Vec<String>,
    cache: Arc<Mutex<Option<CachedToken>>>,
    client: Client,
    leeway: Duration,
    http_timeout: Duration,
}

impl ServiceAccountTokenProvider {
    pub fn from_file(path: impl AsRef<Path>, scopes: Vec<String>) -> Result<Self> {
        let data = fs::read_to_string(path).map_err(|err| Error::TokenProvider(err.to_string()))?;
        Self::from_json(&data, scopes)
    }

    pub fn from_json(data: &str, scopes: Vec<String>) -> Result<Self> {
        let key: ServiceAccountKey = serde_json::from_str(data).map_err(|err| {
            Error::TokenProvider(format!("failed to parse service account key: {err}"))
        })?;
        let client = Client::builder()
            .build()
            .map_err(|err| Error::TokenProvider(format!("failed to build HTTP client: {err}")))?;
        Ok(Self {
            key,
            scopes,
            cache: Arc::new(Mutex::new(None)),
            client,
            leeway: Duration::from_secs(60),
            http_timeout: Duration::from_secs(10),
        })
    }

    pub fn with_leeway(mut self, leeway: Duration) -> Self {
        self.leeway = leeway;
        self
    }

    pub fn with_http_timeout(mut self, timeout: Duration) -> Self {
        self.http_timeout = timeout;
        self
    }

    async fn cached_token(&self) -> Option<String> {
        let cache = self.cache.lock().await;
        cache
            .as_ref()
            .filter(|cached| Instant::now() < cached.expires_at)
            .map(|cached| cached.token.clone())
    }

    async fn store_token(&self, token: String, expires_in: i64) {
        let valid_for = Duration::from_secs(expires_in.max(0) as u64);
        let now = Instant::now();
        let expires_at = now + valid_for;
        let expires_at = expires_at.checked_sub(self.leeway).unwrap_or(now);
        let mut cache = self.cache.lock().await;
        *cache = Some(CachedToken { token, expires_at });
    }

    fn create_jwt(&self) -> Result<String> {
        #[derive(Serialize)]
        struct Claims<'a> {
            iss: &'a str,
            scope: String,
            aud: &'a str,
            exp: i64,
            iat: i64,
        }

        let now = Utc::now();
        let exp = now + ChronoDuration::seconds(3600);
        let claims = Claims {
            iss: &self.key.client_email,
            scope: self.scopes.join(" "),
            aud: &self.key.token_uri,
            exp: exp.timestamp(),
            iat: now.timestamp(),
        };

        let header = Header::new(Algorithm::RS256);
        encode(
            &header,
            &claims,
            &EncodingKey::from_rsa_pem(self.key.private_key.as_bytes())
                .map_err(|err| Error::TokenProvider(err.to_string()))?,
        )
        .map_err(|err| Error::TokenProvider(err.to_string()))
    }
}

#[derive(Debug, Deserialize)]
struct TokenResponse {
    access_token: String,
    expires_in: i64,
}

#[derive(Serialize)]
struct TokenRequest<'a> {
    grant_type: &'a str,
    assertion: &'a str,
}

#[async_trait]
impl TokenProvider for ServiceAccountTokenProvider {
    async fn access_token(&self) -> Result<String> {
        if let Some(token) = self.cached_token().await {
            return Ok(token);
        }

        let assertion = self.create_jwt()?;
        let body = TokenRequest {
            grant_type: "urn:ietf:params:oauth:grant-type:jwt-bearer",
            assertion: &assertion,
        };

        let response = self
            .client
            .post(&self.key.token_uri)
            .timeout(self.http_timeout)
            .form(&body)
            .send()
            .await
            .map_err(|err| Error::TokenProvider(err.to_string()))?;

        if !response.status().is_success() {
            let status = response.status();
            let text = response
                .text()
                .await
                .unwrap_or_else(|_| "<failed to read body>".to_string());
            return Err(Error::TokenProvider(format!(
                "token endpoint error {}: {}",
                status, text
            )));
        }

        let token_response: TokenResponse = response
            .json()
            .await
            .map_err(|err| Error::TokenProvider(format!("invalid token response: {err}")))?;

        self.store_token(
            token_response.access_token.clone(),
            token_response.expires_in,
        )
        .await;
        Ok(token_response.access_token)
    }

    async fn refresh_token(&self) -> Result<String> {
        {
            let mut cache = self.cache.lock().await;
            *cache = None;
        }
        self.access_token().await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const TEST_KEY: &str = r#"{
        "client_email": "test@example.com",
        "private_key": "-----BEGIN PRIVATE KEY-----\nTEST\n-----END PRIVATE KEY-----\n",
        "token_uri": "https://example.com"
    }"#;

    fn provider() -> ServiceAccountTokenProvider {
        ServiceAccountTokenProvider::from_json(TEST_KEY, vec!["scope".to_string()]).unwrap()
    }

    #[tokio::test]
    async fn service_account_cache_respects_leeway() {
        let provider = provider().with_leeway(Duration::from_secs(90));
        provider.store_token("token".to_string(), 60).await;
        assert!(provider.cached_token().await.is_none());
    }

    #[tokio::test]
    async fn service_account_cache_keeps_valid_token() {
        let provider = provider();
        provider.store_token("token".to_string(), 120).await;
        assert_eq!(provider.cached_token().await, Some("token".to_string()));
    }

    #[tokio::test]
    async fn static_token_provider_returns_token() {
        let provider = StaticTokenProvider::new("test-token-123");
        let token = provider.access_token().await.unwrap();
        assert_eq!(token, "test-token-123");
    }

    #[tokio::test]
    async fn env_token_provider_reads_from_env() {
        std::env::set_var("TEST_NBLM_TOKEN", "env-token-456");
        let provider = EnvTokenProvider::new("TEST_NBLM_TOKEN");
        let token = provider.access_token().await.unwrap();
        assert_eq!(token, "env-token-456");
        std::env::remove_var("TEST_NBLM_TOKEN");
    }

    #[tokio::test]
    async fn env_token_provider_errors_when_missing() {
        std::env::remove_var("NONEXISTENT_TOKEN");
        let provider = EnvTokenProvider::new("NONEXISTENT_TOKEN");
        let result = provider.access_token().await;
        assert!(result.is_err());
        assert!(result
            .unwrap_err()
            .to_string()
            .contains("environment variable NONEXISTENT_TOKEN missing"));
    }
}