Skip to main content

lake_client/
client.rs

1// Copyright 2025 TiDB Cloud
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth, KeyPairAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22    SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28    error::{Error, RequestKind, Result},
29    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30    response::QueryResponse,
31    session::SessionState,
32    QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48    Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77const DEFAULT_USERNAME: &str = "root";
78
79static VERSION: Lazy<String> = Lazy::new(|| {
80    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
81    version.to_string()
82});
83
84#[derive(Clone)]
85pub(crate) struct QueryState {
86    pub node_id: String,
87    pub last_access_time: Arc<Mutex<Instant>>,
88    pub timeout_secs: u64,
89}
90
91impl QueryState {
92    pub fn need_heartbeat(&self, now: Instant) -> bool {
93        let t = *self.last_access_time.lock();
94        now.duration_since(t).as_secs() > self.timeout_secs / 2
95    }
96}
97
98struct HttpResponseData {
99    status: StatusCode,
100    headers: HeaderMap,
101    body: Bytes,
102}
103
104pub struct APIClient {
105    pub(crate) session_id: String,
106    cli: HttpClient,
107    scheme: String,
108    host: String,
109    port: u16,
110
111    endpoint: Url,
112
113    auth: Arc<dyn Auth>,
114
115    tenant: Option<String>,
116    warehouse: Mutex<Option<String>>,
117    session_state: Mutex<SessionState>,
118    route_hint: RouteHintGenerator,
119
120    disable_login: bool,
121    query_result_format: String,
122    disable_session_token: bool,
123    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
124
125    closed: AtomicBool,
126
127    server_version: Option<Version>,
128
129    wait_time_secs: Option<i64>,
130    max_rows_in_buffer: Option<i64>,
131    max_rows_per_page: Option<i64>,
132
133    connect_timeout: Duration,
134    page_request_timeout: Duration,
135
136    tls_ca_file: Option<String>,
137
138    presign: Mutex<PresignMode>,
139    last_node_id: Mutex<Option<String>>,
140    last_query_id: Mutex<Option<String>>,
141
142    capability: Capability,
143
144    queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
145
146    retry_count: u32,
147    retry_delay_secs: u64,
148}
149
150impl Drop for APIClient {
151    fn drop(&mut self) {
152        self.close_with_spawn()
153    }
154}
155
156impl APIClient {
157    fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
158        while let Some(err) = source {
159            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
160                match io_err.kind() {
161                    ErrorKind::TimedOut
162                    | ErrorKind::UnexpectedEof
163                    | ErrorKind::ConnectionReset
164                    | ErrorKind::ConnectionAborted
165                    | ErrorKind::BrokenPipe
166                    | ErrorKind::NotConnected
167                    | ErrorKind::WouldBlock
168                    | ErrorKind::Interrupted => return true,
169                    _ => {}
170                }
171            }
172            source = err.source();
173        }
174        false
175    }
176
177    fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
178        if err.is_timeout() {
179            Some("request timeout")
180        } else if err.is_connect() {
181            Some("connection error")
182        } else if err.is_request() {
183            Some("request error")
184        } else if err.is_body() || err.is_decode() {
185            if Self::has_transient_io_source(err.source()) {
186                Some("response read transient I/O error")
187            } else {
188                None
189            }
190        } else {
191            None
192        }
193    }
194
195    fn request_query_id(request: &Request) -> Option<String> {
196        request
197            .headers()
198            .get(HEADER_QUERY_ID)
199            .and_then(|v| v.to_str().ok())
200            .map(|v| v.to_string())
201    }
202
203    fn with_request_context(
204        reqwest_error: ReqwestError,
205        request_kind: &RequestKind,
206        request: &Request,
207        retry_times: Option<u32>,
208    ) -> Error {
209        let error: Error = reqwest_error.into();
210        let error = error.with_context(request_kind.clone());
211        let error = if let Some(retry_times) = retry_times {
212            error.with_retry_times(retry_times)
213        } else {
214            error
215        };
216        if let Some(query_id) = Self::request_query_id(request) {
217            error.with_query_id(query_id)
218        } else {
219            error
220        }
221    }
222
223    fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
224        match serde_json::from_slice::<ResponseWithErrorCode>(body) {
225            Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
226            Ok(resp) => Error::Logic(status, resp.error),
227            Err(_) => Error::response_error(status, body),
228        }
229    }
230
231    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
232        let mut client = Self::from_dsn(dsn).await?;
233        client.build_client(name).await?;
234        if !client.disable_login {
235            client.login().await?;
236        }
237        if client.session_id.is_empty() {
238            client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
239        }
240        let client = Arc::new(client);
241        client.check_presign().await?;
242        GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
243        Ok(client)
244    }
245
246    pub fn capability(&self) -> &Capability {
247        &self.capability
248    }
249
250    fn set_presign_mode(&self, mode: PresignMode) {
251        *self.presign.lock() = mode
252    }
253    fn get_presign_mode(&self) -> PresignMode {
254        *self.presign.lock()
255    }
256
257    async fn from_dsn(dsn: &str) -> Result<Self> {
258        let u = Url::parse(dsn)?;
259        let mut client = Self::default();
260        if let Some(host) = u.host_str() {
261            client.host = host.to_string();
262        }
263
264        let username = u.username().to_string();
265        if !username.is_empty() {
266            let password = u.password().unwrap_or_default();
267            let password = percent_decode_str(password).decode_utf8()?;
268            client.auth = Arc::new(BasicAuth::new(&username, password));
269        }
270
271        let mut session_state = SessionState::default();
272
273        let database = u.path().trim_start_matches('/');
274        if !database.is_empty() {
275            session_state.set_database(database);
276        }
277
278        let mut private_key_file: Option<String> = None;
279        let mut private_key_passphrase_file: Option<String> = None;
280
281        let mut scheme = "https";
282        for (k, v) in u.query_pairs() {
283            match k.as_ref() {
284                "wait_time_secs" => {
285                    client.wait_time_secs = Some(v.parse()?);
286                }
287                "max_rows_in_buffer" => {
288                    client.max_rows_in_buffer = Some(v.parse()?);
289                }
290                "max_rows_per_page" => {
291                    client.max_rows_per_page = Some(v.parse()?);
292                }
293                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
294                "page_request_timeout_secs" => {
295                    client.page_request_timeout = {
296                        let secs: u64 = v.parse()?;
297                        Duration::from_secs(secs)
298                    };
299                }
300                "presign" => {
301                    let presign_mode = match v.as_ref() {
302                        "auto" => PresignMode::Auto,
303                        "detect" => PresignMode::Detect,
304                        "on" => PresignMode::On,
305                        "off" => PresignMode::Off,
306                        _ => {
307                            return Err(Error::BadArgument(format!(
308                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
309                        )))
310                        }
311                    };
312                    client.set_presign_mode(presign_mode);
313                }
314                "tenant" => {
315                    client.tenant = Some(v.to_string());
316                }
317                "warehouse" => {
318                    client.warehouse = Mutex::new(Some(v.to_string()));
319                }
320                "role" => session_state.set_role(v),
321                "sslmode" => match v.as_ref() {
322                    "disable" => scheme = "http",
323                    "require" | "enable" => scheme = "https",
324                    _ => {
325                        return Err(Error::BadArgument(format!(
326                            "Invalid value for sslmode: {v}"
327                        )))
328                    }
329                },
330                "tls_ca_file" => {
331                    client.tls_ca_file = Some(v.to_string());
332                }
333                "access_token" => {
334                    client.auth = Arc::new(AccessTokenAuth::new(v));
335                }
336                "access_token_file" => {
337                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
338                }
339                "private_key_file" => {
340                    private_key_file = Some(v.to_string());
341                }
342                "private_key_passphrase_file" => {
343                    private_key_passphrase_file = Some(v.to_string());
344                }
345                "login" => {
346                    client.disable_login = match v.as_ref() {
347                        "disable" => true,
348                        "enable" => false,
349                        _ => {
350                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
351                        }
352                    }
353                }
354                "session_token" => {
355                    client.disable_session_token = match v.as_ref() {
356                        "disable" => true,
357                        "enable" => false,
358                        _ => {
359                            return Err(Error::BadArgument(format!(
360                                "Invalid value for session_token: {v}"
361                            )))
362                        }
363                    }
364                }
365                "body_format" | "query_result_format" => {
366                    let v = v.to_string().to_lowercase();
367                    match v.as_str() {
368                        "json" | "arrow" => client.query_result_format = v.to_string(),
369                        _ => {
370                            return Err(Error::BadArgument(format!(
371                                "Invalid value for query_result_format: {v}"
372                            )))
373                        }
374                    }
375                }
376                "retry_count" => {
377                    client.retry_count = v.parse()?;
378                }
379                "retry_delay_secs" => {
380                    client.retry_delay_secs = v.parse()?;
381                }
382                _ => {
383                    session_state.set(k, v);
384                }
385            }
386        }
387        // If private_key_file is specified, use KeyPairAuth.
388        if let Some(key_file) = private_key_file {
389            if username.is_empty() {
390                return Err(Error::BadArgument(
391                    "username is required for key-pair authentication".to_string(),
392                ));
393            }
394            client.auth = Arc::new(KeyPairAuth::new(
395                &username,
396                &key_file,
397                private_key_passphrase_file.as_deref(),
398            )?);
399        }
400        client.port = match u.port() {
401            Some(p) => p,
402            None => match scheme {
403                "http" => 80,
404                "https" => 443,
405                _ => unreachable!(),
406            },
407        };
408        client.scheme = scheme.to_string();
409        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
410        client.session_state = Mutex::new(session_state);
411
412        Ok(client)
413    }
414
415    pub fn host(&self) -> &str {
416        self.host.as_str()
417    }
418
419    pub fn port(&self) -> u16 {
420        self.port
421    }
422
423    pub fn scheme(&self) -> &str {
424        self.scheme.as_str()
425    }
426
427    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
428        let ua = match name {
429            Some(n) => n,
430            None => format!("lake-client-rust/{}", VERSION.as_str()),
431        };
432        let cookie_provider = GlobalCookieStore::new();
433        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
434        let mut initial_cookies = [&cookie].into_iter();
435        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
436        let mut cli_builder = HttpClient::builder()
437            .user_agent(ua)
438            .connect_timeout(self.connect_timeout)
439            .cookie_provider(Arc::new(cookie_provider))
440            .pool_idle_timeout(Duration::from_secs(1));
441        #[cfg(any(feature = "rustls", feature = "native-tls"))]
442        if self.scheme == "https" {
443            if let Some(ref ca_file) = self.tls_ca_file {
444                let cert_pem = tokio::fs::read(ca_file).await?;
445                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
446                cli_builder = cli_builder.add_root_certificate(cert);
447            }
448        }
449        self.cli = cli_builder.build()?;
450        Ok(())
451    }
452
453    async fn check_presign(self: &Arc<Self>) -> Result<()> {
454        let mode = match self.get_presign_mode() {
455            PresignMode::Auto => {
456                if self.host.ends_with(".lake.tidbcloud.com")
457                    || self.host.ends_with(".lake.tidbcloud.cn")
458                    || self.host.ends_with(".tidbcloud.com")
459                {
460                    PresignMode::On
461                } else {
462                    PresignMode::Off
463                }
464            }
465            PresignMode::Detect => match self.get_presigned_upload_url("@~/.lakesql/check").await {
466                Ok(_) => PresignMode::On,
467                Err(e) => {
468                    warn!("presign mode off with error detected: {e}");
469                    PresignMode::Off
470                }
471            },
472            mode => mode,
473        };
474        self.set_presign_mode(mode);
475        Ok(())
476    }
477
478    pub fn current_warehouse(&self) -> Option<String> {
479        let guard = self.warehouse.lock();
480        guard.clone()
481    }
482
483    pub fn current_catalog(&self) -> Option<String> {
484        let guard = self.session_state.lock();
485        guard.catalog.clone()
486    }
487
488    pub fn current_database(&self) -> Option<String> {
489        let guard = self.session_state.lock();
490        guard.database.clone()
491    }
492
493    pub fn set_warehouse(&self, warehouse: impl Into<String>) {
494        let mut guard = self.warehouse.lock();
495        *guard = Some(warehouse.into());
496    }
497
498    pub fn set_database(&self, database: impl Into<String>) {
499        let mut guard = self.session_state.lock();
500        guard.set_database(database);
501    }
502
503    pub fn set_role(&self, role: impl Into<String>) {
504        let mut guard = self.session_state.lock();
505        guard.set_role(role);
506    }
507
508    pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
509        let mut guard = self.session_state.lock();
510        guard.set(key, value);
511    }
512
513    pub async fn current_role(&self) -> Option<String> {
514        let guard = self.session_state.lock();
515        guard.role.clone()
516    }
517
518    fn in_active_transaction(&self) -> bool {
519        let guard = self.session_state.lock();
520        guard
521            .txn_state
522            .as_ref()
523            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
524            .unwrap_or(false)
525    }
526
527    pub fn username(&self) -> String {
528        self.auth.username()
529    }
530
531    fn gen_query_id(&self) -> String {
532        uuid::Uuid::now_v7().simple().to_string()
533    }
534
535    async fn handle_session(&self, session: &Option<SessionState>) {
536        let session = match session {
537            Some(session) => session,
538            None => return,
539        };
540
541        // save the updated session state from the server side
542        {
543            let mut session_state = self.session_state.lock();
544            *session_state = session.clone();
545        }
546
547        // process warehouse changed via session settings
548        if let Some(settings) = session.settings.as_ref() {
549            if let Some(v) = settings.get("warehouse") {
550                let mut warehouse = self.warehouse.lock();
551                *warehouse = Some(v.clone());
552            }
553        }
554    }
555
556    pub fn set_last_node_id(&self, node_id: String) {
557        *self.last_node_id.lock() = Some(node_id)
558    }
559
560    pub fn set_last_query_id(&self, query_id: Option<String>) {
561        *self.last_query_id.lock() = query_id
562    }
563
564    pub fn last_query_id(&self) -> Option<String> {
565        self.last_query_id.lock().clone()
566    }
567
568    fn last_node_id(&self) -> Option<String> {
569        self.last_node_id.lock().clone()
570    }
571
572    fn handle_warnings(&self, resp: &QueryResponse) {
573        if let Some(warnings) = &resp.warnings {
574            for w in warnings {
575                warn!(target: "server_warnings", "server warning: {w}");
576            }
577        }
578    }
579
580    pub async fn start_query(
581        self: &Arc<Self>,
582        sql: &str,
583        need_progress: bool,
584        params: Option<serde_json::Value>,
585    ) -> Result<Pages> {
586        info!("start query: {sql}");
587        let (resp, batches) = self.start_query_inner(sql, None, false, params).await?;
588        Pages::new(self.clone(), resp, batches, need_progress)
589    }
590
591    pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
592        let mut mgr = self.queries_need_heartbeat.lock();
593        if let Some(state) = mgr.remove(query_id) {
594            let self_cloned = self.clone();
595            let query_id = query_id.to_owned();
596            GLOBAL_RUNTIME.spawn(async move {
597                if let Err(e) = self_cloned
598                    .end_query(&query_id, false, Some(state.node_id.as_str()))
599                    .await
600                {
601                    error!("failed to final query {query_id}: {e}");
602                }
603            });
604        }
605    }
606
607    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
608        if let Some(info) = &self.session_token_info {
609            let info = info.lock();
610            Ok(builder.bearer_auth(info.0.session_token.clone()))
611        } else {
612            self.auth.wrap(builder)
613        }
614    }
615
616    async fn start_query_inner(
617        &self,
618        sql: &str,
619        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
620        force_json_body: bool,
621        params: Option<serde_json::Value>,
622    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
623        if !self.in_active_transaction() {
624            self.route_hint.next();
625        }
626        let endpoint = self.endpoint.join("v1/query")?;
627
628        // body
629        let session_state = self.session_state();
630        let need_sticky = session_state.need_sticky.unwrap_or(false);
631        let mut req = QueryRequest::new(sql)
632            .with_pagination(self.make_pagination())
633            .with_session(Some(session_state))
634            .with_stage_attachment(stage_attachment_config)
635            .with_params(params);
636
637        // headers
638        let query_id = self.gen_query_id();
639        let mut headers = self.make_headers(Some(&query_id))?;
640        if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
641            debug!("accept arrow data");
642            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
643            req = req.with_arrow();
644        }
645
646        if need_sticky {
647            if let Some(node_id) = self.last_node_id() {
648                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
649            }
650        }
651        let mut builder = self.cli.post(endpoint.clone()).json(&req);
652        builder = self.wrap_auth_or_session_token(builder)?;
653        let request = builder
654            .headers(headers.clone())
655            .timeout(self.page_request_timeout)
656            .build()?;
657        let response = self
658            .query_request_helper(request, true, true, true, RequestKind::QueryStart)
659            .await?;
660        self.handle_page(response, true).await
661    }
662
663    fn is_arrow_data(headers: &HeaderMap) -> bool {
664        if let Some(typ) = headers.get(CONTENT_TYPE) {
665            if let Ok(t) = typ.to_str() {
666                return t == CONTENT_TYPE_ARROW;
667            }
668        }
669        false
670    }
671
672    async fn handle_page(
673        &self,
674        response: HttpResponseData,
675        is_first: bool,
676    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
677        let status = response.status;
678        if status != StatusCode::OK {
679            let error = Self::status_body_to_error(status, &response.body);
680            if !is_first {
681                if let Error::Logic(code, ec) = &error {
682                    if *code == StatusCode::NOT_FOUND {
683                        return Err(Error::QueryNotFound(ec.message.clone()));
684                    }
685                }
686            }
687            return Err(error);
688        }
689        let is_arrow_data = Self::is_arrow_data(&response.headers);
690        if is_first {
691            if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
692                self.route_hint.set(route_hint.to_str().unwrap_or_default());
693            }
694        }
695        let mut body = response.body;
696        let mut batches = vec![];
697        if is_arrow_data {
698            if is_first {
699                debug!("received arrow data");
700            }
701            let cursor = std::io::Cursor::new(body.as_ref());
702            let reader = StreamReader::try_new(cursor, None)
703                .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
704            let schema = reader.schema();
705            let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
706                bytes::Bytes::copy_from_slice(json_resp.as_bytes())
707            } else {
708                return Err(Error::Decode(
709                    "missing response_header metadata in arrow payload".to_string(),
710                ));
711            };
712            for batch in reader {
713                let batch = batch
714                    .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
715                batches.push(batch);
716            }
717            body = json_body
718        };
719        let resp: QueryResponse = json_from_slice(&body)?;
720        self.handle_session(&resp.session).await;
721        if let Some(err) = &resp.error {
722            return Err(Error::QueryFailed(err.clone()));
723        }
724        if is_first {
725            self.handle_warnings(&resp);
726            self.set_last_query_id(Some(resp.id.clone()));
727            if let Some(node_id) = &resp.node_id {
728                self.set_last_node_id(node_id.clone());
729            }
730        }
731        Ok((resp, batches))
732    }
733
734    pub async fn query_page(
735        &self,
736        query_id: &str,
737        next_uri: &str,
738        node_id: &Option<String>,
739    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
740        info!("query page: {next_uri}");
741        let endpoint = self.endpoint.join(next_uri)?;
742        let mut headers = self.make_headers(Some(query_id))?;
743        if self.capability.arrow_data && self.query_result_format == "arrow" {
744            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
745        }
746        let mut builder = self.cli.get(endpoint.clone());
747        builder = self
748            .wrap_auth_or_session_token(builder)?
749            .headers(headers.clone())
750            .timeout(self.page_request_timeout);
751        if let Some(node_id) = node_id {
752            builder = builder.header(HEADER_STICKY_NODE, node_id)
753        }
754        let request = builder.build()?;
755
756        let response = self
757            .query_request_helper(request, false, true, true, RequestKind::QueryPage)
758            .await?;
759        self.handle_page(response, false).await
760    }
761
762    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
763        self.end_query(query_id, true, None).await
764    }
765
766    pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
767        self.end_query(query_id, false, node_id).await
768    }
769
770    pub async fn end_query(
771        &self,
772        query_id: &str,
773        is_kill: bool,
774        node_id: Option<&str>,
775    ) -> Result<()> {
776        let (method, request_kind) = if is_kill {
777            ("kill", RequestKind::QueryKill)
778        } else {
779            ("final", RequestKind::QueryFinal)
780        };
781        let uri = format!("/v1/query/{query_id}/{method}");
782        let endpoint = self.endpoint.join(&uri)?;
783        let headers = self.make_headers(Some(query_id))?;
784
785        info!("{method} query: {uri}");
786
787        let mut builder = self.cli.post(endpoint.clone());
788        if let Some(node_id) = node_id {
789            builder = builder.header(HEADER_STICKY_NODE, node_id)
790        }
791        builder = self.wrap_auth_or_session_token(builder)?;
792        let resp = builder.headers(headers.clone()).send().await?;
793        if resp.status() != 200 {
794            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
795                .with_context(request_kind)
796                .with_query_id(query_id));
797        }
798        Ok(())
799    }
800
801    pub async fn query_all(
802        self: &Arc<Self>,
803        sql: &str,
804        params: Option<serde_json::Value>,
805    ) -> Result<Page> {
806        self.query_all_inner(sql, false, params).await
807    }
808
809    pub async fn query_all_inner(
810        self: &Arc<Self>,
811        sql: &str,
812        force_json_body: bool,
813        params: Option<serde_json::Value>,
814    ) -> Result<Page> {
815        let (resp, batches) = self
816            .start_query_inner(sql, None, force_json_body, params)
817            .await?;
818        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
819        let mut all = Page::default();
820        while let Some(page) = pages.next().await {
821            all.update(page?);
822        }
823        Ok(all)
824    }
825
826    fn session_state(&self) -> SessionState {
827        self.session_state.lock().clone()
828    }
829
830    fn make_pagination(&self) -> Option<PaginationConfig> {
831        if self.wait_time_secs.is_none()
832            && self.max_rows_in_buffer.is_none()
833            && self.max_rows_per_page.is_none()
834        {
835            return None;
836        }
837        let mut pagination = PaginationConfig {
838            wait_time_secs: None,
839            max_rows_in_buffer: None,
840            max_rows_per_page: None,
841        };
842        if let Some(wait_time_secs) = self.wait_time_secs {
843            pagination.wait_time_secs = Some(wait_time_secs);
844        }
845        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
846            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
847        }
848        if let Some(max_rows_per_page) = self.max_rows_per_page {
849            pagination.max_rows_per_page = Some(max_rows_per_page);
850        }
851        Some(pagination)
852    }
853
854    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
855        let mut headers = HeaderMap::new();
856        if let Some(tenant) = &self.tenant {
857            headers.insert(HEADER_TENANT, tenant.parse()?);
858        }
859        let warehouse = self.warehouse.lock().clone();
860        if let Some(warehouse) = warehouse {
861            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
862        }
863        let route_hint = self.route_hint.current();
864        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
865        if let Some(query_id) = query_id {
866            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
867        }
868        Ok(headers)
869    }
870
871    pub async fn insert_with_stage(
872        self: &Arc<Self>,
873        sql: &str,
874        stage: &str,
875        file_format_options: BTreeMap<&str, &str>,
876        copy_options: BTreeMap<&str, &str>,
877    ) -> Result<QueryStats> {
878        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
879        let stage_attachment = Some(StageAttachmentConfig {
880            location: stage,
881            file_format_options: Some(file_format_options),
882            copy_options: Some(copy_options),
883        });
884        let (resp, batches) = self
885            .start_query_inner(sql, stage_attachment, true, None)
886            .await?;
887        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
888        let mut all = Page::default();
889        while let Some(page) = pages.next().await {
890            all.update(page?);
891        }
892        Ok(all.stats)
893    }
894
895    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
896        info!("get presigned upload url: {stage}");
897        let sql = format!("PRESIGN UPLOAD {stage}");
898        let resp = self.query_all_inner(&sql, true, None).await?;
899        if resp.data.len() != 1 {
900            return Err(Error::Decode(
901                "Empty response from server for presigned request".to_string(),
902            ));
903        }
904        if resp.data[0].len() != 3 {
905            return Err(Error::Decode(
906                "Invalid response from server for presigned request".to_string(),
907            ));
908        }
909        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
910        let method = resp.data[0][0].clone().unwrap_or_default();
911        if method != "PUT" {
912            return Err(Error::Decode(format!(
913                "Invalid method for presigned upload request: {method}"
914            )));
915        }
916        let headers: BTreeMap<String, String> =
917            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
918        let url = resp.data[0][2].clone().unwrap_or_default();
919        Ok(PresignedResponse {
920            method,
921            headers,
922            url,
923        })
924    }
925
926    pub async fn upload_to_stage(
927        self: &Arc<Self>,
928        stage: &str,
929        data: Reader,
930        size: u64,
931    ) -> Result<()> {
932        match self.get_presign_mode() {
933            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
934            PresignMode::On => {
935                let presigned = self.get_presigned_upload_url(stage).await?;
936                presign_upload_to_stage(presigned, data, size).await
937            }
938            PresignMode::Auto => {
939                unreachable!("PresignMode::Auto should be handled during client initialization")
940            }
941            PresignMode::Detect => {
942                unreachable!("PresignMode::Detect should be handled during client initialization")
943            }
944        }
945    }
946
947    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
948    async fn upload_to_stage_with_stream(
949        &self,
950        stage: &str,
951        data: Reader,
952        size: u64,
953    ) -> Result<()> {
954        info!("upload to stage with stream: {stage}, size: {size}");
955        if let Some(info) = self.need_pre_refresh_session().await {
956            self.refresh_session_token(info).await?;
957        }
958        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
959        let location = StageLocation::try_from(stage)?;
960        let query_id = self.gen_query_id();
961        let mut headers = self.make_headers(Some(&query_id))?;
962        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
963        let stream = Body::wrap_stream(ReaderStream::new(data));
964        let part = Part::stream_with_length(stream, size).file_name(location.path);
965        let form = Form::new().part("upload", part);
966        let mut builder = self.cli.put(endpoint.clone());
967        builder = self.wrap_auth_or_session_token(builder)?;
968        let resp = builder.headers(headers).multipart(form).send().await?;
969        let status = resp.status();
970        if status != 200 {
971            return Err(Error::response_error(status, &resp.bytes().await?)
972                .with_context(RequestKind::UploadToStage)
973                .with_query_id(query_id));
974        }
975        Ok(())
976    }
977
978    // use base64 encode whenever possible for safety
979    // but also accept raw JSON for test/debug/one-shot operations
980    pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
981    where
982        T: de::DeserializeOwned,
983    {
984        if value.starts_with("{") {
985            serde_json::from_slice(value.as_bytes())
986                .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
987        } else {
988            let json = URL_SAFE.decode(value).map_err(|e| {
989                format!(
990                    "Invalid value {} for {key}, base64 decode error: {}",
991                    value, e
992                )
993            })?;
994            serde_json::from_slice(&json).map_err(|e| {
995                format!(
996                    "Invalid value {value} for {key}, JSON value {},  decode error: {e}",
997                    String::from_utf8_lossy(&json)
998                )
999            })
1000        }
1001    }
1002
1003    pub async fn streaming_load(
1004        &self,
1005        sql: &str,
1006        data: Reader,
1007        file_name: &str,
1008    ) -> Result<LoadResponse> {
1009        let body = Body::wrap_stream(ReaderStream::new(data));
1010        let part = Part::stream(body).file_name(file_name.to_string());
1011        let endpoint = self.endpoint.join("v1/streaming_load")?;
1012        let mut builder = self.cli.put(endpoint.clone());
1013        builder = self.wrap_auth_or_session_token(builder)?;
1014        let query_id = self.gen_query_id();
1015        let mut headers = self.make_headers(Some(&query_id))?;
1016        headers.insert(HEADER_SQL, sql.parse()?);
1017        let session = serde_json::to_string(&*self.session_state.lock())
1018            .expect("serialize session state should not fail");
1019        headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
1020        let form = Form::new().part("upload", part);
1021        let resp = builder.headers(headers).multipart(form).send().await?;
1022        let status = resp.status();
1023        if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
1024            match Self::decode_json_header::<SessionState>(
1025                HEADER_QUERY_CONTEXT,
1026                value.to_str().unwrap(),
1027            ) {
1028                Ok(session) => *self.session_state.lock() = session,
1029                Err(e) => {
1030                    error!("Error decoding session state when streaming load: {e}");
1031                }
1032            }
1033        };
1034        if status != 200 {
1035            return Err(Error::response_error(status, &resp.bytes().await?)
1036                .with_context(RequestKind::StreamingLoad)
1037                .with_query_id(query_id));
1038        }
1039        let resp = resp.json::<LoadResponse>().await?;
1040        Ok(resp)
1041    }
1042
1043    async fn login(&mut self) -> Result<()> {
1044        let endpoint = self.endpoint.join("/v1/session/login")?;
1045        let headers = self.make_headers(None)?;
1046        let body = LoginRequest::from(&*self.session_state.lock());
1047        let mut builder = self.cli.post(endpoint.clone()).json(&body);
1048        if self.disable_session_token {
1049            builder = builder.query(&[("disable_session_token", true)]);
1050        }
1051        let builder = self.auth.wrap(builder)?;
1052        let request = builder
1053            .headers(headers.clone())
1054            .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1055            .build()?;
1056        let response = self
1057            .query_request_helper(request, true, false, true, RequestKind::Login)
1058            .await?;
1059        if response.status == StatusCode::NOT_FOUND {
1060            info!("login return 404, skip login on the old version server");
1061            return Ok(());
1062        }
1063        if response.status != StatusCode::OK {
1064            return Err(Self::status_body_to_error(response.status, &response.body)
1065                .with_context(RequestKind::Login));
1066        }
1067        if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1068            if let Ok(s) = v.to_str() {
1069                self.session_id = s.to_string();
1070            }
1071        }
1072
1073        let response = json_from_slice(&response.body)?;
1074        match response {
1075            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1076            LoginResponseResult::Ok(info) => {
1077                let server_version = info
1078                    .version
1079                    .parse()
1080                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1081                self.capability = Capability::from_server_version(&server_version);
1082                self.server_version = Some(server_version.clone());
1083                let session_id = self.session_id.as_str();
1084                if let Some(tokens) = info.tokens {
1085                    info!(
1086                        "[session {session_id}] login success with session token version = {server_version}",
1087                    );
1088                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1089                } else {
1090                    info!("[session {session_id}] login success, version = {server_version}");
1091                }
1092            }
1093        }
1094        Ok(())
1095    }
1096
1097    pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1098        let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1099        let queries = self.queries_need_heartbeat.lock().clone();
1100        let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1101        let now = Instant::now();
1102
1103        let mut query_ids = Vec::new();
1104        for (qid, state) in queries {
1105            if state.need_heartbeat(now) {
1106                query_ids.push(qid.to_string());
1107                if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1108                    arr.push(qid);
1109                } else {
1110                    node_to_queries.insert(state.node_id, vec![qid]);
1111                }
1112            }
1113        }
1114
1115        if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1116        {
1117            return Ok(());
1118        }
1119
1120        let body = json!({
1121           "node_to_queries": node_to_queries
1122        });
1123        let headers = self.make_headers(None)?;
1124        let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1125        let request = self.wrap_auth_or_session_token(builder)?.build()?;
1126        let response = self
1127            .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1128            .await?;
1129        if response.status != StatusCode::OK {
1130            return Err(Self::status_body_to_error(response.status, &response.body)
1131                .with_context(RequestKind::Heartbeat));
1132        }
1133        let json: Value = json_from_slice(&response.body)?;
1134        let session_id = self.session_id.as_str();
1135        info!("[session {session_id}] heartbeat request={body}, response={json}");
1136        if let Some(queries_to_remove) = json.get("queries_to_remove") {
1137            if let Some(arr) = queries_to_remove.as_array() {
1138                if !arr.is_empty() {
1139                    let mut queries = self.queries_need_heartbeat.lock();
1140                    for q in arr {
1141                        if let Some(q) = q.as_str() {
1142                            queries.remove(q);
1143                        }
1144                    }
1145                }
1146            }
1147        }
1148        let now = Instant::now();
1149        let mut queries = self.queries_need_heartbeat.lock();
1150        for qid in query_ids {
1151            if let Some(state) = queries.get_mut(&qid) {
1152                *state.last_access_time.lock() = now;
1153            }
1154        }
1155        Ok(())
1156    }
1157
1158    fn build_log_out_request(&self) -> Result<Request> {
1159        let endpoint = self.endpoint.join("/v1/session/logout")?;
1160
1161        let session_state = self.session_state();
1162        let need_sticky = session_state.need_sticky.unwrap_or(false);
1163        let mut headers = self.make_headers(None)?;
1164        if need_sticky {
1165            if let Some(node_id) = self.last_node_id() {
1166                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1167            }
1168        }
1169        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1170
1171        let builder = self.wrap_auth_or_session_token(builder)?;
1172        let req = builder.build()?;
1173        Ok(req)
1174    }
1175
1176    pub(crate) fn need_logout(&self) -> bool {
1177        self.session_token_info.is_some()
1178            || self.session_state.lock().need_keep_alive.unwrap_or(false)
1179    }
1180
1181    async fn refresh_session_token(
1182        &self,
1183        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1184    ) -> Result<()> {
1185        let (session_token_info, _) = { self_login_info.lock().clone() };
1186        let endpoint = self.endpoint.join("/v1/session/refresh")?;
1187        let body = RefreshSessionTokenRequest {
1188            session_token: session_token_info.session_token.clone(),
1189        };
1190        let headers = self.make_headers(None)?;
1191        let request = self
1192            .cli
1193            .post(endpoint.clone())
1194            .json(&body)
1195            .headers(headers.clone())
1196            .bearer_auth(session_token_info.refresh_token.clone())
1197            .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1198            .build()?;
1199        let response = self
1200            .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1201            .await?;
1202        if response.status != StatusCode::OK {
1203            return Err(Self::status_body_to_error(response.status, &response.body)
1204                .with_context(RequestKind::SessionRefresh));
1205        }
1206        let response = json_from_slice(&response.body)?;
1207        match response {
1208            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1209            RefreshResponse::Ok(info) => {
1210                *self_login_info.lock() = (info, Instant::now());
1211                Ok(())
1212            }
1213        }
1214    }
1215
1216    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1217        if let Some(info) = &self.session_token_info {
1218            let (start, ttl) = {
1219                let guard = info.lock();
1220                (guard.1, guard.0.session_token_ttl_in_secs)
1221            };
1222            if Instant::now() > start + Duration::from_secs(ttl) {
1223                return Some(info.clone());
1224            }
1225        }
1226        None
1227    }
1228
1229    /// Send request with retry and return raw HTTP status/headers/body.
1230    ///
1231    /// retry on
1232    ///   - network/request/body transient errors
1233    ///   - (optional) 503
1234    ///
1235    /// refresh lake token or reload jwt token if needed.
1236    async fn query_request_helper(
1237        &self,
1238        mut request: Request,
1239        retry_if_503: bool,
1240        refresh_if_401: bool,
1241        reload_auth_if_401: bool,
1242        request_kind: RequestKind,
1243    ) -> Result<HttpResponseData> {
1244        let mut refreshed = false;
1245        let mut retries = 0;
1246        let max_retries = self.retry_count;
1247        let retry_delay = Duration::from_secs(self.retry_delay_secs);
1248
1249        loop {
1250            let req = request.try_clone().expect("request not cloneable");
1251            let response = match self.cli.execute(req).await {
1252                Ok(response) => response,
1253                Err(err) => {
1254                    if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1255                        if retries >= max_retries.saturating_sub(1) {
1256                            return Err(Self::with_request_context(
1257                                err,
1258                                &request_kind,
1259                                &request,
1260                                Some(retries),
1261                            ));
1262                        }
1263                        retries += 1;
1264                        warn!(
1265                            "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1266                            retries,
1267                            max_retries,
1268                            request.url(),
1269                            reason,
1270                            err,
1271                            retry_delay.as_secs(),
1272                        );
1273                        sleep(jitter(retry_delay)).await;
1274                        continue;
1275                    }
1276                    return Err(Self::with_request_context(
1277                        err,
1278                        &request_kind,
1279                        &request,
1280                        Some(retries),
1281                    ));
1282                }
1283            };
1284
1285            let status = response.status();
1286            let headers = response.headers().clone();
1287            let body = match response.bytes().await {
1288                Ok(body) => body,
1289                Err(err) => {
1290                    if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1291                        if retries >= max_retries.saturating_sub(1) {
1292                            return Err(Self::with_request_context(
1293                                err,
1294                                &request_kind,
1295                                &request,
1296                                Some(retries),
1297                            ));
1298                        }
1299                        retries += 1;
1300                        warn!(
1301                            "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1302                            retries,
1303                            max_retries,
1304                            request.url(),
1305                            reason,
1306                            err,
1307                            retry_delay.as_secs(),
1308                        );
1309                        sleep(jitter(retry_delay)).await;
1310                        continue;
1311                    }
1312                    return Err(Self::with_request_context(
1313                        err,
1314                        &request_kind,
1315                        &request,
1316                        Some(retries),
1317                    ));
1318                }
1319            };
1320
1321            if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1322                if retries >= max_retries.saturating_sub(1) {
1323                    return Ok(HttpResponseData {
1324                        status,
1325                        headers,
1326                        body,
1327                    });
1328                }
1329                retries += 1;
1330                warn!(
1331                    "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1332                    retries,
1333                    max_retries,
1334                    request.url(),
1335                    retry_delay.as_secs()
1336                );
1337                sleep(jitter(retry_delay)).await;
1338                continue;
1339            }
1340
1341            if status == StatusCode::UNAUTHORIZED {
1342                request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1343                let unauthorized_error =
1344                    serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1345                if let Some(session_token_info) = &self.session_token_info {
1346                    let should_refresh = unauthorized_error
1347                        .as_ref()
1348                        .map(|r| need_refresh_token(r.error.code))
1349                        .unwrap_or(false);
1350                    if refresh_if_401 && should_refresh && !refreshed {
1351                        Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1352                        refreshed = true;
1353                        retries = 0;
1354                        warn!(
1355                            "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1356                            request.url(),
1357                            retry_delay.as_secs()
1358                        );
1359                        sleep(jitter(retry_delay)).await;
1360                        continue;
1361                    }
1362                }
1363                if reload_auth_if_401 && self.auth.can_reload() {
1364                    if retries >= max_retries.saturating_sub(1) {
1365                        return Ok(HttpResponseData {
1366                            status,
1367                            headers,
1368                            body,
1369                        });
1370                    }
1371                    retries += 1;
1372                    let builder =
1373                        RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1374                    let builder = self.auth.wrap(builder)?;
1375                    request = builder.build()?;
1376                    warn!(
1377                        "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1378                        retries,
1379                        max_retries,
1380                        request.url(),
1381                        retry_delay.as_secs()
1382                    );
1383                    sleep(jitter(retry_delay)).await;
1384                    continue;
1385                }
1386            }
1387
1388            return Ok(HttpResponseData {
1389                status,
1390                headers,
1391                body,
1392            });
1393        }
1394    }
1395
1396    pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1397        if let Err(err) = http_client.execute(request).await {
1398            error!("[session {session_id}] logout request failed: {err}");
1399        } else {
1400            info!("[session {session_id}] logout success");
1401        };
1402    }
1403
1404    pub async fn close(&self) {
1405        let session_id = &self.session_id;
1406        info!("[session {session_id}] try closing now");
1407        if self
1408            .closed
1409            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1410            .is_ok()
1411        {
1412            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1413            if self.need_logout() {
1414                let cli = self.cli.clone();
1415                let req = self
1416                    .build_log_out_request()
1417                    .expect("failed to build logout request");
1418                Self::logout(cli, req, &self.session_id).await;
1419            }
1420        }
1421    }
1422    pub fn close_with_spawn(&self) {
1423        let session_id = &self.session_id;
1424        info!("[session {session_id}]: try closing with spawn");
1425        if self
1426            .closed
1427            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1428            .is_ok()
1429        {
1430            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1431            if self.need_logout() {
1432                let cli = self.cli.clone();
1433                let req = self
1434                    .build_log_out_request()
1435                    .expect("failed to build logout request");
1436                let session_id = self.session_id.clone();
1437                GLOBAL_RUNTIME.spawn(async move {
1438                    Self::logout(cli, req, session_id.as_str()).await;
1439                });
1440            }
1441        }
1442    }
1443
1444    pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1445        let mut queries = self.queries_need_heartbeat.lock();
1446        queries.insert(query_id.to_string(), state);
1447    }
1448}
1449
1450fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1451where
1452    T: Deserialize<'a>,
1453{
1454    serde_json::from_slice::<T>(body).map_err(|e| {
1455        Error::Decode(format!(
1456            "fail to decode JSON response: {e}, body: {}",
1457            String::from_utf8_lossy(body)
1458        ))
1459    })
1460}
1461
1462impl Default for APIClient {
1463    fn default() -> Self {
1464        Self {
1465            session_id: Default::default(),
1466            cli: HttpClient::new(),
1467            scheme: "http".to_string(),
1468            endpoint: Url::parse("http://localhost:8080").unwrap(),
1469            host: "localhost".to_string(),
1470            port: 8000,
1471            tenant: None,
1472            warehouse: Mutex::new(None),
1473            auth: Arc::new(BasicAuth::new(DEFAULT_USERNAME, "")) as Arc<dyn Auth>,
1474            session_state: Mutex::new(SessionState::default()),
1475            wait_time_secs: None,
1476            max_rows_in_buffer: None,
1477            max_rows_per_page: None,
1478            connect_timeout: Duration::from_secs(10),
1479            page_request_timeout: Duration::from_secs(300),
1480            tls_ca_file: None,
1481            presign: Mutex::new(PresignMode::Auto),
1482            route_hint: RouteHintGenerator::new(),
1483            last_node_id: Default::default(),
1484            disable_session_token: true,
1485            disable_login: false,
1486            query_result_format: "json".to_string(),
1487            session_token_info: None,
1488            closed: AtomicBool::new(false),
1489            last_query_id: Default::default(),
1490            server_version: None,
1491            capability: Default::default(),
1492            queries_need_heartbeat: Default::default(),
1493            retry_count: 3,
1494            retry_delay_secs: 10,
1495        }
1496    }
1497}
1498
1499struct RouteHintGenerator {
1500    nonce: AtomicU64,
1501    current: std::sync::Mutex<String>,
1502}
1503
1504impl RouteHintGenerator {
1505    fn new() -> Self {
1506        let gen = Self {
1507            nonce: AtomicU64::new(0),
1508            current: std::sync::Mutex::new("".to_string()),
1509        };
1510        gen.next();
1511        gen
1512    }
1513
1514    fn current(&self) -> String {
1515        let guard = self.current.lock().unwrap();
1516        guard.clone()
1517    }
1518
1519    fn set(&self, hint: &str) {
1520        let mut guard = self.current.lock().unwrap();
1521        *guard = hint.to_string();
1522    }
1523
1524    fn next(&self) -> String {
1525        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1526        let uuid = uuid::Uuid::new_v4();
1527        let current = format!("rh:{uuid}:{nonce:06}");
1528        let mut guard = self.current.lock().unwrap();
1529        guard.clone_from(&current);
1530        current
1531    }
1532}
1533
1534#[cfg(test)]
1535mod test {
1536    use super::*;
1537    use std::io::Write;
1538    use tempfile::NamedTempFile;
1539
1540    #[tokio::test]
1541    async fn parse_dsn() -> Result<()> {
1542        let dsn = "lake://username:password@app.lake.tidbcloud.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1543        let client = APIClient::from_dsn(dsn).await?;
1544        assert_eq!(client.host, "app.lake.tidbcloud.com");
1545        assert_eq!(
1546            client.endpoint,
1547            Url::parse("http://app.lake.tidbcloud.com:80")?
1548        );
1549        assert_eq!(client.wait_time_secs, Some(10));
1550        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1551        assert_eq!(client.max_rows_per_page, Some(10000));
1552        assert_eq!(client.tenant, None);
1553        assert_eq!(
1554            *client.warehouse.try_lock().unwrap(),
1555            Some("wh".to_string())
1556        );
1557        Ok(())
1558    }
1559
1560    #[tokio::test]
1561    async fn parse_dsn_with_private_key_uses_keypair_auth() -> Result<()> {
1562        let output = std::process::Command::new("openssl")
1563            .args([
1564                "genpkey",
1565                "-algorithm",
1566                "RSA",
1567                "-pkeyopt",
1568                "rsa_keygen_bits:2048",
1569            ])
1570            .output();
1571        let output = match output {
1572            Ok(o) if o.status.success() => o,
1573            _ => return Ok(()),
1574        };
1575
1576        let mut key_file = NamedTempFile::new()?;
1577        key_file.write_all(&output.stdout)?;
1578
1579        let dsn = format!(
1580            "lake://username:password@app.tidbcloud.com/test?private_key_file={}&sslmode=disable",
1581            url::form_urlencoded::byte_serialize(key_file.path().to_string_lossy().as_bytes())
1582                .collect::<String>()
1583        );
1584        let client = APIClient::from_dsn(&dsn).await?;
1585        assert_eq!(client.auth.username(), "username");
1586        assert!(client.auth.can_reload());
1587
1588        let request = client
1589            .auth
1590            .wrap(client.cli.get(client.endpoint.join("v1/query")?))?
1591            .build()?;
1592        assert_eq!(
1593            request
1594                .headers()
1595                .get("X-DATABEND-AUTH-METHOD")
1596                .and_then(|value| value.to_str().ok()),
1597            Some("keypair")
1598        );
1599        let authorization = request
1600            .headers()
1601            .get(reqwest::header::AUTHORIZATION)
1602            .and_then(|value| value.to_str().ok())
1603            .unwrap_or_default();
1604        assert!(authorization.starts_with("Bearer "));
1605        assert_eq!(authorization["Bearer ".len()..].split('.').count(), 3);
1606        Ok(())
1607    }
1608
1609    #[tokio::test]
1610    async fn parse_dsn_with_private_key_requires_user() -> Result<()> {
1611        let output = std::process::Command::new("openssl")
1612            .args([
1613                "genpkey",
1614                "-algorithm",
1615                "RSA",
1616                "-pkeyopt",
1617                "rsa_keygen_bits:2048",
1618            ])
1619            .output();
1620        let output = match output {
1621            Ok(o) if o.status.success() => o,
1622            _ => return Ok(()),
1623        };
1624
1625        let mut key_file = NamedTempFile::new()?;
1626        key_file.write_all(&output.stdout)?;
1627
1628        let dsn = format!(
1629            "lake://app.tidbcloud.com/test?private_key_file={}&sslmode=disable",
1630            url::form_urlencoded::byte_serialize(key_file.path().to_string_lossy().as_bytes())
1631                .collect::<String>()
1632        );
1633        let err = APIClient::from_dsn(&dsn)
1634            .await
1635            .err()
1636            .expect("key-pair authentication should require an explicit username");
1637        assert!(
1638            err.to_string()
1639                .contains("username is required for key-pair authentication"),
1640            "unexpected error: {err}"
1641        );
1642
1643        Ok(())
1644    }
1645
1646    #[tokio::test]
1647    async fn parse_encoded_password() -> Result<()> {
1648        let dsn = "lake://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1649        let client = APIClient::from_dsn(dsn).await?;
1650        assert_eq!(client.host(), "localhost");
1651        assert_eq!(client.port(), 443);
1652        Ok(())
1653    }
1654
1655    #[tokio::test]
1656    async fn parse_special_chars_password() -> Result<()> {
1657        let dsn = "lake://username:3a@SC(nYE1k={{R@localhost:8000";
1658        let client = APIClient::from_dsn(dsn).await?;
1659        assert_eq!(client.host(), "localhost");
1660        assert_eq!(client.port(), 8000);
1661        Ok(())
1662    }
1663}