nblm_core/
client.rs

1use std::{sync::Arc, time::Duration};
2
3use reqwest::{Client, Method, StatusCode, Url};
4use serde::de::DeserializeOwned;
5use serde::Serialize;
6
7use crate::auth::TokenProvider;
8use crate::error::{Error, Result};
9use crate::models::{
10    AccountRole, AudioOverviewRequest, AudioOverviewResponse, BatchCreateSourcesRequest,
11    BatchCreateSourcesResponse, BatchDeleteNotebooksRequest, BatchDeleteNotebooksResponse,
12    BatchDeleteSourcesRequest, BatchDeleteSourcesResponse, CreateNotebookRequest,
13    ListRecentlyViewedResponse, Notebook, ShareRequest, ShareResponse, UserContent,
14};
15use crate::retry::{RetryConfig, Retryer};
16
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
18const PAGE_SIZE_MIN: u32 = 1;
19const PAGE_SIZE_MAX: u32 = 500;
20
21pub struct NblmClient {
22    http: Client,
23    token_provider: Arc<dyn TokenProvider>,
24    base: String,
25    parent: String,
26    timeout: Duration,
27    retryer: Retryer,
28    user_project: Option<String>,
29}
30
31impl NblmClient {
32    pub fn new(
33        token_provider: Arc<dyn TokenProvider>,
34        project_number: impl Into<String>,
35        location: impl Into<String>,
36        endpoint_location: impl Into<String>,
37    ) -> Result<Self> {
38        let project_number = project_number.into();
39        let location = location.into();
40        let endpoint_location = endpoint_location.into();
41        let base = format!(
42            "https://{}discoveryengine.googleapis.com/v1alpha",
43            normalize_endpoint_location(endpoint_location)?
44        );
45        let parent = format!("projects/{}/locations/{}", project_number, location);
46
47        let http = Client::builder()
48            .user_agent(concat!("nblm-cli/", env!("CARGO_PKG_VERSION")))
49            .timeout(DEFAULT_TIMEOUT)
50            .build()
51            .map_err(Error::from)?;
52
53        Ok(Self {
54            http,
55            token_provider,
56            base: base.trim_end_matches('/').to_string(),
57            parent,
58            timeout: DEFAULT_TIMEOUT,
59            retryer: Retryer::new(RetryConfig::default()),
60            user_project: None,
61        })
62    }
63
64    pub fn with_timeout(mut self, timeout: Duration) -> Self {
65        self.timeout = timeout;
66        self
67    }
68
69    pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
70        self.retryer = Retryer::new(config);
71        self
72    }
73
74    pub fn with_user_project(mut self, project: impl Into<String>) -> Self {
75        self.user_project = Some(project.into());
76        self
77    }
78
79    /// Override API base URL (for tests). Accepts absolute URL. Trims trailing slash.
80    pub fn with_base_url(mut self, base: impl Into<String>) -> Result<Self> {
81        let base = base.into().trim().trim_end_matches('/').to_string();
82        // Basic sanity check: absolute URL
83        let _ = Url::parse(&base).map_err(Error::from)?;
84        self.base = base;
85        Ok(self)
86    }
87
88    fn notebooks_collection(&self) -> String {
89        format!("{}/notebooks", self.parent)
90    }
91
92    fn notebook_path(&self, notebook_id: &str) -> String {
93        format!("{}/notebooks/{}", self.parent, notebook_id)
94    }
95
96    fn build_url(&self, path: &str) -> Result<Url> {
97        let path = path.trim_start_matches('/');
98        Url::parse(&format!("{}/{}", self.base, path)).map_err(Error::from)
99    }
100
101    async fn request_json<B, R>(&self, method: Method, url: Url, body: Option<&B>) -> Result<R>
102    where
103        B: Serialize + ?Sized,
104        R: DeserializeOwned,
105    {
106        let client = self.http.clone();
107        let method_clone = method.clone();
108        let url_clone = url.clone();
109        let timeout = self.timeout;
110        let body_ref = body;
111        let provider = Arc::clone(&self.token_provider);
112        let user_project = self.user_project.clone();
113
114        let run = || {
115            let client = client.clone();
116            let method = method_clone.clone();
117            let url = url_clone.clone();
118            let provider = Arc::clone(&provider);
119            let user_project = user_project.clone();
120            async move {
121                let token = provider.access_token().await?;
122                let mut builder = client
123                    .request(method, url)
124                    .bearer_auth(token)
125                    .timeout(timeout);
126                if let Some(project) = &user_project {
127                    builder = builder.header("x-goog-user-project", project);
128                }
129                if let Some(body) = body_ref {
130                    builder = builder.json(body);
131                }
132                let request = builder.build().map_err(Error::Request)?;
133                let response = client.execute(request).await.map_err(Error::Request)?;
134                Ok(response)
135            }
136        };
137
138        let mut response = self.retryer.run_with_retry(run).await?;
139
140        if response.status() == StatusCode::UNAUTHORIZED {
141            let _ = response.bytes().await;
142            let run_refresh = || {
143                let client = client.clone();
144                let method = method_clone.clone();
145                let url = url_clone.clone();
146                let provider = Arc::clone(&provider);
147                let user_project = user_project.clone();
148                async move {
149                    let token = provider.refresh_token().await?;
150                    let mut builder = client
151                        .request(method, url)
152                        .bearer_auth(token)
153                        .timeout(timeout);
154                    if let Some(project) = &user_project {
155                        builder = builder.header("x-goog-user-project", project);
156                    }
157                    if let Some(body) = body_ref {
158                        builder = builder.json(body);
159                    }
160                    let request = builder.build().map_err(Error::Request)?;
161                    let response = client.execute(request).await.map_err(Error::Request)?;
162                    Ok(response)
163                }
164            };
165            response = self.retryer.run_with_retry(run_refresh).await?;
166            if !response.status().is_success() {
167                let status = response.status();
168                let body = response.text().await.unwrap_or_default();
169                return Err(Error::http(status, body));
170            }
171            return Ok(response.json::<R>().await?);
172        }
173
174        if !response.status().is_success() {
175            let status = response.status();
176            let body = response.text().await.unwrap_or_default();
177            return Err(Error::http(status, body));
178        }
179
180        Ok(response.json::<R>().await?)
181    }
182
183    pub async fn create_notebook(&self, title: impl Into<String>) -> Result<Notebook> {
184        let url = self.build_url(&self.notebooks_collection())?;
185        let request = CreateNotebookRequest {
186            title: title.into(),
187        };
188        self.request_json(Method::POST, url, Some(&request)).await
189    }
190
191    /// Delete notebooks using the batchDelete API.
192    ///
193    /// # Known Issues (as of 2025-10-19)
194    ///
195    /// The API only accepts a single notebook name despite being named "batchDelete".
196    /// Multiple names result in HTTP 400 error. Use `delete_notebooks` which handles
197    /// this limitation by calling the API once per notebook.
198    pub async fn batch_delete_notebooks(
199        &self,
200        request: BatchDeleteNotebooksRequest,
201    ) -> Result<BatchDeleteNotebooksResponse> {
202        let path = format!("{}:batchDelete", self.notebooks_collection());
203        let url = self.build_url(&path)?;
204        self.request_json(Method::POST, url, Some(&request)).await
205    }
206
207    /// Delete one or more notebooks.
208    ///
209    /// # Implementation Note
210    ///
211    /// Despite the underlying API being named "batchDelete", it only accepts one notebook
212    /// at a time (as of 2025-10-19). This method works around this limitation by calling
213    /// the API sequentially for each notebook.
214    pub async fn delete_notebooks(
215        &self,
216        notebook_names: Vec<String>,
217    ) -> Result<BatchDeleteNotebooksResponse> {
218        // TODO: Remove sequential processing when API supports true batch deletion
219        for name in &notebook_names {
220            let request = BatchDeleteNotebooksRequest {
221                names: vec![name.clone()],
222            };
223            self.batch_delete_notebooks(request).await?;
224        }
225        // Return empty response after all deletions succeed
226        Ok(BatchDeleteNotebooksResponse::default())
227    }
228
229    pub async fn batch_create_sources(
230        &self,
231        notebook_id: &str,
232        request: BatchCreateSourcesRequest,
233    ) -> Result<BatchCreateSourcesResponse> {
234        let path = format!("{}/sources:batchCreate", self.notebook_path(notebook_id));
235        let url = self.build_url(&path)?;
236        self.request_json(Method::POST, url, Some(&request)).await
237    }
238
239    // TODO: This method has not been tested due to the requirement of setting up additional user accounts.
240    pub async fn share_notebook(
241        &self,
242        notebook_id: &str,
243        accounts: Vec<AccountRole>,
244    ) -> Result<ShareResponse> {
245        let path = format!("{}:share", self.notebook_path(notebook_id));
246        let url = self.build_url(&path)?;
247        let request = ShareRequest {
248            account_and_roles: accounts,
249        };
250        self.request_json(Method::POST, url, Some(&request)).await
251    }
252
253    pub async fn create_audio_overview(
254        &self,
255        notebook_id: &str,
256        request: AudioOverviewRequest,
257    ) -> Result<AudioOverviewResponse> {
258        let path = format!("{}/audioOverviews", self.notebook_path(notebook_id));
259        let url = self.build_url(&path)?;
260        self.request_json(Method::POST, url, Some(&request)).await
261    }
262
263    pub async fn delete_audio_overview(&self, notebook_id: &str) -> Result<()> {
264        let path = format!("{}/audioOverviews/default", self.notebook_path(notebook_id));
265        let url = self.build_url(&path)?;
266        let _response: serde_json::Value =
267            self.request_json(Method::DELETE, url, None::<&()>).await?;
268        Ok(())
269    }
270
271    /// List recently viewed notebooks.
272    ///
273    /// # Pagination (Not Implemented by API)
274    ///
275    /// While this method accepts `page_size` and `page_token` parameters,
276    /// the NotebookLM API does not currently implement pagination:
277    /// - `page_size` parameter is accepted but ignored by the API
278    /// - `next_page_token` is never returned in the response
279    /// - All available notebooks are returned regardless of page_size
280    ///
281    /// These parameters are included for future compatibility if the API
282    /// implements pagination in the future.
283    pub async fn list_recently_viewed(
284        &self,
285        page_size: Option<u32>,
286        page_token: Option<&str>,
287    ) -> Result<ListRecentlyViewedResponse> {
288        let path = format!("{}:listRecentlyViewed", self.notebooks_collection());
289        let mut url = self.build_url(&path)?;
290        {
291            let mut pairs = url.query_pairs_mut();
292            if let Some(size) = page_size {
293                let clamped = size.clamp(PAGE_SIZE_MIN, PAGE_SIZE_MAX);
294                pairs.append_pair("pageSize", &clamped.to_string());
295            }
296            if let Some(token) = page_token {
297                pairs.append_pair("pageToken", token);
298            }
299        }
300        self.request_json::<(), _>(Method::GET, url, None::<&()>)
301            .await
302    }
303
304    pub async fn add_sources(
305        &self,
306        notebook_id: &str,
307        contents: Vec<UserContent>,
308    ) -> Result<BatchCreateSourcesResponse> {
309        let request = BatchCreateSourcesRequest {
310            user_contents: contents,
311        };
312        self.batch_create_sources(notebook_id, request).await
313    }
314
315    pub async fn batch_delete_sources(
316        &self,
317        notebook_id: &str,
318        request: BatchDeleteSourcesRequest,
319    ) -> Result<BatchDeleteSourcesResponse> {
320        let path = format!("{}/sources:batchDelete", self.notebook_path(notebook_id));
321        let url = self.build_url(&path)?;
322        self.request_json(Method::POST, url, Some(&request)).await
323    }
324
325    pub async fn delete_sources(
326        &self,
327        notebook_id: &str,
328        source_names: Vec<String>,
329    ) -> Result<BatchDeleteSourcesResponse> {
330        let request = BatchDeleteSourcesRequest {
331            names: source_names,
332        };
333        self.batch_delete_sources(notebook_id, request).await
334    }
335}
336
337fn normalize_endpoint_location(input: String) -> Result<String> {
338    let trimmed = input.trim().trim_end_matches('-').to_lowercase();
339    let normalized = match trimmed.as_str() {
340        "us" => "us-",
341        "eu" => "eu-",
342        "global" => "global-",
343        other => {
344            return Err(Error::Endpoint(format!(
345                "unsupported endpoint location: {other}"
346            )))
347        }
348    };
349    Ok(normalized.to_string())
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn normalize_endpoint_location_variants() {
358        assert_eq!(
359            normalize_endpoint_location("us".into()).unwrap(),
360            "us-".to_string()
361        );
362        assert_eq!(
363            normalize_endpoint_location("eu-".into()).unwrap(),
364            "eu-".to_string()
365        );
366        assert_eq!(
367            normalize_endpoint_location(" global ".into()).unwrap(),
368            "global-".to_string()
369        );
370    }
371
372    #[test]
373    fn normalize_endpoint_location_invalid() {
374        let err = normalize_endpoint_location("asia".into()).unwrap_err();
375        assert!(format!("{err}").contains("unsupported endpoint location"));
376    }
377
378    #[test]
379    fn with_base_url_accepts_absolute_url() {
380        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
381        let client = NblmClient::new(provider, "123", "global", "us").unwrap();
382        let result = client.with_base_url("http://localhost:8080/v1alpha");
383        assert!(result.is_ok());
384    }
385
386    #[test]
387    fn with_base_url_trims_trailing_slash() {
388        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
389        let client = NblmClient::new(provider, "123", "global", "us")
390            .unwrap()
391            .with_base_url("http://example.com/v1alpha/")
392            .unwrap();
393        assert_eq!(client.base, "http://example.com/v1alpha");
394    }
395
396    #[test]
397    fn with_base_url_rejects_relative_path() {
398        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
399        let client = NblmClient::new(provider, "123", "global", "us").unwrap();
400        let result = client.with_base_url("/relative/path");
401        assert!(result.is_err());
402    }
403
404    #[test]
405    fn build_url_combines_base_and_path_correctly() {
406        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
407        let client = NblmClient::new(provider, "123", "global", "us")
408            .unwrap()
409            .with_base_url("http://example.com/v1alpha")
410            .unwrap();
411
412        // Test with leading slash
413        let url = client.build_url("/projects/123/notebooks").unwrap();
414        assert_eq!(
415            url.as_str(),
416            "http://example.com/v1alpha/projects/123/notebooks"
417        );
418
419        // Test without leading slash
420        let url = client.build_url("projects/123/notebooks").unwrap();
421        assert_eq!(
422            url.as_str(),
423            "http://example.com/v1alpha/projects/123/notebooks"
424        );
425    }
426}