1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth, KeyPairAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, RequestKind, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use bytes::Bytes;
40use log::{debug, error, info, warn};
41use once_cell::sync::Lazy;
42use parking_lot::Mutex;
43use percent_encoding::percent_decode_str;
44use reqwest::cookie::CookieStore;
45use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
46use reqwest::multipart::{Form, Part};
47use reqwest::{
48 Body, Client as HttpClient, Error as ReqwestError, Request, RequestBuilder, StatusCode,
49};
50use semver::Version;
51use serde::{de, Deserialize};
52use serde_json::{json, Value};
53use std::collections::{BTreeMap, HashMap};
54use std::error::Error as StdError;
55use std::io::ErrorKind;
56use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
57use std::sync::Arc;
58use std::time::{Duration, Instant};
59use tokio::time::sleep;
60use tokio_retry::strategy::jitter;
61use tokio_stream::StreamExt;
62use tokio_util::io::ReaderStream;
63use url::Url;
64
65const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
66const HEADER_TENANT: &str = "X-DATABEND-TENANT";
67const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
68const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
69const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
70const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
71const TXN_STATE_ACTIVE: &str = "Active";
72const HEADER_SQL: &str = "X-DATABEND-SQL";
73const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
74const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
75const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
76const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
77const DEFAULT_USERNAME: &str = "root";
78
79static VERSION: Lazy<String> = Lazy::new(|| {
80 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
81 version.to_string()
82});
83
84#[derive(Clone)]
85pub(crate) struct QueryState {
86 pub node_id: String,
87 pub last_access_time: Arc<Mutex<Instant>>,
88 pub timeout_secs: u64,
89}
90
91impl QueryState {
92 pub fn need_heartbeat(&self, now: Instant) -> bool {
93 let t = *self.last_access_time.lock();
94 now.duration_since(t).as_secs() > self.timeout_secs / 2
95 }
96}
97
98struct HttpResponseData {
99 status: StatusCode,
100 headers: HeaderMap,
101 body: Bytes,
102}
103
104pub struct APIClient {
105 pub(crate) session_id: String,
106 cli: HttpClient,
107 scheme: String,
108 host: String,
109 port: u16,
110
111 endpoint: Url,
112
113 auth: Arc<dyn Auth>,
114
115 tenant: Option<String>,
116 warehouse: Mutex<Option<String>>,
117 session_state: Mutex<SessionState>,
118 route_hint: RouteHintGenerator,
119
120 disable_login: bool,
121 query_result_format: String,
122 disable_session_token: bool,
123 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
124
125 closed: AtomicBool,
126
127 server_version: Option<Version>,
128
129 wait_time_secs: Option<i64>,
130 max_rows_in_buffer: Option<i64>,
131 max_rows_per_page: Option<i64>,
132
133 connect_timeout: Duration,
134 page_request_timeout: Duration,
135
136 tls_ca_file: Option<String>,
137
138 presign: Mutex<PresignMode>,
139 last_node_id: Mutex<Option<String>>,
140 last_query_id: Mutex<Option<String>>,
141
142 capability: Capability,
143
144 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
145
146 retry_count: u32,
147 retry_delay_secs: u64,
148}
149
150impl Drop for APIClient {
151 fn drop(&mut self) {
152 self.close_with_spawn()
153 }
154}
155
156impl APIClient {
157 fn has_transient_io_source(mut source: Option<&(dyn StdError + 'static)>) -> bool {
158 while let Some(err) = source {
159 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
160 match io_err.kind() {
161 ErrorKind::TimedOut
162 | ErrorKind::UnexpectedEof
163 | ErrorKind::ConnectionReset
164 | ErrorKind::ConnectionAborted
165 | ErrorKind::BrokenPipe
166 | ErrorKind::NotConnected
167 | ErrorKind::WouldBlock
168 | ErrorKind::Interrupted => return true,
169 _ => {}
170 }
171 }
172 source = err.source();
173 }
174 false
175 }
176
177 fn retry_reason_for_reqwest(err: &ReqwestError) -> Option<&'static str> {
178 if err.is_timeout() {
179 Some("request timeout")
180 } else if err.is_connect() {
181 Some("connection error")
182 } else if err.is_request() {
183 Some("request error")
184 } else if err.is_body() || err.is_decode() {
185 if Self::has_transient_io_source(err.source()) {
186 Some("response read transient I/O error")
187 } else {
188 None
189 }
190 } else {
191 None
192 }
193 }
194
195 fn request_query_id(request: &Request) -> Option<String> {
196 request
197 .headers()
198 .get(HEADER_QUERY_ID)
199 .and_then(|v| v.to_str().ok())
200 .map(|v| v.to_string())
201 }
202
203 fn with_request_context(
204 reqwest_error: ReqwestError,
205 request_kind: &RequestKind,
206 request: &Request,
207 retry_times: Option<u32>,
208 ) -> Error {
209 let error: Error = reqwest_error.into();
210 let error = error.with_context(request_kind.clone());
211 let error = if let Some(retry_times) = retry_times {
212 error.with_retry_times(retry_times)
213 } else {
214 error
215 };
216 if let Some(query_id) = Self::request_query_id(request) {
217 error.with_query_id(query_id)
218 } else {
219 error
220 }
221 }
222
223 fn status_body_to_error(status: StatusCode, body: &[u8]) -> Error {
224 match serde_json::from_slice::<ResponseWithErrorCode>(body) {
225 Ok(resp) if status == StatusCode::UNAUTHORIZED => Error::AuthFailure(resp.error),
226 Ok(resp) => Error::Logic(status, resp.error),
227 Err(_) => Error::response_error(status, body),
228 }
229 }
230
231 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
232 let mut client = Self::from_dsn(dsn).await?;
233 client.build_client(name).await?;
234 if !client.disable_login {
235 client.login().await?;
236 }
237 if client.session_id.is_empty() {
238 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
239 }
240 let client = Arc::new(client);
241 client.check_presign().await?;
242 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
243 Ok(client)
244 }
245
246 pub fn capability(&self) -> &Capability {
247 &self.capability
248 }
249
250 fn set_presign_mode(&self, mode: PresignMode) {
251 *self.presign.lock() = mode
252 }
253 fn get_presign_mode(&self) -> PresignMode {
254 *self.presign.lock()
255 }
256
257 async fn from_dsn(dsn: &str) -> Result<Self> {
258 let u = Url::parse(dsn)?;
259 let mut client = Self::default();
260 if let Some(host) = u.host_str() {
261 client.host = host.to_string();
262 }
263
264 let username = u.username().to_string();
265 if !username.is_empty() {
266 let password = u.password().unwrap_or_default();
267 let password = percent_decode_str(password).decode_utf8()?;
268 client.auth = Arc::new(BasicAuth::new(&username, password));
269 }
270
271 let mut session_state = SessionState::default();
272
273 let database = u.path().trim_start_matches('/');
274 if !database.is_empty() {
275 session_state.set_database(database);
276 }
277
278 let mut private_key_file: Option<String> = None;
279 let mut private_key_passphrase_file: Option<String> = None;
280
281 let mut scheme = "https";
282 for (k, v) in u.query_pairs() {
283 match k.as_ref() {
284 "wait_time_secs" => {
285 client.wait_time_secs = Some(v.parse()?);
286 }
287 "max_rows_in_buffer" => {
288 client.max_rows_in_buffer = Some(v.parse()?);
289 }
290 "max_rows_per_page" => {
291 client.max_rows_per_page = Some(v.parse()?);
292 }
293 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
294 "page_request_timeout_secs" => {
295 client.page_request_timeout = {
296 let secs: u64 = v.parse()?;
297 Duration::from_secs(secs)
298 };
299 }
300 "presign" => {
301 let presign_mode = match v.as_ref() {
302 "auto" => PresignMode::Auto,
303 "detect" => PresignMode::Detect,
304 "on" => PresignMode::On,
305 "off" => PresignMode::Off,
306 _ => {
307 return Err(Error::BadArgument(format!(
308 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
309 )))
310 }
311 };
312 client.set_presign_mode(presign_mode);
313 }
314 "tenant" => {
315 client.tenant = Some(v.to_string());
316 }
317 "warehouse" => {
318 client.warehouse = Mutex::new(Some(v.to_string()));
319 }
320 "role" => session_state.set_role(v),
321 "sslmode" => match v.as_ref() {
322 "disable" => scheme = "http",
323 "require" | "enable" => scheme = "https",
324 _ => {
325 return Err(Error::BadArgument(format!(
326 "Invalid value for sslmode: {v}"
327 )))
328 }
329 },
330 "tls_ca_file" => {
331 client.tls_ca_file = Some(v.to_string());
332 }
333 "access_token" => {
334 client.auth = Arc::new(AccessTokenAuth::new(v));
335 }
336 "access_token_file" => {
337 client.auth = Arc::new(AccessTokenFileAuth::new(v));
338 }
339 "private_key_file" => {
340 private_key_file = Some(v.to_string());
341 }
342 "private_key_passphrase_file" => {
343 private_key_passphrase_file = Some(v.to_string());
344 }
345 "login" => {
346 client.disable_login = match v.as_ref() {
347 "disable" => true,
348 "enable" => false,
349 _ => {
350 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
351 }
352 }
353 }
354 "session_token" => {
355 client.disable_session_token = match v.as_ref() {
356 "disable" => true,
357 "enable" => false,
358 _ => {
359 return Err(Error::BadArgument(format!(
360 "Invalid value for session_token: {v}"
361 )))
362 }
363 }
364 }
365 "body_format" | "query_result_format" => {
366 let v = v.to_string().to_lowercase();
367 match v.as_str() {
368 "json" | "arrow" => client.query_result_format = v.to_string(),
369 _ => {
370 return Err(Error::BadArgument(format!(
371 "Invalid value for query_result_format: {v}"
372 )))
373 }
374 }
375 }
376 "retry_count" => {
377 client.retry_count = v.parse()?;
378 }
379 "retry_delay_secs" => {
380 client.retry_delay_secs = v.parse()?;
381 }
382 _ => {
383 session_state.set(k, v);
384 }
385 }
386 }
387 if let Some(key_file) = private_key_file {
389 if username.is_empty() {
390 return Err(Error::BadArgument(
391 "username is required for key-pair authentication".to_string(),
392 ));
393 }
394 client.auth = Arc::new(KeyPairAuth::new(
395 &username,
396 &key_file,
397 private_key_passphrase_file.as_deref(),
398 )?);
399 }
400 client.port = match u.port() {
401 Some(p) => p,
402 None => match scheme {
403 "http" => 80,
404 "https" => 443,
405 _ => unreachable!(),
406 },
407 };
408 client.scheme = scheme.to_string();
409 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
410 client.session_state = Mutex::new(session_state);
411
412 Ok(client)
413 }
414
415 pub fn host(&self) -> &str {
416 self.host.as_str()
417 }
418
419 pub fn port(&self) -> u16 {
420 self.port
421 }
422
423 pub fn scheme(&self) -> &str {
424 self.scheme.as_str()
425 }
426
427 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
428 let ua = match name {
429 Some(n) => n,
430 None => format!("lake-client-rust/{}", VERSION.as_str()),
431 };
432 let cookie_provider = GlobalCookieStore::new();
433 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
434 let mut initial_cookies = [&cookie].into_iter();
435 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
436 let mut cli_builder = HttpClient::builder()
437 .user_agent(ua)
438 .connect_timeout(self.connect_timeout)
439 .cookie_provider(Arc::new(cookie_provider))
440 .pool_idle_timeout(Duration::from_secs(1));
441 #[cfg(any(feature = "rustls", feature = "native-tls"))]
442 if self.scheme == "https" {
443 if let Some(ref ca_file) = self.tls_ca_file {
444 let cert_pem = tokio::fs::read(ca_file).await?;
445 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
446 cli_builder = cli_builder.add_root_certificate(cert);
447 }
448 }
449 self.cli = cli_builder.build()?;
450 Ok(())
451 }
452
453 async fn check_presign(self: &Arc<Self>) -> Result<()> {
454 let mode = match self.get_presign_mode() {
455 PresignMode::Auto => {
456 if self.host.ends_with(".lake.tidbcloud.com")
457 || self.host.ends_with(".lake.tidbcloud.cn")
458 || self.host.ends_with(".tidbcloud.com")
459 {
460 PresignMode::On
461 } else {
462 PresignMode::Off
463 }
464 }
465 PresignMode::Detect => match self.get_presigned_upload_url("@~/.lakesql/check").await {
466 Ok(_) => PresignMode::On,
467 Err(e) => {
468 warn!("presign mode off with error detected: {e}");
469 PresignMode::Off
470 }
471 },
472 mode => mode,
473 };
474 self.set_presign_mode(mode);
475 Ok(())
476 }
477
478 pub fn current_warehouse(&self) -> Option<String> {
479 let guard = self.warehouse.lock();
480 guard.clone()
481 }
482
483 pub fn current_catalog(&self) -> Option<String> {
484 let guard = self.session_state.lock();
485 guard.catalog.clone()
486 }
487
488 pub fn current_database(&self) -> Option<String> {
489 let guard = self.session_state.lock();
490 guard.database.clone()
491 }
492
493 pub fn set_warehouse(&self, warehouse: impl Into<String>) {
494 let mut guard = self.warehouse.lock();
495 *guard = Some(warehouse.into());
496 }
497
498 pub fn set_database(&self, database: impl Into<String>) {
499 let mut guard = self.session_state.lock();
500 guard.set_database(database);
501 }
502
503 pub fn set_role(&self, role: impl Into<String>) {
504 let mut guard = self.session_state.lock();
505 guard.set_role(role);
506 }
507
508 pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
509 let mut guard = self.session_state.lock();
510 guard.set(key, value);
511 }
512
513 pub async fn current_role(&self) -> Option<String> {
514 let guard = self.session_state.lock();
515 guard.role.clone()
516 }
517
518 fn in_active_transaction(&self) -> bool {
519 let guard = self.session_state.lock();
520 guard
521 .txn_state
522 .as_ref()
523 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
524 .unwrap_or(false)
525 }
526
527 pub fn username(&self) -> String {
528 self.auth.username()
529 }
530
531 fn gen_query_id(&self) -> String {
532 uuid::Uuid::now_v7().simple().to_string()
533 }
534
535 async fn handle_session(&self, session: &Option<SessionState>) {
536 let session = match session {
537 Some(session) => session,
538 None => return,
539 };
540
541 {
543 let mut session_state = self.session_state.lock();
544 *session_state = session.clone();
545 }
546
547 if let Some(settings) = session.settings.as_ref() {
549 if let Some(v) = settings.get("warehouse") {
550 let mut warehouse = self.warehouse.lock();
551 *warehouse = Some(v.clone());
552 }
553 }
554 }
555
556 pub fn set_last_node_id(&self, node_id: String) {
557 *self.last_node_id.lock() = Some(node_id)
558 }
559
560 pub fn set_last_query_id(&self, query_id: Option<String>) {
561 *self.last_query_id.lock() = query_id
562 }
563
564 pub fn last_query_id(&self) -> Option<String> {
565 self.last_query_id.lock().clone()
566 }
567
568 fn last_node_id(&self) -> Option<String> {
569 self.last_node_id.lock().clone()
570 }
571
572 fn handle_warnings(&self, resp: &QueryResponse) {
573 if let Some(warnings) = &resp.warnings {
574 for w in warnings {
575 warn!(target: "server_warnings", "server warning: {w}");
576 }
577 }
578 }
579
580 pub async fn start_query(
581 self: &Arc<Self>,
582 sql: &str,
583 need_progress: bool,
584 params: Option<serde_json::Value>,
585 ) -> Result<Pages> {
586 info!("start query: {sql}");
587 let (resp, batches) = self.start_query_inner(sql, None, false, params).await?;
588 Pages::new(self.clone(), resp, batches, need_progress)
589 }
590
591 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
592 let mut mgr = self.queries_need_heartbeat.lock();
593 if let Some(state) = mgr.remove(query_id) {
594 let self_cloned = self.clone();
595 let query_id = query_id.to_owned();
596 GLOBAL_RUNTIME.spawn(async move {
597 if let Err(e) = self_cloned
598 .end_query(&query_id, false, Some(state.node_id.as_str()))
599 .await
600 {
601 error!("failed to final query {query_id}: {e}");
602 }
603 });
604 }
605 }
606
607 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
608 if let Some(info) = &self.session_token_info {
609 let info = info.lock();
610 Ok(builder.bearer_auth(info.0.session_token.clone()))
611 } else {
612 self.auth.wrap(builder)
613 }
614 }
615
616 async fn start_query_inner(
617 &self,
618 sql: &str,
619 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
620 force_json_body: bool,
621 params: Option<serde_json::Value>,
622 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
623 if !self.in_active_transaction() {
624 self.route_hint.next();
625 }
626 let endpoint = self.endpoint.join("v1/query")?;
627
628 let session_state = self.session_state();
630 let need_sticky = session_state.need_sticky.unwrap_or(false);
631 let mut req = QueryRequest::new(sql)
632 .with_pagination(self.make_pagination())
633 .with_session(Some(session_state))
634 .with_stage_attachment(stage_attachment_config)
635 .with_params(params);
636
637 let query_id = self.gen_query_id();
639 let mut headers = self.make_headers(Some(&query_id))?;
640 if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
641 debug!("accept arrow data");
642 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
643 req = req.with_arrow();
644 }
645
646 if need_sticky {
647 if let Some(node_id) = self.last_node_id() {
648 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
649 }
650 }
651 let mut builder = self.cli.post(endpoint.clone()).json(&req);
652 builder = self.wrap_auth_or_session_token(builder)?;
653 let request = builder
654 .headers(headers.clone())
655 .timeout(self.page_request_timeout)
656 .build()?;
657 let response = self
658 .query_request_helper(request, true, true, true, RequestKind::QueryStart)
659 .await?;
660 self.handle_page(response, true).await
661 }
662
663 fn is_arrow_data(headers: &HeaderMap) -> bool {
664 if let Some(typ) = headers.get(CONTENT_TYPE) {
665 if let Ok(t) = typ.to_str() {
666 return t == CONTENT_TYPE_ARROW;
667 }
668 }
669 false
670 }
671
672 async fn handle_page(
673 &self,
674 response: HttpResponseData,
675 is_first: bool,
676 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
677 let status = response.status;
678 if status != StatusCode::OK {
679 let error = Self::status_body_to_error(status, &response.body);
680 if !is_first {
681 if let Error::Logic(code, ec) = &error {
682 if *code == StatusCode::NOT_FOUND {
683 return Err(Error::QueryNotFound(ec.message.clone()));
684 }
685 }
686 }
687 return Err(error);
688 }
689 let is_arrow_data = Self::is_arrow_data(&response.headers);
690 if is_first {
691 if let Some(route_hint) = response.headers.get(HEADER_ROUTE_HINT) {
692 self.route_hint.set(route_hint.to_str().unwrap_or_default());
693 }
694 }
695 let mut body = response.body;
696 let mut batches = vec![];
697 if is_arrow_data {
698 if is_first {
699 debug!("received arrow data");
700 }
701 let cursor = std::io::Cursor::new(body.as_ref());
702 let reader = StreamReader::try_new(cursor, None)
703 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
704 let schema = reader.schema();
705 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
706 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
707 } else {
708 return Err(Error::Decode(
709 "missing response_header metadata in arrow payload".to_string(),
710 ));
711 };
712 for batch in reader {
713 let batch = batch
714 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
715 batches.push(batch);
716 }
717 body = json_body
718 };
719 let resp: QueryResponse = json_from_slice(&body)?;
720 self.handle_session(&resp.session).await;
721 if let Some(err) = &resp.error {
722 return Err(Error::QueryFailed(err.clone()));
723 }
724 if is_first {
725 self.handle_warnings(&resp);
726 self.set_last_query_id(Some(resp.id.clone()));
727 if let Some(node_id) = &resp.node_id {
728 self.set_last_node_id(node_id.clone());
729 }
730 }
731 Ok((resp, batches))
732 }
733
734 pub async fn query_page(
735 &self,
736 query_id: &str,
737 next_uri: &str,
738 node_id: &Option<String>,
739 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
740 info!("query page: {next_uri}");
741 let endpoint = self.endpoint.join(next_uri)?;
742 let mut headers = self.make_headers(Some(query_id))?;
743 if self.capability.arrow_data && self.query_result_format == "arrow" {
744 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
745 }
746 let mut builder = self.cli.get(endpoint.clone());
747 builder = self
748 .wrap_auth_or_session_token(builder)?
749 .headers(headers.clone())
750 .timeout(self.page_request_timeout);
751 if let Some(node_id) = node_id {
752 builder = builder.header(HEADER_STICKY_NODE, node_id)
753 }
754 let request = builder.build()?;
755
756 let response = self
757 .query_request_helper(request, false, true, true, RequestKind::QueryPage)
758 .await?;
759 self.handle_page(response, false).await
760 }
761
762 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
763 self.end_query(query_id, true, None).await
764 }
765
766 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
767 self.end_query(query_id, false, node_id).await
768 }
769
770 pub async fn end_query(
771 &self,
772 query_id: &str,
773 is_kill: bool,
774 node_id: Option<&str>,
775 ) -> Result<()> {
776 let (method, request_kind) = if is_kill {
777 ("kill", RequestKind::QueryKill)
778 } else {
779 ("final", RequestKind::QueryFinal)
780 };
781 let uri = format!("/v1/query/{query_id}/{method}");
782 let endpoint = self.endpoint.join(&uri)?;
783 let headers = self.make_headers(Some(query_id))?;
784
785 info!("{method} query: {uri}");
786
787 let mut builder = self.cli.post(endpoint.clone());
788 if let Some(node_id) = node_id {
789 builder = builder.header(HEADER_STICKY_NODE, node_id)
790 }
791 builder = self.wrap_auth_or_session_token(builder)?;
792 let resp = builder.headers(headers.clone()).send().await?;
793 if resp.status() != 200 {
794 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
795 .with_context(request_kind)
796 .with_query_id(query_id));
797 }
798 Ok(())
799 }
800
801 pub async fn query_all(
802 self: &Arc<Self>,
803 sql: &str,
804 params: Option<serde_json::Value>,
805 ) -> Result<Page> {
806 self.query_all_inner(sql, false, params).await
807 }
808
809 pub async fn query_all_inner(
810 self: &Arc<Self>,
811 sql: &str,
812 force_json_body: bool,
813 params: Option<serde_json::Value>,
814 ) -> Result<Page> {
815 let (resp, batches) = self
816 .start_query_inner(sql, None, force_json_body, params)
817 .await?;
818 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
819 let mut all = Page::default();
820 while let Some(page) = pages.next().await {
821 all.update(page?);
822 }
823 Ok(all)
824 }
825
826 fn session_state(&self) -> SessionState {
827 self.session_state.lock().clone()
828 }
829
830 fn make_pagination(&self) -> Option<PaginationConfig> {
831 if self.wait_time_secs.is_none()
832 && self.max_rows_in_buffer.is_none()
833 && self.max_rows_per_page.is_none()
834 {
835 return None;
836 }
837 let mut pagination = PaginationConfig {
838 wait_time_secs: None,
839 max_rows_in_buffer: None,
840 max_rows_per_page: None,
841 };
842 if let Some(wait_time_secs) = self.wait_time_secs {
843 pagination.wait_time_secs = Some(wait_time_secs);
844 }
845 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
846 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
847 }
848 if let Some(max_rows_per_page) = self.max_rows_per_page {
849 pagination.max_rows_per_page = Some(max_rows_per_page);
850 }
851 Some(pagination)
852 }
853
854 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
855 let mut headers = HeaderMap::new();
856 if let Some(tenant) = &self.tenant {
857 headers.insert(HEADER_TENANT, tenant.parse()?);
858 }
859 let warehouse = self.warehouse.lock().clone();
860 if let Some(warehouse) = warehouse {
861 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
862 }
863 let route_hint = self.route_hint.current();
864 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
865 if let Some(query_id) = query_id {
866 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
867 }
868 Ok(headers)
869 }
870
871 pub async fn insert_with_stage(
872 self: &Arc<Self>,
873 sql: &str,
874 stage: &str,
875 file_format_options: BTreeMap<&str, &str>,
876 copy_options: BTreeMap<&str, &str>,
877 ) -> Result<QueryStats> {
878 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
879 let stage_attachment = Some(StageAttachmentConfig {
880 location: stage,
881 file_format_options: Some(file_format_options),
882 copy_options: Some(copy_options),
883 });
884 let (resp, batches) = self
885 .start_query_inner(sql, stage_attachment, true, None)
886 .await?;
887 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
888 let mut all = Page::default();
889 while let Some(page) = pages.next().await {
890 all.update(page?);
891 }
892 Ok(all.stats)
893 }
894
895 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
896 info!("get presigned upload url: {stage}");
897 let sql = format!("PRESIGN UPLOAD {stage}");
898 let resp = self.query_all_inner(&sql, true, None).await?;
899 if resp.data.len() != 1 {
900 return Err(Error::Decode(
901 "Empty response from server for presigned request".to_string(),
902 ));
903 }
904 if resp.data[0].len() != 3 {
905 return Err(Error::Decode(
906 "Invalid response from server for presigned request".to_string(),
907 ));
908 }
909 let method = resp.data[0][0].clone().unwrap_or_default();
911 if method != "PUT" {
912 return Err(Error::Decode(format!(
913 "Invalid method for presigned upload request: {method}"
914 )));
915 }
916 let headers: BTreeMap<String, String> =
917 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
918 let url = resp.data[0][2].clone().unwrap_or_default();
919 Ok(PresignedResponse {
920 method,
921 headers,
922 url,
923 })
924 }
925
926 pub async fn upload_to_stage(
927 self: &Arc<Self>,
928 stage: &str,
929 data: Reader,
930 size: u64,
931 ) -> Result<()> {
932 match self.get_presign_mode() {
933 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
934 PresignMode::On => {
935 let presigned = self.get_presigned_upload_url(stage).await?;
936 presign_upload_to_stage(presigned, data, size).await
937 }
938 PresignMode::Auto => {
939 unreachable!("PresignMode::Auto should be handled during client initialization")
940 }
941 PresignMode::Detect => {
942 unreachable!("PresignMode::Detect should be handled during client initialization")
943 }
944 }
945 }
946
947 async fn upload_to_stage_with_stream(
949 &self,
950 stage: &str,
951 data: Reader,
952 size: u64,
953 ) -> Result<()> {
954 info!("upload to stage with stream: {stage}, size: {size}");
955 if let Some(info) = self.need_pre_refresh_session().await {
956 self.refresh_session_token(info).await?;
957 }
958 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
959 let location = StageLocation::try_from(stage)?;
960 let query_id = self.gen_query_id();
961 let mut headers = self.make_headers(Some(&query_id))?;
962 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
963 let stream = Body::wrap_stream(ReaderStream::new(data));
964 let part = Part::stream_with_length(stream, size).file_name(location.path);
965 let form = Form::new().part("upload", part);
966 let mut builder = self.cli.put(endpoint.clone());
967 builder = self.wrap_auth_or_session_token(builder)?;
968 let resp = builder.headers(headers).multipart(form).send().await?;
969 let status = resp.status();
970 if status != 200 {
971 return Err(Error::response_error(status, &resp.bytes().await?)
972 .with_context(RequestKind::UploadToStage)
973 .with_query_id(query_id));
974 }
975 Ok(())
976 }
977
978 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
981 where
982 T: de::DeserializeOwned,
983 {
984 if value.starts_with("{") {
985 serde_json::from_slice(value.as_bytes())
986 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
987 } else {
988 let json = URL_SAFE.decode(value).map_err(|e| {
989 format!(
990 "Invalid value {} for {key}, base64 decode error: {}",
991 value, e
992 )
993 })?;
994 serde_json::from_slice(&json).map_err(|e| {
995 format!(
996 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
997 String::from_utf8_lossy(&json)
998 )
999 })
1000 }
1001 }
1002
1003 pub async fn streaming_load(
1004 &self,
1005 sql: &str,
1006 data: Reader,
1007 file_name: &str,
1008 ) -> Result<LoadResponse> {
1009 let body = Body::wrap_stream(ReaderStream::new(data));
1010 let part = Part::stream(body).file_name(file_name.to_string());
1011 let endpoint = self.endpoint.join("v1/streaming_load")?;
1012 let mut builder = self.cli.put(endpoint.clone());
1013 builder = self.wrap_auth_or_session_token(builder)?;
1014 let query_id = self.gen_query_id();
1015 let mut headers = self.make_headers(Some(&query_id))?;
1016 headers.insert(HEADER_SQL, sql.parse()?);
1017 let session = serde_json::to_string(&*self.session_state.lock())
1018 .expect("serialize session state should not fail");
1019 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
1020 let form = Form::new().part("upload", part);
1021 let resp = builder.headers(headers).multipart(form).send().await?;
1022 let status = resp.status();
1023 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
1024 match Self::decode_json_header::<SessionState>(
1025 HEADER_QUERY_CONTEXT,
1026 value.to_str().unwrap(),
1027 ) {
1028 Ok(session) => *self.session_state.lock() = session,
1029 Err(e) => {
1030 error!("Error decoding session state when streaming load: {e}");
1031 }
1032 }
1033 };
1034 if status != 200 {
1035 return Err(Error::response_error(status, &resp.bytes().await?)
1036 .with_context(RequestKind::StreamingLoad)
1037 .with_query_id(query_id));
1038 }
1039 let resp = resp.json::<LoadResponse>().await?;
1040 Ok(resp)
1041 }
1042
1043 async fn login(&mut self) -> Result<()> {
1044 let endpoint = self.endpoint.join("/v1/session/login")?;
1045 let headers = self.make_headers(None)?;
1046 let body = LoginRequest::from(&*self.session_state.lock());
1047 let mut builder = self.cli.post(endpoint.clone()).json(&body);
1048 if self.disable_session_token {
1049 builder = builder.query(&[("disable_session_token", true)]);
1050 }
1051 let builder = self.auth.wrap(builder)?;
1052 let request = builder
1053 .headers(headers.clone())
1054 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1055 .build()?;
1056 let response = self
1057 .query_request_helper(request, true, false, true, RequestKind::Login)
1058 .await?;
1059 if response.status == StatusCode::NOT_FOUND {
1060 info!("login return 404, skip login on the old version server");
1061 return Ok(());
1062 }
1063 if response.status != StatusCode::OK {
1064 return Err(Self::status_body_to_error(response.status, &response.body)
1065 .with_context(RequestKind::Login));
1066 }
1067 if let Some(v) = response.headers.get(HEADER_SESSION_ID) {
1068 if let Ok(s) = v.to_str() {
1069 self.session_id = s.to_string();
1070 }
1071 }
1072
1073 let response = json_from_slice(&response.body)?;
1074 match response {
1075 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
1076 LoginResponseResult::Ok(info) => {
1077 let server_version = info
1078 .version
1079 .parse()
1080 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
1081 self.capability = Capability::from_server_version(&server_version);
1082 self.server_version = Some(server_version.clone());
1083 let session_id = self.session_id.as_str();
1084 if let Some(tokens) = info.tokens {
1085 info!(
1086 "[session {session_id}] login success with session token version = {server_version}",
1087 );
1088 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
1089 } else {
1090 info!("[session {session_id}] login success, version = {server_version}");
1091 }
1092 }
1093 }
1094 Ok(())
1095 }
1096
1097 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
1098 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
1099 let queries = self.queries_need_heartbeat.lock().clone();
1100 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
1101 let now = Instant::now();
1102
1103 let mut query_ids = Vec::new();
1104 for (qid, state) in queries {
1105 if state.need_heartbeat(now) {
1106 query_ids.push(qid.to_string());
1107 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
1108 arr.push(qid);
1109 } else {
1110 node_to_queries.insert(state.node_id, vec![qid]);
1111 }
1112 }
1113 }
1114
1115 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
1116 {
1117 return Ok(());
1118 }
1119
1120 let body = json!({
1121 "node_to_queries": node_to_queries
1122 });
1123 let headers = self.make_headers(None)?;
1124 let builder = self.cli.post(endpoint.clone()).json(&body).headers(headers);
1125 let request = self.wrap_auth_or_session_token(builder)?.build()?;
1126 let response = self
1127 .query_request_helper(request, true, false, true, RequestKind::Heartbeat)
1128 .await?;
1129 if response.status != StatusCode::OK {
1130 return Err(Self::status_body_to_error(response.status, &response.body)
1131 .with_context(RequestKind::Heartbeat));
1132 }
1133 let json: Value = json_from_slice(&response.body)?;
1134 let session_id = self.session_id.as_str();
1135 info!("[session {session_id}] heartbeat request={body}, response={json}");
1136 if let Some(queries_to_remove) = json.get("queries_to_remove") {
1137 if let Some(arr) = queries_to_remove.as_array() {
1138 if !arr.is_empty() {
1139 let mut queries = self.queries_need_heartbeat.lock();
1140 for q in arr {
1141 if let Some(q) = q.as_str() {
1142 queries.remove(q);
1143 }
1144 }
1145 }
1146 }
1147 }
1148 let now = Instant::now();
1149 let mut queries = self.queries_need_heartbeat.lock();
1150 for qid in query_ids {
1151 if let Some(state) = queries.get_mut(&qid) {
1152 *state.last_access_time.lock() = now;
1153 }
1154 }
1155 Ok(())
1156 }
1157
1158 fn build_log_out_request(&self) -> Result<Request> {
1159 let endpoint = self.endpoint.join("/v1/session/logout")?;
1160
1161 let session_state = self.session_state();
1162 let need_sticky = session_state.need_sticky.unwrap_or(false);
1163 let mut headers = self.make_headers(None)?;
1164 if need_sticky {
1165 if let Some(node_id) = self.last_node_id() {
1166 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1167 }
1168 }
1169 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1170
1171 let builder = self.wrap_auth_or_session_token(builder)?;
1172 let req = builder.build()?;
1173 Ok(req)
1174 }
1175
1176 pub(crate) fn need_logout(&self) -> bool {
1177 self.session_token_info.is_some()
1178 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1179 }
1180
1181 async fn refresh_session_token(
1182 &self,
1183 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1184 ) -> Result<()> {
1185 let (session_token_info, _) = { self_login_info.lock().clone() };
1186 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1187 let body = RefreshSessionTokenRequest {
1188 session_token: session_token_info.session_token.clone(),
1189 };
1190 let headers = self.make_headers(None)?;
1191 let request = self
1192 .cli
1193 .post(endpoint.clone())
1194 .json(&body)
1195 .headers(headers.clone())
1196 .bearer_auth(session_token_info.refresh_token.clone())
1197 .timeout(self.connect_timeout.saturating_add(Duration::from_secs(10)))
1198 .build()?;
1199 let response = self
1200 .query_request_helper(request, true, false, false, RequestKind::SessionRefresh)
1201 .await?;
1202 if response.status != StatusCode::OK {
1203 return Err(Self::status_body_to_error(response.status, &response.body)
1204 .with_context(RequestKind::SessionRefresh));
1205 }
1206 let response = json_from_slice(&response.body)?;
1207 match response {
1208 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1209 RefreshResponse::Ok(info) => {
1210 *self_login_info.lock() = (info, Instant::now());
1211 Ok(())
1212 }
1213 }
1214 }
1215
1216 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1217 if let Some(info) = &self.session_token_info {
1218 let (start, ttl) = {
1219 let guard = info.lock();
1220 (guard.1, guard.0.session_token_ttl_in_secs)
1221 };
1222 if Instant::now() > start + Duration::from_secs(ttl) {
1223 return Some(info.clone());
1224 }
1225 }
1226 None
1227 }
1228
1229 async fn query_request_helper(
1237 &self,
1238 mut request: Request,
1239 retry_if_503: bool,
1240 refresh_if_401: bool,
1241 reload_auth_if_401: bool,
1242 request_kind: RequestKind,
1243 ) -> Result<HttpResponseData> {
1244 let mut refreshed = false;
1245 let mut retries = 0;
1246 let max_retries = self.retry_count;
1247 let retry_delay = Duration::from_secs(self.retry_delay_secs);
1248
1249 loop {
1250 let req = request.try_clone().expect("request not cloneable");
1251 let response = match self.cli.execute(req).await {
1252 Ok(response) => response,
1253 Err(err) => {
1254 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1255 if retries >= max_retries.saturating_sub(1) {
1256 return Err(Self::with_request_context(
1257 err,
1258 &request_kind,
1259 &request,
1260 Some(retries),
1261 ));
1262 }
1263 retries += 1;
1264 warn!(
1265 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1266 retries,
1267 max_retries,
1268 request.url(),
1269 reason,
1270 err,
1271 retry_delay.as_secs(),
1272 );
1273 sleep(jitter(retry_delay)).await;
1274 continue;
1275 }
1276 return Err(Self::with_request_context(
1277 err,
1278 &request_kind,
1279 &request,
1280 Some(retries),
1281 ));
1282 }
1283 };
1284
1285 let status = response.status();
1286 let headers = response.headers().clone();
1287 let body = match response.bytes().await {
1288 Ok(body) => body,
1289 Err(err) => {
1290 if let Some(reason) = Self::retry_reason_for_reqwest(&err) {
1291 if retries >= max_retries.saturating_sub(1) {
1292 return Err(Self::with_request_context(
1293 err,
1294 &request_kind,
1295 &request,
1296 Some(retries),
1297 ));
1298 }
1299 retries += 1;
1300 warn!(
1301 "retry {}/{} for {} due to: {} (error: {}), retrying after {} seconds",
1302 retries,
1303 max_retries,
1304 request.url(),
1305 reason,
1306 err,
1307 retry_delay.as_secs(),
1308 );
1309 sleep(jitter(retry_delay)).await;
1310 continue;
1311 }
1312 return Err(Self::with_request_context(
1313 err,
1314 &request_kind,
1315 &request,
1316 Some(retries),
1317 ));
1318 }
1319 };
1320
1321 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1322 if retries >= max_retries.saturating_sub(1) {
1323 return Ok(HttpResponseData {
1324 status,
1325 headers,
1326 body,
1327 });
1328 }
1329 retries += 1;
1330 warn!(
1331 "retry {}/{} for {} due to: service unavailable (503), server may be starting, retrying after {} seconds",
1332 retries,
1333 max_retries,
1334 request.url(),
1335 retry_delay.as_secs()
1336 );
1337 sleep(jitter(retry_delay)).await;
1338 continue;
1339 }
1340
1341 if status == StatusCode::UNAUTHORIZED {
1342 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1343 let unauthorized_error =
1344 serde_json::from_slice::<ResponseWithErrorCode>(&body).ok();
1345 if let Some(session_token_info) = &self.session_token_info {
1346 let should_refresh = unauthorized_error
1347 .as_ref()
1348 .map(|r| need_refresh_token(r.error.code))
1349 .unwrap_or(false);
1350 if refresh_if_401 && should_refresh && !refreshed {
1351 Box::pin(self.refresh_session_token(session_token_info.clone())).await?;
1352 refreshed = true;
1353 retries = 0;
1354 warn!(
1355 "retry for {} due to: session token expired and refreshed, retrying after {} seconds",
1356 request.url(),
1357 retry_delay.as_secs()
1358 );
1359 sleep(jitter(retry_delay)).await;
1360 continue;
1361 }
1362 }
1363 if reload_auth_if_401 && self.auth.can_reload() {
1364 if retries >= max_retries.saturating_sub(1) {
1365 return Ok(HttpResponseData {
1366 status,
1367 headers,
1368 body,
1369 });
1370 }
1371 retries += 1;
1372 let builder =
1373 RequestBuilder::from_parts(HttpClient::new(), request.try_clone().unwrap());
1374 let builder = self.auth.wrap(builder)?;
1375 request = builder.build()?;
1376 warn!(
1377 "retry {}/{} for {} due to: authentication token reloaded, retrying after {} seconds",
1378 retries,
1379 max_retries,
1380 request.url(),
1381 retry_delay.as_secs()
1382 );
1383 sleep(jitter(retry_delay)).await;
1384 continue;
1385 }
1386 }
1387
1388 return Ok(HttpResponseData {
1389 status,
1390 headers,
1391 body,
1392 });
1393 }
1394 }
1395
1396 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1397 if let Err(err) = http_client.execute(request).await {
1398 error!("[session {session_id}] logout request failed: {err}");
1399 } else {
1400 info!("[session {session_id}] logout success");
1401 };
1402 }
1403
1404 pub async fn close(&self) {
1405 let session_id = &self.session_id;
1406 info!("[session {session_id}] try closing now");
1407 if self
1408 .closed
1409 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1410 .is_ok()
1411 {
1412 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1413 if self.need_logout() {
1414 let cli = self.cli.clone();
1415 let req = self
1416 .build_log_out_request()
1417 .expect("failed to build logout request");
1418 Self::logout(cli, req, &self.session_id).await;
1419 }
1420 }
1421 }
1422 pub fn close_with_spawn(&self) {
1423 let session_id = &self.session_id;
1424 info!("[session {session_id}]: try closing with spawn");
1425 if self
1426 .closed
1427 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1428 .is_ok()
1429 {
1430 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1431 if self.need_logout() {
1432 let cli = self.cli.clone();
1433 let req = self
1434 .build_log_out_request()
1435 .expect("failed to build logout request");
1436 let session_id = self.session_id.clone();
1437 GLOBAL_RUNTIME.spawn(async move {
1438 Self::logout(cli, req, session_id.as_str()).await;
1439 });
1440 }
1441 }
1442 }
1443
1444 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1445 let mut queries = self.queries_need_heartbeat.lock();
1446 queries.insert(query_id.to_string(), state);
1447 }
1448}
1449
1450fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1451where
1452 T: Deserialize<'a>,
1453{
1454 serde_json::from_slice::<T>(body).map_err(|e| {
1455 Error::Decode(format!(
1456 "fail to decode JSON response: {e}, body: {}",
1457 String::from_utf8_lossy(body)
1458 ))
1459 })
1460}
1461
1462impl Default for APIClient {
1463 fn default() -> Self {
1464 Self {
1465 session_id: Default::default(),
1466 cli: HttpClient::new(),
1467 scheme: "http".to_string(),
1468 endpoint: Url::parse("http://localhost:8080").unwrap(),
1469 host: "localhost".to_string(),
1470 port: 8000,
1471 tenant: None,
1472 warehouse: Mutex::new(None),
1473 auth: Arc::new(BasicAuth::new(DEFAULT_USERNAME, "")) as Arc<dyn Auth>,
1474 session_state: Mutex::new(SessionState::default()),
1475 wait_time_secs: None,
1476 max_rows_in_buffer: None,
1477 max_rows_per_page: None,
1478 connect_timeout: Duration::from_secs(10),
1479 page_request_timeout: Duration::from_secs(300),
1480 tls_ca_file: None,
1481 presign: Mutex::new(PresignMode::Auto),
1482 route_hint: RouteHintGenerator::new(),
1483 last_node_id: Default::default(),
1484 disable_session_token: true,
1485 disable_login: false,
1486 query_result_format: "json".to_string(),
1487 session_token_info: None,
1488 closed: AtomicBool::new(false),
1489 last_query_id: Default::default(),
1490 server_version: None,
1491 capability: Default::default(),
1492 queries_need_heartbeat: Default::default(),
1493 retry_count: 3,
1494 retry_delay_secs: 10,
1495 }
1496 }
1497}
1498
1499struct RouteHintGenerator {
1500 nonce: AtomicU64,
1501 current: std::sync::Mutex<String>,
1502}
1503
1504impl RouteHintGenerator {
1505 fn new() -> Self {
1506 let gen = Self {
1507 nonce: AtomicU64::new(0),
1508 current: std::sync::Mutex::new("".to_string()),
1509 };
1510 gen.next();
1511 gen
1512 }
1513
1514 fn current(&self) -> String {
1515 let guard = self.current.lock().unwrap();
1516 guard.clone()
1517 }
1518
1519 fn set(&self, hint: &str) {
1520 let mut guard = self.current.lock().unwrap();
1521 *guard = hint.to_string();
1522 }
1523
1524 fn next(&self) -> String {
1525 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1526 let uuid = uuid::Uuid::new_v4();
1527 let current = format!("rh:{uuid}:{nonce:06}");
1528 let mut guard = self.current.lock().unwrap();
1529 guard.clone_from(¤t);
1530 current
1531 }
1532}
1533
1534#[cfg(test)]
1535mod test {
1536 use super::*;
1537 use std::io::Write;
1538 use tempfile::NamedTempFile;
1539
1540 #[tokio::test]
1541 async fn parse_dsn() -> Result<()> {
1542 let dsn = "lake://username:password@app.lake.tidbcloud.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1543 let client = APIClient::from_dsn(dsn).await?;
1544 assert_eq!(client.host, "app.lake.tidbcloud.com");
1545 assert_eq!(
1546 client.endpoint,
1547 Url::parse("http://app.lake.tidbcloud.com:80")?
1548 );
1549 assert_eq!(client.wait_time_secs, Some(10));
1550 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1551 assert_eq!(client.max_rows_per_page, Some(10000));
1552 assert_eq!(client.tenant, None);
1553 assert_eq!(
1554 *client.warehouse.try_lock().unwrap(),
1555 Some("wh".to_string())
1556 );
1557 Ok(())
1558 }
1559
1560 #[tokio::test]
1561 async fn parse_dsn_with_private_key_uses_keypair_auth() -> Result<()> {
1562 let output = std::process::Command::new("openssl")
1563 .args([
1564 "genpkey",
1565 "-algorithm",
1566 "RSA",
1567 "-pkeyopt",
1568 "rsa_keygen_bits:2048",
1569 ])
1570 .output();
1571 let output = match output {
1572 Ok(o) if o.status.success() => o,
1573 _ => return Ok(()),
1574 };
1575
1576 let mut key_file = NamedTempFile::new()?;
1577 key_file.write_all(&output.stdout)?;
1578
1579 let dsn = format!(
1580 "lake://username:password@app.tidbcloud.com/test?private_key_file={}&sslmode=disable",
1581 url::form_urlencoded::byte_serialize(key_file.path().to_string_lossy().as_bytes())
1582 .collect::<String>()
1583 );
1584 let client = APIClient::from_dsn(&dsn).await?;
1585 assert_eq!(client.auth.username(), "username");
1586 assert!(client.auth.can_reload());
1587
1588 let request = client
1589 .auth
1590 .wrap(client.cli.get(client.endpoint.join("v1/query")?))?
1591 .build()?;
1592 assert_eq!(
1593 request
1594 .headers()
1595 .get("X-DATABEND-AUTH-METHOD")
1596 .and_then(|value| value.to_str().ok()),
1597 Some("keypair")
1598 );
1599 let authorization = request
1600 .headers()
1601 .get(reqwest::header::AUTHORIZATION)
1602 .and_then(|value| value.to_str().ok())
1603 .unwrap_or_default();
1604 assert!(authorization.starts_with("Bearer "));
1605 assert_eq!(authorization["Bearer ".len()..].split('.').count(), 3);
1606 Ok(())
1607 }
1608
1609 #[tokio::test]
1610 async fn parse_dsn_with_private_key_requires_user() -> Result<()> {
1611 let output = std::process::Command::new("openssl")
1612 .args([
1613 "genpkey",
1614 "-algorithm",
1615 "RSA",
1616 "-pkeyopt",
1617 "rsa_keygen_bits:2048",
1618 ])
1619 .output();
1620 let output = match output {
1621 Ok(o) if o.status.success() => o,
1622 _ => return Ok(()),
1623 };
1624
1625 let mut key_file = NamedTempFile::new()?;
1626 key_file.write_all(&output.stdout)?;
1627
1628 let dsn = format!(
1629 "lake://app.tidbcloud.com/test?private_key_file={}&sslmode=disable",
1630 url::form_urlencoded::byte_serialize(key_file.path().to_string_lossy().as_bytes())
1631 .collect::<String>()
1632 );
1633 let err = APIClient::from_dsn(&dsn)
1634 .await
1635 .err()
1636 .expect("key-pair authentication should require an explicit username");
1637 assert!(
1638 err.to_string()
1639 .contains("username is required for key-pair authentication"),
1640 "unexpected error: {err}"
1641 );
1642
1643 Ok(())
1644 }
1645
1646 #[tokio::test]
1647 async fn parse_encoded_password() -> Result<()> {
1648 let dsn = "lake://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1649 let client = APIClient::from_dsn(dsn).await?;
1650 assert_eq!(client.host(), "localhost");
1651 assert_eq!(client.port(), 443);
1652 Ok(())
1653 }
1654
1655 #[tokio::test]
1656 async fn parse_special_chars_password() -> Result<()> {
1657 let dsn = "lake://username:3a@SC(nYE1k={{R@localhost:8000";
1658 let client = APIClient::from_dsn(dsn).await?;
1659 assert_eq!(client.host(), "localhost");
1660 assert_eq!(client.port(), 8000);
1661 Ok(())
1662 }
1663}