databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
22use crate::global_cookie_store::GlobalCookieStore;
23use crate::login::{
24    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
25    SessionTokenInfo,
26};
27use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
28use crate::stage::StageLocation;
29use crate::{
30    error::{Error, Result},
31    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
32    response::QueryResponse,
33    session::SessionState,
34};
35use log::{debug, error, info, warn};
36use once_cell::sync::Lazy;
37use percent_encoding::percent_decode_str;
38use reqwest::cookie::CookieStore;
39use reqwest::header::{HeaderMap, HeaderValue};
40use reqwest::multipart::{Form, Part};
41use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
42use serde::Deserialize;
43use tokio::time::sleep;
44use tokio_retry::strategy::jitter;
45use tokio_util::io::ReaderStream;
46use url::Url;
47
48const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
49const HEADER_TENANT: &str = "X-DATABEND-TENANT";
50const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
51const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
52const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
53const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
54const TXN_STATE_ACTIVE: &str = "Active";
55
56static VERSION: Lazy<String> = Lazy::new(|| {
57    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
58    version.to_string()
59});
60
61#[derive(Clone)]
62pub struct APIClient {
63    cli: HttpClient,
64    scheme: String,
65    host: String,
66    port: u16,
67
68    endpoint: Url,
69
70    auth: Arc<dyn Auth>,
71
72    tenant: Option<String>,
73    warehouse: Arc<parking_lot::Mutex<Option<String>>>,
74    session_state: Arc<parking_lot::Mutex<SessionState>>,
75    route_hint: Arc<RouteHintGenerator>,
76
77    disable_login: bool,
78    disable_session_token: bool,
79    session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,
80
81    closed: Arc<AtomicBool>,
82
83    server_version: Option<String>,
84
85    wait_time_secs: Option<i64>,
86    max_rows_in_buffer: Option<i64>,
87    max_rows_per_page: Option<i64>,
88
89    connect_timeout: Duration,
90    page_request_timeout: Duration,
91
92    tls_ca_file: Option<String>,
93
94    presign: PresignMode,
95    last_node_id: Arc<parking_lot::Mutex<Option<String>>>,
96    last_query_id: Arc<parking_lot::Mutex<Option<String>>>,
97}
98
99impl APIClient {
100    pub async fn new(dsn: &str, name: Option<String>) -> Result<Self> {
101        let mut client = Self::from_dsn(dsn).await?;
102        client.build_client(name).await?;
103        client.check_presign().await?;
104        if !client.disable_login {
105            client.login().await?;
106        }
107        Ok(client)
108    }
109
110    async fn from_dsn(dsn: &str) -> Result<Self> {
111        let u = Url::parse(dsn)?;
112        let mut client = Self::default();
113        if let Some(host) = u.host_str() {
114            client.host = host.to_string();
115        }
116
117        if u.username() != "" {
118            let password = u.password().unwrap_or_default();
119            let password = percent_decode_str(password).decode_utf8()?;
120            client.auth = Arc::new(BasicAuth::new(u.username(), password));
121        }
122        let database = match u.path().trim_start_matches('/') {
123            "" => None,
124            s => Some(s.to_string()),
125        };
126        let mut role = None;
127        let mut scheme = "https";
128        let mut session_settings = BTreeMap::new();
129        for (k, v) in u.query_pairs() {
130            match k.as_ref() {
131                "wait_time_secs" => {
132                    client.wait_time_secs = Some(v.parse()?);
133                }
134                "max_rows_in_buffer" => {
135                    client.max_rows_in_buffer = Some(v.parse()?);
136                }
137                "max_rows_per_page" => {
138                    client.max_rows_per_page = Some(v.parse()?);
139                }
140                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
141                "page_request_timeout_secs" => {
142                    client.page_request_timeout = {
143                        let secs: u64 = v.parse()?;
144                        Duration::from_secs(secs)
145                    };
146                }
147                "presign" => {
148                    client.presign = match v.as_ref() {
149                        "auto" => PresignMode::Auto,
150                        "detect" => PresignMode::Detect,
151                        "on" => PresignMode::On,
152                        "off" => PresignMode::Off,
153                        _ => {
154                            return Err(Error::BadArgument(format!(
155                            "Invalid value for presign: {}, should be one of auto/detect/on/off",
156                            v
157                        )))
158                        }
159                    }
160                }
161                "tenant" => {
162                    client.tenant = Some(v.to_string());
163                }
164                "warehouse" => {
165                    client.warehouse = Arc::new(parking_lot::Mutex::new(Some(v.to_string())));
166                }
167                "role" => role = Some(v.to_string()),
168                "sslmode" => match v.as_ref() {
169                    "disable" => scheme = "http",
170                    "require" | "enable" => scheme = "https",
171                    _ => {
172                        return Err(Error::BadArgument(format!(
173                            "Invalid value for sslmode: {}",
174                            v
175                        )))
176                    }
177                },
178                "tls_ca_file" => {
179                    client.tls_ca_file = Some(v.to_string());
180                }
181                "access_token" => {
182                    client.auth = Arc::new(AccessTokenAuth::new(v));
183                }
184                "access_token_file" => {
185                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
186                }
187                "login" => {
188                    client.disable_login = match v.as_ref() {
189                        "disable" => true,
190                        "enable" => false,
191                        _ => {
192                            return Err(Error::BadArgument(format!(
193                                "Invalid value for login: {}",
194                                v
195                            )))
196                        }
197                    }
198                }
199                "session_token" => {
200                    client.disable_session_token = match v.as_ref() {
201                        "disable" => true,
202                        "enable" => false,
203                        _ => {
204                            return Err(Error::BadArgument(format!(
205                                "Invalid value for session_token: {}",
206                                v
207                            )))
208                        }
209                    }
210                }
211                _ => {
212                    session_settings.insert(k.to_string(), v.to_string());
213                }
214            }
215        }
216        client.port = match u.port() {
217            Some(p) => p,
218            None => match scheme {
219                "http" => 80,
220                "https" => 443,
221                _ => unreachable!(),
222            },
223        };
224        client.scheme = scheme.to_string();
225
226        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
227        client.session_state = Arc::new(parking_lot::Mutex::new(
228            SessionState::default()
229                .with_settings(Some(session_settings))
230                .with_role(role)
231                .with_database(database),
232        ));
233
234        Ok(client)
235    }
236
237    pub fn host(&self) -> &str {
238        self.host.as_str()
239    }
240
241    pub fn port(&self) -> u16 {
242        self.port
243    }
244
245    pub fn scheme(&self) -> &str {
246        self.scheme.as_str()
247    }
248
249    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
250        let ua = match name {
251            Some(n) => n,
252            None => format!("databend-client-rust/{}", VERSION.as_str()),
253        };
254        let cookie_provider = GlobalCookieStore::new();
255        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
256        let mut initial_cookies = [&cookie].into_iter();
257        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
258        let mut cli_builder = HttpClient::builder()
259            .user_agent(ua)
260            .cookie_provider(Arc::new(cookie_provider))
261            .pool_idle_timeout(Duration::from_secs(1));
262        #[cfg(any(feature = "rustls", feature = "native-tls"))]
263        if self.scheme == "https" {
264            if let Some(ref ca_file) = self.tls_ca_file {
265                let cert_pem = tokio::fs::read(ca_file).await?;
266                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
267                cli_builder = cli_builder.add_root_certificate(cert);
268            }
269        }
270        self.cli = cli_builder.build()?;
271        Ok(())
272    }
273
274    async fn check_presign(&mut self) -> Result<()> {
275        match self.presign {
276            PresignMode::Auto => {
277                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
278                    self.presign = PresignMode::On;
279                } else {
280                    self.presign = PresignMode::Off;
281                }
282            }
283            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
284                Ok(_) => self.presign = PresignMode::On,
285                Err(e) => {
286                    warn!("presign mode off with error detected: {}", e);
287                    self.presign = PresignMode::Off;
288                }
289            },
290            _ => {}
291        }
292        Ok(())
293    }
294
295    pub fn current_warehouse(&self) -> Option<String> {
296        let guard = self.warehouse.lock();
297        guard.clone()
298    }
299
300    pub fn current_database(&self) -> Option<String> {
301        let guard = self.session_state.lock();
302        guard.database.clone()
303    }
304
305    pub async fn current_role(&self) -> Option<String> {
306        let guard = self.session_state.lock();
307        guard.role.clone()
308    }
309
310    fn in_active_transaction(&self) -> bool {
311        let guard = self.session_state.lock();
312        guard
313            .txn_state
314            .as_ref()
315            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
316            .unwrap_or(false)
317    }
318
319    pub fn username(&self) -> String {
320        self.auth.username()
321    }
322
323    fn gen_query_id(&self) -> String {
324        uuid::Uuid::new_v4().to_string()
325    }
326
327    async fn handle_session(&self, session: &Option<SessionState>) {
328        let session = match session {
329            Some(session) => session,
330            None => return,
331        };
332
333        // save the updated session state from the server side
334        {
335            let mut session_state = self.session_state.lock();
336            *session_state = session.clone();
337        }
338
339        // process warehouse changed via session settings
340        if let Some(settings) = session.settings.as_ref() {
341            if let Some(v) = settings.get("warehouse") {
342                let mut warehouse = self.warehouse.lock();
343                *warehouse = Some(v.clone());
344            }
345        }
346    }
347
348    pub fn set_last_node_id(&self, node_id: String) {
349        *self.last_node_id.lock() = Some(node_id)
350    }
351
352    pub fn set_last_query_id(&self, query_id: Option<String>) {
353        *self.last_query_id.lock() = query_id
354    }
355
356    pub fn last_query_id(&self) -> Option<String> {
357        self.last_query_id.lock().clone()
358    }
359
360    fn last_node_id(&self) -> Option<String> {
361        self.last_node_id.lock().clone()
362    }
363
364    fn handle_warnings(&self, resp: &QueryResponse) {
365        if let Some(warnings) = &resp.warnings {
366            for w in warnings {
367                warn!(target: "server_warnings", "server warning: {}", w);
368            }
369        }
370    }
371
372    pub async fn start_query(&self, sql: &str) -> Result<QueryResponse> {
373        info!("start query: {}", sql);
374        self.start_query_inner(sql, None).await
375    }
376
377    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
378        if let Some(info) = &self.session_token_info {
379            let info = info.lock();
380            Ok(builder.bearer_auth(info.0.session_token.clone()))
381        } else {
382            self.auth.wrap(builder)
383        }
384    }
385
386    async fn start_query_inner(
387        &self,
388        sql: &str,
389        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
390    ) -> Result<QueryResponse> {
391        if !self.in_active_transaction() {
392            self.route_hint.next();
393        }
394        let endpoint = self.endpoint.join("v1/query")?;
395
396        // body
397        let session_state = self.session_state();
398        let need_sticky = session_state.need_sticky.unwrap_or(false);
399        let req = QueryRequest::new(sql)
400            .with_pagination(self.make_pagination())
401            .with_session(Some(session_state))
402            .with_stage_attachment(stage_attachment_config);
403
404        // headers
405        let query_id = self.gen_query_id();
406        let mut headers = self.make_headers(Some(&query_id))?;
407        if need_sticky {
408            if let Some(node_id) = self.last_node_id() {
409                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
410            }
411        }
412        let mut builder = self.cli.post(endpoint.clone()).json(&req);
413        builder = self.wrap_auth_or_session_token(builder)?;
414        let request = builder.headers(headers.clone()).build()?;
415        let response = self.query_request_helper(request, true, true).await?;
416        if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
417            self.route_hint.set(route_hint.to_str().unwrap_or_default());
418        }
419        let body = response.bytes().await?;
420        let result: QueryResponse = json_from_slice(&body)?;
421        self.handle_session(&result.session).await;
422        if let Some(err) = result.error {
423            return Err(Error::QueryFailed(err));
424        }
425
426        self.set_last_query_id(Some(query_id));
427        self.handle_warnings(&result);
428        Ok(result)
429    }
430
431    pub async fn query_page(
432        &self,
433        query_id: &str,
434        next_uri: &str,
435        node_id: &Option<String>,
436    ) -> Result<QueryResponse> {
437        info!("query page: {}", next_uri);
438        let endpoint = self.endpoint.join(next_uri)?;
439        let headers = self.make_headers(Some(query_id))?;
440        let mut builder = self.cli.get(endpoint.clone());
441        builder = self
442            .wrap_auth_or_session_token(builder)?
443            .headers(headers.clone())
444            .timeout(self.page_request_timeout);
445        if let Some(node_id) = node_id {
446            builder = builder.header(HEADER_STICKY_NODE, node_id)
447        }
448        let request = builder.build()?;
449
450        let response = self.query_request_helper(request, false, true).await?;
451        let body = response.bytes().await?;
452        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
453            if let Error::Logic(status, ec) = &e {
454                if *status == 404 {
455                    return Error::QueryNotFound(ec.message.clone());
456                }
457            }
458            e
459        })?;
460        self.handle_session(&resp.session).await;
461        match resp.error {
462            Some(err) => Err(Error::QueryFailed(err)),
463            None => Ok(resp),
464        }
465    }
466
467    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
468        let kill_uri = format!("/v1/query/{}/kill", query_id);
469        let endpoint = self.endpoint.join(&kill_uri)?;
470        let headers = self.make_headers(Some(query_id))?;
471        info!("kill query: {}", kill_uri);
472
473        let mut builder = self.cli.post(endpoint);
474        builder = self.wrap_auth_or_session_token(builder)?;
475        let resp = builder.headers(headers.clone()).send().await?;
476        if resp.status() != 200 {
477            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
478                .with_context("kill query"));
479        }
480        Ok(())
481    }
482
483    async fn wait_for_query(&self, resp: QueryResponse) -> Result<QueryResponse> {
484        info!("wait for query: {}", resp.id);
485        let node_id = resp.node_id.clone();
486        if let Some(node_id) = self.last_node_id() {
487            self.set_last_node_id(node_id.clone());
488        }
489
490        if let Some(next_uri) = &resp.next_uri {
491            let schema = resp.schema;
492            let mut data = resp.data;
493            let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?;
494            while let Some(next_uri) = &resp.next_uri {
495                resp = self.query_page(&resp.id, next_uri, &node_id).await?;
496                data.append(&mut resp.data);
497            }
498            resp.schema = schema;
499            resp.data = data;
500            Ok(resp)
501        } else {
502            Ok(resp)
503        }
504    }
505
506    pub async fn query(&self, sql: &str) -> Result<QueryResponse> {
507        let resp = self.start_query(sql).await?;
508        self.wait_for_query(resp).await
509    }
510
511    fn session_state(&self) -> SessionState {
512        self.session_state.lock().clone()
513    }
514
515    fn make_pagination(&self) -> Option<PaginationConfig> {
516        if self.wait_time_secs.is_none()
517            && self.max_rows_in_buffer.is_none()
518            && self.max_rows_per_page.is_none()
519        {
520            return None;
521        }
522        let mut pagination = PaginationConfig {
523            wait_time_secs: None,
524            max_rows_in_buffer: None,
525            max_rows_per_page: None,
526        };
527        if let Some(wait_time_secs) = self.wait_time_secs {
528            pagination.wait_time_secs = Some(wait_time_secs);
529        }
530        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
531            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
532        }
533        if let Some(max_rows_per_page) = self.max_rows_per_page {
534            pagination.max_rows_per_page = Some(max_rows_per_page);
535        }
536        Some(pagination)
537    }
538
539    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
540        let mut headers = HeaderMap::new();
541        if let Some(tenant) = &self.tenant {
542            headers.insert(HEADER_TENANT, tenant.parse()?);
543        }
544        let warehouse = self.warehouse.lock().clone();
545        if let Some(warehouse) = warehouse {
546            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
547        }
548        let route_hint = self.route_hint.current();
549        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
550        if let Some(query_id) = query_id {
551            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
552        }
553        Ok(headers)
554    }
555
556    pub async fn insert_with_stage(
557        &self,
558        sql: &str,
559        stage: &str,
560        file_format_options: BTreeMap<&str, &str>,
561        copy_options: BTreeMap<&str, &str>,
562    ) -> Result<QueryResponse> {
563        info!(
564            "insert with stage: {}, format: {:?}, copy: {:?}",
565            sql, file_format_options, copy_options
566        );
567        let stage_attachment = Some(StageAttachmentConfig {
568            location: stage,
569            file_format_options: Some(file_format_options),
570            copy_options: Some(copy_options),
571        });
572        let resp = self.start_query_inner(sql, stage_attachment).await?;
573        let resp = self.wait_for_query(resp).await?;
574        Ok(resp)
575    }
576
577    async fn get_presigned_upload_url(&self, stage: &str) -> Result<PresignedResponse> {
578        info!("get presigned upload url: {}", stage);
579        let sql = format!("PRESIGN UPLOAD {}", stage);
580        let resp = self.query(&sql).await?;
581        if resp.data.len() != 1 {
582            return Err(Error::Decode(
583                "Empty response from server for presigned request".to_string(),
584            ));
585        }
586        if resp.data[0].len() != 3 {
587            return Err(Error::Decode(
588                "Invalid response from server for presigned request".to_string(),
589            ));
590        }
591        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
592        let method = resp.data[0][0].clone().unwrap_or_default();
593        if method != "PUT" {
594            return Err(Error::Decode(format!(
595                "Invalid method for presigned upload request: {}",
596                method
597            )));
598        }
599        let headers: BTreeMap<String, String> =
600            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
601        let url = resp.data[0][2].clone().unwrap_or_default();
602        Ok(PresignedResponse {
603            method,
604            headers,
605            url,
606        })
607    }
608
609    pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
610        match self.presign {
611            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
612            PresignMode::On => {
613                let presigned = self.get_presigned_upload_url(stage).await?;
614                presign_upload_to_stage(presigned, data, size).await
615            }
616            PresignMode::Auto => {
617                unreachable!("PresignMode::Auto should be handled during client initialization")
618            }
619            PresignMode::Detect => {
620                unreachable!("PresignMode::Detect should be handled during client initialization")
621            }
622        }
623    }
624
625    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
626    async fn upload_to_stage_with_stream(
627        &self,
628        stage: &str,
629        data: Reader,
630        size: u64,
631    ) -> Result<()> {
632        info!("upload to stage with stream: {}, size: {}", stage, size);
633        if let Some(info) = self.need_pre_refresh_session().await {
634            self.refresh_session_token(info).await?;
635        }
636        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
637        let location = StageLocation::try_from(stage)?;
638        let query_id = self.gen_query_id();
639        let mut headers = self.make_headers(Some(&query_id))?;
640        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
641        let stream = Body::wrap_stream(ReaderStream::new(data));
642        let part = Part::stream_with_length(stream, size).file_name(location.path);
643        let form = Form::new().part("upload", part);
644        let mut builder = self.cli.put(endpoint.clone());
645        builder = self.wrap_auth_or_session_token(builder)?;
646        let resp = builder.headers(headers).multipart(form).send().await?;
647        let status = resp.status();
648        if status != 200 {
649            return Err(
650                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
651            );
652        }
653        Ok(())
654    }
655
656    async fn login(&mut self) -> Result<()> {
657        let endpoint = self.endpoint.join("/v1/session/login")?;
658        let headers = self.make_headers(None)?;
659        let body = LoginRequest::from(&*self.session_state.lock());
660        let mut builder = self.cli.post(endpoint.clone()).json(&body);
661        if self.disable_session_token {
662            builder = builder.query(&[("disable_session_token", true)]);
663        }
664        let builder = self.auth.wrap(builder)?;
665        let request = builder
666            .headers(headers.clone())
667            .timeout(self.connect_timeout)
668            .build()?;
669        let response = self.query_request_helper(request, true, false).await;
670        let response = match response {
671            Ok(r) => r,
672            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
673                info!("login return 404, skip login on the old version server");
674                return Ok(());
675            }
676            Err(e) => return Err(e),
677        };
678        let body = response.bytes().await?;
679        let response = json_from_slice(&body)?;
680        match response {
681            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
682            LoginResponseResult::Ok(info) => {
683                self.server_version = Some(info.version.clone());
684                if let Some(tokens) = info.tokens {
685                    info!("login success with session token");
686                    self.session_token_info =
687                        Some(Arc::new(parking_lot::Mutex::new((tokens, Instant::now()))))
688                }
689                info!("login success without session token");
690            }
691        }
692        Ok(())
693    }
694
695    fn build_log_out_request(&self) -> Result<Request> {
696        let endpoint = self.endpoint.join("/v1/session/logout")?;
697
698        let session_state = self.session_state();
699        let need_sticky = session_state.need_sticky.unwrap_or(false);
700        let mut headers = self.make_headers(None)?;
701        if need_sticky {
702            if let Some(node_id) = self.last_node_id() {
703                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
704            }
705        }
706        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
707
708        let builder = self.wrap_auth_or_session_token(builder)?;
709        let req = builder.build()?;
710        Ok(req)
711    }
712
713    fn need_logout(&self) -> bool {
714        (self.session_token_info.is_some()
715            || self.session_state.lock().need_keep_alive.unwrap_or(false))
716            && self
717                .closed
718                .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
719                .is_ok()
720    }
721
722    async fn refresh_session_token(
723        &self,
724        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
725    ) -> Result<()> {
726        let (session_token_info, _) = { self_login_info.lock().clone() };
727        let endpoint = self.endpoint.join("/v1/session/refresh")?;
728        let body = RefreshSessionTokenRequest {
729            session_token: session_token_info.session_token.clone(),
730        };
731        let headers = self.make_headers(None)?;
732        let request = self
733            .cli
734            .post(endpoint.clone())
735            .json(&body)
736            .headers(headers.clone())
737            .bearer_auth(session_token_info.refresh_token.clone())
738            .timeout(self.connect_timeout)
739            .build()?;
740
741        // avoid recursively call request_helper
742        for i in 0..3 {
743            let req = request.try_clone().expect("request not cloneable");
744            match self.cli.execute(req).await {
745                Ok(response) => {
746                    let status = response.status();
747                    let body = response.bytes().await?;
748                    if status == StatusCode::OK {
749                        let response = json_from_slice(&body)?;
750                        return match response {
751                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
752                            RefreshResponse::Ok(info) => {
753                                *self_login_info.lock() = (info, Instant::now());
754                                Ok(())
755                            }
756                        };
757                    }
758                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
759                        return Err(Error::response_error(status, &body));
760                    }
761                }
762                Err(err) => {
763                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
764                        return Err(Error::Request(err.to_string()));
765                    }
766                }
767            };
768            sleep(jitter(Duration::from_secs(10))).await;
769        }
770        Ok(())
771    }
772
773    async fn need_pre_refresh_session(
774        &self,
775    ) -> Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>> {
776        if let Some(info) = &self.session_token_info {
777            let (start, ttl) = {
778                let guard = info.lock();
779                (guard.1, guard.0.session_token_ttl_in_secs)
780            };
781            if Instant::now() > start + Duration::from_secs(ttl) {
782                return Some(info.clone());
783            }
784        }
785        None
786    }
787
788    /// return Ok if and only if status code is 200.
789    ///
790    /// retry on
791    ///   - network errors
792    ///   - (optional) 503
793    ///
794    /// refresh databend token or reload jwt token if needed.
795    async fn query_request_helper(
796        &self,
797        mut request: Request,
798        retry_if_503: bool,
799        refresh_if_401: bool,
800    ) -> std::result::Result<Response, Error> {
801        let mut refreshed = false;
802        let mut retries = 0;
803        loop {
804            let req = request.try_clone().expect("request not cloneable");
805            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
806                Ok(response) => {
807                    let status = response.status();
808                    if status == StatusCode::OK {
809                        return Ok(response);
810                    }
811                    let body = response.bytes().await?;
812                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
813                        // waiting for server to start
814                        (Error::response_error(status, &body), true)
815                    } else {
816                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
817                        match resp {
818                            Ok(r) => {
819                                let e = r.error;
820                                if status == StatusCode::UNAUTHORIZED {
821                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
822                                    if let Some(session_token_info) = &self.session_token_info {
823                                        info!(
824                                            "will retry {} after refresh token on auth error {}",
825                                            request.url(),
826                                            e
827                                        );
828                                        let retry = if need_refresh_token(e.code)
829                                            && !refreshed
830                                            && refresh_if_401
831                                        {
832                                            self.refresh_session_token(session_token_info.clone())
833                                                .await?;
834                                            refreshed = true;
835                                            true
836                                        } else {
837                                            false
838                                        };
839                                        (Error::AuthFailure(e), retry)
840                                    } else if self.auth.can_reload() {
841                                        info!(
842                                            "will retry {} after reload token on auth error {}",
843                                            request.url(),
844                                            e
845                                        );
846                                        let builder = RequestBuilder::from_parts(
847                                            HttpClient::new(),
848                                            request.try_clone().unwrap(),
849                                        );
850                                        let builder = self.auth.wrap(builder)?;
851                                        request = builder.build()?;
852                                        (Error::AuthFailure(e), true)
853                                    } else {
854                                        (Error::AuthFailure(e), false)
855                                    }
856                                } else {
857                                    (Error::Logic(status, e), false)
858                                }
859                            }
860                            Err(_) => (
861                                Error::Response {
862                                    status,
863                                    msg: String::from_utf8_lossy(&body).to_string(),
864                                },
865                                false,
866                            ),
867                        }
868                    }
869                }
870                Err(err) => (
871                    Error::Request(err.to_string()),
872                    err.is_timeout() || err.is_connect(),
873                ),
874            };
875            if !retry {
876                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
877            }
878            match &err {
879                Error::AuthFailure(_) => {
880                    if refreshed {
881                        retries = 0;
882                    } else if retries == 2 {
883                        return Err(err.with_context(&format!(
884                            "{} {} after 3 reties",
885                            request.method(),
886                            request.url()
887                        )));
888                    }
889                }
890                _ => {
891                    if retries == 2 {
892                        return Err(err.with_context(&format!(
893                            "{} {} after 3 reties",
894                            request.method(),
895                            request.url()
896                        )));
897                    }
898                    retries += 1;
899                    info!(
900                        "will retry {} the {retries}th times on error {}",
901                        request.url(),
902                        err
903                    );
904                }
905            }
906            sleep(jitter(Duration::from_secs(10))).await;
907        }
908    }
909
910    pub async fn close(&self) {
911        if self.need_logout() {
912            let cli = self.cli.clone();
913            let req = self
914                .build_log_out_request()
915                .expect("failed to build logout request");
916            if let Err(err) = cli.execute(req).await {
917                error!("logout request failed: {}", err);
918            } else {
919                debug!("logout success");
920            };
921        }
922    }
923}
924
925impl Drop for APIClient {
926    fn drop(&mut self) {
927        if self.need_logout() {
928            warn!("APIClient::close() was not called");
929        }
930    }
931}
932
933fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
934where
935    T: Deserialize<'a>,
936{
937    serde_json::from_slice::<T>(body).map_err(|e| {
938        Error::Decode(format!(
939            "fail to decode JSON response: {}, body: {}",
940            e,
941            String::from_utf8_lossy(body)
942        ))
943    })
944}
945
946impl Default for APIClient {
947    fn default() -> Self {
948        Self {
949            cli: HttpClient::new(),
950            scheme: "http".to_string(),
951            endpoint: Url::parse("http://localhost:8080").unwrap(),
952            host: "localhost".to_string(),
953            port: 8000,
954            tenant: None,
955            warehouse: Arc::new(parking_lot::Mutex::new(None)),
956            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
957            session_state: Arc::new(parking_lot::Mutex::new(SessionState::default())),
958            wait_time_secs: None,
959            max_rows_in_buffer: None,
960            max_rows_per_page: None,
961            connect_timeout: Duration::from_secs(10),
962            page_request_timeout: Duration::from_secs(30),
963            tls_ca_file: None,
964            presign: PresignMode::Auto,
965            route_hint: Arc::new(RouteHintGenerator::new()),
966            last_node_id: Arc::new(Default::default()),
967            disable_session_token: true,
968            disable_login: false,
969            session_token_info: None,
970            closed: Arc::new(Default::default()),
971            last_query_id: Arc::new(Default::default()),
972            server_version: None,
973        }
974    }
975}
976
977struct RouteHintGenerator {
978    nonce: AtomicU64,
979    current: std::sync::Mutex<String>,
980}
981
982impl RouteHintGenerator {
983    fn new() -> Self {
984        let gen = Self {
985            nonce: AtomicU64::new(0),
986            current: std::sync::Mutex::new("".to_string()),
987        };
988        gen.next();
989        gen
990    }
991
992    fn current(&self) -> String {
993        let guard = self.current.lock().unwrap();
994        guard.clone()
995    }
996
997    fn set(&self, hint: &str) {
998        let mut guard = self.current.lock().unwrap();
999        *guard = hint.to_string();
1000    }
1001
1002    fn next(&self) -> String {
1003        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1004        let uuid = uuid::Uuid::new_v4();
1005        let current = format!("rh:{}:{:06}", uuid, nonce);
1006        let mut guard = self.current.lock().unwrap();
1007        guard.clone_from(&current);
1008        current
1009    }
1010}
1011
1012#[cfg(test)]
1013mod test {
1014    use super::*;
1015
1016    #[tokio::test]
1017    async fn parse_dsn() -> Result<()> {
1018        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1019        let client = APIClient::from_dsn(dsn).await?;
1020        assert_eq!(client.host, "app.databend.com");
1021        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1022        assert_eq!(client.wait_time_secs, Some(10));
1023        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1024        assert_eq!(client.max_rows_per_page, Some(10000));
1025        assert_eq!(client.tenant, None);
1026        assert_eq!(
1027            *client.warehouse.try_lock().unwrap(),
1028            Some("wh".to_string())
1029        );
1030        Ok(())
1031    }
1032
1033    #[tokio::test]
1034    async fn parse_encoded_password() -> Result<()> {
1035        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1036        let client = APIClient::from_dsn(dsn).await?;
1037        assert_eq!(client.host(), "localhost");
1038        assert_eq!(client.port(), 443);
1039        Ok(())
1040    }
1041
1042    #[tokio::test]
1043    async fn parse_special_chars_password() -> Result<()> {
1044        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1045        let client = APIClient::from_dsn(dsn).await?;
1046        assert_eq!(client.host(), "localhost");
1047        assert_eq!(client.port(), 8000);
1048        Ok(())
1049    }
1050}