1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::{de, Deserialize};
49use serde_json::{json, Value};
50use std::collections::{BTreeMap, HashMap};
51use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tokio::time::sleep;
55use tokio_retry::strategy::jitter;
56use tokio_stream::StreamExt;
57use tokio_util::io::ReaderStream;
58use url::Url;
59
60const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
61const HEADER_TENANT: &str = "X-DATABEND-TENANT";
62const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
63const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
64const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
65const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
66const TXN_STATE_ACTIVE: &str = "Active";
67const HEADER_SQL: &str = "X-DATABEND-SQL";
68const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
69const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
70const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
71const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
72
73static VERSION: Lazy<String> = Lazy::new(|| {
74 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
75 version.to_string()
76});
77
78#[derive(Clone)]
79pub(crate) struct QueryState {
80 pub node_id: String,
81 pub last_access_time: Arc<Mutex<Instant>>,
82 pub timeout_secs: u64,
83}
84
85impl QueryState {
86 pub fn need_heartbeat(&self, now: Instant) -> bool {
87 let t = *self.last_access_time.lock();
88 now.duration_since(t).as_secs() > self.timeout_secs / 2
89 }
90}
91
92pub struct APIClient {
93 pub(crate) session_id: String,
94 cli: HttpClient,
95 scheme: String,
96 host: String,
97 port: u16,
98
99 endpoint: Url,
100
101 auth: Arc<dyn Auth>,
102
103 tenant: Option<String>,
104 warehouse: Mutex<Option<String>>,
105 session_state: Mutex<SessionState>,
106 route_hint: RouteHintGenerator,
107
108 disable_login: bool,
109 query_result_format: String,
110 disable_session_token: bool,
111 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
112
113 closed: AtomicBool,
114
115 server_version: Option<Version>,
116
117 wait_time_secs: Option<i64>,
118 max_rows_in_buffer: Option<i64>,
119 max_rows_per_page: Option<i64>,
120
121 connect_timeout: Duration,
122 page_request_timeout: Duration,
123
124 tls_ca_file: Option<String>,
125
126 presign: Mutex<PresignMode>,
127 last_node_id: Mutex<Option<String>>,
128 last_query_id: Mutex<Option<String>>,
129
130 capability: Capability,
131
132 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
133}
134
135impl Drop for APIClient {
136 fn drop(&mut self) {
137 self.close_with_spawn()
138 }
139}
140
141impl APIClient {
142 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
143 let mut client = Self::from_dsn(dsn).await?;
144 client.build_client(name).await?;
145 if !client.disable_login {
146 client.login().await?;
147 }
148 if client.session_id.is_empty() {
149 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
150 }
151 let client = Arc::new(client);
152 client.check_presign().await?;
153 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
154 Ok(client)
155 }
156
157 pub fn capability(&self) -> &Capability {
158 &self.capability
159 }
160
161 fn set_presign_mode(&self, mode: PresignMode) {
162 *self.presign.lock() = mode
163 }
164 fn get_presign_mode(&self) -> PresignMode {
165 *self.presign.lock()
166 }
167
168 async fn from_dsn(dsn: &str) -> Result<Self> {
169 let u = Url::parse(dsn)?;
170 let mut client = Self::default();
171 if let Some(host) = u.host_str() {
172 client.host = host.to_string();
173 }
174
175 if u.username() != "" {
176 let password = u.password().unwrap_or_default();
177 let password = percent_decode_str(password).decode_utf8()?;
178 client.auth = Arc::new(BasicAuth::new(u.username(), password));
179 }
180
181 let mut session_state = SessionState::default();
182
183 let database = u.path().trim_start_matches('/');
184 if !database.is_empty() {
185 session_state.set_database(database);
186 }
187
188 let mut scheme = "https";
189 for (k, v) in u.query_pairs() {
190 match k.as_ref() {
191 "wait_time_secs" => {
192 client.wait_time_secs = Some(v.parse()?);
193 }
194 "max_rows_in_buffer" => {
195 client.max_rows_in_buffer = Some(v.parse()?);
196 }
197 "max_rows_per_page" => {
198 client.max_rows_per_page = Some(v.parse()?);
199 }
200 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
201 "page_request_timeout_secs" => {
202 client.page_request_timeout = {
203 let secs: u64 = v.parse()?;
204 Duration::from_secs(secs)
205 };
206 }
207 "presign" => {
208 let presign_mode = match v.as_ref() {
209 "auto" => PresignMode::Auto,
210 "detect" => PresignMode::Detect,
211 "on" => PresignMode::On,
212 "off" => PresignMode::Off,
213 _ => {
214 return Err(Error::BadArgument(format!(
215 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
216 )))
217 }
218 };
219 client.set_presign_mode(presign_mode);
220 }
221 "tenant" => {
222 client.tenant = Some(v.to_string());
223 }
224 "warehouse" => {
225 client.warehouse = Mutex::new(Some(v.to_string()));
226 }
227 "role" => session_state.set_role(v),
228 "sslmode" => match v.as_ref() {
229 "disable" => scheme = "http",
230 "require" | "enable" => scheme = "https",
231 _ => {
232 return Err(Error::BadArgument(format!(
233 "Invalid value for sslmode: {v}"
234 )))
235 }
236 },
237 "tls_ca_file" => {
238 client.tls_ca_file = Some(v.to_string());
239 }
240 "access_token" => {
241 client.auth = Arc::new(AccessTokenAuth::new(v));
242 }
243 "access_token_file" => {
244 client.auth = Arc::new(AccessTokenFileAuth::new(v));
245 }
246 "login" => {
247 client.disable_login = match v.as_ref() {
248 "disable" => true,
249 "enable" => false,
250 _ => {
251 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
252 }
253 }
254 }
255 "session_token" => {
256 client.disable_session_token = match v.as_ref() {
257 "disable" => true,
258 "enable" => false,
259 _ => {
260 return Err(Error::BadArgument(format!(
261 "Invalid value for session_token: {v}"
262 )))
263 }
264 }
265 }
266 "body_format" | "query_result_format" => {
267 let v = v.to_string().to_lowercase();
268 match v.as_str() {
269 "json" | "arrow" => client.query_result_format = v.to_string(),
270 _ => {
271 return Err(Error::BadArgument(format!(
272 "Invalid value for query_result_format: {v}"
273 )))
274 }
275 }
276 }
277 _ => {
278 session_state.set(k, v);
279 }
280 }
281 }
282 client.port = match u.port() {
283 Some(p) => p,
284 None => match scheme {
285 "http" => 80,
286 "https" => 443,
287 _ => unreachable!(),
288 },
289 };
290 client.scheme = scheme.to_string();
291 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
292 client.session_state = Mutex::new(session_state);
293
294 Ok(client)
295 }
296
297 pub fn host(&self) -> &str {
298 self.host.as_str()
299 }
300
301 pub fn port(&self) -> u16 {
302 self.port
303 }
304
305 pub fn scheme(&self) -> &str {
306 self.scheme.as_str()
307 }
308
309 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
310 let ua = match name {
311 Some(n) => n,
312 None => format!("databend-client-rust/{}", VERSION.as_str()),
313 };
314 let cookie_provider = GlobalCookieStore::new();
315 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
316 let mut initial_cookies = [&cookie].into_iter();
317 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
318 let mut cli_builder = HttpClient::builder()
319 .user_agent(ua)
320 .cookie_provider(Arc::new(cookie_provider))
321 .pool_idle_timeout(Duration::from_secs(1));
322 #[cfg(any(feature = "rustls", feature = "native-tls"))]
323 if self.scheme == "https" {
324 if let Some(ref ca_file) = self.tls_ca_file {
325 let cert_pem = tokio::fs::read(ca_file).await?;
326 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
327 cli_builder = cli_builder.add_root_certificate(cert);
328 }
329 }
330 self.cli = cli_builder.build()?;
331 Ok(())
332 }
333
334 async fn check_presign(self: &Arc<Self>) -> Result<()> {
335 let mode = match self.get_presign_mode() {
336 PresignMode::Auto => {
337 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
338 PresignMode::On
339 } else {
340 PresignMode::Off
341 }
342 }
343 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
344 Ok(_) => PresignMode::On,
345 Err(e) => {
346 warn!("presign mode off with error detected: {e}");
347 PresignMode::Off
348 }
349 },
350 mode => mode,
351 };
352 self.set_presign_mode(mode);
353 Ok(())
354 }
355
356 pub fn current_warehouse(&self) -> Option<String> {
357 let guard = self.warehouse.lock();
358 guard.clone()
359 }
360
361 pub fn current_catalog(&self) -> Option<String> {
362 let guard = self.session_state.lock();
363 guard.catalog.clone()
364 }
365
366 pub fn current_database(&self) -> Option<String> {
367 let guard = self.session_state.lock();
368 guard.database.clone()
369 }
370
371 pub fn set_warehouse(&self, warehouse: impl Into<String>) {
372 let mut guard = self.warehouse.lock();
373 *guard = Some(warehouse.into());
374 }
375
376 pub fn set_database(&self, database: impl Into<String>) {
377 let mut guard = self.session_state.lock();
378 guard.set_database(database);
379 }
380
381 pub fn set_role(&self, role: impl Into<String>) {
382 let mut guard = self.session_state.lock();
383 guard.set_role(role);
384 }
385
386 pub fn set_session(&self, key: impl Into<String>, value: impl Into<String>) {
387 let mut guard = self.session_state.lock();
388 guard.set(key, value);
389 }
390
391 pub async fn current_role(&self) -> Option<String> {
392 let guard = self.session_state.lock();
393 guard.role.clone()
394 }
395
396 fn in_active_transaction(&self) -> bool {
397 let guard = self.session_state.lock();
398 guard
399 .txn_state
400 .as_ref()
401 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
402 .unwrap_or(false)
403 }
404
405 pub fn username(&self) -> String {
406 self.auth.username()
407 }
408
409 fn gen_query_id(&self) -> String {
410 uuid::Uuid::now_v7().simple().to_string()
411 }
412
413 async fn handle_session(&self, session: &Option<SessionState>) {
414 let session = match session {
415 Some(session) => session,
416 None => return,
417 };
418
419 {
421 let mut session_state = self.session_state.lock();
422 *session_state = session.clone();
423 }
424
425 if let Some(settings) = session.settings.as_ref() {
427 if let Some(v) = settings.get("warehouse") {
428 let mut warehouse = self.warehouse.lock();
429 *warehouse = Some(v.clone());
430 }
431 }
432 }
433
434 pub fn set_last_node_id(&self, node_id: String) {
435 *self.last_node_id.lock() = Some(node_id)
436 }
437
438 pub fn set_last_query_id(&self, query_id: Option<String>) {
439 *self.last_query_id.lock() = query_id
440 }
441
442 pub fn last_query_id(&self) -> Option<String> {
443 self.last_query_id.lock().clone()
444 }
445
446 fn last_node_id(&self) -> Option<String> {
447 self.last_node_id.lock().clone()
448 }
449
450 fn handle_warnings(&self, resp: &QueryResponse) {
451 if let Some(warnings) = &resp.warnings {
452 for w in warnings {
453 warn!(target: "server_warnings", "server warning: {w}");
454 }
455 }
456 }
457
458 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
459 info!("start query: {sql}");
460 let (resp, batches) = self.start_query_inner(sql, None, false).await?;
461 Pages::new(self.clone(), resp, batches, need_progress)
462 }
463
464 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
465 let mut mgr = self.queries_need_heartbeat.lock();
466 if let Some(state) = mgr.remove(query_id) {
467 let self_cloned = self.clone();
468 let query_id = query_id.to_owned();
469 GLOBAL_RUNTIME.spawn(async move {
470 if let Err(e) = self_cloned
471 .end_query(&query_id, "final", Some(state.node_id.as_str()))
472 .await
473 {
474 error!("failed to final query {query_id}: {e}");
475 }
476 });
477 }
478 }
479
480 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
481 if let Some(info) = &self.session_token_info {
482 let info = info.lock();
483 Ok(builder.bearer_auth(info.0.session_token.clone()))
484 } else {
485 self.auth.wrap(builder)
486 }
487 }
488
489 async fn start_query_inner(
490 &self,
491 sql: &str,
492 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
493 force_json_body: bool,
494 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
495 if !self.in_active_transaction() {
496 self.route_hint.next();
497 }
498 let endpoint = self.endpoint.join("v1/query")?;
499
500 let session_state = self.session_state();
502 let need_sticky = session_state.need_sticky.unwrap_or(false);
503 let req = QueryRequest::new(sql)
504 .with_pagination(self.make_pagination())
505 .with_session(Some(session_state))
506 .with_stage_attachment(stage_attachment_config);
507
508 let query_id = self.gen_query_id();
510 let mut headers = self.make_headers(Some(&query_id))?;
511 if self.capability.arrow_data && self.query_result_format == "arrow" && !force_json_body {
512 debug!("accept arrow data");
513 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
514 }
515
516 if need_sticky {
517 if let Some(node_id) = self.last_node_id() {
518 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
519 }
520 }
521 let mut builder = self.cli.post(endpoint.clone()).json(&req);
522 builder = self.wrap_auth_or_session_token(builder)?;
523 let request = builder.headers(headers.clone()).build()?;
524 let response = self.query_request_helper(request, true, true).await?;
525 self.handle_page(response, true).await
526 }
527
528 fn is_arrow_data(response: &Response) -> bool {
529 if let Some(typ) = response.headers().get(CONTENT_TYPE) {
530 if let Ok(t) = typ.to_str() {
531 return t == CONTENT_TYPE_ARROW;
532 }
533 }
534 false
535 }
536
537 async fn handle_page(
538 &self,
539 response: Response,
540 is_first: bool,
541 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
542 let status = response.status();
543 if status != 200 {
544 return Err(Error::response_error(status, &response.bytes().await?));
545 }
546 let is_arrow_data = Self::is_arrow_data(&response);
547 if is_first {
548 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
549 self.route_hint.set(route_hint.to_str().unwrap_or_default());
550 }
551 }
552 let mut body = response.bytes().await?;
553 let mut batches = vec![];
554 if is_arrow_data {
555 if is_first {
556 debug!("received arrow data");
557 }
558 let cursor = std::io::Cursor::new(body.as_ref());
559 let reader = StreamReader::try_new(cursor, None)
560 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
561 let schema = reader.schema();
562 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
563 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
564 } else {
565 return Err(Error::Decode(
566 "missing response_header metadata in arrow payload".to_string(),
567 ));
568 };
569 for batch in reader {
570 let batch = batch
571 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
572 batches.push(batch);
573 }
574 body = json_body
575 };
576 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
577 if let Error::Logic(status, ec) = &e {
578 if *status == 404 {
579 return Error::QueryNotFound(ec.message.clone());
580 }
581 }
582 e
583 })?;
584 self.handle_session(&resp.session).await;
585 if let Some(err) = &resp.error {
586 return Err(Error::QueryFailed(err.clone()));
587 }
588 if is_first {
589 self.handle_warnings(&resp);
590 self.set_last_query_id(Some(resp.id.clone()));
591 if let Some(node_id) = &resp.node_id {
592 self.set_last_node_id(node_id.clone());
593 }
594 }
595 Ok((resp, batches))
596 }
597
598 pub async fn query_page(
599 &self,
600 query_id: &str,
601 next_uri: &str,
602 node_id: &Option<String>,
603 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
604 info!("query page: {next_uri}");
605 let endpoint = self.endpoint.join(next_uri)?;
606 let mut headers = self.make_headers(Some(query_id))?;
607 if self.capability.arrow_data && self.query_result_format == "arrow" {
608 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
609 }
610 let mut builder = self.cli.get(endpoint.clone());
611 builder = self
612 .wrap_auth_or_session_token(builder)?
613 .headers(headers.clone())
614 .timeout(self.page_request_timeout);
615 if let Some(node_id) = node_id {
616 builder = builder.header(HEADER_STICKY_NODE, node_id)
617 }
618 let request = builder.build()?;
619
620 let response = self.query_request_helper(request, false, true).await?;
621 self.handle_page(response, false).await
622 }
623
624 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
625 self.end_query(query_id, "kill", None).await
626 }
627
628 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
629 self.end_query(query_id, "final", node_id).await
630 }
631
632 pub async fn end_query(
633 &self,
634 query_id: &str,
635 method: &str,
636 node_id: Option<&str>,
637 ) -> Result<()> {
638 let uri = format!("/v1/query/{query_id}/{method}");
639 let endpoint = self.endpoint.join(&uri)?;
640 let headers = self.make_headers(Some(query_id))?;
641
642 info!("{method} query: {uri}");
643
644 let mut builder = self.cli.post(endpoint);
645 if let Some(node_id) = node_id {
646 builder = builder.header(HEADER_STICKY_NODE, node_id)
647 }
648 builder = self.wrap_auth_or_session_token(builder)?;
649 let resp = builder.headers(headers.clone()).send().await?;
650 if resp.status() != 200 {
651 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
652 .with_context(&format!("{method} query")));
653 }
654 Ok(())
655 }
656
657 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
658 self.query_all_inner(sql, false).await
659 }
660
661 pub async fn query_all_inner(
662 self: &Arc<Self>,
663 sql: &str,
664 force_json_body: bool,
665 ) -> Result<Page> {
666 let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
667 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
668 let mut all = Page::default();
669 while let Some(page) = pages.next().await {
670 all.update(page?);
671 }
672 Ok(all)
673 }
674
675 fn session_state(&self) -> SessionState {
676 self.session_state.lock().clone()
677 }
678
679 fn make_pagination(&self) -> Option<PaginationConfig> {
680 if self.wait_time_secs.is_none()
681 && self.max_rows_in_buffer.is_none()
682 && self.max_rows_per_page.is_none()
683 {
684 return None;
685 }
686 let mut pagination = PaginationConfig {
687 wait_time_secs: None,
688 max_rows_in_buffer: None,
689 max_rows_per_page: None,
690 };
691 if let Some(wait_time_secs) = self.wait_time_secs {
692 pagination.wait_time_secs = Some(wait_time_secs);
693 }
694 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
695 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
696 }
697 if let Some(max_rows_per_page) = self.max_rows_per_page {
698 pagination.max_rows_per_page = Some(max_rows_per_page);
699 }
700 Some(pagination)
701 }
702
703 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
704 let mut headers = HeaderMap::new();
705 if let Some(tenant) = &self.tenant {
706 headers.insert(HEADER_TENANT, tenant.parse()?);
707 }
708 let warehouse = self.warehouse.lock().clone();
709 if let Some(warehouse) = warehouse {
710 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
711 }
712 let route_hint = self.route_hint.current();
713 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
714 if let Some(query_id) = query_id {
715 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
716 }
717 Ok(headers)
718 }
719
720 pub async fn insert_with_stage(
721 self: &Arc<Self>,
722 sql: &str,
723 stage: &str,
724 file_format_options: BTreeMap<&str, &str>,
725 copy_options: BTreeMap<&str, &str>,
726 ) -> Result<QueryStats> {
727 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
728 let stage_attachment = Some(StageAttachmentConfig {
729 location: stage,
730 file_format_options: Some(file_format_options),
731 copy_options: Some(copy_options),
732 });
733 let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
734 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
735 let mut all = Page::default();
736 while let Some(page) = pages.next().await {
737 all.update(page?);
738 }
739 Ok(all.stats)
740 }
741
742 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
743 info!("get presigned upload url: {stage}");
744 let sql = format!("PRESIGN UPLOAD {stage}");
745 let resp = self.query_all_inner(&sql, true).await?;
746 if resp.data.len() != 1 {
747 return Err(Error::Decode(
748 "Empty response from server for presigned request".to_string(),
749 ));
750 }
751 if resp.data[0].len() != 3 {
752 return Err(Error::Decode(
753 "Invalid response from server for presigned request".to_string(),
754 ));
755 }
756 let method = resp.data[0][0].clone().unwrap_or_default();
758 if method != "PUT" {
759 return Err(Error::Decode(format!(
760 "Invalid method for presigned upload request: {method}"
761 )));
762 }
763 let headers: BTreeMap<String, String> =
764 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
765 let url = resp.data[0][2].clone().unwrap_or_default();
766 Ok(PresignedResponse {
767 method,
768 headers,
769 url,
770 })
771 }
772
773 pub async fn upload_to_stage(
774 self: &Arc<Self>,
775 stage: &str,
776 data: Reader,
777 size: u64,
778 ) -> Result<()> {
779 match self.get_presign_mode() {
780 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
781 PresignMode::On => {
782 let presigned = self.get_presigned_upload_url(stage).await?;
783 presign_upload_to_stage(presigned, data, size).await
784 }
785 PresignMode::Auto => {
786 unreachable!("PresignMode::Auto should be handled during client initialization")
787 }
788 PresignMode::Detect => {
789 unreachable!("PresignMode::Detect should be handled during client initialization")
790 }
791 }
792 }
793
794 async fn upload_to_stage_with_stream(
796 &self,
797 stage: &str,
798 data: Reader,
799 size: u64,
800 ) -> Result<()> {
801 info!("upload to stage with stream: {stage}, size: {size}");
802 if let Some(info) = self.need_pre_refresh_session().await {
803 self.refresh_session_token(info).await?;
804 }
805 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
806 let location = StageLocation::try_from(stage)?;
807 let query_id = self.gen_query_id();
808 let mut headers = self.make_headers(Some(&query_id))?;
809 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
810 let stream = Body::wrap_stream(ReaderStream::new(data));
811 let part = Part::stream_with_length(stream, size).file_name(location.path);
812 let form = Form::new().part("upload", part);
813 let mut builder = self.cli.put(endpoint.clone());
814 builder = self.wrap_auth_or_session_token(builder)?;
815 let resp = builder.headers(headers).multipart(form).send().await?;
816 let status = resp.status();
817 if status != 200 {
818 return Err(
819 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
820 );
821 }
822 Ok(())
823 }
824
825 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
828 where
829 T: de::DeserializeOwned,
830 {
831 if value.starts_with("{") {
832 serde_json::from_slice(value.as_bytes())
833 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
834 } else {
835 let json = URL_SAFE.decode(value).map_err(|e| {
836 format!(
837 "Invalid value {} for {key}, base64 decode error: {}",
838 value, e
839 )
840 })?;
841 serde_json::from_slice(&json).map_err(|e| {
842 format!(
843 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
844 String::from_utf8_lossy(&json)
845 )
846 })
847 }
848 }
849
850 pub async fn streaming_load(
851 &self,
852 sql: &str,
853 data: Reader,
854 file_name: &str,
855 ) -> Result<LoadResponse> {
856 let body = Body::wrap_stream(ReaderStream::new(data));
857 let part = Part::stream(body).file_name(file_name.to_string());
858 let endpoint = self.endpoint.join("v1/streaming_load")?;
859 let mut builder = self.cli.put(endpoint.clone());
860 builder = self.wrap_auth_or_session_token(builder)?;
861 let query_id = self.gen_query_id();
862 let mut headers = self.make_headers(Some(&query_id))?;
863 headers.insert(HEADER_SQL, sql.parse()?);
864 let session = serde_json::to_string(&*self.session_state.lock())
865 .expect("serialize session state should not fail");
866 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
867 let form = Form::new().part("upload", part);
868 let resp = builder.headers(headers).multipart(form).send().await?;
869 let status = resp.status();
870 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
871 match Self::decode_json_header::<SessionState>(
872 HEADER_QUERY_CONTEXT,
873 value.to_str().unwrap(),
874 ) {
875 Ok(session) => *self.session_state.lock() = session,
876 Err(e) => {
877 error!("Error decoding session state when streaming load: {e}");
878 }
879 }
880 };
881 if status != 200 {
882 return Err(
883 Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
884 );
885 }
886 let resp = resp.json::<LoadResponse>().await?;
887 Ok(resp)
888 }
889
890 async fn login(&mut self) -> Result<()> {
891 let endpoint = self.endpoint.join("/v1/session/login")?;
892 let headers = self.make_headers(None)?;
893 let body = LoginRequest::from(&*self.session_state.lock());
894 let mut builder = self.cli.post(endpoint.clone()).json(&body);
895 if self.disable_session_token {
896 builder = builder.query(&[("disable_session_token", true)]);
897 }
898 let builder = self.auth.wrap(builder)?;
899 let request = builder
900 .headers(headers.clone())
901 .timeout(self.connect_timeout)
902 .build()?;
903 let response = self.query_request_helper(request, true, false).await;
904 let response = match response {
905 Ok(r) => r,
906 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
907 info!("login return 404, skip login on the old version server");
908 return Ok(());
909 }
910 Err(e) => return Err(e),
911 };
912 if let Some(v) = response.headers().get(HEADER_SESSION_ID) {
913 if let Ok(s) = v.to_str() {
914 self.session_id = s.to_string();
915 }
916 }
917
918 let body = response.bytes().await?;
919 let response = json_from_slice(&body)?;
920 match response {
921 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
922 LoginResponseResult::Ok(info) => {
923 let server_version = info
924 .version
925 .parse()
926 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
927 self.capability = Capability::from_server_version(&server_version);
928 self.server_version = Some(server_version.clone());
929 let session_id = self.session_id.as_str();
930 if let Some(tokens) = info.tokens {
931 info!(
932 "[session {session_id}] login success with session token version = {server_version}",
933 );
934 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
935 } else {
936 info!("[session {session_id}] login success, version = {server_version}");
937 }
938 }
939 }
940 Ok(())
941 }
942
943 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
944 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
945 let queries = self.queries_need_heartbeat.lock().clone();
946 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
947 let now = Instant::now();
948
949 let mut query_ids = Vec::new();
950 for (qid, state) in queries {
951 if state.need_heartbeat(now) {
952 query_ids.push(qid.to_string());
953 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
954 arr.push(qid);
955 } else {
956 node_to_queries.insert(state.node_id, vec![qid]);
957 }
958 }
959 }
960
961 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
962 {
963 return Ok(());
964 }
965
966 let body = json!({
967 "node_to_queries": node_to_queries
968 });
969 let builder = self.cli.post(endpoint.clone()).json(&body);
970 let request = self.wrap_auth_or_session_token(builder)?.build()?;
971 let response = self.query_request_helper(request, true, false).await?;
972 let json: Value = response.json().await?;
973 let session_id = self.session_id.as_str();
974 info!("[session {session_id}] heartbeat request={body}, response={json}");
975 if let Some(queries_to_remove) = json.get("queries_to_remove") {
976 if let Some(arr) = queries_to_remove.as_array() {
977 if !arr.is_empty() {
978 let mut queries = self.queries_need_heartbeat.lock();
979 for q in arr {
980 if let Some(q) = q.as_str() {
981 queries.remove(q);
982 }
983 }
984 }
985 }
986 }
987 let now = Instant::now();
988 let mut queries = self.queries_need_heartbeat.lock();
989 for qid in query_ids {
990 if let Some(state) = queries.get_mut(&qid) {
991 *state.last_access_time.lock() = now;
992 }
993 }
994 Ok(())
995 }
996
997 fn build_log_out_request(&self) -> Result<Request> {
998 let endpoint = self.endpoint.join("/v1/session/logout")?;
999
1000 let session_state = self.session_state();
1001 let need_sticky = session_state.need_sticky.unwrap_or(false);
1002 let mut headers = self.make_headers(None)?;
1003 if need_sticky {
1004 if let Some(node_id) = self.last_node_id() {
1005 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
1006 }
1007 }
1008 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
1009
1010 let builder = self.wrap_auth_or_session_token(builder)?;
1011 let req = builder.build()?;
1012 Ok(req)
1013 }
1014
1015 pub(crate) fn need_logout(&self) -> bool {
1016 self.session_token_info.is_some()
1017 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1018 }
1019
1020 async fn refresh_session_token(
1021 &self,
1022 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1023 ) -> Result<()> {
1024 let (session_token_info, _) = { self_login_info.lock().clone() };
1025 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1026 let body = RefreshSessionTokenRequest {
1027 session_token: session_token_info.session_token.clone(),
1028 };
1029 let headers = self.make_headers(None)?;
1030 let request = self
1031 .cli
1032 .post(endpoint.clone())
1033 .json(&body)
1034 .headers(headers.clone())
1035 .bearer_auth(session_token_info.refresh_token.clone())
1036 .timeout(self.connect_timeout)
1037 .build()?;
1038
1039 for i in 0..3 {
1041 let req = request.try_clone().expect("request not cloneable");
1042 match self.cli.execute(req).await {
1043 Ok(response) => {
1044 let status = response.status();
1045 let body = response.bytes().await?;
1046 if status == StatusCode::OK {
1047 let response = json_from_slice(&body)?;
1048 return match response {
1049 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1050 RefreshResponse::Ok(info) => {
1051 *self_login_info.lock() = (info, Instant::now());
1052 Ok(())
1053 }
1054 };
1055 }
1056 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
1057 return Err(Error::response_error(status, &body));
1058 }
1059 }
1060 Err(err) => {
1061 if !(err.is_timeout() || err.is_connect()) || i > 2 {
1062 return Err(Error::Request(err.to_string()));
1063 }
1064 }
1065 };
1066 sleep(jitter(Duration::from_secs(10))).await;
1067 }
1068 Ok(())
1069 }
1070
1071 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1072 if let Some(info) = &self.session_token_info {
1073 let (start, ttl) = {
1074 let guard = info.lock();
1075 (guard.1, guard.0.session_token_ttl_in_secs)
1076 };
1077 if Instant::now() > start + Duration::from_secs(ttl) {
1078 return Some(info.clone());
1079 }
1080 }
1081 None
1082 }
1083
1084 async fn query_request_helper(
1092 &self,
1093 mut request: Request,
1094 retry_if_503: bool,
1095 refresh_if_401: bool,
1096 ) -> std::result::Result<Response, Error> {
1097 let mut refreshed = false;
1098 let mut retries = 0;
1099 loop {
1100 let req = request.try_clone().expect("request not cloneable");
1101 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
1102 Ok(response) => {
1103 let status = response.status();
1104 if status == StatusCode::OK {
1105 return Ok(response);
1106 }
1107 let body = response.bytes().await?;
1108 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1109 (Error::response_error(status, &body), true)
1111 } else {
1112 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
1113 match resp {
1114 Ok(r) => {
1115 let e = r.error;
1116 if status == StatusCode::UNAUTHORIZED {
1117 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1118 if let Some(session_token_info) = &self.session_token_info {
1119 info!(
1120 "will retry {} after refresh token on auth error {}",
1121 request.url(),
1122 e
1123 );
1124 let retry = if need_refresh_token(e.code)
1125 && !refreshed
1126 && refresh_if_401
1127 {
1128 self.refresh_session_token(session_token_info.clone())
1129 .await?;
1130 refreshed = true;
1131 true
1132 } else {
1133 false
1134 };
1135 (Error::AuthFailure(e), retry)
1136 } else if self.auth.can_reload() {
1137 info!(
1138 "will retry {} after reload token on auth error {}",
1139 request.url(),
1140 e
1141 );
1142 let builder = RequestBuilder::from_parts(
1143 HttpClient::new(),
1144 request.try_clone().unwrap(),
1145 );
1146 let builder = self.auth.wrap(builder)?;
1147 request = builder.build()?;
1148 (Error::AuthFailure(e), true)
1149 } else {
1150 (Error::AuthFailure(e), false)
1151 }
1152 } else {
1153 (Error::Logic(status, e), false)
1154 }
1155 }
1156 Err(_) => (
1157 Error::Response {
1158 status,
1159 msg: String::from_utf8_lossy(&body).to_string(),
1160 },
1161 false,
1162 ),
1163 }
1164 }
1165 }
1166 Err(err) => (
1167 Error::Request(err.to_string()),
1168 err.is_timeout() || err.is_connect() || err.is_request(),
1169 ),
1170 };
1171 if !retry {
1172 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
1173 }
1174 match &err {
1175 Error::AuthFailure(_) => {
1176 if refreshed {
1177 retries = 0;
1178 } else if retries == 2 {
1179 return Err(err.with_context(&format!(
1180 "{} {} after 3 retries",
1181 request.method(),
1182 request.url()
1183 )));
1184 }
1185 }
1186 _ => {
1187 if retries == 2 {
1188 return Err(err.with_context(&format!(
1189 "{} {} after 3 reties",
1190 request.method(),
1191 request.url()
1192 )));
1193 }
1194 retries += 1;
1195 info!(
1196 "will retry {} the {retries}th times on error {}",
1197 request.url(),
1198 err
1199 );
1200 }
1201 }
1202 warn!("will retry after 10 seconds");
1203 sleep(jitter(Duration::from_secs(10))).await;
1204 }
1205 }
1206
1207 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1208 if let Err(err) = http_client.execute(request).await {
1209 error!("[session {session_id}] logout request failed: {err}");
1210 } else {
1211 info!("[session {session_id}] logout success");
1212 };
1213 }
1214
1215 pub async fn close(&self) {
1216 let session_id = &self.session_id;
1217 info!("[session {session_id}] try closing now");
1218 if self
1219 .closed
1220 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1221 .is_ok()
1222 {
1223 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1224 if self.need_logout() {
1225 let cli = self.cli.clone();
1226 let req = self
1227 .build_log_out_request()
1228 .expect("failed to build logout request");
1229 Self::logout(cli, req, &self.session_id).await;
1230 }
1231 }
1232 }
1233 pub fn close_with_spawn(&self) {
1234 let session_id = &self.session_id;
1235 info!("[session {session_id}]: try closing with spawn");
1236 if self
1237 .closed
1238 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1239 .is_ok()
1240 {
1241 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1242 if self.need_logout() {
1243 let cli = self.cli.clone();
1244 let req = self
1245 .build_log_out_request()
1246 .expect("failed to build logout request");
1247 let session_id = self.session_id.clone();
1248 GLOBAL_RUNTIME.spawn(async move {
1249 Self::logout(cli, req, session_id.as_str()).await;
1250 });
1251 }
1252 }
1253 }
1254
1255 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1256 let mut queries = self.queries_need_heartbeat.lock();
1257 queries.insert(query_id.to_string(), state);
1258 }
1259}
1260
1261fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1262where
1263 T: Deserialize<'a>,
1264{
1265 serde_json::from_slice::<T>(body).map_err(|e| {
1266 Error::Decode(format!(
1267 "fail to decode JSON response: {e}, body: {}",
1268 String::from_utf8_lossy(body)
1269 ))
1270 })
1271}
1272
1273impl Default for APIClient {
1274 fn default() -> Self {
1275 Self {
1276 session_id: Default::default(),
1277 cli: HttpClient::new(),
1278 scheme: "http".to_string(),
1279 endpoint: Url::parse("http://localhost:8080").unwrap(),
1280 host: "localhost".to_string(),
1281 port: 8000,
1282 tenant: None,
1283 warehouse: Mutex::new(None),
1284 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1285 session_state: Mutex::new(SessionState::default()),
1286 wait_time_secs: None,
1287 max_rows_in_buffer: None,
1288 max_rows_per_page: None,
1289 connect_timeout: Duration::from_secs(10),
1290 page_request_timeout: Duration::from_secs(30),
1291 tls_ca_file: None,
1292 presign: Mutex::new(PresignMode::Auto),
1293 route_hint: RouteHintGenerator::new(),
1294 last_node_id: Default::default(),
1295 disable_session_token: true,
1296 disable_login: false,
1297 query_result_format: "json".to_string(),
1298 session_token_info: None,
1299 closed: AtomicBool::new(false),
1300 last_query_id: Default::default(),
1301 server_version: None,
1302 capability: Default::default(),
1303 queries_need_heartbeat: Default::default(),
1304 }
1305 }
1306}
1307
1308struct RouteHintGenerator {
1309 nonce: AtomicU64,
1310 current: std::sync::Mutex<String>,
1311}
1312
1313impl RouteHintGenerator {
1314 fn new() -> Self {
1315 let gen = Self {
1316 nonce: AtomicU64::new(0),
1317 current: std::sync::Mutex::new("".to_string()),
1318 };
1319 gen.next();
1320 gen
1321 }
1322
1323 fn current(&self) -> String {
1324 let guard = self.current.lock().unwrap();
1325 guard.clone()
1326 }
1327
1328 fn set(&self, hint: &str) {
1329 let mut guard = self.current.lock().unwrap();
1330 *guard = hint.to_string();
1331 }
1332
1333 fn next(&self) -> String {
1334 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1335 let uuid = uuid::Uuid::new_v4();
1336 let current = format!("rh:{uuid}:{nonce:06}");
1337 let mut guard = self.current.lock().unwrap();
1338 guard.clone_from(¤t);
1339 current
1340 }
1341}
1342
1343#[cfg(test)]
1344mod test {
1345 use super::*;
1346
1347 #[tokio::test]
1348 async fn parse_dsn() -> Result<()> {
1349 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1350 let client = APIClient::from_dsn(dsn).await?;
1351 assert_eq!(client.host, "app.databend.com");
1352 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1353 assert_eq!(client.wait_time_secs, Some(10));
1354 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1355 assert_eq!(client.max_rows_per_page, Some(10000));
1356 assert_eq!(client.tenant, None);
1357 assert_eq!(
1358 *client.warehouse.try_lock().unwrap(),
1359 Some("wh".to_string())
1360 );
1361 Ok(())
1362 }
1363
1364 #[tokio::test]
1365 async fn parse_encoded_password() -> Result<()> {
1366 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1367 let client = APIClient::from_dsn(dsn).await?;
1368 assert_eq!(client.host(), "localhost");
1369 assert_eq!(client.port(), 443);
1370 Ok(())
1371 }
1372
1373 #[tokio::test]
1374 async fn parse_special_chars_password() -> Result<()> {
1375 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1376 let client = APIClient::from_dsn(dsn).await?;
1377 assert_eq!(client.host(), "localhost");
1378 assert_eq!(client.port(), 8000);
1379 Ok(())
1380 }
1381}