1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, RequestKind, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48 Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77
78static VERSION: Lazy<String> = Lazy::new(|| {
79 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
80 version.to_string()
81});
82
83#[derive(Clone)]
84pub(crate) struct QueryState {
85 pub node_id: String,
86 pub last_access_time: Arc<Mutex<Instant>>,
87 pub timeout_secs: u64,
88}
89
90impl QueryState {
91 pub fn need_heartbeat(&self, now: Instant) -> bool {
92 let t = *self.last_access_time.lock();
93 now.duration_since(t).as_secs() > self.timeout_secs / 2
94 }
95}
96
97struct HttpResponseData {
98 status: StatusCode,
99 headers: HeaderMap,
100 body: Bytes,
101}
102
103pub struct APIClient {
104 pub(crate) session_id: String,
105 cli: HttpClient,
106 scheme: String,
107 host: String,
108 port: u16,
109
110 endpoint: Url,
111
112 auth: Arc<dyn Auth>,
113
114 tenant: Option<String>,
115 warehouse: Mutex<Option<String>>,
116 session_state: Mutex<SessionState>,
117 route_hint: RouteHintGenerator,
118
119 disable_login: bool,
120 query_result_format: String,
121 disable_session_token: bool,
122 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
123
124 closed: AtomicBool,
125
126 server_version: Option<Version>,
127
128 wait_time_secs: Option<i64>,
129 max_rows_in_buffer: Option<i64>,
130 max_rows_per_page: Option<i64>,
131
132 connect_timeout: Duration,
133 page_request_timeout: Duration,
134
135 tls_ca_file: Option<String>,
136
137 presign: Mutex<PresignMode>,
138 last_node_id: Mutex<Option<String>>,
139 last_query_id: Mutex<Option<String>>,
140
141 capability: Capability,
142
143 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
144
145 retry_count: u32,
146 retry_delay_secs: u64,
147}
148
149impl Drop for APIClient {
150 fn drop(&mut self) {
151 self.close_with_spawn()
152 }
153}
154
155impl APIClient {
156 fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
157 while let Some(err) = source {
158 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
159 match io_err.kind() {
160 ErrorKind::TimedOut
161 | ErrorKind::UnexpectedEof
162 | ErrorKind::ConnectionReset
163 | ErrorKind::ConnectionAborted
164 | ErrorKind::BrokenPipe
165 | ErrorKind::NotConnected
166 | ErrorKind::WouldBlock
167 | ErrorKind::Interrupted => return true,
168 _ => {}
169 }
170 }
171 source = err.source();
172 }
173 false
174 }
175
176 fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
177 if err.is_timeout() {
178 Some("request timeout")
179 } else if err.is_connect() {
180 Some("connection error")
181 } else if err.is_request() {
182 Some("request error")
183 } else if err.is_body() || err.is_decode() {
184 if Self::has_transient_io_source(err.source()) {
185 Some("response read transient I/O error")
186 } else {
187 None
188 }
189 } else {
190 None
191 }
192 }
193
194 fn request_query_id(request: &Request) -> Option<String> {
195 request
196 .headers()
197 .get(HEADER_QUERY_ID)
198 .and_then(|v| v.to_str().ok())
199 .map(|v| v.to_string())
200 }
201
202 fn with_request_context(
203 reqwest_error: ReqwestError,
204 request_kind: &RequestKind,
205 request: &Request,
206 retry_times: Option<u32>,
207 ) -> Error {
208 let error: Error = reqwest_error.into();
209 let error = error.with_context(request_kind.clone());
210 let error = if let Some(retry_times) = retry_times {
211 error.with_retry_times(retry_times)
212 } else {
213 error
214 };
215 if let Some(query_id) = Self::request_query_id(request) {
216 error.with_query_id(query_id)
217 } else {
218 error
219 }
220 }
221
222 fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
223 match serde_json::from_slice::<ResponseWithErrorCode>(body) {
224 Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
225 Ok(resp) => Error::Logic(status, resp.error),
226 Err(_) => Error::response_error(status, body),
227 }
228 }
229
230 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
231 let mut client = Self::from_dsn(dsn).await?;
232 client.build_client(name).await?;
233 if !client.disable_login {
234 client.login().await?;
235 }
236 if client.session_id.is_empty() {
237 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
238 }
239 let client = Arc::new(client);
240 client.check_presign().await?;
241 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
242 Ok(client)
243 }
244
245 pub fn capability(&self) -> &Capability {
246 &self.capability
247 }
248
249 fn set_presign_mode(&self, mode: PresignMode) {
250 *self.presign.lock() = mode
251 }
252 fn get_presign_mode(&self) -> PresignMode {
253 *self.presign.lock()
254 }
255
256 async fn from_dsn(dsn: &str) -> Result<Self> {
257 let u = Url::parse(dsn)?;
258 let mut client = Self::default();
259 if let Some(host) = u.host_str() {
260 client.host = host.to_string();
261 }
262
263 if u.username() != "" {
264 let password = u.password().unwrap_or_default();
265 let password = percent_decode_str(password).decode_utf8()?;
266 client.auth = Arc::new(BasicAuth::new(u.username(), password));
267 }
268
269 let mut session_state = SessionState::default();
270
271 let database = u.path().trim_start_matches('/');
272 if !database.is_empty() {
273 session_state.set_database(database);
274 }
275
276 let mut scheme = "https";
277 for (k, v) in u.query_pairs() {
278 match k.as_ref() {
279 "wait_time_secs" => {
280 client.wait_time_secs = Some(v.parse()?);
281 }
282 "max_rows_in_buffer" => {
283 client.max_rows_in_buffer = Some(v.parse()?);
284 }
285 "max_rows_per_page" => {
286 client.max_rows_per_page = Some(v.parse()?);
287 }
288 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
289 "page_request_timeout_secs" => {
290 client.page_request_timeout = {
291 let secs: u64 = v.parse()?;
292 Duration::from_secs(secs)
293 };
294 }
295 "presign" => {
296 let presign_mode = match v.as_ref() {
297 "auto" => PresignMode::Auto,
298 "detect" => PresignMode::Detect,
299 "on" => PresignMode::On,
300 "off" => PresignMode::Off,
301 _ => {
302 return Err(Error::BadArgument(format!(
303 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
304 )))
305 }
306 };
307 client.set_presign_mode(presign_mode);
308 }
309 "tenant" => {
310 client.tenant = Some(v.to_string());
311 }
312 "warehouse" => {
313 client.warehouse = Mutex::new(Some(v.to_string()));
314 }
315 "role" => session_state.set_role(v),
316 "sslmode" => match v.as_ref() {
317 "disable" => scheme = "http",
318 "require" | "enable" => scheme = "https",
319 _ => {
320 return Err(Error::BadArgument(format!(
321 "Invalid value for sslmode: {v}"
322 )))
323 }
324 },
325 "tls_ca_file" => {
326 client.tls_ca_file = Some(v.to_string());
327 }
328 "access_token" => {
329 client.auth = Arc::new(AccessTokenAuth::new(v));
330 }
331 "access_token_file" => {
332 client.auth = Arc::new(AccessTokenFileAuth::new(v));
333 }
334 "login" => {
335 client.disable_login = match v.as_ref() {
336 "disable" => true,
337 "enable" => false,
338 _ => {
339 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
340 }
341 }
342 }
343 "session_token" => {
344 client.disable_session_token = match v.as_ref() {
345 "disable" => true,
346 "enable" => false,
347 _ => {
348 return Err(Error::BadArgument(format!(
349 "Invalid value for session_token: {v}"
350 )))
351 }
352 }
353 }
354 "body_format" | "query_result_format" => {
355 let v = v.to_string().to_lowercase();
356 match v.as_str() {
357 "json" | "arrow" => client.query_result_format = v.to_string(),
358 _ => {
359 return Err(Error::BadArgument(format!(
360 "Invalid value for query_result_format: {v}"
361 )))
362 }
363 }
364 }
365 "retry_count" => {
366 client.retry_count = v.parse()?;
367 }
368 "retry_delay_secs" => {
369 client.retry_delay_secs = v.parse()?;
370 }
371 _ => {
372 session_state.set(k, v);
373 }
374 }
375 }
376 client.port = match u.port() {
377 Some(p) => p,
378 None => match scheme {
379 "http" => 80,
380 "https" => 443,
381 _ => unreachable!(),
382 },
383 };
384 client.scheme = scheme.to_string();
385 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
386 client.session_state = Mutex::new(session_state);
387
388 Ok(client)
389 }
390
391 pub fn host(&self) -> &str {
392 self.host.as_str()
393 }
394
395 pub fn port(&self) -> u16 {
396 self.port
397 }
398
399 pub fn scheme(&self) -> &str {
400 self.scheme.as_str()
401 }
402
403 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
404 let ua = match name {
405 Some(n) => n,
406 None => format!("databend-client-rust/{}", VERSION.as_str()),
407 };
408 let cookie_provider = GlobalCookieStore::new();
409 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
410 let mut initial_cookies = [&cookie].into_iter();
411 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
412 let mut cli_builder = HttpClient::builder()
413 .user_agent(ua)
414 .connect_timeout(self.connect_timeout)
415 .cookie_provider(Arc::new(cookie_provider))
416 .pool_idle_timeout(Duration::from_secs(1));
417 #[cfg(any(feature = "rustls", feature = "native-tls"))]
418 if self.scheme == "https" {
419 if let Some(ref ca_file) = self.tls_ca_file {
420 let cert_pem = tokio::fs::read(ca_file).await?;
421 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
422 cli_builder = cli_builder.add_root_certificate(cert);
423 }
424 }
425 self.cli = cli_builder.build()?;
426 Ok(())
427 }
428
429 async fn check_presign(self: &Arc<Self>) -> Result<()> {
430 let mode = match self.get_presign_mode() {
431 PresignMode::Auto => {
432 if self.host.ends_with(".databend.com")
433 || self.host.ends_with(".databend.cn")
434 || self.host.ends_with(".tidbcloud.com")
435 {
436 PresignMode::On
437 } else {
438 PresignMode::Off
439 }
440 }
441 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
442 Ok(_) => PresignMode::On,
443 Err(e) => {
444 warn!("presign mode off with error detected: {e}");
445 PresignMode::Off
446 }
447 },
448 mode => mode,
449 };
450 self.set_presign_mode(mode);
451 Ok(())
452 }
453
454 pub fn current_warehouse(&self) -> Option<String> {
455 let guard = self.warehouse.lock();
456 guard.clone()
457 }
458
459 pub fn current_catalog(&self) -> Option<String> {
460 let guard = self.session_state.lock();
461 guard.catalog.clone()
462 }
463
464 pub fn current_database(&self) -> Option<String> {
465 let guard = self.session_state.lock();
466 guard.database.clone()
467 }
468
469 pub fn set_warehouse(&self, warehouse: impl Into<String>) {
470 let mut guard = self.warehouse.lock();
471 *guard = Some(warehouse.into());
472 }
473
474 pub fn set_database(&self, database: impl Into<String>) {
475 let mut guard = self.session_state.lock();
476 guard.set_database(database);
477 }
478
479 pub fn set_role(&self, role: impl Into<String>) {
480 let mut guard = self.session_state.lock();
481 guard.set_role(role);
482 }
483
484 pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
485 let mut guard = self.session_state.lock();
486 guard.set(key, value);
487 }
488
489 pub async fn current_role(&self) -> Option<String> {
490 let guard = self.session_state.lock();
491 guard.role.clone()
492 }
493
494 fn in_active_transaction(&self) -> bool {
495 let guard = self.session_state.lock();
496 guard
497 .txn_state
498 .as_ref()
499 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
500 .unwrap_or(false)
501 }
502
503 pub fn username(&self) -> String {
504 self.auth.username()
505 }
506
507 fn gen_query_id(&self) -> String {
508 uuid::Uuid::now_v7().simple().to_string()
509 }
510
511 async fn handle_session(&self, session: &Option<SessionState>) {
512 let session = match session {
513 Some(session) => session,
514 None => return,
515 };
516
517 {
519 let mut session_state = self.session_state.lock();
520 *session_state = session.clone();
521 }
522
523 if let Some(settings) = session.settings.as_ref() {
525 if let Some(v) = settings.get("warehouse") {
526 let mut warehouse = self.warehouse.lock();
527 *warehouse = Some(v.clone());
528 }
529 }
530 }
531
532 pub fn set_last_node_id(&self, node_id: String) {
533 *self.last_node_id.lock() = Some(node_id)
534 }
535
536 pub fn set_last_query_id(&self, query_id: Option<String>) {
537 *self.last_query_id.lock() = query_id
538 }
539
540 pub fn last_query_id(&self) -> Option<String> {
541 self.last_query_id.lock().clone()
542 }
543
544 fn last_node_id(&self) -> Option<String> {
545 self.last_node_id.lock().clone()
546 }
547
548 fn handle_warnings(&self, resp: &QueryResponse) {
549 if let Some(warnings) = &resp.warnings {
550 for w in warnings {
551 warn!(target: "server_warnings", "server warning: {w}");
552 }
553 }
554 }
555
556 pub async fn start_query(
557 self: &Arc<Self>,
558 sql: &str,
559 need_progress: bool,
560 params: Option<serde_json::Value>,
561 ) -> Result<Pages> {
562 info!("start query: {sql}");
563 let (resp, batches) = self.start_query_inner(sql, None, false, params).await?;
564 Pages::new(self.clone(), resp, batches, need_progress)
565 }
566
567 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
568 let mut mgr = self.queries_need_heartbeat.lock();
569 if let Some(state) = mgr.remove(query_id) {
570 let self_cloned = self.clone();
571 let query_id = query_id.to_owned();
572 GLOBAL_RUNTIME.spawn(async move {
573 if let Err(e) = self_cloned
574 .end_query(&query_id, false, Some(state.node_id.as_str()))
575 .await
576 {
577 error!("failed to final query {query_id}: {e}");
578 }
579 });
580 }
581 }
582
583 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
584 if let Some(info) = &self.session_token_info {
585 let info = info.lock();
586 Ok(builder.bearer_auth(info.0.session_token.clone()))
587 } else {
588 self.auth.wrap(builder)
589 }
590 }
591
592 async fn start_query_inner(
593 &self,
594 sql: &str,
595 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
596 force_json_body: bool,
597 params: Option<serde_json::Value>,
598 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
599 if !self.in_active_transaction() {
600 self.route_hint.next();
601 }
602 let endpoint = self.endpoint.join("v1/query")?;
603
604 let session_state = self.session_state();
606 let need_sticky = session_state.need_sticky.unwrap_or(false);
607 let mut req = QueryRequest::new(sql)
608 .with_pagination(self.make_pagination())
609 .with_session(Some(session_state))
610 .with_stage_attachment(stage_attachment_config)
611 .with_params(params);
612
613 let query_id = self.gen_query_id();
615 let mut headers = self.make_headers(Some(&query_id))?;
616 if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
617 debug!("accept arrow data");
618 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
619 req = req.with_arrow();
620 }
621
622 if need_sticky {
623 if let Some(node_id) = self.last_node_id() {
624 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
625 }
626 }
627 let mut builder = self.cli.post(endpoint.clone()).json(&req);
628 builder = self.wrap_auth_or_session_token(builder)?;
629 let request = builder
630 .headers(headers.clone())
631 .timeout(self.page_request_timeout)
632 .build()?;
633 let response = self
634 .query_request_helper(request, true, true, true, RequestKind::QueryStart)
635 .await?;
636 self.handle_page(response, true).await
637 }
638
639 fn is_arrow_data(headers: &HeaderMap) -> bool {
640 if let Some(typ) = headers.get(CONTENT_TYPE) {
641 if let Ok(t) = typ.to_str() {
642 return t == CONTENT_TYPE_ARROW;
643 }
644 }
645 false
646 }
647
648 async fn handle_page(
649 &self,
650 response: HttpResponseData,
651 is_first: bool,
652 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
653 let status = response.status;
654 if status != StatusCode::OK {
655 let error = Self::status_body_to_error(status, &response.body);
656 if !is_first {
657 if let Error::Logic(code, ec) = &error {
658 if *code == StatusCode::NOT_FOUND {
659 return Err(Error::QueryNotFound(ec.message.clone()));
660 }
661 }
662 }
663 return Err(error);
664 }
665 let is_arrow_data = Self::is_arrow_data(&response.headers);
666 if is_first {
667 if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
668 self.route_hint.set(route_hint.to_str().unwrap_or_default());
669 }
670 }
671 let mut body = response.body;
672 let mut batches = vec![];
673 if is_arrow_data {
674 if is_first {
675 debug!("received arrow data");
676 }
677 let cursor = std::io::Cursor::new(body.as_ref());
678 let reader = StreamReader::try_new(cursor, None)
679 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
680 let schema = reader.schema();
681 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
682 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
683 } else {
684 return Err(Error::Decode(
685 "missing response_header metadata in arrow payload".to_string(),
686 ));
687 };
688 for batch in reader {
689 let batch = batch
690 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
691 batches.push(batch);
692 }
693 body = json_body
694 };
695 let resp: QueryResponse = json_from_slice(&body)?;
696 self.handle_session(&resp.session).await;
697 if let Some(err) = &resp.error {
698 return Err(Error::QueryFailed(err.clone()));
699 }
700 if is_first {
701 self.handle_warnings(&resp);
702 self.set_last_query_id(Some(resp.id.clone()));
703 if let Some(node_id) = &resp.node_id {
704 self.set_last_node_id(node_id.clone());
705 }
706 }
707 Ok((resp, batches))
708 }
709
710 pub async fn query_page(
711 &self,
712 query_id: &str,
713 next_uri: &str,
714 node_id: &Option<String>,
715 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
716 info!("query page: {next_uri}");
717 let endpoint = self.endpoint.join(next_uri)?;
718 let mut headers = self.make_headers(Some(query_id))?;
719 if self.capability.arrow_data && self.query_result_format == "arrow" {
720 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
721 }
722 let mut builder = self.cli.get(endpoint.clone());
723 builder = self
724 .wrap_auth_or_session_token(builder)?
725 .headers(headers.clone())
726 .timeout(self.page_request_timeout);
727 if let Some(node_id) = node_id {
728 builder = builder.header(HEADER_STICKY_NODE, node_id)
729 }
730 let request = builder.build()?;
731
732 let response = self
733 .query_request_helper(request, false, true, true, RequestKind::QueryPage)
734 .await?;
735 self.handle_page(response, false).await
736 }
737
738 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
739 self.end_query(query_id, true, None).await
740 }
741
742 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
743 self.end_query(query_id, false, node_id).await
744 }
745
746 pub async fn end_query(
747 &self,
748 query_id: &str,
749 is_kill: bool,
750 node_id: Option<&str>,
751 ) -> Result<()> {
752 let (method, request_kind) = if is_kill {
753 ("kill", RequestKind::QueryKill)
754 } else {
755 ("final", RequestKind::QueryFinal)
756 };
757 let uri = format!("/v1/query/{query_id}/{method}");
758 let endpoint = self.endpoint.join(&uri)?;
759 let headers = self.make_headers(Some(query_id))?;
760
761 info!("{method} query: {uri}");
762
763 let mut builder = self.cli.post(endpoint.clone());
764 if let Some(node_id) = node_id {
765 builder = builder.header(HEADER_STICKY_NODE, node_id)
766 }
767 builder = self.wrap_auth_or_session_token(builder)?;
768 let resp = builder.headers(headers.clone()).send().await?;
769 if resp.status() != 200 {
770 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
771 .with_context(request_kind)
772 .with_query_id(query_id));
773 }
774 Ok(())
775 }
776
777 pub async fn query_all(
778 self: &Arc<Self>,
779 sql: &str,
780 params: Option<serde_json::Value>,
781 ) -> Result<Page> {
782 self.query_all_inner(sql, false, params).await
783 }
784
785 pub async fn query_all_inner(
786 self: &Arc<Self>,
787 sql: &str,
788 force_json_body: bool,
789 params: Option<serde_json::Value>,
790 ) -> Result<Page> {
791 let (resp, batches) = self
792 .start_query_inner(sql, None, force_json_body, params)
793 .await?;
794 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
795 let mut all = Page::default();
796 while let Some(page) = pages.next().await {
797 all.update(page?);
798 }
799 Ok(all)
800 }
801
802 fn session_state(&self) -> SessionState {
803 self.session_state.lock().clone()
804 }
805
806 fn make_pagination(&self) -> Option<PaginationConfig> {
807 if self.wait_time_secs.is_none()
808 && self.max_rows_in_buffer.is_none()
809 && self.max_rows_per_page.is_none()
810 {
811 return None;
812 }
813 let mut pagination = PaginationConfig {
814 wait_time_secs: None,
815 max_rows_in_buffer: None,
816 max_rows_per_page: None,
817 };
818 if let Some(wait_time_secs) = self.wait_time_secs {
819 pagination.wait_time_secs = Some(wait_time_secs);
820 }
821 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
822 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
823 }
824 if let Some(max_rows_per_page) = self.max_rows_per_page {
825 pagination.max_rows_per_page = Some(max_rows_per_page);
826 }
827 Some(pagination)
828 }
829
830 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
831 let mut headers = HeaderMap::new();
832 if let Some(tenant) = &self.tenant {
833 headers.insert(HEADER_TENANT, tenant.parse()?);
834 }
835 let warehouse = self.warehouse.lock().clone();
836 if let Some(warehouse) = warehouse {
837 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
838 }
839 let route_hint = self.route_hint.current();
840 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
841 if let Some(query_id) = query_id {
842 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
843 }
844 Ok(headers)
845 }
846
847 pub async fn insert_with_stage(
848 self: &Arc<Self>,
849 sql: &str,
850 stage: &str,
851 file_format_options: BTreeMap<&str, &str>,
852 copy_options: BTreeMap<&str, &str>,
853 ) -> Result<QueryStats> {
854 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
855 let stage_attachment = Some(StageAttachmentConfig {
856 location: stage,
857 file_format_options: Some(file_format_options),
858 copy_options: Some(copy_options),
859 });
860 let (resp, batches) = self
861 .start_query_inner(sql, stage_attachment, true, None)
862 .await?;
863 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
864 let mut all = Page::default();
865 while let Some(page) = pages.next().await {
866 all.update(page?);
867 }
868 Ok(all.stats)
869 }
870
871 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
872 info!("get presigned upload url: {stage}");
873 let sql = format!("PRESIGN UPLOAD {stage}");
874 let resp = self.query_all_inner(&sql, true, None).await?;
875 if resp.data.len() != 1 {
876 return Err(Error::Decode(
877 "Empty response from server for presigned request".to_string(),
878 ));
879 }
880 if resp.data[0].len() != 3 {
881 return Err(Error::Decode(
882 "Invalid response from server for presigned request".to_string(),
883 ));
884 }
885 let method = resp.data[0][0].clone().unwrap_or_default();
887 if method != "PUT" {
888 return Err(Error::Decode(format!(
889 "Invalid method for presigned upload request: {method}"
890 )));
891 }
892 let headers: BTreeMap<String, String> =
893 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
894 let url = resp.data[0][2].clone().unwrap_or_default();
895 Ok(PresignedResponse {
896 method,
897 headers,
898 url,
899 })
900 }
901
902 pub async fn upload_to_stage(
903 self: &Arc<Self>,
904 stage: &str,
905 data: Reader,
906 size: u64,
907 ) -> Result<()> {
908 match self.get_presign_mode() {
909 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
910 PresignMode::On => {
911 let presigned = self.get_presigned_upload_url(stage).await?;
912 presign_upload_to_stage(presigned, data, size).await
913 }
914 PresignMode::Auto => {
915 unreachable!("PresignMode::Auto should be handled during client initialization")
916 }
917 PresignMode::Detect => {
918 unreachable!("PresignMode::Detect should be handled during client initialization")
919 }
920 }
921 }
922
923 async fn upload_to_stage_with_stream(
925 &self,
926 stage: &str,
927 data: Reader,
928 size: u64,
929 ) -> Result<()> {
930 info!("upload to stage with stream: {stage}, size: {size}");
931 if let Some(info) = self.need_pre_refresh_session().await {
932 self.refresh_session_token(info).await?;
933 }
934 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
935 let location = StageLocation::try_from(stage)?;
936 let query_id = self.gen_query_id();
937 let mut headers = self.make_headers(Some(&query_id))?;
938 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
939 let stream = Body::wrap_stream(ReaderStream::new(data));
940 let part = Part::stream_with_length(stream, size).file_name(location.path);
941 let form = Form::new().part("upload", part);
942 let mut builder = self.cli.put(endpoint.clone());
943 builder = self.wrap_auth_or_session_token(builder)?;
944 let resp = builder.headers(headers).multipart(form).send().await?;
945 let status = resp.status();
946 if status != 200 {
947 return Err(Error::response_error(status, &resp.bytes().await?)
948 .with_context(RequestKind::UploadToStage)
949 .with_query_id(query_id));
950 }
951 Ok(())
952 }
953
954 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
957 where
958 T: de::DeserializeOwned,
959 {
960 if value.starts_with("{") {
961 serde_json::from_slice(value.as_bytes())
962 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
963 } else {
964 let json = URL_SAFE.decode(value).map_err(|e| {
965 format!(
966 "Invalid value {} for {key}, base64 decode error: {}",
967 value, e
968 )
969 })?;
970 serde_json::from_slice(&json).map_err(|e| {
971 format!(
972 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
973 String::from_utf8_lossy(&json)
974 )
975 })
976 }
977 }
978
979 pub async fn streaming_load(
980 &self,
981 sql: &str,
982 data: Reader,
983 file_name: &str,
984 ) -> Result<LoadResponse> {
985 let body = Body::wrap_stream(ReaderStream::new(data));
986 let part = Part::stream(body).file_name(file_name.to_string());
987 let endpoint = self.endpoint.join("v1/streaming_load")?;
988 let mut builder = self.cli.put(endpoint.clone());
989 builder = self.wrap_auth_or_session_token(builder)?;
990 let query_id = self.gen_query_id();
991 let mut headers = self.make_headers(Some(&query_id))?;
992 headers.insert(HEADER_SQL, sql.parse()?);
993 let session = serde_json::to_string(&*self.session_state.lock())
994 .expect("serialize session state should not fail");
995 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
996 let form = Form::new().part("upload", part);
997 let resp = builder.headers(headers).multipart(form).send().await?;
998 let status = resp.status();
999 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
1000 match Self::decode_json_header::<SessionState>(
1001 HEADER_QUERY_CONTEXT,
1002 value.to_str().unwrap(),
1003 ) {
1004 Ok(session) => *self.session_state.lock() = session,
1005 Err(e) => {
1006 error!("Error decoding session state when streaming load: {e}");
1007 }
1008 }
1009 };
1010 if status != 200 {
1011 return Err(Error::response_error(status, &resp.bytes().await?)
1012 .with_context(RequestKind::StreamingLoad)
1013 .with_query_id(query_id));
1014 }
1015 let resp = resp.json::<LoadResponse>().await?;
1016 Ok(resp)
1017 }
1018
1019 async fn login(&mut self) -> Result<()> {
1020 let endpoint = self.endpoint.join("/v1/session/login")?;
1021 let headers = self.make_headers(None)?;
1022 let body = LoginRequest::from(&*self.session_state.lock());
1023 let mut builder = self.cli.post(endpoint.clone()).json(&body);
1024 if self.disable_session_token {
1025 builder = builder.query(&[("disable_session_token", true)]);
1026 }
1027 let builder = self.auth.wrap(builder)?;
1028 let request = builder
1029 .headers(headers.clone())
1030 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1031 .build()?;
1032 let response = self
1033 .query_request_helper(request, true, false, true, RequestKind::Login)
1034 .await?;
1035 if response.status == StatusCode::NOT_FOUND {
1036 info!("login return 404, skip login on the old version server");
1037 return Ok(());
1038 }
1039 if response.status != StatusCode::OK {
1040 return Err(Self::status_body_to_error(response.status, &response.body)
1041 .with_context(RequestKind::Login));
1042 }
1043 if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1044 if let Ok(s) = v.to_str() {
1045 self.session_id = s.to_string();
1046 }
1047 }
1048
1049 let response = json_from_slice(&response.body)?;
1050 match response {
1051 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1052 LoginResponseResult::Ok(info) => {
1053 let server_version = info
1054 .version
1055 .parse()
1056 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1057 self.capability = Capability::from_server_version(&server_version);
1058 self.server_version = Some(server_version.clone());
1059 let session_id = self.session_id.as_str();
1060 if let Some(tokens) = info.tokens {
1061 info!(
1062 "[session {session_id}] login success with session token version = {server_version}",
1063 );
1064 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1065 } else {
1066 info!("[session {session_id}] login success, version = {server_version}");
1067 }
1068 }
1069 }
1070 Ok(())
1071 }
1072
1073 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1074 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1075 let queries = self.queries_need_heartbeat.lock().clone();
1076 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1077 let now = Instant::now();
1078
1079 let mut query_ids = Vec::new();
1080 for (qid, state) in queries {
1081 if state.need_heartbeat(now) {
1082 query_ids.push(qid.to_string());
1083 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1084 arr.push(qid);
1085 } else {
1086 node_to_queries.insert(state.node_id, vec![qid]);
1087 }
1088 }
1089 }
1090
1091 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1092 {
1093 return Ok(());
1094 }
1095
1096 let body = json!({
1097 "node_to_queries": node_to_queries
1098 });
1099 let headers = self.make_headers(None)?;
1100 let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1101 let request = self.wrap_auth_or_session_token(builder)?.build()?;
1102 let response = self
1103 .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1104 .await?;
1105 if response.status != StatusCode::OK {
1106 return Err(Self::status_body_to_error(response.status, &response.body)
1107 .with_context(RequestKind::Heartbeat));
1108 }
1109 let json: Value = json_from_slice(&response.body)?;
1110 let session_id = self.session_id.as_str();
1111 info!("[session {session_id}] heartbeat request={body}, response={json}");
1112 if let Some(queries_to_remove) = json.get("queries_to_remove") {
1113 if let Some(arr) = queries_to_remove.as_array() {
1114 if !arr.is_empty() {
1115 let mut queries = self.queries_need_heartbeat.lock();
1116 for q in arr {
1117 if let Some(q) = q.as_str() {
1118 queries.remove(q);
1119 }
1120 }
1121 }
1122 }
1123 }
1124 let now = Instant::now();
1125 let mut queries = self.queries_need_heartbeat.lock();
1126 for qid in query_ids {
1127 if let Some(state) = queries.get_mut(&qid) {
1128 *state.last_access_time.lock() = now;
1129 }
1130 }
1131 Ok(())
1132 }
1133
1134 fn build_log_out_request(&self) -> Result<Request> {
1135 let endpoint = self.endpoint.join("/v1/session/logout")?;
1136
1137 let session_state = self.session_state();
1138 let need_sticky = session_state.need_sticky.unwrap_or(false);
1139 let mut headers = self.make_headers(None)?;
1140 if need_sticky {
1141 if let Some(node_id) = self.last_node_id() {
1142 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1143 }
1144 }
1145 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1146
1147 let builder = self.wrap_auth_or_session_token(builder)?;
1148 let req = builder.build()?;
1149 Ok(req)
1150 }
1151
1152 pub(crate) fn need_logout(&self) -> bool {
1153 self.session_token_info.is_some()
1154 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1155 }
1156
1157 async fn refresh_session_token(
1158 &self,
1159 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1160 ) -> Result<()> {
1161 let (session_token_info, _) = { self_login_info.lock().clone() };
1162 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1163 let body = RefreshSessionTokenRequest {
1164 session_token: session_token_info.session_token.clone(),
1165 };
1166 let headers = self.make_headers(None)?;
1167 let request = self
1168 .cli
1169 .post(endpoint.clone())
1170 .json(&body)
1171 .headers(headers.clone())
1172 .bearer_auth(session_token_info.refresh_token.clone())
1173 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1174 .build()?;
1175 let response = self
1176 .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1177 .await?;
1178 if response.status != StatusCode::OK {
1179 return Err(Self::status_body_to_error(response.status, &response.body)
1180 .with_context(RequestKind::SessionRefresh));
1181 }
1182 let response = json_from_slice(&response.body)?;
1183 match response {
1184 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1185 RefreshResponse::Ok(info) => {
1186 *self_login_info.lock() = (info, Instant::now());
1187 Ok(())
1188 }
1189 }
1190 }
1191
1192 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1193 if let Some(info) = &self.session_token_info {
1194 let (start, ttl) = {
1195 let guard = info.lock();
1196 (guard.1, guard.0.session_token_ttl_in_secs)
1197 };
1198 if Instant::now() > start + Duration::from_secs(ttl) {
1199 return Some(info.clone());
1200 }
1201 }
1202 None
1203 }
1204
1205 async fn query_request_helper(
1213 &self,
1214 mut request: Request,
1215 retry_if_503: bool,
1216 refresh_if_401: bool,
1217 reload_auth_if_401: bool,
1218 request_kind: RequestKind,
1219 ) -> Result<HttpResponseData> {
1220 let mut refreshed = false;
1221 let mut retries = 0;
1222 let max_retries = self.retry_count;
1223 let retry_delay = Duration::from_secs(self.retry_delay_secs);
1224
1225 loop {
1226 let req = request.try_clone().expect("request not cloneable");
1227 let response = match self.cli.execute(req).await {
1228 Ok(response) => response,
1229 Err(err) => {
1230 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1231 if retries >= max_retries.saturating_sub(1) {
1232 return Err(Self::with_request_context(
1233 err,
1234 &request_kind,
1235 &request,
1236 Some(retries),
1237 ));
1238 }
1239 retries += 1;
1240 warn!(
1241 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1242 retries,
1243 max_retries,
1244 request.url(),
1245 reason,
1246 err,
1247 retry_delay.as_secs(),
1248 );
1249 sleep(jitter(retry_delay)).await;
1250 continue;
1251 }
1252 return Err(Self::with_request_context(
1253 err,
1254 &request_kind,
1255 &request,
1256 Some(retries),
1257 ));
1258 }
1259 };
1260
1261 let status = response.status();
1262 let headers = response.headers().clone();
1263 let body = match response.bytes().await {
1264 Ok(body) => body,
1265 Err(err) => {
1266 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1267 if retries >= max_retries.saturating_sub(1) {
1268 return Err(Self::with_request_context(
1269 err,
1270 &request_kind,
1271 &request,
1272 Some(retries),
1273 ));
1274 }
1275 retries += 1;
1276 warn!(
1277 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1278 retries,
1279 max_retries,
1280 request.url(),
1281 reason,
1282 err,
1283 retry_delay.as_secs(),
1284 );
1285 sleep(jitter(retry_delay)).await;
1286 continue;
1287 }
1288 return Err(Self::with_request_context(
1289 err,
1290 &request_kind,
1291 &request,
1292 Some(retries),
1293 ));
1294 }
1295 };
1296
1297 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1298 if retries >= max_retries.saturating_sub(1) {
1299 return Ok(HttpResponseData {
1300 status,
1301 headers,
1302 body,
1303 });
1304 }
1305 retries += 1;
1306 warn!(
1307 "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1308 retries,
1309 max_retries,
1310 request.url(),
1311 retry_delay.as_secs()
1312 );
1313 sleep(jitter(retry_delay)).await;
1314 continue;
1315 }
1316
1317 if status == StatusCode::UNAUTHORIZED {
1318 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1319 let unauthorized_error =
1320 serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1321 if let Some(session_token_info) = &self.session_token_info {
1322 let should_refresh = unauthorized_error
1323 .as_ref()
1324 .map(|r| need_refresh_token(r.error.code))
1325 .unwrap_or(false);
1326 if refresh_if_401 && should_refresh && !refreshed {
1327 Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1328 refreshed = true;
1329 retries = 0;
1330 warn!(
1331 "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1332 request.url(),
1333 retry_delay.as_secs()
1334 );
1335 sleep(jitter(retry_delay)).await;
1336 continue;
1337 }
1338 }
1339 if reload_auth_if_401 && self.auth.can_reload() {
1340 if retries >= max_retries.saturating_sub(1) {
1341 return Ok(HttpResponseData {
1342 status,
1343 headers,
1344 body,
1345 });
1346 }
1347 retries += 1;
1348 let builder =
1349 RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1350 let builder = self.auth.wrap(builder)?;
1351 request = builder.build()?;
1352 warn!(
1353 "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1354 retries,
1355 max_retries,
1356 request.url(),
1357 retry_delay.as_secs()
1358 );
1359 sleep(jitter(retry_delay)).await;
1360 continue;
1361 }
1362 }
1363
1364 return Ok(HttpResponseData {
1365 status,
1366 headers,
1367 body,
1368 });
1369 }
1370 }
1371
1372 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1373 if let Err(err) = http_client.execute(request).await {
1374 error!("[session {session_id}] logout request failed: {err}");
1375 } else {
1376 info!("[session {session_id}] logout success");
1377 };
1378 }
1379
1380 pub async fn close(&self) {
1381 let session_id = &self.session_id;
1382 info!("[session {session_id}] try closing now");
1383 if self
1384 .closed
1385 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1386 .is_ok()
1387 {
1388 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1389 if self.need_logout() {
1390 let cli = self.cli.clone();
1391 let req = self
1392 .build_log_out_request()
1393 .expect("failed to build logout request");
1394 Self::logout(cli, req, &self.session_id).await;
1395 }
1396 }
1397 }
1398 pub fn close_with_spawn(&self) {
1399 let session_id = &self.session_id;
1400 info!("[session {session_id}]: try closing with spawn");
1401 if self
1402 .closed
1403 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1404 .is_ok()
1405 {
1406 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1407 if self.need_logout() {
1408 let cli = self.cli.clone();
1409 let req = self
1410 .build_log_out_request()
1411 .expect("failed to build logout request");
1412 let session_id = self.session_id.clone();
1413 GLOBAL_RUNTIME.spawn(async move {
1414 Self::logout(cli, req, session_id.as_str()).await;
1415 });
1416 }
1417 }
1418 }
1419
1420 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1421 let mut queries = self.queries_need_heartbeat.lock();
1422 queries.insert(query_id.to_string(), state);
1423 }
1424}
1425
1426fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1427where
1428 T: Deserialize<'a>,
1429{
1430 serde_json::from_slice::<T>(body).map_err(|e| {
1431 Error::Decode(format!(
1432 "fail to decode JSON response: {e}, body: {}",
1433 String::from_utf8_lossy(body)
1434 ))
1435 })
1436}
1437
1438impl Default for APIClient {
1439 fn default() -> Self {
1440 Self {
1441 session_id: Default::default(),
1442 cli: HttpClient::new(),
1443 scheme: "http".to_string(),
1444 endpoint: Url::parse("http://localhost:8080").unwrap(),
1445 host: "localhost".to_string(),
1446 port: 8000,
1447 tenant: None,
1448 warehouse: Mutex::new(None),
1449 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1450 session_state: Mutex::new(SessionState::default()),
1451 wait_time_secs: None,
1452 max_rows_in_buffer: None,
1453 max_rows_per_page: None,
1454 connect_timeout: Duration::from_secs(10),
1455 page_request_timeout: Duration::from_secs(300),
1456 tls_ca_file: None,
1457 presign: Mutex::new(PresignMode::Auto),
1458 route_hint: RouteHintGenerator::new(),
1459 last_node_id: Default::default(),
1460 disable_session_token: true,
1461 disable_login: false,
1462 query_result_format: "json".to_string(),
1463 session_token_info: None,
1464 closed: AtomicBool::new(false),
1465 last_query_id: Default::default(),
1466 server_version: None,
1467 capability: Default::default(),
1468 queries_need_heartbeat: Default::default(),
1469 retry_count: 3,
1470 retry_delay_secs: 10,
1471 }
1472 }
1473}
1474
1475struct RouteHintGenerator {
1476 nonce: AtomicU64,
1477 current: std::sync::Mutex<String>,
1478}
1479
1480impl RouteHintGenerator {
1481 fn new() -> Self {
1482 let gen = Self {
1483 nonce: AtomicU64::new(0),
1484 current: std::sync::Mutex::new("".to_string()),
1485 };
1486 gen.next();
1487 gen
1488 }
1489
1490 fn current(&self) -> String {
1491 let guard = self.current.lock().unwrap();
1492 guard.clone()
1493 }
1494
1495 fn set(&self, hint: &str) {
1496 let mut guard = self.current.lock().unwrap();
1497 *guard = hint.to_string();
1498 }
1499
1500 fn next(&self) -> String {
1501 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1502 let uuid = uuid::Uuid::new_v4();
1503 let current = format!("rh:{uuid}:{nonce:06}");
1504 let mut guard = self.current.lock().unwrap();
1505 guard.clone_from(¤t);
1506 current
1507 }
1508}
1509
1510#[cfg(test)]
1511mod test {
1512 use super::*;
1513
1514 #[tokio::test]
1515 async fn parse_dsn() -> Result<()> {
1516 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1517 let client = APIClient::from_dsn(dsn).await?;
1518 assert_eq!(client.host, "app.databend.com");
1519 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1520 assert_eq!(client.wait_time_secs, Some(10));
1521 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1522 assert_eq!(client.max_rows_per_page, Some(10000));
1523 assert_eq!(client.tenant, None);
1524 assert_eq!(
1525 *client.warehouse.try_lock().unwrap(),
1526 Some("wh".to_string())
1527 );
1528 Ok(())
1529 }
1530
1531 #[tokio::test]
1532 async fn parse_encoded_password() -> Result<()> {
1533 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1534 let client = APIClient::from_dsn(dsn).await?;
1535 assert_eq!(client.host(), "localhost");
1536 assert_eq!(client.port(), 443);
1537 Ok(())
1538 }
1539
1540 #[tokio::test]
1541 async fn parse_special_chars_password() -> Result<()> {
1542 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1543 let client = APIClient::from_dsn(dsn).await?;
1544 assert_eq!(client.host(), "localhost");
1545 assert_eq!(client.port(), 8000);
1546 Ok(())
1547 }
1548}