Skip to main content

databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22    SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28    error::{Error, RequestKind, Result},
29    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30    response::QueryResponse,
31    session::SessionState,
32    QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48    Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77
78static VERSION: Lazy<String> = Lazy::new(|| {
79    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
80    version.to_string()
81});
82
83#[derive(Clone)]
84pub(crate) struct QueryState {
85    pub node_id: String,
86    pub last_access_time: Arc<Mutex<Instant>>,
87    pub timeout_secs: u64,
88}
89
90impl QueryState {
91    pub fn need_heartbeat(&self, now: Instant) -> bool {
92        let t = *self.last_access_time.lock();
93        now.duration_since(t).as_secs() > self.timeout_secs / 2
94    }
95}
96
97struct HttpResponseData {
98    status: StatusCode,
99    headers: HeaderMap,
100    body: Bytes,
101}
102
103pub struct APIClient {
104    pub(crate) session_id: String,
105    cli: HttpClient,
106    scheme: String,
107    host: String,
108    port: u16,
109
110    endpoint: Url,
111
112    auth: Arc<dyn Auth>,
113
114    tenant: Option<String>,
115    warehouse: Mutex<Option<String>>,
116    session_state: Mutex<SessionState>,
117    route_hint: RouteHintGenerator,
118
119    disable_login: bool,
120    query_result_format: String,
121    disable_session_token: bool,
122    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
123
124    closed: AtomicBool,
125
126    server_version: Option<Version>,
127
128    wait_time_secs: Option<i64>,
129    max_rows_in_buffer: Option<i64>,
130    max_rows_per_page: Option<i64>,
131
132    connect_timeout: Duration,
133    page_request_timeout: Duration,
134
135    tls_ca_file: Option<String>,
136
137    presign: Mutex<PresignMode>,
138    last_node_id: Mutex<Option<String>>,
139    last_query_id: Mutex<Option<String>>,
140
141    capability: Capability,
142
143    queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
144
145    retry_count: u32,
146    retry_delay_secs: u64,
147}
148
149impl Drop for APIClient {
150    fn drop(&mut self) {
151        self.close_with_spawn()
152    }
153}
154
155impl APIClient {
156    fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
157        while let Some(err) = source {
158            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
159                match io_err.kind() {
160                    ErrorKind::TimedOut
161                    | ErrorKind::UnexpectedEof
162                    | ErrorKind::ConnectionReset
163                    | ErrorKind::ConnectionAborted
164                    | ErrorKind::BrokenPipe
165                    | ErrorKind::NotConnected
166                    | ErrorKind::WouldBlock
167                    | ErrorKind::Interrupted => return true,
168                    _ => {}
169                }
170            }
171            source = err.source();
172        }
173        false
174    }
175
176    fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
177        if err.is_timeout() {
178            Some("request timeout")
179        } else if err.is_connect() {
180            Some("connection error")
181        } else if err.is_request() {
182            Some("request error")
183        } else if err.is_body() || err.is_decode() {
184            if Self::has_transient_io_source(err.source()) {
185                Some("response read transient I/O error")
186            } else {
187                None
188            }
189        } else {
190            None
191        }
192    }
193
194    fn request_query_id(request: &Request) -> Option<String> {
195        request
196            .headers()
197            .get(HEADER_QUERY_ID)
198            .and_then(|v| v.to_str().ok())
199            .map(|v| v.to_string())
200    }
201
202    fn with_request_context(
203        reqwest_error: ReqwestError,
204        request_kind: &RequestKind,
205        request: &Request,
206        retry_times: Option<u32>,
207    ) -> Error {
208        let error: Error = reqwest_error.into();
209        let error = error.with_context(request_kind.clone());
210        let error = if let Some(retry_times) = retry_times {
211            error.with_retry_times(retry_times)
212        } else {
213            error
214        };
215        if let Some(query_id) = Self::request_query_id(request) {
216            error.with_query_id(query_id)
217        } else {
218            error
219        }
220    }
221
222    fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
223        match serde_json::from_slice::<ResponseWithErrorCode>(body) {
224            Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
225            Ok(resp) => Error::Logic(status, resp.error),
226            Err(_) => Error::response_error(status, body),
227        }
228    }
229
230    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
231        let mut client = Self::from_dsn(dsn).await?;
232        client.build_client(name).await?;
233        if !client.disable_login {
234            client.login().await?;
235        }
236        if client.session_id.is_empty() {
237            client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
238        }
239        let client = Arc::new(client);
240        client.check_presign().await?;
241        GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
242        Ok(client)
243    }
244
245    pub fn capability(&self) -> &Capability {
246        &self.capability
247    }
248
249    fn set_presign_mode(&self, mode: PresignMode) {
250        *self.presign.lock() = mode
251    }
252    fn get_presign_mode(&self) -> PresignMode {
253        *self.presign.lock()
254    }
255
256    async fn from_dsn(dsn: &str) -> Result<Self> {
257        let u = Url::parse(dsn)?;
258        let mut client = Self::default();
259        if let Some(host) = u.host_str() {
260            client.host = host.to_string();
261        }
262
263        if u.username() != "" {
264            let password = u.password().unwrap_or_default();
265            let password = percent_decode_str(password).decode_utf8()?;
266            client.auth = Arc::new(BasicAuth::new(u.username(), password));
267        }
268
269        let mut session_state = SessionState::default();
270
271        let database = u.path().trim_start_matches('/');
272        if !database.is_empty() {
273            session_state.set_database(database);
274        }
275
276        let mut scheme = "https";
277        for (k, v) in u.query_pairs() {
278            match k.as_ref() {
279                "wait_time_secs" => {
280                    client.wait_time_secs = Some(v.parse()?);
281                }
282                "max_rows_in_buffer" => {
283                    client.max_rows_in_buffer = Some(v.parse()?);
284                }
285                "max_rows_per_page" => {
286                    client.max_rows_per_page = Some(v.parse()?);
287                }
288                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
289                "page_request_timeout_secs" => {
290                    client.page_request_timeout = {
291                        let secs: u64 = v.parse()?;
292                        Duration::from_secs(secs)
293                    };
294                }
295                "presign" => {
296                    let presign_mode = match v.as_ref() {
297                        "auto" => PresignMode::Auto,
298                        "detect" => PresignMode::Detect,
299                        "on" => PresignMode::On,
300                        "off" => PresignMode::Off,
301                        _ => {
302                            return Err(Error::BadArgument(format!(
303                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
304                        )))
305                        }
306                    };
307                    client.set_presign_mode(presign_mode);
308                }
309                "tenant" => {
310                    client.tenant = Some(v.to_string());
311                }
312                "warehouse" => {
313                    client.warehouse = Mutex::new(Some(v.to_string()));
314                }
315                "role" => session_state.set_role(v),
316                "sslmode" => match v.as_ref() {
317                    "disable" => scheme = "http",
318                    "require" | "enable" => scheme = "https",
319                    _ => {
320                        return Err(Error::BadArgument(format!(
321                            "Invalid value for sslmode: {v}"
322                        )))
323                    }
324                },
325                "tls_ca_file" => {
326                    client.tls_ca_file = Some(v.to_string());
327                }
328                "access_token" => {
329                    client.auth = Arc::new(AccessTokenAuth::new(v));
330                }
331                "access_token_file" => {
332                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
333                }
334                "login" => {
335                    client.disable_login = match v.as_ref() {
336                        "disable" => true,
337                        "enable" => false,
338                        _ => {
339                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
340                        }
341                    }
342                }
343                "session_token" => {
344                    client.disable_session_token = match v.as_ref() {
345                        "disable" => true,
346                        "enable" => false,
347                        _ => {
348                            return Err(Error::BadArgument(format!(
349                                "Invalid value for session_token: {v}"
350                            )))
351                        }
352                    }
353                }
354                "body_format" | "query_result_format" => {
355                    let v = v.to_string().to_lowercase();
356                    match v.as_str() {
357                        "json" | "arrow" => client.query_result_format = v.to_string(),
358                        _ => {
359                            return Err(Error::BadArgument(format!(
360                                "Invalid value for query_result_format: {v}"
361                            )))
362                        }
363                    }
364                }
365                "retry_count" => {
366                    client.retry_count = v.parse()?;
367                }
368                "retry_delay_secs" => {
369                    client.retry_delay_secs = v.parse()?;
370                }
371                _ => {
372                    session_state.set(k, v);
373                }
374            }
375        }
376        client.port = match u.port() {
377            Some(p) => p,
378            None => match scheme {
379                "http" => 80,
380                "https" => 443,
381                _ => unreachable!(),
382            },
383        };
384        client.scheme = scheme.to_string();
385        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
386        client.session_state = Mutex::new(session_state);
387
388        Ok(client)
389    }
390
391    pub fn host(&self) -> &str {
392        self.host.as_str()
393    }
394
395    pub fn port(&self) -> u16 {
396        self.port
397    }
398
399    pub fn scheme(&self) -> &str {
400        self.scheme.as_str()
401    }
402
403    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
404        let ua = match name {
405            Some(n) => n,
406            None => format!("databend-client-rust/{}", VERSION.as_str()),
407        };
408        let cookie_provider = GlobalCookieStore::new();
409        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
410        let mut initial_cookies = [&cookie].into_iter();
411        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
412        let mut cli_builder = HttpClient::builder()
413            .user_agent(ua)
414            .connect_timeout(self.connect_timeout)
415            .cookie_provider(Arc::new(cookie_provider))
416            .pool_idle_timeout(Duration::from_secs(1));
417        #[cfg(any(feature = "rustls", feature = "native-tls"))]
418        if self.scheme == "https" {
419            if let Some(ref ca_file) = self.tls_ca_file {
420                let cert_pem = tokio::fs::read(ca_file).await?;
421                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
422                cli_builder = cli_builder.add_root_certificate(cert);
423            }
424        }
425        self.cli = cli_builder.build()?;
426        Ok(())
427    }
428
429    async fn check_presign(self: &Arc<Self>) -> Result<()> {
430        let mode = match self.get_presign_mode() {
431            PresignMode::Auto => {
432                if self.host.ends_with(".databend.com")
433                    || self.host.ends_with(".databend.cn")
434                    || self.host.ends_with(".tidbcloud.com")
435                {
436                    PresignMode::On
437                } else {
438                    PresignMode::Off
439                }
440            }
441            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
442                Ok(_) => PresignMode::On,
443                Err(e) => {
444                    warn!("presign mode off with error detected: {e}");
445                    PresignMode::Off
446                }
447            },
448            mode => mode,
449        };
450        self.set_presign_mode(mode);
451        Ok(())
452    }
453
454    pub fn current_warehouse(&self) -> Option<String> {
455        let guard = self.warehouse.lock();
456        guard.clone()
457    }
458
459    pub fn current_catalog(&self) -> Option<String> {
460        let guard = self.session_state.lock();
461        guard.catalog.clone()
462    }
463
464    pub fn current_database(&self) -> Option<String> {
465        let guard = self.session_state.lock();
466        guard.database.clone()
467    }
468
469    pub fn set_warehouse(&self, warehouse: impl Into<String>) {
470        let mut guard = self.warehouse.lock();
471        *guard = Some(warehouse.into());
472    }
473
474    pub fn set_database(&self, database: impl Into<String>) {
475        let mut guard = self.session_state.lock();
476        guard.set_database(database);
477    }
478
479    pub fn set_role(&self, role: impl Into<String>) {
480        let mut guard = self.session_state.lock();
481        guard.set_role(role);
482    }
483
484    pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
485        let mut guard = self.session_state.lock();
486        guard.set(key, value);
487    }
488
489    pub async fn current_role(&self) -> Option<String> {
490        let guard = self.session_state.lock();
491        guard.role.clone()
492    }
493
494    fn in_active_transaction(&self) -> bool {
495        let guard = self.session_state.lock();
496        guard
497            .txn_state
498            .as_ref()
499            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
500            .unwrap_or(false)
501    }
502
503    pub fn username(&self) -> String {
504        self.auth.username()
505    }
506
507    fn gen_query_id(&self) -> String {
508        uuid::Uuid::now_v7().simple().to_string()
509    }
510
511    async fn handle_session(&self, session: &Option<SessionState>) {
512        let session = match session {
513            Some(session) => session,
514            None => return,
515        };
516
517        // save the updated session state from the server side
518        {
519            let mut session_state = self.session_state.lock();
520            *session_state = session.clone();
521        }
522
523        // process warehouse changed via session settings
524        if let Some(settings) = session.settings.as_ref() {
525            if let Some(v) = settings.get("warehouse") {
526                let mut warehouse = self.warehouse.lock();
527                *warehouse = Some(v.clone());
528            }
529        }
530    }
531
532    pub fn set_last_node_id(&self, node_id: String) {
533        *self.last_node_id.lock() = Some(node_id)
534    }
535
536    pub fn set_last_query_id(&self, query_id: Option<String>) {
537        *self.last_query_id.lock() = query_id
538    }
539
540    pub fn last_query_id(&self) -> Option<String> {
541        self.last_query_id.lock().clone()
542    }
543
544    fn last_node_id(&self) -> Option<String> {
545        self.last_node_id.lock().clone()
546    }
547
548    fn handle_warnings(&self, resp: &QueryResponse) {
549        if let Some(warnings) = &resp.warnings {
550            for w in warnings {
551                warn!(target: "server_warnings", "server warning: {w}");
552            }
553        }
554    }
555
556    pub async fn start_query(
557        self: &Arc<Self>,
558        sql: &str,
559        need_progress: bool,
560        params: Option<serde_json::Value>,
561    ) -> Result<Pages> {
562        info!("start query: {sql}");
563        let (resp, batches) = self.start_query_inner(sql, None, false, params).await?;
564        Pages::new(self.clone(), resp, batches, need_progress)
565    }
566
567    pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
568        let mut mgr = self.queries_need_heartbeat.lock();
569        if let Some(state) = mgr.remove(query_id) {
570            let self_cloned = self.clone();
571            let query_id = query_id.to_owned();
572            GLOBAL_RUNTIME.spawn(async move {
573                if let Err(e) = self_cloned
574                    .end_query(&query_id, false, Some(state.node_id.as_str()))
575                    .await
576                {
577                    error!("failed to final query {query_id}: {e}");
578                }
579            });
580        }
581    }
582
583    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
584        if let Some(info) = &self.session_token_info {
585            let info = info.lock();
586            Ok(builder.bearer_auth(info.0.session_token.clone()))
587        } else {
588            self.auth.wrap(builder)
589        }
590    }
591
592    async fn start_query_inner(
593        &self,
594        sql: &str,
595        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
596        force_json_body: bool,
597        params: Option<serde_json::Value>,
598    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
599        if !self.in_active_transaction() {
600            self.route_hint.next();
601        }
602        let endpoint = self.endpoint.join("v1/query")?;
603
604        // body
605        let session_state = self.session_state();
606        let need_sticky = session_state.need_sticky.unwrap_or(false);
607        let mut req = QueryRequest::new(sql)
608            .with_pagination(self.make_pagination())
609            .with_session(Some(session_state))
610            .with_stage_attachment(stage_attachment_config)
611            .with_params(params);
612
613        // headers
614        let query_id = self.gen_query_id();
615        let mut headers = self.make_headers(Some(&query_id))?;
616        if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
617            debug!("accept arrow data");
618            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
619            req = req.with_arrow();
620        }
621
622        if need_sticky {
623            if let Some(node_id) = self.last_node_id() {
624                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
625            }
626        }
627        let mut builder = self.cli.post(endpoint.clone()).json(&req);
628        builder = self.wrap_auth_or_session_token(builder)?;
629        let request = builder
630            .headers(headers.clone())
631            .timeout(self.page_request_timeout)
632            .build()?;
633        let response = self
634            .query_request_helper(request, true, true, true, RequestKind::QueryStart)
635            .await?;
636        self.handle_page(response, true).await
637    }
638
639    fn is_arrow_data(headers: &HeaderMap) -> bool {
640        if let Some(typ) = headers.get(CONTENT_TYPE) {
641            if let Ok(t) = typ.to_str() {
642                return t == CONTENT_TYPE_ARROW;
643            }
644        }
645        false
646    }
647
648    async fn handle_page(
649        &self,
650        response: HttpResponseData,
651        is_first: bool,
652    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
653        let status = response.status;
654        if status != StatusCode::OK {
655            let error = Self::status_body_to_error(status, &response.body);
656            if !is_first {
657                if let Error::Logic(code, ec) = &error {
658                    if *code == StatusCode::NOT_FOUND {
659                        return Err(Error::QueryNotFound(ec.message.clone()));
660                    }
661                }
662            }
663            return Err(error);
664        }
665        let is_arrow_data = Self::is_arrow_data(&response.headers);
666        if is_first {
667            if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
668                self.route_hint.set(route_hint.to_str().unwrap_or_default());
669            }
670        }
671        let mut body = response.body;
672        let mut batches = vec![];
673        if is_arrow_data {
674            if is_first {
675                debug!("received arrow data");
676            }
677            let cursor = std::io::Cursor::new(body.as_ref());
678            let reader = StreamReader::try_new(cursor, None)
679                .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
680            let schema = reader.schema();
681            let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
682                bytes::Bytes::copy_from_slice(json_resp.as_bytes())
683            } else {
684                return Err(Error::Decode(
685                    "missing response_header metadata in arrow payload".to_string(),
686                ));
687            };
688            for batch in reader {
689                let batch = batch
690                    .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
691                batches.push(batch);
692            }
693            body = json_body
694        };
695        let resp: QueryResponse = json_from_slice(&body)?;
696        self.handle_session(&resp.session).await;
697        if let Some(err) = &resp.error {
698            return Err(Error::QueryFailed(err.clone()));
699        }
700        if is_first {
701            self.handle_warnings(&resp);
702            self.set_last_query_id(Some(resp.id.clone()));
703            if let Some(node_id) = &resp.node_id {
704                self.set_last_node_id(node_id.clone());
705            }
706        }
707        Ok((resp, batches))
708    }
709
710    pub async fn query_page(
711        &self,
712        query_id: &str,
713        next_uri: &str,
714        node_id: &Option<String>,
715    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
716        info!("query page: {next_uri}");
717        let endpoint = self.endpoint.join(next_uri)?;
718        let mut headers = self.make_headers(Some(query_id))?;
719        if self.capability.arrow_data && self.query_result_format == "arrow" {
720            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
721        }
722        let mut builder = self.cli.get(endpoint.clone());
723        builder = self
724            .wrap_auth_or_session_token(builder)?
725            .headers(headers.clone())
726            .timeout(self.page_request_timeout);
727        if let Some(node_id) = node_id {
728            builder = builder.header(HEADER_STICKY_NODE, node_id)
729        }
730        let request = builder.build()?;
731
732        let response = self
733            .query_request_helper(request, false, true, true, RequestKind::QueryPage)
734            .await?;
735        self.handle_page(response, false).await
736    }
737
738    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
739        self.end_query(query_id, true, None).await
740    }
741
742    pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
743        self.end_query(query_id, false, node_id).await
744    }
745
746    pub async fn end_query(
747        &self,
748        query_id: &str,
749        is_kill: bool,
750        node_id: Option<&str>,
751    ) -> Result<()> {
752        let (method, request_kind) = if is_kill {
753            ("kill", RequestKind::QueryKill)
754        } else {
755            ("final", RequestKind::QueryFinal)
756        };
757        let uri = format!("/v1/query/{query_id}/{method}");
758        let endpoint = self.endpoint.join(&uri)?;
759        let headers = self.make_headers(Some(query_id))?;
760
761        info!("{method} query: {uri}");
762
763        let mut builder = self.cli.post(endpoint.clone());
764        if let Some(node_id) = node_id {
765            builder = builder.header(HEADER_STICKY_NODE, node_id)
766        }
767        builder = self.wrap_auth_or_session_token(builder)?;
768        let resp = builder.headers(headers.clone()).send().await?;
769        if resp.status() != 200 {
770            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
771                .with_context(request_kind)
772                .with_query_id(query_id));
773        }
774        Ok(())
775    }
776
777    pub async fn query_all(
778        self: &Arc<Self>,
779        sql: &str,
780        params: Option<serde_json::Value>,
781    ) -> Result<Page> {
782        self.query_all_inner(sql, false, params).await
783    }
784
785    pub async fn query_all_inner(
786        self: &Arc<Self>,
787        sql: &str,
788        force_json_body: bool,
789        params: Option<serde_json::Value>,
790    ) -> Result<Page> {
791        let (resp, batches) = self
792            .start_query_inner(sql, None, force_json_body, params)
793            .await?;
794        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
795        let mut all = Page::default();
796        while let Some(page) = pages.next().await {
797            all.update(page?);
798        }
799        Ok(all)
800    }
801
802    fn session_state(&self) -> SessionState {
803        self.session_state.lock().clone()
804    }
805
806    fn make_pagination(&self) -> Option<PaginationConfig> {
807        if self.wait_time_secs.is_none()
808            && self.max_rows_in_buffer.is_none()
809            && self.max_rows_per_page.is_none()
810        {
811            return None;
812        }
813        let mut pagination = PaginationConfig {
814            wait_time_secs: None,
815            max_rows_in_buffer: None,
816            max_rows_per_page: None,
817        };
818        if let Some(wait_time_secs) = self.wait_time_secs {
819            pagination.wait_time_secs = Some(wait_time_secs);
820        }
821        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
822            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
823        }
824        if let Some(max_rows_per_page) = self.max_rows_per_page {
825            pagination.max_rows_per_page = Some(max_rows_per_page);
826        }
827        Some(pagination)
828    }
829
830    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
831        let mut headers = HeaderMap::new();
832        if let Some(tenant) = &self.tenant {
833            headers.insert(HEADER_TENANT, tenant.parse()?);
834        }
835        let warehouse = self.warehouse.lock().clone();
836        if let Some(warehouse) = warehouse {
837            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
838        }
839        let route_hint = self.route_hint.current();
840        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
841        if let Some(query_id) = query_id {
842            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
843        }
844        Ok(headers)
845    }
846
847    pub async fn insert_with_stage(
848        self: &Arc<Self>,
849        sql: &str,
850        stage: &str,
851        file_format_options: BTreeMap<&str, &str>,
852        copy_options: BTreeMap<&str, &str>,
853    ) -> Result<QueryStats> {
854        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
855        let stage_attachment = Some(StageAttachmentConfig {
856            location: stage,
857            file_format_options: Some(file_format_options),
858            copy_options: Some(copy_options),
859        });
860        let (resp, batches) = self
861            .start_query_inner(sql, stage_attachment, true, None)
862            .await?;
863        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
864        let mut all = Page::default();
865        while let Some(page) = pages.next().await {
866            all.update(page?);
867        }
868        Ok(all.stats)
869    }
870
871    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
872        info!("get presigned upload url: {stage}");
873        let sql = format!("PRESIGN UPLOAD {stage}");
874        let resp = self.query_all_inner(&sql, true, None).await?;
875        if resp.data.len() != 1 {
876            return Err(Error::Decode(
877                "Empty response from server for presigned request".to_string(),
878            ));
879        }
880        if resp.data[0].len() != 3 {
881            return Err(Error::Decode(
882                "Invalid response from server for presigned request".to_string(),
883            ));
884        }
885        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
886        let method = resp.data[0][0].clone().unwrap_or_default();
887        if method != "PUT" {
888            return Err(Error::Decode(format!(
889                "Invalid method for presigned upload request: {method}"
890            )));
891        }
892        let headers: BTreeMap<String, String> =
893            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
894        let url = resp.data[0][2].clone().unwrap_or_default();
895        Ok(PresignedResponse {
896            method,
897            headers,
898            url,
899        })
900    }
901
902    pub async fn upload_to_stage(
903        self: &Arc<Self>,
904        stage: &str,
905        data: Reader,
906        size: u64,
907    ) -> Result<()> {
908        match self.get_presign_mode() {
909            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
910            PresignMode::On => {
911                let presigned = self.get_presigned_upload_url(stage).await?;
912                presign_upload_to_stage(presigned, data, size).await
913            }
914            PresignMode::Auto => {
915                unreachable!("PresignMode::Auto should be handled during client initialization")
916            }
917            PresignMode::Detect => {
918                unreachable!("PresignMode::Detect should be handled during client initialization")
919            }
920        }
921    }
922
923    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
924    async fn upload_to_stage_with_stream(
925        &self,
926        stage: &str,
927        data: Reader,
928        size: u64,
929    ) -> Result<()> {
930        info!("upload to stage with stream: {stage}, size: {size}");
931        if let Some(info) = self.need_pre_refresh_session().await {
932            self.refresh_session_token(info).await?;
933        }
934        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
935        let location = StageLocation::try_from(stage)?;
936        let query_id = self.gen_query_id();
937        let mut headers = self.make_headers(Some(&query_id))?;
938        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
939        let stream = Body::wrap_stream(ReaderStream::new(data));
940        let part = Part::stream_with_length(stream, size).file_name(location.path);
941        let form = Form::new().part("upload", part);
942        let mut builder = self.cli.put(endpoint.clone());
943        builder = self.wrap_auth_or_session_token(builder)?;
944        let resp = builder.headers(headers).multipart(form).send().await?;
945        let status = resp.status();
946        if status != 200 {
947            return Err(Error::response_error(status, &resp.bytes().await?)
948                .with_context(RequestKind::UploadToStage)
949                .with_query_id(query_id));
950        }
951        Ok(())
952    }
953
954    // use base64 encode whenever possible for safety
955    // but also accept raw JSON for test/debug/one-shot operations
956    pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
957    where
958        T: de::DeserializeOwned,
959    {
960        if value.starts_with("{") {
961            serde_json::from_slice(value.as_bytes())
962                .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
963        } else {
964            let json = URL_SAFE.decode(value).map_err(|e| {
965                format!(
966                    "Invalid value {} for {key}, base64 decode error: {}",
967                    value, e
968                )
969            })?;
970            serde_json::from_slice(&json).map_err(|e| {
971                format!(
972                    "Invalid value {value} for {key}, JSON value {},  decode error: {e}",
973                    String::from_utf8_lossy(&json)
974                )
975            })
976        }
977    }
978
979    pub async fn streaming_load(
980        &self,
981        sql: &str,
982        data: Reader,
983        file_name: &str,
984    ) -> Result<LoadResponse> {
985        let body = Body::wrap_stream(ReaderStream::new(data));
986        let part = Part::stream(body).file_name(file_name.to_string());
987        let endpoint = self.endpoint.join("v1/streaming_load")?;
988        let mut builder = self.cli.put(endpoint.clone());
989        builder = self.wrap_auth_or_session_token(builder)?;
990        let query_id = self.gen_query_id();
991        let mut headers = self.make_headers(Some(&query_id))?;
992        headers.insert(HEADER_SQL, sql.parse()?);
993        let session = serde_json::to_string(&*self.session_state.lock())
994            .expect("serialize session state should not fail");
995        headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
996        let form = Form::new().part("upload", part);
997        let resp = builder.headers(headers).multipart(form).send().await?;
998        let status = resp.status();
999        if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
1000            match Self::decode_json_header::<SessionState>(
1001                HEADER_QUERY_CONTEXT,
1002                value.to_str().unwrap(),
1003            ) {
1004                Ok(session) => *self.session_state.lock() = session,
1005                Err(e) => {
1006                    error!("Error decoding session state when streaming load: {e}");
1007                }
1008            }
1009        };
1010        if status != 200 {
1011            return Err(Error::response_error(status, &resp.bytes().await?)
1012                .with_context(RequestKind::StreamingLoad)
1013                .with_query_id(query_id));
1014        }
1015        let resp = resp.json::<LoadResponse>().await?;
1016        Ok(resp)
1017    }
1018
1019    async fn login(&mut self) -> Result<()> {
1020        let endpoint = self.endpoint.join("/v1/session/login")?;
1021        let headers = self.make_headers(None)?;
1022        let body = LoginRequest::from(&*self.session_state.lock());
1023        let mut builder = self.cli.post(endpoint.clone()).json(&body);
1024        if self.disable_session_token {
1025            builder = builder.query(&[("disable_session_token", true)]);
1026        }
1027        let builder = self.auth.wrap(builder)?;
1028        let request = builder
1029            .headers(headers.clone())
1030            .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1031            .build()?;
1032        let response = self
1033            .query_request_helper(request, true, false, true, RequestKind::Login)
1034            .await?;
1035        if response.status == StatusCode::NOT_FOUND {
1036            info!("login return 404, skip login on the old version server");
1037            return Ok(());
1038        }
1039        if response.status != StatusCode::OK {
1040            return Err(Self::status_body_to_error(response.status, &response.body)
1041                .with_context(RequestKind::Login));
1042        }
1043        if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1044            if let Ok(s) = v.to_str() {
1045                self.session_id = s.to_string();
1046            }
1047        }
1048
1049        let response = json_from_slice(&response.body)?;
1050        match response {
1051            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1052            LoginResponseResult::Ok(info) => {
1053                let server_version = info
1054                    .version
1055                    .parse()
1056                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1057                self.capability = Capability::from_server_version(&server_version);
1058                self.server_version = Some(server_version.clone());
1059                let session_id = self.session_id.as_str();
1060                if let Some(tokens) = info.tokens {
1061                    info!(
1062                        "[session {session_id}] login success with session token version = {server_version}",
1063                    );
1064                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1065                } else {
1066                    info!("[session {session_id}] login success, version = {server_version}");
1067                }
1068            }
1069        }
1070        Ok(())
1071    }
1072
1073    pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1074        let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1075        let queries = self.queries_need_heartbeat.lock().clone();
1076        let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1077        let now = Instant::now();
1078
1079        let mut query_ids = Vec::new();
1080        for (qid, state) in queries {
1081            if state.need_heartbeat(now) {
1082                query_ids.push(qid.to_string());
1083                if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1084                    arr.push(qid);
1085                } else {
1086                    node_to_queries.insert(state.node_id, vec![qid]);
1087                }
1088            }
1089        }
1090
1091        if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1092        {
1093            return Ok(());
1094        }
1095
1096        let body = json!({
1097           "node_to_queries": node_to_queries
1098        });
1099        let headers = self.make_headers(None)?;
1100        let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1101        let request = self.wrap_auth_or_session_token(builder)?.build()?;
1102        let response = self
1103            .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1104            .await?;
1105        if response.status != StatusCode::OK {
1106            return Err(Self::status_body_to_error(response.status, &response.body)
1107                .with_context(RequestKind::Heartbeat));
1108        }
1109        let json: Value = json_from_slice(&response.body)?;
1110        let session_id = self.session_id.as_str();
1111        info!("[session {session_id}] heartbeat request={body}, response={json}");
1112        if let Some(queries_to_remove) = json.get("queries_to_remove") {
1113            if let Some(arr) = queries_to_remove.as_array() {
1114                if !arr.is_empty() {
1115                    let mut queries = self.queries_need_heartbeat.lock();
1116                    for q in arr {
1117                        if let Some(q) = q.as_str() {
1118                            queries.remove(q);
1119                        }
1120                    }
1121                }
1122            }
1123        }
1124        let now = Instant::now();
1125        let mut queries = self.queries_need_heartbeat.lock();
1126        for qid in query_ids {
1127            if let Some(state) = queries.get_mut(&qid) {
1128                *state.last_access_time.lock() = now;
1129            }
1130        }
1131        Ok(())
1132    }
1133
1134    fn build_log_out_request(&self) -> Result<Request> {
1135        let endpoint = self.endpoint.join("/v1/session/logout")?;
1136
1137        let session_state = self.session_state();
1138        let need_sticky = session_state.need_sticky.unwrap_or(false);
1139        let mut headers = self.make_headers(None)?;
1140        if need_sticky {
1141            if let Some(node_id) = self.last_node_id() {
1142                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1143            }
1144        }
1145        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1146
1147        let builder = self.wrap_auth_or_session_token(builder)?;
1148        let req = builder.build()?;
1149        Ok(req)
1150    }
1151
1152    pub(crate) fn need_logout(&self) -> bool {
1153        self.session_token_info.is_some()
1154            || self.session_state.lock().need_keep_alive.unwrap_or(false)
1155    }
1156
1157    async fn refresh_session_token(
1158        &self,
1159        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1160    ) -> Result<()> {
1161        let (session_token_info, _) = { self_login_info.lock().clone() };
1162        let endpoint = self.endpoint.join("/v1/session/refresh")?;
1163        let body = RefreshSessionTokenRequest {
1164            session_token: session_token_info.session_token.clone(),
1165        };
1166        let headers = self.make_headers(None)?;
1167        let request = self
1168            .cli
1169            .post(endpoint.clone())
1170            .json(&body)
1171            .headers(headers.clone())
1172            .bearer_auth(session_token_info.refresh_token.clone())
1173            .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1174            .build()?;
1175        let response = self
1176            .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1177            .await?;
1178        if response.status != StatusCode::OK {
1179            return Err(Self::status_body_to_error(response.status, &response.body)
1180                .with_context(RequestKind::SessionRefresh));
1181        }
1182        let response = json_from_slice(&response.body)?;
1183        match response {
1184            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1185            RefreshResponse::Ok(info) => {
1186                *self_login_info.lock() = (info, Instant::now());
1187                Ok(())
1188            }
1189        }
1190    }
1191
1192    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1193        if let Some(info) = &self.session_token_info {
1194            let (start, ttl) = {
1195                let guard = info.lock();
1196                (guard.1, guard.0.session_token_ttl_in_secs)
1197            };
1198            if Instant::now() > start + Duration::from_secs(ttl) {
1199                return Some(info.clone());
1200            }
1201        }
1202        None
1203    }
1204
1205    /// Send request with retry and return raw HTTP status/headers/body.
1206    ///
1207    /// retry on
1208    ///   - network/request/body transient errors
1209    ///   - (optional) 503
1210    ///
1211    /// refresh databend token or reload jwt token if needed.
1212    async fn query_request_helper(
1213        &self,
1214        mut request: Request,
1215        retry_if_503: bool,
1216        refresh_if_401: bool,
1217        reload_auth_if_401: bool,
1218        request_kind: RequestKind,
1219    ) -> Result<HttpResponseData> {
1220        let mut refreshed = false;
1221        let mut retries = 0;
1222        let max_retries = self.retry_count;
1223        let retry_delay = Duration::from_secs(self.retry_delay_secs);
1224
1225        loop {
1226            let req = request.try_clone().expect("request not cloneable");
1227            let response = match self.cli.execute(req).await {
1228                Ok(response) => response,
1229                Err(err) => {
1230                    if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1231                        if retries >= max_retries.saturating_sub(1) {
1232                            return Err(Self::with_request_context(
1233                                err,
1234                                &request_kind,
1235                                &request,
1236                                Some(retries),
1237                            ));
1238                        }
1239                        retries += 1;
1240                        warn!(
1241                            "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1242                            retries,
1243                            max_retries,
1244                            request.url(),
1245                            reason,
1246                            err,
1247                            retry_delay.as_secs(),
1248                        );
1249                        sleep(jitter(retry_delay)).await;
1250                        continue;
1251                    }
1252                    return Err(Self::with_request_context(
1253                        err,
1254                        &request_kind,
1255                        &request,
1256                        Some(retries),
1257                    ));
1258                }
1259            };
1260
1261            let status = response.status();
1262            let headers = response.headers().clone();
1263            let body = match response.bytes().await {
1264                Ok(body) => body,
1265                Err(err) => {
1266                    if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1267                        if retries >= max_retries.saturating_sub(1) {
1268                            return Err(Self::with_request_context(
1269                                err,
1270                                &request_kind,
1271                                &request,
1272                                Some(retries),
1273                            ));
1274                        }
1275                        retries += 1;
1276                        warn!(
1277                            "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1278                            retries,
1279                            max_retries,
1280                            request.url(),
1281                            reason,
1282                            err,
1283                            retry_delay.as_secs(),
1284                        );
1285                        sleep(jitter(retry_delay)).await;
1286                        continue;
1287                    }
1288                    return Err(Self::with_request_context(
1289                        err,
1290                        &request_kind,
1291                        &request,
1292                        Some(retries),
1293                    ));
1294                }
1295            };
1296
1297            if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1298                if retries >= max_retries.saturating_sub(1) {
1299                    return Ok(HttpResponseData {
1300                        status,
1301                        headers,
1302                        body,
1303                    });
1304                }
1305                retries += 1;
1306                warn!(
1307                    "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1308                    retries,
1309                    max_retries,
1310                    request.url(),
1311                    retry_delay.as_secs()
1312                );
1313                sleep(jitter(retry_delay)).await;
1314                continue;
1315            }
1316
1317            if status == StatusCode::UNAUTHORIZED {
1318                request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1319                let unauthorized_error =
1320                    serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1321                if let Some(session_token_info) = &self.session_token_info {
1322                    let should_refresh = unauthorized_error
1323                        .as_ref()
1324                        .map(|r| need_refresh_token(r.error.code))
1325                        .unwrap_or(false);
1326                    if refresh_if_401 && should_refresh && !refreshed {
1327                        Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1328                        refreshed = true;
1329                        retries = 0;
1330                        warn!(
1331                            "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1332                            request.url(),
1333                            retry_delay.as_secs()
1334                        );
1335                        sleep(jitter(retry_delay)).await;
1336                        continue;
1337                    }
1338                }
1339                if reload_auth_if_401 && self.auth.can_reload() {
1340                    if retries >= max_retries.saturating_sub(1) {
1341                        return Ok(HttpResponseData {
1342                            status,
1343                            headers,
1344                            body,
1345                        });
1346                    }
1347                    retries += 1;
1348                    let builder =
1349                        RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1350                    let builder = self.auth.wrap(builder)?;
1351                    request = builder.build()?;
1352                    warn!(
1353                        "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1354                        retries,
1355                        max_retries,
1356                        request.url(),
1357                        retry_delay.as_secs()
1358                    );
1359                    sleep(jitter(retry_delay)).await;
1360                    continue;
1361                }
1362            }
1363
1364            return Ok(HttpResponseData {
1365                status,
1366                headers,
1367                body,
1368            });
1369        }
1370    }
1371
1372    pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1373        if let Err(err) = http_client.execute(request).await {
1374            error!("[session {session_id}] logout request failed: {err}");
1375        } else {
1376            info!("[session {session_id}] logout success");
1377        };
1378    }
1379
1380    pub async fn close(&self) {
1381        let session_id = &self.session_id;
1382        info!("[session {session_id}] try closing now");
1383        if self
1384            .closed
1385            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1386            .is_ok()
1387        {
1388            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1389            if self.need_logout() {
1390                let cli = self.cli.clone();
1391                let req = self
1392                    .build_log_out_request()
1393                    .expect("failed to build logout request");
1394                Self::logout(cli, req, &self.session_id).await;
1395            }
1396        }
1397    }
1398    pub fn close_with_spawn(&self) {
1399        let session_id = &self.session_id;
1400        info!("[session {session_id}]: try closing with spawn");
1401        if self
1402            .closed
1403            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1404            .is_ok()
1405        {
1406            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1407            if self.need_logout() {
1408                let cli = self.cli.clone();
1409                let req = self
1410                    .build_log_out_request()
1411                    .expect("failed to build logout request");
1412                let session_id = self.session_id.clone();
1413                GLOBAL_RUNTIME.spawn(async move {
1414                    Self::logout(cli, req, session_id.as_str()).await;
1415                });
1416            }
1417        }
1418    }
1419
1420    pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1421        let mut queries = self.queries_need_heartbeat.lock();
1422        queries.insert(query_id.to_string(), state);
1423    }
1424}
1425
1426fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1427where
1428    T: Deserialize<'a>,
1429{
1430    serde_json::from_slice::<T>(body).map_err(|e| {
1431        Error::Decode(format!(
1432            "fail to decode JSON response: {e}, body: {}",
1433            String::from_utf8_lossy(body)
1434        ))
1435    })
1436}
1437
1438impl Default for APIClient {
1439    fn default() -> Self {
1440        Self {
1441            session_id: Default::default(),
1442            cli: HttpClient::new(),
1443            scheme: "http".to_string(),
1444            endpoint: Url::parse("http://localhost:8080").unwrap(),
1445            host: "localhost".to_string(),
1446            port: 8000,
1447            tenant: None,
1448            warehouse: Mutex::new(None),
1449            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1450            session_state: Mutex::new(SessionState::default()),
1451            wait_time_secs: None,
1452            max_rows_in_buffer: None,
1453            max_rows_per_page: None,
1454            connect_timeout: Duration::from_secs(10),
1455            page_request_timeout: Duration::from_secs(300),
1456            tls_ca_file: None,
1457            presign: Mutex::new(PresignMode::Auto),
1458            route_hint: RouteHintGenerator::new(),
1459            last_node_id: Default::default(),
1460            disable_session_token: true,
1461            disable_login: false,
1462            query_result_format: "json".to_string(),
1463            session_token_info: None,
1464            closed: AtomicBool::new(false),
1465            last_query_id: Default::default(),
1466            server_version: None,
1467            capability: Default::default(),
1468            queries_need_heartbeat: Default::default(),
1469            retry_count: 3,
1470            retry_delay_secs: 10,
1471        }
1472    }
1473}
1474
1475struct RouteHintGenerator {
1476    nonce: AtomicU64,
1477    current: std::sync::Mutex<String>,
1478}
1479
1480impl RouteHintGenerator {
1481    fn new() -> Self {
1482        let gen = Self {
1483            nonce: AtomicU64::new(0),
1484            current: std::sync::Mutex::new("".to_string()),
1485        };
1486        gen.next();
1487        gen
1488    }
1489
1490    fn current(&self) -> String {
1491        let guard = self.current.lock().unwrap();
1492        guard.clone()
1493    }
1494
1495    fn set(&self, hint: &str) {
1496        let mut guard = self.current.lock().unwrap();
1497        *guard = hint.to_string();
1498    }
1499
1500    fn next(&self) -> String {
1501        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1502        let uuid = uuid::Uuid::new_v4();
1503        let current = format!("rh:{uuid}:{nonce:06}");
1504        let mut guard = self.current.lock().unwrap();
1505        guard.clone_from(&current);
1506        current
1507    }
1508}
1509
1510#[cfg(test)]
1511mod test {
1512    use super::*;
1513
1514    #[tokio::test]
1515    async fn parse_dsn() -> Result<()> {
1516        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1517        let client = APIClient::from_dsn(dsn).await?;
1518        assert_eq!(client.host, "app.databend.com");
1519        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1520        assert_eq!(client.wait_time_secs, Some(10));
1521        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1522        assert_eq!(client.max_rows_per_page, Some(10000));
1523        assert_eq!(client.tenant, None);
1524        assert_eq!(
1525            *client.warehouse.try_lock().unwrap(),
1526            Some("wh".to_string())
1527        );
1528        Ok(())
1529    }
1530
1531    #[tokio::test]
1532    async fn parse_encoded_password() -> Result<()> {
1533        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1534        let client = APIClient::from_dsn(dsn).await?;
1535        assert_eq!(client.host(), "localhost");
1536        assert_eq!(client.port(), 443);
1537        Ok(())
1538    }
1539
1540    #[tokio::test]
1541    async fn parse_special_chars_password() -> Result<()> {
1542        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1543        let client = APIClient::from_dsn(dsn).await?;
1544        assert_eq!(client.host(), "localhost");
1545        assert_eq!(client.port(), 8000);
1546        Ok(())
1547    }
1548}