1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::client_mgr::{GLOBAL_CLIENT_MANAGER, GLOBAL_RUNTIME};
18use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
19use crate::global_cookie_store::GlobalCookieStore;
20use crate::login::{
21 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
22 SessionTokenInfo,
23};
24use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
25use crate::response::LoadResponse;
26use crate::stage::StageLocation;
27use crate::{
28 error::{Error, Result},
29 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
30 response::QueryResponse,
31 session::SessionState,
32 QueryStats,
33};
34use crate::{Page, Pages};
35use arrow_array::RecordBatch;
36use arrow_ipc::reader::StreamReader;
37use base64::engine::general_purpose::URL_SAFE;
38use base64::Engine;
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, CONTENT_TYPE};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::{de, Deserialize};
49use serde_json::{json, Value};
50use std::collections::{BTreeMap, HashMap};
51use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
52use std::sync::Arc;
53use std::time::{Duration, Instant};
54use tokio::time::sleep;
55use tokio_retry::strategy::jitter;
56use tokio_stream::StreamExt;
57use tokio_util::io::ReaderStream;
58use url::Url;
59
60const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
61const HEADER_TENANT: &str = "X-DATABEND-TENANT";
62const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
63const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
64const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
65const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
66const TXN_STATE_ACTIVE: &str = "Active";
67const HEADER_SQL: &str = "X-DATABEND-SQL";
68const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
69const HEADER_SESSION_ID: &str = "X-DATABEND-SESSION-ID";
70const CONTENT_TYPE_ARROW: &str = "application/vnd.apache.arrow.stream";
71const CONTENT_TYPE_ARROW_OR_JSON: &str = "application/vnd.apache.arrow.stream";
72
73static VERSION: Lazy<String> = Lazy::new(|| {
74 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
75 version.to_string()
76});
77
78#[derive(Clone)]
79pub(crate) struct QueryState {
80 pub node_id: String,
81 pub last_access_time: Arc<Mutex<Instant>>,
82 pub timeout_secs: u64,
83}
84
85impl QueryState {
86 pub fn need_heartbeat(&self, now: Instant) -> bool {
87 let t = *self.last_access_time.lock();
88 now.duration_since(t).as_secs() > self.timeout_secs / 2
89 }
90}
91
92pub struct APIClient {
93 pub(crate) session_id: String,
94 cli: HttpClient,
95 scheme: String,
96 host: String,
97 port: u16,
98
99 endpoint: Url,
100
101 auth: Arc<dyn Auth>,
102
103 tenant: Option<String>,
104 warehouse: Mutex<Option<String>>,
105 session_state: Mutex<SessionState>,
106 route_hint: RouteHintGenerator,
107
108 disable_login: bool,
109 body_format: String,
110 disable_session_token: bool,
111 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
112
113 closed: AtomicBool,
114
115 server_version: Option<Version>,
116
117 wait_time_secs: Option<i64>,
118 max_rows_in_buffer: Option<i64>,
119 max_rows_per_page: Option<i64>,
120
121 connect_timeout: Duration,
122 page_request_timeout: Duration,
123
124 tls_ca_file: Option<String>,
125
126 presign: Mutex<PresignMode>,
127 last_node_id: Mutex<Option<String>>,
128 last_query_id: Mutex<Option<String>>,
129
130 capability: Capability,
131
132 queries_need_heartbeat: Mutex<HashMap<String, QueryState>>,
133}
134
135impl Drop for APIClient {
136 fn drop(&mut self) {
137 self.close_with_spawn()
138 }
139}
140
141impl APIClient {
142 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
143 let mut client = Self::from_dsn(dsn).await?;
144 client.build_client(name).await?;
145 if !client.disable_login {
146 client.login().await?;
147 }
148 if client.session_id.is_empty() {
149 client.session_id = format!("no_login_{}", uuid::Uuid::new_v4());
150 }
151 let client = Arc::new(client);
152 client.check_presign().await?;
153 GLOBAL_CLIENT_MANAGER.register_client(client.clone()).await;
154 Ok(client)
155 }
156
157 pub fn capability(&self) -> &Capability {
158 &self.capability
159 }
160
161 fn set_presign_mode(&self, mode: PresignMode) {
162 *self.presign.lock() = mode
163 }
164 fn get_presign_mode(&self) -> PresignMode {
165 *self.presign.lock()
166 }
167
168 async fn from_dsn(dsn: &str) -> Result<Self> {
169 let u = Url::parse(dsn)?;
170 let mut client = Self::default();
171 if let Some(host) = u.host_str() {
172 client.host = host.to_string();
173 }
174
175 if u.username() != "" {
176 let password = u.password().unwrap_or_default();
177 let password = percent_decode_str(password).decode_utf8()?;
178 client.auth = Arc::new(BasicAuth::new(u.username(), password));
179 }
180 let database = match u.path().trim_start_matches('/') {
181 "" => None,
182 s => Some(s.to_string()),
183 };
184 let mut role = None;
185 let mut scheme = "https";
186 let mut session_settings = BTreeMap::new();
187 for (k, v) in u.query_pairs() {
188 match k.as_ref() {
189 "wait_time_secs" => {
190 client.wait_time_secs = Some(v.parse()?);
191 }
192 "max_rows_in_buffer" => {
193 client.max_rows_in_buffer = Some(v.parse()?);
194 }
195 "max_rows_per_page" => {
196 client.max_rows_per_page = Some(v.parse()?);
197 }
198 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
199 "page_request_timeout_secs" => {
200 client.page_request_timeout = {
201 let secs: u64 = v.parse()?;
202 Duration::from_secs(secs)
203 };
204 }
205 "presign" => {
206 let presign_mode = match v.as_ref() {
207 "auto" => PresignMode::Auto,
208 "detect" => PresignMode::Detect,
209 "on" => PresignMode::On,
210 "off" => PresignMode::Off,
211 _ => {
212 return Err(Error::BadArgument(format!(
213 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
214 )))
215 }
216 };
217 client.set_presign_mode(presign_mode);
218 }
219 "tenant" => {
220 client.tenant = Some(v.to_string());
221 }
222 "warehouse" => {
223 client.warehouse = Mutex::new(Some(v.to_string()));
224 }
225 "role" => role = Some(v.to_string()),
226 "sslmode" => match v.as_ref() {
227 "disable" => scheme = "http",
228 "require" | "enable" => scheme = "https",
229 _ => {
230 return Err(Error::BadArgument(format!(
231 "Invalid value for sslmode: {v}"
232 )))
233 }
234 },
235 "tls_ca_file" => {
236 client.tls_ca_file = Some(v.to_string());
237 }
238 "access_token" => {
239 client.auth = Arc::new(AccessTokenAuth::new(v));
240 }
241 "access_token_file" => {
242 client.auth = Arc::new(AccessTokenFileAuth::new(v));
243 }
244 "login" => {
245 client.disable_login = match v.as_ref() {
246 "disable" => true,
247 "enable" => false,
248 _ => {
249 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
250 }
251 }
252 }
253 "session_token" => {
254 client.disable_session_token = match v.as_ref() {
255 "disable" => true,
256 "enable" => false,
257 _ => {
258 return Err(Error::BadArgument(format!(
259 "Invalid value for session_token: {v}"
260 )))
261 }
262 }
263 }
264 "body_format" => {
265 let v = v.to_string().to_lowercase();
266 match v.as_str() {
267 "json" | "arrow" => client.body_format = v.to_string(),
268 _ => {
269 return Err(Error::BadArgument(format!(
270 "Invalid value for body_format: {v}"
271 )))
272 }
273 }
274 }
275 _ => {
276 session_settings.insert(k.to_string(), v.to_string());
277 }
278 }
279 }
280 client.port = match u.port() {
281 Some(p) => p,
282 None => match scheme {
283 "http" => 80,
284 "https" => 443,
285 _ => unreachable!(),
286 },
287 };
288 client.scheme = scheme.to_string();
289
290 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
291 client.session_state = Mutex::new(
292 SessionState::default()
293 .with_settings(Some(session_settings))
294 .with_role(role)
295 .with_database(database),
296 );
297
298 Ok(client)
299 }
300
301 pub fn host(&self) -> &str {
302 self.host.as_str()
303 }
304
305 pub fn port(&self) -> u16 {
306 self.port
307 }
308
309 pub fn scheme(&self) -> &str {
310 self.scheme.as_str()
311 }
312
313 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
314 let ua = match name {
315 Some(n) => n,
316 None => format!("databend-client-rust/{}", VERSION.as_str()),
317 };
318 let cookie_provider = GlobalCookieStore::new();
319 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
320 let mut initial_cookies = [&cookie].into_iter();
321 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
322 let mut cli_builder = HttpClient::builder()
323 .user_agent(ua)
324 .cookie_provider(Arc::new(cookie_provider))
325 .pool_idle_timeout(Duration::from_secs(1));
326 #[cfg(any(feature = "rustls", feature = "native-tls"))]
327 if self.scheme == "https" {
328 if let Some(ref ca_file) = self.tls_ca_file {
329 let cert_pem = tokio::fs::read(ca_file).await?;
330 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
331 cli_builder = cli_builder.add_root_certificate(cert);
332 }
333 }
334 self.cli = cli_builder.build()?;
335 Ok(())
336 }
337
338 async fn check_presign(self: &Arc<Self>) -> Result<()> {
339 let mode = match self.get_presign_mode() {
340 PresignMode::Auto => {
341 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
342 PresignMode::On
343 } else {
344 PresignMode::Off
345 }
346 }
347 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
348 Ok(_) => PresignMode::On,
349 Err(e) => {
350 warn!("presign mode off with error detected: {e}");
351 PresignMode::Off
352 }
353 },
354 mode => mode,
355 };
356 self.set_presign_mode(mode);
357 Ok(())
358 }
359
360 pub fn current_warehouse(&self) -> Option<String> {
361 let guard = self.warehouse.lock();
362 guard.clone()
363 }
364
365 pub fn current_catalog(&self) -> Option<String> {
366 let guard = self.session_state.lock();
367 guard.catalog.clone()
368 }
369
370 pub fn current_database(&self) -> Option<String> {
371 let guard = self.session_state.lock();
372 guard.database.clone()
373 }
374
375 pub async fn current_role(&self) -> Option<String> {
376 let guard = self.session_state.lock();
377 guard.role.clone()
378 }
379
380 fn in_active_transaction(&self) -> bool {
381 let guard = self.session_state.lock();
382 guard
383 .txn_state
384 .as_ref()
385 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
386 .unwrap_or(false)
387 }
388
389 pub fn username(&self) -> String {
390 self.auth.username()
391 }
392
393 fn gen_query_id(&self) -> String {
394 uuid::Uuid::now_v7().simple().to_string()
395 }
396
397 async fn handle_session(&self, session: &Option<SessionState>) {
398 let session = match session {
399 Some(session) => session,
400 None => return,
401 };
402
403 {
405 let mut session_state = self.session_state.lock();
406 *session_state = session.clone();
407 }
408
409 if let Some(settings) = session.settings.as_ref() {
411 if let Some(v) = settings.get("warehouse") {
412 let mut warehouse = self.warehouse.lock();
413 *warehouse = Some(v.clone());
414 }
415 }
416 }
417
418 pub fn set_last_node_id(&self, node_id: String) {
419 *self.last_node_id.lock() = Some(node_id)
420 }
421
422 pub fn set_last_query_id(&self, query_id: Option<String>) {
423 *self.last_query_id.lock() = query_id
424 }
425
426 pub fn last_query_id(&self) -> Option<String> {
427 self.last_query_id.lock().clone()
428 }
429
430 fn last_node_id(&self) -> Option<String> {
431 self.last_node_id.lock().clone()
432 }
433
434 fn handle_warnings(&self, resp: &QueryResponse) {
435 if let Some(warnings) = &resp.warnings {
436 for w in warnings {
437 warn!(target: "server_warnings", "server warning: {w}");
438 }
439 }
440 }
441
442 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
443 info!("start query: {sql}");
444 let (resp, batches) = self.start_query_inner(sql, None).await?;
445 Pages::new(self.clone(), resp, batches, need_progress)
446 }
447
448 pub fn finalize_query(self: &Arc<Self>, query_id: &str) {
449 let mut mgr = self.queries_need_heartbeat.lock();
450 if let Some(state) = mgr.remove(query_id) {
451 let self_cloned = self.clone();
452 let query_id = query_id.to_owned();
453 GLOBAL_RUNTIME.spawn(async move {
454 if let Err(e) = self_cloned
455 .end_query(&query_id, "final", Some(state.node_id.as_str()))
456 .await
457 {
458 error!("failed to final query {query_id}: {e}");
459 }
460 });
461 }
462 }
463
464 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
465 if let Some(info) = &self.session_token_info {
466 let info = info.lock();
467 Ok(builder.bearer_auth(info.0.session_token.clone()))
468 } else {
469 self.auth.wrap(builder)
470 }
471 }
472
473 async fn start_query_inner(
474 &self,
475 sql: &str,
476 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
477 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
478 if !self.in_active_transaction() {
479 self.route_hint.next();
480 }
481 let endpoint = self.endpoint.join("v1/query")?;
482
483 let session_state = self.session_state();
485 let need_sticky = session_state.need_sticky.unwrap_or(false);
486 let req = QueryRequest::new(sql)
487 .with_pagination(self.make_pagination())
488 .with_session(Some(session_state))
489 .with_stage_attachment(stage_attachment_config);
490
491 let query_id = self.gen_query_id();
493 let mut headers = self.make_headers(Some(&query_id))?;
494 if self.capability.arrow_data && self.body_format == "arrow" {
495 debug!("accept arrow data");
496 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
497 }
498
499 if need_sticky {
500 if let Some(node_id) = self.last_node_id() {
501 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
502 }
503 }
504 let mut builder = self.cli.post(endpoint.clone()).json(&req);
505 builder = self.wrap_auth_or_session_token(builder)?;
506 let request = builder.headers(headers.clone()).build()?;
507 let response = self.query_request_helper(request, true, true).await?;
508 self.handle_page(response, true).await
509 }
510
511 fn is_arrow_data(response: &Response) -> bool {
512 if let Some(typ) = response.headers().get(CONTENT_TYPE) {
513 if let Ok(t) = typ.to_str() {
514 return t == CONTENT_TYPE_ARROW;
515 }
516 }
517 false
518 }
519
520 async fn handle_page(
521 &self,
522 response: Response,
523 is_first: bool,
524 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
525 let status = response.status();
526 if status != 200 {
527 return Err(Error::response_error(status, &response.bytes().await?));
528 }
529 let is_arrow_data = Self::is_arrow_data(&response);
530 if is_first {
531 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
532 self.route_hint.set(route_hint.to_str().unwrap_or_default());
533 }
534 }
535 let mut body = response.bytes().await?;
536 let mut batches = vec![];
537 if is_arrow_data {
538 if is_first {
539 debug!("received arrow data");
540 }
541 let cursor = std::io::Cursor::new(body.as_ref());
542 let reader = StreamReader::try_new(cursor, None)
543 .map_err(|e| Error::Decode(format!("failed to decode arrow stream: {e}")))?;
544 let schema = reader.schema();
545 let json_body = if let Some(json_resp) = schema.metadata.get("response_header") {
546 bytes::Bytes::copy_from_slice(json_resp.as_bytes())
547 } else {
548 return Err(Error::Decode(
549 "missing response_header metadata in arrow payload".to_string(),
550 ));
551 };
552 for batch in reader {
553 let batch = batch
554 .map_err(|e| Error::Decode(format!("failed to decode arrow batch: {e}")))?;
555 batches.push(batch);
556 }
557 body = json_body
558 };
559 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
560 if let Error::Logic(status, ec) = &e {
561 if *status == 404 {
562 return Error::QueryNotFound(ec.message.clone());
563 }
564 }
565 e
566 })?;
567 self.handle_session(&resp.session).await;
568 if let Some(err) = &resp.error {
569 return Err(Error::QueryFailed(err.clone()));
570 }
571 if is_first {
572 self.handle_warnings(&resp);
573 self.set_last_query_id(Some(resp.id.clone()));
574 if let Some(node_id) = &resp.node_id {
575 self.set_last_node_id(node_id.clone());
576 }
577 }
578 Ok((resp, batches))
579 }
580
581 pub async fn query_page(
582 &self,
583 query_id: &str,
584 next_uri: &str,
585 node_id: &Option<String>,
586 ) -> Result<(QueryResponse, Vec<RecordBatch>)> {
587 info!("query page: {next_uri}");
588 let endpoint = self.endpoint.join(next_uri)?;
589 let mut headers = self.make_headers(Some(query_id))?;
590 if self.capability.arrow_data && self.body_format == "arrow" {
591 headers.insert(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_ARROW_OR_JSON));
592 }
593 let mut builder = self.cli.get(endpoint.clone());
594 builder = self
595 .wrap_auth_or_session_token(builder)?
596 .headers(headers.clone())
597 .timeout(self.page_request_timeout);
598 if let Some(node_id) = node_id {
599 builder = builder.header(HEADER_STICKY_NODE, node_id)
600 }
601 let request = builder.build()?;
602
603 let response = self.query_request_helper(request, false, true).await?;
604 self.handle_page(response, false).await
605 }
606
607 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
608 self.end_query(query_id, "kill", None).await
609 }
610
611 pub async fn final_query(&self, query_id: &str, node_id: Option<&str>) -> Result<()> {
612 self.end_query(query_id, "final", node_id).await
613 }
614
615 pub async fn end_query(
616 &self,
617 query_id: &str,
618 method: &str,
619 node_id: Option<&str>,
620 ) -> Result<()> {
621 let uri = format!("/v1/query/{query_id}/{method}");
622 let endpoint = self.endpoint.join(&uri)?;
623 let headers = self.make_headers(Some(query_id))?;
624
625 info!("{method} query: {uri}");
626
627 let mut builder = self.cli.post(endpoint);
628 if let Some(node_id) = node_id {
629 builder = builder.header(HEADER_STICKY_NODE, node_id)
630 }
631 builder = self.wrap_auth_or_session_token(builder)?;
632 let resp = builder.headers(headers.clone()).send().await?;
633 if resp.status() != 200 {
634 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
635 .with_context(&format!("{method} query")));
636 }
637 Ok(())
638 }
639
640 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
641 let mut pages = self.start_query(sql, false).await?;
642 let mut all = Page::default();
643 while let Some(page) = pages.next().await {
644 all.update(page?);
645 }
646 Ok(all)
647 }
648
649 fn session_state(&self) -> SessionState {
650 self.session_state.lock().clone()
651 }
652
653 fn make_pagination(&self) -> Option<PaginationConfig> {
654 if self.wait_time_secs.is_none()
655 && self.max_rows_in_buffer.is_none()
656 && self.max_rows_per_page.is_none()
657 {
658 return None;
659 }
660 let mut pagination = PaginationConfig {
661 wait_time_secs: None,
662 max_rows_in_buffer: None,
663 max_rows_per_page: None,
664 };
665 if let Some(wait_time_secs) = self.wait_time_secs {
666 pagination.wait_time_secs = Some(wait_time_secs);
667 }
668 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
669 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
670 }
671 if let Some(max_rows_per_page) = self.max_rows_per_page {
672 pagination.max_rows_per_page = Some(max_rows_per_page);
673 }
674 Some(pagination)
675 }
676
677 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
678 let mut headers = HeaderMap::new();
679 if let Some(tenant) = &self.tenant {
680 headers.insert(HEADER_TENANT, tenant.parse()?);
681 }
682 let warehouse = self.warehouse.lock().clone();
683 if let Some(warehouse) = warehouse {
684 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
685 }
686 let route_hint = self.route_hint.current();
687 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
688 if let Some(query_id) = query_id {
689 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
690 }
691 Ok(headers)
692 }
693
694 pub async fn insert_with_stage(
695 self: &Arc<Self>,
696 sql: &str,
697 stage: &str,
698 file_format_options: BTreeMap<&str, &str>,
699 copy_options: BTreeMap<&str, &str>,
700 ) -> Result<QueryStats> {
701 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
702 let stage_attachment = Some(StageAttachmentConfig {
703 location: stage,
704 file_format_options: Some(file_format_options),
705 copy_options: Some(copy_options),
706 });
707 let (resp, batches) = self.start_query_inner(sql, stage_attachment).await?;
708 let mut pages = Pages::new(self.clone(), resp, batches, false)?;
709 let mut all = Page::default();
710 while let Some(page) = pages.next().await {
711 all.update(page?);
712 }
713 Ok(all.stats)
714 }
715
716 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
717 info!("get presigned upload url: {stage}");
718 let sql = format!("PRESIGN UPLOAD {stage}");
719 let resp = self.query_all(&sql).await?;
720 if resp.data.len() != 1 {
721 return Err(Error::Decode(
722 "Empty response from server for presigned request".to_string(),
723 ));
724 }
725 if resp.data[0].len() != 3 {
726 return Err(Error::Decode(
727 "Invalid response from server for presigned request".to_string(),
728 ));
729 }
730 let method = resp.data[0][0].clone().unwrap_or_default();
732 if method != "PUT" {
733 return Err(Error::Decode(format!(
734 "Invalid method for presigned upload request: {method}"
735 )));
736 }
737 let headers: BTreeMap<String, String> =
738 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
739 let url = resp.data[0][2].clone().unwrap_or_default();
740 Ok(PresignedResponse {
741 method,
742 headers,
743 url,
744 })
745 }
746
747 pub async fn upload_to_stage(
748 self: &Arc<Self>,
749 stage: &str,
750 data: Reader,
751 size: u64,
752 ) -> Result<()> {
753 match self.get_presign_mode() {
754 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
755 PresignMode::On => {
756 let presigned = self.get_presigned_upload_url(stage).await?;
757 presign_upload_to_stage(presigned, data, size).await
758 }
759 PresignMode::Auto => {
760 unreachable!("PresignMode::Auto should be handled during client initialization")
761 }
762 PresignMode::Detect => {
763 unreachable!("PresignMode::Detect should be handled during client initialization")
764 }
765 }
766 }
767
768 async fn upload_to_stage_with_stream(
770 &self,
771 stage: &str,
772 data: Reader,
773 size: u64,
774 ) -> Result<()> {
775 info!("upload to stage with stream: {stage}, size: {size}");
776 if let Some(info) = self.need_pre_refresh_session().await {
777 self.refresh_session_token(info).await?;
778 }
779 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
780 let location = StageLocation::try_from(stage)?;
781 let query_id = self.gen_query_id();
782 let mut headers = self.make_headers(Some(&query_id))?;
783 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
784 let stream = Body::wrap_stream(ReaderStream::new(data));
785 let part = Part::stream_with_length(stream, size).file_name(location.path);
786 let form = Form::new().part("upload", part);
787 let mut builder = self.cli.put(endpoint.clone());
788 builder = self.wrap_auth_or_session_token(builder)?;
789 let resp = builder.headers(headers).multipart(form).send().await?;
790 let status = resp.status();
791 if status != 200 {
792 return Err(
793 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
794 );
795 }
796 Ok(())
797 }
798
799 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
802 where
803 T: de::DeserializeOwned,
804 {
805 if value.starts_with("{") {
806 serde_json::from_slice(value.as_bytes())
807 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
808 } else {
809 let json = URL_SAFE.decode(value).map_err(|e| {
810 format!(
811 "Invalid value {} for {key}, base64 decode error: {}",
812 value, e
813 )
814 })?;
815 serde_json::from_slice(&json).map_err(|e| {
816 format!(
817 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
818 String::from_utf8_lossy(&json)
819 )
820 })
821 }
822 }
823
824 pub async fn streaming_load(
825 &self,
826 sql: &str,
827 data: Reader,
828 file_name: &str,
829 ) -> Result<LoadResponse> {
830 let body = Body::wrap_stream(ReaderStream::new(data));
831 let part = Part::stream(body).file_name(file_name.to_string());
832 let endpoint = self.endpoint.join("v1/streaming_load")?;
833 let mut builder = self.cli.put(endpoint.clone());
834 builder = self.wrap_auth_or_session_token(builder)?;
835 let query_id = self.gen_query_id();
836 let mut headers = self.make_headers(Some(&query_id))?;
837 headers.insert(HEADER_SQL, sql.parse()?);
838 let session = serde_json::to_string(&*self.session_state.lock())
839 .expect("serialize session state should not fail");
840 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
841 let form = Form::new().part("upload", part);
842 let resp = builder.headers(headers).multipart(form).send().await?;
843 let status = resp.status();
844 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
845 match Self::decode_json_header::<SessionState>(
846 HEADER_QUERY_CONTEXT,
847 value.to_str().unwrap(),
848 ) {
849 Ok(session) => *self.session_state.lock() = session,
850 Err(e) => {
851 error!("Error decoding session state when streaming load: {e}");
852 }
853 }
854 };
855 if status != 200 {
856 return Err(
857 Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
858 );
859 }
860 let resp = resp.json::<LoadResponse>().await?;
861 Ok(resp)
862 }
863
864 async fn login(&mut self) -> Result<()> {
865 let endpoint = self.endpoint.join("/v1/session/login")?;
866 let headers = self.make_headers(None)?;
867 let body = LoginRequest::from(&*self.session_state.lock());
868 let mut builder = self.cli.post(endpoint.clone()).json(&body);
869 if self.disable_session_token {
870 builder = builder.query(&[("disable_session_token", true)]);
871 }
872 let builder = self.auth.wrap(builder)?;
873 let request = builder
874 .headers(headers.clone())
875 .timeout(self.connect_timeout)
876 .build()?;
877 let response = self.query_request_helper(request, true, false).await;
878 let response = match response {
879 Ok(r) => r,
880 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
881 info!("login return 404, skip login on the old version server");
882 return Ok(());
883 }
884 Err(e) => return Err(e),
885 };
886 if let Some(v) = response.headers().get(HEADER_SESSION_ID) {
887 if let Ok(s) = v.to_str() {
888 self.session_id = s.to_string();
889 }
890 }
891
892 let body = response.bytes().await?;
893 let response = json_from_slice(&body)?;
894 match response {
895 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
896 LoginResponseResult::Ok(info) => {
897 let server_version = info
898 .version
899 .parse()
900 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
901 self.capability = Capability::from_server_version(&server_version);
902 self.server_version = Some(server_version.clone());
903 let session_id = self.session_id.as_str();
904 if let Some(tokens) = info.tokens {
905 info!(
906 "[session {session_id}] login success with session token version = {server_version}",
907 );
908 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
909 } else {
910 info!("[session {session_id}] login success, version = {server_version}");
911 }
912 }
913 }
914 Ok(())
915 }
916
917 pub(crate) async fn try_heartbeat(&self) -> Result<()> {
918 let endpoint = self.endpoint.join("/v1/session/heartbeat")?;
919 let queries = self.queries_need_heartbeat.lock().clone();
920 let mut node_to_queries = HashMap::<String, Vec<String>>::new();
921 let now = Instant::now();
922
923 let mut query_ids = Vec::new();
924 for (qid, state) in queries {
925 if state.need_heartbeat(now) {
926 query_ids.push(qid.to_string());
927 if let Some(arr) = node_to_queries.get_mut(&state.node_id) {
928 arr.push(qid);
929 } else {
930 node_to_queries.insert(state.node_id, vec![qid]);
931 }
932 }
933 }
934
935 if node_to_queries.is_empty() && !self.session_state.lock().need_sticky.unwrap_or_default()
936 {
937 return Ok(());
938 }
939
940 let body = json!({
941 "node_to_queries": node_to_queries
942 });
943 let builder = self.cli.post(endpoint.clone()).json(&body);
944 let request = self.wrap_auth_or_session_token(builder)?.build()?;
945 let response = self.query_request_helper(request, true, false).await?;
946 let json: Value = response.json().await?;
947 let session_id = self.session_id.as_str();
948 info!("[session {session_id}] heartbeat request={body}, response={json}");
949 if let Some(queries_to_remove) = json.get("queries_to_remove") {
950 if let Some(arr) = queries_to_remove.as_array() {
951 if !arr.is_empty() {
952 let mut queries = self.queries_need_heartbeat.lock();
953 for q in arr {
954 if let Some(q) = q.as_str() {
955 queries.remove(q);
956 }
957 }
958 }
959 }
960 }
961 let now = Instant::now();
962 let mut queries = self.queries_need_heartbeat.lock();
963 for qid in query_ids {
964 if let Some(state) = queries.get_mut(&qid) {
965 *state.last_access_time.lock() = now;
966 }
967 }
968 Ok(())
969 }
970
971 fn build_log_out_request(&self) -> Result<Request> {
972 let endpoint = self.endpoint.join("/v1/session/logout")?;
973
974 let session_state = self.session_state();
975 let need_sticky = session_state.need_sticky.unwrap_or(false);
976 let mut headers = self.make_headers(None)?;
977 if need_sticky {
978 if let Some(node_id) = self.last_node_id() {
979 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
980 }
981 }
982 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
983
984 let builder = self.wrap_auth_or_session_token(builder)?;
985 let req = builder.build()?;
986 Ok(req)
987 }
988
989 pub(crate) fn need_logout(&self) -> bool {
990 self.session_token_info.is_some()
991 || self.session_state.lock().need_keep_alive.unwrap_or(false)
992 }
993
994 async fn refresh_session_token(
995 &self,
996 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
997 ) -> Result<()> {
998 let (session_token_info, _) = { self_login_info.lock().clone() };
999 let endpoint = self.endpoint.join("/v1/session/refresh")?;
1000 let body = RefreshSessionTokenRequest {
1001 session_token: session_token_info.session_token.clone(),
1002 };
1003 let headers = self.make_headers(None)?;
1004 let request = self
1005 .cli
1006 .post(endpoint.clone())
1007 .json(&body)
1008 .headers(headers.clone())
1009 .bearer_auth(session_token_info.refresh_token.clone())
1010 .timeout(self.connect_timeout)
1011 .build()?;
1012
1013 for i in 0..3 {
1015 let req = request.try_clone().expect("request not cloneable");
1016 match self.cli.execute(req).await {
1017 Ok(response) => {
1018 let status = response.status();
1019 let body = response.bytes().await?;
1020 if status == StatusCode::OK {
1021 let response = json_from_slice(&body)?;
1022 return match response {
1023 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
1024 RefreshResponse::Ok(info) => {
1025 *self_login_info.lock() = (info, Instant::now());
1026 Ok(())
1027 }
1028 };
1029 }
1030 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
1031 return Err(Error::response_error(status, &body));
1032 }
1033 }
1034 Err(err) => {
1035 if !(err.is_timeout() || err.is_connect()) || i > 2 {
1036 return Err(Error::Request(err.to_string()));
1037 }
1038 }
1039 };
1040 sleep(jitter(Duration::from_secs(10))).await;
1041 }
1042 Ok(())
1043 }
1044
1045 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
1046 if let Some(info) = &self.session_token_info {
1047 let (start, ttl) = {
1048 let guard = info.lock();
1049 (guard.1, guard.0.session_token_ttl_in_secs)
1050 };
1051 if Instant::now() > start + Duration::from_secs(ttl) {
1052 return Some(info.clone());
1053 }
1054 }
1055 None
1056 }
1057
1058 async fn query_request_helper(
1066 &self,
1067 mut request: Request,
1068 retry_if_503: bool,
1069 refresh_if_401: bool,
1070 ) -> std::result::Result<Response, Error> {
1071 let mut refreshed = false;
1072 let mut retries = 0;
1073 loop {
1074 let req = request.try_clone().expect("request not cloneable");
1075 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
1076 Ok(response) => {
1077 let status = response.status();
1078 if status == StatusCode::OK {
1079 return Ok(response);
1080 }
1081 let body = response.bytes().await?;
1082 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
1083 (Error::response_error(status, &body), true)
1085 } else {
1086 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
1087 match resp {
1088 Ok(r) => {
1089 let e = r.error;
1090 if status == StatusCode::UNAUTHORIZED {
1091 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
1092 if let Some(session_token_info) = &self.session_token_info {
1093 info!(
1094 "will retry {} after refresh token on auth error {}",
1095 request.url(),
1096 e
1097 );
1098 let retry = if need_refresh_token(e.code)
1099 && !refreshed
1100 && refresh_if_401
1101 {
1102 self.refresh_session_token(session_token_info.clone())
1103 .await?;
1104 refreshed = true;
1105 true
1106 } else {
1107 false
1108 };
1109 (Error::AuthFailure(e), retry)
1110 } else if self.auth.can_reload() {
1111 info!(
1112 "will retry {} after reload token on auth error {}",
1113 request.url(),
1114 e
1115 );
1116 let builder = RequestBuilder::from_parts(
1117 HttpClient::new(),
1118 request.try_clone().unwrap(),
1119 );
1120 let builder = self.auth.wrap(builder)?;
1121 request = builder.build()?;
1122 (Error::AuthFailure(e), true)
1123 } else {
1124 (Error::AuthFailure(e), false)
1125 }
1126 } else {
1127 (Error::Logic(status, e), false)
1128 }
1129 }
1130 Err(_) => (
1131 Error::Response {
1132 status,
1133 msg: String::from_utf8_lossy(&body).to_string(),
1134 },
1135 false,
1136 ),
1137 }
1138 }
1139 }
1140 Err(err) => (
1141 Error::Request(err.to_string()),
1142 err.is_timeout() || err.is_connect(),
1143 ),
1144 };
1145 if !retry {
1146 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
1147 }
1148 match &err {
1149 Error::AuthFailure(_) => {
1150 if refreshed {
1151 retries = 0;
1152 } else if retries == 2 {
1153 return Err(err.with_context(&format!(
1154 "{} {} after 3 reties",
1155 request.method(),
1156 request.url()
1157 )));
1158 }
1159 }
1160 _ => {
1161 if retries == 2 {
1162 return Err(err.with_context(&format!(
1163 "{} {} after 3 reties",
1164 request.method(),
1165 request.url()
1166 )));
1167 }
1168 retries += 1;
1169 info!(
1170 "will retry {} the {retries}th times on error {}",
1171 request.url(),
1172 err
1173 );
1174 }
1175 }
1176 sleep(jitter(Duration::from_secs(10))).await;
1177 }
1178 }
1179
1180 pub async fn logout(http_client: HttpClient, request: Request, session_id: &str) {
1181 if let Err(err) = http_client.execute(request).await {
1182 error!("[session {session_id}] logout request failed: {err}");
1183 } else {
1184 info!("[session {session_id}] logout success");
1185 };
1186 }
1187
1188 pub async fn close(&self) {
1189 let session_id = &self.session_id;
1190 info!("[session {session_id}] try closing now");
1191 if self
1192 .closed
1193 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1194 .is_ok()
1195 {
1196 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1197 if self.need_logout() {
1198 let cli = self.cli.clone();
1199 let req = self
1200 .build_log_out_request()
1201 .expect("failed to build logout request");
1202 Self::logout(cli, req, &self.session_id).await;
1203 }
1204 }
1205 }
1206 pub fn close_with_spawn(&self) {
1207 let session_id = &self.session_id;
1208 info!("[session {session_id}]: try closing with spawn");
1209 if self
1210 .closed
1211 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
1212 .is_ok()
1213 {
1214 GLOBAL_CLIENT_MANAGER.unregister_client(&self.session_id);
1215 if self.need_logout() {
1216 let cli = self.cli.clone();
1217 let req = self
1218 .build_log_out_request()
1219 .expect("failed to build logout request");
1220 let session_id = self.session_id.clone();
1221 GLOBAL_RUNTIME.spawn(async move {
1222 Self::logout(cli, req, session_id.as_str()).await;
1223 });
1224 }
1225 }
1226 }
1227
1228 pub(crate) fn register_query_for_heartbeat(&self, query_id: &str, state: QueryState) {
1229 let mut queries = self.queries_need_heartbeat.lock();
1230 queries.insert(query_id.to_string(), state);
1231 }
1232}
1233
1234fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1235where
1236 T: Deserialize<'a>,
1237{
1238 serde_json::from_slice::<T>(body).map_err(|e| {
1239 Error::Decode(format!(
1240 "fail to decode JSON response: {e}, body: {}",
1241 String::from_utf8_lossy(body)
1242 ))
1243 })
1244}
1245
1246impl Default for APIClient {
1247 fn default() -> Self {
1248 Self {
1249 session_id: Default::default(),
1250 cli: HttpClient::new(),
1251 scheme: "http".to_string(),
1252 endpoint: Url::parse("http://localhost:8080").unwrap(),
1253 host: "localhost".to_string(),
1254 port: 8000,
1255 tenant: None,
1256 warehouse: Mutex::new(None),
1257 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1258 session_state: Mutex::new(SessionState::default()),
1259 wait_time_secs: None,
1260 max_rows_in_buffer: None,
1261 max_rows_per_page: None,
1262 connect_timeout: Duration::from_secs(10),
1263 page_request_timeout: Duration::from_secs(30),
1264 tls_ca_file: None,
1265 presign: Mutex::new(PresignMode::Auto),
1266 route_hint: RouteHintGenerator::new(),
1267 last_node_id: Default::default(),
1268 disable_session_token: true,
1269 disable_login: false,
1270 body_format: "json".to_string(),
1271 session_token_info: None,
1272 closed: AtomicBool::new(false),
1273 last_query_id: Default::default(),
1274 server_version: None,
1275 capability: Default::default(),
1276 queries_need_heartbeat: Default::default(),
1277 }
1278 }
1279}
1280
1281struct RouteHintGenerator {
1282 nonce: AtomicU64,
1283 current: std::sync::Mutex<String>,
1284}
1285
1286impl RouteHintGenerator {
1287 fn new() -> Self {
1288 let gen = Self {
1289 nonce: AtomicU64::new(0),
1290 current: std::sync::Mutex::new("".to_string()),
1291 };
1292 gen.next();
1293 gen
1294 }
1295
1296 fn current(&self) -> String {
1297 let guard = self.current.lock().unwrap();
1298 guard.clone()
1299 }
1300
1301 fn set(&self, hint: &str) {
1302 let mut guard = self.current.lock().unwrap();
1303 *guard = hint.to_string();
1304 }
1305
1306 fn next(&self) -> String {
1307 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1308 let uuid = uuid::Uuid::new_v4();
1309 let current = format!("rh:{uuid}:{nonce:06}");
1310 let mut guard = self.current.lock().unwrap();
1311 guard.clone_from(¤t);
1312 current
1313 }
1314}
1315
1316#[cfg(test)]
1317mod test {
1318 use super::*;
1319
1320 #[tokio::test]
1321 async fn parse_dsn() -> Result<()> {
1322 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1323 let client = APIClient::from_dsn(dsn).await?;
1324 assert_eq!(client.host, "app.databend.com");
1325 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1326 assert_eq!(client.wait_time_secs, Some(10));
1327 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1328 assert_eq!(client.max_rows_per_page, Some(10000));
1329 assert_eq!(client.tenant, None);
1330 assert_eq!(
1331 *client.warehouse.try_lock().unwrap(),
1332 Some("wh".to_string())
1333 );
1334 Ok(())
1335 }
1336
1337 #[tokio::test]
1338 async fn parse_encoded_password() -> Result<()> {
1339 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1340 let client = APIClient::from_dsn(dsn).await?;
1341 assert_eq!(client.host(), "localhost");
1342 assert_eq!(client.port(), 443);
1343 Ok(())
1344 }
1345
1346 #[tokio::test]
1347 async fn parse_special_chars_password() -> Result<()> {
1348 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1349 let client = APIClient::from_dsn(dsn).await?;
1350 assert_eq!(client.host(), "localhost");
1351 assert_eq!(client.port(), 8000);
1352 Ok(())
1353 }
1354}