1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, RequestKind, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48 Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77
78static VERSION: Lazy<String> = Lazy::new(|| {
79 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
80 version.to_string()
81});
82
83#[derive(Clone)]
84pub(crate) struct QueryState {
85 pub node_id: String,
86 pub last_access_time: Arc<Mutex<Instant>>,
87 pub timeout_secs: u64,
88}
89
90impl QueryState {
91 pub fn need_heartbeat(&self, now: Instant) -> bool {
92 let t = *self.last_access_time.lock();
93 now.duration_since(t).as_secs() > self.timeout_secs / 2
94 }
95}
96
97struct HttpResponseData {
98 status: StatusCode,
99 headers: HeaderMap,
100 body: Bytes,
101}
102
103pub struct APIClient {
104 pub(crate) session_id: String,
105 cli: HttpClient,
106 scheme: String,
107 host: String,
108 port: u16,
109
110 endpoint: Url,
111
112 auth: Arc<dyn Auth>,
113
114 tenant: Option<String>,
115 warehouse: Mutex<Option<String>>,
116 session_state: Mutex<SessionState>,
117 route_hint: RouteHintGenerator,
118
119 disable_login: bool,
120 query_result_format: String,
121 disable_session_token: bool,
122 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
123
124 closed: AtomicBool,
125
126 server_version: Option<Version>,
127
128 wait_time_secs: Option<i64>,
129 max_rows_in_buffer: Option<i64>,
130 max_rows_per_page: Option<i64>,
131
132 connect_timeout: Duration,
133 page_request_timeout: Duration,
134
135 tls_ca_file: Option<String>,
136
137 presign: Mutex<PresignMode>,
138 last_node_id: Mutex<Option<String>>,
139 last_query_id: Mutex<Option<String>>,
140
141 capability: Capability,
142
143 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
144
145 retry_count: u32,
146 retry_delay_secs: u64,
147}
148
149impl Drop for APIClient {
150 fn drop(&mut self) {
151 self.close_with_spawn()
152 }
153}
154
155impl APIClient {
156 fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
157 while let Some(err) = source {
158 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
159 match io_err.kind() {
160 ErrorKind::TimedOut
161 | ErrorKind::UnexpectedEof
162 | ErrorKind::ConnectionReset
163 | ErrorKind::ConnectionAborted
164 | ErrorKind::BrokenPipe
165 | ErrorKind::NotConnected
166 | ErrorKind::WouldBlock
167 | ErrorKind::Interrupted => return true,
168 _ => {}
169 }
170 }
171 source = err.source();
172 }
173 false
174 }
175
176 fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
177 if err.is_timeout() {
178 Some("request timeout")
179 } else if err.is_connect() {
180 Some("connection error")
181 } else if err.is_request() {
182 Some("request error")
183 } else if err.is_body() || err.is_decode() {
184 if Self::has_transient_io_source(err.source()) {
185 Some("response read transient I/O error")
186 } else {
187 None
188 }
189 } else {
190 None
191 }
192 }
193
194 fn request_query_id(request: &Request) -> Option<String> {
195 request
196 .headers()
197 .get(HEADER_QUERY_ID)
198 .and_then(|v| v.to_str().ok())
199 .map(|v| v.to_string())
200 }
201
202 fn with_request_context(
203 reqwest_error: ReqwestError,
204 request_kind: &RequestKind,
205 request: &Request,
206 retry_times: Option<u32>,
207 ) -> Error {
208 let error: Error = reqwest_error.into();
209 let error = error.with_context(request_kind.clone());
210 let error = if let Some(retry_times) = retry_times {
211 error.with_retry_times(retry_times)
212 } else {
213 error
214 };
215 if let Some(query_id) = Self::request_query_id(request) {
216 error.with_query_id(query_id)
217 } else {
218 error
219 }
220 }
221
222 fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
223 match serde_json::from_slice::<ResponseWithErrorCode>(body) {
224 Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
225 Ok(resp) => Error::Logic(status, resp.error),
226 Err(_) => Error::response_error(status, body),
227 }
228 }
229
230 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
231 let mut client = Self::from_dsn(dsn).await?;
232 client.build_client(name).await?;
233 if !client.disable_login {
234 client.login().await?;
235 }
236 if client.session_id.is_empty() {
237 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
238 }
239 let client = Arc::new(client);
240 client.check_presign().await?;
241 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
242 Ok(client)
243 }
244
245 pub fn capability(&self) -> &Capability {
246 &self.capability
247 }
248
249 fn set_presign_mode(&self, mode: PresignMode) {
250 *self.presign.lock() = mode
251 }
252 fn get_presign_mode(&self) -> PresignMode {
253 *self.presign.lock()
254 }
255
256 async fn from_dsn(dsn: &str) -> Result<Self> {
257 let u = Url::parse(dsn)?;
258 let mut client = Self::default();
259 if let Some(host) = u.host_str() {
260 client.host = host.to_string();
261 }
262
263 if u.username() != "" {
264 let password = u.password().unwrap_or_default();
265 let password = percent_decode_str(password).decode_utf8()?;
266 client.auth = Arc::new(BasicAuth::new(u.username(), password));
267 }
268
269 let mut session_state = SessionState::default();
270
271 let database = u.path().trim_start_matches('/');
272 if !database.is_empty() {
273 session_state.set_database(database);
274 }
275
276 let mut scheme = "https";
277 for (k, v) in u.query_pairs() {
278 match k.as_ref() {
279 "wait_time_secs" => {
280 client.wait_time_secs = Some(v.parse()?);
281 }
282 "max_rows_in_buffer" => {
283 client.max_rows_in_buffer = Some(v.parse()?);
284 }
285 "max_rows_per_page" => {
286 client.max_rows_per_page = Some(v.parse()?);
287 }
288 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
289 "page_request_timeout_secs" => {
290 client.page_request_timeout = {
291 let secs: u64 = v.parse()?;
292 Duration::from_secs(secs)
293 };
294 }
295 "presign" => {
296 let presign_mode = match v.as_ref() {
297 "auto" => PresignMode::Auto,
298 "detect" => PresignMode::Detect,
299 "on" => PresignMode::On,
300 "off" => PresignMode::Off,
301 _ => {
302 return Err(Error::BadArgument(format!(
303 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
304 )))
305 }
306 };
307 client.set_presign_mode(presign_mode);
308 }
309 "tenant" => {
310 client.tenant = Some(v.to_string());
311 }
312 "warehouse" => {
313 client.warehouse = Mutex::new(Some(v.to_string()));
314 }
315 "role" => session_state.set_role(v),
316 "sslmode" => match v.as_ref() {
317 "disable" => scheme = "http",
318 "require" | "enable" => scheme = "https",
319 _ => {
320 return Err(Error::BadArgument(format!(
321 "Invalid value for sslmode: {v}"
322 )))
323 }
324 },
325 "tls_ca_file" => {
326 client.tls_ca_file = Some(v.to_string());
327 }
328 "access_token" => {
329 client.auth = Arc::new(AccessTokenAuth::new(v));
330 }
331 "access_token_file" => {
332 client.auth = Arc::new(AccessTokenFileAuth::new(v));
333 }
334 "login" => {
335 client.disable_login = match v.as_ref() {
336 "disable" => true,
337 "enable" => false,
338 _ => {
339 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
340 }
341 }
342 }
343 "session_token" => {
344 client.disable_session_token = match v.as_ref() {
345 "disable" => true,
346 "enable" => false,
347 _ => {
348 return Err(Error::BadArgument(format!(
349 "Invalid value for session_token: {v}"
350 )))
351 }
352 }
353 }
354 "body_format" | "query_result_format" => {
355 let v = v.to_string().to_lowercase();
356 match v.as_str() {
357 "json" | "arrow" => client.query_result_format = v.to_string(),
358 _ => {
359 return Err(Error::BadArgument(format!(
360 "Invalid value for query_result_format: {v}"
361 )))
362 }
363 }
364 }
365 "retry_count" => {
366 client.retry_count = v.parse()?;
367 }
368 "retry_delay_secs" => {
369 client.retry_delay_secs = v.parse()?;
370 }
371 _ => {
372 session_state.set(k, v);
373 }
374 }
375 }
376 client.port = match u.port() {
377 Some(p) => p,
378 None => match scheme {
379 "http" => 80,
380 "https" => 443,
381 _ => unreachable!(),
382 },
383 };
384 client.scheme = scheme.to_string();
385 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
386 client.session_state = Mutex::new(session_state);
387
388 Ok(client)
389 }
390
391 pub fn host(&self) -> &str {
392 self.host.as_str()
393 }
394
395 pub fn port(&self) -> u16 {
396 self.port
397 }
398
399 pub fn scheme(&self) -> &str {
400 self.scheme.as_str()
401 }
402
403 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
404 let ua = match name {
405 Some(n) => n,
406 None => format!("databend-client-rust/{}", VERSION.as_str()),
407 };
408 let cookie_provider = GlobalCookieStore::new();
409 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
410 let mut initial_cookies = [&cookie].into_iter();
411 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
412 let mut cli_builder = HttpClient::builder()
413 .user_agent(ua)
414 .connect_timeout(self.connect_timeout)
415 .cookie_provider(Arc::new(cookie_provider))
416 .pool_idle_timeout(Duration::from_secs(1));
417 #[cfg(any(feature = "rustls", feature = "native-tls"))]
418 if self.scheme == "https" {
419 if let Some(ref ca_file) = self.tls_ca_file {
420 let cert_pem = tokio::fs::read(ca_file).await?;
421 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
422 cli_builder = cli_builder.add_root_certificate(cert);
423 }
424 }
425 self.cli = cli_builder.build()?;
426 Ok(())
427 }
428
429 async fn check_presign(self: &Arc<Self>) -> Result<()> {
430 let mode = match self.get_presign_mode() {
431 PresignMode::Auto => {
432 if self.host.ends_with(".databend.com")
433 || self.host.ends_with(".databend.cn")
434 || self.host.ends_with(".tidbcloud.com")
435 {
436 PresignMode::On
437 } else {
438 PresignMode::Off
439 }
440 }
441 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
442 Ok(_) => PresignMode::On,
443 Err(e) => {
444 warn!("presign mode off with error detected: {e}");
445 PresignMode::Off
446 }
447 },
448 mode => mode,
449 };
450 self.set_presign_mode(mode);
451 Ok(())
452 }
453
454 pub fn current_warehouse(&self) -> Option<String> {
455 let guard = self.warehouse.lock();
456 guard.clone()
457 }
458
459 pub fn current_catalog(&self) -> Option<String> {
460 let guard = self.session_state.lock();
461 guard.catalog.clone()
462 }
463
464 pub fn current_database(&self) -> Option<String> {
465 let guard = self.session_state.lock();
466 guard.database.clone()
467 }
468
469 pub fn set_warehouse(&self, warehouse: impl Into<String>) {
470 let mut guard = self.warehouse.lock();
471 *guard = Some(warehouse.into());
472 }
473
474 pub fn set_database(&self, database: impl Into<String>) {
475 let mut guard = self.session_state.lock();
476 guard.set_database(database);
477 }
478
479 pub fn set_role(&self, role: impl Into<String>) {
480 let mut guard = self.session_state.lock();
481 guard.set_role(role);
482 }
483
484 pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
485 let mut guard = self.session_state.lock();
486 guard.set(key, value);
487 }
488
489 pub async fn current_role(&self) -> Option<String> {
490 let guard = self.session_state.lock();
491 guard.role.clone()
492 }
493
494 fn in_active_transaction(&self) -> bool {
495 let guard = self.session_state.lock();
496 guard
497 .txn_state
498 .as_ref()
499 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
500 .unwrap_or(false)
501 }
502
503 pub fn username(&self) -> String {
504 self.auth.username()
505 }
506
507 fn gen_query_id(&self) -> String {
508 uuid::Uuid::now_v7().simple().to_string()
509 }
510
511 async fn handle_session(&self, session: &Option<SessionState>) {
512 let session = match session {
513 Some(session) => session,
514 None => return,
515 };
516
517 {
519 let mut session_state = self.session_state.lock();
520 *session_state = session.clone();
521 }
522
523 if let Some(settings) = session.settings.as_ref() {
525 if let Some(v) = settings.get("warehouse") {
526 let mut warehouse = self.warehouse.lock();
527 *warehouse = Some(v.clone());
528 }
529 }
530 }
531
532 pub fn set_last_node_id(&self, node_id: String) {
533 *self.last_node_id.lock() = Some(node_id)
534 }
535
536 pub fn set_last_query_id(&self, query_id: Option<String>) {
537 *self.last_query_id.lock() = query_id
538 }
539
540 pub fn last_query_id(&self) -> Option<String> {
541 self.last_query_id.lock().clone()
542 }
543
544 fn last_node_id(&self) -> Option<String> {
545 self.last_node_id.lock().clone()
546 }
547
548 fn handle_warnings(&self, resp: &QueryResponse) {
549 if let Some(warnings) = &resp.warnings {
550 for w in warnings {
551 warn!(target: "server_warnings", "server warning: {w}");
552 }
553 }
554 }
555
556 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
557 info!("start query: {sql}");
558 let (resp, batches) = self.start_query_inner(sql, None, false).await?;
559 Pages::new(self.clone(), resp, batches, need_progress)
560 }
561
562 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
563 let mut mgr = self.queries_need_heartbeat.lock();
564 if let Some(state) = mgr.remove(query_id) {
565 let self_cloned = self.clone();
566 let query_id = query_id.to_owned();
567 GLOBAL_RUNTIME.spawn(async move {
568 if let Err(e) = self_cloned
569 .end_query(&query_id, false, Some(state.node_id.as_str()))
570 .await
571 {
572 error!("failed to final query {query_id}: {e}");
573 }
574 });
575 }
576 }
577
578 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
579 if let Some(info) = &self.session_token_info {
580 let info = info.lock();
581 Ok(builder.bearer_auth(info.0.session_token.clone()))
582 } else {
583 self.auth.wrap(builder)
584 }
585 }
586
587 async fn start_query_inner(
588 &self,
589 sql: &str,
590 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
591 force_json_body: bool,
592 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
593 if !self.in_active_transaction() {
594 self.route_hint.next();
595 }
596 let endpoint = self.endpoint.join("v1/query")?;
597
598 let session_state = self.session_state();
600 let need_sticky = session_state.need_sticky.unwrap_or(false);
601 let req = QueryRequest::new(sql)
602 .with_pagination(self.make_pagination())
603 .with_session(Some(session_state))
604 .with_stage_attachment(stage_attachment_config);
605
606 let query_id = self.gen_query_id();
608 let mut headers = self.make_headers(Some(&query_id))?;
609 if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
610 debug!("accept arrow data");
611 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
612 }
613
614 if need_sticky {
615 if let Some(node_id) = self.last_node_id() {
616 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
617 }
618 }
619 let mut builder = self.cli.post(endpoint.clone()).json(&req);
620 builder = self.wrap_auth_or_session_token(builder)?;
621 let request = builder
622 .headers(headers.clone())
623 .timeout(self.page_request_timeout)
624 .build()?;
625 let response = self
626 .query_request_helper(request, true, true, true, RequestKind::QueryStart)
627 .await?;
628 self.handle_page(response, true).await
629 }
630
631 fn is_arrow_data(headers: &HeaderMap) -> bool {
632 if let Some(typ) = headers.get(CONTENT_TYPE) {
633 if let Ok(t) = typ.to_str() {
634 return t == CONTENT_TYPE_ARROW;
635 }
636 }
637 false
638 }
639
640 async fn handle_page(
641 &self,
642 response: HttpResponseData,
643 is_first: bool,
644 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
645 let status = response.status;
646 if status != StatusCode::OK {
647 let error = Self::status_body_to_error(status, &response.body);
648 if !is_first {
649 if let Error::Logic(code, ec) = &error {
650 if *code == StatusCode::NOT_FOUND {
651 return Err(Error::QueryNotFound(ec.message.clone()));
652 }
653 }
654 }
655 return Err(error);
656 }
657 let is_arrow_data = Self::is_arrow_data(&response.headers);
658 if is_first {
659 if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
660 self.route_hint.set(route_hint.to_str().unwrap_or_default());
661 }
662 }
663 let mut body = response.body;
664 let mut batches = vec![];
665 if is_arrow_data {
666 if is_first {
667 debug!("received arrow data");
668 }
669 let cursor = std::io::Cursor::new(body.as_ref());
670 let reader = StreamReader::try_new(cursor, None)
671 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
672 let schema = reader.schema();
673 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
674 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
675 } else {
676 return Err(Error::Decode(
677 "missing response_header metadata in arrow payload".to_string(),
678 ));
679 };
680 for batch in reader {
681 let batch = batch
682 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
683 batches.push(batch);
684 }
685 body = json_body
686 };
687 let resp: QueryResponse = json_from_slice(&body)?;
688 self.handle_session(&resp.session).await;
689 if let Some(err) = &resp.error {
690 return Err(Error::QueryFailed(err.clone()));
691 }
692 if is_first {
693 self.handle_warnings(&resp);
694 self.set_last_query_id(Some(resp.id.clone()));
695 if let Some(node_id) = &resp.node_id {
696 self.set_last_node_id(node_id.clone());
697 }
698 }
699 Ok((resp, batches))
700 }
701
702 pub async fn query_page(
703 &self,
704 query_id: &str,
705 next_uri: &str,
706 node_id: &Option<String>,
707 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
708 info!("query page: {next_uri}");
709 let endpoint = self.endpoint.join(next_uri)?;
710 let mut headers = self.make_headers(Some(query_id))?;
711 if self.capability.arrow_data && self.query_result_format == "arrow" {
712 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
713 }
714 let mut builder = self.cli.get(endpoint.clone());
715 builder = self
716 .wrap_auth_or_session_token(builder)?
717 .headers(headers.clone())
718 .timeout(self.page_request_timeout);
719 if let Some(node_id) = node_id {
720 builder = builder.header(HEADER_STICKY_NODE, node_id)
721 }
722 let request = builder.build()?;
723
724 let response = self
725 .query_request_helper(request, false, true, true, RequestKind::QueryPage)
726 .await?;
727 self.handle_page(response, false).await
728 }
729
730 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
731 self.end_query(query_id, true, None).await
732 }
733
734 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
735 self.end_query(query_id, false, node_id).await
736 }
737
738 pub async fn end_query(
739 &self,
740 query_id: &str,
741 is_kill: bool,
742 node_id: Option<&str>,
743 ) -> Result<()> {
744 let (method, request_kind) = if is_kill {
745 ("kill", RequestKind::QueryKill)
746 } else {
747 ("final", RequestKind::QueryFinal)
748 };
749 let uri = format!("/v1/query/{query_id}/{method}");
750 let endpoint = self.endpoint.join(&uri)?;
751 let headers = self.make_headers(Some(query_id))?;
752
753 info!("{method} query: {uri}");
754
755 let mut builder = self.cli.post(endpoint.clone());
756 if let Some(node_id) = node_id {
757 builder = builder.header(HEADER_STICKY_NODE, node_id)
758 }
759 builder = self.wrap_auth_or_session_token(builder)?;
760 let resp = builder.headers(headers.clone()).send().await?;
761 if resp.status() != 200 {
762 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
763 .with_context(request_kind)
764 .with_query_id(query_id));
765 }
766 Ok(())
767 }
768
769 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
770 self.query_all_inner(sql, false).await
771 }
772
773 pub async fn query_all_inner(
774 self: &Arc<Self>,
775 sql: &str,
776 force_json_body: bool,
777 ) -> Result<Page> {
778 let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
779 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
780 let mut all = Page::default();
781 while let Some(page) = pages.next().await {
782 all.update(page?);
783 }
784 Ok(all)
785 }
786
787 fn session_state(&self) -> SessionState {
788 self.session_state.lock().clone()
789 }
790
791 fn make_pagination(&self) -> Option<PaginationConfig> {
792 if self.wait_time_secs.is_none()
793 && self.max_rows_in_buffer.is_none()
794 && self.max_rows_per_page.is_none()
795 {
796 return None;
797 }
798 let mut pagination = PaginationConfig {
799 wait_time_secs: None,
800 max_rows_in_buffer: None,
801 max_rows_per_page: None,
802 };
803 if let Some(wait_time_secs) = self.wait_time_secs {
804 pagination.wait_time_secs = Some(wait_time_secs);
805 }
806 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
807 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
808 }
809 if let Some(max_rows_per_page) = self.max_rows_per_page {
810 pagination.max_rows_per_page = Some(max_rows_per_page);
811 }
812 Some(pagination)
813 }
814
815 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
816 let mut headers = HeaderMap::new();
817 if let Some(tenant) = &self.tenant {
818 headers.insert(HEADER_TENANT, tenant.parse()?);
819 }
820 let warehouse = self.warehouse.lock().clone();
821 if let Some(warehouse) = warehouse {
822 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
823 }
824 let route_hint = self.route_hint.current();
825 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
826 if let Some(query_id) = query_id {
827 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
828 }
829 Ok(headers)
830 }
831
832 pub async fn insert_with_stage(
833 self: &Arc<Self>,
834 sql: &str,
835 stage: &str,
836 file_format_options: BTreeMap<&str, &str>,
837 copy_options: BTreeMap<&str, &str>,
838 ) -> Result<QueryStats> {
839 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
840 let stage_attachment = Some(StageAttachmentConfig {
841 location: stage,
842 file_format_options: Some(file_format_options),
843 copy_options: Some(copy_options),
844 });
845 let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
846 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
847 let mut all = Page::default();
848 while let Some(page) = pages.next().await {
849 all.update(page?);
850 }
851 Ok(all.stats)
852 }
853
854 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
855 info!("get presigned upload url: {stage}");
856 let sql = format!("PRESIGN UPLOAD {stage}");
857 let resp = self.query_all_inner(&sql, true).await?;
858 if resp.data.len() != 1 {
859 return Err(Error::Decode(
860 "Empty response from server for presigned request".to_string(),
861 ));
862 }
863 if resp.data[0].len() != 3 {
864 return Err(Error::Decode(
865 "Invalid response from server for presigned request".to_string(),
866 ));
867 }
868 let method = resp.data[0][0].clone().unwrap_or_default();
870 if method != "PUT" {
871 return Err(Error::Decode(format!(
872 "Invalid method for presigned upload request: {method}"
873 )));
874 }
875 let headers: BTreeMap<String, String> =
876 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
877 let url = resp.data[0][2].clone().unwrap_or_default();
878 Ok(PresignedResponse {
879 method,
880 headers,
881 url,
882 })
883 }
884
885 pub async fn upload_to_stage(
886 self: &Arc<Self>,
887 stage: &str,
888 data: Reader,
889 size: u64,
890 ) -> Result<()> {
891 match self.get_presign_mode() {
892 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
893 PresignMode::On => {
894 let presigned = self.get_presigned_upload_url(stage).await?;
895 presign_upload_to_stage(presigned, data, size).await
896 }
897 PresignMode::Auto => {
898 unreachable!("PresignMode::Auto should be handled during client initialization")
899 }
900 PresignMode::Detect => {
901 unreachable!("PresignMode::Detect should be handled during client initialization")
902 }
903 }
904 }
905
906 async fn upload_to_stage_with_stream(
908 &self,
909 stage: &str,
910 data: Reader,
911 size: u64,
912 ) -> Result<()> {
913 info!("upload to stage with stream: {stage}, size: {size}");
914 if let Some(info) = self.need_pre_refresh_session().await {
915 self.refresh_session_token(info).await?;
916 }
917 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
918 let location = StageLocation::try_from(stage)?;
919 let query_id = self.gen_query_id();
920 let mut headers = self.make_headers(Some(&query_id))?;
921 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
922 let stream = Body::wrap_stream(ReaderStream::new(data));
923 let part = Part::stream_with_length(stream, size).file_name(location.path);
924 let form = Form::new().part("upload", part);
925 let mut builder = self.cli.put(endpoint.clone());
926 builder = self.wrap_auth_or_session_token(builder)?;
927 let resp = builder.headers(headers).multipart(form).send().await?;
928 let status = resp.status();
929 if status != 200 {
930 return Err(Error::response_error(status, &resp.bytes().await?)
931 .with_context(RequestKind::UploadToStage)
932 .with_query_id(query_id));
933 }
934 Ok(())
935 }
936
937 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
940 where
941 T: de::DeserializeOwned,
942 {
943 if value.starts_with("{") {
944 serde_json::from_slice(value.as_bytes())
945 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
946 } else {
947 let json = URL_SAFE.decode(value).map_err(|e| {
948 format!(
949 "Invalid value {} for {key}, base64 decode error: {}",
950 value, e
951 )
952 })?;
953 serde_json::from_slice(&json).map_err(|e| {
954 format!(
955 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
956 String::from_utf8_lossy(&json)
957 )
958 })
959 }
960 }
961
962 pub async fn streaming_load(
963 &self,
964 sql: &str,
965 data: Reader,
966 file_name: &str,
967 ) -> Result<LoadResponse> {
968 let body = Body::wrap_stream(ReaderStream::new(data));
969 let part = Part::stream(body).file_name(file_name.to_string());
970 let endpoint = self.endpoint.join("v1/streaming_load")?;
971 let mut builder = self.cli.put(endpoint.clone());
972 builder = self.wrap_auth_or_session_token(builder)?;
973 let query_id = self.gen_query_id();
974 let mut headers = self.make_headers(Some(&query_id))?;
975 headers.insert(HEADER_SQL, sql.parse()?);
976 let session = serde_json::to_string(&*self.session_state.lock())
977 .expect("serialize session state should not fail");
978 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
979 let form = Form::new().part("upload", part);
980 let resp = builder.headers(headers).multipart(form).send().await?;
981 let status = resp.status();
982 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
983 match Self::decode_json_header::<SessionState>(
984 HEADER_QUERY_CONTEXT,
985 value.to_str().unwrap(),
986 ) {
987 Ok(session) => *self.session_state.lock() = session,
988 Err(e) => {
989 error!("Error decoding session state when streaming load: {e}");
990 }
991 }
992 };
993 if status != 200 {
994 return Err(Error::response_error(status, &resp.bytes().await?)
995 .with_context(RequestKind::StreamingLoad)
996 .with_query_id(query_id));
997 }
998 let resp = resp.json::<LoadResponse>().await?;
999 Ok(resp)
1000 }
1001
1002 async fn login(&mut self) -> Result<()> {
1003 let endpoint = self.endpoint.join("/v1/session/login")?;
1004 let headers = self.make_headers(None)?;
1005 let body = LoginRequest::from(&*self.session_state.lock());
1006 let mut builder = self.cli.post(endpoint.clone()).json(&body);
1007 if self.disable_session_token {
1008 builder = builder.query(&[("disable_session_token", true)]);
1009 }
1010 let builder = self.auth.wrap(builder)?;
1011 let request = builder
1012 .headers(headers.clone())
1013 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1014 .build()?;
1015 let response = self
1016 .query_request_helper(request, true, false, true, RequestKind::Login)
1017 .await?;
1018 if response.status == StatusCode::NOT_FOUND {
1019 info!("login return 404, skip login on the old version server");
1020 return Ok(());
1021 }
1022 if response.status != StatusCode::OK {
1023 return Err(Self::status_body_to_error(response.status, &response.body)
1024 .with_context(RequestKind::Login));
1025 }
1026 if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1027 if let Ok(s) = v.to_str() {
1028 self.session_id = s.to_string();
1029 }
1030 }
1031
1032 let response = json_from_slice(&response.body)?;
1033 match response {
1034 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1035 LoginResponseResult::Ok(info) => {
1036 let server_version = info
1037 .version
1038 .parse()
1039 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1040 self.capability = Capability::from_server_version(&server_version);
1041 self.server_version = Some(server_version.clone());
1042 let session_id = self.session_id.as_str();
1043 if let Some(tokens) = info.tokens {
1044 info!(
1045 "[session {session_id}] login success with session token version = {server_version}",
1046 );
1047 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1048 } else {
1049 info!("[session {session_id}] login success, version = {server_version}");
1050 }
1051 }
1052 }
1053 Ok(())
1054 }
1055
1056 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1057 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1058 let queries = self.queries_need_heartbeat.lock().clone();
1059 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1060 let now = Instant::now();
1061
1062 let mut query_ids = Vec::new();
1063 for (qid, state) in queries {
1064 if state.need_heartbeat(now) {
1065 query_ids.push(qid.to_string());
1066 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1067 arr.push(qid);
1068 } else {
1069 node_to_queries.insert(state.node_id, vec![qid]);
1070 }
1071 }
1072 }
1073
1074 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1075 {
1076 return Ok(());
1077 }
1078
1079 let body = json!({
1080 "node_to_queries": node_to_queries
1081 });
1082 let headers = self.make_headers(None)?;
1083 let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1084 let request = self.wrap_auth_or_session_token(builder)?.build()?;
1085 let response = self
1086 .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1087 .await?;
1088 if response.status != StatusCode::OK {
1089 return Err(Self::status_body_to_error(response.status, &response.body)
1090 .with_context(RequestKind::Heartbeat));
1091 }
1092 let json: Value = json_from_slice(&response.body)?;
1093 let session_id = self.session_id.as_str();
1094 info!("[session {session_id}] heartbeat request={body}, response={json}");
1095 if let Some(queries_to_remove) = json.get("queries_to_remove") {
1096 if let Some(arr) = queries_to_remove.as_array() {
1097 if !arr.is_empty() {
1098 let mut queries = self.queries_need_heartbeat.lock();
1099 for q in arr {
1100 if let Some(q) = q.as_str() {
1101 queries.remove(q);
1102 }
1103 }
1104 }
1105 }
1106 }
1107 let now = Instant::now();
1108 let mut queries = self.queries_need_heartbeat.lock();
1109 for qid in query_ids {
1110 if let Some(state) = queries.get_mut(&qid) {
1111 *state.last_access_time.lock() = now;
1112 }
1113 }
1114 Ok(())
1115 }
1116
1117 fn build_log_out_request(&self) -> Result<Request> {
1118 let endpoint = self.endpoint.join("/v1/session/logout")?;
1119
1120 let session_state = self.session_state();
1121 let need_sticky = session_state.need_sticky.unwrap_or(false);
1122 let mut headers = self.make_headers(None)?;
1123 if need_sticky {
1124 if let Some(node_id) = self.last_node_id() {
1125 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1126 }
1127 }
1128 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1129
1130 let builder = self.wrap_auth_or_session_token(builder)?;
1131 let req = builder.build()?;
1132 Ok(req)
1133 }
1134
1135 pub(crate) fn need_logout(&self) -> bool {
1136 self.session_token_info.is_some()
1137 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1138 }
1139
1140 async fn refresh_session_token(
1141 &self,
1142 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1143 ) -> Result<()> {
1144 let (session_token_info, _) = { self_login_info.lock().clone() };
1145 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1146 let body = RefreshSessionTokenRequest {
1147 session_token: session_token_info.session_token.clone(),
1148 };
1149 let headers = self.make_headers(None)?;
1150 let request = self
1151 .cli
1152 .post(endpoint.clone())
1153 .json(&body)
1154 .headers(headers.clone())
1155 .bearer_auth(session_token_info.refresh_token.clone())
1156 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1157 .build()?;
1158 let response = self
1159 .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1160 .await?;
1161 if response.status != StatusCode::OK {
1162 return Err(Self::status_body_to_error(response.status, &response.body)
1163 .with_context(RequestKind::SessionRefresh));
1164 }
1165 let response = json_from_slice(&response.body)?;
1166 match response {
1167 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1168 RefreshResponse::Ok(info) => {
1169 *self_login_info.lock() = (info, Instant::now());
1170 Ok(())
1171 }
1172 }
1173 }
1174
1175 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1176 if let Some(info) = &self.session_token_info {
1177 let (start, ttl) = {
1178 let guard = info.lock();
1179 (guard.1, guard.0.session_token_ttl_in_secs)
1180 };
1181 if Instant::now() > start + Duration::from_secs(ttl) {
1182 return Some(info.clone());
1183 }
1184 }
1185 None
1186 }
1187
1188 async fn query_request_helper(
1196 &self,
1197 mut request: Request,
1198 retry_if_503: bool,
1199 refresh_if_401: bool,
1200 reload_auth_if_401: bool,
1201 request_kind: RequestKind,
1202 ) -> Result<HttpResponseData> {
1203 let mut refreshed = false;
1204 let mut retries = 0;
1205 let max_retries = self.retry_count;
1206 let retry_delay = Duration::from_secs(self.retry_delay_secs);
1207
1208 loop {
1209 let req = request.try_clone().expect("request not cloneable");
1210 let response = match self.cli.execute(req).await {
1211 Ok(response) => response,
1212 Err(err) => {
1213 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1214 if retries >= max_retries.saturating_sub(1) {
1215 return Err(Self::with_request_context(
1216 err,
1217 &request_kind,
1218 &request,
1219 Some(retries),
1220 ));
1221 }
1222 retries += 1;
1223 warn!(
1224 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1225 retries,
1226 max_retries,
1227 request.url(),
1228 reason,
1229 err,
1230 retry_delay.as_secs(),
1231 );
1232 sleep(jitter(retry_delay)).await;
1233 continue;
1234 }
1235 return Err(Self::with_request_context(
1236 err,
1237 &request_kind,
1238 &request,
1239 Some(retries),
1240 ));
1241 }
1242 };
1243
1244 let status = response.status();
1245 let headers = response.headers().clone();
1246 let body = match response.bytes().await {
1247 Ok(body) => body,
1248 Err(err) => {
1249 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1250 if retries >= max_retries.saturating_sub(1) {
1251 return Err(Self::with_request_context(
1252 err,
1253 &request_kind,
1254 &request,
1255 Some(retries),
1256 ));
1257 }
1258 retries += 1;
1259 warn!(
1260 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1261 retries,
1262 max_retries,
1263 request.url(),
1264 reason,
1265 err,
1266 retry_delay.as_secs(),
1267 );
1268 sleep(jitter(retry_delay)).await;
1269 continue;
1270 }
1271 return Err(Self::with_request_context(
1272 err,
1273 &request_kind,
1274 &request,
1275 Some(retries),
1276 ));
1277 }
1278 };
1279
1280 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1281 if retries >= max_retries.saturating_sub(1) {
1282 return Ok(HttpResponseData {
1283 status,
1284 headers,
1285 body,
1286 });
1287 }
1288 retries += 1;
1289 warn!(
1290 "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1291 retries,
1292 max_retries,
1293 request.url(),
1294 retry_delay.as_secs()
1295 );
1296 sleep(jitter(retry_delay)).await;
1297 continue;
1298 }
1299
1300 if status == StatusCode::UNAUTHORIZED {
1301 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1302 let unauthorized_error =
1303 serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1304 if let Some(session_token_info) = &self.session_token_info {
1305 let should_refresh = unauthorized_error
1306 .as_ref()
1307 .map(|r| need_refresh_token(r.error.code))
1308 .unwrap_or(false);
1309 if refresh_if_401 && should_refresh && !refreshed {
1310 Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1311 refreshed = true;
1312 retries = 0;
1313 warn!(
1314 "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1315 request.url(),
1316 retry_delay.as_secs()
1317 );
1318 sleep(jitter(retry_delay)).await;
1319 continue;
1320 }
1321 }
1322 if reload_auth_if_401 && self.auth.can_reload() {
1323 if retries >= max_retries.saturating_sub(1) {
1324 return Ok(HttpResponseData {
1325 status,
1326 headers,
1327 body,
1328 });
1329 }
1330 retries += 1;
1331 let builder =
1332 RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1333 let builder = self.auth.wrap(builder)?;
1334 request = builder.build()?;
1335 warn!(
1336 "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1337 retries,
1338 max_retries,
1339 request.url(),
1340 retry_delay.as_secs()
1341 );
1342 sleep(jitter(retry_delay)).await;
1343 continue;
1344 }
1345 }
1346
1347 return Ok(HttpResponseData {
1348 status,
1349 headers,
1350 body,
1351 });
1352 }
1353 }
1354
1355 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1356 if let Err(err) = http_client.execute(request).await {
1357 error!("[session {session_id}] logout request failed: {err}");
1358 } else {
1359 info!("[session {session_id}] logout success");
1360 };
1361 }
1362
1363 pub async fn close(&self) {
1364 let session_id = &self.session_id;
1365 info!("[session {session_id}] try closing now");
1366 if self
1367 .closed
1368 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1369 .is_ok()
1370 {
1371 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1372 if self.need_logout() {
1373 let cli = self.cli.clone();
1374 let req = self
1375 .build_log_out_request()
1376 .expect("failed to build logout request");
1377 Self::logout(cli, req, &self.session_id).await;
1378 }
1379 }
1380 }
1381 pub fn close_with_spawn(&self) {
1382 let session_id = &self.session_id;
1383 info!("[session {session_id}]: try closing with spawn");
1384 if self
1385 .closed
1386 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1387 .is_ok()
1388 {
1389 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1390 if self.need_logout() {
1391 let cli = self.cli.clone();
1392 let req = self
1393 .build_log_out_request()
1394 .expect("failed to build logout request");
1395 let session_id = self.session_id.clone();
1396 GLOBAL_RUNTIME.spawn(async move {
1397 Self::logout(cli, req, session_id.as_str()).await;
1398 });
1399 }
1400 }
1401 }
1402
1403 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1404 let mut queries = self.queries_need_heartbeat.lock();
1405 queries.insert(query_id.to_string(), state);
1406 }
1407}
1408
1409fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1410where
1411 T: Deserialize<'a>,
1412{
1413 serde_json::from_slice::<T>(body).map_err(|e| {
1414 Error::Decode(format!(
1415 "fail to decode JSON response: {e}, body: {}",
1416 String::from_utf8_lossy(body)
1417 ))
1418 })
1419}
1420
1421impl Default for APIClient {
1422 fn default() -> Self {
1423 Self {
1424 session_id: Default::default(),
1425 cli: HttpClient::new(),
1426 scheme: "http".to_string(),
1427 endpoint: Url::parse("http://localhost:8080").unwrap(),
1428 host: "localhost".to_string(),
1429 port: 8000,
1430 tenant: None,
1431 warehouse: Mutex::new(None),
1432 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1433 session_state: Mutex::new(SessionState::default()),
1434 wait_time_secs: None,
1435 max_rows_in_buffer: None,
1436 max_rows_per_page: None,
1437 connect_timeout: Duration::from_secs(10),
1438 page_request_timeout: Duration::from_secs(300),
1439 tls_ca_file: None,
1440 presign: Mutex::new(PresignMode::Auto),
1441 route_hint: RouteHintGenerator::new(),
1442 last_node_id: Default::default(),
1443 disable_session_token: true,
1444 disable_login: false,
1445 query_result_format: "json".to_string(),
1446 session_token_info: None,
1447 closed: AtomicBool::new(false),
1448 last_query_id: Default::default(),
1449 server_version: None,
1450 capability: Default::default(),
1451 queries_need_heartbeat: Default::default(),
1452 retry_count: 3,
1453 retry_delay_secs: 10,
1454 }
1455 }
1456}
1457
1458struct RouteHintGenerator {
1459 nonce: AtomicU64,
1460 current: std::sync::Mutex<String>,
1461}
1462
1463impl RouteHintGenerator {
1464 fn new() -> Self {
1465 let gen = Self {
1466 nonce: AtomicU64::new(0),
1467 current: std::sync::Mutex::new("".to_string()),
1468 };
1469 gen.next();
1470 gen
1471 }
1472
1473 fn current(&self) -> String {
1474 let guard = self.current.lock().unwrap();
1475 guard.clone()
1476 }
1477
1478 fn set(&self, hint: &str) {
1479 let mut guard = self.current.lock().unwrap();
1480 *guard = hint.to_string();
1481 }
1482
1483 fn next(&self) -> String {
1484 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1485 let uuid = uuid::Uuid::new_v4();
1486 let current = format!("rh:{uuid}:{nonce:06}");
1487 let mut guard = self.current.lock().unwrap();
1488 guard.clone_from(¤t);
1489 current
1490 }
1491}
1492
1493#[cfg(test)]
1494mod test {
1495 use super::*;
1496
1497 #[tokio::test]
1498 async fn parse_dsn() -> Result<()> {
1499 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1500 let client = APIClient::from_dsn(dsn).await?;
1501 assert_eq!(client.host, "app.databend.com");
1502 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1503 assert_eq!(client.wait_time_secs, Some(10));
1504 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1505 assert_eq!(client.max_rows_per_page, Some(10000));
1506 assert_eq!(client.tenant, None);
1507 assert_eq!(
1508 *client.warehouse.try_lock().unwrap(),
1509 Some("wh".to_string())
1510 );
1511 Ok(())
1512 }
1513
1514 #[tokio::test]
1515 async fn parse_encoded_password() -> Result<()> {
1516 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1517 let client = APIClient::from_dsn(dsn).await?;
1518 assert_eq!(client.host(), "localhost");
1519 assert_eq!(client.port(), 443);
1520 Ok(())
1521 }
1522
1523 #[tokio::test]
1524 async fn parse_special_chars_password() -> Result<()> {
1525 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1526 let client = APIClient::from_dsn(dsn).await?;
1527 assert_eq!(client.host(), "localhost");
1528 assert_eq!(client.port(), 8000);
1529 Ok(())
1530 }
1531}