firebolt/
client.rs

1use crate::error::FireboltError;
2use crate::result::ResultSet;
3use std::collections::HashMap;
4use url::Url;
5
6const HEADER_UPDATE_ENDPOINT: &str = "Firebolt-Update-Endpoint";
7const HEADER_UPDATE_PARAMETERS: &str = "Firebolt-Update-Parameters";
8const HEADER_RESET_SESSION: &str = "Firebolt-Reset-Session";
9const HEADER_REMOVE_PARAMETERS: &str = "Firebolt-Remove-Parameters";
10
11#[derive(Debug)]
12pub struct FireboltClient {
13    _client_id: String,
14    _client_secret: String,
15    _token: String,
16    _parameters: HashMap<String, String>,
17    _engine_url: String,
18    _api_endpoint: String,
19}
20
21impl FireboltClient {
22    pub async fn query(&mut self, sql: &str) -> Result<ResultSet, FireboltError> {
23        let engine_url = self.engine_url();
24        let url = ensure_trailing_slash(engine_url);
25
26        let mut params = self.parameters().clone();
27        params.insert("output_format".to_string(), "JSON_Compact".to_string());
28
29        self.execute_query_request(&url, sql, &params, true).await
30    }
31
32    async fn execute_query_request(
33        &mut self,
34        url: &str,
35        sql: &str,
36        params: &HashMap<String, String>,
37        should_retry: bool,
38    ) -> Result<ResultSet, FireboltError> {
39        let client = reqwest::Client::new();
40        let token = &self._token;
41
42        let response = client
43            .post(url)
44            .query(params)
45            .header("Authorization", format!("Bearer {token}"))
46            .header("User-Agent", crate::version::user_agent())
47            .header(
48                "Firebolt-Protocol-Version",
49                crate::version::PROTOCOL_VERSION,
50            )
51            .body(sql.to_string())
52            .send()
53            .await
54            .map_err(|e| FireboltError::Network(format!("Request failed: {e}")))?;
55
56        let status = response.status();
57
58        if status == 401 && should_retry {
59            let (new_token, _expiration) = crate::auth::authenticate(
60                self.client_id().to_string(),
61                self.client_secret().to_string(),
62                self.api_endpoint().to_string(),
63            )
64            .await
65            .map_err(|e| FireboltError::Authentication(format!("Token refresh failed: {e}")))?;
66
67            self.set_token(new_token);
68            Box::pin(self.execute_query_request(url, sql, params, false)).await
69        } else if status == 401 {
70            Err(FireboltError::Authentication(
71                "Authentication failed after token refresh".to_string(),
72            ))
73        } else if status.is_server_error() {
74            let body = response.text().await.map_err(|e| {
75                FireboltError::Network(format!("Failed to read error response: {e}"))
76            })?;
77            Err(crate::parser::parse_server_error(body))
78        } else if status.is_success() {
79            self.process_response_headers(&response)?;
80            let body = response
81                .text()
82                .await
83                .map_err(|e| FireboltError::Network(format!("Failed to read response: {e}")))?;
84            crate::parser::parse_response(body)
85        } else {
86            let body = response.text().await.map_err(|e| {
87                FireboltError::Network(format!("Failed to read error response: {e}"))
88            })?;
89            Err(crate::parser::parse_server_error(body))
90        }
91    }
92
93    pub fn client_id(&self) -> &str {
94        &self._client_id
95    }
96
97    pub fn client_secret(&self) -> &str {
98        &self._client_secret
99    }
100
101    pub fn api_endpoint(&self) -> &str {
102        &self._api_endpoint
103    }
104
105    pub fn engine_url(&self) -> &str {
106        &self._engine_url
107    }
108
109    pub fn parameters(&self) -> &HashMap<String, String> {
110        &self._parameters
111    }
112
113    pub fn set_token(&mut self, token: String) {
114        self._token = token;
115    }
116
117    pub fn builder() -> FireboltClientFactory {
118        FireboltClientFactory::new()
119    }
120
121    fn process_response_headers(
122        &mut self,
123        response: &reqwest::Response,
124    ) -> Result<(), FireboltError> {
125        if let Some(endpoint_header) = response.headers().get(HEADER_UPDATE_ENDPOINT) {
126            let endpoint_str = endpoint_header.to_str().map_err(|e| {
127                FireboltError::HeaderParsing(format!("Invalid endpoint header: {e}"))
128            })?;
129
130            let url = Url::parse(FireboltClientFactory::fix_schema(endpoint_str).as_str())
131                .map_err(|e| FireboltError::HeaderParsing(format!("Invalid endpoint URL: {e}")))?;
132
133            let base_url = format!("{}://{}", url.scheme(), url.host_str().unwrap_or(""));
134            let path = url.path();
135            self._engine_url = if path == "/" || path.is_empty() {
136                base_url
137            } else {
138                format!("{base_url}{path}")
139            };
140
141            for (key, value) in url.query_pairs() {
142                self._parameters.insert(key.to_string(), value.to_string());
143            }
144        }
145
146        if let Some(params_header) = response.headers().get(HEADER_UPDATE_PARAMETERS) {
147            let params_str = params_header.to_str().map_err(|e| {
148                FireboltError::HeaderParsing(format!("Invalid parameters header: {e}"))
149            })?;
150
151            for param_pair in params_str.split(',') {
152                let param_pair = param_pair.trim();
153                if param_pair.is_empty() {
154                    continue;
155                }
156
157                let parts: Vec<&str> = param_pair.splitn(2, '=').collect();
158                if parts.len() != 2 {
159                    return Err(FireboltError::HeaderParsing(format!(
160                        "Invalid parameter format: {param_pair}"
161                    )));
162                }
163
164                let key = parts[0].trim();
165                let value = parts[1].trim();
166
167                if key.is_empty() {
168                    return Err(FireboltError::HeaderParsing(
169                        "Parameter key cannot be empty".to_string(),
170                    ));
171                }
172
173                self._parameters.insert(key.to_string(), value.to_string());
174            }
175        }
176
177        if response.headers().contains_key(HEADER_RESET_SESSION) {
178            let database = self._parameters.get("database").cloned();
179            let engine = self._parameters.get("engine").cloned();
180
181            self._parameters.clear();
182
183            if let Some(db) = database {
184                self._parameters.insert("database".to_string(), db);
185            }
186            if let Some(eng) = engine {
187                self._parameters.insert("engine".to_string(), eng);
188            }
189        }
190
191        if let Some(remove_header) = response.headers().get(HEADER_REMOVE_PARAMETERS) {
192            let remove_str = remove_header.to_str().map_err(|e| {
193                FireboltError::HeaderParsing(format!("Invalid remove parameters header: {e}"))
194            })?;
195
196            for param_name in remove_str.split(',') {
197                let param_name = param_name.trim();
198                if !param_name.is_empty() {
199                    self._parameters.remove(param_name);
200                }
201            }
202        }
203
204        Ok(())
205    }
206}
207
208fn ensure_trailing_slash(url: &str) -> String {
209    if url.ends_with('/') {
210        url.to_string()
211    } else {
212        format!("{url}/")
213    }
214}
215
216pub struct FireboltClientFactory {
217    client_id: Option<String>,
218    client_secret: Option<String>,
219    database_name: Option<String>,
220    engine_name: Option<String>,
221    account_name: Option<String>,
222    _api_endpoint: String,
223}
224
225impl FireboltClientFactory {
226    fn new() -> Self {
227        Self {
228            client_id: None,
229            client_secret: None,
230            database_name: None,
231            engine_name: None,
232            account_name: None,
233            _api_endpoint: "https://api.firebolt.io".to_string(),
234        }
235    }
236
237    fn fix_schema(url: &str) -> String {
238        if url.starts_with("https://") || url.starts_with("http://") {
239            url.to_string()
240        } else {
241            format!("https://{url}")
242        }
243    }
244
245    fn get_api_endpoint() -> String {
246        let api_endpoint = std::env::var("FIREBOLT_API_ENDPOINT")
247            .unwrap_or_else(|_| "api.app.firebolt.io".to_string());
248
249        Self::fix_schema(&api_endpoint)
250    }
251
252    async fn get_engine_url(
253        account_name: &str,
254        api_endpoint: &str,
255        token: &str,
256    ) -> Result<String, FireboltError> {
257        let engine_url_endpoint = format!("{api_endpoint}/web/v3/account/{account_name}/engineUrl");
258        let client = reqwest::Client::new();
259
260        let response = client
261            .get(&engine_url_endpoint)
262            .header("Authorization", format!("Bearer {token}"))
263            .header("User-Agent", crate::version::user_agent())
264            .send()
265            .await
266            .map_err(|e| FireboltError::Network(format!("Failed to get engine URL: {e}")))?;
267
268        let status = response.status();
269
270        match status.as_u16() {
271            200 => {
272                let body = response
273                    .text()
274                    .await
275                    .map_err(|e| FireboltError::Network(format!("Failed to read response: {e}")))?;
276
277                let json: serde_json::Value = serde_json::from_str(&body).map_err(|e| {
278                    FireboltError::Query(format!("Failed to parse engine URL response: {e}"))
279                })?;
280
281                let engine_url =
282                    json.get("engineUrl")
283                        .and_then(|v| v.as_str())
284                        .ok_or_else(|| {
285                            FireboltError::Query("Missing engineUrl field in response".to_string())
286                        })?;
287
288                Ok(Self::fix_schema(ensure_trailing_slash(engine_url).as_str()))
289            }
290            404 => Err(FireboltError::Configuration(format!(
291                "Account '{account_name}' not found"
292            ))),
293            _ => {
294                let body = response.text().await.map_err(|e| {
295                    FireboltError::Network(format!("Failed to read error response: {e}"))
296                })?;
297                Err(FireboltError::Query(body))
298            }
299        }
300    }
301
302    pub fn with_credentials(mut self, client_id: String, client_secret: String) -> Self {
303        self.client_id = Some(client_id);
304        self.client_secret = Some(client_secret);
305        self
306    }
307
308    pub fn with_database(mut self, database_name: String) -> Self {
309        self.database_name = Some(database_name);
310        self
311    }
312
313    pub fn with_engine(mut self, engine_name: String) -> Self {
314        self.engine_name = Some(engine_name);
315        self
316    }
317
318    pub fn with_account(mut self, account_name: String) -> Self {
319        self.account_name = Some(account_name);
320        self
321    }
322
323    pub async fn build(self) -> Result<FireboltClient, FireboltError> {
324        // 1. Validate required parameters
325        let client_id = self
326            .client_id
327            .ok_or_else(|| FireboltError::Configuration("client_id is required".to_string()))?;
328        let client_secret = self
329            .client_secret
330            .ok_or_else(|| FireboltError::Configuration("client_secret is required".to_string()))?;
331        let account_name = self
332            .account_name
333            .ok_or_else(|| FireboltError::Configuration("account_name is required".to_string()))?;
334
335        let api_endpoint = Self::get_api_endpoint();
336
337        let (token, _expiration) = crate::auth::authenticate(
338            client_id.clone(),
339            client_secret.clone(),
340            api_endpoint.clone(),
341        )
342        .await
343        .map_err(FireboltError::Authentication)?;
344
345        let engine_url = Self::get_engine_url(&account_name, &api_endpoint, &token).await?;
346
347        let mut client = FireboltClient {
348            _client_id: client_id,
349            _client_secret: client_secret,
350            _token: token,
351            _parameters: HashMap::new(),
352            _engine_url: engine_url,
353            _api_endpoint: api_endpoint,
354        };
355
356        if let Some(database_name) = self.database_name {
357            let use_database_sql = format!("USE DATABASE \"{database_name}\"");
358            client.query(&use_database_sql).await.map_err(|e| {
359                FireboltError::Configuration(format!("Failed to set database: {e}"))
360            })?;
361        }
362
363        if let Some(engine_name) = self.engine_name {
364            let use_engine_sql = format!("USE ENGINE \"{engine_name}\"");
365            client
366                .query(&use_engine_sql)
367                .await
368                .map_err(|e| FireboltError::Configuration(format!("Failed to set engine: {e}")))?;
369        }
370
371        Ok(client)
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[tokio::test]
380    async fn test_execute_query_request_success() {
381        let mut server = mockito::Server::new_async().await;
382        let mock = server
383            .mock("POST", "/")
384            .with_status(200)
385            .with_header("content-type", "application/json")
386            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
387            .create_async()
388            .await;
389
390        let mut client = create_test_client();
391        client._engine_url = server.url();
392
393        let result = client
394            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
395            .await;
396
397        mock.assert_async().await;
398        assert!(result.is_ok());
399    }
400
401    #[tokio::test]
402    async fn test_execute_query_request_retry_on_401() {
403        let mut server = mockito::Server::new_async().await;
404
405        let mock_401 = server
406            .mock("POST", "/")
407            .with_status(401)
408            .expect(1)
409            .create_async()
410            .await;
411
412        let _mock_success = server
413            .mock("POST", "/")
414            .with_status(200)
415            .with_header("content-type", "application/json")
416            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
417            .create_async()
418            .await;
419
420        let mut client = create_test_client();
421        client._engine_url = server.url();
422
423        let result = client
424            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
425            .await;
426
427        mock_401.assert_async().await;
428
429        assert!(result.is_err());
430        assert!(matches!(
431            result.unwrap_err(),
432            FireboltError::Authentication(_)
433        ));
434    }
435
436    #[tokio::test]
437    async fn test_execute_query_request_no_retry_on_second_401() {
438        let mut server = mockito::Server::new_async().await;
439        let mock = server
440            .mock("POST", "/")
441            .with_status(401)
442            .expect(1)
443            .create_async()
444            .await;
445
446        let mut client = create_test_client();
447        client._engine_url = server.url();
448
449        let result = client
450            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), false)
451            .await;
452
453        mock.assert_async().await;
454        assert!(result.is_err());
455        assert!(matches!(
456            result.unwrap_err(),
457            FireboltError::Authentication(_)
458        ));
459    }
460
461    #[tokio::test]
462    async fn test_execute_query_request_5xx_error() {
463        let mut server = mockito::Server::new_async().await;
464        let mock = server
465            .mock("POST", "/")
466            .with_status(500)
467            .with_body("Internal Server Error")
468            .create_async()
469            .await;
470
471        let mut client = create_test_client();
472        client._engine_url = server.url();
473
474        let result = client
475            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
476            .await;
477
478        mock.assert_async().await;
479        assert!(result.is_err());
480        let error = result.unwrap_err();
481        assert!(matches!(error, FireboltError::Query(_)));
482        assert!(format!("{error:?}").contains("Internal Server Error"));
483    }
484
485    #[test]
486    fn test_client_getters() {
487        let client = create_test_client();
488        assert_eq!(client.client_id(), "test_id");
489        assert_eq!(client.client_secret(), "test_secret");
490        assert_eq!(client.api_endpoint(), "https://api.test.firebolt.io");
491        assert_eq!(client.engine_url(), "https://test.engine.url/");
492        assert!(client.parameters().is_empty());
493    }
494
495    #[test]
496    fn test_set_token() {
497        let mut client = create_test_client();
498        client.set_token("new_token".to_string());
499        assert_eq!(client._token, "new_token".to_string());
500    }
501
502    #[tokio::test]
503    async fn test_execute_query_request_headers() {
504        let mut server = mockito::Server::new_async().await;
505        let mock = server
506            .mock("POST", "/")
507            .match_header("User-Agent", crate::version::user_agent().as_str())
508            .match_header(
509                "Firebolt-Protocol-Version",
510                crate::version::PROTOCOL_VERSION,
511            )
512            .match_header("Authorization", "Bearer test_token")
513            .with_status(200)
514            .with_header("content-type", "application/json")
515            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
516            .create_async()
517            .await;
518
519        let mut client = create_test_client();
520        client._engine_url = server.url();
521
522        let result = client
523            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
524            .await;
525
526        mock.assert_async().await;
527        assert!(result.is_ok());
528    }
529
530    #[test]
531    fn test_ensure_trailing_slash() {
532        assert_eq!(
533            ensure_trailing_slash("https://example.com"),
534            "https://example.com/"
535        );
536        assert_eq!(
537            ensure_trailing_slash("https://example.com/"),
538            "https://example.com/"
539        );
540        assert_eq!(ensure_trailing_slash(""), "/");
541    }
542
543    fn create_test_client() -> FireboltClient {
544        FireboltClient {
545            _client_id: "test_id".to_string(),
546            _client_secret: "test_secret".to_string(),
547            _token: "test_token".to_string(),
548            _parameters: HashMap::new(),
549            _engine_url: "https://test.engine.url/".to_string(),
550            _api_endpoint: "https://api.test.firebolt.io".to_string(),
551        }
552    }
553
554    #[tokio::test]
555    async fn test_build_missing_client_id() {
556        let mut server = mockito::Server::new_async().await;
557
558        let _auth_mock = server
559            .mock("POST", "/oauth/token")
560            .with_status(200)
561            .with_header("content-type", "application/json")
562            .with_body(r#"{"access_token": "test_token", "expires_in": 3600}"#)
563            .create_async()
564            .await;
565
566        let _engine_mock = server
567            .mock("GET", "/web/v3/account/test_account/engineUrl")
568            .with_status(200)
569            .with_header("content-type", "application/json")
570            .with_body(r#"{"engineUrl": "https://engine.test.firebolt.io/path"}"#)
571            .create_async()
572            .await;
573
574        let api_endpoint = server
575            .url()
576            .replace("http://", "https://api.test.firebolt.io");
577        std::env::set_var("FIREBOLT_API_ENDPOINT", &api_endpoint);
578
579        let factory_no_id = FireboltClientFactory {
580            client_id: None,
581            client_secret: Some("secret".to_string()),
582            database_name: None,
583            engine_name: None,
584            account_name: Some("test_account".to_string()),
585            _api_endpoint: api_endpoint,
586        };
587
588        let result = factory_no_id.build().await;
589
590        std::env::remove_var("FIREBOLT_API_ENDPOINT");
591
592        assert!(result.is_err());
593        assert!(matches!(
594            result.unwrap_err(),
595            FireboltError::Configuration(_)
596        ));
597    }
598
599    #[tokio::test]
600    async fn test_build_missing_client_secret() {
601        let mut server = mockito::Server::new_async().await;
602
603        let _auth_mock = server
604            .mock("POST", "/oauth/token")
605            .with_status(200)
606            .with_header("content-type", "application/json")
607            .with_body(r#"{"access_token": "test_token", "expires_in": 3600}"#)
608            .create_async()
609            .await;
610
611        let _engine_mock = server
612            .mock("GET", "/web/v3/account/test_account/engineUrl")
613            .with_status(200)
614            .with_header("content-type", "application/json")
615            .with_body(r#"{"engineUrl": "https://engine.test.firebolt.io/path"}"#)
616            .create_async()
617            .await;
618
619        let api_endpoint = server
620            .url()
621            .replace("http://", "https://api.test.firebolt.io");
622        std::env::set_var("FIREBOLT_API_ENDPOINT", &api_endpoint);
623
624        let factory_no_secret = FireboltClientFactory {
625            client_id: Some("client_id".to_string()),
626            client_secret: None,
627            database_name: None,
628            engine_name: None,
629            account_name: Some("test_account".to_string()),
630            _api_endpoint: api_endpoint,
631        };
632
633        let result = factory_no_secret.build().await;
634
635        std::env::remove_var("FIREBOLT_API_ENDPOINT");
636
637        assert!(result.is_err());
638        assert!(matches!(
639            result.unwrap_err(),
640            FireboltError::Configuration(_)
641        ));
642    }
643
644    #[tokio::test]
645    async fn test_build_missing_account_name() {
646        let mut server = mockito::Server::new_async().await;
647
648        let _auth_mock = server
649            .mock("POST", "/oauth/token")
650            .with_status(200)
651            .with_header("content-type", "application/json")
652            .with_body(r#"{"access_token": "test_token", "expires_in": 3600}"#)
653            .create_async()
654            .await;
655
656        let _engine_mock = server
657            .mock("GET", "/web/v3/account/test_account/engineUrl")
658            .with_status(200)
659            .with_header("content-type", "application/json")
660            .with_body(r#"{"engineUrl": "https://engine.test.firebolt.io/path"}"#)
661            .create_async()
662            .await;
663
664        let api_endpoint = server
665            .url()
666            .replace("http://", "https://api.test.firebolt.io");
667        std::env::set_var("FIREBOLT_API_ENDPOINT", &api_endpoint);
668
669        let factory_no_account = FireboltClientFactory {
670            client_id: Some("client_id".to_string()),
671            client_secret: Some("secret".to_string()),
672            database_name: None,
673            engine_name: None,
674            account_name: None,
675            _api_endpoint: api_endpoint,
676        };
677
678        let result = factory_no_account.build().await;
679
680        std::env::remove_var("FIREBOLT_API_ENDPOINT");
681
682        assert!(result.is_err());
683        assert!(matches!(
684            result.unwrap_err(),
685            FireboltError::Configuration(_)
686        ));
687    }
688
689    #[tokio::test]
690    async fn test_build_engine_url_success() {
691        std::env::set_var("FIREBOLT_API_ENDPOINT", "api.test.firebolt.io");
692
693        let factory = FireboltClientFactory {
694            client_id: Some("test_client_id".to_string()),
695            client_secret: Some("test_client_secret".to_string()),
696            database_name: None,
697            engine_name: None,
698            account_name: Some("test_account".to_string()),
699            _api_endpoint: "https://api.test.firebolt.io".to_string(),
700        };
701
702        let result = factory.build().await;
703
704        std::env::remove_var("FIREBOLT_API_ENDPOINT");
705
706        assert!(result.is_err());
707        assert!(matches!(
708            result.unwrap_err(),
709            FireboltError::Authentication(_)
710        ));
711    }
712
713    #[tokio::test]
714    async fn test_build_account_not_found() {
715        std::env::set_var("FIREBOLT_API_ENDPOINT", "api.test.firebolt.io");
716
717        let factory = FireboltClientFactory {
718            client_id: Some("test_client_id".to_string()),
719            client_secret: Some("test_client_secret".to_string()),
720            database_name: None,
721            engine_name: None,
722            account_name: Some("nonexistent_account".to_string()),
723            _api_endpoint: "https://api.test.firebolt.io".to_string(),
724        };
725
726        let result = factory.build().await;
727
728        std::env::remove_var("FIREBOLT_API_ENDPOINT");
729
730        assert!(result.is_err());
731        assert!(matches!(
732            result.unwrap_err(),
733            FireboltError::Authentication(_)
734        ));
735    }
736
737    #[tokio::test]
738    async fn test_build_server_error() {
739        std::env::set_var("FIREBOLT_API_ENDPOINT", "api.test.firebolt.io");
740
741        let factory = FireboltClientFactory {
742            client_id: Some("test_client_id".to_string()),
743            client_secret: Some("test_client_secret".to_string()),
744            database_name: None,
745            engine_name: None,
746            account_name: Some("test_account".to_string()),
747            _api_endpoint: "https://api.test.firebolt.io".to_string(),
748        };
749
750        let result = factory.build().await;
751
752        std::env::remove_var("FIREBOLT_API_ENDPOINT");
753
754        assert!(result.is_err());
755        assert!(matches!(
756            result.unwrap_err(),
757            FireboltError::Authentication(_)
758        ));
759    }
760
761    #[test]
762    fn test_get_api_endpoint_default() {
763        std::env::remove_var("FIREBOLT_API_ENDPOINT");
764
765        let result = FireboltClientFactory::get_api_endpoint();
766
767        assert_eq!(result, "https://api.app.firebolt.io");
768    }
769
770    #[test]
771    fn test_get_api_endpoint_from_env() {
772        std::env::set_var("FIREBOLT_API_ENDPOINT", "custom.api.firebolt.io");
773
774        let result = FireboltClientFactory::get_api_endpoint();
775
776        assert_eq!(result, "https://custom.api.firebolt.io");
777
778        std::env::remove_var("FIREBOLT_API_ENDPOINT");
779    }
780
781    #[test]
782    fn test_get_api_endpoint_with_https_prefix() {
783        std::env::set_var("FIREBOLT_API_ENDPOINT", "https://custom.api.firebolt.io");
784
785        let result = FireboltClientFactory::get_api_endpoint();
786
787        assert_eq!(result, "https://custom.api.firebolt.io");
788
789        std::env::remove_var("FIREBOLT_API_ENDPOINT");
790    }
791
792    #[test]
793    fn test_get_api_endpoint_with_http_prefix() {
794        std::env::set_var("FIREBOLT_API_ENDPOINT", "http://custom.api.firebolt.io");
795
796        let result = FireboltClientFactory::get_api_endpoint();
797
798        assert_eq!(result, "http://custom.api.firebolt.io");
799
800        std::env::remove_var("FIREBOLT_API_ENDPOINT");
801    }
802
803    #[tokio::test]
804    async fn test_get_engine_url_success() {
805        let mut server = mockito::Server::new_async().await;
806        let mock = server
807            .mock("GET", "/web/v3/account/test_account/engineUrl")
808            .with_status(200)
809            .with_header("content-type", "application/json")
810            .with_body(r#"{"engineUrl": "engine.test.firebolt.io"}"#)
811            .create_async()
812            .await;
813
814        let result =
815            FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
816                .await;
817
818        mock.assert_async().await;
819        assert!(result.is_ok());
820
821        let engine_url = result.unwrap();
822        assert_eq!(engine_url, "https://engine.test.firebolt.io/");
823    }
824
825    #[tokio::test]
826    async fn test_get_engine_url_account_not_found() {
827        let mut server = mockito::Server::new_async().await;
828        let mock = server
829            .mock("GET", "/web/v3/account/nonexistent/engineUrl")
830            .with_status(404)
831            .create_async()
832            .await;
833
834        let result =
835            FireboltClientFactory::get_engine_url("nonexistent", &server.url(), "test_token").await;
836
837        mock.assert_async().await;
838        assert!(result.is_err());
839        assert!(matches!(
840            result.unwrap_err(),
841            FireboltError::Configuration(_)
842        ));
843    }
844
845    #[tokio::test]
846    async fn test_get_engine_url_server_error() {
847        let mut server = mockito::Server::new_async().await;
848        let mock = server
849            .mock("GET", "/web/v3/account/test_account/engineUrl")
850            .with_status(500)
851            .with_body("Internal server error")
852            .create_async()
853            .await;
854
855        let result =
856            FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
857                .await;
858
859        mock.assert_async().await;
860        assert!(result.is_err());
861        assert!(matches!(result.unwrap_err(), FireboltError::Query(_)));
862    }
863
864    #[tokio::test]
865    async fn test_get_engine_url_invalid_json() {
866        let mut server = mockito::Server::new_async().await;
867        let mock = server
868            .mock("GET", "/web/v3/account/test_account/engineUrl")
869            .with_status(200)
870            .with_header("content-type", "application/json")
871            .with_body("invalid json")
872            .create_async()
873            .await;
874
875        let result =
876            FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
877                .await;
878
879        mock.assert_async().await;
880        assert!(result.is_err());
881        assert!(matches!(result.unwrap_err(), FireboltError::Query(_)));
882    }
883
884    #[tokio::test]
885    async fn test_get_engine_url_missing_engine_url_field() {
886        let mut server = mockito::Server::new_async().await;
887        let mock = server
888            .mock("GET", "/web/v3/account/test_account/engineUrl")
889            .with_status(200)
890            .with_header("content-type", "application/json")
891            .with_body(r#"{"otherField": "value"}"#)
892            .create_async()
893            .await;
894
895        let result =
896            FireboltClientFactory::get_engine_url("test_account", &server.url(), "test_token")
897                .await;
898
899        mock.assert_async().await;
900        assert!(result.is_err());
901        assert!(matches!(result.unwrap_err(), FireboltError::Query(_)));
902    }
903
904    #[tokio::test]
905    async fn test_process_response_headers_update_endpoint() {
906        let mut server = mockito::Server::new_async().await;
907        let mock = server
908            .mock("POST", "/")
909            .with_status(200)
910            .with_header("content-type", "application/json")
911            .with_header(
912                HEADER_UPDATE_ENDPOINT,
913                "https://new.engine.url/path?param1=value1&param2=value2",
914            )
915            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
916            .create_async()
917            .await;
918
919        let mut client = create_test_client();
920        client._engine_url = server.url();
921
922        let result = client
923            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
924            .await;
925
926        mock.assert_async().await;
927        assert!(result.is_ok());
928        assert_eq!(client._engine_url, "https://new.engine.url/path");
929        assert_eq!(
930            client._parameters.get("param1"),
931            Some(&"value1".to_string())
932        );
933        assert_eq!(
934            client._parameters.get("param2"),
935            Some(&"value2".to_string())
936        );
937    }
938
939    #[tokio::test]
940    async fn test_process_response_headers_update_parameters() {
941        let mut server = mockito::Server::new_async().await;
942        let mock = server
943            .mock("POST", "/")
944            .with_status(200)
945            .with_header("content-type", "application/json")
946            .with_header(
947                HEADER_UPDATE_PARAMETERS,
948                "database=new_db,engine=new_engine,custom=value",
949            )
950            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
951            .create_async()
952            .await;
953
954        let mut client = create_test_client();
955        client._engine_url = server.url();
956
957        let result = client
958            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
959            .await;
960
961        mock.assert_async().await;
962        assert!(result.is_ok());
963        assert_eq!(
964            client._parameters.get("database"),
965            Some(&"new_db".to_string())
966        );
967        assert_eq!(
968            client._parameters.get("engine"),
969            Some(&"new_engine".to_string())
970        );
971        assert_eq!(client._parameters.get("custom"), Some(&"value".to_string()));
972    }
973
974    #[tokio::test]
975    async fn test_process_response_headers_reset_session() {
976        let mut server = mockito::Server::new_async().await;
977        let mock = server
978            .mock("POST", "/")
979            .with_status(200)
980            .with_header("content-type", "application/json")
981            .with_header(HEADER_RESET_SESSION, "true")
982            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
983            .create_async()
984            .await;
985
986        let mut client = create_test_client();
987        client._engine_url = server.url();
988        client
989            ._parameters
990            .insert("database".to_string(), "test_db".to_string());
991        client
992            ._parameters
993            .insert("engine".to_string(), "test_engine".to_string());
994        client
995            ._parameters
996            .insert("custom_param".to_string(), "custom_value".to_string());
997
998        let result = client
999            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1000            .await;
1001
1002        mock.assert_async().await;
1003        assert!(result.is_ok());
1004        assert_eq!(
1005            client._parameters.get("database"),
1006            Some(&"test_db".to_string())
1007        );
1008        assert_eq!(
1009            client._parameters.get("engine"),
1010            Some(&"test_engine".to_string())
1011        );
1012        assert_eq!(client._parameters.get("custom_param"), None);
1013        assert_eq!(client._parameters.len(), 2);
1014    }
1015
1016    #[tokio::test]
1017    async fn test_process_response_headers_remove_parameters() {
1018        let mut server = mockito::Server::new_async().await;
1019        let mock = server
1020            .mock("POST", "/")
1021            .with_status(200)
1022            .with_header("content-type", "application/json")
1023            .with_header(HEADER_REMOVE_PARAMETERS, "param1,param3")
1024            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
1025            .create_async()
1026            .await;
1027
1028        let mut client = create_test_client();
1029        client._engine_url = server.url();
1030        client
1031            ._parameters
1032            .insert("param1".to_string(), "value1".to_string());
1033        client
1034            ._parameters
1035            .insert("param2".to_string(), "value2".to_string());
1036        client
1037            ._parameters
1038            .insert("param3".to_string(), "value3".to_string());
1039
1040        let result = client
1041            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1042            .await;
1043
1044        mock.assert_async().await;
1045        assert!(result.is_ok());
1046        assert_eq!(client._parameters.get("param1"), None);
1047        assert_eq!(
1048            client._parameters.get("param2"),
1049            Some(&"value2".to_string())
1050        );
1051        assert_eq!(client._parameters.get("param3"), None);
1052        assert_eq!(client._parameters.len(), 1);
1053    }
1054
1055    #[tokio::test]
1056    async fn test_process_response_headers_invalid_parameters_format() {
1057        let mut server = mockito::Server::new_async().await;
1058        let mock = server
1059            .mock("POST", "/")
1060            .with_status(200)
1061            .with_header("content-type", "application/json")
1062            .with_header(HEADER_UPDATE_PARAMETERS, "invalid-format-no-equals")
1063            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
1064            .create_async()
1065            .await;
1066
1067        let mut client = create_test_client();
1068        client._engine_url = server.url();
1069
1070        let result = client
1071            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1072            .await;
1073
1074        mock.assert_async().await;
1075        assert!(result.is_err());
1076        assert!(matches!(
1077            result.unwrap_err(),
1078            FireboltError::HeaderParsing(_)
1079        ));
1080    }
1081
1082    #[tokio::test]
1083    async fn test_process_response_headers_empty_parameter_key() {
1084        let mut server = mockito::Server::new_async().await;
1085        let mock = server
1086            .mock("POST", "/")
1087            .with_status(200)
1088            .with_header("content-type", "application/json")
1089            .with_header(HEADER_UPDATE_PARAMETERS, "=value")
1090            .with_body(r#"{"meta": [{"name": "test", "type": "int"}], "data": [[1]]}"#)
1091            .create_async()
1092            .await;
1093
1094        let mut client = create_test_client();
1095        client._engine_url = server.url();
1096
1097        let result = client
1098            .execute_query_request(&server.url(), "SELECT 1", &HashMap::new(), true)
1099            .await;
1100
1101        mock.assert_async().await;
1102        assert!(result.is_err());
1103        assert!(matches!(
1104            result.unwrap_err(),
1105            FireboltError::HeaderParsing(_)
1106        ));
1107    }
1108}