databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22    SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::retry::RetryDecision;
27use crate::stage::StageLocation;
28use crate::{
29    error::{Error, Result},
30    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
31    response::QueryResponse,
32    session::SessionState,
33    QueryStats,
34};
35use crate::{Page, Pages};
36use arrow_array::RecordBatch;
37use arrow_ipc::reader::StreamReader;
38use base64::engine::general_purpose::URL_SAFE;
39use base64::Engine;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
48use semver::Version;
49use serde::{de, Deserialize};
50use serde_json::{json, Value};
51use std::collections::{BTreeMap, HashMap};
52use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
53use std::sync::Arc;
54use std::time::{Duration, Instant};
55use tokio::time::sleep;
56use tokio_retry::strategy::jitter;
57use tokio_stream::StreamExt;
58use tokio_util::io::ReaderStream;
59use url::Url;
60
61const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
62const HEADER_TENANT: &str = "X-DATABEND-TENANT";
63const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
64const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
65const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
66const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
67const TXN_STATE_ACTIVE: &str = "Active";
68const HEADER_SQL: &str = "X-DATABEND-SQL";
69const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
70const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
71const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
72const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
73
74static VERSION: Lazy<String> = Lazy::new(|| {
75    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
76    version.to_string()
77});
78
79#[derive(Clone)]
80pub(crate) struct QueryState {
81    pub node_id: String,
82    pub last_access_time: Arc<Mutex<Instant>>,
83    pub timeout_secs: u64,
84}
85
86impl QueryState {
87    pub fn need_heartbeat(&self, now: Instant) -> bool {
88        let t = *self.last_access_time.lock();
89        now.duration_since(t).as_secs() > self.timeout_secs / 2
90    }
91}
92
93pub struct APIClient {
94    pub(crate) session_id: String,
95    cli: HttpClient,
96    scheme: String,
97    host: String,
98    port: u16,
99
100    endpoint: Url,
101
102    auth: Arc<dyn Auth>,
103
104    tenant: Option<String>,
105    warehouse: Mutex<Option<String>>,
106    session_state: Mutex<SessionState>,
107    route_hint: RouteHintGenerator,
108
109    disable_login: bool,
110    query_result_format: String,
111    disable_session_token: bool,
112    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
113
114    closed: AtomicBool,
115
116    server_version: Option<Version>,
117
118    wait_time_secs: Option<i64>,
119    max_rows_in_buffer: Option<i64>,
120    max_rows_per_page: Option<i64>,
121
122    connect_timeout: Duration,
123    page_request_timeout: Duration,
124
125    tls_ca_file: Option<String>,
126
127    presign: Mutex<PresignMode>,
128    last_node_id: Mutex<Option<String>>,
129    last_query_id: Mutex<Option<String>>,
130
131    capability: Capability,
132
133    queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
134
135    retry_count: u32,
136    retry_delay_secs: u64,
137}
138
139impl Drop for APIClient {
140    fn drop(&mut self) {
141        self.close_with_spawn()
142    }
143}
144
145impl APIClient {
146    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
147        let mut client = Self::from_dsn(dsn).await?;
148        client.build_client(name).await?;
149        if !client.disable_login {
150            client.login().await?;
151        }
152        if client.session_id.is_empty() {
153            client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
154        }
155        let client = Arc::new(client);
156        client.check_presign().await?;
157        GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
158        Ok(client)
159    }
160
161    pub fn capability(&self) -> &Capability {
162        &self.capability
163    }
164
165    fn set_presign_mode(&self, mode: PresignMode) {
166        *self.presign.lock() = mode
167    }
168    fn get_presign_mode(&self) -> PresignMode {
169        *self.presign.lock()
170    }
171
172    async fn from_dsn(dsn: &str) -> Result<Self> {
173        let u = Url::parse(dsn)?;
174        let mut client = Self::default();
175        if let Some(host) = u.host_str() {
176            client.host = host.to_string();
177        }
178
179        if u.username() != "" {
180            let password = u.password().unwrap_or_default();
181            let password = percent_decode_str(password).decode_utf8()?;
182            client.auth = Arc::new(BasicAuth::new(u.username(), password));
183        }
184
185        let mut session_state = SessionState::default();
186
187        let database = u.path().trim_start_matches('/');
188        if !database.is_empty() {
189            session_state.set_database(database);
190        }
191
192        let mut scheme = "https";
193        for (k, v) in u.query_pairs() {
194            match k.as_ref() {
195                "wait_time_secs" => {
196                    client.wait_time_secs = Some(v.parse()?);
197                }
198                "max_rows_in_buffer" => {
199                    client.max_rows_in_buffer = Some(v.parse()?);
200                }
201                "max_rows_per_page" => {
202                    client.max_rows_per_page = Some(v.parse()?);
203                }
204                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
205                "page_request_timeout_secs" => {
206                    client.page_request_timeout = {
207                        let secs: u64 = v.parse()?;
208                        Duration::from_secs(secs)
209                    };
210                }
211                "presign" => {
212                    let presign_mode = match v.as_ref() {
213                        "auto" => PresignMode::Auto,
214                        "detect" => PresignMode::Detect,
215                        "on" => PresignMode::On,
216                        "off" => PresignMode::Off,
217                        _ => {
218                            return Err(Error::BadArgument(format!(
219                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
220                        )))
221                        }
222                    };
223                    client.set_presign_mode(presign_mode);
224                }
225                "tenant" => {
226                    client.tenant = Some(v.to_string());
227                }
228                "warehouse" => {
229                    client.warehouse = Mutex::new(Some(v.to_string()));
230                }
231                "role" => session_state.set_role(v),
232                "sslmode" => match v.as_ref() {
233                    "disable" => scheme = "http",
234                    "require" | "enable" => scheme = "https",
235                    _ => {
236                        return Err(Error::BadArgument(format!(
237                            "Invalid value for sslmode: {v}"
238                        )))
239                    }
240                },
241                "tls_ca_file" => {
242                    client.tls_ca_file = Some(v.to_string());
243                }
244                "access_token" => {
245                    client.auth = Arc::new(AccessTokenAuth::new(v));
246                }
247                "access_token_file" => {
248                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
249                }
250                "login" => {
251                    client.disable_login = match v.as_ref() {
252                        "disable" => true,
253                        "enable" => false,
254                        _ => {
255                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
256                        }
257                    }
258                }
259                "session_token" => {
260                    client.disable_session_token = match v.as_ref() {
261                        "disable" => true,
262                        "enable" => false,
263                        _ => {
264                            return Err(Error::BadArgument(format!(
265                                "Invalid value for session_token: {v}"
266                            )))
267                        }
268                    }
269                }
270                "body_format" | "query_result_format" => {
271                    let v = v.to_string().to_lowercase();
272                    match v.as_str() {
273                        "json" | "arrow" => client.query_result_format = v.to_string(),
274                        _ => {
275                            return Err(Error::BadArgument(format!(
276                                "Invalid value for query_result_format: {v}"
277                            )))
278                        }
279                    }
280                }
281                "retry_count" => {
282                    client.retry_count = v.parse()?;
283                }
284                "retry_delay_secs" => {
285                    client.retry_delay_secs = v.parse()?;
286                }
287                _ => {
288                    session_state.set(k, v);
289                }
290            }
291        }
292        client.port = match u.port() {
293            Some(p) => p,
294            None => match scheme {
295                "http" => 80,
296                "https" => 443,
297                _ => unreachable!(),
298            },
299        };
300        client.scheme = scheme.to_string();
301        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
302        client.session_state = Mutex::new(session_state);
303
304        Ok(client)
305    }
306
307    pub fn host(&self) -> &str {
308        self.host.as_str()
309    }
310
311    pub fn port(&self) -> u16 {
312        self.port
313    }
314
315    pub fn scheme(&self) -> &str {
316        self.scheme.as_str()
317    }
318
319    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
320        let ua = match name {
321            Some(n) => n,
322            None => format!("databend-client-rust/{}", VERSION.as_str()),
323        };
324        let cookie_provider = GlobalCookieStore::new();
325        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
326        let mut initial_cookies = [&cookie].into_iter();
327        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
328        let mut cli_builder = HttpClient::builder()
329            .user_agent(ua)
330            .cookie_provider(Arc::new(cookie_provider))
331            .pool_idle_timeout(Duration::from_secs(1));
332        #[cfg(any(feature = "rustls", feature = "native-tls"))]
333        if self.scheme == "https" {
334            if let Some(ref ca_file) = self.tls_ca_file {
335                let cert_pem = tokio::fs::read(ca_file).await?;
336                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
337                cli_builder = cli_builder.add_root_certificate(cert);
338            }
339        }
340        self.cli = cli_builder.build()?;
341        Ok(())
342    }
343
344    async fn check_presign(self: &Arc<Self>) -> Result<()> {
345        let mode = match self.get_presign_mode() {
346            PresignMode::Auto => {
347                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
348                    PresignMode::On
349                } else {
350                    PresignMode::Off
351                }
352            }
353            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
354                Ok(_) => PresignMode::On,
355                Err(e) => {
356                    warn!("presign mode off with error detected: {e}");
357                    PresignMode::Off
358                }
359            },
360            mode => mode,
361        };
362        self.set_presign_mode(mode);
363        Ok(())
364    }
365
366    pub fn current_warehouse(&self) -> Option<String> {
367        let guard = self.warehouse.lock();
368        guard.clone()
369    }
370
371    pub fn current_catalog(&self) -> Option<String> {
372        let guard = self.session_state.lock();
373        guard.catalog.clone()
374    }
375
376    pub fn current_database(&self) -> Option<String> {
377        let guard = self.session_state.lock();
378        guard.database.clone()
379    }
380
381    pub fn set_warehouse(&self, warehouse: impl Into<String>) {
382        let mut guard = self.warehouse.lock();
383        *guard = Some(warehouse.into());
384    }
385
386    pub fn set_database(&self, database: impl Into<String>) {
387        let mut guard = self.session_state.lock();
388        guard.set_database(database);
389    }
390
391    pub fn set_role(&self, role: impl Into<String>) {
392        let mut guard = self.session_state.lock();
393        guard.set_role(role);
394    }
395
396    pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
397        let mut guard = self.session_state.lock();
398        guard.set(key, value);
399    }
400
401    pub async fn current_role(&self) -> Option<String> {
402        let guard = self.session_state.lock();
403        guard.role.clone()
404    }
405
406    fn in_active_transaction(&self) -> bool {
407        let guard = self.session_state.lock();
408        guard
409            .txn_state
410            .as_ref()
411            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
412            .unwrap_or(false)
413    }
414
415    pub fn username(&self) -> String {
416        self.auth.username()
417    }
418
419    fn gen_query_id(&self) -> String {
420        uuid::Uuid::now_v7().simple().to_string()
421    }
422
423    async fn handle_session(&self, session: &Option<SessionState>) {
424        let session = match session {
425            Some(session) => session,
426            None => return,
427        };
428
429        // save the updated session state from the server side
430        {
431            let mut session_state = self.session_state.lock();
432            *session_state = session.clone();
433        }
434
435        // process warehouse changed via session settings
436        if let Some(settings) = session.settings.as_ref() {
437            if let Some(v) = settings.get("warehouse") {
438                let mut warehouse = self.warehouse.lock();
439                *warehouse = Some(v.clone());
440            }
441        }
442    }
443
444    pub fn set_last_node_id(&self, node_id: String) {
445        *self.last_node_id.lock() = Some(node_id)
446    }
447
448    pub fn set_last_query_id(&self, query_id: Option<String>) {
449        *self.last_query_id.lock() = query_id
450    }
451
452    pub fn last_query_id(&self) -> Option<String> {
453        self.last_query_id.lock().clone()
454    }
455
456    fn last_node_id(&self) -> Option<String> {
457        self.last_node_id.lock().clone()
458    }
459
460    fn handle_warnings(&self, resp: &QueryResponse) {
461        if let Some(warnings) = &resp.warnings {
462            for w in warnings {
463                warn!(target: "server_warnings", "server warning: {w}");
464            }
465        }
466    }
467
468    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
469        info!("start query: {sql}");
470        let (resp, batches) = self.start_query_inner(sql, None, false).await?;
471        Pages::new(self.clone(), resp, batches, need_progress)
472    }
473
474    pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
475        let mut mgr = self.queries_need_heartbeat.lock();
476        if let Some(state) = mgr.remove(query_id) {
477            let self_cloned = self.clone();
478            let query_id = query_id.to_owned();
479            GLOBAL_RUNTIME.spawn(async move {
480                if let Err(e) = self_cloned
481                    .end_query(&query_id, "final", Some(state.node_id.as_str()))
482                    .await
483                {
484                    error!("failed to final query {query_id}: {e}");
485                }
486            });
487        }
488    }
489
490    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
491        if let Some(info) = &self.session_token_info {
492            let info = info.lock();
493            Ok(builder.bearer_auth(info.0.session_token.clone()))
494        } else {
495            self.auth.wrap(builder)
496        }
497    }
498
499    async fn start_query_inner(
500        &self,
501        sql: &str,
502        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
503        force_json_body: bool,
504    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
505        if !self.in_active_transaction() {
506            self.route_hint.next();
507        }
508        let endpoint = self.endpoint.join("v1/query")?;
509
510        // body
511        let session_state = self.session_state();
512        let need_sticky = session_state.need_sticky.unwrap_or(false);
513        let req = QueryRequest::new(sql)
514            .with_pagination(self.make_pagination())
515            .with_session(Some(session_state))
516            .with_stage_attachment(stage_attachment_config);
517
518        // headers
519        let query_id = self.gen_query_id();
520        let mut headers = self.make_headers(Some(&query_id))?;
521        if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
522            debug!("accept arrow data");
523            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
524        }
525
526        if need_sticky {
527            if let Some(node_id) = self.last_node_id() {
528                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
529            }
530        }
531        let mut builder = self.cli.post(endpoint.clone()).json(&req);
532        builder = self.wrap_auth_or_session_token(builder)?;
533        let request = builder.headers(headers.clone()).build()?;
534        let response = self.query_request_helper(request, true, true).await?;
535        self.handle_page(response, true).await
536    }
537
538    fn is_arrow_data(response: &Response) -> bool {
539        if let Some(typ) = response.headers().get(CONTENT_TYPE) {
540            if let Ok(t) = typ.to_str() {
541                return t == CONTENT_TYPE_ARROW;
542            }
543        }
544        false
545    }
546
547    async fn handle_page(
548        &self,
549        response: Response,
550        is_first: bool,
551    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
552        let status = response.status();
553        if status != 200 {
554            return Err(Error::response_error(status, &response.bytes().await?));
555        }
556        let is_arrow_data = Self::is_arrow_data(&response);
557        if is_first {
558            if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
559                self.route_hint.set(route_hint.to_str().unwrap_or_default());
560            }
561        }
562        let mut body = response.bytes().await?;
563        let mut batches = vec![];
564        if is_arrow_data {
565            if is_first {
566                debug!("received arrow data");
567            }
568            let cursor = std::io::Cursor::new(body.as_ref());
569            let reader = StreamReader::try_new(cursor, None)
570                .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
571            let schema = reader.schema();
572            let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
573                bytes::Bytes::copy_from_slice(json_resp.as_bytes())
574            } else {
575                return Err(Error::Decode(
576                    "missing response_header metadata in arrow payload".to_string(),
577                ));
578            };
579            for batch in reader {
580                let batch = batch
581                    .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
582                batches.push(batch);
583            }
584            body = json_body
585        };
586        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
587            if let Error::Logic(status, ec) = &e {
588                if *status == 404 {
589                    return Error::QueryNotFound(ec.message.clone());
590                }
591            }
592            e
593        })?;
594        self.handle_session(&resp.session).await;
595        if let Some(err) = &resp.error {
596            return Err(Error::QueryFailed(err.clone()));
597        }
598        if is_first {
599            self.handle_warnings(&resp);
600            self.set_last_query_id(Some(resp.id.clone()));
601            if let Some(node_id) = &resp.node_id {
602                self.set_last_node_id(node_id.clone());
603            }
604        }
605        Ok((resp, batches))
606    }
607
608    pub async fn query_page(
609        &self,
610        query_id: &str,
611        next_uri: &str,
612        node_id: &Option<String>,
613    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
614        info!("query page: {next_uri}");
615        let endpoint = self.endpoint.join(next_uri)?;
616        let mut headers = self.make_headers(Some(query_id))?;
617        if self.capability.arrow_data && self.query_result_format == "arrow" {
618            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
619        }
620        let mut builder = self.cli.get(endpoint.clone());
621        builder = self
622            .wrap_auth_or_session_token(builder)?
623            .headers(headers.clone())
624            .timeout(self.page_request_timeout);
625        if let Some(node_id) = node_id {
626            builder = builder.header(HEADER_STICKY_NODE, node_id)
627        }
628        let request = builder.build()?;
629
630        let response = self.query_request_helper(request, false, true).await?;
631        self.handle_page(response, false).await
632    }
633
634    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
635        self.end_query(query_id, "kill", None).await
636    }
637
638    pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
639        self.end_query(query_id, "final", node_id).await
640    }
641
642    pub async fn end_query(
643        &self,
644        query_id: &str,
645        method: &str,
646        node_id: Option<&str>,
647    ) -> Result<()> {
648        let uri = format!("/v1/query/{query_id}/{method}");
649        let endpoint = self.endpoint.join(&uri)?;
650        let headers = self.make_headers(Some(query_id))?;
651
652        info!("{method} query: {uri}");
653
654        let mut builder = self.cli.post(endpoint);
655        if let Some(node_id) = node_id {
656            builder = builder.header(HEADER_STICKY_NODE, node_id)
657        }
658        builder = self.wrap_auth_or_session_token(builder)?;
659        let resp = builder.headers(headers.clone()).send().await?;
660        if resp.status() != 200 {
661            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
662                .with_context(&format!("{method} query")));
663        }
664        Ok(())
665    }
666
667    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
668        self.query_all_inner(sql, false).await
669    }
670
671    pub async fn query_all_inner(
672        self: &Arc<Self>,
673        sql: &str,
674        force_json_body: bool,
675    ) -> Result<Page> {
676        let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
677        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
678        let mut all = Page::default();
679        while let Some(page) = pages.next().await {
680            all.update(page?);
681        }
682        Ok(all)
683    }
684
685    fn session_state(&self) -> SessionState {
686        self.session_state.lock().clone()
687    }
688
689    fn make_pagination(&self) -> Option<PaginationConfig> {
690        if self.wait_time_secs.is_none()
691            && self.max_rows_in_buffer.is_none()
692            && self.max_rows_per_page.is_none()
693        {
694            return None;
695        }
696        let mut pagination = PaginationConfig {
697            wait_time_secs: None,
698            max_rows_in_buffer: None,
699            max_rows_per_page: None,
700        };
701        if let Some(wait_time_secs) = self.wait_time_secs {
702            pagination.wait_time_secs = Some(wait_time_secs);
703        }
704        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
705            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
706        }
707        if let Some(max_rows_per_page) = self.max_rows_per_page {
708            pagination.max_rows_per_page = Some(max_rows_per_page);
709        }
710        Some(pagination)
711    }
712
713    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
714        let mut headers = HeaderMap::new();
715        if let Some(tenant) = &self.tenant {
716            headers.insert(HEADER_TENANT, tenant.parse()?);
717        }
718        let warehouse = self.warehouse.lock().clone();
719        if let Some(warehouse) = warehouse {
720            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
721        }
722        let route_hint = self.route_hint.current();
723        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
724        if let Some(query_id) = query_id {
725            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
726        }
727        Ok(headers)
728    }
729
730    pub async fn insert_with_stage(
731        self: &Arc<Self>,
732        sql: &str,
733        stage: &str,
734        file_format_options: BTreeMap<&str, &str>,
735        copy_options: BTreeMap<&str, &str>,
736    ) -> Result<QueryStats> {
737        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
738        let stage_attachment = Some(StageAttachmentConfig {
739            location: stage,
740            file_format_options: Some(file_format_options),
741            copy_options: Some(copy_options),
742        });
743        let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
744        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
745        let mut all = Page::default();
746        while let Some(page) = pages.next().await {
747            all.update(page?);
748        }
749        Ok(all.stats)
750    }
751
752    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
753        info!("get presigned upload url: {stage}");
754        let sql = format!("PRESIGN UPLOAD {stage}");
755        let resp = self.query_all_inner(&sql, true).await?;
756        if resp.data.len() != 1 {
757            return Err(Error::Decode(
758                "Empty response from server for presigned request".to_string(),
759            ));
760        }
761        if resp.data[0].len() != 3 {
762            return Err(Error::Decode(
763                "Invalid response from server for presigned request".to_string(),
764            ));
765        }
766        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
767        let method = resp.data[0][0].clone().unwrap_or_default();
768        if method != "PUT" {
769            return Err(Error::Decode(format!(
770                "Invalid method for presigned upload request: {method}"
771            )));
772        }
773        let headers: BTreeMap<String, String> =
774            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
775        let url = resp.data[0][2].clone().unwrap_or_default();
776        Ok(PresignedResponse {
777            method,
778            headers,
779            url,
780        })
781    }
782
783    pub async fn upload_to_stage(
784        self: &Arc<Self>,
785        stage: &str,
786        data: Reader,
787        size: u64,
788    ) -> Result<()> {
789        match self.get_presign_mode() {
790            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
791            PresignMode::On => {
792                let presigned = self.get_presigned_upload_url(stage).await?;
793                presign_upload_to_stage(presigned, data, size).await
794            }
795            PresignMode::Auto => {
796                unreachable!("PresignMode::Auto should be handled during client initialization")
797            }
798            PresignMode::Detect => {
799                unreachable!("PresignMode::Detect should be handled during client initialization")
800            }
801        }
802    }
803
804    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
805    async fn upload_to_stage_with_stream(
806        &self,
807        stage: &str,
808        data: Reader,
809        size: u64,
810    ) -> Result<()> {
811        info!("upload to stage with stream: {stage}, size: {size}");
812        if let Some(info) = self.need_pre_refresh_session().await {
813            self.refresh_session_token(info).await?;
814        }
815        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
816        let location = StageLocation::try_from(stage)?;
817        let query_id = self.gen_query_id();
818        let mut headers = self.make_headers(Some(&query_id))?;
819        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
820        let stream = Body::wrap_stream(ReaderStream::new(data));
821        let part = Part::stream_with_length(stream, size).file_name(location.path);
822        let form = Form::new().part("upload", part);
823        let mut builder = self.cli.put(endpoint.clone());
824        builder = self.wrap_auth_or_session_token(builder)?;
825        let resp = builder.headers(headers).multipart(form).send().await?;
826        let status = resp.status();
827        if status != 200 {
828            return Err(
829                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
830            );
831        }
832        Ok(())
833    }
834
835    // use base64 encode whenever possible for safety
836    // but also accept raw JSON for test/debug/one-shot operations
837    pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
838    where
839        T: de::DeserializeOwned,
840    {
841        if value.starts_with("{") {
842            serde_json::from_slice(value.as_bytes())
843                .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
844        } else {
845            let json = URL_SAFE.decode(value).map_err(|e| {
846                format!(
847                    "Invalid value {} for {key}, base64 decode error: {}",
848                    value, e
849                )
850            })?;
851            serde_json::from_slice(&json).map_err(|e| {
852                format!(
853                    "Invalid value {value} for {key}, JSON value {},  decode error: {e}",
854                    String::from_utf8_lossy(&json)
855                )
856            })
857        }
858    }
859
860    pub async fn streaming_load(
861        &self,
862        sql: &str,
863        data: Reader,
864        file_name: &str,
865    ) -> Result<LoadResponse> {
866        let body = Body::wrap_stream(ReaderStream::new(data));
867        let part = Part::stream(body).file_name(file_name.to_string());
868        let endpoint = self.endpoint.join("v1/streaming_load")?;
869        let mut builder = self.cli.put(endpoint.clone());
870        builder = self.wrap_auth_or_session_token(builder)?;
871        let query_id = self.gen_query_id();
872        let mut headers = self.make_headers(Some(&query_id))?;
873        headers.insert(HEADER_SQL, sql.parse()?);
874        let session = serde_json::to_string(&*self.session_state.lock())
875            .expect("serialize session state should not fail");
876        headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
877        let form = Form::new().part("upload", part);
878        let resp = builder.headers(headers).multipart(form).send().await?;
879        let status = resp.status();
880        if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
881            match Self::decode_json_header::<SessionState>(
882                HEADER_QUERY_CONTEXT,
883                value.to_str().unwrap(),
884            ) {
885                Ok(session) => *self.session_state.lock() = session,
886                Err(e) => {
887                    error!("Error decoding session state when streaming load: {e}");
888                }
889            }
890        };
891        if status != 200 {
892            return Err(
893                Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
894            );
895        }
896        let resp = resp.json::<LoadResponse>().await?;
897        Ok(resp)
898    }
899
900    async fn login(&mut self) -> Result<()> {
901        let endpoint = self.endpoint.join("/v1/session/login")?;
902        let headers = self.make_headers(None)?;
903        let body = LoginRequest::from(&*self.session_state.lock());
904        let mut builder = self.cli.post(endpoint.clone()).json(&body);
905        if self.disable_session_token {
906            builder = builder.query(&[("disable_session_token", true)]);
907        }
908        let builder = self.auth.wrap(builder)?;
909        let request = builder
910            .headers(headers.clone())
911            .timeout(self.connect_timeout)
912            .build()?;
913        let response = self.query_request_helper(request, true, false).await;
914        let response = match response {
915            Ok(r) => r,
916            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
917                info!("login return 404, skip login on the old version server");
918                return Ok(());
919            }
920            Err(e) => return Err(e),
921        };
922        if let Some(v) = response.headers().get(HEADER_SESSION_ID) {
923            if let Ok(s) = v.to_str() {
924                self.session_id = s.to_string();
925            }
926        }
927
928        let body = response.bytes().await?;
929        let response = json_from_slice(&body)?;
930        match response {
931            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
932            LoginResponseResult::Ok(info) => {
933                let server_version = info
934                    .version
935                    .parse()
936                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
937                self.capability = Capability::from_server_version(&server_version);
938                self.server_version = Some(server_version.clone());
939                let session_id = self.session_id.as_str();
940                if let Some(tokens) = info.tokens {
941                    info!(
942                        "[session {session_id}] login success with session token version = {server_version}",
943                    );
944                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
945                } else {
946                    info!("[session {session_id}] login success, version = {server_version}");
947                }
948            }
949        }
950        Ok(())
951    }
952
953    pub(crate) async fn try_heartbeat(&self) -> Result<()> {
954        let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
955        let queries = self.queries_need_heartbeat.lock().clone();
956        let mut node_to_queries = HashMap::<String, Vec<String>>::new();
957        let now = Instant::now();
958
959        let mut query_ids = Vec::new();
960        for (qid, state) in queries {
961            if state.need_heartbeat(now) {
962                query_ids.push(qid.to_string());
963                if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
964                    arr.push(qid);
965                } else {
966                    node_to_queries.insert(state.node_id, vec![qid]);
967                }
968            }
969        }
970
971        if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
972        {
973            return Ok(());
974        }
975
976        let body = json!({
977           "node_to_queries": node_to_queries
978        });
979        let builder = self.cli.post(endpoint.clone()).json(&body);
980        let request = self.wrap_auth_or_session_token(builder)?.build()?;
981        let response = self.query_request_helper(request, true, false).await?;
982        let json: Value = response.json().await?;
983        let session_id = self.session_id.as_str();
984        info!("[session {session_id}] heartbeat request={body}, response={json}");
985        if let Some(queries_to_remove) = json.get("queries_to_remove") {
986            if let Some(arr) = queries_to_remove.as_array() {
987                if !arr.is_empty() {
988                    let mut queries = self.queries_need_heartbeat.lock();
989                    for q in arr {
990                        if let Some(q) = q.as_str() {
991                            queries.remove(q);
992                        }
993                    }
994                }
995            }
996        }
997        let now = Instant::now();
998        let mut queries = self.queries_need_heartbeat.lock();
999        for qid in query_ids {
1000            if let Some(state) = queries.get_mut(&qid) {
1001                *state.last_access_time.lock() = now;
1002            }
1003        }
1004        Ok(())
1005    }
1006
1007    fn build_log_out_request(&self) -> Result<Request> {
1008        let endpoint = self.endpoint.join("/v1/session/logout")?;
1009
1010        let session_state = self.session_state();
1011        let need_sticky = session_state.need_sticky.unwrap_or(false);
1012        let mut headers = self.make_headers(None)?;
1013        if need_sticky {
1014            if let Some(node_id) = self.last_node_id() {
1015                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1016            }
1017        }
1018        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1019
1020        let builder = self.wrap_auth_or_session_token(builder)?;
1021        let req = builder.build()?;
1022        Ok(req)
1023    }
1024
1025    pub(crate) fn need_logout(&self) -> bool {
1026        self.session_token_info.is_some()
1027            || self.session_state.lock().need_keep_alive.unwrap_or(false)
1028    }
1029
1030    async fn refresh_session_token(
1031        &self,
1032        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1033    ) -> Result<()> {
1034        let (session_token_info, _) = { self_login_info.lock().clone() };
1035        let endpoint = self.endpoint.join("/v1/session/refresh")?;
1036        let body = RefreshSessionTokenRequest {
1037            session_token: session_token_info.session_token.clone(),
1038        };
1039        let headers = self.make_headers(None)?;
1040        let request = self
1041            .cli
1042            .post(endpoint.clone())
1043            .json(&body)
1044            .headers(headers.clone())
1045            .bearer_auth(session_token_info.refresh_token.clone())
1046            .timeout(self.connect_timeout)
1047            .build()?;
1048
1049        let max_retries = self.retry_count;
1050        let retry_delay = Duration::from_secs(self.retry_delay_secs);
1051        // avoid recursively call request_helper
1052        for i in 0..max_retries {
1053            let req = request.try_clone().expect("request not cloneable");
1054            match self.cli.execute(req).await {
1055                Ok(response) => {
1056                    let status = response.status();
1057                    let body = response.bytes().await?;
1058                    if status == StatusCode::OK {
1059                        let response = json_from_slice(&body)?;
1060                        return match response {
1061                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1062                            RefreshResponse::Ok(info) => {
1063                                *self_login_info.lock() = (info, Instant::now());
1064                                Ok(())
1065                            }
1066                        };
1067                    }
1068                    if status != StatusCode::SERVICE_UNAVAILABLE
1069                        || i >= max_retries.saturating_sub(1)
1070                    {
1071                        return Err(Error::response_error(status, &body));
1072                    }
1073                    info!(
1074                        "retry {}/{} for session token refresh due to: service unavailable (503)",
1075                        i + 1,
1076                        max_retries
1077                    );
1078                }
1079                Err(err) => {
1080                    if !(err.is_timeout() || err.is_connect()) || i >= max_retries.saturating_sub(1)
1081                    {
1082                        return Err(Error::Request(err.to_string()));
1083                    }
1084                    let reason = if err.is_timeout() {
1085                        "request timeout"
1086                    } else if err.is_connect() {
1087                        "connection error"
1088                    } else {
1089                        "request error"
1090                    };
1091                    info!(
1092                        "retry {}/{} for session token refresh due to: {} (error: {})",
1093                        i + 1,
1094                        max_retries,
1095                        reason,
1096                        err
1097                    );
1098                }
1099            };
1100            warn!(
1101                "retrying session token refresh after {} seconds",
1102                retry_delay.as_secs()
1103            );
1104            sleep(jitter(retry_delay)).await;
1105        }
1106        Ok(())
1107    }
1108
1109    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1110        if let Some(info) = &self.session_token_info {
1111            let (start, ttl) = {
1112                let guard = info.lock();
1113                (guard.1, guard.0.session_token_ttl_in_secs)
1114            };
1115            if Instant::now() > start + Duration::from_secs(ttl) {
1116                return Some(info.clone());
1117            }
1118        }
1119        None
1120    }
1121
1122    /// return Ok if and only if status code is 200.
1123    ///
1124    /// retry on
1125    ///   - network errors
1126    ///   - (optional) 503
1127    ///
1128    /// refresh databend token or reload jwt token if needed.
1129    async fn query_request_helper(
1130        &self,
1131        mut request: Request,
1132        retry_if_503: bool,
1133        refresh_if_401: bool,
1134    ) -> std::result::Result<Response, Error> {
1135        let mut refreshed = false;
1136        let mut retries = 0;
1137        let max_retries = self.retry_count;
1138        let retry_delay = Duration::from_secs(self.retry_delay_secs);
1139        loop {
1140            let req = request.try_clone().expect("request not cloneable");
1141            let decision = match self.cli.execute(req).await {
1142                Ok(response) => {
1143                    let status = response.status();
1144                    if status == StatusCode::OK {
1145                        return Ok(response);
1146                    }
1147                    let body = response.bytes().await?;
1148                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1149                        // waiting for server to start
1150                        RetryDecision::retry_with_reason(
1151                            Error::response_error(status, &body),
1152                            "service unavailable (503), server may be starting",
1153                        )
1154                    } else {
1155                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
1156                        match resp {
1157                            Ok(r) => {
1158                                let e = r.error;
1159                                if status == StatusCode::UNAUTHORIZED {
1160                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1161                                    if let Some(session_token_info) = &self.session_token_info {
1162                                        if need_refresh_token(e.code)
1163                                            && !refreshed
1164                                            && refresh_if_401
1165                                        {
1166                                            self.refresh_session_token(session_token_info.clone())
1167                                                .await?;
1168                                            refreshed = true;
1169                                            RetryDecision::retry_with_reason(
1170                                                Error::AuthFailure(e),
1171                                                "session token expired and refreshed",
1172                                            )
1173                                        } else {
1174                                            RetryDecision::no_retry(Error::AuthFailure(e))
1175                                        }
1176                                    } else if self.auth.can_reload() {
1177                                        let builder = RequestBuilder::from_parts(
1178                                            HttpClient::new(),
1179                                            request.try_clone().unwrap(),
1180                                        );
1181                                        let builder = self.auth.wrap(builder)?;
1182                                        request = builder.build()?;
1183                                        RetryDecision::retry_with_reason(
1184                                            Error::AuthFailure(e),
1185                                            "authentication token reloaded",
1186                                        )
1187                                    } else {
1188                                        RetryDecision::no_retry(Error::AuthFailure(e))
1189                                    }
1190                                } else {
1191                                    RetryDecision::no_retry(Error::Logic(status, e))
1192                                }
1193                            }
1194                            Err(_) => RetryDecision::no_retry(Error::Response {
1195                                status,
1196                                msg: String::from_utf8_lossy(&body).to_string(),
1197                            }),
1198                        }
1199                    }
1200                }
1201                Err(err) => {
1202                    let is_retryable = err.is_timeout() || err.is_connect() || err.is_request();
1203                    if is_retryable {
1204                        let reason = if err.is_timeout() {
1205                            "request timeout"
1206                        } else if err.is_connect() {
1207                            "connection error"
1208                        } else {
1209                            "request error"
1210                        };
1211                        RetryDecision::retry_with_reason(Error::Request(err.to_string()), reason)
1212                    } else {
1213                        RetryDecision::no_retry(Error::Request(err.to_string()))
1214                    }
1215                }
1216            };
1217            if !decision.should_retry {
1218                return Err(decision.error.with_context(&format!(
1219                    "{} {}",
1220                    request.method(),
1221                    request.url()
1222                )));
1223            }
1224            match &decision.error {
1225                Error::AuthFailure(_) => {
1226                    if refreshed {
1227                        retries = 0;
1228                    } else if retries >= max_retries.saturating_sub(1) {
1229                        return Err(decision.error.with_context(&format!(
1230                            "{} {} after {} retries",
1231                            request.method(),
1232                            request.url(),
1233                            max_retries
1234                        )));
1235                    }
1236                }
1237                _ => {
1238                    if retries >= max_retries.saturating_sub(1) {
1239                        return Err(decision.error.with_context(&format!(
1240                            "{} {} after {} retries",
1241                            request.method(),
1242                            request.url(),
1243                            max_retries
1244                        )));
1245                    }
1246                    retries += 1;
1247                    if let Some(reason) = decision.reason {
1248                        info!(
1249                            "retry {}/{} for {} due to: {} (error: {})",
1250                            retries,
1251                            max_retries,
1252                            request.url(),
1253                            reason,
1254                            decision.error
1255                        );
1256                    } else {
1257                        info!(
1258                            "retry {}/{} for {} on error: {}",
1259                            retries,
1260                            max_retries,
1261                            request.url(),
1262                            decision.error
1263                        );
1264                    }
1265                }
1266            }
1267            if let Some(reason) = decision.reason {
1268                warn!(
1269                    "retrying after {} seconds due to: {}",
1270                    retry_delay.as_secs(),
1271                    reason
1272                );
1273            } else {
1274                warn!("retrying after {} seconds", retry_delay.as_secs());
1275            }
1276            sleep(jitter(retry_delay)).await;
1277        }
1278    }
1279
1280    pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1281        if let Err(err) = http_client.execute(request).await {
1282            error!("[session {session_id}] logout request failed: {err}");
1283        } else {
1284            info!("[session {session_id}] logout success");
1285        };
1286    }
1287
1288    pub async fn close(&self) {
1289        let session_id = &self.session_id;
1290        info!("[session {session_id}] try closing now");
1291        if self
1292            .closed
1293            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1294            .is_ok()
1295        {
1296            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1297            if self.need_logout() {
1298                let cli = self.cli.clone();
1299                let req = self
1300                    .build_log_out_request()
1301                    .expect("failed to build logout request");
1302                Self::logout(cli, req, &self.session_id).await;
1303            }
1304        }
1305    }
1306    pub fn close_with_spawn(&self) {
1307        let session_id = &self.session_id;
1308        info!("[session {session_id}]: try closing with spawn");
1309        if self
1310            .closed
1311            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1312            .is_ok()
1313        {
1314            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1315            if self.need_logout() {
1316                let cli = self.cli.clone();
1317                let req = self
1318                    .build_log_out_request()
1319                    .expect("failed to build logout request");
1320                let session_id = self.session_id.clone();
1321                GLOBAL_RUNTIME.spawn(async move {
1322                    Self::logout(cli, req, session_id.as_str()).await;
1323                });
1324            }
1325        }
1326    }
1327
1328    pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1329        let mut queries = self.queries_need_heartbeat.lock();
1330        queries.insert(query_id.to_string(), state);
1331    }
1332}
1333
1334fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1335where
1336    T: Deserialize<'a>,
1337{
1338    serde_json::from_slice::<T>(body).map_err(|e| {
1339        Error::Decode(format!(
1340            "fail to decode JSON response: {e}, body: {}",
1341            String::from_utf8_lossy(body)
1342        ))
1343    })
1344}
1345
1346impl Default for APIClient {
1347    fn default() -> Self {
1348        Self {
1349            session_id: Default::default(),
1350            cli: HttpClient::new(),
1351            scheme: "http".to_string(),
1352            endpoint: Url::parse("http://localhost:8080").unwrap(),
1353            host: "localhost".to_string(),
1354            port: 8000,
1355            tenant: None,
1356            warehouse: Mutex::new(None),
1357            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1358            session_state: Mutex::new(SessionState::default()),
1359            wait_time_secs: None,
1360            max_rows_in_buffer: None,
1361            max_rows_per_page: None,
1362            connect_timeout: Duration::from_secs(10),
1363            page_request_timeout: Duration::from_secs(30),
1364            tls_ca_file: None,
1365            presign: Mutex::new(PresignMode::Auto),
1366            route_hint: RouteHintGenerator::new(),
1367            last_node_id: Default::default(),
1368            disable_session_token: true,
1369            disable_login: false,
1370            query_result_format: "json".to_string(),
1371            session_token_info: None,
1372            closed: AtomicBool::new(false),
1373            last_query_id: Default::default(),
1374            server_version: None,
1375            capability: Default::default(),
1376            queries_need_heartbeat: Default::default(),
1377            retry_count: 3,
1378            retry_delay_secs: 10,
1379        }
1380    }
1381}
1382
1383struct RouteHintGenerator {
1384    nonce: AtomicU64,
1385    current: std::sync::Mutex<String>,
1386}
1387
1388impl RouteHintGenerator {
1389    fn new() -> Self {
1390        let gen = Self {
1391            nonce: AtomicU64::new(0),
1392            current: std::sync::Mutex::new("".to_string()),
1393        };
1394        gen.next();
1395        gen
1396    }
1397
1398    fn current(&self) -> String {
1399        let guard = self.current.lock().unwrap();
1400        guard.clone()
1401    }
1402
1403    fn set(&self, hint: &str) {
1404        let mut guard = self.current.lock().unwrap();
1405        *guard = hint.to_string();
1406    }
1407
1408    fn next(&self) -> String {
1409        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1410        let uuid = uuid::Uuid::new_v4();
1411        let current = format!("rh:{uuid}:{nonce:06}");
1412        let mut guard = self.current.lock().unwrap();
1413        guard.clone_from(&current);
1414        current
1415    }
1416}
1417
1418#[cfg(test)]
1419mod test {
1420    use super::*;
1421
1422    #[tokio::test]
1423    async fn parse_dsn() -> Result<()> {
1424        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1425        let client = APIClient::from_dsn(dsn).await?;
1426        assert_eq!(client.host, "app.databend.com");
1427        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1428        assert_eq!(client.wait_time_secs, Some(10));
1429        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1430        assert_eq!(client.max_rows_per_page, Some(10000));
1431        assert_eq!(client.tenant, None);
1432        assert_eq!(
1433            *client.warehouse.try_lock().unwrap(),
1434            Some("wh".to_string())
1435        );
1436        Ok(())
1437    }
1438
1439    #[tokio::test]
1440    async fn parse_encoded_password() -> Result<()> {
1441        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1442        let client = APIClient::from_dsn(dsn).await?;
1443        assert_eq!(client.host(), "localhost");
1444        assert_eq!(client.port(), 443);
1445        Ok(())
1446    }
1447
1448    #[tokio::test]
1449    async fn parse_special_chars_password() -> Result<()> {
1450        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1451        let client = APIClient::from_dsn(dsn).await?;
1452        assert_eq!(client.host(), "localhost");
1453        assert_eq!(client.port(), 8000);
1454        Ok(())
1455    }
1456}