nblm-core 0.2.2

Core library for NotebookLM Enterprise API client
Documentation
use std::borrow::Cow;
use std::sync::{Arc, OnceLock};

use bytes::Bytes;
use reqwest::{header::HeaderMap, Client, Method, RequestBuilder, StatusCode, Url};
use serde::de::DeserializeOwned;
use serde::Serialize;

use crate::auth::TokenProvider;
use crate::error::{Error, Result};

use super::retry::Retryer;

/// HTTP layer implementation for NBLM API requests
#[derive(Clone)]
pub(crate) struct HttpClient {
    pub(super) client: Client,
    pub(super) token_provider: Arc<dyn TokenProvider>,
    pub(super) retryer: Retryer,
    pub(super) user_project: Option<String>,
}

impl HttpClient {
    pub fn new(
        client: Client,
        token_provider: Arc<dyn TokenProvider>,
        retryer: Retryer,
        user_project: Option<String>,
    ) -> Self {
        Self {
            client,
            token_provider,
            retryer,
            user_project,
        }
    }

    pub async fn request_json<B, R>(&self, method: Method, url: Url, body: Option<&B>) -> Result<R>
    where
        B: Serialize + ?Sized,
        R: DeserializeOwned,
    {
        let body_bytes = match body {
            Some(value) => Some(Bytes::from(serde_json::to_vec(value).map_err(Error::Json)?)),
            None => None,
        };
        let body_bytes = Arc::new(body_bytes);
        let builder_fn = {
            let body_bytes = Arc::clone(&body_bytes);
            move |mut builder: RequestBuilder| -> Result<RequestBuilder> {
                if let Some(bytes) = &*body_bytes {
                    builder = builder
                        .header(reqwest::header::CONTENT_TYPE, "application/json")
                        .body(bytes.clone());
                }
                Ok(builder)
            }
        };

        let method_for_parse = method.clone();
        let url_for_parse = url.clone();
        let response = self.execute_with_builder(method, url, builder_fn).await?;

        parse_json_response(&method_for_parse, &url_for_parse, response).await
    }

    pub async fn request_binary<R>(
        &self,
        method: Method,
        url: Url,
        headers: HeaderMap,
        body: Bytes,
    ) -> Result<R>
    where
        R: DeserializeOwned,
    {
        let header_entries: Vec<(reqwest::header::HeaderName, reqwest::header::HeaderValue)> =
            headers
                .iter()
                .map(|(k, v)| (k.clone(), v.clone()))
                .collect();
        let header_entries = Arc::new(header_entries);
        let body = body;
        let builder_fn = {
            let header_entries = Arc::clone(&header_entries);
            let body = body.clone();
            move |mut builder: RequestBuilder| -> Result<RequestBuilder> {
                for (key, value) in header_entries.iter() {
                    builder = builder.header(key.clone(), value.clone());
                }
                builder = builder.body(body.clone());
                Ok(builder)
            }
        };

        let method_for_parse = method.clone();
        let url_for_parse = url.clone();
        let response = self.execute_with_builder(method, url, builder_fn).await?;

        parse_json_response(&method_for_parse, &url_for_parse, response).await
    }
}

const MAX_BODY_PREVIEW: usize = 2048;

fn debug_http_enabled() -> bool {
    static FLAG: OnceLock<bool> = OnceLock::new();
    *FLAG.get_or_init(|| match std::env::var("NBLM_DEBUG_HTTP") {
        Ok(value) => matches!(value.to_ascii_lowercase().as_str(), "1" | "true" | "yes"),
        Err(_) => false,
    })
}

fn build_body_preview(body: &[u8]) -> Cow<'_, str> {
    match std::str::from_utf8(body) {
        Ok(text) => {
            if text.len() > MAX_BODY_PREVIEW {
                let mut preview = text[..MAX_BODY_PREVIEW].to_string();
                preview.push('');
                Cow::Owned(preview)
            } else {
                Cow::Borrowed(text)
            }
        }
        Err(_) => Cow::Owned(format!("<non-utf8 body: {} bytes>", body.len())),
    }
}

fn log_http_response(method: &Method, url: &Url, status: StatusCode, body: &[u8]) {
    if !debug_http_enabled() {
        return;
    }

    let preview = build_body_preview(body);
    eprintln!(
        "[nblm::http] method={} status={} url={} body_len={} body={}",
        method,
        status.as_u16(),
        url,
        body.len(),
        preview
    );
}

async fn parse_json_response<R>(
    method: &Method,
    url: &Url,
    response: reqwest::Response,
) -> Result<R>
where
    R: DeserializeOwned,
{
    let status = response.status();
    let body = response.bytes().await.map_err(Error::Request)?;
    log_http_response(method, url, status, &body);

    if !status.is_success() {
        let text = String::from_utf8_lossy(&body).into_owned();
        return Err(Error::http(status, text));
    }

    let parsed = serde_json::from_slice::<R>(&body)?;
    Ok(parsed)
}

impl HttpClient {
    async fn execute_with_builder<F>(
        &self,
        method: Method,
        url: Url,
        builder_fn: F,
    ) -> Result<reqwest::Response>
    where
        F: Fn(RequestBuilder) -> Result<RequestBuilder> + Send + Sync + 'static,
    {
        let client = self.client.clone();
        let method_clone = method.clone();
        let url_clone = url.clone();
        let provider = Arc::clone(&self.token_provider);
        let user_project = self.user_project.clone();
        let builder_fn = Arc::new(builder_fn);

        let run = {
            let client = client.clone();
            let method = method_clone.clone();
            let url = url_clone.clone();
            let provider = Arc::clone(&provider);
            let user_project = user_project.clone();
            let builder_fn = Arc::clone(&builder_fn);
            move || {
                let client = client.clone();
                let method = method.clone();
                let url = url.clone();
                let provider = Arc::clone(&provider);
                let user_project = user_project.clone();
                let builder_fn = Arc::clone(&builder_fn);
                async move {
                    let token = provider.access_token().await?;
                    let mut builder = client.request(method, url).bearer_auth(token);
                    if let Some(project) = &user_project {
                        builder = builder.header("x-goog-user-project", project);
                    }
                    builder = builder_fn(builder)?;
                    let request = builder.build().map_err(Error::Request)?;
                    let response = client.execute(request).await.map_err(Error::Request)?;
                    Ok(response)
                }
            }
        };

        let mut response = self.retryer.run_with_retry(run).await?;

        if response.status() == StatusCode::UNAUTHORIZED {
            let status = response.status();
            let body = response.bytes().await.map_err(Error::Request)?;
            log_http_response(&method, &url, status, &body);
            let run_refresh = {
                let client = client.clone();
                let method = method_clone.clone();
                let url = url_clone.clone();
                let provider = Arc::clone(&provider);
                let user_project = user_project.clone();
                let builder_fn = Arc::clone(&builder_fn);
                move || {
                    let client = client.clone();
                    let method = method.clone();
                    let url = url.clone();
                    let provider = Arc::clone(&provider);
                    let user_project = user_project.clone();
                    let builder_fn = Arc::clone(&builder_fn);
                    async move {
                        let token = provider.refresh_token().await?;
                        let mut builder = client.request(method, url).bearer_auth(token);
                        if let Some(project) = &user_project {
                            builder = builder.header("x-goog-user-project", project);
                        }
                        builder = builder_fn(builder)?;
                        let request = builder.build().map_err(Error::Request)?;
                        let response = client.execute(request).await.map_err(Error::Request)?;
                        Ok(response)
                    }
                }
            };
            response = self.retryer.run_with_retry(run_refresh).await?;
        }

        Ok(response)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn build_body_preview_returns_borrowed_for_short_utf8() {
        let input = b"short body";
        let preview = build_body_preview(input);
        assert!(matches!(preview, Cow::Borrowed("short body")));
    }

    #[test]
    fn build_body_preview_truncates_long_utf8() {
        let long_text = "x".repeat(MAX_BODY_PREVIEW + 10);
        let preview = build_body_preview(long_text.as_bytes());
        let expected = format!("{}", "x".repeat(MAX_BODY_PREVIEW));
        match preview {
            Cow::Owned(truncated) => assert_eq!(truncated, expected),
            _ => panic!("expected owned truncated preview"),
        }
    }

    #[test]
    fn build_body_preview_handles_non_utf8() {
        let binary = [0xffu8, 0x00, 0xfe];
        let preview = build_body_preview(&binary);
        let expected = format!("<non-utf8 body: {} bytes>", binary.len());
        match preview {
            Cow::Owned(msg) => assert_eq!(msg, expected),
            _ => panic!("expected owned message for non-utf8 body"),
        }
    }
}