1use std::{sync::Arc, time::Duration};
2
3use reqwest::{Client, Method, StatusCode, Url};
4use serde::de::DeserializeOwned;
5use serde::Serialize;
6
7use crate::auth::TokenProvider;
8use crate::error::{Error, Result};
9use crate::models::{
10 AccountRole, AudioOverviewRequest, AudioOverviewResponse, BatchCreateSourcesRequest,
11 BatchCreateSourcesResponse, BatchDeleteNotebooksRequest, BatchDeleteNotebooksResponse,
12 BatchDeleteSourcesRequest, BatchDeleteSourcesResponse, CreateNotebookRequest,
13 ListRecentlyViewedResponse, Notebook, ShareRequest, ShareResponse, UserContent,
14};
15use crate::retry::{RetryConfig, Retryer};
16
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
18const PAGE_SIZE_MIN: u32 = 1;
19const PAGE_SIZE_MAX: u32 = 500;
20
21pub struct NblmClient {
22 http: Client,
23 token_provider: Arc<dyn TokenProvider>,
24 base: String,
25 parent: String,
26 timeout: Duration,
27 retryer: Retryer,
28 user_project: Option<String>,
29}
30
31impl NblmClient {
32 pub fn new(
33 token_provider: Arc<dyn TokenProvider>,
34 project_number: impl Into<String>,
35 location: impl Into<String>,
36 endpoint_location: impl Into<String>,
37 ) -> Result<Self> {
38 let project_number = project_number.into();
39 let location = location.into();
40 let endpoint_location = endpoint_location.into();
41 let base = format!(
42 "https://{}discoveryengine.googleapis.com/v1alpha",
43 normalize_endpoint_location(endpoint_location)?
44 );
45 let parent = format!("projects/{}/locations/{}", project_number, location);
46
47 let http = Client::builder()
48 .user_agent(concat!("nblm-cli/", env!("CARGO_PKG_VERSION")))
49 .timeout(DEFAULT_TIMEOUT)
50 .build()
51 .map_err(Error::from)?;
52
53 Ok(Self {
54 http,
55 token_provider,
56 base: base.trim_end_matches('/').to_string(),
57 parent,
58 timeout: DEFAULT_TIMEOUT,
59 retryer: Retryer::new(RetryConfig::default()),
60 user_project: None,
61 })
62 }
63
64 pub fn with_timeout(mut self, timeout: Duration) -> Self {
65 self.timeout = timeout;
66 self
67 }
68
69 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
70 self.retryer = Retryer::new(config);
71 self
72 }
73
74 pub fn with_user_project(mut self, project: impl Into<String>) -> Self {
75 self.user_project = Some(project.into());
76 self
77 }
78
79 pub fn with_base_url(mut self, base: impl Into<String>) -> Result<Self> {
81 let base = base.into().trim().trim_end_matches('/').to_string();
82 let _ = Url::parse(&base).map_err(Error::from)?;
84 self.base = base;
85 Ok(self)
86 }
87
88 fn notebooks_collection(&self) -> String {
89 format!("{}/notebooks", self.parent)
90 }
91
92 fn notebook_path(&self, notebook_id: &str) -> String {
93 format!("{}/notebooks/{}", self.parent, notebook_id)
94 }
95
96 fn build_url(&self, path: &str) -> Result<Url> {
97 let path = path.trim_start_matches('/');
98 Url::parse(&format!("{}/{}", self.base, path)).map_err(Error::from)
99 }
100
101 async fn request_json<B, R>(&self, method: Method, url: Url, body: Option<&B>) -> Result<R>
102 where
103 B: Serialize + ?Sized,
104 R: DeserializeOwned,
105 {
106 let client = self.http.clone();
107 let method_clone = method.clone();
108 let url_clone = url.clone();
109 let timeout = self.timeout;
110 let body_ref = body;
111 let provider = Arc::clone(&self.token_provider);
112 let user_project = self.user_project.clone();
113
114 let run = || {
115 let client = client.clone();
116 let method = method_clone.clone();
117 let url = url_clone.clone();
118 let provider = Arc::clone(&provider);
119 let user_project = user_project.clone();
120 async move {
121 let token = provider.access_token().await?;
122 let mut builder = client
123 .request(method, url)
124 .bearer_auth(token)
125 .timeout(timeout);
126 if let Some(project) = &user_project {
127 builder = builder.header("x-goog-user-project", project);
128 }
129 if let Some(body) = body_ref {
130 builder = builder.json(body);
131 }
132 let request = builder.build().map_err(Error::Request)?;
133 let response = client.execute(request).await.map_err(Error::Request)?;
134 Ok(response)
135 }
136 };
137
138 let mut response = self.retryer.run_with_retry(run).await?;
139
140 if response.status() == StatusCode::UNAUTHORIZED {
141 let _ = response.bytes().await;
142 let run_refresh = || {
143 let client = client.clone();
144 let method = method_clone.clone();
145 let url = url_clone.clone();
146 let provider = Arc::clone(&provider);
147 let user_project = user_project.clone();
148 async move {
149 let token = provider.refresh_token().await?;
150 let mut builder = client
151 .request(method, url)
152 .bearer_auth(token)
153 .timeout(timeout);
154 if let Some(project) = &user_project {
155 builder = builder.header("x-goog-user-project", project);
156 }
157 if let Some(body) = body_ref {
158 builder = builder.json(body);
159 }
160 let request = builder.build().map_err(Error::Request)?;
161 let response = client.execute(request).await.map_err(Error::Request)?;
162 Ok(response)
163 }
164 };
165 response = self.retryer.run_with_retry(run_refresh).await?;
166 if !response.status().is_success() {
167 let status = response.status();
168 let body = response.text().await.unwrap_or_default();
169 return Err(Error::http(status, body));
170 }
171 return Ok(response.json::<R>().await?);
172 }
173
174 if !response.status().is_success() {
175 let status = response.status();
176 let body = response.text().await.unwrap_or_default();
177 return Err(Error::http(status, body));
178 }
179
180 Ok(response.json::<R>().await?)
181 }
182
183 pub async fn create_notebook(&self, title: impl Into<String>) -> Result<Notebook> {
184 let url = self.build_url(&self.notebooks_collection())?;
185 let request = CreateNotebookRequest {
186 title: title.into(),
187 };
188 self.request_json(Method::POST, url, Some(&request)).await
189 }
190
191 pub async fn batch_delete_notebooks(
199 &self,
200 request: BatchDeleteNotebooksRequest,
201 ) -> Result<BatchDeleteNotebooksResponse> {
202 let path = format!("{}:batchDelete", self.notebooks_collection());
203 let url = self.build_url(&path)?;
204 self.request_json(Method::POST, url, Some(&request)).await
205 }
206
207 pub async fn delete_notebooks(
215 &self,
216 notebook_names: Vec<String>,
217 ) -> Result<BatchDeleteNotebooksResponse> {
218 for name in ¬ebook_names {
220 let request = BatchDeleteNotebooksRequest {
221 names: vec![name.clone()],
222 };
223 self.batch_delete_notebooks(request).await?;
224 }
225 Ok(BatchDeleteNotebooksResponse::default())
227 }
228
229 pub async fn batch_create_sources(
230 &self,
231 notebook_id: &str,
232 request: BatchCreateSourcesRequest,
233 ) -> Result<BatchCreateSourcesResponse> {
234 let path = format!("{}/sources:batchCreate", self.notebook_path(notebook_id));
235 let url = self.build_url(&path)?;
236 self.request_json(Method::POST, url, Some(&request)).await
237 }
238
239 pub async fn share_notebook(
241 &self,
242 notebook_id: &str,
243 accounts: Vec<AccountRole>,
244 ) -> Result<ShareResponse> {
245 let path = format!("{}:share", self.notebook_path(notebook_id));
246 let url = self.build_url(&path)?;
247 let request = ShareRequest {
248 account_and_roles: accounts,
249 };
250 self.request_json(Method::POST, url, Some(&request)).await
251 }
252
253 pub async fn create_audio_overview(
254 &self,
255 notebook_id: &str,
256 request: AudioOverviewRequest,
257 ) -> Result<AudioOverviewResponse> {
258 let path = format!("{}/audioOverviews", self.notebook_path(notebook_id));
259 let url = self.build_url(&path)?;
260 self.request_json(Method::POST, url, Some(&request)).await
261 }
262
263 pub async fn delete_audio_overview(&self, notebook_id: &str) -> Result<()> {
264 let path = format!("{}/audioOverviews/default", self.notebook_path(notebook_id));
265 let url = self.build_url(&path)?;
266 let _response: serde_json::Value =
267 self.request_json(Method::DELETE, url, None::<&()>).await?;
268 Ok(())
269 }
270
271 pub async fn list_recently_viewed(
284 &self,
285 page_size: Option<u32>,
286 page_token: Option<&str>,
287 ) -> Result<ListRecentlyViewedResponse> {
288 let path = format!("{}:listRecentlyViewed", self.notebooks_collection());
289 let mut url = self.build_url(&path)?;
290 {
291 let mut pairs = url.query_pairs_mut();
292 if let Some(size) = page_size {
293 let clamped = size.clamp(PAGE_SIZE_MIN, PAGE_SIZE_MAX);
294 pairs.append_pair("pageSize", &clamped.to_string());
295 }
296 if let Some(token) = page_token {
297 pairs.append_pair("pageToken", token);
298 }
299 }
300 self.request_json::<(), _>(Method::GET, url, None::<&()>)
301 .await
302 }
303
304 pub async fn add_sources(
305 &self,
306 notebook_id: &str,
307 contents: Vec<UserContent>,
308 ) -> Result<BatchCreateSourcesResponse> {
309 let request = BatchCreateSourcesRequest {
310 user_contents: contents,
311 };
312 self.batch_create_sources(notebook_id, request).await
313 }
314
315 pub async fn batch_delete_sources(
316 &self,
317 notebook_id: &str,
318 request: BatchDeleteSourcesRequest,
319 ) -> Result<BatchDeleteSourcesResponse> {
320 let path = format!("{}/sources:batchDelete", self.notebook_path(notebook_id));
321 let url = self.build_url(&path)?;
322 self.request_json(Method::POST, url, Some(&request)).await
323 }
324
325 pub async fn delete_sources(
326 &self,
327 notebook_id: &str,
328 source_names: Vec<String>,
329 ) -> Result<BatchDeleteSourcesResponse> {
330 let request = BatchDeleteSourcesRequest {
331 names: source_names,
332 };
333 self.batch_delete_sources(notebook_id, request).await
334 }
335}
336
337fn normalize_endpoint_location(input: String) -> Result<String> {
338 let trimmed = input.trim().trim_end_matches('-').to_lowercase();
339 let normalized = match trimmed.as_str() {
340 "us" => "us-",
341 "eu" => "eu-",
342 "global" => "global-",
343 other => {
344 return Err(Error::Endpoint(format!(
345 "unsupported endpoint location: {other}"
346 )))
347 }
348 };
349 Ok(normalized.to_string())
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn normalize_endpoint_location_variants() {
358 assert_eq!(
359 normalize_endpoint_location("us".into()).unwrap(),
360 "us-".to_string()
361 );
362 assert_eq!(
363 normalize_endpoint_location("eu-".into()).unwrap(),
364 "eu-".to_string()
365 );
366 assert_eq!(
367 normalize_endpoint_location(" global ".into()).unwrap(),
368 "global-".to_string()
369 );
370 }
371
372 #[test]
373 fn normalize_endpoint_location_invalid() {
374 let err = normalize_endpoint_location("asia".into()).unwrap_err();
375 assert!(format!("{err}").contains("unsupported endpoint location"));
376 }
377
378 #[test]
379 fn with_base_url_accepts_absolute_url() {
380 let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
381 let client = NblmClient::new(provider, "123", "global", "us").unwrap();
382 let result = client.with_base_url("http://localhost:8080/v1alpha");
383 assert!(result.is_ok());
384 }
385
386 #[test]
387 fn with_base_url_trims_trailing_slash() {
388 let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
389 let client = NblmClient::new(provider, "123", "global", "us")
390 .unwrap()
391 .with_base_url("http://example.com/v1alpha/")
392 .unwrap();
393 assert_eq!(client.base, "http://example.com/v1alpha");
394 }
395
396 #[test]
397 fn with_base_url_rejects_relative_path() {
398 let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
399 let client = NblmClient::new(provider, "123", "global", "us").unwrap();
400 let result = client.with_base_url("/relative/path");
401 assert!(result.is_err());
402 }
403
404 #[test]
405 fn build_url_combines_base_and_path_correctly() {
406 let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
407 let client = NblmClient::new(provider, "123", "global", "us")
408 .unwrap()
409 .with_base_url("http://example.com/v1alpha")
410 .unwrap();
411
412 let url = client.build_url("/projects/123/notebooks").unwrap();
414 assert_eq!(
415 url.as_str(),
416 "http://example.com/v1alpha/projects/123/notebooks"
417 );
418
419 let url = client.build_url("projects/123/notebooks").unwrap();
421 assert_eq!(
422 url.as_str(),
423 "http://example.com/v1alpha/projects/123/notebooks"
424 );
425 }
426}