1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, RequestKind, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48 Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77
78static VERSION: Lazy<String> = Lazy::new(|| {
79 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
80 version.to_string()
81});
82
83#[derive(Clone)]
84pub(crate) struct QueryState {
85 pub node_id: String,
86 pub last_access_time: Arc<Mutex<Instant>>,
87 pub timeout_secs: u64,
88}
89
90impl QueryState {
91 pub fn need_heartbeat(&self, now: Instant) -> bool {
92 let t = *self.last_access_time.lock();
93 now.duration_since(t).as_secs() > self.timeout_secs / 2
94 }
95}
96
97struct HttpResponseData {
98 status: StatusCode,
99 headers: HeaderMap,
100 body: Bytes,
101}
102
103pub struct APIClient {
104 pub(crate) session_id: String,
105 cli: HttpClient,
106 scheme: String,
107 host: String,
108 port: u16,
109
110 endpoint: Url,
111
112 auth: Arc<dyn Auth>,
113
114 tenant: Option<String>,
115 warehouse: Mutex<Option<String>>,
116 session_state: Mutex<SessionState>,
117 route_hint: RouteHintGenerator,
118
119 disable_login: bool,
120 query_result_format: String,
121 disable_session_token: bool,
122 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
123
124 closed: AtomicBool,
125
126 server_version: Option<Version>,
127
128 wait_time_secs: Option<i64>,
129 max_rows_in_buffer: Option<i64>,
130 max_rows_per_page: Option<i64>,
131
132 connect_timeout: Duration,
133 page_request_timeout: Duration,
134
135 tls_ca_file: Option<String>,
136
137 presign: Mutex<PresignMode>,
138 last_node_id: Mutex<Option<String>>,
139 last_query_id: Mutex<Option<String>>,
140
141 capability: Capability,
142
143 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
144
145 retry_count: u32,
146 retry_delay_secs: u64,
147}
148
149impl Drop for APIClient {
150 fn drop(&mut self) {
151 self.close_with_spawn()
152 }
153}
154
155impl APIClient {
156 fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
157 while let Some(err) = source {
158 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
159 match io_err.kind() {
160 ErrorKind::TimedOut
161 | ErrorKind::UnexpectedEof
162 | ErrorKind::ConnectionReset
163 | ErrorKind::ConnectionAborted
164 | ErrorKind::BrokenPipe
165 | ErrorKind::NotConnected
166 | ErrorKind::WouldBlock
167 | ErrorKind::Interrupted => return true,
168 _ => {}
169 }
170 }
171 source = err.source();
172 }
173 false
174 }
175
176 fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
177 if err.is_timeout() {
178 Some("request timeout")
179 } else if err.is_connect() {
180 Some("connection error")
181 } else if err.is_request() {
182 Some("request error")
183 } else if err.is_body() || err.is_decode() {
184 if Self::has_transient_io_source(err.source()) {
185 Some("response read transient I/O error")
186 } else {
187 None
188 }
189 } else {
190 None
191 }
192 }
193
194 fn request_query_id(request: &Request) -> Option<String> {
195 request
196 .headers()
197 .get(HEADER_QUERY_ID)
198 .and_then(|v| v.to_str().ok())
199 .map(|v| v.to_string())
200 }
201
202 fn with_request_context(
203 reqwest_error: ReqwestError,
204 request_kind: &RequestKind,
205 request: &Request,
206 retry_times: Option<u32>,
207 ) -> Error {
208 let error: Error = reqwest_error.into();
209 let error = error.with_context(request_kind.clone());
210 let error = if let Some(retry_times) = retry_times {
211 error.with_retry_times(retry_times)
212 } else {
213 error
214 };
215 if let Some(query_id) = Self::request_query_id(request) {
216 error.with_query_id(query_id)
217 } else {
218 error
219 }
220 }
221
222 fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
223 match serde_json::from_slice::<ResponseWithErrorCode>(body) {
224 Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
225 Ok(resp) => Error::Logic(status, resp.error),
226 Err(_) => Error::response_error(status, body),
227 }
228 }
229
230 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
231 let mut client = Self::from_dsn(dsn).await?;
232 client.build_client(name).await?;
233 if !client.disable_login {
234 client.login().await?;
235 }
236 if client.session_id.is_empty() {
237 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
238 }
239 let client = Arc::new(client);
240 client.check_presign().await?;
241 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
242 Ok(client)
243 }
244
245 pub fn capability(&self) -> &Capability {
246 &self.capability
247 }
248
249 fn set_presign_mode(&self, mode: PresignMode) {
250 *self.presign.lock() = mode
251 }
252 fn get_presign_mode(&self) -> PresignMode {
253 *self.presign.lock()
254 }
255
256 async fn from_dsn(dsn: &str) -> Result<Self> {
257 let u = Url::parse(dsn)?;
258 let mut client = Self::default();
259 if let Some(host) = u.host_str() {
260 client.host = host.to_string();
261 }
262
263 if u.username() != "" {
264 let password = u.password().unwrap_or_default();
265 let password = percent_decode_str(password).decode_utf8()?;
266 client.auth = Arc::new(BasicAuth::new(u.username(), password));
267 }
268
269 let mut session_state = SessionState::default();
270
271 let database = u.path().trim_start_matches('/');
272 if !database.is_empty() {
273 session_state.set_database(database);
274 }
275
276 let mut scheme = "https";
277 for (k, v) in u.query_pairs() {
278 match k.as_ref() {
279 "wait_time_secs" => {
280 client.wait_time_secs = Some(v.parse()?);
281 }
282 "max_rows_in_buffer" => {
283 client.max_rows_in_buffer = Some(v.parse()?);
284 }
285 "max_rows_per_page" => {
286 client.max_rows_per_page = Some(v.parse()?);
287 }
288 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
289 "page_request_timeout_secs" => {
290 client.page_request_timeout = {
291 let secs: u64 = v.parse()?;
292 Duration::from_secs(secs)
293 };
294 }
295 "presign" => {
296 let presign_mode = match v.as_ref() {
297 "auto" => PresignMode::Auto,
298 "detect" => PresignMode::Detect,
299 "on" => PresignMode::On,
300 "off" => PresignMode::Off,
301 _ => {
302 return Err(Error::BadArgument(format!(
303 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
304 )))
305 }
306 };
307 client.set_presign_mode(presign_mode);
308 }
309 "tenant" => {
310 client.tenant = Some(v.to_string());
311 }
312 "warehouse" => {
313 client.warehouse = Mutex::new(Some(v.to_string()));
314 }
315 "role" => session_state.set_role(v),
316 "sslmode" => match v.as_ref() {
317 "disable" => scheme = "http",
318 "require" | "enable" => scheme = "https",
319 _ => {
320 return Err(Error::BadArgument(format!(
321 "Invalid value for sslmode: {v}"
322 )))
323 }
324 },
325 "tls_ca_file" => {
326 client.tls_ca_file = Some(v.to_string());
327 }
328 "access_token" => {
329 client.auth = Arc::new(AccessTokenAuth::new(v));
330 }
331 "access_token_file" => {
332 client.auth = Arc::new(AccessTokenFileAuth::new(v));
333 }
334 "login" => {
335 client.disable_login = match v.as_ref() {
336 "disable" => true,
337 "enable" => false,
338 _ => {
339 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
340 }
341 }
342 }
343 "session_token" => {
344 client.disable_session_token = match v.as_ref() {
345 "disable" => true,
346 "enable" => false,
347 _ => {
348 return Err(Error::BadArgument(format!(
349 "Invalid value for session_token: {v}"
350 )))
351 }
352 }
353 }
354 "body_format" | "query_result_format" => {
355 let v = v.to_string().to_lowercase();
356 match v.as_str() {
357 "json" | "arrow" => client.query_result_format = v.to_string(),
358 _ => {
359 return Err(Error::BadArgument(format!(
360 "Invalid value for query_result_format: {v}"
361 )))
362 }
363 }
364 }
365 "retry_count" => {
366 client.retry_count = v.parse()?;
367 }
368 "retry_delay_secs" => {
369 client.retry_delay_secs = v.parse()?;
370 }
371 _ => {
372 session_state.set(k, v);
373 }
374 }
375 }
376 client.port = match u.port() {
377 Some(p) => p,
378 None => match scheme {
379 "http" => 80,
380 "https" => 443,
381 _ => unreachable!(),
382 },
383 };
384 client.scheme = scheme.to_string();
385 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
386 client.session_state = Mutex::new(session_state);
387
388 Ok(client)
389 }
390
391 pub fn host(&self) -> &str {
392 self.host.as_str()
393 }
394
395 pub fn port(&self) -> u16 {
396 self.port
397 }
398
399 pub fn scheme(&self) -> &str {
400 self.scheme.as_str()
401 }
402
403 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
404 let ua = match name {
405 Some(n) => n,
406 None => format!("databend-client-rust/{}", VERSION.as_str()),
407 };
408 let cookie_provider = GlobalCookieStore::new();
409 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
410 let mut initial_cookies = [&cookie].into_iter();
411 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
412 let mut cli_builder = HttpClient::builder()
413 .user_agent(ua)
414 .connect_timeout(self.connect_timeout)
415 .cookie_provider(Arc::new(cookie_provider))
416 .pool_idle_timeout(Duration::from_secs(1));
417 #[cfg(any(feature = "rustls", feature = "native-tls"))]
418 if self.scheme == "https" {
419 if let Some(ref ca_file) = self.tls_ca_file {
420 let cert_pem = tokio::fs::read(ca_file).await?;
421 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
422 cli_builder = cli_builder.add_root_certificate(cert);
423 }
424 }
425 self.cli = cli_builder.build()?;
426 Ok(())
427 }
428
429 async fn check_presign(self: &Arc<Self>) -> Result<()> {
430 let mode = match self.get_presign_mode() {
431 PresignMode::Auto => {
432 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
433 PresignMode::On
434 } else {
435 PresignMode::Off
436 }
437 }
438 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
439 Ok(_) => PresignMode::On,
440 Err(e) => {
441 warn!("presign mode off with error detected: {e}");
442 PresignMode::Off
443 }
444 },
445 mode => mode,
446 };
447 self.set_presign_mode(mode);
448 Ok(())
449 }
450
451 pub fn current_warehouse(&self) -> Option<String> {
452 let guard = self.warehouse.lock();
453 guard.clone()
454 }
455
456 pub fn current_catalog(&self) -> Option<String> {
457 let guard = self.session_state.lock();
458 guard.catalog.clone()
459 }
460
461 pub fn current_database(&self) -> Option<String> {
462 let guard = self.session_state.lock();
463 guard.database.clone()
464 }
465
466 pub fn set_warehouse(&self, warehouse: impl Into<String>) {
467 let mut guard = self.warehouse.lock();
468 *guard = Some(warehouse.into());
469 }
470
471 pub fn set_database(&self, database: impl Into<String>) {
472 let mut guard = self.session_state.lock();
473 guard.set_database(database);
474 }
475
476 pub fn set_role(&self, role: impl Into<String>) {
477 let mut guard = self.session_state.lock();
478 guard.set_role(role);
479 }
480
481 pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
482 let mut guard = self.session_state.lock();
483 guard.set(key, value);
484 }
485
486 pub async fn current_role(&self) -> Option<String> {
487 let guard = self.session_state.lock();
488 guard.role.clone()
489 }
490
491 fn in_active_transaction(&self) -> bool {
492 let guard = self.session_state.lock();
493 guard
494 .txn_state
495 .as_ref()
496 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
497 .unwrap_or(false)
498 }
499
500 pub fn username(&self) -> String {
501 self.auth.username()
502 }
503
504 fn gen_query_id(&self) -> String {
505 uuid::Uuid::now_v7().simple().to_string()
506 }
507
508 async fn handle_session(&self, session: &Option<SessionState>) {
509 let session = match session {
510 Some(session) => session,
511 None => return,
512 };
513
514 {
516 let mut session_state = self.session_state.lock();
517 *session_state = session.clone();
518 }
519
520 if let Some(settings) = session.settings.as_ref() {
522 if let Some(v) = settings.get("warehouse") {
523 let mut warehouse = self.warehouse.lock();
524 *warehouse = Some(v.clone());
525 }
526 }
527 }
528
529 pub fn set_last_node_id(&self, node_id: String) {
530 *self.last_node_id.lock() = Some(node_id)
531 }
532
533 pub fn set_last_query_id(&self, query_id: Option<String>) {
534 *self.last_query_id.lock() = query_id
535 }
536
537 pub fn last_query_id(&self) -> Option<String> {
538 self.last_query_id.lock().clone()
539 }
540
541 fn last_node_id(&self) -> Option<String> {
542 self.last_node_id.lock().clone()
543 }
544
545 fn handle_warnings(&self, resp: &QueryResponse) {
546 if let Some(warnings) = &resp.warnings {
547 for w in warnings {
548 warn!(target: "server_warnings", "server warning: {w}");
549 }
550 }
551 }
552
553 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
554 info!("start query: {sql}");
555 let (resp, batches) = self.start_query_inner(sql, None, false).await?;
556 Pages::new(self.clone(), resp, batches, need_progress)
557 }
558
559 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
560 let mut mgr = self.queries_need_heartbeat.lock();
561 if let Some(state) = mgr.remove(query_id) {
562 let self_cloned = self.clone();
563 let query_id = query_id.to_owned();
564 GLOBAL_RUNTIME.spawn(async move {
565 if let Err(e) = self_cloned
566 .end_query(&query_id, false, Some(state.node_id.as_str()))
567 .await
568 {
569 error!("failed to final query {query_id}: {e}");
570 }
571 });
572 }
573 }
574
575 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
576 if let Some(info) = &self.session_token_info {
577 let info = info.lock();
578 Ok(builder.bearer_auth(info.0.session_token.clone()))
579 } else {
580 self.auth.wrap(builder)
581 }
582 }
583
584 async fn start_query_inner(
585 &self,
586 sql: &str,
587 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
588 force_json_body: bool,
589 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
590 if !self.in_active_transaction() {
591 self.route_hint.next();
592 }
593 let endpoint = self.endpoint.join("v1/query")?;
594
595 let session_state = self.session_state();
597 let need_sticky = session_state.need_sticky.unwrap_or(false);
598 let req = QueryRequest::new(sql)
599 .with_pagination(self.make_pagination())
600 .with_session(Some(session_state))
601 .with_stage_attachment(stage_attachment_config);
602
603 let query_id = self.gen_query_id();
605 let mut headers = self.make_headers(Some(&query_id))?;
606 if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
607 debug!("accept arrow data");
608 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
609 }
610
611 if need_sticky {
612 if let Some(node_id) = self.last_node_id() {
613 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
614 }
615 }
616 let mut builder = self.cli.post(endpoint.clone()).json(&req);
617 builder = self.wrap_auth_or_session_token(builder)?;
618 let request = builder
619 .headers(headers.clone())
620 .timeout(self.page_request_timeout)
621 .build()?;
622 let response = self
623 .query_request_helper(request, true, true, true, RequestKind::QueryStart)
624 .await?;
625 self.handle_page(response, true).await
626 }
627
628 fn is_arrow_data(headers: &HeaderMap) -> bool {
629 if let Some(typ) = headers.get(CONTENT_TYPE) {
630 if let Ok(t) = typ.to_str() {
631 return t == CONTENT_TYPE_ARROW;
632 }
633 }
634 false
635 }
636
637 async fn handle_page(
638 &self,
639 response: HttpResponseData,
640 is_first: bool,
641 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
642 let status = response.status;
643 if status != StatusCode::OK {
644 let error = Self::status_body_to_error(status, &response.body);
645 if !is_first {
646 if let Error::Logic(code, ec) = &error {
647 if *code == StatusCode::NOT_FOUND {
648 return Err(Error::QueryNotFound(ec.message.clone()));
649 }
650 }
651 }
652 return Err(error);
653 }
654 let is_arrow_data = Self::is_arrow_data(&response.headers);
655 if is_first {
656 if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
657 self.route_hint.set(route_hint.to_str().unwrap_or_default());
658 }
659 }
660 let mut body = response.body;
661 let mut batches = vec![];
662 if is_arrow_data {
663 if is_first {
664 debug!("received arrow data");
665 }
666 let cursor = std::io::Cursor::new(body.as_ref());
667 let reader = StreamReader::try_new(cursor, None)
668 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
669 let schema = reader.schema();
670 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
671 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
672 } else {
673 return Err(Error::Decode(
674 "missing response_header metadata in arrow payload".to_string(),
675 ));
676 };
677 for batch in reader {
678 let batch = batch
679 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
680 batches.push(batch);
681 }
682 body = json_body
683 };
684 let resp: QueryResponse = json_from_slice(&body)?;
685 self.handle_session(&resp.session).await;
686 if let Some(err) = &resp.error {
687 return Err(Error::QueryFailed(err.clone()));
688 }
689 if is_first {
690 self.handle_warnings(&resp);
691 self.set_last_query_id(Some(resp.id.clone()));
692 if let Some(node_id) = &resp.node_id {
693 self.set_last_node_id(node_id.clone());
694 }
695 }
696 Ok((resp, batches))
697 }
698
699 pub async fn query_page(
700 &self,
701 query_id: &str,
702 next_uri: &str,
703 node_id: &Option<String>,
704 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
705 info!("query page: {next_uri}");
706 let endpoint = self.endpoint.join(next_uri)?;
707 let mut headers = self.make_headers(Some(query_id))?;
708 if self.capability.arrow_data && self.query_result_format == "arrow" {
709 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
710 }
711 let mut builder = self.cli.get(endpoint.clone());
712 builder = self
713 .wrap_auth_or_session_token(builder)?
714 .headers(headers.clone())
715 .timeout(self.page_request_timeout);
716 if let Some(node_id) = node_id {
717 builder = builder.header(HEADER_STICKY_NODE, node_id)
718 }
719 let request = builder.build()?;
720
721 let response = self
722 .query_request_helper(request, false, true, true, RequestKind::QueryPage)
723 .await?;
724 self.handle_page(response, false).await
725 }
726
727 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
728 self.end_query(query_id, true, None).await
729 }
730
731 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
732 self.end_query(query_id, false, node_id).await
733 }
734
735 pub async fn end_query(
736 &self,
737 query_id: &str,
738 is_kill: bool,
739 node_id: Option<&str>,
740 ) -> Result<()> {
741 let (method, request_kind) = if is_kill {
742 ("kill", RequestKind::QueryKill)
743 } else {
744 ("final", RequestKind::QueryFinal)
745 };
746 let uri = format!("/v1/query/{query_id}/{method}");
747 let endpoint = self.endpoint.join(&uri)?;
748 let headers = self.make_headers(Some(query_id))?;
749
750 info!("{method} query: {uri}");
751
752 let mut builder = self.cli.post(endpoint.clone());
753 if let Some(node_id) = node_id {
754 builder = builder.header(HEADER_STICKY_NODE, node_id)
755 }
756 builder = self.wrap_auth_or_session_token(builder)?;
757 let resp = builder.headers(headers.clone()).send().await?;
758 if resp.status() != 200 {
759 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
760 .with_context(request_kind)
761 .with_query_id(query_id));
762 }
763 Ok(())
764 }
765
766 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
767 self.query_all_inner(sql, false).await
768 }
769
770 pub async fn query_all_inner(
771 self: &Arc<Self>,
772 sql: &str,
773 force_json_body: bool,
774 ) -> Result<Page> {
775 let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
776 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
777 let mut all = Page::default();
778 while let Some(page) = pages.next().await {
779 all.update(page?);
780 }
781 Ok(all)
782 }
783
784 fn session_state(&self) -> SessionState {
785 self.session_state.lock().clone()
786 }
787
788 fn make_pagination(&self) -> Option<PaginationConfig> {
789 if self.wait_time_secs.is_none()
790 && self.max_rows_in_buffer.is_none()
791 && self.max_rows_per_page.is_none()
792 {
793 return None;
794 }
795 let mut pagination = PaginationConfig {
796 wait_time_secs: None,
797 max_rows_in_buffer: None,
798 max_rows_per_page: None,
799 };
800 if let Some(wait_time_secs) = self.wait_time_secs {
801 pagination.wait_time_secs = Some(wait_time_secs);
802 }
803 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
804 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
805 }
806 if let Some(max_rows_per_page) = self.max_rows_per_page {
807 pagination.max_rows_per_page = Some(max_rows_per_page);
808 }
809 Some(pagination)
810 }
811
812 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
813 let mut headers = HeaderMap::new();
814 if let Some(tenant) = &self.tenant {
815 headers.insert(HEADER_TENANT, tenant.parse()?);
816 }
817 let warehouse = self.warehouse.lock().clone();
818 if let Some(warehouse) = warehouse {
819 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
820 }
821 let route_hint = self.route_hint.current();
822 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
823 if let Some(query_id) = query_id {
824 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
825 }
826 Ok(headers)
827 }
828
829 pub async fn insert_with_stage(
830 self: &Arc<Self>,
831 sql: &str,
832 stage: &str,
833 file_format_options: BTreeMap<&str, &str>,
834 copy_options: BTreeMap<&str, &str>,
835 ) -> Result<QueryStats> {
836 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
837 let stage_attachment = Some(StageAttachmentConfig {
838 location: stage,
839 file_format_options: Some(file_format_options),
840 copy_options: Some(copy_options),
841 });
842 let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
843 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
844 let mut all = Page::default();
845 while let Some(page) = pages.next().await {
846 all.update(page?);
847 }
848 Ok(all.stats)
849 }
850
851 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
852 info!("get presigned upload url: {stage}");
853 let sql = format!("PRESIGN UPLOAD {stage}");
854 let resp = self.query_all_inner(&sql, true).await?;
855 if resp.data.len() != 1 {
856 return Err(Error::Decode(
857 "Empty response from server for presigned request".to_string(),
858 ));
859 }
860 if resp.data[0].len() != 3 {
861 return Err(Error::Decode(
862 "Invalid response from server for presigned request".to_string(),
863 ));
864 }
865 let method = resp.data[0][0].clone().unwrap_or_default();
867 if method != "PUT" {
868 return Err(Error::Decode(format!(
869 "Invalid method for presigned upload request: {method}"
870 )));
871 }
872 let headers: BTreeMap<String, String> =
873 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
874 let url = resp.data[0][2].clone().unwrap_or_default();
875 Ok(PresignedResponse {
876 method,
877 headers,
878 url,
879 })
880 }
881
882 pub async fn upload_to_stage(
883 self: &Arc<Self>,
884 stage: &str,
885 data: Reader,
886 size: u64,
887 ) -> Result<()> {
888 match self.get_presign_mode() {
889 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
890 PresignMode::On => {
891 let presigned = self.get_presigned_upload_url(stage).await?;
892 presign_upload_to_stage(presigned, data, size).await
893 }
894 PresignMode::Auto => {
895 unreachable!("PresignMode::Auto should be handled during client initialization")
896 }
897 PresignMode::Detect => {
898 unreachable!("PresignMode::Detect should be handled during client initialization")
899 }
900 }
901 }
902
903 async fn upload_to_stage_with_stream(
905 &self,
906 stage: &str,
907 data: Reader,
908 size: u64,
909 ) -> Result<()> {
910 info!("upload to stage with stream: {stage}, size: {size}");
911 if let Some(info) = self.need_pre_refresh_session().await {
912 self.refresh_session_token(info).await?;
913 }
914 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
915 let location = StageLocation::try_from(stage)?;
916 let query_id = self.gen_query_id();
917 let mut headers = self.make_headers(Some(&query_id))?;
918 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
919 let stream = Body::wrap_stream(ReaderStream::new(data));
920 let part = Part::stream_with_length(stream, size).file_name(location.path);
921 let form = Form::new().part("upload", part);
922 let mut builder = self.cli.put(endpoint.clone());
923 builder = self.wrap_auth_or_session_token(builder)?;
924 let resp = builder.headers(headers).multipart(form).send().await?;
925 let status = resp.status();
926 if status != 200 {
927 return Err(Error::response_error(status, &resp.bytes().await?)
928 .with_context(RequestKind::UploadToStage)
929 .with_query_id(query_id));
930 }
931 Ok(())
932 }
933
934 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
937 where
938 T: de::DeserializeOwned,
939 {
940 if value.starts_with("{") {
941 serde_json::from_slice(value.as_bytes())
942 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
943 } else {
944 let json = URL_SAFE.decode(value).map_err(|e| {
945 format!(
946 "Invalid value {} for {key}, base64 decode error: {}",
947 value, e
948 )
949 })?;
950 serde_json::from_slice(&json).map_err(|e| {
951 format!(
952 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
953 String::from_utf8_lossy(&json)
954 )
955 })
956 }
957 }
958
959 pub async fn streaming_load(
960 &self,
961 sql: &str,
962 data: Reader,
963 file_name: &str,
964 ) -> Result<LoadResponse> {
965 let body = Body::wrap_stream(ReaderStream::new(data));
966 let part = Part::stream(body).file_name(file_name.to_string());
967 let endpoint = self.endpoint.join("v1/streaming_load")?;
968 let mut builder = self.cli.put(endpoint.clone());
969 builder = self.wrap_auth_or_session_token(builder)?;
970 let query_id = self.gen_query_id();
971 let mut headers = self.make_headers(Some(&query_id))?;
972 headers.insert(HEADER_SQL, sql.parse()?);
973 let session = serde_json::to_string(&*self.session_state.lock())
974 .expect("serialize session state should not fail");
975 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
976 let form = Form::new().part("upload", part);
977 let resp = builder.headers(headers).multipart(form).send().await?;
978 let status = resp.status();
979 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
980 match Self::decode_json_header::<SessionState>(
981 HEADER_QUERY_CONTEXT,
982 value.to_str().unwrap(),
983 ) {
984 Ok(session) => *self.session_state.lock() = session,
985 Err(e) => {
986 error!("Error decoding session state when streaming load: {e}");
987 }
988 }
989 };
990 if status != 200 {
991 return Err(Error::response_error(status, &resp.bytes().await?)
992 .with_context(RequestKind::StreamingLoad)
993 .with_query_id(query_id));
994 }
995 let resp = resp.json::<LoadResponse>().await?;
996 Ok(resp)
997 }
998
999 async fn login(&mut self) -> Result<()> {
1000 let endpoint = self.endpoint.join("/v1/session/login")?;
1001 let headers = self.make_headers(None)?;
1002 let body = LoginRequest::from(&*self.session_state.lock());
1003 let mut builder = self.cli.post(endpoint.clone()).json(&body);
1004 if self.disable_session_token {
1005 builder = builder.query(&[("disable_session_token", true)]);
1006 }
1007 let builder = self.auth.wrap(builder)?;
1008 let request = builder
1009 .headers(headers.clone())
1010 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1011 .build()?;
1012 let response = self
1013 .query_request_helper(request, true, false, true, RequestKind::Login)
1014 .await?;
1015 if response.status == StatusCode::NOT_FOUND {
1016 info!("login return 404, skip login on the old version server");
1017 return Ok(());
1018 }
1019 if response.status != StatusCode::OK {
1020 return Err(Self::status_body_to_error(response.status, &response.body)
1021 .with_context(RequestKind::Login));
1022 }
1023 if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1024 if let Ok(s) = v.to_str() {
1025 self.session_id = s.to_string();
1026 }
1027 }
1028
1029 let response = json_from_slice(&response.body)?;
1030 match response {
1031 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1032 LoginResponseResult::Ok(info) => {
1033 let server_version = info
1034 .version
1035 .parse()
1036 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1037 self.capability = Capability::from_server_version(&server_version);
1038 self.server_version = Some(server_version.clone());
1039 let session_id = self.session_id.as_str();
1040 if let Some(tokens) = info.tokens {
1041 info!(
1042 "[session {session_id}] login success with session token version = {server_version}",
1043 );
1044 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1045 } else {
1046 info!("[session {session_id}] login success, version = {server_version}");
1047 }
1048 }
1049 }
1050 Ok(())
1051 }
1052
1053 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1054 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1055 let queries = self.queries_need_heartbeat.lock().clone();
1056 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1057 let now = Instant::now();
1058
1059 let mut query_ids = Vec::new();
1060 for (qid, state) in queries {
1061 if state.need_heartbeat(now) {
1062 query_ids.push(qid.to_string());
1063 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1064 arr.push(qid);
1065 } else {
1066 node_to_queries.insert(state.node_id, vec![qid]);
1067 }
1068 }
1069 }
1070
1071 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1072 {
1073 return Ok(());
1074 }
1075
1076 let body = json!({
1077 "node_to_queries": node_to_queries
1078 });
1079 let headers = self.make_headers(None)?;
1080 let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1081 let request = self.wrap_auth_or_session_token(builder)?.build()?;
1082 let response = self
1083 .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1084 .await?;
1085 if response.status != StatusCode::OK {
1086 return Err(Self::status_body_to_error(response.status, &response.body)
1087 .with_context(RequestKind::Heartbeat));
1088 }
1089 let json: Value = json_from_slice(&response.body)?;
1090 let session_id = self.session_id.as_str();
1091 info!("[session {session_id}] heartbeat request={body}, response={json}");
1092 if let Some(queries_to_remove) = json.get("queries_to_remove") {
1093 if let Some(arr) = queries_to_remove.as_array() {
1094 if !arr.is_empty() {
1095 let mut queries = self.queries_need_heartbeat.lock();
1096 for q in arr {
1097 if let Some(q) = q.as_str() {
1098 queries.remove(q);
1099 }
1100 }
1101 }
1102 }
1103 }
1104 let now = Instant::now();
1105 let mut queries = self.queries_need_heartbeat.lock();
1106 for qid in query_ids {
1107 if let Some(state) = queries.get_mut(&qid) {
1108 *state.last_access_time.lock() = now;
1109 }
1110 }
1111 Ok(())
1112 }
1113
1114 fn build_log_out_request(&self) -> Result<Request> {
1115 let endpoint = self.endpoint.join("/v1/session/logout")?;
1116
1117 let session_state = self.session_state();
1118 let need_sticky = session_state.need_sticky.unwrap_or(false);
1119 let mut headers = self.make_headers(None)?;
1120 if need_sticky {
1121 if let Some(node_id) = self.last_node_id() {
1122 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1123 }
1124 }
1125 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1126
1127 let builder = self.wrap_auth_or_session_token(builder)?;
1128 let req = builder.build()?;
1129 Ok(req)
1130 }
1131
1132 pub(crate) fn need_logout(&self) -> bool {
1133 self.session_token_info.is_some()
1134 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1135 }
1136
1137 async fn refresh_session_token(
1138 &self,
1139 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1140 ) -> Result<()> {
1141 let (session_token_info, _) = { self_login_info.lock().clone() };
1142 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1143 let body = RefreshSessionTokenRequest {
1144 session_token: session_token_info.session_token.clone(),
1145 };
1146 let headers = self.make_headers(None)?;
1147 let request = self
1148 .cli
1149 .post(endpoint.clone())
1150 .json(&body)
1151 .headers(headers.clone())
1152 .bearer_auth(session_token_info.refresh_token.clone())
1153 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1154 .build()?;
1155 let response = self
1156 .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1157 .await?;
1158 if response.status != StatusCode::OK {
1159 return Err(Self::status_body_to_error(response.status, &response.body)
1160 .with_context(RequestKind::SessionRefresh));
1161 }
1162 let response = json_from_slice(&response.body)?;
1163 match response {
1164 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1165 RefreshResponse::Ok(info) => {
1166 *self_login_info.lock() = (info, Instant::now());
1167 Ok(())
1168 }
1169 }
1170 }
1171
1172 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1173 if let Some(info) = &self.session_token_info {
1174 let (start, ttl) = {
1175 let guard = info.lock();
1176 (guard.1, guard.0.session_token_ttl_in_secs)
1177 };
1178 if Instant::now() > start + Duration::from_secs(ttl) {
1179 return Some(info.clone());
1180 }
1181 }
1182 None
1183 }
1184
1185 async fn query_request_helper(
1193 &self,
1194 mut request: Request,
1195 retry_if_503: bool,
1196 refresh_if_401: bool,
1197 reload_auth_if_401: bool,
1198 request_kind: RequestKind,
1199 ) -> Result<HttpResponseData> {
1200 let mut refreshed = false;
1201 let mut retries = 0;
1202 let max_retries = self.retry_count;
1203 let retry_delay = Duration::from_secs(self.retry_delay_secs);
1204
1205 loop {
1206 let req = request.try_clone().expect("request not cloneable");
1207 let response = match self.cli.execute(req).await {
1208 Ok(response) => response,
1209 Err(err) => {
1210 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1211 if retries >= max_retries.saturating_sub(1) {
1212 return Err(Self::with_request_context(
1213 err,
1214 &request_kind,
1215 &request,
1216 Some(retries),
1217 ));
1218 }
1219 retries += 1;
1220 warn!(
1221 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1222 retries,
1223 max_retries,
1224 request.url(),
1225 reason,
1226 err,
1227 retry_delay.as_secs(),
1228 );
1229 sleep(jitter(retry_delay)).await;
1230 continue;
1231 }
1232 return Err(Self::with_request_context(
1233 err,
1234 &request_kind,
1235 &request,
1236 Some(retries),
1237 ));
1238 }
1239 };
1240
1241 let status = response.status();
1242 let headers = response.headers().clone();
1243 let body = match response.bytes().await {
1244 Ok(body) => body,
1245 Err(err) => {
1246 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1247 if retries >= max_retries.saturating_sub(1) {
1248 return Err(Self::with_request_context(
1249 err,
1250 &request_kind,
1251 &request,
1252 Some(retries),
1253 ));
1254 }
1255 retries += 1;
1256 warn!(
1257 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1258 retries,
1259 max_retries,
1260 request.url(),
1261 reason,
1262 err,
1263 retry_delay.as_secs(),
1264 );
1265 sleep(jitter(retry_delay)).await;
1266 continue;
1267 }
1268 return Err(Self::with_request_context(
1269 err,
1270 &request_kind,
1271 &request,
1272 Some(retries),
1273 ));
1274 }
1275 };
1276
1277 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1278 if retries >= max_retries.saturating_sub(1) {
1279 return Ok(HttpResponseData {
1280 status,
1281 headers,
1282 body,
1283 });
1284 }
1285 retries += 1;
1286 warn!(
1287 "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1288 retries,
1289 max_retries,
1290 request.url(),
1291 retry_delay.as_secs()
1292 );
1293 sleep(jitter(retry_delay)).await;
1294 continue;
1295 }
1296
1297 if status == StatusCode::UNAUTHORIZED {
1298 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1299 let unauthorized_error =
1300 serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1301 if let Some(session_token_info) = &self.session_token_info {
1302 let should_refresh = unauthorized_error
1303 .as_ref()
1304 .map(|r| need_refresh_token(r.error.code))
1305 .unwrap_or(false);
1306 if refresh_if_401 && should_refresh && !refreshed {
1307 Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1308 refreshed = true;
1309 retries = 0;
1310 warn!(
1311 "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1312 request.url(),
1313 retry_delay.as_secs()
1314 );
1315 sleep(jitter(retry_delay)).await;
1316 continue;
1317 }
1318 }
1319 if reload_auth_if_401 && self.auth.can_reload() {
1320 if retries >= max_retries.saturating_sub(1) {
1321 return Ok(HttpResponseData {
1322 status,
1323 headers,
1324 body,
1325 });
1326 }
1327 retries += 1;
1328 let builder =
1329 RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1330 let builder = self.auth.wrap(builder)?;
1331 request = builder.build()?;
1332 warn!(
1333 "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1334 retries,
1335 max_retries,
1336 request.url(),
1337 retry_delay.as_secs()
1338 );
1339 sleep(jitter(retry_delay)).await;
1340 continue;
1341 }
1342 }
1343
1344 return Ok(HttpResponseData {
1345 status,
1346 headers,
1347 body,
1348 });
1349 }
1350 }
1351
1352 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1353 if let Err(err) = http_client.execute(request).await {
1354 error!("[session {session_id}] logout request failed: {err}");
1355 } else {
1356 info!("[session {session_id}] logout success");
1357 };
1358 }
1359
1360 pub async fn close(&self) {
1361 let session_id = &self.session_id;
1362 info!("[session {session_id}] try closing now");
1363 if self
1364 .closed
1365 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1366 .is_ok()
1367 {
1368 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1369 if self.need_logout() {
1370 let cli = self.cli.clone();
1371 let req = self
1372 .build_log_out_request()
1373 .expect("failed to build logout request");
1374 Self::logout(cli, req, &self.session_id).await;
1375 }
1376 }
1377 }
1378 pub fn close_with_spawn(&self) {
1379 let session_id = &self.session_id;
1380 info!("[session {session_id}]: try closing with spawn");
1381 if self
1382 .closed
1383 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1384 .is_ok()
1385 {
1386 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1387 if self.need_logout() {
1388 let cli = self.cli.clone();
1389 let req = self
1390 .build_log_out_request()
1391 .expect("failed to build logout request");
1392 let session_id = self.session_id.clone();
1393 GLOBAL_RUNTIME.spawn(async move {
1394 Self::logout(cli, req, session_id.as_str()).await;
1395 });
1396 }
1397 }
1398 }
1399
1400 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1401 let mut queries = self.queries_need_heartbeat.lock();
1402 queries.insert(query_id.to_string(), state);
1403 }
1404}
1405
1406fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1407where
1408 T: Deserialize<'a>,
1409{
1410 serde_json::from_slice::<T>(body).map_err(|e| {
1411 Error::Decode(format!(
1412 "fail to decode JSON response: {e}, body: {}",
1413 String::from_utf8_lossy(body)
1414 ))
1415 })
1416}
1417
1418impl Default for APIClient {
1419 fn default() -> Self {
1420 Self {
1421 session_id: Default::default(),
1422 cli: HttpClient::new(),
1423 scheme: "http".to_string(),
1424 endpoint: Url::parse("http://localhost:8080").unwrap(),
1425 host: "localhost".to_string(),
1426 port: 8000,
1427 tenant: None,
1428 warehouse: Mutex::new(None),
1429 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1430 session_state: Mutex::new(SessionState::default()),
1431 wait_time_secs: None,
1432 max_rows_in_buffer: None,
1433 max_rows_per_page: None,
1434 connect_timeout: Duration::from_secs(10),
1435 page_request_timeout: Duration::from_secs(300),
1436 tls_ca_file: None,
1437 presign: Mutex::new(PresignMode::Auto),
1438 route_hint: RouteHintGenerator::new(),
1439 last_node_id: Default::default(),
1440 disable_session_token: true,
1441 disable_login: false,
1442 query_result_format: "json".to_string(),
1443 session_token_info: None,
1444 closed: AtomicBool::new(false),
1445 last_query_id: Default::default(),
1446 server_version: None,
1447 capability: Default::default(),
1448 queries_need_heartbeat: Default::default(),
1449 retry_count: 3,
1450 retry_delay_secs: 10,
1451 }
1452 }
1453}
1454
1455struct RouteHintGenerator {
1456 nonce: AtomicU64,
1457 current: std::sync::Mutex<String>,
1458}
1459
1460impl RouteHintGenerator {
1461 fn new() -> Self {
1462 let gen = Self {
1463 nonce: AtomicU64::new(0),
1464 current: std::sync::Mutex::new("".to_string()),
1465 };
1466 gen.next();
1467 gen
1468 }
1469
1470 fn current(&self) -> String {
1471 let guard = self.current.lock().unwrap();
1472 guard.clone()
1473 }
1474
1475 fn set(&self, hint: &str) {
1476 let mut guard = self.current.lock().unwrap();
1477 *guard = hint.to_string();
1478 }
1479
1480 fn next(&self) -> String {
1481 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1482 let uuid = uuid::Uuid::new_v4();
1483 let current = format!("rh:{uuid}:{nonce:06}");
1484 let mut guard = self.current.lock().unwrap();
1485 guard.clone_from(¤t);
1486 current
1487 }
1488}
1489
1490#[cfg(test)]
1491mod test {
1492 use super::*;
1493
1494 #[tokio::test]
1495 async fn parse_dsn() -> Result<()> {
1496 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1497 let client = APIClient::from_dsn(dsn).await?;
1498 assert_eq!(client.host, "app.databend.com");
1499 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1500 assert_eq!(client.wait_time_secs, Some(10));
1501 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1502 assert_eq!(client.max_rows_per_page, Some(10000));
1503 assert_eq!(client.tenant, None);
1504 assert_eq!(
1505 *client.warehouse.try_lock().unwrap(),
1506 Some("wh".to_string())
1507 );
1508 Ok(())
1509 }
1510
1511 #[tokio::test]
1512 async fn parse_encoded_password() -> Result<()> {
1513 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1514 let client = APIClient::from_dsn(dsn).await?;
1515 assert_eq!(client.host(), "localhost");
1516 assert_eq!(client.port(), 443);
1517 Ok(())
1518 }
1519
1520 #[tokio::test]
1521 async fn parse_special_chars_password() -> Result<()> {
1522 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1523 let client = APIClient::from_dsn(dsn).await?;
1524 assert_eq!(client.host(), "localhost");
1525 assert_eq!(client.port(), 8000);
1526 Ok(())
1527 }
1528}