1use crate::error::FireboltError;
2use crate::result::ResultSet;
3use std::collections::HashMap;
4use url::Url;
5
6const HEADER_UPDATE_ENDPOINT: &str = "Firebolt-Update-Endpoint";
7const HEADER_UPDATE_PARAMETERS: &str = "Firebolt-Update-Parameters";
8const HEADER_RESET_SESSION: &str = "Firebolt-Reset-Session";
9const HEADER_REMOVE_PARAMETERS: &str = "Firebolt-Remove-Parameters";
10
11#[derive(Debug)]
12pub struct FireboltClient {
13 _client_id: String,
14 _client_secret: String,
15 _token: String,
16 _parameters: HashMap<String, String>,
17 _engine_url: String,
18 _api_endpoint: String,
19}
20
21impl FireboltClient {
22 pub async fn query(&mut self, sql: &str) -> Result<ResultSet, FireboltError> {
23 let engine_url = self.engine_url();
24 let url = ensure_trailing_slash(engine_url);
25
26 let mut params = self.parameters().clone();
27 params.insert("output_format".to_string(), "JSON_Compact".to_string());
28
29 self.execute_query_request(&url, sql, ¶ms, true).await
30 }
31
32 async fn execute_query_request(
33 &mut self,
34 url: &str,
35 sql: &str,
36 params: &HashMap<String, String>,
37 should_retry: bool,
38 ) -> Result<ResultSet, FireboltError> {
39 let client = reqwest::Client::new();
40 let token = &self._token;
41
42 let response = client
43 .post(url)
44 .query(params)
45 .header("Authorization", format!("Bearer {token}"))
46 .header("User-Agent", crate::version::user_agent())
47 .header(
48 "Firebolt-Protocol-Version",
49 crate::version::PROTOCOL_VERSION,
50 )
51 .body(sql.to_string())
52 .send()
53 .await
54 .map_err(|e| FireboltError::Network(format!("Request failed: {e}")))?;
55
56 let status = response.status();
57
58 if status == 401 && should_retry {
59 let (new_token, _expiration) = crate::auth::authenticate(
60 self.client_id().to_string(),
61 self.client_secret().to_string(),
62 self.api_endpoint().to_string(),
63 )
64 .await
65 .map_err(|e| FireboltError::Authentication(format!("Token refresh failed: {e}")))?;
66
67 self.set_token(new_token);
68 Box::pin(self.execute_query_request(url, sql, params, false)).await
69 } else if status == 401 {
70 Err(FireboltError::Authentication(
71 "Authentication failed after token refresh".to_string(),
72 ))
73 } else if status.is_server_error() {
74 let body = response.text().await.map_err(|e| {
75 FireboltError::Network(format!("Failed to read error response: {e}"))
76 })?;
77 Err(crate::parser::parse_server_error(body))
78 } else if status.is_success() {
79 self.process_response_headers(&response)?;
80 let body = response
81 .text()
82 .await
83 .map_err(|e| FireboltError::Network(format!("Failed to read response: {e}")))?;
84 crate::parser::parse_response(body)
85 } else {
86 let body = response.text().await.map_err(|e| {
87 FireboltError::Network(format!("Failed to read error response: {e}"))
88 })?;
89 Err(crate::parser::parse_server_error(body))
90 }
91 }
92
93 pub fn client_id(&self) -> &str {
94 &self._client_id
95 }
96
97 pub fn client_secret(&self) -> &str {
98 &self._client_secret
99 }
100
101 pub fn api_endpoint(&self) -> &str {
102 &self._api_endpoint
103 }
104
105 pub fn engine_url(&self) -> &str {
106 &self._engine_url
107 }
108
109 pub fn parameters(&self) -> &HashMap<String, String> {
110 &self._parameters
111 }
112
113 pub fn set_token(&mut self, token: String) {
114 self._token = token;
115 }
116
117 pub fn builder() -> FireboltClientFactory {
118 FireboltClientFactory::new()
119 }
120
121 fn process_response_headers(
122 &mut self,
123 response: &reqwest::Response,
124 ) -> Result<(), FireboltError> {
125 if let Some(endpoint_header) = response.headers().get(HEADER_UPDATE_ENDPOINT) {
126 let endpoint_str = endpoint_header.to_str().map_err(|e| {
127 FireboltError::HeaderParsing(format!("Invalid endpoint header: {e}"))
128 })?;
129
130 let url = Url::parse(FireboltClientFactory::fix_schema(endpoint_str).as_str())
131 .map_err(|e| FireboltError::HeaderParsing(format!("Invalid endpoint URL: {e}")))?;
132
133 let base_url = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
134 let path = url.path();
135 self._engine_url = if path == "/" || path.is_empty() {
136 base_url
137 } else {
138 format!("{base_url}{path}")
139 };
140
141 for (key, value) in url.query_pairs() {
142 self._parameters.insert(key.to_string(), value.to_string());
143 }
144 }
145
146 if let Some(params_header) = response.headers().get(HEADER_UPDATE_PARAMETERS) {
147 let params_str = params_header.to_str().map_err(|e| {
148 FireboltError::HeaderParsing(format!("Invalid parameters header: {e}"))
149 })?;
150
151 for param_pair in params_str.split(',') {
152 let param_pair = param_pair.trim();
153 if param_pair.is_empty() {
154 continue;
155 }
156
157 let parts: Vec<&str> = param_pair.splitn(2, '=').collect();
158 if parts.len() != 2 {
159 return Err(FireboltError::HeaderParsing(format!(
160 "Invalid parameter format: {param_pair}"
161 )));
162 }
163
164 let key = parts[0].trim();
165 let value = parts[1].trim();
166
167 if key.is_empty() {
168 return Err(FireboltError::HeaderParsing(
169 "Parameter key cannot be empty".to_string(),
170 ));
171 }
172
173 self._parameters.insert(key.to_string(), value.to_string());
174 }
175 }
176
177 if response.headers().contains_key(HEADER_RESET_SESSION) {
178 let database = self._parameters.get("database").cloned();
179 let engine = self._parameters.get("engine").cloned();
180
181 self._parameters.clear();
182
183 if let Some(db) = database {
184 self._parameters.insert("database".to_string(), db);
185 }
186 if let Some(eng) = engine {
187 self._parameters.insert("engine".to_string(), eng);
188 }
189 }
190
191 if let Some(remove_header) = response.headers().get(HEADER_REMOVE_PARAMETERS) {
192 let remove_str = remove_header.to_str().map_err(|e| {
193 FireboltError::HeaderParsing(format!("Invalid remove parameters header: {e}"))
194 })?;
195
196 for param_name in remove_str.split(',') {
197 let param_name = param_name.trim();
198 if !param_name.is_empty() {
199 self._parameters.remove(param_name);
200 }
201 }
202 }
203
204 Ok(())
205 }
206}
207
208fn ensure_trailing_slash(url: &str) -> String {
209 if url.ends_with('/') {
210 url.to_string()
211 } else {
212 format!("{url}/")
213 }
214}
215
216pub struct FireboltClientFactory {
217 client_id: Option<String>,
218 client_secret: Option<String>,
219 database_name: Option<String>,
220 engine_name: Option<String>,
221 account_name: Option<String>,
222 _api_endpoint: String,
223}
224
225impl FireboltClientFactory {
226 fn new() -> Self {
227 Self {
228 client_id: None,
229 client_secret: None,
230 database_name: None,
231 engine_name: None,
232 account_name: None,
233 _api_endpoint: "https://api.firebolt.io".to_string(),
234 }
235 }
236
237 fn fix_schema(url: &str) -> String {
238 if url.starts_with("https://") || url.starts_with("http://") {
239 url.to_string()
240 } else {
241 format!("https://{url}")
242 }
243 }
244
245 fn get_api_endpoint() -> String {
246 let api_endpoint = std::env::var("FIREBOLT_API_ENDPOINT")
247 .unwrap_or_else(|_| "api.app.firebolt.io".to_string());
248
249 Self::fix_schema(&api_endpoint)
250 }
251
252 async fn get_engine_url(
253 account_name: &str,
254 api_endpoint: &str,
255 token: &str,
256 ) -> Result<String, FireboltError> {
257 let engine_url_endpoint = format!("{api_endpoint}/web/v3/account/{account_name}/engineUrl");
258 let client = reqwest::Client::new();
259
260 let response = client
261 .get(&engine_url_endpoint)
262 .header("Authorization", format!("Bearer {token}"))
263 .header("User-Agent", crate::version::user_agent())
264 .send()
265 .await
266 .map_err(|e| FireboltError::Network(format!("Failed to get engine URL: {e}")))?;
267
268 let status = response.status();
269
270 match status.as_u16() {
271 200 => {
272 let body = response
273 .text()
274 .await
275 .map_err(|e| FireboltError::Network(format!("Failed to read response: {e}")))?;
276
277 let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
278 FireboltError::Query(format!("Failed to parse engine URL response: {e}"))
279 })?;
280
281 let engine_url =
282 json.get("engineUrl")
283 .and_then(|v| v.as_str())
284 .ok_or_else(|| {
285 FireboltError::Query("Missing engineUrl field in response".to_string())
286 })?;
287
288 Ok(Self::fix_schema(ensure_trailing_slash(engine_url).as_str()))
289 }
290 404 => Err(FireboltError::Configuration(format!(
291 "Account '{account_name}' not found"
292 ))),
293 _ => {
294 let body = response.text().await.map_err(|e| {
295 FireboltError::Network(format!("Failed to read error response: {e}"))
296 })?;
297 Err(FireboltError::Query(body))
298 }
299 }
300 }
301
302 pub fn with_credentials(mut self, client_id: String, client_secret: String) -> Self {
303 self.client_id = Some(client_id);
304 self.client_secret = Some(client_secret);
305 self
306 }
307
308 pub fn with_database(mut self, database_name: String) -> Self {
309 self.database_name = Some(database_name);
310 self
311 }
312
313 pub fn with_engine(mut self, engine_name: String) -> Self {
314 self.engine_name = Some(engine_name);
315 self
316 }
317
318 pub fn with_account(mut self, account_name: String) -> Self {
319 self.account_name = Some(account_name);
320 self
321 }
322
323 pub async fn build(self) -> Result<FireboltClient, FireboltError> {
324 let client_id = self
326 .client_id
327 .ok_or_else(|| FireboltError::Configuration("client_id is required".to_string()))?;
328 let client_secret = self
329 .client_secret
330 .ok_or_else(|| FireboltError::Configuration("client_secret is required".to_string()))?;
331 let account_name = self
332 .account_name
333 .ok_or_else(|| FireboltError::Configuration("account_name is required".to_string()))?;
334
335 let api_endpoint = Self::get_api_endpoint();
336
337 let (token, _expiration) = crate::auth::authenticate(
338 client_id.clone(),
339 client_secret.clone(),
340 api_endpoint.clone(),
341 )
342 .await
343 .map_err(FireboltError::Authentication)?;
344
345 let engine_url = Self::get_engine_url(&account_name, &api_endpoint, &token).await?;
346
347 let mut client = FireboltClient {
348 _client_id: client_id,
349 _client_secret: client_secret,
350 _token: token,
351 _parameters: HashMap::new(),
352 _engine_url: engine_url,
353 _api_endpoint: api_endpoint,
354 };
355
356 if let Some(database_name) = self.database_name {
357 let use_database_sql = format!("USE DATABASE \"{database_name}\"");
358 client.query(&use_database_sql).await.map_err(|e| {
359 FireboltError::Configuration(format!("Failed to set database: {e}"))
360 })?;
361 }
362
363 if let Some(engine_name) = self.engine_name {
364 let use_engine_sql = format!("USE ENGINE \"{engine_name}\"");
365 client
366 .query(&use_engine_sql)
367 .await
368 .map_err(|e| FireboltError::Configuration(format!("Failed to set engine: {e}")))?;
369 }
370
371 Ok(client)
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[tokio::test]
380 async fn test_execute_query_request_success() {
381 let mut server = mockito::Server::new_async().await;
382 let mock = server
383 .mock("POST", "/")
384 .with_status(200)
385 .with_header("content-type", "application/json")
386 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
387 .create_async()
388 .await;
389
390 let mut client = create_test_client();
391 client._engine_url = server.url();
392
393 let result = client
394 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
395 .await;
396
397 mock.assert_async().await;
398 assert!(result.is_ok());
399 }
400
401 #[tokio::test]
402 async fn test_execute_query_request_retry_on_401() {
403 let mut server = mockito::Server::new_async().await;
404
405 let mock_401 = server
406 .mock("POST", "/")
407 .with_status(401)
408 .expect(1)
409 .create_async()
410 .await;
411
412 let _mock_success = server
413 .mock("POST", "/")
414 .with_status(200)
415 .with_header("content-type", "application/json")
416 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
417 .create_async()
418 .await;
419
420 let mut client = create_test_client();
421 client._engine_url = server.url();
422
423 let result = client
424 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
425 .await;
426
427 mock_401.assert_async().await;
428
429 assert!(result.is_err());
430 assert!(matches!(
431 result.unwrap_err(),
432 FireboltError::Authentication(_)
433 ));
434 }
435
436 #[tokio::test]
437 async fn test_execute_query_request_no_retry_on_second_401() {
438 let mut server = mockito::Server::new_async().await;
439 let mock = server
440 .mock("POST", "/")
441 .with_status(401)
442 .expect(1)
443 .create_async()
444 .await;
445
446 let mut client = create_test_client();
447 client._engine_url = server.url();
448
449 let result = client
450 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), false)
451 .await;
452
453 mock.assert_async().await;
454 assert!(result.is_err());
455 assert!(matches!(
456 result.unwrap_err(),
457 FireboltError::Authentication(_)
458 ));
459 }
460
461 #[tokio::test]
462 async fn test_execute_query_request_5xx_error() {
463 let mut server = mockito::Server::new_async().await;
464 let mock = server
465 .mock("POST", "/")
466 .with_status(500)
467 .with_body("Internal Server Error")
468 .create_async()
469 .await;
470
471 let mut client = create_test_client();
472 client._engine_url = server.url();
473
474 let result = client
475 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
476 .await;
477
478 mock.assert_async().await;
479 assert!(result.is_err());
480 let error = result.unwrap_err();
481 assert!(matches!(error, FireboltError::Query(_)));
482 assert!(format!("{error:?}").contains("Internal Server Error"));
483 }
484
485 #[test]
486 fn test_client_getters() {
487 let client = create_test_client();
488 assert_eq!(client.client_id(), "test_id");
489 assert_eq!(client.client_secret(), "test_secret");
490 assert_eq!(client.api_endpoint(), "https://api.test.firebolt.io");
491 assert_eq!(client.engine_url(), "https://test.engine.url/");
492 assert!(client.parameters().is_empty());
493 }
494
495 #[test]
496 fn test_set_token() {
497 let mut client = create_test_client();
498 client.set_token("new_token".to_string());
499 assert_eq!(client._token, "new_token".to_string());
500 }
501
502 #[tokio::test]
503 async fn test_execute_query_request_headers() {
504 let mut server = mockito::Server::new_async().await;
505 let mock = server
506 .mock("POST", "/")
507 .match_header("User-Agent", crate::version::user_agent().as_str())
508 .match_header(
509 "Firebolt-Protocol-Version",
510 crate::version::PROTOCOL_VERSION,
511 )
512 .match_header("Authorization", "Bearer test_token")
513 .with_status(200)
514 .with_header("content-type", "application/json")
515 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
516 .create_async()
517 .await;
518
519 let mut client = create_test_client();
520 client._engine_url = server.url();
521
522 let result = client
523 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
524 .await;
525
526 mock.assert_async().await;
527 assert!(result.is_ok());
528 }
529
530 #[test]
531 fn test_ensure_trailing_slash() {
532 assert_eq!(
533 ensure_trailing_slash("https://example.com"),
534 "https://example.com/"
535 );
536 assert_eq!(
537 ensure_trailing_slash("https://example.com/"),
538 "https://example.com/"
539 );
540 assert_eq!(ensure_trailing_slash(""), "/");
541 }
542
543 fn create_test_client() -> FireboltClient {
544 FireboltClient {
545 _client_id: "test_id".to_string(),
546 _client_secret: "test_secret".to_string(),
547 _token: "test_token".to_string(),
548 _parameters: HashMap::new(),
549 _engine_url: "https://test.engine.url/".to_string(),
550 _api_endpoint: "https://api.test.firebolt.io".to_string(),
551 }
552 }
553
554 #[tokio::test]
555 async fn test_build_missing_client_id() {
556 let mut server = mockito::Server::new_async().await;
557
558 let _auth_mock = server
559 .mock("POST", "/oauth/token")
560 .with_status(200)
561 .with_header("content-type", "application/json")
562 .with_body(r#"{"access_token": "test_token", "expires_in": 3600}"#)
563 .create_async()
564 .await;
565
566 let _engine_mock = server
567 .mock("GET", "/web/v3/account/test_account/engineUrl")
568 .with_status(200)
569 .with_header("content-type", "application/json")
570 .with_body(r#"{"engineUrl": "https://engine.test.firebolt.io/path"}"#)
571 .create_async()
572 .await;
573
574 let api_endpoint = server
575 .url()
576 .replace("http://", "https://api.test.firebolt.io");
577 std::env::set_var("FIREBOLT_API_ENDPOINT", &api_endpoint);
578
579 let factory_no_id = FireboltClientFactory {
580 client_id: None,
581 client_secret: Some("secret".to_string()),
582 database_name: None,
583 engine_name: None,
584 account_name: Some("test_account".to_string()),
585 _api_endpoint: api_endpoint,
586 };
587
588 let result = factory_no_id.build().await;
589
590 std::env::remove_var("FIREBOLT_API_ENDPOINT");
591
592 assert!(result.is_err());
593 assert!(matches!(
594 result.unwrap_err(),
595 FireboltError::Configuration(_)
596 ));
597 }
598
599 #[tokio::test]
600 async fn test_build_missing_client_secret() {
601 let mut server = mockito::Server::new_async().await;
602
603 let _auth_mock = server
604 .mock("POST", "/oauth/token")
605 .with_status(200)
606 .with_header("content-type", "application/json")
607 .with_body(r#"{"access_token": "test_token", "expires_in": 3600}"#)
608 .create_async()
609 .await;
610
611 let _engine_mock = server
612 .mock("GET", "/web/v3/account/test_account/engineUrl")
613 .with_status(200)
614 .with_header("content-type", "application/json")
615 .with_body(r#"{"engineUrl": "https://engine.test.firebolt.io/path"}"#)
616 .create_async()
617 .await;
618
619 let api_endpoint = server
620 .url()
621 .replace("http://", "https://api.test.firebolt.io");
622 std::env::set_var("FIREBOLT_API_ENDPOINT", &api_endpoint);
623
624 let factory_no_secret = FireboltClientFactory {
625 client_id: Some("client_id".to_string()),
626 client_secret: None,
627 database_name: None,
628 engine_name: None,
629 account_name: Some("test_account".to_string()),
630 _api_endpoint: api_endpoint,
631 };
632
633 let result = factory_no_secret.build().await;
634
635 std::env::remove_var("FIREBOLT_API_ENDPOINT");
636
637 assert!(result.is_err());
638 assert!(matches!(
639 result.unwrap_err(),
640 FireboltError::Configuration(_)
641 ));
642 }
643
644 #[tokio::test]
645 async fn test_build_missing_account_name() {
646 let mut server = mockito::Server::new_async().await;
647
648 let _auth_mock = server
649 .mock("POST", "/oauth/token")
650 .with_status(200)
651 .with_header("content-type", "application/json")
652 .with_body(r#"{"access_token": "test_token", "expires_in": 3600}"#)
653 .create_async()
654 .await;
655
656 let _engine_mock = server
657 .mock("GET", "/web/v3/account/test_account/engineUrl")
658 .with_status(200)
659 .with_header("content-type", "application/json")
660 .with_body(r#"{"engineUrl": "https://engine.test.firebolt.io/path"}"#)
661 .create_async()
662 .await;
663
664 let api_endpoint = server
665 .url()
666 .replace("http://", "https://api.test.firebolt.io");
667 std::env::set_var("FIREBOLT_API_ENDPOINT", &api_endpoint);
668
669 let factory_no_account = FireboltClientFactory {
670 client_id: Some("client_id".to_string()),
671 client_secret: Some("secret".to_string()),
672 database_name: None,
673 engine_name: None,
674 account_name: None,
675 _api_endpoint: api_endpoint,
676 };
677
678 let result = factory_no_account.build().await;
679
680 std::env::remove_var("FIREBOLT_API_ENDPOINT");
681
682 assert!(result.is_err());
683 assert!(matches!(
684 result.unwrap_err(),
685 FireboltError::Configuration(_)
686 ));
687 }
688
689 #[tokio::test]
690 async fn test_build_engine_url_success() {
691 std::env::set_var("FIREBOLT_API_ENDPOINT", "api.test.firebolt.io");
692
693 let factory = FireboltClientFactory {
694 client_id: Some("test_client_id".to_string()),
695 client_secret: Some("test_client_secret".to_string()),
696 database_name: None,
697 engine_name: None,
698 account_name: Some("test_account".to_string()),
699 _api_endpoint: "https://api.test.firebolt.io".to_string(),
700 };
701
702 let result = factory.build().await;
703
704 std::env::remove_var("FIREBOLT_API_ENDPOINT");
705
706 assert!(result.is_err());
707 assert!(matches!(
708 result.unwrap_err(),
709 FireboltError::Authentication(_)
710 ));
711 }
712
713 #[tokio::test]
714 async fn test_build_account_not_found() {
715 std::env::set_var("FIREBOLT_API_ENDPOINT", "api.test.firebolt.io");
716
717 let factory = FireboltClientFactory {
718 client_id: Some("test_client_id".to_string()),
719 client_secret: Some("test_client_secret".to_string()),
720 database_name: None,
721 engine_name: None,
722 account_name: Some("nonexistent_account".to_string()),
723 _api_endpoint: "https://api.test.firebolt.io".to_string(),
724 };
725
726 let result = factory.build().await;
727
728 std::env::remove_var("FIREBOLT_API_ENDPOINT");
729
730 assert!(result.is_err());
731 assert!(matches!(
732 result.unwrap_err(),
733 FireboltError::Authentication(_)
734 ));
735 }
736
737 #[tokio::test]
738 async fn test_build_server_error() {
739 std::env::set_var("FIREBOLT_API_ENDPOINT", "api.test.firebolt.io");
740
741 let factory = FireboltClientFactory {
742 client_id: Some("test_client_id".to_string()),
743 client_secret: Some("test_client_secret".to_string()),
744 database_name: None,
745 engine_name: None,
746 account_name: Some("test_account".to_string()),
747 _api_endpoint: "https://api.test.firebolt.io".to_string(),
748 };
749
750 let result = factory.build().await;
751
752 std::env::remove_var("FIREBOLT_API_ENDPOINT");
753
754 assert!(result.is_err());
755 assert!(matches!(
756 result.unwrap_err(),
757 FireboltError::Authentication(_)
758 ));
759 }
760
761 #[test]
762 fn test_get_api_endpoint_default() {
763 std::env::remove_var("FIREBOLT_API_ENDPOINT");
764
765 let result = FireboltClientFactory::get_api_endpoint();
766
767 assert_eq!(result, "https://api.app.firebolt.io");
768 }
769
770 #[test]
771 fn test_get_api_endpoint_from_env() {
772 std::env::set_var("FIREBOLT_API_ENDPOINT", "custom.api.firebolt.io");
773
774 let result = FireboltClientFactory::get_api_endpoint();
775
776 assert_eq!(result, "https://custom.api.firebolt.io");
777
778 std::env::remove_var("FIREBOLT_API_ENDPOINT");
779 }
780
781 #[test]
782 fn test_get_api_endpoint_with_https_prefix() {
783 std::env::set_var("FIREBOLT_API_ENDPOINT", "https://custom.api.firebolt.io");
784
785 let result = FireboltClientFactory::get_api_endpoint();
786
787 assert_eq!(result, "https://custom.api.firebolt.io");
788
789 std::env::remove_var("FIREBOLT_API_ENDPOINT");
790 }
791
792 #[test]
793 fn test_get_api_endpoint_with_http_prefix() {
794 std::env::set_var("FIREBOLT_API_ENDPOINT", "http://custom.api.firebolt.io");
795
796 let result = FireboltClientFactory::get_api_endpoint();
797
798 assert_eq!(result, "http://custom.api.firebolt.io");
799
800 std::env::remove_var("FIREBOLT_API_ENDPOINT");
801 }
802
803 #[tokio::test]
804 async fn test_get_engine_url_success() {
805 let mut server = mockito::Server::new_async().await;
806 let mock = server
807 .mock("GET", "/web/v3/account/test_account/engineUrl")
808 .with_status(200)
809 .with_header("content-type", "application/json")
810 .with_body(r#"{"engineUrl": "engine.test.firebolt.io"}"#)
811 .create_async()
812 .await;
813
814 let result =
815 FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
816 .await;
817
818 mock.assert_async().await;
819 assert!(result.is_ok());
820
821 let engine_url = result.unwrap();
822 assert_eq!(engine_url, "https://engine.test.firebolt.io/");
823 }
824
825 #[tokio::test]
826 async fn test_get_engine_url_account_not_found() {
827 let mut server = mockito::Server::new_async().await;
828 let mock = server
829 .mock("GET", "/web/v3/account/nonexistent/engineUrl")
830 .with_status(404)
831 .create_async()
832 .await;
833
834 let result =
835 FireboltClientFactory::get_engine_url("nonexistent", &server.url(), "test_token").await;
836
837 mock.assert_async().await;
838 assert!(result.is_err());
839 assert!(matches!(
840 result.unwrap_err(),
841 FireboltError::Configuration(_)
842 ));
843 }
844
845 #[tokio::test]
846 async fn test_get_engine_url_server_error() {
847 let mut server = mockito::Server::new_async().await;
848 let mock = server
849 .mock("GET", "/web/v3/account/test_account/engineUrl")
850 .with_status(500)
851 .with_body("Internal server error")
852 .create_async()
853 .await;
854
855 let result =
856 FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
857 .await;
858
859 mock.assert_async().await;
860 assert!(result.is_err());
861 assert!(matches!(result.unwrap_err(), FireboltError::Query(_)));
862 }
863
864 #[tokio::test]
865 async fn test_get_engine_url_invalid_json() {
866 let mut server = mockito::Server::new_async().await;
867 let mock = server
868 .mock("GET", "/web/v3/account/test_account/engineUrl")
869 .with_status(200)
870 .with_header("content-type", "application/json")
871 .with_body("invalid json")
872 .create_async()
873 .await;
874
875 let result =
876 FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
877 .await;
878
879 mock.assert_async().await;
880 assert!(result.is_err());
881 assert!(matches!(result.unwrap_err(), FireboltError::Query(_)));
882 }
883
884 #[tokio::test]
885 async fn test_get_engine_url_missing_engine_url_field() {
886 let mut server = mockito::Server::new_async().await;
887 let mock = server
888 .mock("GET", "/web/v3/account/test_account/engineUrl")
889 .with_status(200)
890 .with_header("content-type", "application/json")
891 .with_body(r#"{"otherField": "value"}"#)
892 .create_async()
893 .await;
894
895 let result =
896 FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
897 .await;
898
899 mock.assert_async().await;
900 assert!(result.is_err());
901 assert!(matches!(result.unwrap_err(), FireboltError::Query(_)));
902 }
903
904 #[tokio::test]
905 async fn test_process_response_headers_update_endpoint() {
906 let mut server = mockito::Server::new_async().await;
907 let mock = server
908 .mock("POST", "/")
909 .with_status(200)
910 .with_header("content-type", "application/json")
911 .with_header(
912 HEADER_UPDATE_ENDPOINT,
913 "https://new.engine.url/path?param1=value1¶m2=value2",
914 )
915 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
916 .create_async()
917 .await;
918
919 let mut client = create_test_client();
920 client._engine_url = server.url();
921
922 let result = client
923 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
924 .await;
925
926 mock.assert_async().await;
927 assert!(result.is_ok());
928 assert_eq!(client._engine_url, "https://new.engine.url/path");
929 assert_eq!(
930 client._parameters.get("param1"),
931 Some(&"value1".to_string())
932 );
933 assert_eq!(
934 client._parameters.get("param2"),
935 Some(&"value2".to_string())
936 );
937 }
938
939 #[tokio::test]
940 async fn test_process_response_headers_update_parameters() {
941 let mut server = mockito::Server::new_async().await;
942 let mock = server
943 .mock("POST", "/")
944 .with_status(200)
945 .with_header("content-type", "application/json")
946 .with_header(
947 HEADER_UPDATE_PARAMETERS,
948 "database=new_db,engine=new_engine,custom=value",
949 )
950 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
951 .create_async()
952 .await;
953
954 let mut client = create_test_client();
955 client._engine_url = server.url();
956
957 let result = client
958 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
959 .await;
960
961 mock.assert_async().await;
962 assert!(result.is_ok());
963 assert_eq!(
964 client._parameters.get("database"),
965 Some(&"new_db".to_string())
966 );
967 assert_eq!(
968 client._parameters.get("engine"),
969 Some(&"new_engine".to_string())
970 );
971 assert_eq!(client._parameters.get("custom"), Some(&"value".to_string()));
972 }
973
974 #[tokio::test]
975 async fn test_process_response_headers_reset_session() {
976 let mut server = mockito::Server::new_async().await;
977 let mock = server
978 .mock("POST", "/")
979 .with_status(200)
980 .with_header("content-type", "application/json")
981 .with_header(HEADER_RESET_SESSION, "true")
982 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
983 .create_async()
984 .await;
985
986 let mut client = create_test_client();
987 client._engine_url = server.url();
988 client
989 ._parameters
990 .insert("database".to_string(), "test_db".to_string());
991 client
992 ._parameters
993 .insert("engine".to_string(), "test_engine".to_string());
994 client
995 ._parameters
996 .insert("custom_param".to_string(), "custom_value".to_string());
997
998 let result = client
999 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1000 .await;
1001
1002 mock.assert_async().await;
1003 assert!(result.is_ok());
1004 assert_eq!(
1005 client._parameters.get("database"),
1006 Some(&"test_db".to_string())
1007 );
1008 assert_eq!(
1009 client._parameters.get("engine"),
1010 Some(&"test_engine".to_string())
1011 );
1012 assert_eq!(client._parameters.get("custom_param"), None);
1013 assert_eq!(client._parameters.len(), 2);
1014 }
1015
1016 #[tokio::test]
1017 async fn test_process_response_headers_remove_parameters() {
1018 let mut server = mockito::Server::new_async().await;
1019 let mock = server
1020 .mock("POST", "/")
1021 .with_status(200)
1022 .with_header("content-type", "application/json")
1023 .with_header(HEADER_REMOVE_PARAMETERS, "param1,param3")
1024 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
1025 .create_async()
1026 .await;
1027
1028 let mut client = create_test_client();
1029 client._engine_url = server.url();
1030 client
1031 ._parameters
1032 .insert("param1".to_string(), "value1".to_string());
1033 client
1034 ._parameters
1035 .insert("param2".to_string(), "value2".to_string());
1036 client
1037 ._parameters
1038 .insert("param3".to_string(), "value3".to_string());
1039
1040 let result = client
1041 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1042 .await;
1043
1044 mock.assert_async().await;
1045 assert!(result.is_ok());
1046 assert_eq!(client._parameters.get("param1"), None);
1047 assert_eq!(
1048 client._parameters.get("param2"),
1049 Some(&"value2".to_string())
1050 );
1051 assert_eq!(client._parameters.get("param3"), None);
1052 assert_eq!(client._parameters.len(), 1);
1053 }
1054
1055 #[tokio::test]
1056 async fn test_process_response_headers_invalid_parameters_format() {
1057 let mut server = mockito::Server::new_async().await;
1058 let mock = server
1059 .mock("POST", "/")
1060 .with_status(200)
1061 .with_header("content-type", "application/json")
1062 .with_header(HEADER_UPDATE_PARAMETERS, "invalid-format-no-equals")
1063 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
1064 .create_async()
1065 .await;
1066
1067 let mut client = create_test_client();
1068 client._engine_url = server.url();
1069
1070 let result = client
1071 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1072 .await;
1073
1074 mock.assert_async().await;
1075 assert!(result.is_err());
1076 assert!(matches!(
1077 result.unwrap_err(),
1078 FireboltError::HeaderParsing(_)
1079 ));
1080 }
1081
1082 #[tokio::test]
1083 async fn test_process_response_headers_empty_parameter_key() {
1084 let mut server = mockito::Server::new_async().await;
1085 let mock = server
1086 .mock("POST", "/")
1087 .with_status(200)
1088 .with_header("content-type", "application/json")
1089 .with_header(HEADER_UPDATE_PARAMETERS, "=value")
1090 .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
1091 .create_async()
1092 .await;
1093
1094 let mut client = create_test_client();
1095 client._engine_url = server.url();
1096
1097 let result = client
1098 .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1099 .await;
1100
1101 mock.assert_async().await;
1102 assert!(result.is_err());
1103 assert!(matches!(
1104 result.unwrap_err(),
1105 FireboltError::HeaderParsing(_)
1106 ));
1107 }
1108}