databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
22use crate::global_cookie_store::GlobalCookieStore;
23use crate::login::{
24    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
25    SessionTokenInfo,
26};
27use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
28use crate::stage::StageLocation;
29use crate::{
30    error::{Error, Result},
31    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
32    response::QueryResponse,
33    session::SessionState,
34    QueryStats,
35};
36use crate::{Page, Pages};
37use log::{debug, error, info, warn};
38use once_cell::sync::Lazy;
39use parking_lot::Mutex;
40use percent_encoding::percent_decode_str;
41use reqwest::cookie::CookieStore;
42use reqwest::header::{HeaderMap, HeaderValue};
43use reqwest::multipart::{Form, Part};
44use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
45use serde::Deserialize;
46use tokio::time::sleep;
47use tokio_retry::strategy::jitter;
48use tokio_stream::StreamExt;
49use tokio_util::io::ReaderStream;
50use url::Url;
51
52const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
53const HEADER_TENANT: &str = "X-DATABEND-TENANT";
54const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
55const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
56const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
57const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
58const TXN_STATE_ACTIVE: &str = "Active";
59
60static VERSION: Lazy<String> = Lazy::new(|| {
61    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
62    version.to_string()
63});
64
65pub struct APIClient {
66    cli: HttpClient,
67    scheme: String,
68    host: String,
69    port: u16,
70
71    endpoint: Url,
72
73    auth: Arc<dyn Auth>,
74
75    tenant: Option<String>,
76    warehouse: Mutex<Option<String>>,
77    session_state: Mutex<SessionState>,
78    route_hint: RouteHintGenerator,
79
80    disable_login: bool,
81    disable_session_token: bool,
82    session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,
83
84    closed: AtomicBool,
85
86    server_version: Option<String>,
87
88    wait_time_secs: Option<i64>,
89    max_rows_in_buffer: Option<i64>,
90    max_rows_per_page: Option<i64>,
91
92    connect_timeout: Duration,
93    page_request_timeout: Duration,
94
95    tls_ca_file: Option<String>,
96
97    presign: Mutex<PresignMode>,
98    last_node_id: Mutex<Option<String>>,
99    last_query_id: Mutex<Option<String>>,
100}
101
102impl APIClient {
103    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
104        let mut client = Self::from_dsn(dsn).await?;
105        client.build_client(name).await?;
106        if !client.disable_login {
107            client.login().await?;
108        }
109        let client = Arc::new(client);
110        client.check_presign().await?;
111        Ok(client)
112    }
113
114    fn set_presign_mode(&self, mode: PresignMode) {
115        *self.presign.lock() = mode
116    }
117    fn get_presign_mode(&self) -> PresignMode {
118        *self.presign.lock()
119    }
120
121    async fn from_dsn(dsn: &str) -> Result<Self> {
122        let u = Url::parse(dsn)?;
123        let mut client = Self::default();
124        if let Some(host) = u.host_str() {
125            client.host = host.to_string();
126        }
127
128        if u.username() != "" {
129            let password = u.password().unwrap_or_default();
130            let password = percent_decode_str(password).decode_utf8()?;
131            client.auth = Arc::new(BasicAuth::new(u.username(), password));
132        }
133        let database = match u.path().trim_start_matches('/') {
134            "" => None,
135            s => Some(s.to_string()),
136        };
137        let mut role = None;
138        let mut scheme = "https";
139        let mut session_settings = BTreeMap::new();
140        for (k, v) in u.query_pairs() {
141            match k.as_ref() {
142                "wait_time_secs" => {
143                    client.wait_time_secs = Some(v.parse()?);
144                }
145                "max_rows_in_buffer" => {
146                    client.max_rows_in_buffer = Some(v.parse()?);
147                }
148                "max_rows_per_page" => {
149                    client.max_rows_per_page = Some(v.parse()?);
150                }
151                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
152                "page_request_timeout_secs" => {
153                    client.page_request_timeout = {
154                        let secs: u64 = v.parse()?;
155                        Duration::from_secs(secs)
156                    };
157                }
158                "presign" => {
159                    let presign_mode = match v.as_ref() {
160                        "auto" => PresignMode::Auto,
161                        "detect" => PresignMode::Detect,
162                        "on" => PresignMode::On,
163                        "off" => PresignMode::Off,
164                        _ => {
165                            return Err(Error::BadArgument(format!(
166                            "Invalid value for presign: {}, should be one of auto/detect/on/off",
167                            v
168                        )))
169                        }
170                    };
171                    client.set_presign_mode(presign_mode);
172                }
173                "tenant" => {
174                    client.tenant = Some(v.to_string());
175                }
176                "warehouse" => {
177                    client.warehouse = Mutex::new(Some(v.to_string()));
178                }
179                "role" => role = Some(v.to_string()),
180                "sslmode" => match v.as_ref() {
181                    "disable" => scheme = "http",
182                    "require" | "enable" => scheme = "https",
183                    _ => {
184                        return Err(Error::BadArgument(format!(
185                            "Invalid value for sslmode: {}",
186                            v
187                        )))
188                    }
189                },
190                "tls_ca_file" => {
191                    client.tls_ca_file = Some(v.to_string());
192                }
193                "access_token" => {
194                    client.auth = Arc::new(AccessTokenAuth::new(v));
195                }
196                "access_token_file" => {
197                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
198                }
199                "login" => {
200                    client.disable_login = match v.as_ref() {
201                        "disable" => true,
202                        "enable" => false,
203                        _ => {
204                            return Err(Error::BadArgument(format!(
205                                "Invalid value for login: {}",
206                                v
207                            )))
208                        }
209                    }
210                }
211                "session_token" => {
212                    client.disable_session_token = match v.as_ref() {
213                        "disable" => true,
214                        "enable" => false,
215                        _ => {
216                            return Err(Error::BadArgument(format!(
217                                "Invalid value for session_token: {}",
218                                v
219                            )))
220                        }
221                    }
222                }
223                _ => {
224                    session_settings.insert(k.to_string(), v.to_string());
225                }
226            }
227        }
228        client.port = match u.port() {
229            Some(p) => p,
230            None => match scheme {
231                "http" => 80,
232                "https" => 443,
233                _ => unreachable!(),
234            },
235        };
236        client.scheme = scheme.to_string();
237
238        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
239        client.session_state = Mutex::new(
240            SessionState::default()
241                .with_settings(Some(session_settings))
242                .with_role(role)
243                .with_database(database),
244        );
245
246        Ok(client)
247    }
248
249    pub fn host(&self) -> &str {
250        self.host.as_str()
251    }
252
253    pub fn port(&self) -> u16 {
254        self.port
255    }
256
257    pub fn scheme(&self) -> &str {
258        self.scheme.as_str()
259    }
260
261    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
262        let ua = match name {
263            Some(n) => n,
264            None => format!("databend-client-rust/{}", VERSION.as_str()),
265        };
266        let cookie_provider = GlobalCookieStore::new();
267        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
268        let mut initial_cookies = [&cookie].into_iter();
269        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
270        let mut cli_builder = HttpClient::builder()
271            .user_agent(ua)
272            .cookie_provider(Arc::new(cookie_provider))
273            .pool_idle_timeout(Duration::from_secs(1));
274        #[cfg(any(feature = "rustls", feature = "native-tls"))]
275        if self.scheme == "https" {
276            if let Some(ref ca_file) = self.tls_ca_file {
277                let cert_pem = tokio::fs::read(ca_file).await?;
278                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
279                cli_builder = cli_builder.add_root_certificate(cert);
280            }
281        }
282        self.cli = cli_builder.build()?;
283        Ok(())
284    }
285
286    async fn check_presign(self: &Arc<Self>) -> Result<()> {
287        let mode = match self.get_presign_mode() {
288            PresignMode::Auto => {
289                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
290                    PresignMode::On
291                } else {
292                    PresignMode::Off
293                }
294            }
295            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
296                Ok(_) => PresignMode::On,
297                Err(e) => {
298                    warn!("presign mode off with error detected: {}", e);
299                    PresignMode::Off
300                }
301            },
302            mode => mode,
303        };
304        self.set_presign_mode(mode);
305        Ok(())
306    }
307
308    pub fn current_warehouse(&self) -> Option<String> {
309        let guard = self.warehouse.lock();
310        guard.clone()
311    }
312
313    pub fn current_catalog(&self) -> Option<String> {
314        let guard = self.session_state.lock();
315        guard.catalog.clone()
316    }
317
318    pub fn current_database(&self) -> Option<String> {
319        let guard = self.session_state.lock();
320        guard.database.clone()
321    }
322
323    pub async fn current_role(&self) -> Option<String> {
324        let guard = self.session_state.lock();
325        guard.role.clone()
326    }
327
328    fn in_active_transaction(&self) -> bool {
329        let guard = self.session_state.lock();
330        guard
331            .txn_state
332            .as_ref()
333            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
334            .unwrap_or(false)
335    }
336
337    pub fn username(&self) -> String {
338        self.auth.username()
339    }
340
341    fn gen_query_id(&self) -> String {
342        uuid::Uuid::new_v4().to_string()
343    }
344
345    async fn handle_session(&self, session: &Option<SessionState>) {
346        let session = match session {
347            Some(session) => session,
348            None => return,
349        };
350
351        // save the updated session state from the server side
352        {
353            let mut session_state = self.session_state.lock();
354            *session_state = session.clone();
355        }
356
357        // process warehouse changed via session settings
358        if let Some(settings) = session.settings.as_ref() {
359            if let Some(v) = settings.get("warehouse") {
360                let mut warehouse = self.warehouse.lock();
361                *warehouse = Some(v.clone());
362            }
363        }
364    }
365
366    pub fn set_last_node_id(&self, node_id: String) {
367        *self.last_node_id.lock() = Some(node_id)
368    }
369
370    pub fn set_last_query_id(&self, query_id: Option<String>) {
371        *self.last_query_id.lock() = query_id
372    }
373
374    pub fn last_query_id(&self) -> Option<String> {
375        self.last_query_id.lock().clone()
376    }
377
378    fn last_node_id(&self) -> Option<String> {
379        self.last_node_id.lock().clone()
380    }
381
382    fn handle_warnings(&self, resp: &QueryResponse) {
383        if let Some(warnings) = &resp.warnings {
384            for w in warnings {
385                warn!(target: "server_warnings", "server warning: {}", w);
386            }
387        }
388    }
389
390    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
391        info!("start query: {}", sql);
392        let resp = self.start_query_inner(sql, None).await?;
393        let pages = Pages::new(self.clone(), resp, need_progress);
394        Ok(pages)
395    }
396
397    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
398        if let Some(info) = &self.session_token_info {
399            let info = info.lock();
400            Ok(builder.bearer_auth(info.0.session_token.clone()))
401        } else {
402            self.auth.wrap(builder)
403        }
404    }
405
406    async fn start_query_inner(
407        &self,
408        sql: &str,
409        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
410    ) -> Result<QueryResponse> {
411        if !self.in_active_transaction() {
412            self.route_hint.next();
413        }
414        let endpoint = self.endpoint.join("v1/query")?;
415
416        // body
417        let session_state = self.session_state();
418        let need_sticky = session_state.need_sticky.unwrap_or(false);
419        let req = QueryRequest::new(sql)
420            .with_pagination(self.make_pagination())
421            .with_session(Some(session_state))
422            .with_stage_attachment(stage_attachment_config);
423
424        // headers
425        let query_id = self.gen_query_id();
426        let mut headers = self.make_headers(Some(&query_id))?;
427        if need_sticky {
428            if let Some(node_id) = self.last_node_id() {
429                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
430            }
431        }
432        let mut builder = self.cli.post(endpoint.clone()).json(&req);
433        builder = self.wrap_auth_or_session_token(builder)?;
434        let request = builder.headers(headers.clone()).build()?;
435        let response = self.query_request_helper(request, true, true).await?;
436        if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
437            self.route_hint.set(route_hint.to_str().unwrap_or_default());
438        }
439        let body = response.bytes().await?;
440        let result: QueryResponse = json_from_slice(&body)?;
441        self.handle_session(&result.session).await;
442        if let Some(err) = result.error {
443            return Err(Error::QueryFailed(err));
444        }
445
446        self.set_last_query_id(Some(query_id));
447        self.handle_warnings(&result);
448        if let Some(node_id) = &result.node_id {
449            self.set_last_node_id(node_id.clone());
450        }
451        Ok(result)
452    }
453
454    pub async fn query_page(
455        &self,
456        query_id: &str,
457        next_uri: &str,
458        node_id: &Option<String>,
459    ) -> Result<QueryResponse> {
460        info!("query page: {}", next_uri);
461        let endpoint = self.endpoint.join(next_uri)?;
462        let headers = self.make_headers(Some(query_id))?;
463        let mut builder = self.cli.get(endpoint.clone());
464        builder = self
465            .wrap_auth_or_session_token(builder)?
466            .headers(headers.clone())
467            .timeout(self.page_request_timeout);
468        if let Some(node_id) = node_id {
469            builder = builder.header(HEADER_STICKY_NODE, node_id)
470        }
471        let request = builder.build()?;
472
473        let response = self.query_request_helper(request, false, true).await?;
474        let body = response.bytes().await?;
475        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
476            if let Error::Logic(status, ec) = &e {
477                if *status == 404 {
478                    return Error::QueryNotFound(ec.message.clone());
479                }
480            }
481            e
482        })?;
483        self.handle_session(&resp.session).await;
484        match resp.error {
485            Some(err) => Err(Error::QueryFailed(err)),
486            None => Ok(resp),
487        }
488    }
489
490    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
491        let kill_uri = format!("/v1/query/{}/kill", query_id);
492        let endpoint = self.endpoint.join(&kill_uri)?;
493        let headers = self.make_headers(Some(query_id))?;
494        info!("kill query: {}", kill_uri);
495
496        let mut builder = self.cli.post(endpoint);
497        builder = self.wrap_auth_or_session_token(builder)?;
498        let resp = builder.headers(headers.clone()).send().await?;
499        if resp.status() != 200 {
500            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
501                .with_context("kill query"));
502        }
503        Ok(())
504    }
505
506    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
507        let mut pages = self.start_query(sql, false).await?;
508        let mut all = Page::default();
509        while let Some(page) = pages.next().await {
510            all.update(page?);
511        }
512        Ok(all)
513    }
514
515    fn session_state(&self) -> SessionState {
516        self.session_state.lock().clone()
517    }
518
519    fn make_pagination(&self) -> Option<PaginationConfig> {
520        if self.wait_time_secs.is_none()
521            && self.max_rows_in_buffer.is_none()
522            && self.max_rows_per_page.is_none()
523        {
524            return None;
525        }
526        let mut pagination = PaginationConfig {
527            wait_time_secs: None,
528            max_rows_in_buffer: None,
529            max_rows_per_page: None,
530        };
531        if let Some(wait_time_secs) = self.wait_time_secs {
532            pagination.wait_time_secs = Some(wait_time_secs);
533        }
534        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
535            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
536        }
537        if let Some(max_rows_per_page) = self.max_rows_per_page {
538            pagination.max_rows_per_page = Some(max_rows_per_page);
539        }
540        Some(pagination)
541    }
542
543    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
544        let mut headers = HeaderMap::new();
545        if let Some(tenant) = &self.tenant {
546            headers.insert(HEADER_TENANT, tenant.parse()?);
547        }
548        let warehouse = self.warehouse.lock().clone();
549        if let Some(warehouse) = warehouse {
550            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
551        }
552        let route_hint = self.route_hint.current();
553        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
554        if let Some(query_id) = query_id {
555            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
556        }
557        Ok(headers)
558    }
559
560    pub async fn insert_with_stage(
561        self: &Arc<Self>,
562        sql: &str,
563        stage: &str,
564        file_format_options: BTreeMap<&str, &str>,
565        copy_options: BTreeMap<&str, &str>,
566    ) -> Result<QueryStats> {
567        info!(
568            "insert with stage: {}, format: {:?}, copy: {:?}",
569            sql, file_format_options, copy_options
570        );
571        let stage_attachment = Some(StageAttachmentConfig {
572            location: stage,
573            file_format_options: Some(file_format_options),
574            copy_options: Some(copy_options),
575        });
576        let resp = self.start_query_inner(sql, stage_attachment).await?;
577        let mut pages = Pages::new(self.clone(), resp, false);
578        let mut all = Page::default();
579        while let Some(page) = pages.next().await {
580            all.update(page?);
581        }
582        Ok(all.stats)
583    }
584
585    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
586        info!("get presigned upload url: {}", stage);
587        let sql = format!("PRESIGN UPLOAD {}", stage);
588        let resp = self.query_all(&sql).await?;
589        if resp.data.len() != 1 {
590            return Err(Error::Decode(
591                "Empty response from server for presigned request".to_string(),
592            ));
593        }
594        if resp.data[0].len() != 3 {
595            return Err(Error::Decode(
596                "Invalid response from server for presigned request".to_string(),
597            ));
598        }
599        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
600        let method = resp.data[0][0].clone().unwrap_or_default();
601        if method != "PUT" {
602            return Err(Error::Decode(format!(
603                "Invalid method for presigned upload request: {}",
604                method
605            )));
606        }
607        let headers: BTreeMap<String, String> =
608            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
609        let url = resp.data[0][2].clone().unwrap_or_default();
610        Ok(PresignedResponse {
611            method,
612            headers,
613            url,
614        })
615    }
616
617    pub async fn upload_to_stage(
618        self: &Arc<Self>,
619        stage: &str,
620        data: Reader,
621        size: u64,
622    ) -> Result<()> {
623        match self.get_presign_mode() {
624            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
625            PresignMode::On => {
626                let presigned = self.get_presigned_upload_url(stage).await?;
627                presign_upload_to_stage(presigned, data, size).await
628            }
629            PresignMode::Auto => {
630                unreachable!("PresignMode::Auto should be handled during client initialization")
631            }
632            PresignMode::Detect => {
633                unreachable!("PresignMode::Detect should be handled during client initialization")
634            }
635        }
636    }
637
638    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
639    async fn upload_to_stage_with_stream(
640        &self,
641        stage: &str,
642        data: Reader,
643        size: u64,
644    ) -> Result<()> {
645        info!("upload to stage with stream: {}, size: {}", stage, size);
646        if let Some(info) = self.need_pre_refresh_session().await {
647            self.refresh_session_token(info).await?;
648        }
649        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
650        let location = StageLocation::try_from(stage)?;
651        let query_id = self.gen_query_id();
652        let mut headers = self.make_headers(Some(&query_id))?;
653        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
654        let stream = Body::wrap_stream(ReaderStream::new(data));
655        let part = Part::stream_with_length(stream, size).file_name(location.path);
656        let form = Form::new().part("upload", part);
657        let mut builder = self.cli.put(endpoint.clone());
658        builder = self.wrap_auth_or_session_token(builder)?;
659        let resp = builder.headers(headers).multipart(form).send().await?;
660        let status = resp.status();
661        if status != 200 {
662            return Err(
663                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
664            );
665        }
666        Ok(())
667    }
668
669    async fn login(&mut self) -> Result<()> {
670        let endpoint = self.endpoint.join("/v1/session/login")?;
671        let headers = self.make_headers(None)?;
672        let body = LoginRequest::from(&*self.session_state.lock());
673        let mut builder = self.cli.post(endpoint.clone()).json(&body);
674        if self.disable_session_token {
675            builder = builder.query(&[("disable_session_token", true)]);
676        }
677        let builder = self.auth.wrap(builder)?;
678        let request = builder
679            .headers(headers.clone())
680            .timeout(self.connect_timeout)
681            .build()?;
682        let response = self.query_request_helper(request, true, false).await;
683        let response = match response {
684            Ok(r) => r,
685            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
686                info!("login return 404, skip login on the old version server");
687                return Ok(());
688            }
689            Err(e) => return Err(e),
690        };
691        let body = response.bytes().await?;
692        let response = json_from_slice(&body)?;
693        match response {
694            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
695            LoginResponseResult::Ok(info) => {
696                self.server_version = Some(info.version.clone());
697                if let Some(tokens) = info.tokens {
698                    info!("login success with session token");
699                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
700                }
701                info!("login success without session token");
702            }
703        }
704        Ok(())
705    }
706
707    fn build_log_out_request(&self) -> Result<Request> {
708        let endpoint = self.endpoint.join("/v1/session/logout")?;
709
710        let session_state = self.session_state();
711        let need_sticky = session_state.need_sticky.unwrap_or(false);
712        let mut headers = self.make_headers(None)?;
713        if need_sticky {
714            if let Some(node_id) = self.last_node_id() {
715                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
716            }
717        }
718        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
719
720        let builder = self.wrap_auth_or_session_token(builder)?;
721        let req = builder.build()?;
722        Ok(req)
723    }
724
725    fn need_logout(&self) -> bool {
726        (self.session_token_info.is_some()
727            || self.session_state.lock().need_keep_alive.unwrap_or(false))
728            && self
729                .closed
730                .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
731                .is_ok()
732    }
733
734    async fn refresh_session_token(
735        &self,
736        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
737    ) -> Result<()> {
738        let (session_token_info, _) = { self_login_info.lock().clone() };
739        let endpoint = self.endpoint.join("/v1/session/refresh")?;
740        let body = RefreshSessionTokenRequest {
741            session_token: session_token_info.session_token.clone(),
742        };
743        let headers = self.make_headers(None)?;
744        let request = self
745            .cli
746            .post(endpoint.clone())
747            .json(&body)
748            .headers(headers.clone())
749            .bearer_auth(session_token_info.refresh_token.clone())
750            .timeout(self.connect_timeout)
751            .build()?;
752
753        // avoid recursively call request_helper
754        for i in 0..3 {
755            let req = request.try_clone().expect("request not cloneable");
756            match self.cli.execute(req).await {
757                Ok(response) => {
758                    let status = response.status();
759                    let body = response.bytes().await?;
760                    if status == StatusCode::OK {
761                        let response = json_from_slice(&body)?;
762                        return match response {
763                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
764                            RefreshResponse::Ok(info) => {
765                                *self_login_info.lock() = (info, Instant::now());
766                                Ok(())
767                            }
768                        };
769                    }
770                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
771                        return Err(Error::response_error(status, &body));
772                    }
773                }
774                Err(err) => {
775                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
776                        return Err(Error::Request(err.to_string()));
777                    }
778                }
779            };
780            sleep(jitter(Duration::from_secs(10))).await;
781        }
782        Ok(())
783    }
784
785    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
786        if let Some(info) = &self.session_token_info {
787            let (start, ttl) = {
788                let guard = info.lock();
789                (guard.1, guard.0.session_token_ttl_in_secs)
790            };
791            if Instant::now() > start + Duration::from_secs(ttl) {
792                return Some(info.clone());
793            }
794        }
795        None
796    }
797
798    /// return Ok if and only if status code is 200.
799    ///
800    /// retry on
801    ///   - network errors
802    ///   - (optional) 503
803    ///
804    /// refresh databend token or reload jwt token if needed.
805    async fn query_request_helper(
806        &self,
807        mut request: Request,
808        retry_if_503: bool,
809        refresh_if_401: bool,
810    ) -> std::result::Result<Response, Error> {
811        let mut refreshed = false;
812        let mut retries = 0;
813        loop {
814            let req = request.try_clone().expect("request not cloneable");
815            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
816                Ok(response) => {
817                    let status = response.status();
818                    if status == StatusCode::OK {
819                        return Ok(response);
820                    }
821                    let body = response.bytes().await?;
822                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
823                        // waiting for server to start
824                        (Error::response_error(status, &body), true)
825                    } else {
826                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
827                        match resp {
828                            Ok(r) => {
829                                let e = r.error;
830                                if status == StatusCode::UNAUTHORIZED {
831                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
832                                    if let Some(session_token_info) = &self.session_token_info {
833                                        info!(
834                                            "will retry {} after refresh token on auth error {}",
835                                            request.url(),
836                                            e
837                                        );
838                                        let retry = if need_refresh_token(e.code)
839                                            && !refreshed
840                                            && refresh_if_401
841                                        {
842                                            self.refresh_session_token(session_token_info.clone())
843                                                .await?;
844                                            refreshed = true;
845                                            true
846                                        } else {
847                                            false
848                                        };
849                                        (Error::AuthFailure(e), retry)
850                                    } else if self.auth.can_reload() {
851                                        info!(
852                                            "will retry {} after reload token on auth error {}",
853                                            request.url(),
854                                            e
855                                        );
856                                        let builder = RequestBuilder::from_parts(
857                                            HttpClient::new(),
858                                            request.try_clone().unwrap(),
859                                        );
860                                        let builder = self.auth.wrap(builder)?;
861                                        request = builder.build()?;
862                                        (Error::AuthFailure(e), true)
863                                    } else {
864                                        (Error::AuthFailure(e), false)
865                                    }
866                                } else {
867                                    (Error::Logic(status, e), false)
868                                }
869                            }
870                            Err(_) => (
871                                Error::Response {
872                                    status,
873                                    msg: String::from_utf8_lossy(&body).to_string(),
874                                },
875                                false,
876                            ),
877                        }
878                    }
879                }
880                Err(err) => (
881                    Error::Request(err.to_string()),
882                    err.is_timeout() || err.is_connect(),
883                ),
884            };
885            if !retry {
886                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
887            }
888            match &err {
889                Error::AuthFailure(_) => {
890                    if refreshed {
891                        retries = 0;
892                    } else if retries == 2 {
893                        return Err(err.with_context(&format!(
894                            "{} {} after 3 reties",
895                            request.method(),
896                            request.url()
897                        )));
898                    }
899                }
900                _ => {
901                    if retries == 2 {
902                        return Err(err.with_context(&format!(
903                            "{} {} after 3 reties",
904                            request.method(),
905                            request.url()
906                        )));
907                    }
908                    retries += 1;
909                    info!(
910                        "will retry {} the {retries}th times on error {}",
911                        request.url(),
912                        err
913                    );
914                }
915            }
916            sleep(jitter(Duration::from_secs(10))).await;
917        }
918    }
919
920    pub async fn close(&self) {
921        if self.need_logout() {
922            let cli = self.cli.clone();
923            let req = self
924                .build_log_out_request()
925                .expect("failed to build logout request");
926            if let Err(err) = cli.execute(req).await {
927                error!("logout request failed: {}", err);
928            } else {
929                debug!("logout success");
930            };
931        }
932    }
933}
934
935impl Drop for APIClient {
936    fn drop(&mut self) {
937        if self.need_logout() {
938            warn!("APIClient::close() was not called");
939        }
940    }
941}
942
943fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
944where
945    T: Deserialize<'a>,
946{
947    serde_json::from_slice::<T>(body).map_err(|e| {
948        Error::Decode(format!(
949            "fail to decode JSON response: {}, body: {}",
950            e,
951            String::from_utf8_lossy(body)
952        ))
953    })
954}
955
956impl Default for APIClient {
957    fn default() -> Self {
958        Self {
959            cli: HttpClient::new(),
960            scheme: "http".to_string(),
961            endpoint: Url::parse("http://localhost:8080").unwrap(),
962            host: "localhost".to_string(),
963            port: 8000,
964            tenant: None,
965            warehouse: Mutex::new(None),
966            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
967            session_state: Mutex::new(SessionState::default()),
968            wait_time_secs: None,
969            max_rows_in_buffer: None,
970            max_rows_per_page: None,
971            connect_timeout: Duration::from_secs(10),
972            page_request_timeout: Duration::from_secs(30),
973            tls_ca_file: None,
974            presign: Mutex::new(PresignMode::Auto),
975            route_hint: RouteHintGenerator::new(),
976            last_node_id: Default::default(),
977            disable_session_token: true,
978            disable_login: false,
979            session_token_info: None,
980            closed: Default::default(),
981            last_query_id: Default::default(),
982            server_version: None,
983        }
984    }
985}
986
987struct RouteHintGenerator {
988    nonce: AtomicU64,
989    current: std::sync::Mutex<String>,
990}
991
992impl RouteHintGenerator {
993    fn new() -> Self {
994        let gen = Self {
995            nonce: AtomicU64::new(0),
996            current: std::sync::Mutex::new("".to_string()),
997        };
998        gen.next();
999        gen
1000    }
1001
1002    fn current(&self) -> String {
1003        let guard = self.current.lock().unwrap();
1004        guard.clone()
1005    }
1006
1007    fn set(&self, hint: &str) {
1008        let mut guard = self.current.lock().unwrap();
1009        *guard = hint.to_string();
1010    }
1011
1012    fn next(&self) -> String {
1013        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1014        let uuid = uuid::Uuid::new_v4();
1015        let current = format!("rh:{}:{:06}", uuid, nonce);
1016        let mut guard = self.current.lock().unwrap();
1017        guard.clone_from(&current);
1018        current
1019    }
1020}
1021
1022#[cfg(test)]
1023mod test {
1024    use super::*;
1025
1026    #[tokio::test]
1027    async fn parse_dsn() -> Result<()> {
1028        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1029        let client = APIClient::from_dsn(dsn).await?;
1030        assert_eq!(client.host, "app.databend.com");
1031        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1032        assert_eq!(client.wait_time_secs, Some(10));
1033        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1034        assert_eq!(client.max_rows_per_page, Some(10000));
1035        assert_eq!(client.tenant, None);
1036        assert_eq!(
1037            *client.warehouse.try_lock().unwrap(),
1038            Some("wh".to_string())
1039        );
1040        Ok(())
1041    }
1042
1043    #[tokio::test]
1044    async fn parse_encoded_password() -> Result<()> {
1045        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1046        let client = APIClient::from_dsn(dsn).await?;
1047        assert_eq!(client.host(), "localhost");
1048        assert_eq!(client.port(), 443);
1049        Ok(())
1050    }
1051
1052    #[tokio::test]
1053    async fn parse_special_chars_password() -> Result<()> {
1054        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1055        let client = APIClient::from_dsn(dsn).await?;
1056        assert_eq!(client.host(), "localhost");
1057        assert_eq!(client.port(), 8000);
1058        Ok(())
1059    }
1060}