1use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
22use crate::global_cookie_store::GlobalCookieStore;
23use crate::login::{
24 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
25 SessionTokenInfo,
26};
27use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
28use crate::stage::StageLocation;
29use crate::{
30 error::{Error, Result},
31 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
32 response::QueryResponse,
33 session::SessionState,
34 QueryStats,
35};
36use crate::{Page, Pages};
37use log::{debug, error, info, warn};
38use once_cell::sync::Lazy;
39use parking_lot::Mutex;
40use percent_encoding::percent_decode_str;
41use reqwest::cookie::CookieStore;
42use reqwest::header::{HeaderMap, HeaderValue};
43use reqwest::multipart::{Form, Part};
44use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
45use serde::Deserialize;
46use tokio::time::sleep;
47use tokio_retry::strategy::jitter;
48use tokio_stream::StreamExt;
49use tokio_util::io::ReaderStream;
50use url::Url;
51
52const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
53const HEADER_TENANT: &str = "X-DATABEND-TENANT";
54const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
55const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
56const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
57const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
58const TXN_STATE_ACTIVE: &str = "Active";
59
60static VERSION: Lazy<String> = Lazy::new(|| {
61 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
62 version.to_string()
63});
64
65pub struct APIClient {
66 cli: HttpClient,
67 scheme: String,
68 host: String,
69 port: u16,
70
71 endpoint: Url,
72
73 auth: Arc<dyn Auth>,
74
75 tenant: Option<String>,
76 warehouse: Mutex<Option<String>>,
77 session_state: Mutex<SessionState>,
78 route_hint: RouteHintGenerator,
79
80 disable_login: bool,
81 disable_session_token: bool,
82 session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,
83
84 closed: AtomicBool,
85
86 server_version: Option<String>,
87
88 wait_time_secs: Option<i64>,
89 max_rows_in_buffer: Option<i64>,
90 max_rows_per_page: Option<i64>,
91
92 connect_timeout: Duration,
93 page_request_timeout: Duration,
94
95 tls_ca_file: Option<String>,
96
97 presign: Mutex<PresignMode>,
98 last_node_id: Mutex<Option<String>>,
99 last_query_id: Mutex<Option<String>>,
100}
101
102impl APIClient {
103 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
104 let mut client = Self::from_dsn(dsn).await?;
105 client.build_client(name).await?;
106 if !client.disable_login {
107 client.login().await?;
108 }
109 let client = Arc::new(client);
110 client.check_presign().await?;
111 Ok(client)
112 }
113
114 fn set_presign_mode(&self, mode: PresignMode) {
115 *self.presign.lock() = mode
116 }
117 fn get_presign_mode(&self) -> PresignMode {
118 *self.presign.lock()
119 }
120
121 async fn from_dsn(dsn: &str) -> Result<Self> {
122 let u = Url::parse(dsn)?;
123 let mut client = Self::default();
124 if let Some(host) = u.host_str() {
125 client.host = host.to_string();
126 }
127
128 if u.username() != "" {
129 let password = u.password().unwrap_or_default();
130 let password = percent_decode_str(password).decode_utf8()?;
131 client.auth = Arc::new(BasicAuth::new(u.username(), password));
132 }
133 let database = match u.path().trim_start_matches('/') {
134 "" => None,
135 s => Some(s.to_string()),
136 };
137 let mut role = None;
138 let mut scheme = "https";
139 let mut session_settings = BTreeMap::new();
140 for (k, v) in u.query_pairs() {
141 match k.as_ref() {
142 "wait_time_secs" => {
143 client.wait_time_secs = Some(v.parse()?);
144 }
145 "max_rows_in_buffer" => {
146 client.max_rows_in_buffer = Some(v.parse()?);
147 }
148 "max_rows_per_page" => {
149 client.max_rows_per_page = Some(v.parse()?);
150 }
151 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
152 "page_request_timeout_secs" => {
153 client.page_request_timeout = {
154 let secs: u64 = v.parse()?;
155 Duration::from_secs(secs)
156 };
157 }
158 "presign" => {
159 let presign_mode = match v.as_ref() {
160 "auto" => PresignMode::Auto,
161 "detect" => PresignMode::Detect,
162 "on" => PresignMode::On,
163 "off" => PresignMode::Off,
164 _ => {
165 return Err(Error::BadArgument(format!(
166 "Invalid value for presign: {}, should be one of auto/detect/on/off",
167 v
168 )))
169 }
170 };
171 client.set_presign_mode(presign_mode);
172 }
173 "tenant" => {
174 client.tenant = Some(v.to_string());
175 }
176 "warehouse" => {
177 client.warehouse = Mutex::new(Some(v.to_string()));
178 }
179 "role" => role = Some(v.to_string()),
180 "sslmode" => match v.as_ref() {
181 "disable" => scheme = "http",
182 "require" | "enable" => scheme = "https",
183 _ => {
184 return Err(Error::BadArgument(format!(
185 "Invalid value for sslmode: {}",
186 v
187 )))
188 }
189 },
190 "tls_ca_file" => {
191 client.tls_ca_file = Some(v.to_string());
192 }
193 "access_token" => {
194 client.auth = Arc::new(AccessTokenAuth::new(v));
195 }
196 "access_token_file" => {
197 client.auth = Arc::new(AccessTokenFileAuth::new(v));
198 }
199 "login" => {
200 client.disable_login = match v.as_ref() {
201 "disable" => true,
202 "enable" => false,
203 _ => {
204 return Err(Error::BadArgument(format!(
205 "Invalid value for login: {}",
206 v
207 )))
208 }
209 }
210 }
211 "session_token" => {
212 client.disable_session_token = match v.as_ref() {
213 "disable" => true,
214 "enable" => false,
215 _ => {
216 return Err(Error::BadArgument(format!(
217 "Invalid value for session_token: {}",
218 v
219 )))
220 }
221 }
222 }
223 _ => {
224 session_settings.insert(k.to_string(), v.to_string());
225 }
226 }
227 }
228 client.port = match u.port() {
229 Some(p) => p,
230 None => match scheme {
231 "http" => 80,
232 "https" => 443,
233 _ => unreachable!(),
234 },
235 };
236 client.scheme = scheme.to_string();
237
238 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
239 client.session_state = Mutex::new(
240 SessionState::default()
241 .with_settings(Some(session_settings))
242 .with_role(role)
243 .with_database(database),
244 );
245
246 Ok(client)
247 }
248
249 pub fn host(&self) -> &str {
250 self.host.as_str()
251 }
252
253 pub fn port(&self) -> u16 {
254 self.port
255 }
256
257 pub fn scheme(&self) -> &str {
258 self.scheme.as_str()
259 }
260
261 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
262 let ua = match name {
263 Some(n) => n,
264 None => format!("databend-client-rust/{}", VERSION.as_str()),
265 };
266 let cookie_provider = GlobalCookieStore::new();
267 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
268 let mut initial_cookies = [&cookie].into_iter();
269 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
270 let mut cli_builder = HttpClient::builder()
271 .user_agent(ua)
272 .cookie_provider(Arc::new(cookie_provider))
273 .pool_idle_timeout(Duration::from_secs(1));
274 #[cfg(any(feature = "rustls", feature = "native-tls"))]
275 if self.scheme == "https" {
276 if let Some(ref ca_file) = self.tls_ca_file {
277 let cert_pem = tokio::fs::read(ca_file).await?;
278 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
279 cli_builder = cli_builder.add_root_certificate(cert);
280 }
281 }
282 self.cli = cli_builder.build()?;
283 Ok(())
284 }
285
286 async fn check_presign(self: &Arc<Self>) -> Result<()> {
287 let mode = match self.get_presign_mode() {
288 PresignMode::Auto => {
289 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
290 PresignMode::On
291 } else {
292 PresignMode::Off
293 }
294 }
295 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
296 Ok(_) => PresignMode::On,
297 Err(e) => {
298 warn!("presign mode off with error detected: {}", e);
299 PresignMode::Off
300 }
301 },
302 mode => mode,
303 };
304 self.set_presign_mode(mode);
305 Ok(())
306 }
307
308 pub fn current_warehouse(&self) -> Option<String> {
309 let guard = self.warehouse.lock();
310 guard.clone()
311 }
312
313 pub fn current_catalog(&self) -> Option<String> {
314 let guard = self.session_state.lock();
315 guard.catalog.clone()
316 }
317
318 pub fn current_database(&self) -> Option<String> {
319 let guard = self.session_state.lock();
320 guard.database.clone()
321 }
322
323 pub async fn current_role(&self) -> Option<String> {
324 let guard = self.session_state.lock();
325 guard.role.clone()
326 }
327
328 fn in_active_transaction(&self) -> bool {
329 let guard = self.session_state.lock();
330 guard
331 .txn_state
332 .as_ref()
333 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
334 .unwrap_or(false)
335 }
336
337 pub fn username(&self) -> String {
338 self.auth.username()
339 }
340
341 fn gen_query_id(&self) -> String {
342 uuid::Uuid::new_v4().to_string()
343 }
344
345 async fn handle_session(&self, session: &Option<SessionState>) {
346 let session = match session {
347 Some(session) => session,
348 None => return,
349 };
350
351 {
353 let mut session_state = self.session_state.lock();
354 *session_state = session.clone();
355 }
356
357 if let Some(settings) = session.settings.as_ref() {
359 if let Some(v) = settings.get("warehouse") {
360 let mut warehouse = self.warehouse.lock();
361 *warehouse = Some(v.clone());
362 }
363 }
364 }
365
366 pub fn set_last_node_id(&self, node_id: String) {
367 *self.last_node_id.lock() = Some(node_id)
368 }
369
370 pub fn set_last_query_id(&self, query_id: Option<String>) {
371 *self.last_query_id.lock() = query_id
372 }
373
374 pub fn last_query_id(&self) -> Option<String> {
375 self.last_query_id.lock().clone()
376 }
377
378 fn last_node_id(&self) -> Option<String> {
379 self.last_node_id.lock().clone()
380 }
381
382 fn handle_warnings(&self, resp: &QueryResponse) {
383 if let Some(warnings) = &resp.warnings {
384 for w in warnings {
385 warn!(target: "server_warnings", "server warning: {}", w);
386 }
387 }
388 }
389
390 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
391 info!("start query: {}", sql);
392 let resp = self.start_query_inner(sql, None).await?;
393 let pages = Pages::new(self.clone(), resp, need_progress);
394 Ok(pages)
395 }
396
397 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
398 if let Some(info) = &self.session_token_info {
399 let info = info.lock();
400 Ok(builder.bearer_auth(info.0.session_token.clone()))
401 } else {
402 self.auth.wrap(builder)
403 }
404 }
405
406 async fn start_query_inner(
407 &self,
408 sql: &str,
409 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
410 ) -> Result<QueryResponse> {
411 if !self.in_active_transaction() {
412 self.route_hint.next();
413 }
414 let endpoint = self.endpoint.join("v1/query")?;
415
416 let session_state = self.session_state();
418 let need_sticky = session_state.need_sticky.unwrap_or(false);
419 let req = QueryRequest::new(sql)
420 .with_pagination(self.make_pagination())
421 .with_session(Some(session_state))
422 .with_stage_attachment(stage_attachment_config);
423
424 let query_id = self.gen_query_id();
426 let mut headers = self.make_headers(Some(&query_id))?;
427 if need_sticky {
428 if let Some(node_id) = self.last_node_id() {
429 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
430 }
431 }
432 let mut builder = self.cli.post(endpoint.clone()).json(&req);
433 builder = self.wrap_auth_or_session_token(builder)?;
434 let request = builder.headers(headers.clone()).build()?;
435 let response = self.query_request_helper(request, true, true).await?;
436 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
437 self.route_hint.set(route_hint.to_str().unwrap_or_default());
438 }
439 let body = response.bytes().await?;
440 let result: QueryResponse = json_from_slice(&body)?;
441 self.handle_session(&result.session).await;
442 if let Some(err) = result.error {
443 return Err(Error::QueryFailed(err));
444 }
445
446 self.set_last_query_id(Some(query_id));
447 self.handle_warnings(&result);
448 if let Some(node_id) = &result.node_id {
449 self.set_last_node_id(node_id.clone());
450 }
451 Ok(result)
452 }
453
454 pub async fn query_page(
455 &self,
456 query_id: &str,
457 next_uri: &str,
458 node_id: &Option<String>,
459 ) -> Result<QueryResponse> {
460 info!("query page: {}", next_uri);
461 let endpoint = self.endpoint.join(next_uri)?;
462 let headers = self.make_headers(Some(query_id))?;
463 let mut builder = self.cli.get(endpoint.clone());
464 builder = self
465 .wrap_auth_or_session_token(builder)?
466 .headers(headers.clone())
467 .timeout(self.page_request_timeout);
468 if let Some(node_id) = node_id {
469 builder = builder.header(HEADER_STICKY_NODE, node_id)
470 }
471 let request = builder.build()?;
472
473 let response = self.query_request_helper(request, false, true).await?;
474 let body = response.bytes().await?;
475 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
476 if let Error::Logic(status, ec) = &e {
477 if *status == 404 {
478 return Error::QueryNotFound(ec.message.clone());
479 }
480 }
481 e
482 })?;
483 self.handle_session(&resp.session).await;
484 match resp.error {
485 Some(err) => Err(Error::QueryFailed(err)),
486 None => Ok(resp),
487 }
488 }
489
490 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
491 let kill_uri = format!("/v1/query/{}/kill", query_id);
492 let endpoint = self.endpoint.join(&kill_uri)?;
493 let headers = self.make_headers(Some(query_id))?;
494 info!("kill query: {}", kill_uri);
495
496 let mut builder = self.cli.post(endpoint);
497 builder = self.wrap_auth_or_session_token(builder)?;
498 let resp = builder.headers(headers.clone()).send().await?;
499 if resp.status() != 200 {
500 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
501 .with_context("kill query"));
502 }
503 Ok(())
504 }
505
506 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
507 let mut pages = self.start_query(sql, false).await?;
508 let mut all = Page::default();
509 while let Some(page) = pages.next().await {
510 all.update(page?);
511 }
512 Ok(all)
513 }
514
515 fn session_state(&self) -> SessionState {
516 self.session_state.lock().clone()
517 }
518
519 fn make_pagination(&self) -> Option<PaginationConfig> {
520 if self.wait_time_secs.is_none()
521 && self.max_rows_in_buffer.is_none()
522 && self.max_rows_per_page.is_none()
523 {
524 return None;
525 }
526 let mut pagination = PaginationConfig {
527 wait_time_secs: None,
528 max_rows_in_buffer: None,
529 max_rows_per_page: None,
530 };
531 if let Some(wait_time_secs) = self.wait_time_secs {
532 pagination.wait_time_secs = Some(wait_time_secs);
533 }
534 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
535 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
536 }
537 if let Some(max_rows_per_page) = self.max_rows_per_page {
538 pagination.max_rows_per_page = Some(max_rows_per_page);
539 }
540 Some(pagination)
541 }
542
543 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
544 let mut headers = HeaderMap::new();
545 if let Some(tenant) = &self.tenant {
546 headers.insert(HEADER_TENANT, tenant.parse()?);
547 }
548 let warehouse = self.warehouse.lock().clone();
549 if let Some(warehouse) = warehouse {
550 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
551 }
552 let route_hint = self.route_hint.current();
553 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
554 if let Some(query_id) = query_id {
555 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
556 }
557 Ok(headers)
558 }
559
560 pub async fn insert_with_stage(
561 self: &Arc<Self>,
562 sql: &str,
563 stage: &str,
564 file_format_options: BTreeMap<&str, &str>,
565 copy_options: BTreeMap<&str, &str>,
566 ) -> Result<QueryStats> {
567 info!(
568 "insert with stage: {}, format: {:?}, copy: {:?}",
569 sql, file_format_options, copy_options
570 );
571 let stage_attachment = Some(StageAttachmentConfig {
572 location: stage,
573 file_format_options: Some(file_format_options),
574 copy_options: Some(copy_options),
575 });
576 let resp = self.start_query_inner(sql, stage_attachment).await?;
577 let mut pages = Pages::new(self.clone(), resp, false);
578 let mut all = Page::default();
579 while let Some(page) = pages.next().await {
580 all.update(page?);
581 }
582 Ok(all.stats)
583 }
584
585 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
586 info!("get presigned upload url: {}", stage);
587 let sql = format!("PRESIGN UPLOAD {}", stage);
588 let resp = self.query_all(&sql).await?;
589 if resp.data.len() != 1 {
590 return Err(Error::Decode(
591 "Empty response from server for presigned request".to_string(),
592 ));
593 }
594 if resp.data[0].len() != 3 {
595 return Err(Error::Decode(
596 "Invalid response from server for presigned request".to_string(),
597 ));
598 }
599 let method = resp.data[0][0].clone().unwrap_or_default();
601 if method != "PUT" {
602 return Err(Error::Decode(format!(
603 "Invalid method for presigned upload request: {}",
604 method
605 )));
606 }
607 let headers: BTreeMap<String, String> =
608 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
609 let url = resp.data[0][2].clone().unwrap_or_default();
610 Ok(PresignedResponse {
611 method,
612 headers,
613 url,
614 })
615 }
616
617 pub async fn upload_to_stage(
618 self: &Arc<Self>,
619 stage: &str,
620 data: Reader,
621 size: u64,
622 ) -> Result<()> {
623 match self.get_presign_mode() {
624 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
625 PresignMode::On => {
626 let presigned = self.get_presigned_upload_url(stage).await?;
627 presign_upload_to_stage(presigned, data, size).await
628 }
629 PresignMode::Auto => {
630 unreachable!("PresignMode::Auto should be handled during client initialization")
631 }
632 PresignMode::Detect => {
633 unreachable!("PresignMode::Detect should be handled during client initialization")
634 }
635 }
636 }
637
638 async fn upload_to_stage_with_stream(
640 &self,
641 stage: &str,
642 data: Reader,
643 size: u64,
644 ) -> Result<()> {
645 info!("upload to stage with stream: {}, size: {}", stage, size);
646 if let Some(info) = self.need_pre_refresh_session().await {
647 self.refresh_session_token(info).await?;
648 }
649 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
650 let location = StageLocation::try_from(stage)?;
651 let query_id = self.gen_query_id();
652 let mut headers = self.make_headers(Some(&query_id))?;
653 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
654 let stream = Body::wrap_stream(ReaderStream::new(data));
655 let part = Part::stream_with_length(stream, size).file_name(location.path);
656 let form = Form::new().part("upload", part);
657 let mut builder = self.cli.put(endpoint.clone());
658 builder = self.wrap_auth_or_session_token(builder)?;
659 let resp = builder.headers(headers).multipart(form).send().await?;
660 let status = resp.status();
661 if status != 200 {
662 return Err(
663 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
664 );
665 }
666 Ok(())
667 }
668
669 async fn login(&mut self) -> Result<()> {
670 let endpoint = self.endpoint.join("/v1/session/login")?;
671 let headers = self.make_headers(None)?;
672 let body = LoginRequest::from(&*self.session_state.lock());
673 let mut builder = self.cli.post(endpoint.clone()).json(&body);
674 if self.disable_session_token {
675 builder = builder.query(&[("disable_session_token", true)]);
676 }
677 let builder = self.auth.wrap(builder)?;
678 let request = builder
679 .headers(headers.clone())
680 .timeout(self.connect_timeout)
681 .build()?;
682 let response = self.query_request_helper(request, true, false).await;
683 let response = match response {
684 Ok(r) => r,
685 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
686 info!("login return 404, skip login on the old version server");
687 return Ok(());
688 }
689 Err(e) => return Err(e),
690 };
691 let body = response.bytes().await?;
692 let response = json_from_slice(&body)?;
693 match response {
694 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
695 LoginResponseResult::Ok(info) => {
696 self.server_version = Some(info.version.clone());
697 if let Some(tokens) = info.tokens {
698 info!("login success with session token");
699 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
700 }
701 info!("login success without session token");
702 }
703 }
704 Ok(())
705 }
706
707 fn build_log_out_request(&self) -> Result<Request> {
708 let endpoint = self.endpoint.join("/v1/session/logout")?;
709
710 let session_state = self.session_state();
711 let need_sticky = session_state.need_sticky.unwrap_or(false);
712 let mut headers = self.make_headers(None)?;
713 if need_sticky {
714 if let Some(node_id) = self.last_node_id() {
715 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
716 }
717 }
718 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
719
720 let builder = self.wrap_auth_or_session_token(builder)?;
721 let req = builder.build()?;
722 Ok(req)
723 }
724
725 fn need_logout(&self) -> bool {
726 (self.session_token_info.is_some()
727 || self.session_state.lock().need_keep_alive.unwrap_or(false))
728 && self
729 .closed
730 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
731 .is_ok()
732 }
733
734 async fn refresh_session_token(
735 &self,
736 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
737 ) -> Result<()> {
738 let (session_token_info, _) = { self_login_info.lock().clone() };
739 let endpoint = self.endpoint.join("/v1/session/refresh")?;
740 let body = RefreshSessionTokenRequest {
741 session_token: session_token_info.session_token.clone(),
742 };
743 let headers = self.make_headers(None)?;
744 let request = self
745 .cli
746 .post(endpoint.clone())
747 .json(&body)
748 .headers(headers.clone())
749 .bearer_auth(session_token_info.refresh_token.clone())
750 .timeout(self.connect_timeout)
751 .build()?;
752
753 for i in 0..3 {
755 let req = request.try_clone().expect("request not cloneable");
756 match self.cli.execute(req).await {
757 Ok(response) => {
758 let status = response.status();
759 let body = response.bytes().await?;
760 if status == StatusCode::OK {
761 let response = json_from_slice(&body)?;
762 return match response {
763 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
764 RefreshResponse::Ok(info) => {
765 *self_login_info.lock() = (info, Instant::now());
766 Ok(())
767 }
768 };
769 }
770 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
771 return Err(Error::response_error(status, &body));
772 }
773 }
774 Err(err) => {
775 if !(err.is_timeout() || err.is_connect()) || i > 2 {
776 return Err(Error::Request(err.to_string()));
777 }
778 }
779 };
780 sleep(jitter(Duration::from_secs(10))).await;
781 }
782 Ok(())
783 }
784
785 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
786 if let Some(info) = &self.session_token_info {
787 let (start, ttl) = {
788 let guard = info.lock();
789 (guard.1, guard.0.session_token_ttl_in_secs)
790 };
791 if Instant::now() > start + Duration::from_secs(ttl) {
792 return Some(info.clone());
793 }
794 }
795 None
796 }
797
798 async fn query_request_helper(
806 &self,
807 mut request: Request,
808 retry_if_503: bool,
809 refresh_if_401: bool,
810 ) -> std::result::Result<Response, Error> {
811 let mut refreshed = false;
812 let mut retries = 0;
813 loop {
814 let req = request.try_clone().expect("request not cloneable");
815 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
816 Ok(response) => {
817 let status = response.status();
818 if status == StatusCode::OK {
819 return Ok(response);
820 }
821 let body = response.bytes().await?;
822 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
823 (Error::response_error(status, &body), true)
825 } else {
826 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
827 match resp {
828 Ok(r) => {
829 let e = r.error;
830 if status == StatusCode::UNAUTHORIZED {
831 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
832 if let Some(session_token_info) = &self.session_token_info {
833 info!(
834 "will retry {} after refresh token on auth error {}",
835 request.url(),
836 e
837 );
838 let retry = if need_refresh_token(e.code)
839 && !refreshed
840 && refresh_if_401
841 {
842 self.refresh_session_token(session_token_info.clone())
843 .await?;
844 refreshed = true;
845 true
846 } else {
847 false
848 };
849 (Error::AuthFailure(e), retry)
850 } else if self.auth.can_reload() {
851 info!(
852 "will retry {} after reload token on auth error {}",
853 request.url(),
854 e
855 );
856 let builder = RequestBuilder::from_parts(
857 HttpClient::new(),
858 request.try_clone().unwrap(),
859 );
860 let builder = self.auth.wrap(builder)?;
861 request = builder.build()?;
862 (Error::AuthFailure(e), true)
863 } else {
864 (Error::AuthFailure(e), false)
865 }
866 } else {
867 (Error::Logic(status, e), false)
868 }
869 }
870 Err(_) => (
871 Error::Response {
872 status,
873 msg: String::from_utf8_lossy(&body).to_string(),
874 },
875 false,
876 ),
877 }
878 }
879 }
880 Err(err) => (
881 Error::Request(err.to_string()),
882 err.is_timeout() || err.is_connect(),
883 ),
884 };
885 if !retry {
886 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
887 }
888 match &err {
889 Error::AuthFailure(_) => {
890 if refreshed {
891 retries = 0;
892 } else if retries == 2 {
893 return Err(err.with_context(&format!(
894 "{} {} after 3 reties",
895 request.method(),
896 request.url()
897 )));
898 }
899 }
900 _ => {
901 if retries == 2 {
902 return Err(err.with_context(&format!(
903 "{} {} after 3 reties",
904 request.method(),
905 request.url()
906 )));
907 }
908 retries += 1;
909 info!(
910 "will retry {} the {retries}th times on error {}",
911 request.url(),
912 err
913 );
914 }
915 }
916 sleep(jitter(Duration::from_secs(10))).await;
917 }
918 }
919
920 pub async fn close(&self) {
921 if self.need_logout() {
922 let cli = self.cli.clone();
923 let req = self
924 .build_log_out_request()
925 .expect("failed to build logout request");
926 if let Err(err) = cli.execute(req).await {
927 error!("logout request failed: {}", err);
928 } else {
929 debug!("logout success");
930 };
931 }
932 }
933}
934
935impl Drop for APIClient {
936 fn drop(&mut self) {
937 if self.need_logout() {
938 warn!("APIClient::close() was not called");
939 }
940 }
941}
942
943fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
944where
945 T: Deserialize<'a>,
946{
947 serde_json::from_slice::<T>(body).map_err(|e| {
948 Error::Decode(format!(
949 "fail to decode JSON response: {}, body: {}",
950 e,
951 String::from_utf8_lossy(body)
952 ))
953 })
954}
955
956impl Default for APIClient {
957 fn default() -> Self {
958 Self {
959 cli: HttpClient::new(),
960 scheme: "http".to_string(),
961 endpoint: Url::parse("http://localhost:8080").unwrap(),
962 host: "localhost".to_string(),
963 port: 8000,
964 tenant: None,
965 warehouse: Mutex::new(None),
966 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
967 session_state: Mutex::new(SessionState::default()),
968 wait_time_secs: None,
969 max_rows_in_buffer: None,
970 max_rows_per_page: None,
971 connect_timeout: Duration::from_secs(10),
972 page_request_timeout: Duration::from_secs(30),
973 tls_ca_file: None,
974 presign: Mutex::new(PresignMode::Auto),
975 route_hint: RouteHintGenerator::new(),
976 last_node_id: Default::default(),
977 disable_session_token: true,
978 disable_login: false,
979 session_token_info: None,
980 closed: Default::default(),
981 last_query_id: Default::default(),
982 server_version: None,
983 }
984 }
985}
986
987struct RouteHintGenerator {
988 nonce: AtomicU64,
989 current: std::sync::Mutex<String>,
990}
991
992impl RouteHintGenerator {
993 fn new() -> Self {
994 let gen = Self {
995 nonce: AtomicU64::new(0),
996 current: std::sync::Mutex::new("".to_string()),
997 };
998 gen.next();
999 gen
1000 }
1001
1002 fn current(&self) -> String {
1003 let guard = self.current.lock().unwrap();
1004 guard.clone()
1005 }
1006
1007 fn set(&self, hint: &str) {
1008 let mut guard = self.current.lock().unwrap();
1009 *guard = hint.to_string();
1010 }
1011
1012 fn next(&self) -> String {
1013 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1014 let uuid = uuid::Uuid::new_v4();
1015 let current = format!("rh:{}:{:06}", uuid, nonce);
1016 let mut guard = self.current.lock().unwrap();
1017 guard.clone_from(¤t);
1018 current
1019 }
1020}
1021
1022#[cfg(test)]
1023mod test {
1024 use super::*;
1025
1026 #[tokio::test]
1027 async fn parse_dsn() -> Result<()> {
1028 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1029 let client = APIClient::from_dsn(dsn).await?;
1030 assert_eq!(client.host, "app.databend.com");
1031 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1032 assert_eq!(client.wait_time_secs, Some(10));
1033 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1034 assert_eq!(client.max_rows_per_page, Some(10000));
1035 assert_eq!(client.tenant, None);
1036 assert_eq!(
1037 *client.warehouse.try_lock().unwrap(),
1038 Some("wh".to_string())
1039 );
1040 Ok(())
1041 }
1042
1043 #[tokio::test]
1044 async fn parse_encoded_password() -> Result<()> {
1045 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1046 let client = APIClient::from_dsn(dsn).await?;
1047 assert_eq!(client.host(), "localhost");
1048 assert_eq!(client.port(), 443);
1049 Ok(())
1050 }
1051
1052 #[tokio::test]
1053 async fn parse_special_chars_password() -> Result<()> {
1054 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1055 let client = APIClient::from_dsn(dsn).await?;
1056 assert_eq!(client.host(), "localhost");
1057 assert_eq!(client.port(), 8000);
1058 Ok(())
1059 }
1060}