nblm_core/client/api/
mod.rs

1pub(crate) mod backends;
2
3use crate::client::NblmClient;
4use crate::error::Result;
5use crate::models::enterprise::{
6    audio::{AudioOverviewRequest, AudioOverviewResponse},
7    notebook::{
8        BatchDeleteNotebooksRequest, BatchDeleteNotebooksResponse, ListRecentlyViewedResponse,
9        Notebook,
10    },
11    source::{
12        BatchCreateSourcesRequest, BatchCreateSourcesResponse, BatchDeleteSourcesRequest,
13        BatchDeleteSourcesResponse, NotebookSource, UploadSourceFileResponse, UserContent,
14    },
15};
16
17impl NblmClient {
18    pub async fn create_notebook(&self, title: impl Into<String>) -> Result<Notebook> {
19        self.backends
20            .notebooks()
21            .create_notebook(title.into())
22            .await
23    }
24
25    pub async fn batch_delete_notebooks(
26        &self,
27        request: BatchDeleteNotebooksRequest,
28    ) -> Result<BatchDeleteNotebooksResponse> {
29        self.backends
30            .notebooks()
31            .batch_delete_notebooks(request)
32            .await
33    }
34
35    pub async fn delete_notebooks(
36        &self,
37        notebook_names: Vec<String>,
38    ) -> Result<BatchDeleteNotebooksResponse> {
39        self.backends
40            .notebooks()
41            .delete_notebooks(notebook_names)
42            .await
43    }
44
45    pub async fn list_recently_viewed(
46        &self,
47        page_size: Option<u32>,
48    ) -> Result<ListRecentlyViewedResponse> {
49        self.backends
50            .notebooks()
51            .list_recently_viewed(page_size)
52            .await
53    }
54
55    pub async fn batch_create_sources(
56        &self,
57        notebook_id: &str,
58        request: BatchCreateSourcesRequest,
59    ) -> Result<BatchCreateSourcesResponse> {
60        let includes_drive = has_drive_content(request.user_contents.iter());
61        self.ensure_drive_scope_if_needed(includes_drive).await?;
62        self.backends
63            .sources()
64            .batch_create_sources(notebook_id, request)
65            .await
66    }
67
68    pub async fn add_sources(
69        &self,
70        notebook_id: &str,
71        contents: Vec<UserContent>,
72    ) -> Result<BatchCreateSourcesResponse> {
73        let includes_drive = has_drive_content(contents.iter());
74        self.ensure_drive_scope_if_needed(includes_drive).await?;
75        self.backends
76            .sources()
77            .add_sources(notebook_id, contents)
78            .await
79    }
80
81    pub async fn batch_delete_sources(
82        &self,
83        notebook_id: &str,
84        request: BatchDeleteSourcesRequest,
85    ) -> Result<BatchDeleteSourcesResponse> {
86        self.backends
87            .sources()
88            .batch_delete_sources(notebook_id, request)
89            .await
90    }
91
92    pub async fn delete_sources(
93        &self,
94        notebook_id: &str,
95        source_names: Vec<String>,
96    ) -> Result<BatchDeleteSourcesResponse> {
97        self.backends
98            .sources()
99            .delete_sources(notebook_id, source_names)
100            .await
101    }
102
103    pub async fn upload_source_file(
104        &self,
105        notebook_id: &str,
106        file_name: &str,
107        content_type: &str,
108        data: Vec<u8>,
109    ) -> Result<UploadSourceFileResponse> {
110        self.backends
111            .sources()
112            .upload_source_file(notebook_id, file_name, content_type, data)
113            .await
114    }
115
116    pub async fn get_source(&self, notebook_id: &str, source_id: &str) -> Result<NotebookSource> {
117        self.backends
118            .sources()
119            .get_source(notebook_id, source_id)
120            .await
121    }
122
123    pub async fn create_audio_overview(
124        &self,
125        notebook_id: &str,
126        request: AudioOverviewRequest,
127    ) -> Result<AudioOverviewResponse> {
128        self.backends
129            .audio()
130            .create_audio_overview(notebook_id, request)
131            .await
132    }
133
134    pub async fn delete_audio_overview(&self, notebook_id: &str) -> Result<()> {
135        self.backends
136            .audio()
137            .delete_audio_overview(notebook_id)
138            .await
139    }
140}
141
142fn has_drive_content<'a, I>(contents: I) -> bool
143where
144    I: IntoIterator<Item = &'a UserContent>,
145{
146    contents
147        .into_iter()
148        .any(|content| matches!(content, UserContent::GoogleDrive { .. }))
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::auth::StaticTokenProvider;
155    use crate::env::EnvironmentConfig;
156    use crate::error::Error;
157    use serde_json::json;
158    use serial_test::serial;
159    use std::sync::Arc;
160    use wiremock::matchers::{method, path, query_param};
161    use wiremock::{Mock, MockServer, ResponseTemplate};
162
163    struct EnvGuard {
164        key: &'static str,
165        original: Option<String>,
166    }
167
168    impl EnvGuard {
169        fn new(key: &'static str) -> Self {
170            let original = std::env::var(key).ok();
171            Self { key, original }
172        }
173    }
174
175    impl Drop for EnvGuard {
176        fn drop(&mut self) {
177            if let Some(value) = &self.original {
178                std::env::set_var(self.key, value);
179            } else {
180                std::env::remove_var(self.key);
181            }
182        }
183    }
184
185    async fn build_client(base_url: &str) -> NblmClient {
186        let provider = Arc::new(StaticTokenProvider::new("test-token"));
187        let env = EnvironmentConfig::enterprise("123", "global", "us").unwrap();
188        NblmClient::new(provider, env)
189            .unwrap()
190            .with_base_url(base_url)
191            .unwrap()
192    }
193
194    #[rstest::rstest]
195    #[case::with_drive_scope("https://www.googleapis.com/auth/drive.file", true, 1)]
196    #[case::without_drive_scope("https://www.googleapis.com/auth/cloud-platform", false, 0)]
197    #[tokio::test]
198    #[serial]
199    async fn add_sources_validates_drive_scope(
200        #[case] scope: &str,
201        #[case] should_succeed: bool,
202        #[case] api_call_count: u64,
203    ) {
204        let server = MockServer::start().await;
205        let tokeninfo_url = format!("{}/tokeninfo", server.uri());
206        let _guard = EnvGuard::new("NBLM_TOKENINFO_ENDPOINT");
207        std::env::set_var("NBLM_TOKENINFO_ENDPOINT", &tokeninfo_url);
208
209        Mock::given(method("GET"))
210            .and(path("/tokeninfo"))
211            .and(query_param("access_token", "test-token"))
212            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
213                "scope": scope
214            })))
215            .expect(1)
216            .mount(&server)
217            .await;
218
219        Mock::given(method("POST"))
220            .and(path(
221                "/v1alpha/projects/123/locations/global/notebooks/notebook-id/sources:batchCreate",
222            ))
223            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
224                "sources": [],
225                "errorCount": 0
226            })))
227            .expect(api_call_count)
228            .mount(&server)
229            .await;
230
231        let client = build_client(&format!("{}/v1alpha", server.uri())).await;
232
233        let result = client
234            .add_sources(
235                "notebook-id",
236                vec![UserContent::google_drive(
237                    "doc".to_string(),
238                    "application/pdf".to_string(),
239                    None,
240                )],
241            )
242            .await;
243
244        if should_succeed {
245            assert!(
246                result.is_ok(),
247                "expected add_sources to succeed: {:?}",
248                result
249            );
250        } else {
251            let err = result.expect_err("expected add_sources to fail when drive scope is missing");
252            match err {
253                Error::TokenProvider(message) => {
254                    assert!(
255                        message.contains("drive.file"),
256                        "unexpected message: {message}"
257                    );
258                }
259                other => panic!("expected TokenProvider error, got {other:?}"),
260            }
261        }
262    }
263
264    #[tokio::test]
265    #[serial]
266    async fn batch_create_sources_validates_drive_scope() {
267        let server = MockServer::start().await;
268        let tokeninfo_url = format!("{}/tokeninfo", server.uri());
269        let _guard = EnvGuard::new("NBLM_TOKENINFO_ENDPOINT");
270        std::env::set_var("NBLM_TOKENINFO_ENDPOINT", &tokeninfo_url);
271
272        Mock::given(method("GET"))
273            .and(path("/tokeninfo"))
274            .and(query_param("access_token", "test-token"))
275            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
276                "scope": "https://www.googleapis.com/auth/drive"
277            })))
278            .expect(1)
279            .mount(&server)
280            .await;
281
282        Mock::given(method("POST"))
283            .and(path(
284                "/v1alpha/projects/123/locations/global/notebooks/notebook-id/sources:batchCreate",
285            ))
286            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
287                "sources": [],
288                "errorCount": 0
289            })))
290            .expect(1)
291            .mount(&server)
292            .await;
293
294        let client = build_client(&format!("{}/v1alpha", server.uri())).await;
295
296        let request = BatchCreateSourcesRequest {
297            user_contents: vec![UserContent::google_drive(
298                "doc".to_string(),
299                "application/pdf".to_string(),
300                None,
301            )],
302        };
303
304        let result = client
305            .batch_create_sources("notebook-id", request)
306            .await
307            .expect("expected batch_create_sources to succeed");
308
309        assert!(result.sources.is_empty());
310    }
311}