databend_client/
client.rs

1// Copyright 2021 Datafuse Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
18use crate::global_cookie_store::GlobalCookieStore;
19use crate::login::{
20    LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
21    SessionTokenInfo,
22};
23use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
24use crate::response::LoadResponse;
25use crate::stage::StageLocation;
26use crate::{
27    error::{Error, Result},
28    request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
29    response::QueryResponse,
30    session::SessionState,
31    QueryStats,
32};
33use crate::{Page, Pages};
34use base64::engine::general_purpose::URL_SAFE;
35use base64::Engine;
36use log::{debug, error, info, warn};
37use once_cell::sync::Lazy;
38use parking_lot::Mutex;
39use percent_encoding::percent_decode_str;
40use reqwest::cookie::CookieStore;
41use reqwest::header::{HeaderMap, HeaderValue};
42use reqwest::multipart::{Form, Part};
43use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
44use semver::Version;
45use serde::{de, Deserialize};
46use std::collections::BTreeMap;
47use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::time::sleep;
51use tokio_retry::strategy::jitter;
52use tokio_stream::StreamExt;
53use tokio_util::io::ReaderStream;
54use url::Url;
55
56const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
57const HEADER_TENANT: &str = "X-DATABEND-TENANT";
58const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
59const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
60const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
61const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
62const TXN_STATE_ACTIVE: &str = "Active";
63const HEADER_SQL: &str = "X-DATABEND-SQL";
64const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
65
66static VERSION: Lazy<String> = Lazy::new(|| {
67    let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
68    version.to_string()
69});
70
71pub struct APIClient {
72    cli: HttpClient,
73    scheme: String,
74    host: String,
75    port: u16,
76
77    endpoint: Url,
78
79    auth: Arc<dyn Auth>,
80
81    tenant: Option<String>,
82    warehouse: Mutex<Option<String>>,
83    session_state: Mutex<SessionState>,
84    route_hint: RouteHintGenerator,
85
86    disable_login: bool,
87    disable_session_token: bool,
88    session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
89
90    closed: AtomicBool,
91
92    server_version: Option<Version>,
93
94    wait_time_secs: Option<i64>,
95    max_rows_in_buffer: Option<i64>,
96    max_rows_per_page: Option<i64>,
97
98    connect_timeout: Duration,
99    page_request_timeout: Duration,
100
101    tls_ca_file: Option<String>,
102
103    presign: Mutex<PresignMode>,
104    last_node_id: Mutex<Option<String>>,
105    last_query_id: Mutex<Option<String>>,
106
107    capability: Capability,
108}
109
110impl APIClient {
111    pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
112        let mut client = Self::from_dsn(dsn).await?;
113        client.build_client(name).await?;
114        if !client.disable_login {
115            client.login().await?;
116        }
117        let client = Arc::new(client);
118        client.check_presign().await?;
119        Ok(client)
120    }
121
122    pub fn capability(&self) -> &Capability {
123        &self.capability
124    }
125
126    fn set_presign_mode(&self, mode: PresignMode) {
127        *self.presign.lock() = mode
128    }
129    fn get_presign_mode(&self) -> PresignMode {
130        *self.presign.lock()
131    }
132
133    async fn from_dsn(dsn: &str) -> Result<Self> {
134        let u = Url::parse(dsn)?;
135        let mut client = Self::default();
136        if let Some(host) = u.host_str() {
137            client.host = host.to_string();
138        }
139
140        if u.username() != "" {
141            let password = u.password().unwrap_or_default();
142            let password = percent_decode_str(password).decode_utf8()?;
143            client.auth = Arc::new(BasicAuth::new(u.username(), password));
144        }
145        let database = match u.path().trim_start_matches('/') {
146            "" => None,
147            s => Some(s.to_string()),
148        };
149        let mut role = None;
150        let mut scheme = "https";
151        let mut session_settings = BTreeMap::new();
152        for (k, v) in u.query_pairs() {
153            match k.as_ref() {
154                "wait_time_secs" => {
155                    client.wait_time_secs = Some(v.parse()?);
156                }
157                "max_rows_in_buffer" => {
158                    client.max_rows_in_buffer = Some(v.parse()?);
159                }
160                "max_rows_per_page" => {
161                    client.max_rows_per_page = Some(v.parse()?);
162                }
163                "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
164                "page_request_timeout_secs" => {
165                    client.page_request_timeout = {
166                        let secs: u64 = v.parse()?;
167                        Duration::from_secs(secs)
168                    };
169                }
170                "presign" => {
171                    let presign_mode = match v.as_ref() {
172                        "auto" => PresignMode::Auto,
173                        "detect" => PresignMode::Detect,
174                        "on" => PresignMode::On,
175                        "off" => PresignMode::Off,
176                        _ => {
177                            return Err(Error::BadArgument(format!(
178                            "Invalid value for presign: {v}, should be one of auto/detect/on/off"
179                        )))
180                        }
181                    };
182                    client.set_presign_mode(presign_mode);
183                }
184                "tenant" => {
185                    client.tenant = Some(v.to_string());
186                }
187                "warehouse" => {
188                    client.warehouse = Mutex::new(Some(v.to_string()));
189                }
190                "role" => role = Some(v.to_string()),
191                "sslmode" => match v.as_ref() {
192                    "disable" => scheme = "http",
193                    "require" | "enable" => scheme = "https",
194                    _ => {
195                        return Err(Error::BadArgument(format!(
196                            "Invalid value for sslmode: {v}"
197                        )))
198                    }
199                },
200                "tls_ca_file" => {
201                    client.tls_ca_file = Some(v.to_string());
202                }
203                "access_token" => {
204                    client.auth = Arc::new(AccessTokenAuth::new(v));
205                }
206                "access_token_file" => {
207                    client.auth = Arc::new(AccessTokenFileAuth::new(v));
208                }
209                "login" => {
210                    client.disable_login = match v.as_ref() {
211                        "disable" => true,
212                        "enable" => false,
213                        _ => {
214                            return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
215                        }
216                    }
217                }
218                "session_token" => {
219                    client.disable_session_token = match v.as_ref() {
220                        "disable" => true,
221                        "enable" => false,
222                        _ => {
223                            return Err(Error::BadArgument(format!(
224                                "Invalid value for session_token: {v}"
225                            )))
226                        }
227                    }
228                }
229                _ => {
230                    session_settings.insert(k.to_string(), v.to_string());
231                }
232            }
233        }
234        client.port = match u.port() {
235            Some(p) => p,
236            None => match scheme {
237                "http" => 80,
238                "https" => 443,
239                _ => unreachable!(),
240            },
241        };
242        client.scheme = scheme.to_string();
243
244        client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
245        client.session_state = Mutex::new(
246            SessionState::default()
247                .with_settings(Some(session_settings))
248                .with_role(role)
249                .with_database(database),
250        );
251
252        Ok(client)
253    }
254
255    pub fn host(&self) -> &str {
256        self.host.as_str()
257    }
258
259    pub fn port(&self) -> u16 {
260        self.port
261    }
262
263    pub fn scheme(&self) -> &str {
264        self.scheme.as_str()
265    }
266
267    async fn build_client(&mut self, name: Option<String>) -> Result<()> {
268        let ua = match name {
269            Some(n) => n,
270            None => format!("databend-client-rust/{}", VERSION.as_str()),
271        };
272        let cookie_provider = GlobalCookieStore::new();
273        let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
274        let mut initial_cookies = [&cookie].into_iter();
275        cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
276        let mut cli_builder = HttpClient::builder()
277            .user_agent(ua)
278            .cookie_provider(Arc::new(cookie_provider))
279            .pool_idle_timeout(Duration::from_secs(1));
280        #[cfg(any(feature = "rustls", feature = "native-tls"))]
281        if self.scheme == "https" {
282            if let Some(ref ca_file) = self.tls_ca_file {
283                let cert_pem = tokio::fs::read(ca_file).await?;
284                let cert = reqwest::Certificate::from_pem(&cert_pem)?;
285                cli_builder = cli_builder.add_root_certificate(cert);
286            }
287        }
288        self.cli = cli_builder.build()?;
289        Ok(())
290    }
291
292    async fn check_presign(self: &Arc<Self>) -> Result<()> {
293        let mode = match self.get_presign_mode() {
294            PresignMode::Auto => {
295                if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
296                    PresignMode::On
297                } else {
298                    PresignMode::Off
299                }
300            }
301            PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
302                Ok(_) => PresignMode::On,
303                Err(e) => {
304                    warn!("presign mode off with error detected: {e}");
305                    PresignMode::Off
306                }
307            },
308            mode => mode,
309        };
310        self.set_presign_mode(mode);
311        Ok(())
312    }
313
314    pub fn current_warehouse(&self) -> Option<String> {
315        let guard = self.warehouse.lock();
316        guard.clone()
317    }
318
319    pub fn current_catalog(&self) -> Option<String> {
320        let guard = self.session_state.lock();
321        guard.catalog.clone()
322    }
323
324    pub fn current_database(&self) -> Option<String> {
325        let guard = self.session_state.lock();
326        guard.database.clone()
327    }
328
329    pub async fn current_role(&self) -> Option<String> {
330        let guard = self.session_state.lock();
331        guard.role.clone()
332    }
333
334    fn in_active_transaction(&self) -> bool {
335        let guard = self.session_state.lock();
336        guard
337            .txn_state
338            .as_ref()
339            .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
340            .unwrap_or(false)
341    }
342
343    pub fn username(&self) -> String {
344        self.auth.username()
345    }
346
347    fn gen_query_id(&self) -> String {
348        uuid::Uuid::now_v7().simple().to_string()
349    }
350
351    async fn handle_session(&self, session: &Option<SessionState>) {
352        let session = match session {
353            Some(session) => session,
354            None => return,
355        };
356
357        // save the updated session state from the server side
358        {
359            let mut session_state = self.session_state.lock();
360            *session_state = session.clone();
361        }
362
363        // process warehouse changed via session settings
364        if let Some(settings) = session.settings.as_ref() {
365            if let Some(v) = settings.get("warehouse") {
366                let mut warehouse = self.warehouse.lock();
367                *warehouse = Some(v.clone());
368            }
369        }
370    }
371
372    pub fn set_last_node_id(&self, node_id: String) {
373        *self.last_node_id.lock() = Some(node_id)
374    }
375
376    pub fn set_last_query_id(&self, query_id: Option<String>) {
377        *self.last_query_id.lock() = query_id
378    }
379
380    pub fn last_query_id(&self) -> Option<String> {
381        self.last_query_id.lock().clone()
382    }
383
384    fn last_node_id(&self) -> Option<String> {
385        self.last_node_id.lock().clone()
386    }
387
388    fn handle_warnings(&self, resp: &QueryResponse) {
389        if let Some(warnings) = &resp.warnings {
390            for w in warnings {
391                warn!(target: "server_warnings", "server warning: {w}");
392            }
393        }
394    }
395
396    pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
397        info!("start query: {sql}");
398        let resp = self.start_query_inner(sql, None).await?;
399        let pages = Pages::new(self.clone(), resp, need_progress);
400        Ok(pages)
401    }
402
403    fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
404        if let Some(info) = &self.session_token_info {
405            let info = info.lock();
406            Ok(builder.bearer_auth(info.0.session_token.clone()))
407        } else {
408            self.auth.wrap(builder)
409        }
410    }
411
412    async fn start_query_inner(
413        &self,
414        sql: &str,
415        stage_attachment_config: Option<StageAttachmentConfig<'_>>,
416    ) -> Result<QueryResponse> {
417        if !self.in_active_transaction() {
418            self.route_hint.next();
419        }
420        let endpoint = self.endpoint.join("v1/query")?;
421
422        // body
423        let session_state = self.session_state();
424        let need_sticky = session_state.need_sticky.unwrap_or(false);
425        let req = QueryRequest::new(sql)
426            .with_pagination(self.make_pagination())
427            .with_session(Some(session_state))
428            .with_stage_attachment(stage_attachment_config);
429
430        // headers
431        let query_id = self.gen_query_id();
432        let mut headers = self.make_headers(Some(&query_id))?;
433        if need_sticky {
434            if let Some(node_id) = self.last_node_id() {
435                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
436            }
437        }
438        let mut builder = self.cli.post(endpoint.clone()).json(&req);
439        builder = self.wrap_auth_or_session_token(builder)?;
440        let request = builder.headers(headers.clone()).build()?;
441        let response = self.query_request_helper(request, true, true).await?;
442        if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
443            self.route_hint.set(route_hint.to_str().unwrap_or_default());
444        }
445        let body = response.bytes().await?;
446        let result: QueryResponse = json_from_slice(&body)?;
447        self.handle_session(&result.session).await;
448        if let Some(err) = result.error {
449            return Err(Error::QueryFailed(err));
450        }
451
452        self.set_last_query_id(Some(query_id));
453        self.handle_warnings(&result);
454        if let Some(node_id) = &result.node_id {
455            self.set_last_node_id(node_id.clone());
456        }
457        Ok(result)
458    }
459
460    pub async fn query_page(
461        &self,
462        query_id: &str,
463        next_uri: &str,
464        node_id: &Option<String>,
465    ) -> Result<QueryResponse> {
466        info!("query page: {next_uri}");
467        let endpoint = self.endpoint.join(next_uri)?;
468        let headers = self.make_headers(Some(query_id))?;
469        let mut builder = self.cli.get(endpoint.clone());
470        builder = self
471            .wrap_auth_or_session_token(builder)?
472            .headers(headers.clone())
473            .timeout(self.page_request_timeout);
474        if let Some(node_id) = node_id {
475            builder = builder.header(HEADER_STICKY_NODE, node_id)
476        }
477        let request = builder.build()?;
478
479        let response = self.query_request_helper(request, false, true).await?;
480        let status = response.status();
481        if status != 200 {
482            return Err(Error::response_error(status, &response.bytes().await?));
483        }
484        let body = response.bytes().await?;
485        let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
486            if let Error::Logic(status, ec) = &e {
487                if *status == 404 {
488                    return Error::QueryNotFound(ec.message.clone());
489                }
490            }
491            e
492        })?;
493        self.handle_session(&resp.session).await;
494        match resp.error {
495            Some(err) => Err(Error::QueryFailed(err)),
496            None => Ok(resp),
497        }
498    }
499
500    pub async fn kill_query(&self, query_id: &str) -> Result<()> {
501        let kill_uri = format!("/v1/query/{query_id}/kill");
502        let endpoint = self.endpoint.join(&kill_uri)?;
503        let headers = self.make_headers(Some(query_id))?;
504        info!("kill query: {kill_uri}");
505
506        let mut builder = self.cli.post(endpoint);
507        builder = self.wrap_auth_or_session_token(builder)?;
508        let resp = builder.headers(headers.clone()).send().await?;
509        if resp.status() != 200 {
510            return Err(Error::response_error(resp.status(), &resp.bytes().await?)
511                .with_context("kill query"));
512        }
513        Ok(())
514    }
515
516    pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
517        let mut pages = self.start_query(sql, false).await?;
518        let mut all = Page::default();
519        while let Some(page) = pages.next().await {
520            all.update(page?);
521        }
522        Ok(all)
523    }
524
525    fn session_state(&self) -> SessionState {
526        self.session_state.lock().clone()
527    }
528
529    fn make_pagination(&self) -> Option<PaginationConfig> {
530        if self.wait_time_secs.is_none()
531            && self.max_rows_in_buffer.is_none()
532            && self.max_rows_per_page.is_none()
533        {
534            return None;
535        }
536        let mut pagination = PaginationConfig {
537            wait_time_secs: None,
538            max_rows_in_buffer: None,
539            max_rows_per_page: None,
540        };
541        if let Some(wait_time_secs) = self.wait_time_secs {
542            pagination.wait_time_secs = Some(wait_time_secs);
543        }
544        if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
545            pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
546        }
547        if let Some(max_rows_per_page) = self.max_rows_per_page {
548            pagination.max_rows_per_page = Some(max_rows_per_page);
549        }
550        Some(pagination)
551    }
552
553    fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
554        let mut headers = HeaderMap::new();
555        if let Some(tenant) = &self.tenant {
556            headers.insert(HEADER_TENANT, tenant.parse()?);
557        }
558        let warehouse = self.warehouse.lock().clone();
559        if let Some(warehouse) = warehouse {
560            headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
561        }
562        let route_hint = self.route_hint.current();
563        headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
564        if let Some(query_id) = query_id {
565            headers.insert(HEADER_QUERY_ID, query_id.parse()?);
566        }
567        Ok(headers)
568    }
569
570    pub async fn insert_with_stage(
571        self: &Arc<Self>,
572        sql: &str,
573        stage: &str,
574        file_format_options: BTreeMap<&str, &str>,
575        copy_options: BTreeMap<&str, &str>,
576    ) -> Result<QueryStats> {
577        info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
578        let stage_attachment = Some(StageAttachmentConfig {
579            location: stage,
580            file_format_options: Some(file_format_options),
581            copy_options: Some(copy_options),
582        });
583        let resp = self.start_query_inner(sql, stage_attachment).await?;
584        let mut pages = Pages::new(self.clone(), resp, false);
585        let mut all = Page::default();
586        while let Some(page) = pages.next().await {
587            all.update(page?);
588        }
589        Ok(all.stats)
590    }
591
592    async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
593        info!("get presigned upload url: {stage}");
594        let sql = format!("PRESIGN UPLOAD {stage}");
595        let resp = self.query_all(&sql).await?;
596        if resp.data.len() != 1 {
597            return Err(Error::Decode(
598                "Empty response from server for presigned request".to_string(),
599            ));
600        }
601        if resp.data[0].len() != 3 {
602            return Err(Error::Decode(
603                "Invalid response from server for presigned request".to_string(),
604            ));
605        }
606        // resp.data[0]: [ "PUT", "{\"host\":\"s3.us-east-2.amazonaws.com\"}", "https://s3.us-east-2.amazonaws.com/query-storage-xxxxx/tnxxxxx/stage/user/xxxx/xxx?" ]
607        let method = resp.data[0][0].clone().unwrap_or_default();
608        if method != "PUT" {
609            return Err(Error::Decode(format!(
610                "Invalid method for presigned upload request: {method}"
611            )));
612        }
613        let headers: BTreeMap<String, String> =
614            serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
615        let url = resp.data[0][2].clone().unwrap_or_default();
616        Ok(PresignedResponse {
617            method,
618            headers,
619            url,
620        })
621    }
622
623    pub async fn upload_to_stage(
624        self: &Arc<Self>,
625        stage: &str,
626        data: Reader,
627        size: u64,
628    ) -> Result<()> {
629        match self.get_presign_mode() {
630            PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
631            PresignMode::On => {
632                let presigned = self.get_presigned_upload_url(stage).await?;
633                presign_upload_to_stage(presigned, data, size).await
634            }
635            PresignMode::Auto => {
636                unreachable!("PresignMode::Auto should be handled during client initialization")
637            }
638            PresignMode::Detect => {
639                unreachable!("PresignMode::Detect should be handled during client initialization")
640            }
641        }
642    }
643
644    /// Upload data to stage with stream api, should not be used directly, use `upload_to_stage` instead.
645    async fn upload_to_stage_with_stream(
646        &self,
647        stage: &str,
648        data: Reader,
649        size: u64,
650    ) -> Result<()> {
651        info!("upload to stage with stream: {stage}, size: {size}");
652        if let Some(info) = self.need_pre_refresh_session().await {
653            self.refresh_session_token(info).await?;
654        }
655        let endpoint = self.endpoint.join("v1/upload_to_stage")?;
656        let location = StageLocation::try_from(stage)?;
657        let query_id = self.gen_query_id();
658        let mut headers = self.make_headers(Some(&query_id))?;
659        headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
660        let stream = Body::wrap_stream(ReaderStream::new(data));
661        let part = Part::stream_with_length(stream, size).file_name(location.path);
662        let form = Form::new().part("upload", part);
663        let mut builder = self.cli.put(endpoint.clone());
664        builder = self.wrap_auth_or_session_token(builder)?;
665        let resp = builder.headers(headers).multipart(form).send().await?;
666        let status = resp.status();
667        if status != 200 {
668            return Err(
669                Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
670            );
671        }
672        Ok(())
673    }
674
675    // use base64 encode whenever possible for safety
676    // but also accept raw JSON for test/debug/one-shot operations
677    pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
678    where
679        T: de::DeserializeOwned,
680    {
681        if value.starts_with("{") {
682            serde_json::from_slice(value.as_bytes())
683                .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
684        } else {
685            let json = URL_SAFE.decode(value).map_err(|e| {
686                format!(
687                    "Invalid value {} for {key}, base64 decode error: {}",
688                    value, e
689                )
690            })?;
691            serde_json::from_slice(&json).map_err(|e| {
692                format!(
693                    "Invalid value {value} for {key}, JSON value {},  decode error: {e}",
694                    String::from_utf8_lossy(&json)
695                )
696            })
697        }
698    }
699
700    pub async fn streaming_load(
701        &self,
702        sql: &str,
703        data: Reader,
704        file_name: &str,
705    ) -> Result<LoadResponse> {
706        let body = Body::wrap_stream(ReaderStream::new(data));
707        let part = Part::stream(body).file_name(file_name.to_string());
708        let endpoint = self.endpoint.join("v1/streaming_load")?;
709        let mut builder = self.cli.put(endpoint.clone());
710        builder = self.wrap_auth_or_session_token(builder)?;
711        let query_id = self.gen_query_id();
712        let mut headers = self.make_headers(Some(&query_id))?;
713        headers.insert(HEADER_SQL, sql.parse()?);
714        let session = serde_json::to_string(&*self.session_state.lock())
715            .expect("serialize session state should not fail");
716        headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
717        let form = Form::new().part("upload", part);
718        let resp = builder.headers(headers).multipart(form).send().await?;
719        let status = resp.status();
720        if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
721            match Self::decode_json_header::<SessionState>(
722                HEADER_QUERY_CONTEXT,
723                value.to_str().unwrap(),
724            ) {
725                Ok(session) => *self.session_state.lock() = session,
726                Err(e) => {
727                    error!("Error decoding session state when streaming load: {e}");
728                }
729            }
730        };
731        if status != 200 {
732            return Err(
733                Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
734            );
735        }
736        let resp = resp.json::<LoadResponse>().await?;
737        Ok(resp)
738    }
739
740    async fn login(&mut self) -> Result<()> {
741        let endpoint = self.endpoint.join("/v1/session/login")?;
742        let headers = self.make_headers(None)?;
743        let body = LoginRequest::from(&*self.session_state.lock());
744        let mut builder = self.cli.post(endpoint.clone()).json(&body);
745        if self.disable_session_token {
746            builder = builder.query(&[("disable_session_token", true)]);
747        }
748        let builder = self.auth.wrap(builder)?;
749        let request = builder
750            .headers(headers.clone())
751            .timeout(self.connect_timeout)
752            .build()?;
753        let response = self.query_request_helper(request, true, false).await;
754        let response = match response {
755            Ok(r) => r,
756            Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
757                info!("login return 404, skip login on the old version server");
758                return Ok(());
759            }
760            Err(e) => return Err(e),
761        };
762        let body = response.bytes().await?;
763        let response = json_from_slice(&body)?;
764        match response {
765            LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
766            LoginResponseResult::Ok(info) => {
767                let server_version = info
768                    .version
769                    .parse()
770                    .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
771                self.capability = Capability::from_server_version(&server_version);
772                self.server_version = Some(server_version.clone());
773                if let Some(tokens) = info.tokens {
774                    info!("login success with session token version = {server_version}",);
775                    self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
776                } else {
777                    info!("login success, version = {server_version}");
778                }
779            }
780        }
781        Ok(())
782    }
783
784    fn build_log_out_request(&self) -> Result<Request> {
785        let endpoint = self.endpoint.join("/v1/session/logout")?;
786
787        let session_state = self.session_state();
788        let need_sticky = session_state.need_sticky.unwrap_or(false);
789        let mut headers = self.make_headers(None)?;
790        if need_sticky {
791            if let Some(node_id) = self.last_node_id() {
792                headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
793            }
794        }
795        let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
796
797        let builder = self.wrap_auth_or_session_token(builder)?;
798        let req = builder.build()?;
799        Ok(req)
800    }
801
802    fn need_logout(&self) -> bool {
803        (self.session_token_info.is_some()
804            || self.session_state.lock().need_keep_alive.unwrap_or(false))
805            && self
806                .closed
807                .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
808                .is_ok()
809    }
810
811    async fn refresh_session_token(
812        &self,
813        self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
814    ) -> Result<()> {
815        let (session_token_info, _) = { self_login_info.lock().clone() };
816        let endpoint = self.endpoint.join("/v1/session/refresh")?;
817        let body = RefreshSessionTokenRequest {
818            session_token: session_token_info.session_token.clone(),
819        };
820        let headers = self.make_headers(None)?;
821        let request = self
822            .cli
823            .post(endpoint.clone())
824            .json(&body)
825            .headers(headers.clone())
826            .bearer_auth(session_token_info.refresh_token.clone())
827            .timeout(self.connect_timeout)
828            .build()?;
829
830        // avoid recursively call request_helper
831        for i in 0..3 {
832            let req = request.try_clone().expect("request not cloneable");
833            match self.cli.execute(req).await {
834                Ok(response) => {
835                    let status = response.status();
836                    let body = response.bytes().await?;
837                    if status == StatusCode::OK {
838                        let response = json_from_slice(&body)?;
839                        return match response {
840                            RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
841                            RefreshResponse::Ok(info) => {
842                                *self_login_info.lock() = (info, Instant::now());
843                                Ok(())
844                            }
845                        };
846                    }
847                    if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
848                        return Err(Error::response_error(status, &body));
849                    }
850                }
851                Err(err) => {
852                    if !(err.is_timeout() || err.is_connect()) || i > 2 {
853                        return Err(Error::Request(err.to_string()));
854                    }
855                }
856            };
857            sleep(jitter(Duration::from_secs(10))).await;
858        }
859        Ok(())
860    }
861
862    async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
863        if let Some(info) = &self.session_token_info {
864            let (start, ttl) = {
865                let guard = info.lock();
866                (guard.1, guard.0.session_token_ttl_in_secs)
867            };
868            if Instant::now() > start + Duration::from_secs(ttl) {
869                return Some(info.clone());
870            }
871        }
872        None
873    }
874
875    /// return Ok if and only if status code is 200.
876    ///
877    /// retry on
878    ///   - network errors
879    ///   - (optional) 503
880    ///
881    /// refresh databend token or reload jwt token if needed.
882    async fn query_request_helper(
883        &self,
884        mut request: Request,
885        retry_if_503: bool,
886        refresh_if_401: bool,
887    ) -> std::result::Result<Response, Error> {
888        let mut refreshed = false;
889        let mut retries = 0;
890        loop {
891            let req = request.try_clone().expect("request not cloneable");
892            let (err, retry): (Error, bool) = match self.cli.execute(req).await {
893                Ok(response) => {
894                    let status = response.status();
895                    if status == StatusCode::OK {
896                        return Ok(response);
897                    }
898                    let body = response.bytes().await?;
899                    if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
900                        // waiting for server to start
901                        (Error::response_error(status, &body), true)
902                    } else {
903                        let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
904                        match resp {
905                            Ok(r) => {
906                                let e = r.error;
907                                if status == StatusCode::UNAUTHORIZED {
908                                    request.headers_mut().remove(reqwest::header::AUTHORIZATION);
909                                    if let Some(session_token_info) = &self.session_token_info {
910                                        info!(
911                                            "will retry {} after refresh token on auth error {}",
912                                            request.url(),
913                                            e
914                                        );
915                                        let retry = if need_refresh_token(e.code)
916                                            && !refreshed
917                                            && refresh_if_401
918                                        {
919                                            self.refresh_session_token(session_token_info.clone())
920                                                .await?;
921                                            refreshed = true;
922                                            true
923                                        } else {
924                                            false
925                                        };
926                                        (Error::AuthFailure(e), retry)
927                                    } else if self.auth.can_reload() {
928                                        info!(
929                                            "will retry {} after reload token on auth error {}",
930                                            request.url(),
931                                            e
932                                        );
933                                        let builder = RequestBuilder::from_parts(
934                                            HttpClient::new(),
935                                            request.try_clone().unwrap(),
936                                        );
937                                        let builder = self.auth.wrap(builder)?;
938                                        request = builder.build()?;
939                                        (Error::AuthFailure(e), true)
940                                    } else {
941                                        (Error::AuthFailure(e), false)
942                                    }
943                                } else {
944                                    (Error::Logic(status, e), false)
945                                }
946                            }
947                            Err(_) => (
948                                Error::Response {
949                                    status,
950                                    msg: String::from_utf8_lossy(&body).to_string(),
951                                },
952                                false,
953                            ),
954                        }
955                    }
956                }
957                Err(err) => (
958                    Error::Request(err.to_string()),
959                    err.is_timeout() || err.is_connect(),
960                ),
961            };
962            if !retry {
963                return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
964            }
965            match &err {
966                Error::AuthFailure(_) => {
967                    if refreshed {
968                        retries = 0;
969                    } else if retries == 2 {
970                        return Err(err.with_context(&format!(
971                            "{} {} after 3 reties",
972                            request.method(),
973                            request.url()
974                        )));
975                    }
976                }
977                _ => {
978                    if retries == 2 {
979                        return Err(err.with_context(&format!(
980                            "{} {} after 3 reties",
981                            request.method(),
982                            request.url()
983                        )));
984                    }
985                    retries += 1;
986                    info!(
987                        "will retry {} the {retries}th times on error {}",
988                        request.url(),
989                        err
990                    );
991                }
992            }
993            sleep(jitter(Duration::from_secs(10))).await;
994        }
995    }
996
997    pub async fn close(&self) {
998        if self.need_logout() {
999            let cli = self.cli.clone();
1000            let req = self
1001                .build_log_out_request()
1002                .expect("failed to build logout request");
1003            if let Err(err) = cli.execute(req).await {
1004                error!("logout request failed: {err}");
1005            } else {
1006                debug!("logout success");
1007            };
1008        }
1009    }
1010}
1011
1012impl Drop for APIClient {
1013    fn drop(&mut self) {
1014        if self.need_logout() {
1015            warn!("APIClient::close() was not called");
1016        }
1017    }
1018}
1019
1020fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1021where
1022    T: Deserialize<'a>,
1023{
1024    serde_json::from_slice::<T>(body).map_err(|e| {
1025        Error::Decode(format!(
1026            "fail to decode JSON response: {e}, body: {}",
1027            String::from_utf8_lossy(body)
1028        ))
1029    })
1030}
1031
1032impl Default for APIClient {
1033    fn default() -> Self {
1034        Self {
1035            cli: HttpClient::new(),
1036            scheme: "http".to_string(),
1037            endpoint: Url::parse("http://localhost:8080").unwrap(),
1038            host: "localhost".to_string(),
1039            port: 8000,
1040            tenant: None,
1041            warehouse: Mutex::new(None),
1042            auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1043            session_state: Mutex::new(SessionState::default()),
1044            wait_time_secs: None,
1045            max_rows_in_buffer: None,
1046            max_rows_per_page: None,
1047            connect_timeout: Duration::from_secs(10),
1048            page_request_timeout: Duration::from_secs(30),
1049            tls_ca_file: None,
1050            presign: Mutex::new(PresignMode::Auto),
1051            route_hint: RouteHintGenerator::new(),
1052            last_node_id: Default::default(),
1053            disable_session_token: true,
1054            disable_login: false,
1055            session_token_info: None,
1056            closed: Default::default(),
1057            last_query_id: Default::default(),
1058            server_version: None,
1059            capability: Default::default(),
1060        }
1061    }
1062}
1063
1064struct RouteHintGenerator {
1065    nonce: AtomicU64,
1066    current: std::sync::Mutex<String>,
1067}
1068
1069impl RouteHintGenerator {
1070    fn new() -> Self {
1071        let gen = Self {
1072            nonce: AtomicU64::new(0),
1073            current: std::sync::Mutex::new("".to_string()),
1074        };
1075        gen.next();
1076        gen
1077    }
1078
1079    fn current(&self) -> String {
1080        let guard = self.current.lock().unwrap();
1081        guard.clone()
1082    }
1083
1084    fn set(&self, hint: &str) {
1085        let mut guard = self.current.lock().unwrap();
1086        *guard = hint.to_string();
1087    }
1088
1089    fn next(&self) -> String {
1090        let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1091        let uuid = uuid::Uuid::new_v4();
1092        let current = format!("rh:{uuid}:{nonce:06}");
1093        let mut guard = self.current.lock().unwrap();
1094        guard.clone_from(&current);
1095        current
1096    }
1097}
1098
1099#[cfg(test)]
1100mod test {
1101    use super::*;
1102
1103    #[tokio::test]
1104    async fn parse_dsn() -> Result<()> {
1105        let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1106        let client = APIClient::from_dsn(dsn).await?;
1107        assert_eq!(client.host, "app.databend.com");
1108        assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1109        assert_eq!(client.wait_time_secs, Some(10));
1110        assert_eq!(client.max_rows_in_buffer, Some(5000000));
1111        assert_eq!(client.max_rows_per_page, Some(10000));
1112        assert_eq!(client.tenant, None);
1113        assert_eq!(
1114            *client.warehouse.try_lock().unwrap(),
1115            Some("wh".to_string())
1116        );
1117        Ok(())
1118    }
1119
1120    #[tokio::test]
1121    async fn parse_encoded_password() -> Result<()> {
1122        let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1123        let client = APIClient::from_dsn(dsn).await?;
1124        assert_eq!(client.host(), "localhost");
1125        assert_eq!(client.port(), 443);
1126        Ok(())
1127    }
1128
1129    #[tokio::test]
1130    async fn parse_special_chars_password() -> Result<()> {
1131        let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1132        let client = APIClient::from_dsn(dsn).await?;
1133        assert_eq!(client.host(), "localhost");
1134        assert_eq!(client.port(), 8000);
1135        Ok(())
1136    }
1137}