1pub(crate) mod backends;
2
3use crate::client::NblmClient;
4use crate::error::Result;
5use crate::models::enterprise::{
6 audio::{AudioOverviewRequest, AudioOverviewResponse},
7 notebook::{
8 BatchDeleteNotebooksRequest, BatchDeleteNotebooksResponse, ListRecentlyViewedResponse,
9 Notebook,
10 },
11 source::{
12 BatchCreateSourcesRequest, BatchCreateSourcesResponse, BatchDeleteSourcesRequest,
13 BatchDeleteSourcesResponse, NotebookSource, UploadSourceFileResponse, UserContent,
14 },
15};
16
17impl NblmClient {
18 pub async fn create_notebook(&self, title: impl Into<String>) -> Result<Notebook> {
19 self.backends
20 .notebooks()
21 .create_notebook(title.into())
22 .await
23 }
24
25 pub async fn batch_delete_notebooks(
26 &self,
27 request: BatchDeleteNotebooksRequest,
28 ) -> Result<BatchDeleteNotebooksResponse> {
29 self.backends
30 .notebooks()
31 .batch_delete_notebooks(request)
32 .await
33 }
34
35 pub async fn delete_notebooks(
36 &self,
37 notebook_names: Vec<String>,
38 ) -> Result<BatchDeleteNotebooksResponse> {
39 self.backends
40 .notebooks()
41 .delete_notebooks(notebook_names)
42 .await
43 }
44
45 pub async fn list_recently_viewed(
46 &self,
47 page_size: Option<u32>,
48 ) -> Result<ListRecentlyViewedResponse> {
49 self.backends
50 .notebooks()
51 .list_recently_viewed(page_size)
52 .await
53 }
54
55 pub async fn batch_create_sources(
56 &self,
57 notebook_id: &str,
58 request: BatchCreateSourcesRequest,
59 ) -> Result<BatchCreateSourcesResponse> {
60 let includes_drive = has_drive_content(request.user_contents.iter());
61 self.ensure_drive_scope_if_needed(includes_drive).await?;
62 self.backends
63 .sources()
64 .batch_create_sources(notebook_id, request)
65 .await
66 }
67
68 pub async fn add_sources(
69 &self,
70 notebook_id: &str,
71 contents: Vec<UserContent>,
72 ) -> Result<BatchCreateSourcesResponse> {
73 let includes_drive = has_drive_content(contents.iter());
74 self.ensure_drive_scope_if_needed(includes_drive).await?;
75 self.backends
76 .sources()
77 .add_sources(notebook_id, contents)
78 .await
79 }
80
81 pub async fn batch_delete_sources(
82 &self,
83 notebook_id: &str,
84 request: BatchDeleteSourcesRequest,
85 ) -> Result<BatchDeleteSourcesResponse> {
86 self.backends
87 .sources()
88 .batch_delete_sources(notebook_id, request)
89 .await
90 }
91
92 pub async fn delete_sources(
93 &self,
94 notebook_id: &str,
95 source_names: Vec<String>,
96 ) -> Result<BatchDeleteSourcesResponse> {
97 self.backends
98 .sources()
99 .delete_sources(notebook_id, source_names)
100 .await
101 }
102
103 pub async fn upload_source_file(
104 &self,
105 notebook_id: &str,
106 file_name: &str,
107 content_type: &str,
108 data: Vec<u8>,
109 ) -> Result<UploadSourceFileResponse> {
110 self.backends
111 .sources()
112 .upload_source_file(notebook_id, file_name, content_type, data)
113 .await
114 }
115
116 pub async fn get_source(&self, notebook_id: &str, source_id: &str) -> Result<NotebookSource> {
117 self.backends
118 .sources()
119 .get_source(notebook_id, source_id)
120 .await
121 }
122
123 pub async fn create_audio_overview(
124 &self,
125 notebook_id: &str,
126 request: AudioOverviewRequest,
127 ) -> Result<AudioOverviewResponse> {
128 self.backends
129 .audio()
130 .create_audio_overview(notebook_id, request)
131 .await
132 }
133
134 pub async fn delete_audio_overview(&self, notebook_id: &str) -> Result<()> {
135 self.backends
136 .audio()
137 .delete_audio_overview(notebook_id)
138 .await
139 }
140}
141
142fn has_drive_content<'a, I>(contents: I) -> bool
143where
144 I: IntoIterator<Item = &'a UserContent>,
145{
146 contents
147 .into_iter()
148 .any(|content| matches!(content, UserContent::GoogleDrive { .. }))
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::auth::StaticTokenProvider;
155 use crate::env::EnvironmentConfig;
156 use crate::error::Error;
157 use serde_json::json;
158 use serial_test::serial;
159 use std::sync::Arc;
160 use wiremock::matchers::{method, path, query_param};
161 use wiremock::{Mock, MockServer, ResponseTemplate};
162
163 struct EnvGuard {
164 key: &'static str,
165 original: Option<String>,
166 }
167
168 impl EnvGuard {
169 fn new(key: &'static str) -> Self {
170 let original = std::env::var(key).ok();
171 Self { key, original }
172 }
173 }
174
175 impl Drop for EnvGuard {
176 fn drop(&mut self) {
177 if let Some(value) = &self.original {
178 std::env::set_var(self.key, value);
179 } else {
180 std::env::remove_var(self.key);
181 }
182 }
183 }
184
185 async fn build_client(base_url: &str) -> NblmClient {
186 let provider = Arc::new(StaticTokenProvider::new("test-token"));
187 let env = EnvironmentConfig::enterprise("123", "global", "us").unwrap();
188 NblmClient::new(provider, env)
189 .unwrap()
190 .with_base_url(base_url)
191 .unwrap()
192 }
193
194 #[rstest::rstest]
195 #[case::with_drive_scope("https://www.googleapis.com/auth/drive.file", true, 1)]
196 #[case::without_drive_scope("https://www.googleapis.com/auth/cloud-platform", false, 0)]
197 #[tokio::test]
198 #[serial]
199 async fn add_sources_validates_drive_scope(
200 #[case] scope: &str,
201 #[case] should_succeed: bool,
202 #[case] api_call_count: u64,
203 ) {
204 let server = MockServer::start().await;
205 let tokeninfo_url = format!("{}/tokeninfo", server.uri());
206 let _guard = EnvGuard::new("NBLM_TOKENINFO_ENDPOINT");
207 std::env::set_var("NBLM_TOKENINFO_ENDPOINT", &tokeninfo_url);
208
209 Mock::given(method("GET"))
210 .and(path("/tokeninfo"))
211 .and(query_param("access_token", "test-token"))
212 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
213 "scope": scope
214 })))
215 .expect(1)
216 .mount(&server)
217 .await;
218
219 Mock::given(method("POST"))
220 .and(path(
221 "/v1alpha/projects/123/locations/global/notebooks/notebook-id/sources:batchCreate",
222 ))
223 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
224 "sources": [],
225 "errorCount": 0
226 })))
227 .expect(api_call_count)
228 .mount(&server)
229 .await;
230
231 let client = build_client(&format!("{}/v1alpha", server.uri())).await;
232
233 let result = client
234 .add_sources(
235 "notebook-id",
236 vec![UserContent::google_drive(
237 "doc".to_string(),
238 "application/pdf".to_string(),
239 None,
240 )],
241 )
242 .await;
243
244 if should_succeed {
245 assert!(
246 result.is_ok(),
247 "expected add_sources to succeed: {:?}",
248 result
249 );
250 } else {
251 let err = result.expect_err("expected add_sources to fail when drive scope is missing");
252 match err {
253 Error::TokenProvider(message) => {
254 assert!(
255 message.contains("drive.file"),
256 "unexpected message: {message}"
257 );
258 }
259 other => panic!("expected TokenProvider error, got {other:?}"),
260 }
261 }
262 }
263
264 #[tokio::test]
265 #[serial]
266 async fn batch_create_sources_validates_drive_scope() {
267 let server = MockServer::start().await;
268 let tokeninfo_url = format!("{}/tokeninfo", server.uri());
269 let _guard = EnvGuard::new("NBLM_TOKENINFO_ENDPOINT");
270 std::env::set_var("NBLM_TOKENINFO_ENDPOINT", &tokeninfo_url);
271
272 Mock::given(method("GET"))
273 .and(path("/tokeninfo"))
274 .and(query_param("access_token", "test-token"))
275 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
276 "scope": "https://www.googleapis.com/auth/drive"
277 })))
278 .expect(1)
279 .mount(&server)
280 .await;
281
282 Mock::given(method("POST"))
283 .and(path(
284 "/v1alpha/projects/123/locations/global/notebooks/notebook-id/sources:batchCreate",
285 ))
286 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
287 "sources": [],
288 "errorCount": 0
289 })))
290 .expect(1)
291 .mount(&server)
292 .await;
293
294 let client = build_client(&format!("{}/v1alpha", server.uri())).await;
295
296 let request = BatchCreateSourcesRequest {
297 user_contents: vec![UserContent::google_drive(
298 "doc".to_string(),
299 "application/pdf".to_string(),
300 None,
301 )],
302 };
303
304 let result = client
305 .batch_create_sources("notebook-id", request)
306 .await
307 .expect("expected batch_create_sources to succeed");
308
309 assert!(result.sources.is_empty());
310 }
311}