1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, RequestKind, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48 Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77
78static VERSION: Lazy<String> = Lazy::new(|| {
79 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
80 version.to_string()
81});
82
83#[derive(Clone)]
84pub(crate) struct QueryState {
85 pub node_id: String,
86 pub last_access_time: Arc<Mutex<Instant>>,
87 pub timeout_secs: u64,
88}
89
90impl QueryState {
91 pub fn need_heartbeat(&self, now: Instant) -> bool {
92 let t = *self.last_access_time.lock();
93 now.duration_since(t).as_secs() > self.timeout_secs / 2
94 }
95}
96
97struct HttpResponseData {
98 status: StatusCode,
99 headers: HeaderMap,
100 body: Bytes,
101}
102
103pub struct APIClient {
104 pub(crate) session_id: String,
105 cli: HttpClient,
106 scheme: String,
107 host: String,
108 port: u16,
109
110 endpoint: Url,
111
112 auth: Arc<dyn Auth>,
113
114 tenant: Option<String>,
115 warehouse: Mutex<Option<String>>,
116 session_state: Mutex<SessionState>,
117 route_hint: RouteHintGenerator,
118
119 disable_login: bool,
120 query_result_format: String,
121 disable_session_token: bool,
122 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
123
124 closed: AtomicBool,
125
126 server_version: Option<Version>,
127
128 wait_time_secs: Option<i64>,
129 max_rows_in_buffer: Option<i64>,
130 max_rows_per_page: Option<i64>,
131
132 connect_timeout: Duration,
133 page_request_timeout: Duration,
134
135 tls_ca_file: Option<String>,
136
137 presign: Mutex<PresignMode>,
138 last_node_id: Mutex<Option<String>>,
139 last_query_id: Mutex<Option<String>>,
140
141 capability: Capability,
142
143 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
144
145 retry_count: u32,
146 retry_delay_secs: u64,
147}
148
149impl Drop for APIClient {
150 fn drop(&mut self) {
151 self.close_with_spawn()
152 }
153}
154
155impl APIClient {
156 fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
157 while let Some(err) = source {
158 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
159 match io_err.kind() {
160 ErrorKind::TimedOut
161 | ErrorKind::UnexpectedEof
162 | ErrorKind::ConnectionReset
163 | ErrorKind::ConnectionAborted
164 | ErrorKind::BrokenPipe
165 | ErrorKind::NotConnected
166 | ErrorKind::WouldBlock
167 | ErrorKind::Interrupted => return true,
168 _ => {}
169 }
170 }
171 source = err.source();
172 }
173 false
174 }
175
176 fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
177 if err.is_timeout() {
178 Some("request timeout")
179 } else if err.is_connect() {
180 Some("connection error")
181 } else if err.is_request() {
182 Some("request error")
183 } else if err.is_body() || err.is_decode() {
184 if Self::has_transient_io_source(err.source()) {
185 Some("response read transient I/O error")
186 } else {
187 None
188 }
189 } else {
190 None
191 }
192 }
193
194 fn request_query_id(request: &Request) -> Option<String> {
195 request
196 .headers()
197 .get(HEADER_QUERY_ID)
198 .and_then(|v| v.to_str().ok())
199 .map(|v| v.to_string())
200 }
201
202 fn with_request_context(
203 reqwest_error: ReqwestError,
204 request_kind: &RequestKind,
205 request: &Request,
206 retry_times: Option<u32>,
207 ) -> Error {
208 let error: Error = reqwest_error.into();
209 let error = error.with_context(request_kind.clone());
210 let error = if let Some(retry_times) = retry_times {
211 error.with_retry_times(retry_times)
212 } else {
213 error
214 };
215 if let Some(query_id) = Self::request_query_id(request) {
216 error.with_query_id(query_id)
217 } else {
218 error
219 }
220 }
221
222 fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
223 match serde_json::from_slice::<ResponseWithErrorCode>(body) {
224 Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
225 Ok(resp) => Error::Logic(status, resp.error),
226 Err(_) => Error::response_error(status, body),
227 }
228 }
229
230 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
231 let mut client = Self::from_dsn(dsn).await?;
232 client.build_client(name).await?;
233 if !client.disable_login {
234 client.login().await?;
235 }
236 if client.session_id.is_empty() {
237 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
238 }
239 let client = Arc::new(client);
240 client.check_presign().await?;
241 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
242 Ok(client)
243 }
244
245 pub fn capability(&self) -> &Capability {
246 &self.capability
247 }
248
249 fn set_presign_mode(&self, mode: PresignMode) {
250 *self.presign.lock() = mode
251 }
252 fn get_presign_mode(&self) -> PresignMode {
253 *self.presign.lock()
254 }
255
256 async fn from_dsn(dsn: &str) -> Result<Self> {
257 let u = Url::parse(dsn)?;
258 let mut client = Self::default();
259 if let Some(host) = u.host_str() {
260 client.host = host.to_string();
261 }
262
263 if u.username() != "" {
264 let password = u.password().unwrap_or_default();
265 let password = percent_decode_str(password).decode_utf8()?;
266 client.auth = Arc::new(BasicAuth::new(u.username(), password));
267 }
268
269 let mut session_state = SessionState::default();
270
271 let database = u.path().trim_start_matches('/');
272 if !database.is_empty() {
273 session_state.set_database(database);
274 }
275
276 let mut scheme = "https";
277 for (k, v) in u.query_pairs() {
278 match k.as_ref() {
279 "wait_time_secs" => {
280 client.wait_time_secs = Some(v.parse()?);
281 }
282 "max_rows_in_buffer" => {
283 client.max_rows_in_buffer = Some(v.parse()?);
284 }
285 "max_rows_per_page" => {
286 client.max_rows_per_page = Some(v.parse()?);
287 }
288 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
289 "page_request_timeout_secs" => {
290 client.page_request_timeout = {
291 let secs: u64 = v.parse()?;
292 Duration::from_secs(secs)
293 };
294 }
295 "presign" => {
296 let presign_mode = match v.as_ref() {
297 "auto" => PresignMode::Auto,
298 "detect" => PresignMode::Detect,
299 "on" => PresignMode::On,
300 "off" => PresignMode::Off,
301 _ => {
302 return Err(Error::BadArgument(format!(
303 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
304 )))
305 }
306 };
307 client.set_presign_mode(presign_mode);
308 }
309 "tenant" => {
310 client.tenant = Some(v.to_string());
311 }
312 "warehouse" => {
313 client.warehouse = Mutex::new(Some(v.to_string()));
314 }
315 "role" => session_state.set_role(v),
316 "sslmode" => match v.as_ref() {
317 "disable" => scheme = "http",
318 "require" | "enable" => scheme = "https",
319 _ => {
320 return Err(Error::BadArgument(format!(
321 "Invalid value for sslmode: {v}"
322 )))
323 }
324 },
325 "tls_ca_file" => {
326 client.tls_ca_file = Some(v.to_string());
327 }
328 "access_token" => {
329 client.auth = Arc::new(AccessTokenAuth::new(v));
330 }
331 "access_token_file" => {
332 client.auth = Arc::new(AccessTokenFileAuth::new(v));
333 }
334 "login" => {
335 client.disable_login = match v.as_ref() {
336 "disable" => true,
337 "enable" => false,
338 _ => {
339 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
340 }
341 }
342 }
343 "session_token" => {
344 client.disable_session_token = match v.as_ref() {
345 "disable" => true,
346 "enable" => false,
347 _ => {
348 return Err(Error::BadArgument(format!(
349 "Invalid value for session_token: {v}"
350 )))
351 }
352 }
353 }
354 "body_format" | "query_result_format" => {
355 let v = v.to_string().to_lowercase();
356 match v.as_str() {
357 "json" | "arrow" => client.query_result_format = v.to_string(),
358 _ => {
359 return Err(Error::BadArgument(format!(
360 "Invalid value for query_result_format: {v}"
361 )))
362 }
363 }
364 }
365 "retry_count" => {
366 client.retry_count = v.parse()?;
367 }
368 "retry_delay_secs" => {
369 client.retry_delay_secs = v.parse()?;
370 }
371 _ => {
372 session_state.set(k, v);
373 }
374 }
375 }
376 client.port = match u.port() {
377 Some(p) => p,
378 None => match scheme {
379 "http" => 80,
380 "https" => 443,
381 _ => unreachable!(),
382 },
383 };
384 client.scheme = scheme.to_string();
385 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
386 client.session_state = Mutex::new(session_state);
387
388 Ok(client)
389 }
390
391 pub fn host(&self) -> &str {
392 self.host.as_str()
393 }
394
395 pub fn port(&self) -> u16 {
396 self.port
397 }
398
399 pub fn scheme(&self) -> &str {
400 self.scheme.as_str()
401 }
402
403 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
404 let ua = match name {
405 Some(n) => n,
406 None => format!("databend-client-rust/{}", VERSION.as_str()),
407 };
408 let cookie_provider = GlobalCookieStore::new();
409 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
410 let mut initial_cookies = [&cookie].into_iter();
411 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
412 let mut cli_builder = HttpClient::builder()
413 .user_agent(ua)
414 .cookie_provider(Arc::new(cookie_provider))
415 .pool_idle_timeout(Duration::from_secs(1));
416 #[cfg(any(feature = "rustls", feature = "native-tls"))]
417 if self.scheme == "https" {
418 if let Some(ref ca_file) = self.tls_ca_file {
419 let cert_pem = tokio::fs::read(ca_file).await?;
420 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
421 cli_builder = cli_builder.add_root_certificate(cert);
422 }
423 }
424 self.cli = cli_builder.build()?;
425 Ok(())
426 }
427
428 async fn check_presign(self: &Arc<Self>) -> Result<()> {
429 let mode = match self.get_presign_mode() {
430 PresignMode::Auto => {
431 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
432 PresignMode::On
433 } else {
434 PresignMode::Off
435 }
436 }
437 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
438 Ok(_) => PresignMode::On,
439 Err(e) => {
440 warn!("presign mode off with error detected: {e}");
441 PresignMode::Off
442 }
443 },
444 mode => mode,
445 };
446 self.set_presign_mode(mode);
447 Ok(())
448 }
449
450 pub fn current_warehouse(&self) -> Option<String> {
451 let guard = self.warehouse.lock();
452 guard.clone()
453 }
454
455 pub fn current_catalog(&self) -> Option<String> {
456 let guard = self.session_state.lock();
457 guard.catalog.clone()
458 }
459
460 pub fn current_database(&self) -> Option<String> {
461 let guard = self.session_state.lock();
462 guard.database.clone()
463 }
464
465 pub fn set_warehouse(&self, warehouse: impl Into<String>) {
466 let mut guard = self.warehouse.lock();
467 *guard = Some(warehouse.into());
468 }
469
470 pub fn set_database(&self, database: impl Into<String>) {
471 let mut guard = self.session_state.lock();
472 guard.set_database(database);
473 }
474
475 pub fn set_role(&self, role: impl Into<String>) {
476 let mut guard = self.session_state.lock();
477 guard.set_role(role);
478 }
479
480 pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
481 let mut guard = self.session_state.lock();
482 guard.set(key, value);
483 }
484
485 pub async fn current_role(&self) -> Option<String> {
486 let guard = self.session_state.lock();
487 guard.role.clone()
488 }
489
490 fn in_active_transaction(&self) -> bool {
491 let guard = self.session_state.lock();
492 guard
493 .txn_state
494 .as_ref()
495 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
496 .unwrap_or(false)
497 }
498
499 pub fn username(&self) -> String {
500 self.auth.username()
501 }
502
503 fn gen_query_id(&self) -> String {
504 uuid::Uuid::now_v7().simple().to_string()
505 }
506
507 async fn handle_session(&self, session: &Option<SessionState>) {
508 let session = match session {
509 Some(session) => session,
510 None => return,
511 };
512
513 {
515 let mut session_state = self.session_state.lock();
516 *session_state = session.clone();
517 }
518
519 if let Some(settings) = session.settings.as_ref() {
521 if let Some(v) = settings.get("warehouse") {
522 let mut warehouse = self.warehouse.lock();
523 *warehouse = Some(v.clone());
524 }
525 }
526 }
527
528 pub fn set_last_node_id(&self, node_id: String) {
529 *self.last_node_id.lock() = Some(node_id)
530 }
531
532 pub fn set_last_query_id(&self, query_id: Option<String>) {
533 *self.last_query_id.lock() = query_id
534 }
535
536 pub fn last_query_id(&self) -> Option<String> {
537 self.last_query_id.lock().clone()
538 }
539
540 fn last_node_id(&self) -> Option<String> {
541 self.last_node_id.lock().clone()
542 }
543
544 fn handle_warnings(&self, resp: &QueryResponse) {
545 if let Some(warnings) = &resp.warnings {
546 for w in warnings {
547 warn!(target: "server_warnings", "server warning: {w}");
548 }
549 }
550 }
551
552 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
553 info!("start query: {sql}");
554 let (resp, batches) = self.start_query_inner(sql, None, false).await?;
555 Pages::new(self.clone(), resp, batches, need_progress)
556 }
557
558 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
559 let mut mgr = self.queries_need_heartbeat.lock();
560 if let Some(state) = mgr.remove(query_id) {
561 let self_cloned = self.clone();
562 let query_id = query_id.to_owned();
563 GLOBAL_RUNTIME.spawn(async move {
564 if let Err(e) = self_cloned
565 .end_query(&query_id, false, Some(state.node_id.as_str()))
566 .await
567 {
568 error!("failed to final query {query_id}: {e}");
569 }
570 });
571 }
572 }
573
574 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
575 if let Some(info) = &self.session_token_info {
576 let info = info.lock();
577 Ok(builder.bearer_auth(info.0.session_token.clone()))
578 } else {
579 self.auth.wrap(builder)
580 }
581 }
582
583 async fn start_query_inner(
584 &self,
585 sql: &str,
586 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
587 force_json_body: bool,
588 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
589 if !self.in_active_transaction() {
590 self.route_hint.next();
591 }
592 let endpoint = self.endpoint.join("v1/query")?;
593
594 let session_state = self.session_state();
596 let need_sticky = session_state.need_sticky.unwrap_or(false);
597 let req = QueryRequest::new(sql)
598 .with_pagination(self.make_pagination())
599 .with_session(Some(session_state))
600 .with_stage_attachment(stage_attachment_config);
601
602 let query_id = self.gen_query_id();
604 let mut headers = self.make_headers(Some(&query_id))?;
605 if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
606 debug!("accept arrow data");
607 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
608 }
609
610 if need_sticky {
611 if let Some(node_id) = self.last_node_id() {
612 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
613 }
614 }
615 let mut builder = self.cli.post(endpoint.clone()).json(&req);
616 builder = self.wrap_auth_or_session_token(builder)?;
617 let request = builder.headers(headers.clone()).build()?;
618 let response = self
619 .query_request_helper(request, true, true, true, RequestKind::QueryStart)
620 .await?;
621 self.handle_page(response, true).await
622 }
623
624 fn is_arrow_data(headers: &HeaderMap) -> bool {
625 if let Some(typ) = headers.get(CONTENT_TYPE) {
626 if let Ok(t) = typ.to_str() {
627 return t == CONTENT_TYPE_ARROW;
628 }
629 }
630 false
631 }
632
633 async fn handle_page(
634 &self,
635 response: HttpResponseData,
636 is_first: bool,
637 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
638 let status = response.status;
639 if status != StatusCode::OK {
640 let error = Self::status_body_to_error(status, &response.body);
641 if !is_first {
642 if let Error::Logic(code, ec) = &error {
643 if *code == StatusCode::NOT_FOUND {
644 return Err(Error::QueryNotFound(ec.message.clone()));
645 }
646 }
647 }
648 return Err(error);
649 }
650 let is_arrow_data = Self::is_arrow_data(&response.headers);
651 if is_first {
652 if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
653 self.route_hint.set(route_hint.to_str().unwrap_or_default());
654 }
655 }
656 let mut body = response.body;
657 let mut batches = vec![];
658 if is_arrow_data {
659 if is_first {
660 debug!("received arrow data");
661 }
662 let cursor = std::io::Cursor::new(body.as_ref());
663 let reader = StreamReader::try_new(cursor, None)
664 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
665 let schema = reader.schema();
666 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
667 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
668 } else {
669 return Err(Error::Decode(
670 "missing response_header metadata in arrow payload".to_string(),
671 ));
672 };
673 for batch in reader {
674 let batch = batch
675 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
676 batches.push(batch);
677 }
678 body = json_body
679 };
680 let resp: QueryResponse = json_from_slice(&body)?;
681 self.handle_session(&resp.session).await;
682 if let Some(err) = &resp.error {
683 return Err(Error::QueryFailed(err.clone()));
684 }
685 if is_first {
686 self.handle_warnings(&resp);
687 self.set_last_query_id(Some(resp.id.clone()));
688 if let Some(node_id) = &resp.node_id {
689 self.set_last_node_id(node_id.clone());
690 }
691 }
692 Ok((resp, batches))
693 }
694
695 pub async fn query_page(
696 &self,
697 query_id: &str,
698 next_uri: &str,
699 node_id: &Option<String>,
700 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
701 info!("query page: {next_uri}");
702 let endpoint = self.endpoint.join(next_uri)?;
703 let mut headers = self.make_headers(Some(query_id))?;
704 if self.capability.arrow_data && self.query_result_format == "arrow" {
705 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
706 }
707 let mut builder = self.cli.get(endpoint.clone());
708 builder = self
709 .wrap_auth_or_session_token(builder)?
710 .headers(headers.clone())
711 .timeout(self.page_request_timeout);
712 if let Some(node_id) = node_id {
713 builder = builder.header(HEADER_STICKY_NODE, node_id)
714 }
715 let request = builder.build()?;
716
717 let response = self
718 .query_request_helper(request, false, true, true, RequestKind::QueryPage)
719 .await?;
720 self.handle_page(response, false).await
721 }
722
723 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
724 self.end_query(query_id, true, None).await
725 }
726
727 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
728 self.end_query(query_id, false, node_id).await
729 }
730
731 pub async fn end_query(
732 &self,
733 query_id: &str,
734 is_kill: bool,
735 node_id: Option<&str>,
736 ) -> Result<()> {
737 let (method, request_kind) = if is_kill {
738 ("kill", RequestKind::QueryKill)
739 } else {
740 ("final", RequestKind::QueryFinal)
741 };
742 let uri = format!("/v1/query/{query_id}/{method}");
743 let endpoint = self.endpoint.join(&uri)?;
744 let headers = self.make_headers(Some(query_id))?;
745
746 info!("{method} query: {uri}");
747
748 let mut builder = self.cli.post(endpoint.clone());
749 if let Some(node_id) = node_id {
750 builder = builder.header(HEADER_STICKY_NODE, node_id)
751 }
752 builder = self.wrap_auth_or_session_token(builder)?;
753 let resp = builder.headers(headers.clone()).send().await?;
754 if resp.status() != 200 {
755 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
756 .with_context(request_kind)
757 .with_query_id(query_id));
758 }
759 Ok(())
760 }
761
762 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
763 self.query_all_inner(sql, false).await
764 }
765
766 pub async fn query_all_inner(
767 self: &Arc<Self>,
768 sql: &str,
769 force_json_body: bool,
770 ) -> Result<Page> {
771 let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
772 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
773 let mut all = Page::default();
774 while let Some(page) = pages.next().await {
775 all.update(page?);
776 }
777 Ok(all)
778 }
779
780 fn session_state(&self) -> SessionState {
781 self.session_state.lock().clone()
782 }
783
784 fn make_pagination(&self) -> Option<PaginationConfig> {
785 if self.wait_time_secs.is_none()
786 && self.max_rows_in_buffer.is_none()
787 && self.max_rows_per_page.is_none()
788 {
789 return None;
790 }
791 let mut pagination = PaginationConfig {
792 wait_time_secs: None,
793 max_rows_in_buffer: None,
794 max_rows_per_page: None,
795 };
796 if let Some(wait_time_secs) = self.wait_time_secs {
797 pagination.wait_time_secs = Some(wait_time_secs);
798 }
799 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
800 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
801 }
802 if let Some(max_rows_per_page) = self.max_rows_per_page {
803 pagination.max_rows_per_page = Some(max_rows_per_page);
804 }
805 Some(pagination)
806 }
807
808 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
809 let mut headers = HeaderMap::new();
810 if let Some(tenant) = &self.tenant {
811 headers.insert(HEADER_TENANT, tenant.parse()?);
812 }
813 let warehouse = self.warehouse.lock().clone();
814 if let Some(warehouse) = warehouse {
815 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
816 }
817 let route_hint = self.route_hint.current();
818 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
819 if let Some(query_id) = query_id {
820 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
821 }
822 Ok(headers)
823 }
824
825 pub async fn insert_with_stage(
826 self: &Arc<Self>,
827 sql: &str,
828 stage: &str,
829 file_format_options: BTreeMap<&str, &str>,
830 copy_options: BTreeMap<&str, &str>,
831 ) -> Result<QueryStats> {
832 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
833 let stage_attachment = Some(StageAttachmentConfig {
834 location: stage,
835 file_format_options: Some(file_format_options),
836 copy_options: Some(copy_options),
837 });
838 let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
839 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
840 let mut all = Page::default();
841 while let Some(page) = pages.next().await {
842 all.update(page?);
843 }
844 Ok(all.stats)
845 }
846
847 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
848 info!("get presigned upload url: {stage}");
849 let sql = format!("PRESIGN UPLOAD {stage}");
850 let resp = self.query_all_inner(&sql, true).await?;
851 if resp.data.len() != 1 {
852 return Err(Error::Decode(
853 "Empty response from server for presigned request".to_string(),
854 ));
855 }
856 if resp.data[0].len() != 3 {
857 return Err(Error::Decode(
858 "Invalid response from server for presigned request".to_string(),
859 ));
860 }
861 let method = resp.data[0][0].clone().unwrap_or_default();
863 if method != "PUT" {
864 return Err(Error::Decode(format!(
865 "Invalid method for presigned upload request: {method}"
866 )));
867 }
868 let headers: BTreeMap<String, String> =
869 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
870 let url = resp.data[0][2].clone().unwrap_or_default();
871 Ok(PresignedResponse {
872 method,
873 headers,
874 url,
875 })
876 }
877
878 pub async fn upload_to_stage(
879 self: &Arc<Self>,
880 stage: &str,
881 data: Reader,
882 size: u64,
883 ) -> Result<()> {
884 match self.get_presign_mode() {
885 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
886 PresignMode::On => {
887 let presigned = self.get_presigned_upload_url(stage).await?;
888 presign_upload_to_stage(presigned, data, size).await
889 }
890 PresignMode::Auto => {
891 unreachable!("PresignMode::Auto should be handled during client initialization")
892 }
893 PresignMode::Detect => {
894 unreachable!("PresignMode::Detect should be handled during client initialization")
895 }
896 }
897 }
898
899 async fn upload_to_stage_with_stream(
901 &self,
902 stage: &str,
903 data: Reader,
904 size: u64,
905 ) -> Result<()> {
906 info!("upload to stage with stream: {stage}, size: {size}");
907 if let Some(info) = self.need_pre_refresh_session().await {
908 self.refresh_session_token(info).await?;
909 }
910 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
911 let location = StageLocation::try_from(stage)?;
912 let query_id = self.gen_query_id();
913 let mut headers = self.make_headers(Some(&query_id))?;
914 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
915 let stream = Body::wrap_stream(ReaderStream::new(data));
916 let part = Part::stream_with_length(stream, size).file_name(location.path);
917 let form = Form::new().part("upload", part);
918 let mut builder = self.cli.put(endpoint.clone());
919 builder = self.wrap_auth_or_session_token(builder)?;
920 let resp = builder.headers(headers).multipart(form).send().await?;
921 let status = resp.status();
922 if status != 200 {
923 return Err(Error::response_error(status, &resp.bytes().await?)
924 .with_context(RequestKind::UploadToStage)
925 .with_query_id(query_id));
926 }
927 Ok(())
928 }
929
930 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
933 where
934 T: de::DeserializeOwned,
935 {
936 if value.starts_with("{") {
937 serde_json::from_slice(value.as_bytes())
938 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
939 } else {
940 let json = URL_SAFE.decode(value).map_err(|e| {
941 format!(
942 "Invalid value {} for {key}, base64 decode error: {}",
943 value, e
944 )
945 })?;
946 serde_json::from_slice(&json).map_err(|e| {
947 format!(
948 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
949 String::from_utf8_lossy(&json)
950 )
951 })
952 }
953 }
954
955 pub async fn streaming_load(
956 &self,
957 sql: &str,
958 data: Reader,
959 file_name: &str,
960 ) -> Result<LoadResponse> {
961 let body = Body::wrap_stream(ReaderStream::new(data));
962 let part = Part::stream(body).file_name(file_name.to_string());
963 let endpoint = self.endpoint.join("v1/streaming_load")?;
964 let mut builder = self.cli.put(endpoint.clone());
965 builder = self.wrap_auth_or_session_token(builder)?;
966 let query_id = self.gen_query_id();
967 let mut headers = self.make_headers(Some(&query_id))?;
968 headers.insert(HEADER_SQL, sql.parse()?);
969 let session = serde_json::to_string(&*self.session_state.lock())
970 .expect("serialize session state should not fail");
971 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
972 let form = Form::new().part("upload", part);
973 let resp = builder.headers(headers).multipart(form).send().await?;
974 let status = resp.status();
975 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
976 match Self::decode_json_header::<SessionState>(
977 HEADER_QUERY_CONTEXT,
978 value.to_str().unwrap(),
979 ) {
980 Ok(session) => *self.session_state.lock() = session,
981 Err(e) => {
982 error!("Error decoding session state when streaming load: {e}");
983 }
984 }
985 };
986 if status != 200 {
987 return Err(Error::response_error(status, &resp.bytes().await?)
988 .with_context(RequestKind::StreamingLoad)
989 .with_query_id(query_id));
990 }
991 let resp = resp.json::<LoadResponse>().await?;
992 Ok(resp)
993 }
994
995 async fn login(&mut self) -> Result<()> {
996 let endpoint = self.endpoint.join("/v1/session/login")?;
997 let headers = self.make_headers(None)?;
998 let body = LoginRequest::from(&*self.session_state.lock());
999 let mut builder = self.cli.post(endpoint.clone()).json(&body);
1000 if self.disable_session_token {
1001 builder = builder.query(&[("disable_session_token", true)]);
1002 }
1003 let builder = self.auth.wrap(builder)?;
1004 let request = builder
1005 .headers(headers.clone())
1006 .timeout(self.connect_timeout)
1007 .build()?;
1008 let response = self
1009 .query_request_helper(request, true, false, true, RequestKind::Login)
1010 .await?;
1011 if response.status == StatusCode::NOT_FOUND {
1012 info!("login return 404, skip login on the old version server");
1013 return Ok(());
1014 }
1015 if response.status != StatusCode::OK {
1016 return Err(Self::status_body_to_error(response.status, &response.body)
1017 .with_context(RequestKind::Login));
1018 }
1019 if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1020 if let Ok(s) = v.to_str() {
1021 self.session_id = s.to_string();
1022 }
1023 }
1024
1025 let response = json_from_slice(&response.body)?;
1026 match response {
1027 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1028 LoginResponseResult::Ok(info) => {
1029 let server_version = info
1030 .version
1031 .parse()
1032 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1033 self.capability = Capability::from_server_version(&server_version);
1034 self.server_version = Some(server_version.clone());
1035 let session_id = self.session_id.as_str();
1036 if let Some(tokens) = info.tokens {
1037 info!(
1038 "[session {session_id}] login success with session token version = {server_version}",
1039 );
1040 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1041 } else {
1042 info!("[session {session_id}] login success, version = {server_version}");
1043 }
1044 }
1045 }
1046 Ok(())
1047 }
1048
1049 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1050 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1051 let queries = self.queries_need_heartbeat.lock().clone();
1052 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1053 let now = Instant::now();
1054
1055 let mut query_ids = Vec::new();
1056 for (qid, state) in queries {
1057 if state.need_heartbeat(now) {
1058 query_ids.push(qid.to_string());
1059 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1060 arr.push(qid);
1061 } else {
1062 node_to_queries.insert(state.node_id, vec![qid]);
1063 }
1064 }
1065 }
1066
1067 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1068 {
1069 return Ok(());
1070 }
1071
1072 let body = json!({
1073 "node_to_queries": node_to_queries
1074 });
1075 let headers = self.make_headers(None)?;
1076 let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1077 let request = self.wrap_auth_or_session_token(builder)?.build()?;
1078 let response = self
1079 .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1080 .await?;
1081 if response.status != StatusCode::OK {
1082 return Err(Self::status_body_to_error(response.status, &response.body)
1083 .with_context(RequestKind::Heartbeat));
1084 }
1085 let json: Value = json_from_slice(&response.body)?;
1086 let session_id = self.session_id.as_str();
1087 info!("[session {session_id}] heartbeat request={body}, response={json}");
1088 if let Some(queries_to_remove) = json.get("queries_to_remove") {
1089 if let Some(arr) = queries_to_remove.as_array() {
1090 if !arr.is_empty() {
1091 let mut queries = self.queries_need_heartbeat.lock();
1092 for q in arr {
1093 if let Some(q) = q.as_str() {
1094 queries.remove(q);
1095 }
1096 }
1097 }
1098 }
1099 }
1100 let now = Instant::now();
1101 let mut queries = self.queries_need_heartbeat.lock();
1102 for qid in query_ids {
1103 if let Some(state) = queries.get_mut(&qid) {
1104 *state.last_access_time.lock() = now;
1105 }
1106 }
1107 Ok(())
1108 }
1109
1110 fn build_log_out_request(&self) -> Result<Request> {
1111 let endpoint = self.endpoint.join("/v1/session/logout")?;
1112
1113 let session_state = self.session_state();
1114 let need_sticky = session_state.need_sticky.unwrap_or(false);
1115 let mut headers = self.make_headers(None)?;
1116 if need_sticky {
1117 if let Some(node_id) = self.last_node_id() {
1118 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1119 }
1120 }
1121 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1122
1123 let builder = self.wrap_auth_or_session_token(builder)?;
1124 let req = builder.build()?;
1125 Ok(req)
1126 }
1127
1128 pub(crate) fn need_logout(&self) -> bool {
1129 self.session_token_info.is_some()
1130 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1131 }
1132
1133 async fn refresh_session_token(
1134 &self,
1135 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1136 ) -> Result<()> {
1137 let (session_token_info, _) = { self_login_info.lock().clone() };
1138 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1139 let body = RefreshSessionTokenRequest {
1140 session_token: session_token_info.session_token.clone(),
1141 };
1142 let headers = self.make_headers(None)?;
1143 let request = self
1144 .cli
1145 .post(endpoint.clone())
1146 .json(&body)
1147 .headers(headers.clone())
1148 .bearer_auth(session_token_info.refresh_token.clone())
1149 .timeout(self.connect_timeout)
1150 .build()?;
1151 let response = self
1152 .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1153 .await?;
1154 if response.status != StatusCode::OK {
1155 return Err(Self::status_body_to_error(response.status, &response.body)
1156 .with_context(RequestKind::SessionRefresh));
1157 }
1158 let response = json_from_slice(&response.body)?;
1159 match response {
1160 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1161 RefreshResponse::Ok(info) => {
1162 *self_login_info.lock() = (info, Instant::now());
1163 Ok(())
1164 }
1165 }
1166 }
1167
1168 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1169 if let Some(info) = &self.session_token_info {
1170 let (start, ttl) = {
1171 let guard = info.lock();
1172 (guard.1, guard.0.session_token_ttl_in_secs)
1173 };
1174 if Instant::now() > start + Duration::from_secs(ttl) {
1175 return Some(info.clone());
1176 }
1177 }
1178 None
1179 }
1180
1181 async fn query_request_helper(
1189 &self,
1190 mut request: Request,
1191 retry_if_503: bool,
1192 refresh_if_401: bool,
1193 reload_auth_if_401: bool,
1194 request_kind: RequestKind,
1195 ) -> Result<HttpResponseData> {
1196 let mut refreshed = false;
1197 let mut retries = 0;
1198 let max_retries = self.retry_count;
1199 let retry_delay = Duration::from_secs(self.retry_delay_secs);
1200
1201 loop {
1202 let req = request.try_clone().expect("request not cloneable");
1203 let response = match self.cli.execute(req).await {
1204 Ok(response) => response,
1205 Err(err) => {
1206 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1207 if retries >= max_retries.saturating_sub(1) {
1208 return Err(Self::with_request_context(
1209 err,
1210 &request_kind,
1211 &request,
1212 Some(retries),
1213 ));
1214 }
1215 retries += 1;
1216 warn!(
1217 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1218 retries,
1219 max_retries,
1220 request.url(),
1221 reason,
1222 err,
1223 retry_delay.as_secs(),
1224 );
1225 sleep(jitter(retry_delay)).await;
1226 continue;
1227 }
1228 return Err(Self::with_request_context(
1229 err,
1230 &request_kind,
1231 &request,
1232 Some(retries),
1233 ));
1234 }
1235 };
1236
1237 let status = response.status();
1238 let headers = response.headers().clone();
1239 let body = match response.bytes().await {
1240 Ok(body) => body,
1241 Err(err) => {
1242 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1243 if retries >= max_retries.saturating_sub(1) {
1244 return Err(Self::with_request_context(
1245 err,
1246 &request_kind,
1247 &request,
1248 Some(retries),
1249 ));
1250 }
1251 retries += 1;
1252 warn!(
1253 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1254 retries,
1255 max_retries,
1256 request.url(),
1257 reason,
1258 err,
1259 retry_delay.as_secs(),
1260 );
1261 sleep(jitter(retry_delay)).await;
1262 continue;
1263 }
1264 return Err(Self::with_request_context(
1265 err,
1266 &request_kind,
1267 &request,
1268 Some(retries),
1269 ));
1270 }
1271 };
1272
1273 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1274 if retries >= max_retries.saturating_sub(1) {
1275 return Ok(HttpResponseData {
1276 status,
1277 headers,
1278 body,
1279 });
1280 }
1281 retries += 1;
1282 warn!(
1283 "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1284 retries,
1285 max_retries,
1286 request.url(),
1287 retry_delay.as_secs()
1288 );
1289 sleep(jitter(retry_delay)).await;
1290 continue;
1291 }
1292
1293 if status == StatusCode::UNAUTHORIZED {
1294 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1295 let unauthorized_error =
1296 serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1297 if let Some(session_token_info) = &self.session_token_info {
1298 let should_refresh = unauthorized_error
1299 .as_ref()
1300 .map(|r| need_refresh_token(r.error.code))
1301 .unwrap_or(false);
1302 if refresh_if_401 && should_refresh && !refreshed {
1303 Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1304 refreshed = true;
1305 retries = 0;
1306 warn!(
1307 "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1308 request.url(),
1309 retry_delay.as_secs()
1310 );
1311 sleep(jitter(retry_delay)).await;
1312 continue;
1313 }
1314 }
1315 if reload_auth_if_401 && self.auth.can_reload() {
1316 if retries >= max_retries.saturating_sub(1) {
1317 return Ok(HttpResponseData {
1318 status,
1319 headers,
1320 body,
1321 });
1322 }
1323 retries += 1;
1324 let builder =
1325 RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1326 let builder = self.auth.wrap(builder)?;
1327 request = builder.build()?;
1328 warn!(
1329 "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1330 retries,
1331 max_retries,
1332 request.url(),
1333 retry_delay.as_secs()
1334 );
1335 sleep(jitter(retry_delay)).await;
1336 continue;
1337 }
1338 }
1339
1340 return Ok(HttpResponseData {
1341 status,
1342 headers,
1343 body,
1344 });
1345 }
1346 }
1347
1348 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1349 if let Err(err) = http_client.execute(request).await {
1350 error!("[session {session_id}] logout request failed: {err}");
1351 } else {
1352 info!("[session {session_id}] logout success");
1353 };
1354 }
1355
1356 pub async fn close(&self) {
1357 let session_id = &self.session_id;
1358 info!("[session {session_id}] try closing now");
1359 if self
1360 .closed
1361 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1362 .is_ok()
1363 {
1364 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1365 if self.need_logout() {
1366 let cli = self.cli.clone();
1367 let req = self
1368 .build_log_out_request()
1369 .expect("failed to build logout request");
1370 Self::logout(cli, req, &self.session_id).await;
1371 }
1372 }
1373 }
1374 pub fn close_with_spawn(&self) {
1375 let session_id = &self.session_id;
1376 info!("[session {session_id}]: try closing with spawn");
1377 if self
1378 .closed
1379 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1380 .is_ok()
1381 {
1382 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1383 if self.need_logout() {
1384 let cli = self.cli.clone();
1385 let req = self
1386 .build_log_out_request()
1387 .expect("failed to build logout request");
1388 let session_id = self.session_id.clone();
1389 GLOBAL_RUNTIME.spawn(async move {
1390 Self::logout(cli, req, session_id.as_str()).await;
1391 });
1392 }
1393 }
1394 }
1395
1396 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1397 let mut queries = self.queries_need_heartbeat.lock();
1398 queries.insert(query_id.to_string(), state);
1399 }
1400}
1401
1402fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1403where
1404 T: Deserialize<'a>,
1405{
1406 serde_json::from_slice::<T>(body).map_err(|e| {
1407 Error::Decode(format!(
1408 "fail to decode JSON response: {e}, body: {}",
1409 String::from_utf8_lossy(body)
1410 ))
1411 })
1412}
1413
1414impl Default for APIClient {
1415 fn default() -> Self {
1416 Self {
1417 session_id: Default::default(),
1418 cli: HttpClient::new(),
1419 scheme: "http".to_string(),
1420 endpoint: Url::parse("http://localhost:8080").unwrap(),
1421 host: "localhost".to_string(),
1422 port: 8000,
1423 tenant: None,
1424 warehouse: Mutex::new(None),
1425 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1426 session_state: Mutex::new(SessionState::default()),
1427 wait_time_secs: None,
1428 max_rows_in_buffer: None,
1429 max_rows_per_page: None,
1430 connect_timeout: Duration::from_secs(10),
1431 page_request_timeout: Duration::from_secs(30),
1432 tls_ca_file: None,
1433 presign: Mutex::new(PresignMode::Auto),
1434 route_hint: RouteHintGenerator::new(),
1435 last_node_id: Default::default(),
1436 disable_session_token: true,
1437 disable_login: false,
1438 query_result_format: "json".to_string(),
1439 session_token_info: None,
1440 closed: AtomicBool::new(false),
1441 last_query_id: Default::default(),
1442 server_version: None,
1443 capability: Default::default(),
1444 queries_need_heartbeat: Default::default(),
1445 retry_count: 3,
1446 retry_delay_secs: 10,
1447 }
1448 }
1449}
1450
1451struct RouteHintGenerator {
1452 nonce: AtomicU64,
1453 current: std::sync::Mutex<String>,
1454}
1455
1456impl RouteHintGenerator {
1457 fn new() -> Self {
1458 let gen = Self {
1459 nonce: AtomicU64::new(0),
1460 current: std::sync::Mutex::new("".to_string()),
1461 };
1462 gen.next();
1463 gen
1464 }
1465
1466 fn current(&self) -> String {
1467 let guard = self.current.lock().unwrap();
1468 guard.clone()
1469 }
1470
1471 fn set(&self, hint: &str) {
1472 let mut guard = self.current.lock().unwrap();
1473 *guard = hint.to_string();
1474 }
1475
1476 fn next(&self) -> String {
1477 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1478 let uuid = uuid::Uuid::new_v4();
1479 let current = format!("rh:{uuid}:{nonce:06}");
1480 let mut guard = self.current.lock().unwrap();
1481 guard.clone_from(¤t);
1482 current
1483 }
1484}
1485
1486#[cfg(test)]
1487mod test {
1488 use super::*;
1489
1490 #[tokio::test]
1491 async fn parse_dsn() -> Result<()> {
1492 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1493 let client = APIClient::from_dsn(dsn).await?;
1494 assert_eq!(client.host, "app.databend.com");
1495 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1496 assert_eq!(client.wait_time_secs, Some(10));
1497 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1498 assert_eq!(client.max_rows_per_page, Some(10000));
1499 assert_eq!(client.tenant, None);
1500 assert_eq!(
1501 *client.warehouse.try_lock().unwrap(),
1502 Some("wh".to_string())
1503 );
1504 Ok(())
1505 }
1506
1507 #[tokio::test]
1508 async fn parse_encoded_password() -> Result<()> {
1509 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1510 let client = APIClient::from_dsn(dsn).await?;
1511 assert_eq!(client.host(), "localhost");
1512 assert_eq!(client.port(), 443);
1513 Ok(())
1514 }
1515
1516 #[tokio::test]
1517 async fn parse_special_chars_password() -> Result<()> {
1518 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1519 let client = APIClient::from_dsn(dsn).await?;
1520 assert_eq!(client.host(), "localhost");
1521 assert_eq!(client.port(), 8000);
1522 Ok(())
1523 }
1524}