databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22    SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28    error::{Error, Result},
29    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30    response::QueryResponse,
31    session::SessionState,
32    QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::{de, Deserialize};
49use serde_json::{json, Value};
50use std::collections::{BTreeMap, HashMap};
51use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tokio::time::sleep;
55use tokio_retry::strategy::jitter;
56use tokio_stream::StreamExt;
57use tokio_util::io::ReaderStream;
58use url::Url;
59
60const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
61const HEADER_TENANT: &str = "X-DATABEND-TENANT";
62const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
63const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
64const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
65const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
66const TXN_STATE_ACTIVE: &str = "Active";
67const HEADER_SQL: &str = "X-DATABEND-SQL";
68const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
69const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
70const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
71const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
72
73static VERSION: Lazy<String> = Lazy::new(|| {
74    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
75    version.to_string()
76});
77
78#[derive(Clone)]
79pub(crate) struct QueryState {
80    pub node_id: String,
81    pub last_access_time: Arc<Mutex<Instant>>,
82    pub timeout_secs: u64,
83}
84
85impl QueryState {
86    pub fn need_heartbeat(&self, now: Instant) -> bool {
87        let t = *self.last_access_time.lock();
88        now.duration_since(t).as_secs() > self.timeout_secs / 2
89    }
90}
91
92pub struct APIClient {
93    pub(crate) session_id: String,
94    cli: HttpClient,
95    scheme: String,
96    host: String,
97    port: u16,
98
99    endpoint: Url,
100
101    auth: Arc<dyn Auth>,
102
103    tenant: Option<String>,
104    warehouse: Mutex<Option<String>>,
105    session_state: Mutex<SessionState>,
106    route_hint: RouteHintGenerator,
107
108    disable_login: bool,
109    body_format: String,
110    disable_session_token: bool,
111    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
112
113    closed: AtomicBool,
114
115    server_version: Option<Version>,
116
117    wait_time_secs: Option<i64>,
118    max_rows_in_buffer: Option<i64>,
119    max_rows_per_page: Option<i64>,
120
121    connect_timeout: Duration,
122    page_request_timeout: Duration,
123
124    tls_ca_file: Option<String>,
125
126    presign: Mutex<PresignMode>,
127    last_node_id: Mutex<Option<String>>,
128    last_query_id: Mutex<Option<String>>,
129
130    capability: Capability,
131
132    queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
133}
134
135impl Drop for APIClient {
136    fn drop(&mut self) {
137        self.close_with_spawn()
138    }
139}
140
141impl APIClient {
142    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
143        let mut client = Self::from_dsn(dsn).await?;
144        client.build_client(name).await?;
145        if !client.disable_login {
146            client.login().await?;
147        }
148        if client.session_id.is_empty() {
149            client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
150        }
151        let client = Arc::new(client);
152        client.check_presign().await?;
153        GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
154        Ok(client)
155    }
156
157    pub fn capability(&self) -> &Capability {
158        &self.capability
159    }
160
161    fn set_presign_mode(&self, mode: PresignMode) {
162        *self.presign.lock() = mode
163    }
164    fn get_presign_mode(&self) -> PresignMode {
165        *self.presign.lock()
166    }
167
168    async fn from_dsn(dsn: &str) -> Result<Self> {
169        let u = Url::parse(dsn)?;
170        let mut client = Self::default();
171        if let Some(host) = u.host_str() {
172            client.host = host.to_string();
173        }
174
175        if u.username() != "" {
176            let password = u.password().unwrap_or_default();
177            let password = percent_decode_str(password).decode_utf8()?;
178            client.auth = Arc::new(BasicAuth::new(u.username(), password));
179        }
180        let database = match u.path().trim_start_matches('/') {
181            "" => None,
182            s => Some(s.to_string()),
183        };
184        let mut role = None;
185        let mut scheme = "https";
186        let mut session_settings = BTreeMap::new();
187        for (k, v) in u.query_pairs() {
188            match k.as_ref() {
189                "wait_time_secs" => {
190                    client.wait_time_secs = Some(v.parse()?);
191                }
192                "max_rows_in_buffer" => {
193                    client.max_rows_in_buffer = Some(v.parse()?);
194                }
195                "max_rows_per_page" => {
196                    client.max_rows_per_page = Some(v.parse()?);
197                }
198                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
199                "page_request_timeout_secs" => {
200                    client.page_request_timeout = {
201                        let secs: u64 = v.parse()?;
202                        Duration::from_secs(secs)
203                    };
204                }
205                "presign" => {
206                    let presign_mode = match v.as_ref() {
207                        "auto" => PresignMode::Auto,
208                        "detect" => PresignMode::Detect,
209                        "on" => PresignMode::On,
210                        "off" => PresignMode::Off,
211                        _ => {
212                            return Err(Error::BadArgument(format!(
213                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
214                        )))
215                        }
216                    };
217                    client.set_presign_mode(presign_mode);
218                }
219                "tenant" => {
220                    client.tenant = Some(v.to_string());
221                }
222                "warehouse" => {
223                    client.warehouse = Mutex::new(Some(v.to_string()));
224                }
225                "role" => role = Some(v.to_string()),
226                "sslmode" => match v.as_ref() {
227                    "disable" => scheme = "http",
228                    "require" | "enable" => scheme = "https",
229                    _ => {
230                        return Err(Error::BadArgument(format!(
231                            "Invalid value for sslmode: {v}"
232                        )))
233                    }
234                },
235                "tls_ca_file" => {
236                    client.tls_ca_file = Some(v.to_string());
237                }
238                "access_token" => {
239                    client.auth = Arc::new(AccessTokenAuth::new(v));
240                }
241                "access_token_file" => {
242                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
243                }
244                "login" => {
245                    client.disable_login = match v.as_ref() {
246                        "disable" => true,
247                        "enable" => false,
248                        _ => {
249                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
250                        }
251                    }
252                }
253                "session_token" => {
254                    client.disable_session_token = match v.as_ref() {
255                        "disable" => true,
256                        "enable" => false,
257                        _ => {
258                            return Err(Error::BadArgument(format!(
259                                "Invalid value for session_token: {v}"
260                            )))
261                        }
262                    }
263                }
264                "body_format" => {
265                    let v = v.to_string().to_lowercase();
266                    match v.as_str() {
267                        "json" | "arrow" => client.body_format = v.to_string(),
268                        _ => {
269                            return Err(Error::BadArgument(format!(
270                                "Invalid value for body_format: {v}"
271                            )))
272                        }
273                    }
274                }
275                _ => {
276                    session_settings.insert(k.to_string(), v.to_string());
277                }
278            }
279        }
280        client.port = match u.port() {
281            Some(p) => p,
282            None => match scheme {
283                "http" => 80,
284                "https" => 443,
285                _ => unreachable!(),
286            },
287        };
288        client.scheme = scheme.to_string();
289
290        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
291        client.session_state = Mutex::new(
292            SessionState::default()
293                .with_settings(Some(session_settings))
294                .with_role(role)
295                .with_database(database),
296        );
297
298        Ok(client)
299    }
300
301    pub fn host(&self) -> &str {
302        self.host.as_str()
303    }
304
305    pub fn port(&self) -> u16 {
306        self.port
307    }
308
309    pub fn scheme(&self) -> &str {
310        self.scheme.as_str()
311    }
312
313    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
314        let ua = match name {
315            Some(n) => n,
316            None => format!("databend-client-rust/{}", VERSION.as_str()),
317        };
318        let cookie_provider = GlobalCookieStore::new();
319        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
320        let mut initial_cookies = [&cookie].into_iter();
321        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
322        let mut cli_builder = HttpClient::builder()
323            .user_agent(ua)
324            .cookie_provider(Arc::new(cookie_provider))
325            .pool_idle_timeout(Duration::from_secs(1));
326        #[cfg(any(feature = "rustls", feature = "native-tls"))]
327        if self.scheme == "https" {
328            if let Some(ref ca_file) = self.tls_ca_file {
329                let cert_pem = tokio::fs::read(ca_file).await?;
330                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
331                cli_builder = cli_builder.add_root_certificate(cert);
332            }
333        }
334        self.cli = cli_builder.build()?;
335        Ok(())
336    }
337
338    async fn check_presign(self: &Arc<Self>) -> Result<()> {
339        let mode = match self.get_presign_mode() {
340            PresignMode::Auto => {
341                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
342                    PresignMode::On
343                } else {
344                    PresignMode::Off
345                }
346            }
347            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
348                Ok(_) => PresignMode::On,
349                Err(e) => {
350                    warn!("presign mode off with error detected: {e}");
351                    PresignMode::Off
352                }
353            },
354            mode => mode,
355        };
356        self.set_presign_mode(mode);
357        Ok(())
358    }
359
360    pub fn current_warehouse(&self) -> Option<String> {
361        let guard = self.warehouse.lock();
362        guard.clone()
363    }
364
365    pub fn current_catalog(&self) -> Option<String> {
366        let guard = self.session_state.lock();
367        guard.catalog.clone()
368    }
369
370    pub fn current_database(&self) -> Option<String> {
371        let guard = self.session_state.lock();
372        guard.database.clone()
373    }
374
375    pub async fn current_role(&self) -> Option<String> {
376        let guard = self.session_state.lock();
377        guard.role.clone()
378    }
379
380    fn in_active_transaction(&self) -> bool {
381        let guard = self.session_state.lock();
382        guard
383            .txn_state
384            .as_ref()
385            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
386            .unwrap_or(false)
387    }
388
389    pub fn username(&self) -> String {
390        self.auth.username()
391    }
392
393    fn gen_query_id(&self) -> String {
394        uuid::Uuid::now_v7().simple().to_string()
395    }
396
397    async fn handle_session(&self, session: &Option<SessionState>) {
398        let session = match session {
399            Some(session) => session,
400            None => return,
401        };
402
403        // save the updated session state from the server side
404        {
405            let mut session_state = self.session_state.lock();
406            *session_state = session.clone();
407        }
408
409        // process warehouse changed via session settings
410        if let Some(settings) = session.settings.as_ref() {
411            if let Some(v) = settings.get("warehouse") {
412                let mut warehouse = self.warehouse.lock();
413                *warehouse = Some(v.clone());
414            }
415        }
416    }
417
418    pub fn set_last_node_id(&self, node_id: String) {
419        *self.last_node_id.lock() = Some(node_id)
420    }
421
422    pub fn set_last_query_id(&self, query_id: Option<String>) {
423        *self.last_query_id.lock() = query_id
424    }
425
426    pub fn last_query_id(&self) -> Option<String> {
427        self.last_query_id.lock().clone()
428    }
429
430    fn last_node_id(&self) -> Option<String> {
431        self.last_node_id.lock().clone()
432    }
433
434    fn handle_warnings(&self, resp: &QueryResponse) {
435        if let Some(warnings) = &resp.warnings {
436            for w in warnings {
437                warn!(target: "server_warnings", "server warning: {w}");
438            }
439        }
440    }
441
442    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
443        info!("start query: {sql}");
444        let (resp, batches) = self.start_query_inner(sql, None, false).await?;
445        Pages::new(self.clone(), resp, batches, need_progress)
446    }
447
448    pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
449        let mut mgr = self.queries_need_heartbeat.lock();
450        if let Some(state) = mgr.remove(query_id) {
451            let self_cloned = self.clone();
452            let query_id = query_id.to_owned();
453            GLOBAL_RUNTIME.spawn(async move {
454                if let Err(e) = self_cloned
455                    .end_query(&query_id, "final", Some(state.node_id.as_str()))
456                    .await
457                {
458                    error!("failed to final query {query_id}: {e}");
459                }
460            });
461        }
462    }
463
464    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
465        if let Some(info) = &self.session_token_info {
466            let info = info.lock();
467            Ok(builder.bearer_auth(info.0.session_token.clone()))
468        } else {
469            self.auth.wrap(builder)
470        }
471    }
472
473    async fn start_query_inner(
474        &self,
475        sql: &str,
476        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
477        force_json_body: bool,
478    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
479        if !self.in_active_transaction() {
480            self.route_hint.next();
481        }
482        let endpoint = self.endpoint.join("v1/query")?;
483
484        // body
485        let session_state = self.session_state();
486        let need_sticky = session_state.need_sticky.unwrap_or(false);
487        let req = QueryRequest::new(sql)
488            .with_pagination(self.make_pagination())
489            .with_session(Some(session_state))
490            .with_stage_attachment(stage_attachment_config);
491
492        // headers
493        let query_id = self.gen_query_id();
494        let mut headers = self.make_headers(Some(&query_id))?;
495        if self.capability.arrow_data && self.body_format == "arrow" && !force_json_body {
496            debug!("accept arrow data");
497            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
498        }
499
500        if need_sticky {
501            if let Some(node_id) = self.last_node_id() {
502                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
503            }
504        }
505        let mut builder = self.cli.post(endpoint.clone()).json(&req);
506        builder = self.wrap_auth_or_session_token(builder)?;
507        let request = builder.headers(headers.clone()).build()?;
508        let response = self.query_request_helper(request, true, true).await?;
509        self.handle_page(response, true).await
510    }
511
512    fn is_arrow_data(response: &Response) -> bool {
513        if let Some(typ) = response.headers().get(CONTENT_TYPE) {
514            if let Ok(t) = typ.to_str() {
515                return t == CONTENT_TYPE_ARROW;
516            }
517        }
518        false
519    }
520
521    async fn handle_page(
522        &self,
523        response: Response,
524        is_first: bool,
525    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
526        let status = response.status();
527        if status != 200 {
528            return Err(Error::response_error(status, &response.bytes().await?));
529        }
530        let is_arrow_data = Self::is_arrow_data(&response);
531        if is_first {
532            if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
533                self.route_hint.set(route_hint.to_str().unwrap_or_default());
534            }
535        }
536        let mut body = response.bytes().await?;
537        let mut batches = vec![];
538        if is_arrow_data {
539            if is_first {
540                debug!("received arrow data");
541            }
542            let cursor = std::io::Cursor::new(body.as_ref());
543            let reader = StreamReader::try_new(cursor, None)
544                .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
545            let schema = reader.schema();
546            let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
547                bytes::Bytes::copy_from_slice(json_resp.as_bytes())
548            } else {
549                return Err(Error::Decode(
550                    "missing response_header metadata in arrow payload".to_string(),
551                ));
552            };
553            for batch in reader {
554                let batch = batch
555                    .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
556                batches.push(batch);
557            }
558            body = json_body
559        };
560        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
561            if let Error::Logic(status, ec) = &e {
562                if *status == 404 {
563                    return Error::QueryNotFound(ec.message.clone());
564                }
565            }
566            e
567        })?;
568        self.handle_session(&resp.session).await;
569        if let Some(err) = &resp.error {
570            return Err(Error::QueryFailed(err.clone()));
571        }
572        if is_first {
573            self.handle_warnings(&resp);
574            self.set_last_query_id(Some(resp.id.clone()));
575            if let Some(node_id) = &resp.node_id {
576                self.set_last_node_id(node_id.clone());
577            }
578        }
579        Ok((resp, batches))
580    }
581
582    pub async fn query_page(
583        &self,
584        query_id: &str,
585        next_uri: &str,
586        node_id: &Option<String>,
587    ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
588        info!("query page: {next_uri}");
589        let endpoint = self.endpoint.join(next_uri)?;
590        let mut headers = self.make_headers(Some(query_id))?;
591        if self.capability.arrow_data && self.body_format == "arrow" {
592            headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
593        }
594        let mut builder = self.cli.get(endpoint.clone());
595        builder = self
596            .wrap_auth_or_session_token(builder)?
597            .headers(headers.clone())
598            .timeout(self.page_request_timeout);
599        if let Some(node_id) = node_id {
600            builder = builder.header(HEADER_STICKY_NODE, node_id)
601        }
602        let request = builder.build()?;
603
604        let response = self.query_request_helper(request, false, true).await?;
605        self.handle_page(response, false).await
606    }
607
608    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
609        self.end_query(query_id, "kill", None).await
610    }
611
612    pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
613        self.end_query(query_id, "final", node_id).await
614    }
615
616    pub async fn end_query(
617        &self,
618        query_id: &str,
619        method: &str,
620        node_id: Option<&str>,
621    ) -> Result<()> {
622        let uri = format!("/v1/query/{query_id}/{method}");
623        let endpoint = self.endpoint.join(&uri)?;
624        let headers = self.make_headers(Some(query_id))?;
625
626        info!("{method} query: {uri}");
627
628        let mut builder = self.cli.post(endpoint);
629        if let Some(node_id) = node_id {
630            builder = builder.header(HEADER_STICKY_NODE, node_id)
631        }
632        builder = self.wrap_auth_or_session_token(builder)?;
633        let resp = builder.headers(headers.clone()).send().await?;
634        if resp.status() != 200 {
635            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
636                .with_context(&format!("{method} query")));
637        }
638        Ok(())
639    }
640
641    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
642        self.query_all_inner(sql, false).await
643    }
644
645    pub async fn query_all_inner(
646        self: &Arc<Self>,
647        sql: &str,
648        force_json_body: bool,
649    ) -> Result<Page> {
650        let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
651        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
652        let mut all = Page::default();
653        while let Some(page) = pages.next().await {
654            all.update(page?);
655        }
656        Ok(all)
657    }
658
659    fn session_state(&self) -> SessionState {
660        self.session_state.lock().clone()
661    }
662
663    fn make_pagination(&self) -> Option<PaginationConfig> {
664        if self.wait_time_secs.is_none()
665            && self.max_rows_in_buffer.is_none()
666            && self.max_rows_per_page.is_none()
667        {
668            return None;
669        }
670        let mut pagination = PaginationConfig {
671            wait_time_secs: None,
672            max_rows_in_buffer: None,
673            max_rows_per_page: None,
674        };
675        if let Some(wait_time_secs) = self.wait_time_secs {
676            pagination.wait_time_secs = Some(wait_time_secs);
677        }
678        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
679            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
680        }
681        if let Some(max_rows_per_page) = self.max_rows_per_page {
682            pagination.max_rows_per_page = Some(max_rows_per_page);
683        }
684        Some(pagination)
685    }
686
687    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
688        let mut headers = HeaderMap::new();
689        if let Some(tenant) = &self.tenant {
690            headers.insert(HEADER_TENANT, tenant.parse()?);
691        }
692        let warehouse = self.warehouse.lock().clone();
693        if let Some(warehouse) = warehouse {
694            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
695        }
696        let route_hint = self.route_hint.current();
697        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
698        if let Some(query_id) = query_id {
699            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
700        }
701        Ok(headers)
702    }
703
704    pub async fn insert_with_stage(
705        self: &Arc<Self>,
706        sql: &str,
707        stage: &str,
708        file_format_options: BTreeMap<&str, &str>,
709        copy_options: BTreeMap<&str, &str>,
710    ) -> Result<QueryStats> {
711        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
712        let stage_attachment = Some(StageAttachmentConfig {
713            location: stage,
714            file_format_options: Some(file_format_options),
715            copy_options: Some(copy_options),
716        });
717        let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
718        let mut pages = Pages::new(self.clone(), resp, batches, false)?;
719        let mut all = Page::default();
720        while let Some(page) = pages.next().await {
721            all.update(page?);
722        }
723        Ok(all.stats)
724    }
725
726    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
727        info!("get presigned upload url: {stage}");
728        let sql = format!("PRESIGN UPLOAD {stage}");
729        let resp = self.query_all_inner(&sql, true).await?;
730        if resp.data.len() != 1 {
731            return Err(Error::Decode(
732                "Empty response from server for presigned request".to_string(),
733            ));
734        }
735        if resp.data[0].len() != 3 {
736            return Err(Error::Decode(
737                "Invalid response from server for presigned request".to_string(),
738            ));
739        }
740        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
741        let method = resp.data[0][0].clone().unwrap_or_default();
742        if method != "PUT" {
743            return Err(Error::Decode(format!(
744                "Invalid method for presigned upload request: {method}"
745            )));
746        }
747        let headers: BTreeMap<String, String> =
748            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
749        let url = resp.data[0][2].clone().unwrap_or_default();
750        Ok(PresignedResponse {
751            method,
752            headers,
753            url,
754        })
755    }
756
757    pub async fn upload_to_stage(
758        self: &Arc<Self>,
759        stage: &str,
760        data: Reader,
761        size: u64,
762    ) -> Result<()> {
763        match self.get_presign_mode() {
764            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
765            PresignMode::On => {
766                let presigned = self.get_presigned_upload_url(stage).await?;
767                presign_upload_to_stage(presigned, data, size).await
768            }
769            PresignMode::Auto => {
770                unreachable!("PresignMode::Auto should be handled during client initialization")
771            }
772            PresignMode::Detect => {
773                unreachable!("PresignMode::Detect should be handled during client initialization")
774            }
775        }
776    }
777
778    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
779    async fn upload_to_stage_with_stream(
780        &self,
781        stage: &str,
782        data: Reader,
783        size: u64,
784    ) -> Result<()> {
785        info!("upload to stage with stream: {stage}, size: {size}");
786        if let Some(info) = self.need_pre_refresh_session().await {
787            self.refresh_session_token(info).await?;
788        }
789        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
790        let location = StageLocation::try_from(stage)?;
791        let query_id = self.gen_query_id();
792        let mut headers = self.make_headers(Some(&query_id))?;
793        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
794        let stream = Body::wrap_stream(ReaderStream::new(data));
795        let part = Part::stream_with_length(stream, size).file_name(location.path);
796        let form = Form::new().part("upload", part);
797        let mut builder = self.cli.put(endpoint.clone());
798        builder = self.wrap_auth_or_session_token(builder)?;
799        let resp = builder.headers(headers).multipart(form).send().await?;
800        let status = resp.status();
801        if status != 200 {
802            return Err(
803                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
804            );
805        }
806        Ok(())
807    }
808
809    // use base64 encode whenever possible for safety
810    // but also accept raw JSON for test/debug/one-shot operations
811    pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
812    where
813        T: de::DeserializeOwned,
814    {
815        if value.starts_with("{") {
816            serde_json::from_slice(value.as_bytes())
817                .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
818        } else {
819            let json = URL_SAFE.decode(value).map_err(|e| {
820                format!(
821                    "Invalid value {} for {key}, base64 decode error: {}",
822                    value, e
823                )
824            })?;
825            serde_json::from_slice(&json).map_err(|e| {
826                format!(
827                    "Invalid value {value} for {key}, JSON value {},  decode error: {e}",
828                    String::from_utf8_lossy(&json)
829                )
830            })
831        }
832    }
833
834    pub async fn streaming_load(
835        &self,
836        sql: &str,
837        data: Reader,
838        file_name: &str,
839    ) -> Result<LoadResponse> {
840        let body = Body::wrap_stream(ReaderStream::new(data));
841        let part = Part::stream(body).file_name(file_name.to_string());
842        let endpoint = self.endpoint.join("v1/streaming_load")?;
843        let mut builder = self.cli.put(endpoint.clone());
844        builder = self.wrap_auth_or_session_token(builder)?;
845        let query_id = self.gen_query_id();
846        let mut headers = self.make_headers(Some(&query_id))?;
847        headers.insert(HEADER_SQL, sql.parse()?);
848        let session = serde_json::to_string(&*self.session_state.lock())
849            .expect("serialize session state should not fail");
850        headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
851        let form = Form::new().part("upload", part);
852        let resp = builder.headers(headers).multipart(form).send().await?;
853        let status = resp.status();
854        if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
855            match Self::decode_json_header::<SessionState>(
856                HEADER_QUERY_CONTEXT,
857                value.to_str().unwrap(),
858            ) {
859                Ok(session) => *self.session_state.lock() = session,
860                Err(e) => {
861                    error!("Error decoding session state when streaming load: {e}");
862                }
863            }
864        };
865        if status != 200 {
866            return Err(
867                Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
868            );
869        }
870        let resp = resp.json::<LoadResponse>().await?;
871        Ok(resp)
872    }
873
874    async fn login(&mut self) -> Result<()> {
875        let endpoint = self.endpoint.join("/v1/session/login")?;
876        let headers = self.make_headers(None)?;
877        let body = LoginRequest::from(&*self.session_state.lock());
878        let mut builder = self.cli.post(endpoint.clone()).json(&body);
879        if self.disable_session_token {
880            builder = builder.query(&[("disable_session_token", true)]);
881        }
882        let builder = self.auth.wrap(builder)?;
883        let request = builder
884            .headers(headers.clone())
885            .timeout(self.connect_timeout)
886            .build()?;
887        let response = self.query_request_helper(request, true, false).await;
888        let response = match response {
889            Ok(r) => r,
890            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
891                info!("login return 404, skip login on the old version server");
892                return Ok(());
893            }
894            Err(e) => return Err(e),
895        };
896        if let Some(v) = response.headers().get(HEADER_SESSION_ID) {
897            if let Ok(s) = v.to_str() {
898                self.session_id = s.to_string();
899            }
900        }
901
902        let body = response.bytes().await?;
903        let response = json_from_slice(&body)?;
904        match response {
905            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
906            LoginResponseResult::Ok(info) => {
907                let server_version = info
908                    .version
909                    .parse()
910                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
911                self.capability = Capability::from_server_version(&server_version);
912                self.server_version = Some(server_version.clone());
913                let session_id = self.session_id.as_str();
914                if let Some(tokens) = info.tokens {
915                    info!(
916                        "[session {session_id}] login success with session token version = {server_version}",
917                    );
918                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
919                } else {
920                    info!("[session {session_id}] login success, version = {server_version}");
921                }
922            }
923        }
924        Ok(())
925    }
926
927    pub(crate) async fn try_heartbeat(&self) -> Result<()> {
928        let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
929        let queries = self.queries_need_heartbeat.lock().clone();
930        let mut node_to_queries = HashMap::<String, Vec<String>>::new();
931        let now = Instant::now();
932
933        let mut query_ids = Vec::new();
934        for (qid, state) in queries {
935            if state.need_heartbeat(now) {
936                query_ids.push(qid.to_string());
937                if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
938                    arr.push(qid);
939                } else {
940                    node_to_queries.insert(state.node_id, vec![qid]);
941                }
942            }
943        }
944
945        if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
946        {
947            return Ok(());
948        }
949
950        let body = json!({
951           "node_to_queries": node_to_queries
952        });
953        let builder = self.cli.post(endpoint.clone()).json(&body);
954        let request = self.wrap_auth_or_session_token(builder)?.build()?;
955        let response = self.query_request_helper(request, true, false).await?;
956        let json: Value = response.json().await?;
957        let session_id = self.session_id.as_str();
958        info!("[session {session_id}] heartbeat request={body}, response={json}");
959        if let Some(queries_to_remove) = json.get("queries_to_remove") {
960            if let Some(arr) = queries_to_remove.as_array() {
961                if !arr.is_empty() {
962                    let mut queries = self.queries_need_heartbeat.lock();
963                    for q in arr {
964                        if let Some(q) = q.as_str() {
965                            queries.remove(q);
966                        }
967                    }
968                }
969            }
970        }
971        let now = Instant::now();
972        let mut queries = self.queries_need_heartbeat.lock();
973        for qid in query_ids {
974            if let Some(state) = queries.get_mut(&qid) {
975                *state.last_access_time.lock() = now;
976            }
977        }
978        Ok(())
979    }
980
981    fn build_log_out_request(&self) -> Result<Request> {
982        let endpoint = self.endpoint.join("/v1/session/logout")?;
983
984        let session_state = self.session_state();
985        let need_sticky = session_state.need_sticky.unwrap_or(false);
986        let mut headers = self.make_headers(None)?;
987        if need_sticky {
988            if let Some(node_id) = self.last_node_id() {
989                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
990            }
991        }
992        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
993
994        let builder = self.wrap_auth_or_session_token(builder)?;
995        let req = builder.build()?;
996        Ok(req)
997    }
998
999    pub(crate) fn need_logout(&self) -> bool {
1000        self.session_token_info.is_some()
1001            || self.session_state.lock().need_keep_alive.unwrap_or(false)
1002    }
1003
1004    async fn refresh_session_token(
1005        &self,
1006        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1007    ) -> Result<()> {
1008        let (session_token_info, _) = { self_login_info.lock().clone() };
1009        let endpoint = self.endpoint.join("/v1/session/refresh")?;
1010        let body = RefreshSessionTokenRequest {
1011            session_token: session_token_info.session_token.clone(),
1012        };
1013        let headers = self.make_headers(None)?;
1014        let request = self
1015            .cli
1016            .post(endpoint.clone())
1017            .json(&body)
1018            .headers(headers.clone())
1019            .bearer_auth(session_token_info.refresh_token.clone())
1020            .timeout(self.connect_timeout)
1021            .build()?;
1022
1023        // avoid recursively call request_helper
1024        for i in 0..3 {
1025            let req = request.try_clone().expect("request not cloneable");
1026            match self.cli.execute(req).await {
1027                Ok(response) => {
1028                    let status = response.status();
1029                    let body = response.bytes().await?;
1030                    if status == StatusCode::OK {
1031                        let response = json_from_slice(&body)?;
1032                        return match response {
1033                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1034                            RefreshResponse::Ok(info) => {
1035                                *self_login_info.lock() = (info, Instant::now());
1036                                Ok(())
1037                            }
1038                        };
1039                    }
1040                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
1041                        return Err(Error::response_error(status, &body));
1042                    }
1043                }
1044                Err(err) => {
1045                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
1046                        return Err(Error::Request(err.to_string()));
1047                    }
1048                }
1049            };
1050            sleep(jitter(Duration::from_secs(10))).await;
1051        }
1052        Ok(())
1053    }
1054
1055    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1056        if let Some(info) = &self.session_token_info {
1057            let (start, ttl) = {
1058                let guard = info.lock();
1059                (guard.1, guard.0.session_token_ttl_in_secs)
1060            };
1061            if Instant::now() > start + Duration::from_secs(ttl) {
1062                return Some(info.clone());
1063            }
1064        }
1065        None
1066    }
1067
1068    /// return Ok if and only if status code is 200.
1069    ///
1070    /// retry on
1071    ///   - network errors
1072    ///   - (optional) 503
1073    ///
1074    /// refresh databend token or reload jwt token if needed.
1075    async fn query_request_helper(
1076        &self,
1077        mut request: Request,
1078        retry_if_503: bool,
1079        refresh_if_401: bool,
1080    ) -> std::result::Result<Response, Error> {
1081        let mut refreshed = false;
1082        let mut retries = 0;
1083        loop {
1084            let req = request.try_clone().expect("request not cloneable");
1085            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
1086                Ok(response) => {
1087                    let status = response.status();
1088                    if status == StatusCode::OK {
1089                        return Ok(response);
1090                    }
1091                    let body = response.bytes().await?;
1092                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1093                        // waiting for server to start
1094                        (Error::response_error(status, &body), true)
1095                    } else {
1096                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
1097                        match resp {
1098                            Ok(r) => {
1099                                let e = r.error;
1100                                if status == StatusCode::UNAUTHORIZED {
1101                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1102                                    if let Some(session_token_info) = &self.session_token_info {
1103                                        info!(
1104                                            "will retry {} after refresh token on auth error {}",
1105                                            request.url(),
1106                                            e
1107                                        );
1108                                        let retry = if need_refresh_token(e.code)
1109                                            && !refreshed
1110                                            && refresh_if_401
1111                                        {
1112                                            self.refresh_session_token(session_token_info.clone())
1113                                                .await?;
1114                                            refreshed = true;
1115                                            true
1116                                        } else {
1117                                            false
1118                                        };
1119                                        (Error::AuthFailure(e), retry)
1120                                    } else if self.auth.can_reload() {
1121                                        info!(
1122                                            "will retry {} after reload token on auth error {}",
1123                                            request.url(),
1124                                            e
1125                                        );
1126                                        let builder = RequestBuilder::from_parts(
1127                                            HttpClient::new(),
1128                                            request.try_clone().unwrap(),
1129                                        );
1130                                        let builder = self.auth.wrap(builder)?;
1131                                        request = builder.build()?;
1132                                        (Error::AuthFailure(e), true)
1133                                    } else {
1134                                        (Error::AuthFailure(e), false)
1135                                    }
1136                                } else {
1137                                    (Error::Logic(status, e), false)
1138                                }
1139                            }
1140                            Err(_) => (
1141                                Error::Response {
1142                                    status,
1143                                    msg: String::from_utf8_lossy(&body).to_string(),
1144                                },
1145                                false,
1146                            ),
1147                        }
1148                    }
1149                }
1150                Err(err) => (
1151                    Error::Request(err.to_string()),
1152                    err.is_timeout() || err.is_connect(),
1153                ),
1154            };
1155            if !retry {
1156                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
1157            }
1158            match &err {
1159                Error::AuthFailure(_) => {
1160                    if refreshed {
1161                        retries = 0;
1162                    } else if retries == 2 {
1163                        return Err(err.with_context(&format!(
1164                            "{} {} after 3 reties",
1165                            request.method(),
1166                            request.url()
1167                        )));
1168                    }
1169                }
1170                _ => {
1171                    if retries == 2 {
1172                        return Err(err.with_context(&format!(
1173                            "{} {} after 3 reties",
1174                            request.method(),
1175                            request.url()
1176                        )));
1177                    }
1178                    retries += 1;
1179                    info!(
1180                        "will retry {} the {retries}th times on error {}",
1181                        request.url(),
1182                        err
1183                    );
1184                }
1185            }
1186            sleep(jitter(Duration::from_secs(10))).await;
1187        }
1188    }
1189
1190    pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1191        if let Err(err) = http_client.execute(request).await {
1192            error!("[session {session_id}] logout request failed: {err}");
1193        } else {
1194            info!("[session {session_id}] logout success");
1195        };
1196    }
1197
1198    pub async fn close(&self) {
1199        let session_id = &self.session_id;
1200        info!("[session {session_id}] try closing now");
1201        if self
1202            .closed
1203            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1204            .is_ok()
1205        {
1206            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1207            if self.need_logout() {
1208                let cli = self.cli.clone();
1209                let req = self
1210                    .build_log_out_request()
1211                    .expect("failed to build logout request");
1212                Self::logout(cli, req, &self.session_id).await;
1213            }
1214        }
1215    }
1216    pub fn close_with_spawn(&self) {
1217        let session_id = &self.session_id;
1218        info!("[session {session_id}]: try closing with spawn");
1219        if self
1220            .closed
1221            .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1222            .is_ok()
1223        {
1224            GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1225            if self.need_logout() {
1226                let cli = self.cli.clone();
1227                let req = self
1228                    .build_log_out_request()
1229                    .expect("failed to build logout request");
1230                let session_id = self.session_id.clone();
1231                GLOBAL_RUNTIME.spawn(async move {
1232                    Self::logout(cli, req, session_id.as_str()).await;
1233                });
1234            }
1235        }
1236    }
1237
1238    pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1239        let mut queries = self.queries_need_heartbeat.lock();
1240        queries.insert(query_id.to_string(), state);
1241    }
1242}
1243
1244fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1245where
1246    T: Deserialize<'a>,
1247{
1248    serde_json::from_slice::<T>(body).map_err(|e| {
1249        Error::Decode(format!(
1250            "fail to decode JSON response: {e}, body: {}",
1251            String::from_utf8_lossy(body)
1252        ))
1253    })
1254}
1255
1256impl Default for APIClient {
1257    fn default() -> Self {
1258        Self {
1259            session_id: Default::default(),
1260            cli: HttpClient::new(),
1261            scheme: "http".to_string(),
1262            endpoint: Url::parse("http://localhost:8080").unwrap(),
1263            host: "localhost".to_string(),
1264            port: 8000,
1265            tenant: None,
1266            warehouse: Mutex::new(None),
1267            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1268            session_state: Mutex::new(SessionState::default()),
1269            wait_time_secs: None,
1270            max_rows_in_buffer: None,
1271            max_rows_per_page: None,
1272            connect_timeout: Duration::from_secs(10),
1273            page_request_timeout: Duration::from_secs(30),
1274            tls_ca_file: None,
1275            presign: Mutex::new(PresignMode::Auto),
1276            route_hint: RouteHintGenerator::new(),
1277            last_node_id: Default::default(),
1278            disable_session_token: true,
1279            disable_login: false,
1280            body_format: "json".to_string(),
1281            session_token_info: None,
1282            closed: AtomicBool::new(false),
1283            last_query_id: Default::default(),
1284            server_version: None,
1285            capability: Default::default(),
1286            queries_need_heartbeat: Default::default(),
1287        }
1288    }
1289}
1290
1291struct RouteHintGenerator {
1292    nonce: AtomicU64,
1293    current: std::sync::Mutex<String>,
1294}
1295
1296impl RouteHintGenerator {
1297    fn new() -> Self {
1298        let gen = Self {
1299            nonce: AtomicU64::new(0),
1300            current: std::sync::Mutex::new("".to_string()),
1301        };
1302        gen.next();
1303        gen
1304    }
1305
1306    fn current(&self) -> String {
1307        let guard = self.current.lock().unwrap();
1308        guard.clone()
1309    }
1310
1311    fn set(&self, hint: &str) {
1312        let mut guard = self.current.lock().unwrap();
1313        *guard = hint.to_string();
1314    }
1315
1316    fn next(&self) -> String {
1317        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1318        let uuid = uuid::Uuid::new_v4();
1319        let current = format!("rh:{uuid}:{nonce:06}");
1320        let mut guard = self.current.lock().unwrap();
1321        guard.clone_from(&current);
1322        current
1323    }
1324}
1325
1326#[cfg(test)]
1327mod test {
1328    use super::*;
1329
1330    #[tokio::test]
1331    async fn parse_dsn() -> Result<()> {
1332        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1333        let client = APIClient::from_dsn(dsn).await?;
1334        assert_eq!(client.host, "app.databend.com");
1335        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1336        assert_eq!(client.wait_time_secs, Some(10));
1337        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1338        assert_eq!(client.max_rows_per_page, Some(10000));
1339        assert_eq!(client.tenant, None);
1340        assert_eq!(
1341            *client.warehouse.try_lock().unwrap(),
1342            Some("wh".to_string())
1343        );
1344        Ok(())
1345    }
1346
1347    #[tokio::test]
1348    async fn parse_encoded_password() -> Result<()> {
1349        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1350        let client = APIClient::from_dsn(dsn).await?;
1351        assert_eq!(client.host(), "localhost");
1352        assert_eq!(client.port(), 443);
1353        Ok(())
1354    }
1355
1356    #[tokio::test]
1357    async fn parse_special_chars_password() -> Result<()> {
1358        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1359        let client = APIClient::from_dsn(dsn).await?;
1360        assert_eq!(client.host(), "localhost");
1361        assert_eq!(client.port(), 8000);
1362        Ok(())
1363    }
1364}