databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
22use crate::global_cookie_store::GlobalCookieStore;
23use crate::login::{
24    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
25    SessionTokenInfo,
26};
27use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
28use crate::stage::StageLocation;
29use crate::{
30    error::{Error, Result},
31    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
32    response::QueryResponse,
33    session::SessionState,
34    QueryStats,
35};
36use crate::{Page, Pages};
37use log::{debug, error, info, warn};
38use once_cell::sync::Lazy;
39use parking_lot::Mutex;
40use percent_encoding::percent_decode_str;
41use reqwest::cookie::CookieStore;
42use reqwest::header::{HeaderMap, HeaderValue};
43use reqwest::multipart::{Form, Part};
44use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
45use serde::Deserialize;
46use tokio::time::sleep;
47use tokio_retry::strategy::jitter;
48use tokio_stream::StreamExt;
49use tokio_util::io::ReaderStream;
50use url::Url;
51
52const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
53const HEADER_TENANT: &str = "X-DATABEND-TENANT";
54const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
55const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
56const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
57const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
58const TXN_STATE_ACTIVE: &str = "Active";
59
60static VERSION: Lazy<String> = Lazy::new(|| {
61    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
62    version.to_string()
63});
64
65pub struct APIClient {
66    cli: HttpClient,
67    scheme: String,
68    host: String,
69    port: u16,
70
71    endpoint: Url,
72
73    auth: Arc<dyn Auth>,
74
75    tenant: Option<String>,
76    warehouse: Mutex<Option<String>>,
77    session_state: Mutex<SessionState>,
78    route_hint: RouteHintGenerator,
79
80    disable_login: bool,
81    disable_session_token: bool,
82    session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,
83
84    closed: AtomicBool,
85
86    server_version: Option<String>,
87
88    wait_time_secs: Option<i64>,
89    max_rows_in_buffer: Option<i64>,
90    max_rows_per_page: Option<i64>,
91
92    connect_timeout: Duration,
93    page_request_timeout: Duration,
94
95    tls_ca_file: Option<String>,
96
97    presign: Mutex<PresignMode>,
98    last_node_id: Mutex<Option<String>>,
99    last_query_id: Mutex<Option<String>>,
100}
101
102impl APIClient {
103    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
104        let mut client = Self::from_dsn(dsn).await?;
105        client.build_client(name).await?;
106        if !client.disable_login {
107            client.login().await?;
108        }
109        let client = Arc::new(client);
110        client.check_presign().await?;
111        Ok(client)
112    }
113
114    fn set_presign_mode(&self, mode: PresignMode) {
115        *self.presign.lock() = mode
116    }
117    fn get_presign_mode(&self) -> PresignMode {
118        *self.presign.lock()
119    }
120
121    async fn from_dsn(dsn: &str) -> Result<Self> {
122        let u = Url::parse(dsn)?;
123        let mut client = Self::default();
124        if let Some(host) = u.host_str() {
125            client.host = host.to_string();
126        }
127
128        if u.username() != "" {
129            let password = u.password().unwrap_or_default();
130            let password = percent_decode_str(password).decode_utf8()?;
131            client.auth = Arc::new(BasicAuth::new(u.username(), password));
132        }
133        let database = match u.path().trim_start_matches('/') {
134            "" => None,
135            s => Some(s.to_string()),
136        };
137        let mut role = None;
138        let mut scheme = "https";
139        let mut session_settings = BTreeMap::new();
140        for (k, v) in u.query_pairs() {
141            match k.as_ref() {
142                "wait_time_secs" => {
143                    client.wait_time_secs = Some(v.parse()?);
144                }
145                "max_rows_in_buffer" => {
146                    client.max_rows_in_buffer = Some(v.parse()?);
147                }
148                "max_rows_per_page" => {
149                    client.max_rows_per_page = Some(v.parse()?);
150                }
151                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
152                "page_request_timeout_secs" => {
153                    client.page_request_timeout = {
154                        let secs: u64 = v.parse()?;
155                        Duration::from_secs(secs)
156                    };
157                }
158                "presign" => {
159                    let presign_mode = match v.as_ref() {
160                        "auto" => PresignMode::Auto,
161                        "detect" => PresignMode::Detect,
162                        "on" => PresignMode::On,
163                        "off" => PresignMode::Off,
164                        _ => {
165                            return Err(Error::BadArgument(format!(
166                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
167                        )))
168                        }
169                    };
170                    client.set_presign_mode(presign_mode);
171                }
172                "tenant" => {
173                    client.tenant = Some(v.to_string());
174                }
175                "warehouse" => {
176                    client.warehouse = Mutex::new(Some(v.to_string()));
177                }
178                "role" => role = Some(v.to_string()),
179                "sslmode" => match v.as_ref() {
180                    "disable" => scheme = "http",
181                    "require" | "enable" => scheme = "https",
182                    _ => {
183                        return Err(Error::BadArgument(format!(
184                            "Invalid value for sslmode: {v}"
185                        )))
186                    }
187                },
188                "tls_ca_file" => {
189                    client.tls_ca_file = Some(v.to_string());
190                }
191                "access_token" => {
192                    client.auth = Arc::new(AccessTokenAuth::new(v));
193                }
194                "access_token_file" => {
195                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
196                }
197                "login" => {
198                    client.disable_login = match v.as_ref() {
199                        "disable" => true,
200                        "enable" => false,
201                        _ => {
202                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
203                        }
204                    }
205                }
206                "session_token" => {
207                    client.disable_session_token = match v.as_ref() {
208                        "disable" => true,
209                        "enable" => false,
210                        _ => {
211                            return Err(Error::BadArgument(format!(
212                                "Invalid value for session_token: {v}"
213                            )))
214                        }
215                    }
216                }
217                _ => {
218                    session_settings.insert(k.to_string(), v.to_string());
219                }
220            }
221        }
222        client.port = match u.port() {
223            Some(p) => p,
224            None => match scheme {
225                "http" => 80,
226                "https" => 443,
227                _ => unreachable!(),
228            },
229        };
230        client.scheme = scheme.to_string();
231
232        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
233        client.session_state = Mutex::new(
234            SessionState::default()
235                .with_settings(Some(session_settings))
236                .with_role(role)
237                .with_database(database),
238        );
239
240        Ok(client)
241    }
242
243    pub fn host(&self) -> &str {
244        self.host.as_str()
245    }
246
247    pub fn port(&self) -> u16 {
248        self.port
249    }
250
251    pub fn scheme(&self) -> &str {
252        self.scheme.as_str()
253    }
254
255    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
256        let ua = match name {
257            Some(n) => n,
258            None => format!("databend-client-rust/{}", VERSION.as_str()),
259        };
260        let cookie_provider = GlobalCookieStore::new();
261        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
262        let mut initial_cookies = [&cookie].into_iter();
263        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
264        let mut cli_builder = HttpClient::builder()
265            .user_agent(ua)
266            .cookie_provider(Arc::new(cookie_provider))
267            .pool_idle_timeout(Duration::from_secs(1));
268        #[cfg(any(feature = "rustls", feature = "native-tls"))]
269        if self.scheme == "https" {
270            if let Some(ref ca_file) = self.tls_ca_file {
271                let cert_pem = tokio::fs::read(ca_file).await?;
272                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
273                cli_builder = cli_builder.add_root_certificate(cert);
274            }
275        }
276        self.cli = cli_builder.build()?;
277        Ok(())
278    }
279
280    async fn check_presign(self: &Arc<Self>) -> Result<()> {
281        let mode = match self.get_presign_mode() {
282            PresignMode::Auto => {
283                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
284                    PresignMode::On
285                } else {
286                    PresignMode::Off
287                }
288            }
289            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
290                Ok(_) => PresignMode::On,
291                Err(e) => {
292                    warn!("presign mode off with error detected: {e}");
293                    PresignMode::Off
294                }
295            },
296            mode => mode,
297        };
298        self.set_presign_mode(mode);
299        Ok(())
300    }
301
302    pub fn current_warehouse(&self) -> Option<String> {
303        let guard = self.warehouse.lock();
304        guard.clone()
305    }
306
307    pub fn current_catalog(&self) -> Option<String> {
308        let guard = self.session_state.lock();
309        guard.catalog.clone()
310    }
311
312    pub fn current_database(&self) -> Option<String> {
313        let guard = self.session_state.lock();
314        guard.database.clone()
315    }
316
317    pub async fn current_role(&self) -> Option<String> {
318        let guard = self.session_state.lock();
319        guard.role.clone()
320    }
321
322    fn in_active_transaction(&self) -> bool {
323        let guard = self.session_state.lock();
324        guard
325            .txn_state
326            .as_ref()
327            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
328            .unwrap_or(false)
329    }
330
331    pub fn username(&self) -> String {
332        self.auth.username()
333    }
334
335    fn gen_query_id(&self) -> String {
336        uuid::Uuid::now_v7().simple().to_string()
337    }
338
339    async fn handle_session(&self, session: &Option<SessionState>) {
340        let session = match session {
341            Some(session) => session,
342            None => return,
343        };
344
345        // save the updated session state from the server side
346        {
347            let mut session_state = self.session_state.lock();
348            *session_state = session.clone();
349        }
350
351        // process warehouse changed via session settings
352        if let Some(settings) = session.settings.as_ref() {
353            if let Some(v) = settings.get("warehouse") {
354                let mut warehouse = self.warehouse.lock();
355                *warehouse = Some(v.clone());
356            }
357        }
358    }
359
360    pub fn set_last_node_id(&self, node_id: String) {
361        *self.last_node_id.lock() = Some(node_id)
362    }
363
364    pub fn set_last_query_id(&self, query_id: Option<String>) {
365        *self.last_query_id.lock() = query_id
366    }
367
368    pub fn last_query_id(&self) -> Option<String> {
369        self.last_query_id.lock().clone()
370    }
371
372    fn last_node_id(&self) -> Option<String> {
373        self.last_node_id.lock().clone()
374    }
375
376    fn handle_warnings(&self, resp: &QueryResponse) {
377        if let Some(warnings) = &resp.warnings {
378            for w in warnings {
379                warn!(target: "server_warnings", "server warning: {w}");
380            }
381        }
382    }
383
384    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
385        info!("start query: {sql}");
386        let resp = self.start_query_inner(sql, None).await?;
387        let pages = Pages::new(self.clone(), resp, need_progress);
388        Ok(pages)
389    }
390
391    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
392        if let Some(info) = &self.session_token_info {
393            let info = info.lock();
394            Ok(builder.bearer_auth(info.0.session_token.clone()))
395        } else {
396            self.auth.wrap(builder)
397        }
398    }
399
400    async fn start_query_inner(
401        &self,
402        sql: &str,
403        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
404    ) -> Result<QueryResponse> {
405        if !self.in_active_transaction() {
406            self.route_hint.next();
407        }
408        let endpoint = self.endpoint.join("v1/query")?;
409
410        // body
411        let session_state = self.session_state();
412        let need_sticky = session_state.need_sticky.unwrap_or(false);
413        let req = QueryRequest::new(sql)
414            .with_pagination(self.make_pagination())
415            .with_session(Some(session_state))
416            .with_stage_attachment(stage_attachment_config);
417
418        // headers
419        let query_id = self.gen_query_id();
420        let mut headers = self.make_headers(Some(&query_id))?;
421        if need_sticky {
422            if let Some(node_id) = self.last_node_id() {
423                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
424            }
425        }
426        let mut builder = self.cli.post(endpoint.clone()).json(&req);
427        builder = self.wrap_auth_or_session_token(builder)?;
428        let request = builder.headers(headers.clone()).build()?;
429        let response = self.query_request_helper(request, true, true).await?;
430        if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
431            self.route_hint.set(route_hint.to_str().unwrap_or_default());
432        }
433        let body = response.bytes().await?;
434        let result: QueryResponse = json_from_slice(&body)?;
435        self.handle_session(&result.session).await;
436        if let Some(err) = result.error {
437            return Err(Error::QueryFailed(err));
438        }
439
440        self.set_last_query_id(Some(query_id));
441        self.handle_warnings(&result);
442        if let Some(node_id) = &result.node_id {
443            self.set_last_node_id(node_id.clone());
444        }
445        Ok(result)
446    }
447
448    pub async fn query_page(
449        &self,
450        query_id: &str,
451        next_uri: &str,
452        node_id: &Option<String>,
453    ) -> Result<QueryResponse> {
454        info!("query page: {next_uri}");
455        let endpoint = self.endpoint.join(next_uri)?;
456        let headers = self.make_headers(Some(query_id))?;
457        let mut builder = self.cli.get(endpoint.clone());
458        builder = self
459            .wrap_auth_or_session_token(builder)?
460            .headers(headers.clone())
461            .timeout(self.page_request_timeout);
462        if let Some(node_id) = node_id {
463            builder = builder.header(HEADER_STICKY_NODE, node_id)
464        }
465        let request = builder.build()?;
466
467        let response = self.query_request_helper(request, false, true).await?;
468        let body = response.bytes().await?;
469        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
470            if let Error::Logic(status, ec) = &e {
471                if *status == 404 {
472                    return Error::QueryNotFound(ec.message.clone());
473                }
474            }
475            e
476        })?;
477        self.handle_session(&resp.session).await;
478        match resp.error {
479            Some(err) => Err(Error::QueryFailed(err)),
480            None => Ok(resp),
481        }
482    }
483
484    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
485        let kill_uri = format!("/v1/query/{query_id}/kill");
486        let endpoint = self.endpoint.join(&kill_uri)?;
487        let headers = self.make_headers(Some(query_id))?;
488        info!("kill query: {kill_uri}");
489
490        let mut builder = self.cli.post(endpoint);
491        builder = self.wrap_auth_or_session_token(builder)?;
492        let resp = builder.headers(headers.clone()).send().await?;
493        if resp.status() != 200 {
494            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
495                .with_context("kill query"));
496        }
497        Ok(())
498    }
499
500    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
501        let mut pages = self.start_query(sql, false).await?;
502        let mut all = Page::default();
503        while let Some(page) = pages.next().await {
504            all.update(page?);
505        }
506        Ok(all)
507    }
508
509    fn session_state(&self) -> SessionState {
510        self.session_state.lock().clone()
511    }
512
513    fn make_pagination(&self) -> Option<PaginationConfig> {
514        if self.wait_time_secs.is_none()
515            && self.max_rows_in_buffer.is_none()
516            && self.max_rows_per_page.is_none()
517        {
518            return None;
519        }
520        let mut pagination = PaginationConfig {
521            wait_time_secs: None,
522            max_rows_in_buffer: None,
523            max_rows_per_page: None,
524        };
525        if let Some(wait_time_secs) = self.wait_time_secs {
526            pagination.wait_time_secs = Some(wait_time_secs);
527        }
528        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
529            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
530        }
531        if let Some(max_rows_per_page) = self.max_rows_per_page {
532            pagination.max_rows_per_page = Some(max_rows_per_page);
533        }
534        Some(pagination)
535    }
536
537    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
538        let mut headers = HeaderMap::new();
539        if let Some(tenant) = &self.tenant {
540            headers.insert(HEADER_TENANT, tenant.parse()?);
541        }
542        let warehouse = self.warehouse.lock().clone();
543        if let Some(warehouse) = warehouse {
544            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
545        }
546        let route_hint = self.route_hint.current();
547        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
548        if let Some(query_id) = query_id {
549            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
550        }
551        Ok(headers)
552    }
553
554    pub async fn insert_with_stage(
555        self: &Arc<Self>,
556        sql: &str,
557        stage: &str,
558        file_format_options: BTreeMap<&str, &str>,
559        copy_options: BTreeMap<&str, &str>,
560    ) -> Result<QueryStats> {
561        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
562        let stage_attachment = Some(StageAttachmentConfig {
563            location: stage,
564            file_format_options: Some(file_format_options),
565            copy_options: Some(copy_options),
566        });
567        let resp = self.start_query_inner(sql, stage_attachment).await?;
568        let mut pages = Pages::new(self.clone(), resp, false);
569        let mut all = Page::default();
570        while let Some(page) = pages.next().await {
571            all.update(page?);
572        }
573        Ok(all.stats)
574    }
575
576    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
577        info!("get presigned upload url: {stage}");
578        let sql = format!("PRESIGN UPLOAD {stage}");
579        let resp = self.query_all(&sql).await?;
580        if resp.data.len() != 1 {
581            return Err(Error::Decode(
582                "Empty response from server for presigned request".to_string(),
583            ));
584        }
585        if resp.data[0].len() != 3 {
586            return Err(Error::Decode(
587                "Invalid response from server for presigned request".to_string(),
588            ));
589        }
590        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
591        let method = resp.data[0][0].clone().unwrap_or_default();
592        if method != "PUT" {
593            return Err(Error::Decode(format!(
594                "Invalid method for presigned upload request: {method}"
595            )));
596        }
597        let headers: BTreeMap<String, String> =
598            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
599        let url = resp.data[0][2].clone().unwrap_or_default();
600        Ok(PresignedResponse {
601            method,
602            headers,
603            url,
604        })
605    }
606
607    pub async fn upload_to_stage(
608        self: &Arc<Self>,
609        stage: &str,
610        data: Reader,
611        size: u64,
612    ) -> Result<()> {
613        match self.get_presign_mode() {
614            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
615            PresignMode::On => {
616                let presigned = self.get_presigned_upload_url(stage).await?;
617                presign_upload_to_stage(presigned, data, size).await
618            }
619            PresignMode::Auto => {
620                unreachable!("PresignMode::Auto should be handled during client initialization")
621            }
622            PresignMode::Detect => {
623                unreachable!("PresignMode::Detect should be handled during client initialization")
624            }
625        }
626    }
627
628    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
629    async fn upload_to_stage_with_stream(
630        &self,
631        stage: &str,
632        data: Reader,
633        size: u64,
634    ) -> Result<()> {
635        info!("upload to stage with stream: {stage}, size: {size}");
636        if let Some(info) = self.need_pre_refresh_session().await {
637            self.refresh_session_token(info).await?;
638        }
639        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
640        let location = StageLocation::try_from(stage)?;
641        let query_id = self.gen_query_id();
642        let mut headers = self.make_headers(Some(&query_id))?;
643        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
644        let stream = Body::wrap_stream(ReaderStream::new(data));
645        let part = Part::stream_with_length(stream, size).file_name(location.path);
646        let form = Form::new().part("upload", part);
647        let mut builder = self.cli.put(endpoint.clone());
648        builder = self.wrap_auth_or_session_token(builder)?;
649        let resp = builder.headers(headers).multipart(form).send().await?;
650        let status = resp.status();
651        if status != 200 {
652            return Err(
653                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
654            );
655        }
656        Ok(())
657    }
658
659    async fn login(&mut self) -> Result<()> {
660        let endpoint = self.endpoint.join("/v1/session/login")?;
661        let headers = self.make_headers(None)?;
662        let body = LoginRequest::from(&*self.session_state.lock());
663        let mut builder = self.cli.post(endpoint.clone()).json(&body);
664        if self.disable_session_token {
665            builder = builder.query(&[("disable_session_token", true)]);
666        }
667        let builder = self.auth.wrap(builder)?;
668        let request = builder
669            .headers(headers.clone())
670            .timeout(self.connect_timeout)
671            .build()?;
672        let response = self.query_request_helper(request, true, false).await;
673        let response = match response {
674            Ok(r) => r,
675            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
676                info!("login return 404, skip login on the old version server");
677                return Ok(());
678            }
679            Err(e) => return Err(e),
680        };
681        let body = response.bytes().await?;
682        let response = json_from_slice(&body)?;
683        match response {
684            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
685            LoginResponseResult::Ok(info) => {
686                self.server_version = Some(info.version.clone());
687                if let Some(tokens) = info.tokens {
688                    info!("login success with session token");
689                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
690                }
691                info!("login success without session token");
692            }
693        }
694        Ok(())
695    }
696
697    fn build_log_out_request(&self) -> Result<Request> {
698        let endpoint = self.endpoint.join("/v1/session/logout")?;
699
700        let session_state = self.session_state();
701        let need_sticky = session_state.need_sticky.unwrap_or(false);
702        let mut headers = self.make_headers(None)?;
703        if need_sticky {
704            if let Some(node_id) = self.last_node_id() {
705                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
706            }
707        }
708        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
709
710        let builder = self.wrap_auth_or_session_token(builder)?;
711        let req = builder.build()?;
712        Ok(req)
713    }
714
715    fn need_logout(&self) -> bool {
716        (self.session_token_info.is_some()
717            || self.session_state.lock().need_keep_alive.unwrap_or(false))
718            && self
719                .closed
720                .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
721                .is_ok()
722    }
723
724    async fn refresh_session_token(
725        &self,
726        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
727    ) -> Result<()> {
728        let (session_token_info, _) = { self_login_info.lock().clone() };
729        let endpoint = self.endpoint.join("/v1/session/refresh")?;
730        let body = RefreshSessionTokenRequest {
731            session_token: session_token_info.session_token.clone(),
732        };
733        let headers = self.make_headers(None)?;
734        let request = self
735            .cli
736            .post(endpoint.clone())
737            .json(&body)
738            .headers(headers.clone())
739            .bearer_auth(session_token_info.refresh_token.clone())
740            .timeout(self.connect_timeout)
741            .build()?;
742
743        // avoid recursively call request_helper
744        for i in 0..3 {
745            let req = request.try_clone().expect("request not cloneable");
746            match self.cli.execute(req).await {
747                Ok(response) => {
748                    let status = response.status();
749                    let body = response.bytes().await?;
750                    if status == StatusCode::OK {
751                        let response = json_from_slice(&body)?;
752                        return match response {
753                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
754                            RefreshResponse::Ok(info) => {
755                                *self_login_info.lock() = (info, Instant::now());
756                                Ok(())
757                            }
758                        };
759                    }
760                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
761                        return Err(Error::response_error(status, &body));
762                    }
763                }
764                Err(err) => {
765                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
766                        return Err(Error::Request(err.to_string()));
767                    }
768                }
769            };
770            sleep(jitter(Duration::from_secs(10))).await;
771        }
772        Ok(())
773    }
774
775    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
776        if let Some(info) = &self.session_token_info {
777            let (start, ttl) = {
778                let guard = info.lock();
779                (guard.1, guard.0.session_token_ttl_in_secs)
780            };
781            if Instant::now() > start + Duration::from_secs(ttl) {
782                return Some(info.clone());
783            }
784        }
785        None
786    }
787
788    /// return Ok if and only if status code is 200.
789    ///
790    /// retry on
791    ///   - network errors
792    ///   - (optional) 503
793    ///
794    /// refresh databend token or reload jwt token if needed.
795    async fn query_request_helper(
796        &self,
797        mut request: Request,
798        retry_if_503: bool,
799        refresh_if_401: bool,
800    ) -> std::result::Result<Response, Error> {
801        let mut refreshed = false;
802        let mut retries = 0;
803        loop {
804            let req = request.try_clone().expect("request not cloneable");
805            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
806                Ok(response) => {
807                    let status = response.status();
808                    if status == StatusCode::OK {
809                        return Ok(response);
810                    }
811                    let body = response.bytes().await?;
812                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
813                        // waiting for server to start
814                        (Error::response_error(status, &body), true)
815                    } else {
816                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
817                        match resp {
818                            Ok(r) => {
819                                let e = r.error;
820                                if status == StatusCode::UNAUTHORIZED {
821                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
822                                    if let Some(session_token_info) = &self.session_token_info {
823                                        info!(
824                                            "will retry {} after refresh token on auth error {}",
825                                            request.url(),
826                                            e
827                                        );
828                                        let retry = if need_refresh_token(e.code)
829                                            && !refreshed
830                                            && refresh_if_401
831                                        {
832                                            self.refresh_session_token(session_token_info.clone())
833                                                .await?;
834                                            refreshed = true;
835                                            true
836                                        } else {
837                                            false
838                                        };
839                                        (Error::AuthFailure(e), retry)
840                                    } else if self.auth.can_reload() {
841                                        info!(
842                                            "will retry {} after reload token on auth error {}",
843                                            request.url(),
844                                            e
845                                        );
846                                        let builder = RequestBuilder::from_parts(
847                                            HttpClient::new(),
848                                            request.try_clone().unwrap(),
849                                        );
850                                        let builder = self.auth.wrap(builder)?;
851                                        request = builder.build()?;
852                                        (Error::AuthFailure(e), true)
853                                    } else {
854                                        (Error::AuthFailure(e), false)
855                                    }
856                                } else {
857                                    (Error::Logic(status, e), false)
858                                }
859                            }
860                            Err(_) => (
861                                Error::Response {
862                                    status,
863                                    msg: String::from_utf8_lossy(&body).to_string(),
864                                },
865                                false,
866                            ),
867                        }
868                    }
869                }
870                Err(err) => (
871                    Error::Request(err.to_string()),
872                    err.is_timeout() || err.is_connect(),
873                ),
874            };
875            if !retry {
876                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
877            }
878            match &err {
879                Error::AuthFailure(_) => {
880                    if refreshed {
881                        retries = 0;
882                    } else if retries == 2 {
883                        return Err(err.with_context(&format!(
884                            "{} {} after 3 reties",
885                            request.method(),
886                            request.url()
887                        )));
888                    }
889                }
890                _ => {
891                    if retries == 2 {
892                        return Err(err.with_context(&format!(
893                            "{} {} after 3 reties",
894                            request.method(),
895                            request.url()
896                        )));
897                    }
898                    retries += 1;
899                    info!(
900                        "will retry {} the {retries}th times on error {}",
901                        request.url(),
902                        err
903                    );
904                }
905            }
906            sleep(jitter(Duration::from_secs(10))).await;
907        }
908    }
909
910    pub async fn close(&self) {
911        if self.need_logout() {
912            let cli = self.cli.clone();
913            let req = self
914                .build_log_out_request()
915                .expect("failed to build logout request");
916            if let Err(err) = cli.execute(req).await {
917                error!("logout request failed: {err}");
918            } else {
919                debug!("logout success");
920            };
921        }
922    }
923}
924
925impl Drop for APIClient {
926    fn drop(&mut self) {
927        if self.need_logout() {
928            warn!("APIClient::close() was not called");
929        }
930    }
931}
932
933fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
934where
935    T: Deserialize<'a>,
936{
937    serde_json::from_slice::<T>(body).map_err(|e| {
938        Error::Decode(format!(
939            "fail to decode JSON response: {e}, body: {}",
940            String::from_utf8_lossy(body)
941        ))
942    })
943}
944
945impl Default for APIClient {
946    fn default() -> Self {
947        Self {
948            cli: HttpClient::new(),
949            scheme: "http".to_string(),
950            endpoint: Url::parse("http://localhost:8080").unwrap(),
951            host: "localhost".to_string(),
952            port: 8000,
953            tenant: None,
954            warehouse: Mutex::new(None),
955            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
956            session_state: Mutex::new(SessionState::default()),
957            wait_time_secs: None,
958            max_rows_in_buffer: None,
959            max_rows_per_page: None,
960            connect_timeout: Duration::from_secs(10),
961            page_request_timeout: Duration::from_secs(30),
962            tls_ca_file: None,
963            presign: Mutex::new(PresignMode::Auto),
964            route_hint: RouteHintGenerator::new(),
965            last_node_id: Default::default(),
966            disable_session_token: true,
967            disable_login: false,
968            session_token_info: None,
969            closed: Default::default(),
970            last_query_id: Default::default(),
971            server_version: None,
972        }
973    }
974}
975
976struct RouteHintGenerator {
977    nonce: AtomicU64,
978    current: std::sync::Mutex<String>,
979}
980
981impl RouteHintGenerator {
982    fn new() -> Self {
983        let gen = Self {
984            nonce: AtomicU64::new(0),
985            current: std::sync::Mutex::new("".to_string()),
986        };
987        gen.next();
988        gen
989    }
990
991    fn current(&self) -> String {
992        let guard = self.current.lock().unwrap();
993        guard.clone()
994    }
995
996    fn set(&self, hint: &str) {
997        let mut guard = self.current.lock().unwrap();
998        *guard = hint.to_string();
999    }
1000
1001    fn next(&self) -> String {
1002        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1003        let uuid = uuid::Uuid::new_v4();
1004        let current = format!("rh:{uuid}:{nonce:06}");
1005        let mut guard = self.current.lock().unwrap();
1006        guard.clone_from(&current);
1007        current
1008    }
1009}
1010
1011#[cfg(test)]
1012mod test {
1013    use super::*;
1014
1015    #[tokio::test]
1016    async fn parse_dsn() -> Result<()> {
1017        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1018        let client = APIClient::from_dsn(dsn).await?;
1019        assert_eq!(client.host, "app.databend.com");
1020        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1021        assert_eq!(client.wait_time_secs, Some(10));
1022        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1023        assert_eq!(client.max_rows_per_page, Some(10000));
1024        assert_eq!(client.tenant, None);
1025        assert_eq!(
1026            *client.warehouse.try_lock().unwrap(),
1027            Some("wh".to_string())
1028        );
1029        Ok(())
1030    }
1031
1032    #[tokio::test]
1033    async fn parse_encoded_password() -> Result<()> {
1034        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1035        let client = APIClient::from_dsn(dsn).await?;
1036        assert_eq!(client.host(), "localhost");
1037        assert_eq!(client.port(), 443);
1038        Ok(())
1039    }
1040
1041    #[tokio::test]
1042    async fn parse_special_chars_password() -> Result<()> {
1043        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1044        let client = APIClient::from_dsn(dsn).await?;
1045        assert_eq!(client.host(), "localhost");
1046        assert_eq!(client.port(), 8000);
1047        Ok(())
1048    }
1049}