1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::{de, Deserialize};
49use serde_json::{json, Value};
50use std::collections::{BTreeMap, HashMap};
51use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tokio::time::sleep;
55use tokio_retry::strategy::jitter;
56use tokio_stream::StreamExt;
57use tokio_util::io::ReaderStream;
58use url::Url;
59
60const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
61const HEADER_TENANT: &str = "X-DATABEND-TENANT";
62const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
63const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
64const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
65const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
66const TXN_STATE_ACTIVE: &str = "Active";
67const HEADER_SQL: &str = "X-DATABEND-SQL";
68const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
69const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
70const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
71const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
72
73static VERSION: Lazy<String> = Lazy::new(|| {
74 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
75 version.to_string()
76});
77
78#[derive(Clone)]
79pub(crate) struct QueryState {
80 pub node_id: String,
81 pub last_access_time: Arc<Mutex<Instant>>,
82 pub timeout_secs: u64,
83}
84
85impl QueryState {
86 pub fn need_heartbeat(&self, now: Instant) -> bool {
87 let t = *self.last_access_time.lock();
88 now.duration_since(t).as_secs() > self.timeout_secs / 2
89 }
90}
91
92pub struct APIClient {
93 pub(crate) session_id: String,
94 cli: HttpClient,
95 scheme: String,
96 host: String,
97 port: u16,
98
99 endpoint: Url,
100
101 auth: Arc<dyn Auth>,
102
103 tenant: Option<String>,
104 warehouse: Mutex<Option<String>>,
105 session_state: Mutex<SessionState>,
106 route_hint: RouteHintGenerator,
107
108 disable_login: bool,
109 body_format: String,
110 disable_session_token: bool,
111 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
112
113 closed: AtomicBool,
114
115 server_version: Option<Version>,
116
117 wait_time_secs: Option<i64>,
118 max_rows_in_buffer: Option<i64>,
119 max_rows_per_page: Option<i64>,
120
121 connect_timeout: Duration,
122 page_request_timeout: Duration,
123
124 tls_ca_file: Option<String>,
125
126 presign: Mutex<PresignMode>,
127 last_node_id: Mutex<Option<String>>,
128 last_query_id: Mutex<Option<String>>,
129
130 capability: Capability,
131
132 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
133}
134
135impl Drop for APIClient {
136 fn drop(&mut self) {
137 self.close_with_spawn()
138 }
139}
140
141impl APIClient {
142 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
143 let mut client = Self::from_dsn(dsn).await?;
144 client.build_client(name).await?;
145 if !client.disable_login {
146 client.login().await?;
147 }
148 if client.session_id.is_empty() {
149 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
150 }
151 let client = Arc::new(client);
152 client.check_presign().await?;
153 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
154 Ok(client)
155 }
156
157 pub fn capability(&self) -> &Capability {
158 &self.capability
159 }
160
161 fn set_presign_mode(&self, mode: PresignMode) {
162 *self.presign.lock() = mode
163 }
164 fn get_presign_mode(&self) -> PresignMode {
165 *self.presign.lock()
166 }
167
168 async fn from_dsn(dsn: &str) -> Result<Self> {
169 let u = Url::parse(dsn)?;
170 let mut client = Self::default();
171 if let Some(host) = u.host_str() {
172 client.host = host.to_string();
173 }
174
175 if u.username() != "" {
176 let password = u.password().unwrap_or_default();
177 let password = percent_decode_str(password).decode_utf8()?;
178 client.auth = Arc::new(BasicAuth::new(u.username(), password));
179 }
180 let database = match u.path().trim_start_matches('/') {
181 "" => None,
182 s => Some(s.to_string()),
183 };
184 let mut role = None;
185 let mut scheme = "https";
186 let mut session_settings = BTreeMap::new();
187 for (k, v) in u.query_pairs() {
188 match k.as_ref() {
189 "wait_time_secs" => {
190 client.wait_time_secs = Some(v.parse()?);
191 }
192 "max_rows_in_buffer" => {
193 client.max_rows_in_buffer = Some(v.parse()?);
194 }
195 "max_rows_per_page" => {
196 client.max_rows_per_page = Some(v.parse()?);
197 }
198 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
199 "page_request_timeout_secs" => {
200 client.page_request_timeout = {
201 let secs: u64 = v.parse()?;
202 Duration::from_secs(secs)
203 };
204 }
205 "presign" => {
206 let presign_mode = match v.as_ref() {
207 "auto" => PresignMode::Auto,
208 "detect" => PresignMode::Detect,
209 "on" => PresignMode::On,
210 "off" => PresignMode::Off,
211 _ => {
212 return Err(Error::BadArgument(format!(
213 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
214 )))
215 }
216 };
217 client.set_presign_mode(presign_mode);
218 }
219 "tenant" => {
220 client.tenant = Some(v.to_string());
221 }
222 "warehouse" => {
223 client.warehouse = Mutex::new(Some(v.to_string()));
224 }
225 "role" => role = Some(v.to_string()),
226 "sslmode" => match v.as_ref() {
227 "disable" => scheme = "http",
228 "require" | "enable" => scheme = "https",
229 _ => {
230 return Err(Error::BadArgument(format!(
231 "Invalid value for sslmode: {v}"
232 )))
233 }
234 },
235 "tls_ca_file" => {
236 client.tls_ca_file = Some(v.to_string());
237 }
238 "access_token" => {
239 client.auth = Arc::new(AccessTokenAuth::new(v));
240 }
241 "access_token_file" => {
242 client.auth = Arc::new(AccessTokenFileAuth::new(v));
243 }
244 "login" => {
245 client.disable_login = match v.as_ref() {
246 "disable" => true,
247 "enable" => false,
248 _ => {
249 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
250 }
251 }
252 }
253 "session_token" => {
254 client.disable_session_token = match v.as_ref() {
255 "disable" => true,
256 "enable" => false,
257 _ => {
258 return Err(Error::BadArgument(format!(
259 "Invalid value for session_token: {v}"
260 )))
261 }
262 }
263 }
264 "body_format" => {
265 let v = v.to_string().to_lowercase();
266 match v.as_str() {
267 "json" | "arrow" => client.body_format = v.to_string(),
268 _ => {
269 return Err(Error::BadArgument(format!(
270 "Invalid value for body_format: {v}"
271 )))
272 }
273 }
274 }
275 _ => {
276 session_settings.insert(k.to_string(), v.to_string());
277 }
278 }
279 }
280 client.port = match u.port() {
281 Some(p) => p,
282 None => match scheme {
283 "http" => 80,
284 "https" => 443,
285 _ => unreachable!(),
286 },
287 };
288 client.scheme = scheme.to_string();
289
290 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
291 client.session_state = Mutex::new(
292 SessionState::default()
293 .with_settings(Some(session_settings))
294 .with_role(role)
295 .with_database(database),
296 );
297
298 Ok(client)
299 }
300
301 pub fn host(&self) -> &str {
302 self.host.as_str()
303 }
304
305 pub fn port(&self) -> u16 {
306 self.port
307 }
308
309 pub fn scheme(&self) -> &str {
310 self.scheme.as_str()
311 }
312
313 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
314 let ua = match name {
315 Some(n) => n,
316 None => format!("databend-client-rust/{}", VERSION.as_str()),
317 };
318 let cookie_provider = GlobalCookieStore::new();
319 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
320 let mut initial_cookies = [&cookie].into_iter();
321 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
322 let mut cli_builder = HttpClient::builder()
323 .user_agent(ua)
324 .cookie_provider(Arc::new(cookie_provider))
325 .pool_idle_timeout(Duration::from_secs(1));
326 #[cfg(any(feature = "rustls", feature = "native-tls"))]
327 if self.scheme == "https" {
328 if let Some(ref ca_file) = self.tls_ca_file {
329 let cert_pem = tokio::fs::read(ca_file).await?;
330 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
331 cli_builder = cli_builder.add_root_certificate(cert);
332 }
333 }
334 self.cli = cli_builder.build()?;
335 Ok(())
336 }
337
338 async fn check_presign(self: &Arc<Self>) -> Result<()> {
339 let mode = match self.get_presign_mode() {
340 PresignMode::Auto => {
341 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
342 PresignMode::On
343 } else {
344 PresignMode::Off
345 }
346 }
347 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
348 Ok(_) => PresignMode::On,
349 Err(e) => {
350 warn!("presign mode off with error detected: {e}");
351 PresignMode::Off
352 }
353 },
354 mode => mode,
355 };
356 self.set_presign_mode(mode);
357 Ok(())
358 }
359
360 pub fn current_warehouse(&self) -> Option<String> {
361 let guard = self.warehouse.lock();
362 guard.clone()
363 }
364
365 pub fn current_catalog(&self) -> Option<String> {
366 let guard = self.session_state.lock();
367 guard.catalog.clone()
368 }
369
370 pub fn current_database(&self) -> Option<String> {
371 let guard = self.session_state.lock();
372 guard.database.clone()
373 }
374
375 pub async fn current_role(&self) -> Option<String> {
376 let guard = self.session_state.lock();
377 guard.role.clone()
378 }
379
380 fn in_active_transaction(&self) -> bool {
381 let guard = self.session_state.lock();
382 guard
383 .txn_state
384 .as_ref()
385 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
386 .unwrap_or(false)
387 }
388
389 pub fn username(&self) -> String {
390 self.auth.username()
391 }
392
393 fn gen_query_id(&self) -> String {
394 uuid::Uuid::now_v7().simple().to_string()
395 }
396
397 async fn handle_session(&self, session: &Option<SessionState>) {
398 let session = match session {
399 Some(session) => session,
400 None => return,
401 };
402
403 {
405 let mut session_state = self.session_state.lock();
406 *session_state = session.clone();
407 }
408
409 if let Some(settings) = session.settings.as_ref() {
411 if let Some(v) = settings.get("warehouse") {
412 let mut warehouse = self.warehouse.lock();
413 *warehouse = Some(v.clone());
414 }
415 }
416 }
417
418 pub fn set_last_node_id(&self, node_id: String) {
419 *self.last_node_id.lock() = Some(node_id)
420 }
421
422 pub fn set_last_query_id(&self, query_id: Option<String>) {
423 *self.last_query_id.lock() = query_id
424 }
425
426 pub fn last_query_id(&self) -> Option<String> {
427 self.last_query_id.lock().clone()
428 }
429
430 fn last_node_id(&self) -> Option<String> {
431 self.last_node_id.lock().clone()
432 }
433
434 fn handle_warnings(&self, resp: &QueryResponse) {
435 if let Some(warnings) = &resp.warnings {
436 for w in warnings {
437 warn!(target: "server_warnings", "server warning: {w}");
438 }
439 }
440 }
441
442 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
443 info!("start query: {sql}");
444 let (resp, batches) = self.start_query_inner(sql, None, false).await?;
445 Pages::new(self.clone(), resp, batches, need_progress)
446 }
447
448 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
449 let mut mgr = self.queries_need_heartbeat.lock();
450 if let Some(state) = mgr.remove(query_id) {
451 let self_cloned = self.clone();
452 let query_id = query_id.to_owned();
453 GLOBAL_RUNTIME.spawn(async move {
454 if let Err(e) = self_cloned
455 .end_query(&query_id, "final", Some(state.node_id.as_str()))
456 .await
457 {
458 error!("failed to final query {query_id}: {e}");
459 }
460 });
461 }
462 }
463
464 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
465 if let Some(info) = &self.session_token_info {
466 let info = info.lock();
467 Ok(builder.bearer_auth(info.0.session_token.clone()))
468 } else {
469 self.auth.wrap(builder)
470 }
471 }
472
473 async fn start_query_inner(
474 &self,
475 sql: &str,
476 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
477 force_json_body: bool,
478 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
479 if !self.in_active_transaction() {
480 self.route_hint.next();
481 }
482 let endpoint = self.endpoint.join("v1/query")?;
483
484 let session_state = self.session_state();
486 let need_sticky = session_state.need_sticky.unwrap_or(false);
487 let req = QueryRequest::new(sql)
488 .with_pagination(self.make_pagination())
489 .with_session(Some(session_state))
490 .with_stage_attachment(stage_attachment_config);
491
492 let query_id = self.gen_query_id();
494 let mut headers = self.make_headers(Some(&query_id))?;
495 if self.capability.arrow_data && self.body_format == "arrow" && !force_json_body {
496 debug!("accept arrow data");
497 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
498 }
499
500 if need_sticky {
501 if let Some(node_id) = self.last_node_id() {
502 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
503 }
504 }
505 let mut builder = self.cli.post(endpoint.clone()).json(&req);
506 builder = self.wrap_auth_or_session_token(builder)?;
507 let request = builder.headers(headers.clone()).build()?;
508 let response = self.query_request_helper(request, true, true).await?;
509 self.handle_page(response, true).await
510 }
511
512 fn is_arrow_data(response: &Response) -> bool {
513 if let Some(typ) = response.headers().get(CONTENT_TYPE) {
514 if let Ok(t) = typ.to_str() {
515 return t == CONTENT_TYPE_ARROW;
516 }
517 }
518 false
519 }
520
521 async fn handle_page(
522 &self,
523 response: Response,
524 is_first: bool,
525 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
526 let status = response.status();
527 if status != 200 {
528 return Err(Error::response_error(status, &response.bytes().await?));
529 }
530 let is_arrow_data = Self::is_arrow_data(&response);
531 if is_first {
532 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
533 self.route_hint.set(route_hint.to_str().unwrap_or_default());
534 }
535 }
536 let mut body = response.bytes().await?;
537 let mut batches = vec![];
538 if is_arrow_data {
539 if is_first {
540 debug!("received arrow data");
541 }
542 let cursor = std::io::Cursor::new(body.as_ref());
543 let reader = StreamReader::try_new(cursor, None)
544 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
545 let schema = reader.schema();
546 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
547 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
548 } else {
549 return Err(Error::Decode(
550 "missing response_header metadata in arrow payload".to_string(),
551 ));
552 };
553 for batch in reader {
554 let batch = batch
555 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
556 batches.push(batch);
557 }
558 body = json_body
559 };
560 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
561 if let Error::Logic(status, ec) = &e {
562 if *status == 404 {
563 return Error::QueryNotFound(ec.message.clone());
564 }
565 }
566 e
567 })?;
568 self.handle_session(&resp.session).await;
569 if let Some(err) = &resp.error {
570 return Err(Error::QueryFailed(err.clone()));
571 }
572 if is_first {
573 self.handle_warnings(&resp);
574 self.set_last_query_id(Some(resp.id.clone()));
575 if let Some(node_id) = &resp.node_id {
576 self.set_last_node_id(node_id.clone());
577 }
578 }
579 Ok((resp, batches))
580 }
581
582 pub async fn query_page(
583 &self,
584 query_id: &str,
585 next_uri: &str,
586 node_id: &Option<String>,
587 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
588 info!("query page: {next_uri}");
589 let endpoint = self.endpoint.join(next_uri)?;
590 let mut headers = self.make_headers(Some(query_id))?;
591 if self.capability.arrow_data && self.body_format == "arrow" {
592 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
593 }
594 let mut builder = self.cli.get(endpoint.clone());
595 builder = self
596 .wrap_auth_or_session_token(builder)?
597 .headers(headers.clone())
598 .timeout(self.page_request_timeout);
599 if let Some(node_id) = node_id {
600 builder = builder.header(HEADER_STICKY_NODE, node_id)
601 }
602 let request = builder.build()?;
603
604 let response = self.query_request_helper(request, false, true).await?;
605 self.handle_page(response, false).await
606 }
607
608 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
609 self.end_query(query_id, "kill", None).await
610 }
611
612 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
613 self.end_query(query_id, "final", node_id).await
614 }
615
616 pub async fn end_query(
617 &self,
618 query_id: &str,
619 method: &str,
620 node_id: Option<&str>,
621 ) -> Result<()> {
622 let uri = format!("/v1/query/{query_id}/{method}");
623 let endpoint = self.endpoint.join(&uri)?;
624 let headers = self.make_headers(Some(query_id))?;
625
626 info!("{method} query: {uri}");
627
628 let mut builder = self.cli.post(endpoint);
629 if let Some(node_id) = node_id {
630 builder = builder.header(HEADER_STICKY_NODE, node_id)
631 }
632 builder = self.wrap_auth_or_session_token(builder)?;
633 let resp = builder.headers(headers.clone()).send().await?;
634 if resp.status() != 200 {
635 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
636 .with_context(&format!("{method} query")));
637 }
638 Ok(())
639 }
640
641 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
642 self.query_all_inner(sql, false).await
643 }
644
645 pub async fn query_all_inner(
646 self: &Arc<Self>,
647 sql: &str,
648 force_json_body: bool,
649 ) -> Result<Page> {
650 let (resp, batches) = self.start_query_inner(sql, None, force_json_body).await?;
651 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
652 let mut all = Page::default();
653 while let Some(page) = pages.next().await {
654 all.update(page?);
655 }
656 Ok(all)
657 }
658
659 fn session_state(&self) -> SessionState {
660 self.session_state.lock().clone()
661 }
662
663 fn make_pagination(&self) -> Option<PaginationConfig> {
664 if self.wait_time_secs.is_none()
665 && self.max_rows_in_buffer.is_none()
666 && self.max_rows_per_page.is_none()
667 {
668 return None;
669 }
670 let mut pagination = PaginationConfig {
671 wait_time_secs: None,
672 max_rows_in_buffer: None,
673 max_rows_per_page: None,
674 };
675 if let Some(wait_time_secs) = self.wait_time_secs {
676 pagination.wait_time_secs = Some(wait_time_secs);
677 }
678 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
679 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
680 }
681 if let Some(max_rows_per_page) = self.max_rows_per_page {
682 pagination.max_rows_per_page = Some(max_rows_per_page);
683 }
684 Some(pagination)
685 }
686
687 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
688 let mut headers = HeaderMap::new();
689 if let Some(tenant) = &self.tenant {
690 headers.insert(HEADER_TENANT, tenant.parse()?);
691 }
692 let warehouse = self.warehouse.lock().clone();
693 if let Some(warehouse) = warehouse {
694 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
695 }
696 let route_hint = self.route_hint.current();
697 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
698 if let Some(query_id) = query_id {
699 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
700 }
701 Ok(headers)
702 }
703
704 pub async fn insert_with_stage(
705 self: &Arc<Self>,
706 sql: &str,
707 stage: &str,
708 file_format_options: BTreeMap<&str, &str>,
709 copy_options: BTreeMap<&str, &str>,
710 ) -> Result<QueryStats> {
711 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
712 let stage_attachment = Some(StageAttachmentConfig {
713 location: stage,
714 file_format_options: Some(file_format_options),
715 copy_options: Some(copy_options),
716 });
717 let (resp, batches) = self.start_query_inner(sql, stage_attachment, true).await?;
718 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
719 let mut all = Page::default();
720 while let Some(page) = pages.next().await {
721 all.update(page?);
722 }
723 Ok(all.stats)
724 }
725
726 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
727 info!("get presigned upload url: {stage}");
728 let sql = format!("PRESIGN UPLOAD {stage}");
729 let resp = self.query_all_inner(&sql, true).await?;
730 if resp.data.len() != 1 {
731 return Err(Error::Decode(
732 "Empty response from server for presigned request".to_string(),
733 ));
734 }
735 if resp.data[0].len() != 3 {
736 return Err(Error::Decode(
737 "Invalid response from server for presigned request".to_string(),
738 ));
739 }
740 let method = resp.data[0][0].clone().unwrap_or_default();
742 if method != "PUT" {
743 return Err(Error::Decode(format!(
744 "Invalid method for presigned upload request: {method}"
745 )));
746 }
747 let headers: BTreeMap<String, String> =
748 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
749 let url = resp.data[0][2].clone().unwrap_or_default();
750 Ok(PresignedResponse {
751 method,
752 headers,
753 url,
754 })
755 }
756
757 pub async fn upload_to_stage(
758 self: &Arc<Self>,
759 stage: &str,
760 data: Reader,
761 size: u64,
762 ) -> Result<()> {
763 match self.get_presign_mode() {
764 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
765 PresignMode::On => {
766 let presigned = self.get_presigned_upload_url(stage).await?;
767 presign_upload_to_stage(presigned, data, size).await
768 }
769 PresignMode::Auto => {
770 unreachable!("PresignMode::Auto should be handled during client initialization")
771 }
772 PresignMode::Detect => {
773 unreachable!("PresignMode::Detect should be handled during client initialization")
774 }
775 }
776 }
777
778 async fn upload_to_stage_with_stream(
780 &self,
781 stage: &str,
782 data: Reader,
783 size: u64,
784 ) -> Result<()> {
785 info!("upload to stage with stream: {stage}, size: {size}");
786 if let Some(info) = self.need_pre_refresh_session().await {
787 self.refresh_session_token(info).await?;
788 }
789 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
790 let location = StageLocation::try_from(stage)?;
791 let query_id = self.gen_query_id();
792 let mut headers = self.make_headers(Some(&query_id))?;
793 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
794 let stream = Body::wrap_stream(ReaderStream::new(data));
795 let part = Part::stream_with_length(stream, size).file_name(location.path);
796 let form = Form::new().part("upload", part);
797 let mut builder = self.cli.put(endpoint.clone());
798 builder = self.wrap_auth_or_session_token(builder)?;
799 let resp = builder.headers(headers).multipart(form).send().await?;
800 let status = resp.status();
801 if status != 200 {
802 return Err(
803 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
804 );
805 }
806 Ok(())
807 }
808
809 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
812 where
813 T: de::DeserializeOwned,
814 {
815 if value.starts_with("{") {
816 serde_json::from_slice(value.as_bytes())
817 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
818 } else {
819 let json = URL_SAFE.decode(value).map_err(|e| {
820 format!(
821 "Invalid value {} for {key}, base64 decode error: {}",
822 value, e
823 )
824 })?;
825 serde_json::from_slice(&json).map_err(|e| {
826 format!(
827 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
828 String::from_utf8_lossy(&json)
829 )
830 })
831 }
832 }
833
834 pub async fn streaming_load(
835 &self,
836 sql: &str,
837 data: Reader,
838 file_name: &str,
839 ) -> Result<LoadResponse> {
840 let body = Body::wrap_stream(ReaderStream::new(data));
841 let part = Part::stream(body).file_name(file_name.to_string());
842 let endpoint = self.endpoint.join("v1/streaming_load")?;
843 let mut builder = self.cli.put(endpoint.clone());
844 builder = self.wrap_auth_or_session_token(builder)?;
845 let query_id = self.gen_query_id();
846 let mut headers = self.make_headers(Some(&query_id))?;
847 headers.insert(HEADER_SQL, sql.parse()?);
848 let session = serde_json::to_string(&*self.session_state.lock())
849 .expect("serialize session state should not fail");
850 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
851 let form = Form::new().part("upload", part);
852 let resp = builder.headers(headers).multipart(form).send().await?;
853 let status = resp.status();
854 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
855 match Self::decode_json_header::<SessionState>(
856 HEADER_QUERY_CONTEXT,
857 value.to_str().unwrap(),
858 ) {
859 Ok(session) => *self.session_state.lock() = session,
860 Err(e) => {
861 error!("Error decoding session state when streaming load: {e}");
862 }
863 }
864 };
865 if status != 200 {
866 return Err(
867 Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
868 );
869 }
870 let resp = resp.json::<LoadResponse>().await?;
871 Ok(resp)
872 }
873
874 async fn login(&mut self) -> Result<()> {
875 let endpoint = self.endpoint.join("/v1/session/login")?;
876 let headers = self.make_headers(None)?;
877 let body = LoginRequest::from(&*self.session_state.lock());
878 let mut builder = self.cli.post(endpoint.clone()).json(&body);
879 if self.disable_session_token {
880 builder = builder.query(&[("disable_session_token", true)]);
881 }
882 let builder = self.auth.wrap(builder)?;
883 let request = builder
884 .headers(headers.clone())
885 .timeout(self.connect_timeout)
886 .build()?;
887 let response = self.query_request_helper(request, true, false).await;
888 let response = match response {
889 Ok(r) => r,
890 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
891 info!("login return 404, skip login on the old version server");
892 return Ok(());
893 }
894 Err(e) => return Err(e),
895 };
896 if let Some(v) = response.headers().get(HEADER_SESSION_ID) {
897 if let Ok(s) = v.to_str() {
898 self.session_id = s.to_string();
899 }
900 }
901
902 let body = response.bytes().await?;
903 let response = json_from_slice(&body)?;
904 match response {
905 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
906 LoginResponseResult::Ok(info) => {
907 let server_version = info
908 .version
909 .parse()
910 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
911 self.capability = Capability::from_server_version(&server_version);
912 self.server_version = Some(server_version.clone());
913 let session_id = self.session_id.as_str();
914 if let Some(tokens) = info.tokens {
915 info!(
916 "[session {session_id}] login success with session token version = {server_version}",
917 );
918 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
919 } else {
920 info!("[session {session_id}] login success, version = {server_version}");
921 }
922 }
923 }
924 Ok(())
925 }
926
927 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
928 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
929 let queries = self.queries_need_heartbeat.lock().clone();
930 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
931 let now = Instant::now();
932
933 let mut query_ids = Vec::new();
934 for (qid, state) in queries {
935 if state.need_heartbeat(now) {
936 query_ids.push(qid.to_string());
937 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
938 arr.push(qid);
939 } else {
940 node_to_queries.insert(state.node_id, vec![qid]);
941 }
942 }
943 }
944
945 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
946 {
947 return Ok(());
948 }
949
950 let body = json!({
951 "node_to_queries": node_to_queries
952 });
953 let builder = self.cli.post(endpoint.clone()).json(&body);
954 let request = self.wrap_auth_or_session_token(builder)?.build()?;
955 let response = self.query_request_helper(request, true, false).await?;
956 let json: Value = response.json().await?;
957 let session_id = self.session_id.as_str();
958 info!("[session {session_id}] heartbeat request={body}, response={json}");
959 if let Some(queries_to_remove) = json.get("queries_to_remove") {
960 if let Some(arr) = queries_to_remove.as_array() {
961 if !arr.is_empty() {
962 let mut queries = self.queries_need_heartbeat.lock();
963 for q in arr {
964 if let Some(q) = q.as_str() {
965 queries.remove(q);
966 }
967 }
968 }
969 }
970 }
971 let now = Instant::now();
972 let mut queries = self.queries_need_heartbeat.lock();
973 for qid in query_ids {
974 if let Some(state) = queries.get_mut(&qid) {
975 *state.last_access_time.lock() = now;
976 }
977 }
978 Ok(())
979 }
980
981 fn build_log_out_request(&self) -> Result<Request> {
982 let endpoint = self.endpoint.join("/v1/session/logout")?;
983
984 let session_state = self.session_state();
985 let need_sticky = session_state.need_sticky.unwrap_or(false);
986 let mut headers = self.make_headers(None)?;
987 if need_sticky {
988 if let Some(node_id) = self.last_node_id() {
989 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
990 }
991 }
992 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
993
994 let builder = self.wrap_auth_or_session_token(builder)?;
995 let req = builder.build()?;
996 Ok(req)
997 }
998
999 pub(crate) fn need_logout(&self) -> bool {
1000 self.session_token_info.is_some()
1001 || self.session_state.lock().need_keep_alive.unwrap_or(false)
1002 }
1003
1004 async fn refresh_session_token(
1005 &self,
1006 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
1007 ) -> Result<()> {
1008 let (session_token_info, _) = { self_login_info.lock().clone() };
1009 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1010 let body = RefreshSessionTokenRequest {
1011 session_token: session_token_info.session_token.clone(),
1012 };
1013 let headers = self.make_headers(None)?;
1014 let request = self
1015 .cli
1016 .post(endpoint.clone())
1017 .json(&body)
1018 .headers(headers.clone())
1019 .bearer_auth(session_token_info.refresh_token.clone())
1020 .timeout(self.connect_timeout)
1021 .build()?;
1022
1023 for i in 0..3 {
1025 let req = request.try_clone().expect("request not cloneable");
1026 match self.cli.execute(req).await {
1027 Ok(response) => {
1028 let status = response.status();
1029 let body = response.bytes().await?;
1030 if status == StatusCode::OK {
1031 let response = json_from_slice(&body)?;
1032 return match response {
1033 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1034 RefreshResponse::Ok(info) => {
1035 *self_login_info.lock() = (info, Instant::now());
1036 Ok(())
1037 }
1038 };
1039 }
1040 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
1041 return Err(Error::response_error(status, &body));
1042 }
1043 }
1044 Err(err) => {
1045 if !(err.is_timeout() || err.is_connect()) || i > 2 {
1046 return Err(Error::Request(err.to_string()));
1047 }
1048 }
1049 };
1050 sleep(jitter(Duration::from_secs(10))).await;
1051 }
1052 Ok(())
1053 }
1054
1055 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1056 if let Some(info) = &self.session_token_info {
1057 let (start, ttl) = {
1058 let guard = info.lock();
1059 (guard.1, guard.0.session_token_ttl_in_secs)
1060 };
1061 if Instant::now() > start + Duration::from_secs(ttl) {
1062 return Some(info.clone());
1063 }
1064 }
1065 None
1066 }
1067
1068 async fn query_request_helper(
1076 &self,
1077 mut request: Request,
1078 retry_if_503: bool,
1079 refresh_if_401: bool,
1080 ) -> std::result::Result<Response, Error> {
1081 let mut refreshed = false;
1082 let mut retries = 0;
1083 loop {
1084 let req = request.try_clone().expect("request not cloneable");
1085 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
1086 Ok(response) => {
1087 let status = response.status();
1088 if status == StatusCode::OK {
1089 return Ok(response);
1090 }
1091 let body = response.bytes().await?;
1092 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1093 (Error::response_error(status, &body), true)
1095 } else {
1096 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
1097 match resp {
1098 Ok(r) => {
1099 let e = r.error;
1100 if status == StatusCode::UNAUTHORIZED {
1101 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1102 if let Some(session_token_info) = &self.session_token_info {
1103 info!(
1104 "will retry {} after refresh token on auth error {}",
1105 request.url(),
1106 e
1107 );
1108 let retry = if need_refresh_token(e.code)
1109 && !refreshed
1110 && refresh_if_401
1111 {
1112 self.refresh_session_token(session_token_info.clone())
1113 .await?;
1114 refreshed = true;
1115 true
1116 } else {
1117 false
1118 };
1119 (Error::AuthFailure(e), retry)
1120 } else if self.auth.can_reload() {
1121 info!(
1122 "will retry {} after reload token on auth error {}",
1123 request.url(),
1124 e
1125 );
1126 let builder = RequestBuilder::from_parts(
1127 HttpClient::new(),
1128 request.try_clone().unwrap(),
1129 );
1130 let builder = self.auth.wrap(builder)?;
1131 request = builder.build()?;
1132 (Error::AuthFailure(e), true)
1133 } else {
1134 (Error::AuthFailure(e), false)
1135 }
1136 } else {
1137 (Error::Logic(status, e), false)
1138 }
1139 }
1140 Err(_) => (
1141 Error::Response {
1142 status,
1143 msg: String::from_utf8_lossy(&body).to_string(),
1144 },
1145 false,
1146 ),
1147 }
1148 }
1149 }
1150 Err(err) => (
1151 Error::Request(err.to_string()),
1152 err.is_timeout() || err.is_connect(),
1153 ),
1154 };
1155 if !retry {
1156 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
1157 }
1158 match &err {
1159 Error::AuthFailure(_) => {
1160 if refreshed {
1161 retries = 0;
1162 } else if retries == 2 {
1163 return Err(err.with_context(&format!(
1164 "{} {} after 3 reties",
1165 request.method(),
1166 request.url()
1167 )));
1168 }
1169 }
1170 _ => {
1171 if retries == 2 {
1172 return Err(err.with_context(&format!(
1173 "{} {} after 3 reties",
1174 request.method(),
1175 request.url()
1176 )));
1177 }
1178 retries += 1;
1179 info!(
1180 "will retry {} the {retries}th times on error {}",
1181 request.url(),
1182 err
1183 );
1184 }
1185 }
1186 sleep(jitter(Duration::from_secs(10))).await;
1187 }
1188 }
1189
1190 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1191 if let Err(err) = http_client.execute(request).await {
1192 error!("[session {session_id}] logout request failed: {err}");
1193 } else {
1194 info!("[session {session_id}] logout success");
1195 };
1196 }
1197
1198 pub async fn close(&self) {
1199 let session_id = &self.session_id;
1200 info!("[session {session_id}] try closing now");
1201 if self
1202 .closed
1203 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1204 .is_ok()
1205 {
1206 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1207 if self.need_logout() {
1208 let cli = self.cli.clone();
1209 let req = self
1210 .build_log_out_request()
1211 .expect("failed to build logout request");
1212 Self::logout(cli, req, &self.session_id).await;
1213 }
1214 }
1215 }
1216 pub fn close_with_spawn(&self) {
1217 let session_id = &self.session_id;
1218 info!("[session {session_id}]: try closing with spawn");
1219 if self
1220 .closed
1221 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1222 .is_ok()
1223 {
1224 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1225 if self.need_logout() {
1226 let cli = self.cli.clone();
1227 let req = self
1228 .build_log_out_request()
1229 .expect("failed to build logout request");
1230 let session_id = self.session_id.clone();
1231 GLOBAL_RUNTIME.spawn(async move {
1232 Self::logout(cli, req, session_id.as_str()).await;
1233 });
1234 }
1235 }
1236 }
1237
1238 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1239 let mut queries = self.queries_need_heartbeat.lock();
1240 queries.insert(query_id.to_string(), state);
1241 }
1242}
1243
1244fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1245where
1246 T: Deserialize<'a>,
1247{
1248 serde_json::from_slice::<T>(body).map_err(|e| {
1249 Error::Decode(format!(
1250 "fail to decode JSON response: {e}, body: {}",
1251 String::from_utf8_lossy(body)
1252 ))
1253 })
1254}
1255
1256impl Default for APIClient {
1257 fn default() -> Self {
1258 Self {
1259 session_id: Default::default(),
1260 cli: HttpClient::new(),
1261 scheme: "http".to_string(),
1262 endpoint: Url::parse("http://localhost:8080").unwrap(),
1263 host: "localhost".to_string(),
1264 port: 8000,
1265 tenant: None,
1266 warehouse: Mutex::new(None),
1267 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1268 session_state: Mutex::new(SessionState::default()),
1269 wait_time_secs: None,
1270 max_rows_in_buffer: None,
1271 max_rows_per_page: None,
1272 connect_timeout: Duration::from_secs(10),
1273 page_request_timeout: Duration::from_secs(30),
1274 tls_ca_file: None,
1275 presign: Mutex::new(PresignMode::Auto),
1276 route_hint: RouteHintGenerator::new(),
1277 last_node_id: Default::default(),
1278 disable_session_token: true,
1279 disable_login: false,
1280 body_format: "json".to_string(),
1281 session_token_info: None,
1282 closed: AtomicBool::new(false),
1283 last_query_id: Default::default(),
1284 server_version: None,
1285 capability: Default::default(),
1286 queries_need_heartbeat: Default::default(),
1287 }
1288 }
1289}
1290
1291struct RouteHintGenerator {
1292 nonce: AtomicU64,
1293 current: std::sync::Mutex<String>,
1294}
1295
1296impl RouteHintGenerator {
1297 fn new() -> Self {
1298 let gen = Self {
1299 nonce: AtomicU64::new(0),
1300 current: std::sync::Mutex::new("".to_string()),
1301 };
1302 gen.next();
1303 gen
1304 }
1305
1306 fn current(&self) -> String {
1307 let guard = self.current.lock().unwrap();
1308 guard.clone()
1309 }
1310
1311 fn set(&self, hint: &str) {
1312 let mut guard = self.current.lock().unwrap();
1313 *guard = hint.to_string();
1314 }
1315
1316 fn next(&self) -> String {
1317 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1318 let uuid = uuid::Uuid::new_v4();
1319 let current = format!("rh:{uuid}:{nonce:06}");
1320 let mut guard = self.current.lock().unwrap();
1321 guard.clone_from(¤t);
1322 current
1323 }
1324}
1325
1326#[cfg(test)]
1327mod test {
1328 use super::*;
1329
1330 #[tokio::test]
1331 async fn parse_dsn() -> Result<()> {
1332 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1333 let client = APIClient::from_dsn(dsn).await?;
1334 assert_eq!(client.host, "app.databend.com");
1335 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1336 assert_eq!(client.wait_time_secs, Some(10));
1337 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1338 assert_eq!(client.max_rows_per_page, Some(10000));
1339 assert_eq!(client.tenant, None);
1340 assert_eq!(
1341 *client.warehouse.try_lock().unwrap(),
1342 Some("wh".to_string())
1343 );
1344 Ok(())
1345 }
1346
1347 #[tokio::test]
1348 async fn parse_encoded_password() -> Result<()> {
1349 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1350 let client = APIClient::from_dsn(dsn).await?;
1351 assert_eq!(client.host(), "localhost");
1352 assert_eq!(client.port(), 443);
1353 Ok(())
1354 }
1355
1356 #[tokio::test]
1357 async fn parse_special_chars_password() -> Result<()> {
1358 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1359 let client = APIClient::from_dsn(dsn).await?;
1360 assert_eq!(client.host(), "localhost");
1361 assert_eq!(client.port(), 8000);
1362 Ok(())
1363 }
1364}