databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22    SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28    error::{Error, Result},
29    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30    response::QueryResponse,
31    session::SessionState,
32    QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::{de, Deserialize};
49use serde_json::{json, Value};
50use std::collections::{BTreeMap, HashMap};
51use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tokio::time::sleep;
55use tokio_retry::strategy::jitter;
56use tokio_stream::StreamExt;
57use tokio_util::io::ReaderStream;
58use url::Url;
59
60const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
61const HEADER_TENANT: &str = "X-DATABEND-TENANT";
62const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
63const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
64const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
65const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
66const TXN_STATE_ACTIVE: &str = "Active";
67const HEADER_SQL: &str = "X-DATABEND-SQL";
68const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
69const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
70const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
71const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
72
73static VERSION: Lazy<String> = Lazy::new(|| {
74    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
75    version.to_string()
76});
77
78#[derive(Clone)]
79pub(crate) struct QueryState {
80    pub node_id: String,
81    pub last_access_time: Arc<Mutex<Instant>>,
82    pub timeout_secs: u64,
83}
84
85impl QueryState {
86    pub fn need_heartbeat(&self, now: Instant) -> bool {
87        let t = *self.last_access_time.lock();
88        now.duration_since(t).as_secs() > self.timeout_secs / 2
89    }
90}
91
92pub struct APIClient {
93    pub(crate) session_id: String,
94    cli: HttpClient,
95    scheme: String,
96    host: String,
97    port: u16,
98
99    endpoint: Url,
100
101    auth: Arc<dyn Auth>,
102
103    tenant: Option<String>,
104    warehouse: Mutex<Option<String>>,
105    session_state: Mutex<SessionState>,
106    route_hint: RouteHintGenerator,
107
108    disable_login: bool,
109    query_result_format: String,
110    disable_session_token: bool,
111    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
112
113    closed: AtomicBool,
114
115    server_version: Option<Version>,
116
117    wait_time_secs: Option<i64>,
118    max_rows_in_buffer: Option<i64>,
119    max_rows_per_page: Option<i64>,
120
121    connect_timeout: Duration,
122    page_request_timeout: Duration,
123
124    tls_ca_file: Option<String>,
125
126    presign: Mutex<PresignMode>,
127    last_node_id: Mutex<Option<String>>,
128    last_query_id: Mutex<Option<String>>,
129
130    capability: Capability,
131
132    queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
133}
134
135impl Drop for APIClient {
136    fn drop(&mut self) {
137        self.close_with_spawn()
138    }
139}
140
141impl APIClient {
142    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
143        let mut client = Self::from_dsn(dsn).await?;
144        client.build_client(name).await?;
145        if !client.disable_login {
146            client.login().await?;
147        }
148        if client.session_id.is_empty() {
149            client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
150        }
151        let client = Arc::new(client);
152        client.check_presign().await?;
153        GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
154        Ok(client)
155    }
156
157    pub fn capability(&self) -> &Capability {
158        &self.capability
159    }
160
161    fn set_presign_mode(&self, mode: PresignMode) {
162        *self.presign.lock() = mode
163    }
164    fn get_presign_mode(&self) -> PresignMode {
165        *self.presign.lock()
166    }
167
168    async fn from_dsn(dsn: &str) -> Result<Self> {
169        let u = Url::parse(dsn)?;
170        let mut client = Self::default();
171        if let Some(host) = u.host_str() {
172            client.host = host.to_string();
173        }
174
175        if u.username() != "" {
176            let password = u.password().unwrap_or_default();
177            let password = percent_decode_str(password).decode_utf8()?;
178            client.auth = Arc::new(BasicAuth::new(u.username(), password));
179        }
180
181        let mut session_state = SessionState::default();
182
183        let database = u.path().trim_start_matches('/');
184        if !database.is_empty() {
185            session_state.set_database(database);
186        }
187
188        let mut scheme = "https";
189        for (k, v) in u.query_pairs() {
190            match k.as_ref() {
191                "wait_time_secs" => {
192                    client.wait_time_secs = Some(v.parse()?);
193                }
194                "max_rows_in_buffer" => {
195                    client.max_rows_in_buffer = Some(v.parse()?);
196                }
197                "max_rows_per_page" => {
198                    client.max_rows_per_page = Some(v.parse()?);
199                }
200                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
201                "page_request_timeout_secs" => {
202                    client.page_request_timeout = {
203                        let secs: u64 = v.parse()?;
204                        Duration::from_secs(secs)
205                    };
206                }
207                "presign" => {
208                    let presign_mode = match v.as_ref() {
209                        "auto" => PresignMode::Auto,
210                        "detect" => PresignMode::Detect,
211                        "on" => PresignMode::On,
212                        "off" => PresignMode::Off,
213                        _ => {
214                            return Err(Error::BadArgument(format!(
215                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
216                        )))
217                        }
218                    };
219                    client.set_presign_mode(presign_mode);
220                }
221                "tenant" => {
222                    client.tenant = Some(v.to_string());
223                }
224                "warehouse" => {
225                    client.warehouse = Mutex::new(Some(v.to_string()));
226                }
227                "role" => session_state.set_role(v),
228                "sslmode" => match v.as_ref() {
229                    "disable" => scheme = "http",
230                    "require" | "enable" => scheme = "https",
231                    _ => {
232                        return Err(Error::BadArgument(format!(
233                            "Invalid value for sslmode: {v}"
234                        )))
235                    }
236                },
237                "tls_ca_file" => {
238                    client.tls_ca_file = Some(v.to_string());
239                }
240                "access_token" => {
241                    client.auth = Arc::new(AccessTokenAuth::new(v));
242                }
243                "access_token_file" => {
244                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
245                }
246                "login" => {
247                    client.disable_login = match v.as_ref() {
248                        "disable" => true,
249                        "enable" => false,
250                        _ => {
251                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
252                        }
253                    }
254                }
255                "session_token" => {
256                    client.disable_session_token = match v.as_ref() {
257                        "disable" => true,
258                        "enable" => false,
259                        _ => {
260                            return Err(Error::BadArgument(format!(
261                                "Invalid value for session_token: {v}"
262                            )))
263                        }
264                    }
265                }
266                "body_format" | "query_result_format" => {
267                    let v = v.to_string().to_lowercase();
268                    match v.as_str() {
269                        "json" | "arrow" => client.query_result_format = v.to_string(),
270                        _ => {
271                            return Err(Error::BadArgument(format!(
272                                "Invalid value for query_result_format: {v}"
273                            )))
274                        }
275                    }
276                }
277                _ => {
278                    session_state.set(k, v);
279                }
280            }
281        }
282        client.port = match u.port() {
283            Some(p) => p,
284            None => match scheme {
285                "http" => 80,
286                "https" => 443,
287                _ => unreachable!(),
288            },
289        };
290        client.scheme = scheme.to_string();
291        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
292        client.session_state = Mutex::new(session_state);
293
294        Ok(client)
295    }
296
297    pub fn host(&self) -> &str {
298        self.host.as_str()
299    }
300
301    pub fn port(&self) -> u16 {
302        self.port
303    }
304
305    pub fn scheme(&self) -> &str {
306        self.scheme.as_str()
307    }
308
309    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
310        let ua = match name {
311            Some(n) => n,
312            None => format!("databend-client-rust/{}", VERSION.as_str()),
313        };
314        let cookie_provider = GlobalCookieStore::new();
315        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
316        let mut initial_cookies = [&cookie].into_iter();
317        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
318        let mut cli_builder = HttpClient::builder()
319            .user_agent(ua)
320            .cookie_provider(Arc::new(cookie_provider))
321            .pool_idle_timeout(Duration::from_secs(1));
322        #[cfg(any(feature = "rustls", feature = "native-tls"))]
323        if self.scheme == "https" {
324            if let Some(ref ca_file) = self.tls_ca_file {
325                let cert_pem = tokio::fs::read(ca_file).await?;
326                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
327                cli_builder = cli_builder.add_root_certificate(cert);
328            }
329        }
330        self.cli = cli_builder.build()?;
331        Ok(())
332    }
333
334    async fn check_presign(self: &Arc<Self>) -> Result<()> {
335        let mode = match self.get_presign_mode() {
336            PresignMode::Auto => {
337                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
338                    PresignMode::On
339                } else {
340                    PresignMode::Off
341                }
342            }
343            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
344                Ok(_) => PresignMode::On,
345                Err(e) => {
346                    warn!("presign mode off with error detected: {e}");
347                    PresignMode::Off
348                }
349            },
350            mode => mode,
351        };
352        self.set_presign_mode(mode);
353        Ok(())
354    }
355
356    pub fn current_warehouse(&self) -> Option<String> {
357        let guard = self.warehouse.lock();
358        guard.clone()
359    }
360
361    pub fn current_catalog(&self) -> Option<String> {
362        let guard = self.session_state.lock();
363        guard.catalog.clone()
364    }
365
366    pub fn current_database(&self) -> Option<String> {
367        let guard = self.session_state.lock();
368        guard.database.clone()
369    }
370
371    pub fn set_warehouse(&self, warehouse: impl Into<String>) {
372        let mut guard = self.warehouse.lock();
373        *guard = Some(warehouse.into());
374    }
375
376    pub fn set_database(&self, database: impl Into<String>) {
377        let mut guard = self.session_state.lock();
378        guard.set_database(database);
379    }
380
381    pub fn set_role(&self, role: impl Into<String>) {
382        let mut guard = self.session_state.lock();
383        guard.set_role(role);
384    }
385
386    pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
387        let mut guard = self.session_state.lock();
388        guard.set(key, value);
389    }
390
391    pub async fn current_role(&self) -> Option<String> {
392        let guard = self.session_state.lock();
393        guard.role.clone()
394    }
395
396    fn in_active_transaction(&self) -> bool {
397        let guard = self.session_state.lock();
398        guard
399            .txn_state
400            .as_ref()
401            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
402            .unwrap_or(false)
403    }
404
405    pub fn username(&self) -> String {
406        self.auth.username()
407    }
408
409    fn gen_query_id(&self) -> String {
410        uuid::Uuid::now_v7().simple().to_string()
411    }
412
413    async fn handle_session(&self, session: &Option<SessionState>) {
414        let session = match session {
415            Some(session) => session,
416            None => return,
417        };
418
419        // save the updated session state from the server side
420        {
421            let mut session_state = self.session_state.lock();
422            *session_state = session.clone();
423        }
424
425        // process warehouse changed via session settings
426        if let Some(settings) = session.settings.as_ref() {
427            if let Some(v) = settings.get("warehouse") {
428                let mut warehouse = self.warehouse.lock();
429                *warehouse = Some(v.clone());
430            }
431        }
432    }
433
434    pub fn set_last_node_id(&self, node_id: String) {
435        *self.last_node_id.lock() = Some(node_id)
436    }
437
438    pub fn set_last_query_id(&self, query_id: Option<String>) {
439        *self.last_query_id.lock() = query_id
440    }
441
442    pub fn last_query_id(&self) -> Option<String> {
443        self.last_query_id.lock().clone()
444    }
445
446    fn last_node_id(&self) -> Option<String> {
447        self.last_node_id.lock().clone()
448    }
449
450    fn handle_warnings(&self, resp: &QueryResponse) {
451        if let Some(warnings) = &resp.warnings {
452            for w in warnings {
453                warn!(target: "server_warnings", "server warning: {w}");
454            }
455        }
456    }
457
458    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
459        info!("start query: {sql}");
460        let (resp, batches) = self.start_query_inner(sql, None, false).await?;
461        Pages::new(self.clone(), resp, batches, need_progress)
462    }
463
464    pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
465        let mut mgr = self.queries_need_heartbeat.lock();
466        if let Some(state) = mgr.remove(query_id) {
467            let self_cloned = self.clone();
468            let query_id = query_id.to_owned();
469            GLOBAL_RUNTIME.spawn(async move {
470                if let Err(e) = self_cloned
471                    .end_query(&query_id, "final", Some(state.node_id.as_str()))
472                    .await
473                {
474                    error!("failed to final query {query_id}: {e}");
475                }
476            });
477        }
478    }
479
480    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
481        if let Some(info) = &self.session_token_info {
482            let info = info.lock();
483            Ok(builder.bearer_auth(info.0.session_token.clone()))
484        } else {
485            self.auth.wrap(builder)
486        }
487    }
488
489    async fn start_query_inner(
490        &self,
491        sql: &str,
492        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
493        force_json_body: bool,
494    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
495        if !self.in_active_transaction() {
496            self.route_hint.next();
497        }
498        let endpoint = self.endpoint.join("v1/query")?;
499
500        // body
501        let session_state = self.session_state();
502        let need_sticky = session_state.need_sticky.unwrap_or(false);
503        let req = QueryRequest::new(sql)
504            .with_pagination(self.make_pagination())
505            .with_session(Some(session_state))
506            .with_stage_attachment(stage_attachment_config);
507
508        // headers
509        let query_id = self.gen_query_id();
510        let mut headers = self.make_headers(Some(&query_id))?;
511        if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
512            debug!("accept arrow data");
513            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
514        }
515
516        if need_sticky {
517            if let Some(node_id) = self.last_node_id() {
518                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
519            }
520        }
521        let mut builder = self.cli.post(endpoint.clone()).json(&req);
522        builder = self.wrap_auth_or_session_token(builder)?;
523        let request = builder.headers(headers.clone()).build()?;
524        let response = self.query_request_helper(request, true, true).await?;
525        self.handle_page(response, true).await
526    }
527
528    fn is_arrow_data(response: &Response) -> bool {
529        if let Some(typ) = response.headers().get(CONTENT_TYPE) {
530            if let Ok(t) = typ.to_str() {
531                return t == CONTENT_TYPE_ARROW;
532            }
533        }
534        false
535    }
536
537    async fn handle_page(
538        &self,
539        response: Response,
540        is_first: bool,
541    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
542        let status = response.status();
543        if status != 200 {
544            return Err(Error::response_error(status, &response.bytes().await?));
545        }
546        let is_arrow_data = Self::is_arrow_data(&response);
547        if is_first {
548            if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
549                self.route_hint.set(route_hint.to_str().unwrap_or_default());
550            }
551        }
552        let mut body = response.bytes().await?;
553        let mut batches = vec![];
554        if is_arrow_data {
555            if is_first {
556                debug!("received arrow data");
557            }
558            let cursor = std::io::Cursor::new(body.as_ref());
559            let reader = StreamReader::try_new(cursor, None)
560                .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
561            let schema = reader.schema();
562            let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
563                bytes::Bytes::copy_from_slice(json_resp.as_bytes())
564            } else {
565                return Err(Error::Decode(
566                    "missing response_header metadata in arrow payload".to_string(),
567                ));
568            };
569            for batch in reader {
570                let batch = batch
571                    .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
572                batches.push(batch);
573            }
574            body = json_body
575        };
576        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
577            if let Error::Logic(status, ec) = &e {
578                if *status == 404 {
579                    return Error::QueryNotFound(ec.message.clone());
580                }
581            }
582            e
583        })?;
584        self.handle_session(&resp.session).await;
585        if let Some(err) = &resp.error {
586            return Err(Error::QueryFailed(err.clone()));
587        }
588        if is_first {
589            self.handle_warnings(&resp);
590            self.set_last_query_id(Some(resp.id.clone()));
591            if let Some(node_id) = &resp.node_id {
592                self.set_last_node_id(node_id.clone());
593            }
594        }
595        Ok((resp, batches))
596    }
597
598    pub async fn query_page(
599        &self,
600        query_id: &str,
601        next_uri: &str,
602        node_id: &Option<String>,
603    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
604        info!("query page: {next_uri}");
605        let endpoint = self.endpoint.join(next_uri)?;
606        let mut headers = self.make_headers(Some(query_id))?;
607        if self.capability.arrow_data && self.query_result_format == "arrow" {
608            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
609        }
610        let mut builder = self.cli.get(endpoint.clone());
611        builder = self
612            .wrap_auth_or_session_token(builder)?
613            .headers(headers.clone())
614            .timeout(self.page_request_timeout);
615        if let Some(node_id) = node_id {
616            builder = builder.header(HEADER_STICKY_NODE, node_id)
617        }
618        let request = builder.build()?;
619
620        let response = self.query_request_helper(request, false, true).await?;
621        self.handle_page(response, false).await
622    }
623
624    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
625        self.end_query(query_id, "kill", None).await
626    }
627
628    pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
629        self.end_query(query_id, "final", node_id).await
630    }
631
632    pub async fn end_query(
633        &self,
634        query_id: &str,
635        method: &str,
636        node_id: Option<&str>,
637    ) -> Result<()> {
638        let uri = format!("/v1/query/{query_id}/{method}");
639        let endpoint = self.endpoint.join(&uri)?;
640        let headers = self.make_headers(Some(query_id))?;
641
642        info!("{method} query: {uri}");
643
644        let mut builder = self.cli.post(endpoint);
645        if let Some(node_id) = node_id {
646            builder = builder.header(HEADER_STICKY_NODE, node_id)
647        }
648        builder = self.wrap_auth_or_session_token(builder)?;
649        let resp = builder.headers(headers.clone()).send().await?;
650        if resp.status() != 200 {
651            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
652                .with_context(&format!("{method} query")));
653        }
654        Ok(())
655    }
656
657    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
658        self.query_all_inner(sql, false).await
659    }
660
661    pub async fn query_all_inner(
662        self: &Arc<Self>,
663        sql: &str,
664        force_json_body: bool,
665    ) -> Result<Page> {
666        let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
667        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
668        let mut all = Page::default();
669        while let Some(page) = pages.next().await {
670            all.update(page?);
671        }
672        Ok(all)
673    }
674
675    fn session_state(&self) -> SessionState {
676        self.session_state.lock().clone()
677    }
678
679    fn make_pagination(&self) -> Option<PaginationConfig> {
680        if self.wait_time_secs.is_none()
681            && self.max_rows_in_buffer.is_none()
682            && self.max_rows_per_page.is_none()
683        {
684            return None;
685        }
686        let mut pagination = PaginationConfig {
687            wait_time_secs: None,
688            max_rows_in_buffer: None,
689            max_rows_per_page: None,
690        };
691        if let Some(wait_time_secs) = self.wait_time_secs {
692            pagination.wait_time_secs = Some(wait_time_secs);
693        }
694        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
695            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
696        }
697        if let Some(max_rows_per_page) = self.max_rows_per_page {
698            pagination.max_rows_per_page = Some(max_rows_per_page);
699        }
700        Some(pagination)
701    }
702
703    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
704        let mut headers = HeaderMap::new();
705        if let Some(tenant) = &self.tenant {
706            headers.insert(HEADER_TENANT, tenant.parse()?);
707        }
708        let warehouse = self.warehouse.lock().clone();
709        if let Some(warehouse) = warehouse {
710            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
711        }
712        let route_hint = self.route_hint.current();
713        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
714        if let Some(query_id) = query_id {
715            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
716        }
717        Ok(headers)
718    }
719
720    pub async fn insert_with_stage(
721        self: &Arc<Self>,
722        sql: &str,
723        stage: &str,
724        file_format_options: BTreeMap<&str, &str>,
725        copy_options: BTreeMap<&str, &str>,
726    ) -> Result<QueryStats> {
727        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
728        let stage_attachment = Some(StageAttachmentConfig {
729            location: stage,
730            file_format_options: Some(file_format_options),
731            copy_options: Some(copy_options),
732        });
733        let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
734        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
735        let mut all = Page::default();
736        while let Some(page) = pages.next().await {
737            all.update(page?);
738        }
739        Ok(all.stats)
740    }
741
742    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
743        info!("get presigned upload url: {stage}");
744        let sql = format!("PRESIGN UPLOAD {stage}");
745        let resp = self.query_all_inner(&sql, true).await?;
746        if resp.data.len() != 1 {
747            return Err(Error::Decode(
748                "Empty response from server for presigned request".to_string(),
749            ));
750        }
751        if resp.data[0].len() != 3 {
752            return Err(Error::Decode(
753                "Invalid response from server for presigned request".to_string(),
754            ));
755        }
756        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
757        let method = resp.data[0][0].clone().unwrap_or_default();
758        if method != "PUT" {
759            return Err(Error::Decode(format!(
760                "Invalid method for presigned upload request: {method}"
761            )));
762        }
763        let headers: BTreeMap<String, String> =
764            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
765        let url = resp.data[0][2].clone().unwrap_or_default();
766        Ok(PresignedResponse {
767            method,
768            headers,
769            url,
770        })
771    }
772
773    pub async fn upload_to_stage(
774        self: &Arc<Self>,
775        stage: &str,
776        data: Reader,
777        size: u64,
778    ) -> Result<()> {
779        match self.get_presign_mode() {
780            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
781            PresignMode::On => {
782                let presigned = self.get_presigned_upload_url(stage).await?;
783                presign_upload_to_stage(presigned, data, size).await
784            }
785            PresignMode::Auto => {
786                unreachable!("PresignMode::Auto should be handled during client initialization")
787            }
788            PresignMode::Detect => {
789                unreachable!("PresignMode::Detect should be handled during client initialization")
790            }
791        }
792    }
793
794    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
795    async fn upload_to_stage_with_stream(
796        &self,
797        stage: &str,
798        data: Reader,
799        size: u64,
800    ) -> Result<()> {
801        info!("upload to stage with stream: {stage}, size: {size}");
802        if let Some(info) = self.need_pre_refresh_session().await {
803            self.refresh_session_token(info).await?;
804        }
805        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
806        let location = StageLocation::try_from(stage)?;
807        let query_id = self.gen_query_id();
808        let mut headers = self.make_headers(Some(&query_id))?;
809        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
810        let stream = Body::wrap_stream(ReaderStream::new(data));
811        let part = Part::stream_with_length(stream, size).file_name(location.path);
812        let form = Form::new().part("upload", part);
813        let mut builder = self.cli.put(endpoint.clone());
814        builder = self.wrap_auth_or_session_token(builder)?;
815        let resp = builder.headers(headers).multipart(form).send().await?;
816        let status = resp.status();
817        if status != 200 {
818            return Err(
819                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
820            );
821        }
822        Ok(())
823    }
824
825    // use base64 encode whenever possible for safety
826    // but also accept raw JSON for test/debug/one-shot operations
827    pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
828    where
829        T: de::DeserializeOwned,
830    {
831        if value.starts_with("{") {
832            serde_json::from_slice(value.as_bytes())
833                .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
834        } else {
835            let json = URL_SAFE.decode(value).map_err(|e| {
836                format!(
837                    "Invalid value {} for {key}, base64 decode error: {}",
838                    value, e
839                )
840            })?;
841            serde_json::from_slice(&json).map_err(|e| {
842                format!(
843                    "Invalid value {value} for {key}, JSON value {},  decode error: {e}",
844                    String::from_utf8_lossy(&json)
845                )
846            })
847        }
848    }
849
850    pub async fn streaming_load(
851        &self,
852        sql: &str,
853        data: Reader,
854        file_name: &str,
855    ) -> Result<LoadResponse> {
856        let body = Body::wrap_stream(ReaderStream::new(data));
857        let part = Part::stream(body).file_name(file_name.to_string());
858        let endpoint = self.endpoint.join("v1/streaming_load")?;
859        let mut builder = self.cli.put(endpoint.clone());
860        builder = self.wrap_auth_or_session_token(builder)?;
861        let query_id = self.gen_query_id();
862        let mut headers = self.make_headers(Some(&query_id))?;
863        headers.insert(HEADER_SQL, sql.parse()?);
864        let session = serde_json::to_string(&*self.session_state.lock())
865            .expect("serialize session state should not fail");
866        headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
867        let form = Form::new().part("upload", part);
868        let resp = builder.headers(headers).multipart(form).send().await?;
869        let status = resp.status();
870        if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
871            match Self::decode_json_header::<SessionState>(
872                HEADER_QUERY_CONTEXT,
873                value.to_str().unwrap(),
874            ) {
875                Ok(session) => *self.session_state.lock() = session,
876                Err(e) => {
877                    error!("Error decoding session state when streaming load: {e}");
878                }
879            }
880        };
881        if status != 200 {
882            return Err(
883                Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
884            );
885        }
886        let resp = resp.json::<LoadResponse>().await?;
887        Ok(resp)
888    }
889
890    async fn login(&mut self) -> Result<()> {
891        let endpoint = self.endpoint.join("/v1/session/login")?;
892        let headers = self.make_headers(None)?;
893        let body = LoginRequest::from(&*self.session_state.lock());
894        let mut builder = self.cli.post(endpoint.clone()).json(&body);
895        if self.disable_session_token {
896            builder = builder.query(&[("disable_session_token", true)]);
897        }
898        let builder = self.auth.wrap(builder)?;
899        let request = builder
900            .headers(headers.clone())
901            .timeout(self.connect_timeout)
902            .build()?;
903        let response = self.query_request_helper(request, true, false).await;
904        let response = match response {
905            Ok(r) => r,
906            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
907                info!("login return 404, skip login on the old version server");
908                return Ok(());
909            }
910            Err(e) => return Err(e),
911        };
912        if let Some(v) = response.headers().get(HEADER_SESSION_ID) {
913            if let Ok(s) = v.to_str() {
914                self.session_id = s.to_string();
915            }
916        }
917
918        let body = response.bytes().await?;
919        let response = json_from_slice(&body)?;
920        match response {
921            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
922            LoginResponseResult::Ok(info) => {
923                let server_version = info
924                    .version
925                    .parse()
926                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
927                self.capability = Capability::from_server_version(&server_version);
928                self.server_version = Some(server_version.clone());
929                let session_id = self.session_id.as_str();
930                if let Some(tokens) = info.tokens {
931                    info!(
932                        "[session {session_id}] login success with session token version = {server_version}",
933                    );
934                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
935                } else {
936                    info!("[session {session_id}] login success, version = {server_version}");
937                }
938            }
939        }
940        Ok(())
941    }
942
943    pub(crate) async fn try_heartbeat(&self) -> Result<()> {
944        let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
945        let queries = self.queries_need_heartbeat.lock().clone();
946        let mut node_to_queries = HashMap::<String, Vec<String>>::new();
947        let now = Instant::now();
948
949        let mut query_ids = Vec::new();
950        for (qid, state) in queries {
951            if state.need_heartbeat(now) {
952                query_ids.push(qid.to_string());
953                if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
954                    arr.push(qid);
955                } else {
956                    node_to_queries.insert(state.node_id, vec![qid]);
957                }
958            }
959        }
960
961        if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
962        {
963            return Ok(());
964        }
965
966        let body = json!({
967           "node_to_queries": node_to_queries
968        });
969        let builder = self.cli.post(endpoint.clone()).json(&body);
970        let request = self.wrap_auth_or_session_token(builder)?.build()?;
971        let response = self.query_request_helper(request, true, false).await?;
972        let json: Value = response.json().await?;
973        let session_id = self.session_id.as_str();
974        info!("[session {session_id}] heartbeat request={body}, response={json}");
975        if let Some(queries_to_remove) = json.get("queries_to_remove") {
976            if let Some(arr) = queries_to_remove.as_array() {
977                if !arr.is_empty() {
978                    let mut queries = self.queries_need_heartbeat.lock();
979                    for q in arr {
980                        if let Some(q) = q.as_str() {
981                            queries.remove(q);
982                        }
983                    }
984                }
985            }
986        }
987        let now = Instant::now();
988        let mut queries = self.queries_need_heartbeat.lock();
989        for qid in query_ids {
990            if let Some(state) = queries.get_mut(&qid) {
991                *state.last_access_time.lock() = now;
992            }
993        }
994        Ok(())
995    }
996
997    fn build_log_out_request(&self) -> Result<Request> {
998        let endpoint = self.endpoint.join("/v1/session/logout")?;
999
1000        let session_state = self.session_state();
1001        let need_sticky = session_state.need_sticky.unwrap_or(false);
1002        let mut headers = self.make_headers(None)?;
1003        if need_sticky {
1004            if let Some(node_id) = self.last_node_id() {
1005                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1006            }
1007        }
1008        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1009
1010        let builder = self.wrap_auth_or_session_token(builder)?;
1011        let req = builder.build()?;
1012        Ok(req)
1013    }
1014
1015    pub(crate) fn need_logout(&self) -> bool {
1016        self.session_token_info.is_some()
1017            || self.session_state.lock().need_keep_alive.unwrap_or(false)
1018    }
1019
1020    async fn refresh_session_token(
1021        &self,
1022        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1023    ) -> Result<()> {
1024        let (session_token_info, _) = { self_login_info.lock().clone() };
1025        let endpoint = self.endpoint.join("/v1/session/refresh")?;
1026        let body = RefreshSessionTokenRequest {
1027            session_token: session_token_info.session_token.clone(),
1028        };
1029        let headers = self.make_headers(None)?;
1030        let request = self
1031            .cli
1032            .post(endpoint.clone())
1033            .json(&body)
1034            .headers(headers.clone())
1035            .bearer_auth(session_token_info.refresh_token.clone())
1036            .timeout(self.connect_timeout)
1037            .build()?;
1038
1039        // avoid recursively call request_helper
1040        for i in 0..3 {
1041            let req = request.try_clone().expect("request not cloneable");
1042            match self.cli.execute(req).await {
1043                Ok(response) => {
1044                    let status = response.status();
1045                    let body = response.bytes().await?;
1046                    if status == StatusCode::OK {
1047                        let response = json_from_slice(&body)?;
1048                        return match response {
1049                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1050                            RefreshResponse::Ok(info) => {
1051                                *self_login_info.lock() = (info, Instant::now());
1052                                Ok(())
1053                            }
1054                        };
1055                    }
1056                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
1057                        return Err(Error::response_error(status, &body));
1058                    }
1059                }
1060                Err(err) => {
1061                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
1062                        return Err(Error::Request(err.to_string()));
1063                    }
1064                }
1065            };
1066            sleep(jitter(Duration::from_secs(10))).await;
1067        }
1068        Ok(())
1069    }
1070
1071    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1072        if let Some(info) = &self.session_token_info {
1073            let (start, ttl) = {
1074                let guard = info.lock();
1075                (guard.1, guard.0.session_token_ttl_in_secs)
1076            };
1077            if Instant::now() > start + Duration::from_secs(ttl) {
1078                return Some(info.clone());
1079            }
1080        }
1081        None
1082    }
1083
1084    /// return Ok if and only if status code is 200.
1085    ///
1086    /// retry on
1087    ///   - network errors
1088    ///   - (optional) 503
1089    ///
1090    /// refresh databend token or reload jwt token if needed.
1091    async fn query_request_helper(
1092        &self,
1093        mut request: Request,
1094        retry_if_503: bool,
1095        refresh_if_401: bool,
1096    ) -> std::result::Result<Response, Error> {
1097        let mut refreshed = false;
1098        let mut retries = 0;
1099        loop {
1100            let req = request.try_clone().expect("request not cloneable");
1101            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
1102                Ok(response) => {
1103                    let status = response.status();
1104                    if status == StatusCode::OK {
1105                        return Ok(response);
1106                    }
1107                    let body = response.bytes().await?;
1108                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1109                        // waiting for server to start
1110                        (Error::response_error(status, &body), true)
1111                    } else {
1112                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
1113                        match resp {
1114                            Ok(r) => {
1115                                let e = r.error;
1116                                if status == StatusCode::UNAUTHORIZED {
1117                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1118                                    if let Some(session_token_info) = &self.session_token_info {
1119                                        info!(
1120                                            "will retry {} after refresh token on auth error {}",
1121                                            request.url(),
1122                                            e
1123                                        );
1124                                        let retry = if need_refresh_token(e.code)
1125                                            && !refreshed
1126                                            && refresh_if_401
1127                                        {
1128                                            self.refresh_session_token(session_token_info.clone())
1129                                                .await?;
1130                                            refreshed = true;
1131                                            true
1132                                        } else {
1133                                            false
1134                                        };
1135                                        (Error::AuthFailure(e), retry)
1136                                    } else if self.auth.can_reload() {
1137                                        info!(
1138                                            "will retry {} after reload token on auth error {}",
1139                                            request.url(),
1140                                            e
1141                                        );
1142                                        let builder = RequestBuilder::from_parts(
1143                                            HttpClient::new(),
1144                                            request.try_clone().unwrap(),
1145                                        );
1146                                        let builder = self.auth.wrap(builder)?;
1147                                        request = builder.build()?;
1148                                        (Error::AuthFailure(e), true)
1149                                    } else {
1150                                        (Error::AuthFailure(e), false)
1151                                    }
1152                                } else {
1153                                    (Error::Logic(status, e), false)
1154                                }
1155                            }
1156                            Err(_) => (
1157                                Error::Response {
1158                                    status,
1159                                    msg: String::from_utf8_lossy(&body).to_string(),
1160                                },
1161                                false,
1162                            ),
1163                        }
1164                    }
1165                }
1166                Err(err) => (
1167                    Error::Request(err.to_string()),
1168                    err.is_timeout() || err.is_connect() || err.is_request(),
1169                ),
1170            };
1171            if !retry {
1172                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
1173            }
1174            match &err {
1175                Error::AuthFailure(_) => {
1176                    if refreshed {
1177                        retries = 0;
1178                    } else if retries == 2 {
1179                        return Err(err.with_context(&format!(
1180                            "{} {} after 3 retries",
1181                            request.method(),
1182                            request.url()
1183                        )));
1184                    }
1185                }
1186                _ => {
1187                    if retries == 2 {
1188                        return Err(err.with_context(&format!(
1189                            "{} {} after 3 reties",
1190                            request.method(),
1191                            request.url()
1192                        )));
1193                    }
1194                    retries += 1;
1195                    info!(
1196                        "will retry {} the {retries}th times on error {}",
1197                        request.url(),
1198                        err
1199                    );
1200                }
1201            }
1202            warn!("will retry after 10 seconds");
1203            sleep(jitter(Duration::from_secs(10))).await;
1204        }
1205    }
1206
1207    pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1208        if let Err(err) = http_client.execute(request).await {
1209            error!("[session {session_id}] logout request failed: {err}");
1210        } else {
1211            info!("[session {session_id}] logout success");
1212        };
1213    }
1214
1215    pub async fn close(&self) {
1216        let session_id = &self.session_id;
1217        info!("[session {session_id}] try closing now");
1218        if self
1219            .closed
1220            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1221            .is_ok()
1222        {
1223            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1224            if self.need_logout() {
1225                let cli = self.cli.clone();
1226                let req = self
1227                    .build_log_out_request()
1228                    .expect("failed to build logout request");
1229                Self::logout(cli, req, &self.session_id).await;
1230            }
1231        }
1232    }
1233    pub fn close_with_spawn(&self) {
1234        let session_id = &self.session_id;
1235        info!("[session {session_id}]: try closing with spawn");
1236        if self
1237            .closed
1238            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1239            .is_ok()
1240        {
1241            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1242            if self.need_logout() {
1243                let cli = self.cli.clone();
1244                let req = self
1245                    .build_log_out_request()
1246                    .expect("failed to build logout request");
1247                let session_id = self.session_id.clone();
1248                GLOBAL_RUNTIME.spawn(async move {
1249                    Self::logout(cli, req, session_id.as_str()).await;
1250                });
1251            }
1252        }
1253    }
1254
1255    pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1256        let mut queries = self.queries_need_heartbeat.lock();
1257        queries.insert(query_id.to_string(), state);
1258    }
1259}
1260
1261fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1262where
1263    T: Deserialize<'a>,
1264{
1265    serde_json::from_slice::<T>(body).map_err(|e| {
1266        Error::Decode(format!(
1267            "fail to decode JSON response: {e}, body: {}",
1268            String::from_utf8_lossy(body)
1269        ))
1270    })
1271}
1272
1273impl Default for APIClient {
1274    fn default() -> Self {
1275        Self {
1276            session_id: Default::default(),
1277            cli: HttpClient::new(),
1278            scheme: "http".to_string(),
1279            endpoint: Url::parse("http://localhost:8080").unwrap(),
1280            host: "localhost".to_string(),
1281            port: 8000,
1282            tenant: None,
1283            warehouse: Mutex::new(None),
1284            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1285            session_state: Mutex::new(SessionState::default()),
1286            wait_time_secs: None,
1287            max_rows_in_buffer: None,
1288            max_rows_per_page: None,
1289            connect_timeout: Duration::from_secs(10),
1290            page_request_timeout: Duration::from_secs(30),
1291            tls_ca_file: None,
1292            presign: Mutex::new(PresignMode::Auto),
1293            route_hint: RouteHintGenerator::new(),
1294            last_node_id: Default::default(),
1295            disable_session_token: true,
1296            disable_login: false,
1297            query_result_format: "json".to_string(),
1298            session_token_info: None,
1299            closed: AtomicBool::new(false),
1300            last_query_id: Default::default(),
1301            server_version: None,
1302            capability: Default::default(),
1303            queries_need_heartbeat: Default::default(),
1304        }
1305    }
1306}
1307
1308struct RouteHintGenerator {
1309    nonce: AtomicU64,
1310    current: std::sync::Mutex<String>,
1311}
1312
1313impl RouteHintGenerator {
1314    fn new() -> Self {
1315        let gen = Self {
1316            nonce: AtomicU64::new(0),
1317            current: std::sync::Mutex::new("".to_string()),
1318        };
1319        gen.next();
1320        gen
1321    }
1322
1323    fn current(&self) -> String {
1324        let guard = self.current.lock().unwrap();
1325        guard.clone()
1326    }
1327
1328    fn set(&self, hint: &str) {
1329        let mut guard = self.current.lock().unwrap();
1330        *guard = hint.to_string();
1331    }
1332
1333    fn next(&self) -> String {
1334        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1335        let uuid = uuid::Uuid::new_v4();
1336        let current = format!("rh:{uuid}:{nonce:06}");
1337        let mut guard = self.current.lock().unwrap();
1338        guard.clone_from(&current);
1339        current
1340    }
1341}
1342
1343#[cfg(test)]
1344mod test {
1345    use super::*;
1346
1347    #[tokio::test]
1348    async fn parse_dsn() -> Result<()> {
1349        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1350        let client = APIClient::from_dsn(dsn).await?;
1351        assert_eq!(client.host, "app.databend.com");
1352        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1353        assert_eq!(client.wait_time_secs, Some(10));
1354        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1355        assert_eq!(client.max_rows_per_page, Some(10000));
1356        assert_eq!(client.tenant, None);
1357        assert_eq!(
1358            *client.warehouse.try_lock().unwrap(),
1359            Some("wh".to_string())
1360        );
1361        Ok(())
1362    }
1363
1364    #[tokio::test]
1365    async fn parse_encoded_password() -> Result<()> {
1366        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1367        let client = APIClient::from_dsn(dsn).await?;
1368        assert_eq!(client.host(), "localhost");
1369        assert_eq!(client.port(), 443);
1370        Ok(())
1371    }
1372
1373    #[tokio::test]
1374    async fn parse_special_chars_password() -> Result<()> {
1375        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1376        let client = APIClient::from_dsn(dsn).await?;
1377        assert_eq!(client.host(), "localhost");
1378        assert_eq!(client.port(), 8000);
1379        Ok(())
1380    }
1381}