nblm_core/client/
mod.rs

1use std::{sync::Arc, time::Duration};
2
3use reqwest::{Client, Url};
4
5use crate::auth::{ensure_drive_scope, TokenProvider};
6use crate::env::EnvironmentConfig;
7use crate::error::Result;
8
9mod api;
10mod http;
11mod retry;
12mod url;
13
14pub use self::retry::{RetryConfig, Retryer};
15
16use self::api::backends::{BackendContext, ClientBackends};
17use self::http::HttpClient;
18use self::url::{new_url_builder, UrlBuilder};
19
20const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
21
22pub struct NblmClient {
23    pub(self) http: Arc<HttpClient>,
24    pub(self) url_builder: Arc<dyn UrlBuilder>,
25    backends: ClientBackends,
26    environment: EnvironmentConfig,
27    timeout: Duration,
28}
29
30impl NblmClient {
31    pub fn new(
32        token_provider: Arc<dyn TokenProvider>,
33        environment: EnvironmentConfig,
34    ) -> Result<Self> {
35        let client = Client::builder()
36            .user_agent(concat!("nblm-cli/", env!("CARGO_PKG_VERSION")))
37            .timeout(DEFAULT_TIMEOUT)
38            .build()
39            .map_err(crate::error::Error::from)?;
40
41        let retryer = Retryer::new(RetryConfig::default());
42        let http = Arc::new(HttpClient::new(client, token_provider, retryer, None));
43        let url_builder = new_url_builder(
44            environment.profile(),
45            environment.base_url().to_string(),
46            environment.parent_path().to_string(),
47        );
48        let ctx = BackendContext::new(Arc::clone(&http), Arc::clone(&url_builder));
49        let backends = ClientBackends::new(environment.profile(), ctx);
50
51        Ok(Self {
52            http,
53            url_builder,
54            backends,
55            environment,
56            timeout: DEFAULT_TIMEOUT,
57        })
58    }
59
60    #[deprecated(note = "Use EnvironmentConfig::enterprise(...) with NblmClient::new")]
61    pub fn new_enterprise(
62        token_provider: Arc<dyn TokenProvider>,
63        project_number: impl Into<String>,
64        location: impl Into<String>,
65        endpoint_location: impl Into<String>,
66    ) -> Result<Self> {
67        let env = EnvironmentConfig::enterprise(project_number, location, endpoint_location)?;
68        Self::new(token_provider, env)
69    }
70
71    pub fn with_timeout(mut self, timeout: Duration) -> Self {
72        self.timeout = timeout;
73        // Update the underlying HTTP client's timeout
74        let client = Client::builder()
75            .user_agent(concat!("nblm-cli/", env!("CARGO_PKG_VERSION")))
76            .timeout(timeout)
77            .build()
78            .expect("Failed to rebuild client with new timeout");
79
80        let token_provider = Arc::clone(&self.http.token_provider);
81        let retryer = self.http.retryer.clone();
82        let user_project = self.http.user_project.clone();
83        self.http = Arc::new(HttpClient::new(
84            client,
85            token_provider,
86            retryer,
87            user_project,
88        ));
89        self.rebuild_backends();
90        self
91    }
92
93    pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
94        let client = Client::builder()
95            .user_agent(concat!("nblm-cli/", env!("CARGO_PKG_VERSION")))
96            .timeout(self.timeout)
97            .build()
98            .expect("Failed to rebuild client");
99
100        let token_provider = Arc::clone(&self.http.token_provider);
101        let retryer = Retryer::new(config);
102        let user_project = self.http.user_project.clone();
103        self.http = Arc::new(HttpClient::new(
104            client,
105            token_provider,
106            retryer,
107            user_project,
108        ));
109        self.rebuild_backends();
110        self
111    }
112
113    pub fn with_user_project(mut self, project: impl Into<String>) -> Self {
114        let client = Client::builder()
115            .user_agent(concat!("nblm-cli/", env!("CARGO_PKG_VERSION")))
116            .timeout(self.timeout)
117            .build()
118            .expect("Failed to rebuild client");
119
120        let token_provider = Arc::clone(&self.http.token_provider);
121        let retryer = self.http.retryer.clone();
122        let user_project = Some(project.into());
123        self.http = Arc::new(HttpClient::new(
124            client,
125            token_provider,
126            retryer,
127            user_project,
128        ));
129        self.rebuild_backends();
130        self
131    }
132
133    /// Override API base URL (for tests). Accepts absolute URL. Trims trailing slash.
134    pub fn with_base_url(mut self, base: impl Into<String>) -> Result<Self> {
135        let base = base.into().trim().trim_end_matches('/').to_string();
136        // Basic sanity check: absolute URL
137        let _ = Url::parse(&base).map_err(crate::error::Error::from)?;
138        self.environment = self.environment.clone().with_base_url(base.clone());
139        let parent = self.environment.parent_path().to_string();
140        self.url_builder = new_url_builder(self.environment.profile(), base, parent);
141        self.rebuild_backends();
142        Ok(self)
143    }
144}
145
146impl NblmClient {
147    fn rebuild_backends(&mut self) {
148        let ctx = BackendContext::new(Arc::clone(&self.http), Arc::clone(&self.url_builder));
149        self.backends = ClientBackends::new(self.environment.profile(), ctx);
150    }
151
152    pub(crate) async fn ensure_drive_scope_if_needed(&self, includes_drive: bool) -> Result<()> {
153        if includes_drive {
154            ensure_drive_scope(self.http.token_provider.as_ref()).await?;
155        }
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn with_base_url_accepts_absolute_url() {
166        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
167        let env = EnvironmentConfig::enterprise("123", "global", "us").unwrap();
168        let client = NblmClient::new(provider, env).unwrap();
169        let result = client.with_base_url("http://localhost:8080/v1alpha");
170        assert!(result.is_ok());
171    }
172
173    #[test]
174    fn with_base_url_trims_trailing_slash() {
175        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
176        let env = EnvironmentConfig::enterprise("123", "global", "us").unwrap();
177        let client = NblmClient::new(provider, env)
178            .unwrap()
179            .with_base_url("http://example.com/v1alpha/")
180            .unwrap();
181
182        // Test that URL building works correctly
183        let url = client.url_builder.build_url("/test").unwrap();
184        assert_eq!(url.as_str(), "http://example.com/v1alpha/test");
185    }
186
187    #[test]
188    fn with_base_url_rejects_relative_path() {
189        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
190        let env = EnvironmentConfig::enterprise("123", "global", "us").unwrap();
191        let client = NblmClient::new(provider, env).unwrap();
192        let result = client.with_base_url("/relative/path");
193        assert!(result.is_err());
194    }
195
196    #[test]
197    #[allow(deprecated)]
198    fn new_enterprise_constructs_client_correctly() {
199        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
200        let client = NblmClient::new_enterprise(provider, "123", "global", "us").unwrap();
201
202        // Verify base URL is constructed correctly
203        let url = client.url_builder.build_url("/test").unwrap();
204        assert!(url
205            .as_str()
206            .starts_with("https://us-discoveryengine.googleapis.com/v1alpha"));
207
208        // Verify parent path is set correctly
209        let notebooks_url = client.url_builder.notebooks_collection();
210        assert_eq!(notebooks_url, "projects/123/locations/global/notebooks");
211    }
212
213    #[test]
214    #[allow(deprecated)]
215    fn new_enterprise_handles_invalid_endpoint() {
216        let provider = Arc::new(crate::auth::StaticTokenProvider::new("test"));
217        let result = NblmClient::new_enterprise(provider, "123", "global", "invalid");
218        assert!(result.is_err());
219    }
220}