databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::capability::Capability;
22use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
23use crate::global_cookie_store::GlobalCookieStore;
24use crate::login::{
25    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
26    SessionTokenInfo,
27};
28use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
29use crate::response::LoadResponse;
30use crate::stage::StageLocation;
31use crate::{
32    error::{Error, Result},
33    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
34    response::QueryResponse,
35    session::SessionState,
36    QueryStats,
37};
38use crate::{Page, Pages};
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::Deserialize;
49use tokio::time::sleep;
50use tokio_retry::strategy::jitter;
51use tokio_stream::StreamExt;
52use tokio_util::io::ReaderStream;
53use url::Url;
54
55const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
56const HEADER_TENANT: &str = "X-DATABEND-TENANT";
57const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
58const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
59const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
60const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
61const TXN_STATE_ACTIVE: &str = "Active";
62const HEADER_SQL: &str = "X-DATABEND-SQL";
63
64static VERSION: Lazy<String> = Lazy::new(|| {
65    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
66    version.to_string()
67});
68
69pub struct APIClient {
70    cli: HttpClient,
71    scheme: String,
72    host: String,
73    port: u16,
74
75    endpoint: Url,
76
77    auth: Arc<dyn Auth>,
78
79    tenant: Option<String>,
80    warehouse: Mutex<Option<String>>,
81    session_state: Mutex<SessionState>,
82    route_hint: RouteHintGenerator,
83
84    disable_login: bool,
85    disable_session_token: bool,
86    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
87
88    closed: AtomicBool,
89
90    server_version: Option<Version>,
91
92    wait_time_secs: Option<i64>,
93    max_rows_in_buffer: Option<i64>,
94    max_rows_per_page: Option<i64>,
95
96    connect_timeout: Duration,
97    page_request_timeout: Duration,
98
99    tls_ca_file: Option<String>,
100
101    presign: Mutex<PresignMode>,
102    last_node_id: Mutex<Option<String>>,
103    last_query_id: Mutex<Option<String>>,
104
105    capability: Capability,
106}
107
108impl APIClient {
109    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
110        let mut client = Self::from_dsn(dsn).await?;
111        client.build_client(name).await?;
112        if !client.disable_login {
113            client.login().await?;
114        }
115        let client = Arc::new(client);
116        client.check_presign().await?;
117        Ok(client)
118    }
119
120    pub fn capability(&self) -> &Capability {
121        &self.capability
122    }
123
124    fn set_presign_mode(&self, mode: PresignMode) {
125        *self.presign.lock() = mode
126    }
127    fn get_presign_mode(&self) -> PresignMode {
128        *self.presign.lock()
129    }
130
131    async fn from_dsn(dsn: &str) -> Result<Self> {
132        let u = Url::parse(dsn)?;
133        let mut client = Self::default();
134        if let Some(host) = u.host_str() {
135            client.host = host.to_string();
136        }
137
138        if u.username() != "" {
139            let password = u.password().unwrap_or_default();
140            let password = percent_decode_str(password).decode_utf8()?;
141            client.auth = Arc::new(BasicAuth::new(u.username(), password));
142        }
143        let database = match u.path().trim_start_matches('/') {
144            "" => None,
145            s => Some(s.to_string()),
146        };
147        let mut role = None;
148        let mut scheme = "https";
149        let mut session_settings = BTreeMap::new();
150        for (k, v) in u.query_pairs() {
151            match k.as_ref() {
152                "wait_time_secs" => {
153                    client.wait_time_secs = Some(v.parse()?);
154                }
155                "max_rows_in_buffer" => {
156                    client.max_rows_in_buffer = Some(v.parse()?);
157                }
158                "max_rows_per_page" => {
159                    client.max_rows_per_page = Some(v.parse()?);
160                }
161                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
162                "page_request_timeout_secs" => {
163                    client.page_request_timeout = {
164                        let secs: u64 = v.parse()?;
165                        Duration::from_secs(secs)
166                    };
167                }
168                "presign" => {
169                    let presign_mode = match v.as_ref() {
170                        "auto" => PresignMode::Auto,
171                        "detect" => PresignMode::Detect,
172                        "on" => PresignMode::On,
173                        "off" => PresignMode::Off,
174                        _ => {
175                            return Err(Error::BadArgument(format!(
176                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
177                        )))
178                        }
179                    };
180                    client.set_presign_mode(presign_mode);
181                }
182                "tenant" => {
183                    client.tenant = Some(v.to_string());
184                }
185                "warehouse" => {
186                    client.warehouse = Mutex::new(Some(v.to_string()));
187                }
188                "role" => role = Some(v.to_string()),
189                "sslmode" => match v.as_ref() {
190                    "disable" => scheme = "http",
191                    "require" | "enable" => scheme = "https",
192                    _ => {
193                        return Err(Error::BadArgument(format!(
194                            "Invalid value for sslmode: {v}"
195                        )))
196                    }
197                },
198                "tls_ca_file" => {
199                    client.tls_ca_file = Some(v.to_string());
200                }
201                "access_token" => {
202                    client.auth = Arc::new(AccessTokenAuth::new(v));
203                }
204                "access_token_file" => {
205                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
206                }
207                "login" => {
208                    client.disable_login = match v.as_ref() {
209                        "disable" => true,
210                        "enable" => false,
211                        _ => {
212                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
213                        }
214                    }
215                }
216                "session_token" => {
217                    client.disable_session_token = match v.as_ref() {
218                        "disable" => true,
219                        "enable" => false,
220                        _ => {
221                            return Err(Error::BadArgument(format!(
222                                "Invalid value for session_token: {v}"
223                            )))
224                        }
225                    }
226                }
227                _ => {
228                    session_settings.insert(k.to_string(), v.to_string());
229                }
230            }
231        }
232        client.port = match u.port() {
233            Some(p) => p,
234            None => match scheme {
235                "http" => 80,
236                "https" => 443,
237                _ => unreachable!(),
238            },
239        };
240        client.scheme = scheme.to_string();
241
242        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
243        client.session_state = Mutex::new(
244            SessionState::default()
245                .with_settings(Some(session_settings))
246                .with_role(role)
247                .with_database(database),
248        );
249
250        Ok(client)
251    }
252
253    pub fn host(&self) -> &str {
254        self.host.as_str()
255    }
256
257    pub fn port(&self) -> u16 {
258        self.port
259    }
260
261    pub fn scheme(&self) -> &str {
262        self.scheme.as_str()
263    }
264
265    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
266        let ua = match name {
267            Some(n) => n,
268            None => format!("databend-client-rust/{}", VERSION.as_str()),
269        };
270        let cookie_provider = GlobalCookieStore::new();
271        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
272        let mut initial_cookies = [&cookie].into_iter();
273        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
274        let mut cli_builder = HttpClient::builder()
275            .user_agent(ua)
276            .cookie_provider(Arc::new(cookie_provider))
277            .pool_idle_timeout(Duration::from_secs(1));
278        #[cfg(any(feature = "rustls", feature = "native-tls"))]
279        if self.scheme == "https" {
280            if let Some(ref ca_file) = self.tls_ca_file {
281                let cert_pem = tokio::fs::read(ca_file).await?;
282                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
283                cli_builder = cli_builder.add_root_certificate(cert);
284            }
285        }
286        self.cli = cli_builder.build()?;
287        Ok(())
288    }
289
290    async fn check_presign(self: &Arc<Self>) -> Result<()> {
291        let mode = match self.get_presign_mode() {
292            PresignMode::Auto => {
293                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
294                    PresignMode::On
295                } else {
296                    PresignMode::Off
297                }
298            }
299            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
300                Ok(_) => PresignMode::On,
301                Err(e) => {
302                    warn!("presign mode off with error detected: {e}");
303                    PresignMode::Off
304                }
305            },
306            mode => mode,
307        };
308        self.set_presign_mode(mode);
309        Ok(())
310    }
311
312    pub fn current_warehouse(&self) -> Option<String> {
313        let guard = self.warehouse.lock();
314        guard.clone()
315    }
316
317    pub fn current_catalog(&self) -> Option<String> {
318        let guard = self.session_state.lock();
319        guard.catalog.clone()
320    }
321
322    pub fn current_database(&self) -> Option<String> {
323        let guard = self.session_state.lock();
324        guard.database.clone()
325    }
326
327    pub async fn current_role(&self) -> Option<String> {
328        let guard = self.session_state.lock();
329        guard.role.clone()
330    }
331
332    fn in_active_transaction(&self) -> bool {
333        let guard = self.session_state.lock();
334        guard
335            .txn_state
336            .as_ref()
337            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
338            .unwrap_or(false)
339    }
340
341    pub fn username(&self) -> String {
342        self.auth.username()
343    }
344
345    fn gen_query_id(&self) -> String {
346        uuid::Uuid::now_v7().simple().to_string()
347    }
348
349    async fn handle_session(&self, session: &Option<SessionState>) {
350        let session = match session {
351            Some(session) => session,
352            None => return,
353        };
354
355        // save the updated session state from the server side
356        {
357            let mut session_state = self.session_state.lock();
358            *session_state = session.clone();
359        }
360
361        // process warehouse changed via session settings
362        if let Some(settings) = session.settings.as_ref() {
363            if let Some(v) = settings.get("warehouse") {
364                let mut warehouse = self.warehouse.lock();
365                *warehouse = Some(v.clone());
366            }
367        }
368    }
369
370    pub fn set_last_node_id(&self, node_id: String) {
371        *self.last_node_id.lock() = Some(node_id)
372    }
373
374    pub fn set_last_query_id(&self, query_id: Option<String>) {
375        *self.last_query_id.lock() = query_id
376    }
377
378    pub fn last_query_id(&self) -> Option<String> {
379        self.last_query_id.lock().clone()
380    }
381
382    fn last_node_id(&self) -> Option<String> {
383        self.last_node_id.lock().clone()
384    }
385
386    fn handle_warnings(&self, resp: &QueryResponse) {
387        if let Some(warnings) = &resp.warnings {
388            for w in warnings {
389                warn!(target: "server_warnings", "server warning: {w}");
390            }
391        }
392    }
393
394    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
395        info!("start query: {sql}");
396        let resp = self.start_query_inner(sql, None).await?;
397        let pages = Pages::new(self.clone(), resp, need_progress);
398        Ok(pages)
399    }
400
401    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
402        if let Some(info) = &self.session_token_info {
403            let info = info.lock();
404            Ok(builder.bearer_auth(info.0.session_token.clone()))
405        } else {
406            self.auth.wrap(builder)
407        }
408    }
409
410    async fn start_query_inner(
411        &self,
412        sql: &str,
413        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
414    ) -> Result<QueryResponse> {
415        if !self.in_active_transaction() {
416            self.route_hint.next();
417        }
418        let endpoint = self.endpoint.join("v1/query")?;
419
420        // body
421        let session_state = self.session_state();
422        let need_sticky = session_state.need_sticky.unwrap_or(false);
423        let req = QueryRequest::new(sql)
424            .with_pagination(self.make_pagination())
425            .with_session(Some(session_state))
426            .with_stage_attachment(stage_attachment_config);
427
428        // headers
429        let query_id = self.gen_query_id();
430        let mut headers = self.make_headers(Some(&query_id))?;
431        if need_sticky {
432            if let Some(node_id) = self.last_node_id() {
433                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
434            }
435        }
436        let mut builder = self.cli.post(endpoint.clone()).json(&req);
437        builder = self.wrap_auth_or_session_token(builder)?;
438        let request = builder.headers(headers.clone()).build()?;
439        let response = self.query_request_helper(request, true, true).await?;
440        if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
441            self.route_hint.set(route_hint.to_str().unwrap_or_default());
442        }
443        let body = response.bytes().await?;
444        let result: QueryResponse = json_from_slice(&body)?;
445        self.handle_session(&result.session).await;
446        if let Some(err) = result.error {
447            return Err(Error::QueryFailed(err));
448        }
449
450        self.set_last_query_id(Some(query_id));
451        self.handle_warnings(&result);
452        if let Some(node_id) = &result.node_id {
453            self.set_last_node_id(node_id.clone());
454        }
455        Ok(result)
456    }
457
458    pub async fn query_page(
459        &self,
460        query_id: &str,
461        next_uri: &str,
462        node_id: &Option<String>,
463    ) -> Result<QueryResponse> {
464        info!("query page: {next_uri}");
465        let endpoint = self.endpoint.join(next_uri)?;
466        let headers = self.make_headers(Some(query_id))?;
467        let mut builder = self.cli.get(endpoint.clone());
468        builder = self
469            .wrap_auth_or_session_token(builder)?
470            .headers(headers.clone())
471            .timeout(self.page_request_timeout);
472        if let Some(node_id) = node_id {
473            builder = builder.header(HEADER_STICKY_NODE, node_id)
474        }
475        let request = builder.build()?;
476
477        let response = self.query_request_helper(request, false, true).await?;
478        let status = response.status();
479        if status != 200 {
480            return Err(Error::response_error(status, &response.bytes().await?));
481        }
482        let body = response.bytes().await?;
483        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
484            if let Error::Logic(status, ec) = &e {
485                if *status == 404 {
486                    return Error::QueryNotFound(ec.message.clone());
487                }
488            }
489            e
490        })?;
491        self.handle_session(&resp.session).await;
492        match resp.error {
493            Some(err) => Err(Error::QueryFailed(err)),
494            None => Ok(resp),
495        }
496    }
497
498    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
499        let kill_uri = format!("/v1/query/{query_id}/kill");
500        let endpoint = self.endpoint.join(&kill_uri)?;
501        let headers = self.make_headers(Some(query_id))?;
502        info!("kill query: {kill_uri}");
503
504        let mut builder = self.cli.post(endpoint);
505        builder = self.wrap_auth_or_session_token(builder)?;
506        let resp = builder.headers(headers.clone()).send().await?;
507        if resp.status() != 200 {
508            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
509                .with_context("kill query"));
510        }
511        Ok(())
512    }
513
514    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
515        let mut pages = self.start_query(sql, false).await?;
516        let mut all = Page::default();
517        while let Some(page) = pages.next().await {
518            all.update(page?);
519        }
520        Ok(all)
521    }
522
523    fn session_state(&self) -> SessionState {
524        self.session_state.lock().clone()
525    }
526
527    fn make_pagination(&self) -> Option<PaginationConfig> {
528        if self.wait_time_secs.is_none()
529            && self.max_rows_in_buffer.is_none()
530            && self.max_rows_per_page.is_none()
531        {
532            return None;
533        }
534        let mut pagination = PaginationConfig {
535            wait_time_secs: None,
536            max_rows_in_buffer: None,
537            max_rows_per_page: None,
538        };
539        if let Some(wait_time_secs) = self.wait_time_secs {
540            pagination.wait_time_secs = Some(wait_time_secs);
541        }
542        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
543            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
544        }
545        if let Some(max_rows_per_page) = self.max_rows_per_page {
546            pagination.max_rows_per_page = Some(max_rows_per_page);
547        }
548        Some(pagination)
549    }
550
551    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
552        let mut headers = HeaderMap::new();
553        if let Some(tenant) = &self.tenant {
554            headers.insert(HEADER_TENANT, tenant.parse()?);
555        }
556        let warehouse = self.warehouse.lock().clone();
557        if let Some(warehouse) = warehouse {
558            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
559        }
560        let route_hint = self.route_hint.current();
561        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
562        if let Some(query_id) = query_id {
563            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
564        }
565        Ok(headers)
566    }
567
568    pub async fn insert_with_stage(
569        self: &Arc<Self>,
570        sql: &str,
571        stage: &str,
572        file_format_options: BTreeMap<&str, &str>,
573        copy_options: BTreeMap<&str, &str>,
574    ) -> Result<QueryStats> {
575        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
576        let stage_attachment = Some(StageAttachmentConfig {
577            location: stage,
578            file_format_options: Some(file_format_options),
579            copy_options: Some(copy_options),
580        });
581        let resp = self.start_query_inner(sql, stage_attachment).await?;
582        let mut pages = Pages::new(self.clone(), resp, false);
583        let mut all = Page::default();
584        while let Some(page) = pages.next().await {
585            all.update(page?);
586        }
587        Ok(all.stats)
588    }
589
590    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
591        info!("get presigned upload url: {stage}");
592        let sql = format!("PRESIGN UPLOAD {stage}");
593        let resp = self.query_all(&sql).await?;
594        if resp.data.len() != 1 {
595            return Err(Error::Decode(
596                "Empty response from server for presigned request".to_string(),
597            ));
598        }
599        if resp.data[0].len() != 3 {
600            return Err(Error::Decode(
601                "Invalid response from server for presigned request".to_string(),
602            ));
603        }
604        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
605        let method = resp.data[0][0].clone().unwrap_or_default();
606        if method != "PUT" {
607            return Err(Error::Decode(format!(
608                "Invalid method for presigned upload request: {method}"
609            )));
610        }
611        let headers: BTreeMap<String, String> =
612            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
613        let url = resp.data[0][2].clone().unwrap_or_default();
614        Ok(PresignedResponse {
615            method,
616            headers,
617            url,
618        })
619    }
620
621    pub async fn upload_to_stage(
622        self: &Arc<Self>,
623        stage: &str,
624        data: Reader,
625        size: u64,
626    ) -> Result<()> {
627        match self.get_presign_mode() {
628            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
629            PresignMode::On => {
630                let presigned = self.get_presigned_upload_url(stage).await?;
631                presign_upload_to_stage(presigned, data, size).await
632            }
633            PresignMode::Auto => {
634                unreachable!("PresignMode::Auto should be handled during client initialization")
635            }
636            PresignMode::Detect => {
637                unreachable!("PresignMode::Detect should be handled during client initialization")
638            }
639        }
640    }
641
642    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
643    async fn upload_to_stage_with_stream(
644        &self,
645        stage: &str,
646        data: Reader,
647        size: u64,
648    ) -> Result<()> {
649        info!("upload to stage with stream: {stage}, size: {size}");
650        if let Some(info) = self.need_pre_refresh_session().await {
651            self.refresh_session_token(info).await?;
652        }
653        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
654        let location = StageLocation::try_from(stage)?;
655        let query_id = self.gen_query_id();
656        let mut headers = self.make_headers(Some(&query_id))?;
657        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
658        let stream = Body::wrap_stream(ReaderStream::new(data));
659        let part = Part::stream_with_length(stream, size).file_name(location.path);
660        let form = Form::new().part("upload", part);
661        let mut builder = self.cli.put(endpoint.clone());
662        builder = self.wrap_auth_or_session_token(builder)?;
663        let resp = builder.headers(headers).multipart(form).send().await?;
664        let status = resp.status();
665        if status != 200 {
666            return Err(
667                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
668            );
669        }
670        Ok(())
671    }
672
673    pub async fn streaming_load(
674        &self,
675        sql: &str,
676        data: Reader,
677        file_name: &str,
678    ) -> Result<LoadResponse> {
679        let body = Body::wrap_stream(ReaderStream::new(data));
680        let part = Part::stream(body).file_name(file_name.to_string());
681        let endpoint = self.endpoint.join("v1/streaming_load")?;
682        let mut builder = self.cli.put(endpoint.clone());
683        builder = self.wrap_auth_or_session_token(builder)?;
684        let query_id = self.gen_query_id();
685        let mut headers = self.make_headers(Some(&query_id))?;
686        headers.insert(HEADER_SQL, sql.parse()?);
687        let form = Form::new().part("upload", part);
688        let resp = builder.headers(headers).multipart(form).send().await?;
689        let status = resp.status();
690        if status != 200 {
691            return Err(
692                Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
693            );
694        }
695        let resp = resp.json::<LoadResponse>().await?;
696        Ok(resp)
697    }
698
699    async fn login(&mut self) -> Result<()> {
700        let endpoint = self.endpoint.join("/v1/session/login")?;
701        let headers = self.make_headers(None)?;
702        let body = LoginRequest::from(&*self.session_state.lock());
703        let mut builder = self.cli.post(endpoint.clone()).json(&body);
704        if self.disable_session_token {
705            builder = builder.query(&[("disable_session_token", true)]);
706        }
707        let builder = self.auth.wrap(builder)?;
708        let request = builder
709            .headers(headers.clone())
710            .timeout(self.connect_timeout)
711            .build()?;
712        let response = self.query_request_helper(request, true, false).await;
713        let response = match response {
714            Ok(r) => r,
715            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
716                info!("login return 404, skip login on the old version server");
717                return Ok(());
718            }
719            Err(e) => return Err(e),
720        };
721        let body = response.bytes().await?;
722        let response = json_from_slice(&body)?;
723        match response {
724            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
725            LoginResponseResult::Ok(info) => {
726                let server_version = info
727                    .version
728                    .parse()
729                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
730                self.capability = Capability::from_server_version(&server_version);
731                self.server_version = Some(server_version.clone());
732                if let Some(tokens) = info.tokens {
733                    info!("login success with session token version = {server_version}",);
734                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
735                } else {
736                    info!("login success, version = {server_version}");
737                }
738            }
739        }
740        Ok(())
741    }
742
743    fn build_log_out_request(&self) -> Result<Request> {
744        let endpoint = self.endpoint.join("/v1/session/logout")?;
745
746        let session_state = self.session_state();
747        let need_sticky = session_state.need_sticky.unwrap_or(false);
748        let mut headers = self.make_headers(None)?;
749        if need_sticky {
750            if let Some(node_id) = self.last_node_id() {
751                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
752            }
753        }
754        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
755
756        let builder = self.wrap_auth_or_session_token(builder)?;
757        let req = builder.build()?;
758        Ok(req)
759    }
760
761    fn need_logout(&self) -> bool {
762        (self.session_token_info.is_some()
763            || self.session_state.lock().need_keep_alive.unwrap_or(false))
764            && self
765                .closed
766                .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
767                .is_ok()
768    }
769
770    async fn refresh_session_token(
771        &self,
772        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
773    ) -> Result<()> {
774        let (session_token_info, _) = { self_login_info.lock().clone() };
775        let endpoint = self.endpoint.join("/v1/session/refresh")?;
776        let body = RefreshSessionTokenRequest {
777            session_token: session_token_info.session_token.clone(),
778        };
779        let headers = self.make_headers(None)?;
780        let request = self
781            .cli
782            .post(endpoint.clone())
783            .json(&body)
784            .headers(headers.clone())
785            .bearer_auth(session_token_info.refresh_token.clone())
786            .timeout(self.connect_timeout)
787            .build()?;
788
789        // avoid recursively call request_helper
790        for i in 0..3 {
791            let req = request.try_clone().expect("request not cloneable");
792            match self.cli.execute(req).await {
793                Ok(response) => {
794                    let status = response.status();
795                    let body = response.bytes().await?;
796                    if status == StatusCode::OK {
797                        let response = json_from_slice(&body)?;
798                        return match response {
799                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
800                            RefreshResponse::Ok(info) => {
801                                *self_login_info.lock() = (info, Instant::now());
802                                Ok(())
803                            }
804                        };
805                    }
806                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
807                        return Err(Error::response_error(status, &body));
808                    }
809                }
810                Err(err) => {
811                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
812                        return Err(Error::Request(err.to_string()));
813                    }
814                }
815            };
816            sleep(jitter(Duration::from_secs(10))).await;
817        }
818        Ok(())
819    }
820
821    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
822        if let Some(info) = &self.session_token_info {
823            let (start, ttl) = {
824                let guard = info.lock();
825                (guard.1, guard.0.session_token_ttl_in_secs)
826            };
827            if Instant::now() > start + Duration::from_secs(ttl) {
828                return Some(info.clone());
829            }
830        }
831        None
832    }
833
834    /// return Ok if and only if status code is 200.
835    ///
836    /// retry on
837    ///   - network errors
838    ///   - (optional) 503
839    ///
840    /// refresh databend token or reload jwt token if needed.
841    async fn query_request_helper(
842        &self,
843        mut request: Request,
844        retry_if_503: bool,
845        refresh_if_401: bool,
846    ) -> std::result::Result<Response, Error> {
847        let mut refreshed = false;
848        let mut retries = 0;
849        loop {
850            let req = request.try_clone().expect("request not cloneable");
851            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
852                Ok(response) => {
853                    let status = response.status();
854                    if status == StatusCode::OK {
855                        return Ok(response);
856                    }
857                    let body = response.bytes().await?;
858                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
859                        // waiting for server to start
860                        (Error::response_error(status, &body), true)
861                    } else {
862                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
863                        match resp {
864                            Ok(r) => {
865                                let e = r.error;
866                                if status == StatusCode::UNAUTHORIZED {
867                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
868                                    if let Some(session_token_info) = &self.session_token_info {
869                                        info!(
870                                            "will retry {} after refresh token on auth error {}",
871                                            request.url(),
872                                            e
873                                        );
874                                        let retry = if need_refresh_token(e.code)
875                                            && !refreshed
876                                            && refresh_if_401
877                                        {
878                                            self.refresh_session_token(session_token_info.clone())
879                                                .await?;
880                                            refreshed = true;
881                                            true
882                                        } else {
883                                            false
884                                        };
885                                        (Error::AuthFailure(e), retry)
886                                    } else if self.auth.can_reload() {
887                                        info!(
888                                            "will retry {} after reload token on auth error {}",
889                                            request.url(),
890                                            e
891                                        );
892                                        let builder = RequestBuilder::from_parts(
893                                            HttpClient::new(),
894                                            request.try_clone().unwrap(),
895                                        );
896                                        let builder = self.auth.wrap(builder)?;
897                                        request = builder.build()?;
898                                        (Error::AuthFailure(e), true)
899                                    } else {
900                                        (Error::AuthFailure(e), false)
901                                    }
902                                } else {
903                                    (Error::Logic(status, e), false)
904                                }
905                            }
906                            Err(_) => (
907                                Error::Response {
908                                    status,
909                                    msg: String::from_utf8_lossy(&body).to_string(),
910                                },
911                                false,
912                            ),
913                        }
914                    }
915                }
916                Err(err) => (
917                    Error::Request(err.to_string()),
918                    err.is_timeout() || err.is_connect(),
919                ),
920            };
921            if !retry {
922                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
923            }
924            match &err {
925                Error::AuthFailure(_) => {
926                    if refreshed {
927                        retries = 0;
928                    } else if retries == 2 {
929                        return Err(err.with_context(&format!(
930                            "{} {} after 3 reties",
931                            request.method(),
932                            request.url()
933                        )));
934                    }
935                }
936                _ => {
937                    if retries == 2 {
938                        return Err(err.with_context(&format!(
939                            "{} {} after 3 reties",
940                            request.method(),
941                            request.url()
942                        )));
943                    }
944                    retries += 1;
945                    info!(
946                        "will retry {} the {retries}th times on error {}",
947                        request.url(),
948                        err
949                    );
950                }
951            }
952            sleep(jitter(Duration::from_secs(10))).await;
953        }
954    }
955
956    pub async fn close(&self) {
957        if self.need_logout() {
958            let cli = self.cli.clone();
959            let req = self
960                .build_log_out_request()
961                .expect("failed to build logout request");
962            if let Err(err) = cli.execute(req).await {
963                error!("logout request failed: {err}");
964            } else {
965                debug!("logout success");
966            };
967        }
968    }
969}
970
971impl Drop for APIClient {
972    fn drop(&mut self) {
973        if self.need_logout() {
974            warn!("APIClient::close() was not called");
975        }
976    }
977}
978
979fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
980where
981    T: Deserialize<'a>,
982{
983    serde_json::from_slice::<T>(body).map_err(|e| {
984        Error::Decode(format!(
985            "fail to decode JSON response: {e}, body: {}",
986            String::from_utf8_lossy(body)
987        ))
988    })
989}
990
991impl Default for APIClient {
992    fn default() -> Self {
993        Self {
994            cli: HttpClient::new(),
995            scheme: "http".to_string(),
996            endpoint: Url::parse("http://localhost:8080").unwrap(),
997            host: "localhost".to_string(),
998            port: 8000,
999            tenant: None,
1000            warehouse: Mutex::new(None),
1001            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1002            session_state: Mutex::new(SessionState::default()),
1003            wait_time_secs: None,
1004            max_rows_in_buffer: None,
1005            max_rows_per_page: None,
1006            connect_timeout: Duration::from_secs(10),
1007            page_request_timeout: Duration::from_secs(30),
1008            tls_ca_file: None,
1009            presign: Mutex::new(PresignMode::Auto),
1010            route_hint: RouteHintGenerator::new(),
1011            last_node_id: Default::default(),
1012            disable_session_token: true,
1013            disable_login: false,
1014            session_token_info: None,
1015            closed: Default::default(),
1016            last_query_id: Default::default(),
1017            server_version: None,
1018            capability: Default::default(),
1019        }
1020    }
1021}
1022
1023struct RouteHintGenerator {
1024    nonce: AtomicU64,
1025    current: std::sync::Mutex<String>,
1026}
1027
1028impl RouteHintGenerator {
1029    fn new() -> Self {
1030        let gen = Self {
1031            nonce: AtomicU64::new(0),
1032            current: std::sync::Mutex::new("".to_string()),
1033        };
1034        gen.next();
1035        gen
1036    }
1037
1038    fn current(&self) -> String {
1039        let guard = self.current.lock().unwrap();
1040        guard.clone()
1041    }
1042
1043    fn set(&self, hint: &str) {
1044        let mut guard = self.current.lock().unwrap();
1045        *guard = hint.to_string();
1046    }
1047
1048    fn next(&self) -> String {
1049        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1050        let uuid = uuid::Uuid::new_v4();
1051        let current = format!("rh:{uuid}:{nonce:06}");
1052        let mut guard = self.current.lock().unwrap();
1053        guard.clone_from(&current);
1054        current
1055    }
1056}
1057
1058#[cfg(test)]
1059mod test {
1060    use super::*;
1061
1062    #[tokio::test]
1063    async fn parse_dsn() -> Result<()> {
1064        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1065        let client = APIClient::from_dsn(dsn).await?;
1066        assert_eq!(client.host, "app.databend.com");
1067        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1068        assert_eq!(client.wait_time_secs, Some(10));
1069        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1070        assert_eq!(client.max_rows_per_page, Some(10000));
1071        assert_eq!(client.tenant, None);
1072        assert_eq!(
1073            *client.warehouse.try_lock().unwrap(),
1074            Some("wh".to_string())
1075        );
1076        Ok(())
1077    }
1078
1079    #[tokio::test]
1080    async fn parse_encoded_password() -> Result<()> {
1081        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1082        let client = APIClient::from_dsn(dsn).await?;
1083        assert_eq!(client.host(), "localhost");
1084        assert_eq!(client.port(), 443);
1085        Ok(())
1086    }
1087
1088    #[tokio::test]
1089    async fn parse_special_chars_password() -> Result<()> {
1090        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1091        let client = APIClient::from_dsn(dsn).await?;
1092        assert_eq!(client.host(), "localhost");
1093        assert_eq!(client.port(), 8000);
1094        Ok(())
1095    }
1096}