1use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
22use crate::global_cookie_store::GlobalCookieStore;
23use crate::login::{
24 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
25 SessionTokenInfo,
26};
27use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
28use crate::stage::StageLocation;
29use crate::{
30 error::{Error, Result},
31 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
32 response::QueryResponse,
33 session::SessionState,
34 QueryStats,
35};
36use crate::{Page, Pages};
37use log::{debug, error, info, warn};
38use once_cell::sync::Lazy;
39use parking_lot::Mutex;
40use percent_encoding::percent_decode_str;
41use reqwest::cookie::CookieStore;
42use reqwest::header::{HeaderMap, HeaderValue};
43use reqwest::multipart::{Form, Part};
44use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
45use serde::Deserialize;
46use tokio::time::sleep;
47use tokio_retry::strategy::jitter;
48use tokio_stream::StreamExt;
49use tokio_util::io::ReaderStream;
50use url::Url;
51
52const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
53const HEADER_TENANT: &str = "X-DATABEND-TENANT";
54const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
55const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
56const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
57const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
58const TXN_STATE_ACTIVE: &str = "Active";
59
60static VERSION: Lazy<String> = Lazy::new(|| {
61 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
62 version.to_string()
63});
64
65pub struct APIClient {
66 cli: HttpClient,
67 scheme: String,
68 host: String,
69 port: u16,
70
71 endpoint: Url,
72
73 auth: Arc<dyn Auth>,
74
75 tenant: Option<String>,
76 warehouse: Mutex<Option<String>>,
77 session_state: Mutex<SessionState>,
78 route_hint: RouteHintGenerator,
79
80 disable_login: bool,
81 disable_session_token: bool,
82 session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,
83
84 closed: AtomicBool,
85
86 server_version: Option<String>,
87
88 wait_time_secs: Option<i64>,
89 max_rows_in_buffer: Option<i64>,
90 max_rows_per_page: Option<i64>,
91
92 connect_timeout: Duration,
93 page_request_timeout: Duration,
94
95 tls_ca_file: Option<String>,
96
97 presign: Mutex<PresignMode>,
98 last_node_id: Mutex<Option<String>>,
99 last_query_id: Mutex<Option<String>>,
100}
101
102impl APIClient {
103 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
104 let mut client = Self::from_dsn(dsn).await?;
105 client.build_client(name).await?;
106 if !client.disable_login {
107 client.login().await?;
108 }
109 let client = Arc::new(client);
110 client.check_presign().await?;
111 Ok(client)
112 }
113
114 fn set_presign_mode(&self, mode: PresignMode) {
115 *self.presign.lock() = mode
116 }
117 fn get_presign_mode(&self) -> PresignMode {
118 *self.presign.lock()
119 }
120
121 async fn from_dsn(dsn: &str) -> Result<Self> {
122 let u = Url::parse(dsn)?;
123 let mut client = Self::default();
124 if let Some(host) = u.host_str() {
125 client.host = host.to_string();
126 }
127
128 if u.username() != "" {
129 let password = u.password().unwrap_or_default();
130 let password = percent_decode_str(password).decode_utf8()?;
131 client.auth = Arc::new(BasicAuth::new(u.username(), password));
132 }
133 let database = match u.path().trim_start_matches('/') {
134 "" => None,
135 s => Some(s.to_string()),
136 };
137 let mut role = None;
138 let mut scheme = "https";
139 let mut session_settings = BTreeMap::new();
140 for (k, v) in u.query_pairs() {
141 match k.as_ref() {
142 "wait_time_secs" => {
143 client.wait_time_secs = Some(v.parse()?);
144 }
145 "max_rows_in_buffer" => {
146 client.max_rows_in_buffer = Some(v.parse()?);
147 }
148 "max_rows_per_page" => {
149 client.max_rows_per_page = Some(v.parse()?);
150 }
151 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
152 "page_request_timeout_secs" => {
153 client.page_request_timeout = {
154 let secs: u64 = v.parse()?;
155 Duration::from_secs(secs)
156 };
157 }
158 "presign" => {
159 let presign_mode = match v.as_ref() {
160 "auto" => PresignMode::Auto,
161 "detect" => PresignMode::Detect,
162 "on" => PresignMode::On,
163 "off" => PresignMode::Off,
164 _ => {
165 return Err(Error::BadArgument(format!(
166 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
167 )))
168 }
169 };
170 client.set_presign_mode(presign_mode);
171 }
172 "tenant" => {
173 client.tenant = Some(v.to_string());
174 }
175 "warehouse" => {
176 client.warehouse = Mutex::new(Some(v.to_string()));
177 }
178 "role" => role = Some(v.to_string()),
179 "sslmode" => match v.as_ref() {
180 "disable" => scheme = "http",
181 "require" | "enable" => scheme = "https",
182 _ => {
183 return Err(Error::BadArgument(format!(
184 "Invalid value for sslmode: {v}"
185 )))
186 }
187 },
188 "tls_ca_file" => {
189 client.tls_ca_file = Some(v.to_string());
190 }
191 "access_token" => {
192 client.auth = Arc::new(AccessTokenAuth::new(v));
193 }
194 "access_token_file" => {
195 client.auth = Arc::new(AccessTokenFileAuth::new(v));
196 }
197 "login" => {
198 client.disable_login = match v.as_ref() {
199 "disable" => true,
200 "enable" => false,
201 _ => {
202 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
203 }
204 }
205 }
206 "session_token" => {
207 client.disable_session_token = match v.as_ref() {
208 "disable" => true,
209 "enable" => false,
210 _ => {
211 return Err(Error::BadArgument(format!(
212 "Invalid value for session_token: {v}"
213 )))
214 }
215 }
216 }
217 _ => {
218 session_settings.insert(k.to_string(), v.to_string());
219 }
220 }
221 }
222 client.port = match u.port() {
223 Some(p) => p,
224 None => match scheme {
225 "http" => 80,
226 "https" => 443,
227 _ => unreachable!(),
228 },
229 };
230 client.scheme = scheme.to_string();
231
232 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
233 client.session_state = Mutex::new(
234 SessionState::default()
235 .with_settings(Some(session_settings))
236 .with_role(role)
237 .with_database(database),
238 );
239
240 Ok(client)
241 }
242
243 pub fn host(&self) -> &str {
244 self.host.as_str()
245 }
246
247 pub fn port(&self) -> u16 {
248 self.port
249 }
250
251 pub fn scheme(&self) -> &str {
252 self.scheme.as_str()
253 }
254
255 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
256 let ua = match name {
257 Some(n) => n,
258 None => format!("databend-client-rust/{}", VERSION.as_str()),
259 };
260 let cookie_provider = GlobalCookieStore::new();
261 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
262 let mut initial_cookies = [&cookie].into_iter();
263 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
264 let mut cli_builder = HttpClient::builder()
265 .user_agent(ua)
266 .cookie_provider(Arc::new(cookie_provider))
267 .pool_idle_timeout(Duration::from_secs(1));
268 #[cfg(any(feature = "rustls", feature = "native-tls"))]
269 if self.scheme == "https" {
270 if let Some(ref ca_file) = self.tls_ca_file {
271 let cert_pem = tokio::fs::read(ca_file).await?;
272 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
273 cli_builder = cli_builder.add_root_certificate(cert);
274 }
275 }
276 self.cli = cli_builder.build()?;
277 Ok(())
278 }
279
280 async fn check_presign(self: &Arc<Self>) -> Result<()> {
281 let mode = match self.get_presign_mode() {
282 PresignMode::Auto => {
283 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
284 PresignMode::On
285 } else {
286 PresignMode::Off
287 }
288 }
289 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
290 Ok(_) => PresignMode::On,
291 Err(e) => {
292 warn!("presign mode off with error detected: {e}");
293 PresignMode::Off
294 }
295 },
296 mode => mode,
297 };
298 self.set_presign_mode(mode);
299 Ok(())
300 }
301
302 pub fn current_warehouse(&self) -> Option<String> {
303 let guard = self.warehouse.lock();
304 guard.clone()
305 }
306
307 pub fn current_catalog(&self) -> Option<String> {
308 let guard = self.session_state.lock();
309 guard.catalog.clone()
310 }
311
312 pub fn current_database(&self) -> Option<String> {
313 let guard = self.session_state.lock();
314 guard.database.clone()
315 }
316
317 pub async fn current_role(&self) -> Option<String> {
318 let guard = self.session_state.lock();
319 guard.role.clone()
320 }
321
322 fn in_active_transaction(&self) -> bool {
323 let guard = self.session_state.lock();
324 guard
325 .txn_state
326 .as_ref()
327 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
328 .unwrap_or(false)
329 }
330
331 pub fn username(&self) -> String {
332 self.auth.username()
333 }
334
335 fn gen_query_id(&self) -> String {
336 uuid::Uuid::now_v7().simple().to_string()
337 }
338
339 async fn handle_session(&self, session: &Option<SessionState>) {
340 let session = match session {
341 Some(session) => session,
342 None => return,
343 };
344
345 {
347 let mut session_state = self.session_state.lock();
348 *session_state = session.clone();
349 }
350
351 if let Some(settings) = session.settings.as_ref() {
353 if let Some(v) = settings.get("warehouse") {
354 let mut warehouse = self.warehouse.lock();
355 *warehouse = Some(v.clone());
356 }
357 }
358 }
359
360 pub fn set_last_node_id(&self, node_id: String) {
361 *self.last_node_id.lock() = Some(node_id)
362 }
363
364 pub fn set_last_query_id(&self, query_id: Option<String>) {
365 *self.last_query_id.lock() = query_id
366 }
367
368 pub fn last_query_id(&self) -> Option<String> {
369 self.last_query_id.lock().clone()
370 }
371
372 fn last_node_id(&self) -> Option<String> {
373 self.last_node_id.lock().clone()
374 }
375
376 fn handle_warnings(&self, resp: &QueryResponse) {
377 if let Some(warnings) = &resp.warnings {
378 for w in warnings {
379 warn!(target: "server_warnings", "server warning: {w}");
380 }
381 }
382 }
383
384 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
385 info!("start query: {sql}");
386 let resp = self.start_query_inner(sql, None).await?;
387 let pages = Pages::new(self.clone(), resp, need_progress);
388 Ok(pages)
389 }
390
391 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
392 if let Some(info) = &self.session_token_info {
393 let info = info.lock();
394 Ok(builder.bearer_auth(info.0.session_token.clone()))
395 } else {
396 self.auth.wrap(builder)
397 }
398 }
399
400 async fn start_query_inner(
401 &self,
402 sql: &str,
403 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
404 ) -> Result<QueryResponse> {
405 if !self.in_active_transaction() {
406 self.route_hint.next();
407 }
408 let endpoint = self.endpoint.join("v1/query")?;
409
410 let session_state = self.session_state();
412 let need_sticky = session_state.need_sticky.unwrap_or(false);
413 let req = QueryRequest::new(sql)
414 .with_pagination(self.make_pagination())
415 .with_session(Some(session_state))
416 .with_stage_attachment(stage_attachment_config);
417
418 let query_id = self.gen_query_id();
420 let mut headers = self.make_headers(Some(&query_id))?;
421 if need_sticky {
422 if let Some(node_id) = self.last_node_id() {
423 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
424 }
425 }
426 let mut builder = self.cli.post(endpoint.clone()).json(&req);
427 builder = self.wrap_auth_or_session_token(builder)?;
428 let request = builder.headers(headers.clone()).build()?;
429 let response = self.query_request_helper(request, true, true).await?;
430 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
431 self.route_hint.set(route_hint.to_str().unwrap_or_default());
432 }
433 let body = response.bytes().await?;
434 let result: QueryResponse = json_from_slice(&body)?;
435 self.handle_session(&result.session).await;
436 if let Some(err) = result.error {
437 return Err(Error::QueryFailed(err));
438 }
439
440 self.set_last_query_id(Some(query_id));
441 self.handle_warnings(&result);
442 if let Some(node_id) = &result.node_id {
443 self.set_last_node_id(node_id.clone());
444 }
445 Ok(result)
446 }
447
448 pub async fn query_page(
449 &self,
450 query_id: &str,
451 next_uri: &str,
452 node_id: &Option<String>,
453 ) -> Result<QueryResponse> {
454 info!("query page: {next_uri}");
455 let endpoint = self.endpoint.join(next_uri)?;
456 let headers = self.make_headers(Some(query_id))?;
457 let mut builder = self.cli.get(endpoint.clone());
458 builder = self
459 .wrap_auth_or_session_token(builder)?
460 .headers(headers.clone())
461 .timeout(self.page_request_timeout);
462 if let Some(node_id) = node_id {
463 builder = builder.header(HEADER_STICKY_NODE, node_id)
464 }
465 let request = builder.build()?;
466
467 let response = self.query_request_helper(request, false, true).await?;
468 let body = response.bytes().await?;
469 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
470 if let Error::Logic(status, ec) = &e {
471 if *status == 404 {
472 return Error::QueryNotFound(ec.message.clone());
473 }
474 }
475 e
476 })?;
477 self.handle_session(&resp.session).await;
478 match resp.error {
479 Some(err) => Err(Error::QueryFailed(err)),
480 None => Ok(resp),
481 }
482 }
483
484 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
485 let kill_uri = format!("/v1/query/{query_id}/kill");
486 let endpoint = self.endpoint.join(&kill_uri)?;
487 let headers = self.make_headers(Some(query_id))?;
488 info!("kill query: {kill_uri}");
489
490 let mut builder = self.cli.post(endpoint);
491 builder = self.wrap_auth_or_session_token(builder)?;
492 let resp = builder.headers(headers.clone()).send().await?;
493 if resp.status() != 200 {
494 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
495 .with_context("kill query"));
496 }
497 Ok(())
498 }
499
500 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
501 let mut pages = self.start_query(sql, false).await?;
502 let mut all = Page::default();
503 while let Some(page) = pages.next().await {
504 all.update(page?);
505 }
506 Ok(all)
507 }
508
509 fn session_state(&self) -> SessionState {
510 self.session_state.lock().clone()
511 }
512
513 fn make_pagination(&self) -> Option<PaginationConfig> {
514 if self.wait_time_secs.is_none()
515 && self.max_rows_in_buffer.is_none()
516 && self.max_rows_per_page.is_none()
517 {
518 return None;
519 }
520 let mut pagination = PaginationConfig {
521 wait_time_secs: None,
522 max_rows_in_buffer: None,
523 max_rows_per_page: None,
524 };
525 if let Some(wait_time_secs) = self.wait_time_secs {
526 pagination.wait_time_secs = Some(wait_time_secs);
527 }
528 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
529 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
530 }
531 if let Some(max_rows_per_page) = self.max_rows_per_page {
532 pagination.max_rows_per_page = Some(max_rows_per_page);
533 }
534 Some(pagination)
535 }
536
537 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
538 let mut headers = HeaderMap::new();
539 if let Some(tenant) = &self.tenant {
540 headers.insert(HEADER_TENANT, tenant.parse()?);
541 }
542 let warehouse = self.warehouse.lock().clone();
543 if let Some(warehouse) = warehouse {
544 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
545 }
546 let route_hint = self.route_hint.current();
547 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
548 if let Some(query_id) = query_id {
549 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
550 }
551 Ok(headers)
552 }
553
554 pub async fn insert_with_stage(
555 self: &Arc<Self>,
556 sql: &str,
557 stage: &str,
558 file_format_options: BTreeMap<&str, &str>,
559 copy_options: BTreeMap<&str, &str>,
560 ) -> Result<QueryStats> {
561 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
562 let stage_attachment = Some(StageAttachmentConfig {
563 location: stage,
564 file_format_options: Some(file_format_options),
565 copy_options: Some(copy_options),
566 });
567 let resp = self.start_query_inner(sql, stage_attachment).await?;
568 let mut pages = Pages::new(self.clone(), resp, false);
569 let mut all = Page::default();
570 while let Some(page) = pages.next().await {
571 all.update(page?);
572 }
573 Ok(all.stats)
574 }
575
576 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
577 info!("get presigned upload url: {stage}");
578 let sql = format!("PRESIGN UPLOAD {stage}");
579 let resp = self.query_all(&sql).await?;
580 if resp.data.len() != 1 {
581 return Err(Error::Decode(
582 "Empty response from server for presigned request".to_string(),
583 ));
584 }
585 if resp.data[0].len() != 3 {
586 return Err(Error::Decode(
587 "Invalid response from server for presigned request".to_string(),
588 ));
589 }
590 let method = resp.data[0][0].clone().unwrap_or_default();
592 if method != "PUT" {
593 return Err(Error::Decode(format!(
594 "Invalid method for presigned upload request: {method}"
595 )));
596 }
597 let headers: BTreeMap<String, String> =
598 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
599 let url = resp.data[0][2].clone().unwrap_or_default();
600 Ok(PresignedResponse {
601 method,
602 headers,
603 url,
604 })
605 }
606
607 pub async fn upload_to_stage(
608 self: &Arc<Self>,
609 stage: &str,
610 data: Reader,
611 size: u64,
612 ) -> Result<()> {
613 match self.get_presign_mode() {
614 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
615 PresignMode::On => {
616 let presigned = self.get_presigned_upload_url(stage).await?;
617 presign_upload_to_stage(presigned, data, size).await
618 }
619 PresignMode::Auto => {
620 unreachable!("PresignMode::Auto should be handled during client initialization")
621 }
622 PresignMode::Detect => {
623 unreachable!("PresignMode::Detect should be handled during client initialization")
624 }
625 }
626 }
627
628 async fn upload_to_stage_with_stream(
630 &self,
631 stage: &str,
632 data: Reader,
633 size: u64,
634 ) -> Result<()> {
635 info!("upload to stage with stream: {stage}, size: {size}");
636 if let Some(info) = self.need_pre_refresh_session().await {
637 self.refresh_session_token(info).await?;
638 }
639 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
640 let location = StageLocation::try_from(stage)?;
641 let query_id = self.gen_query_id();
642 let mut headers = self.make_headers(Some(&query_id))?;
643 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
644 let stream = Body::wrap_stream(ReaderStream::new(data));
645 let part = Part::stream_with_length(stream, size).file_name(location.path);
646 let form = Form::new().part("upload", part);
647 let mut builder = self.cli.put(endpoint.clone());
648 builder = self.wrap_auth_or_session_token(builder)?;
649 let resp = builder.headers(headers).multipart(form).send().await?;
650 let status = resp.status();
651 if status != 200 {
652 return Err(
653 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
654 );
655 }
656 Ok(())
657 }
658
659 async fn login(&mut self) -> Result<()> {
660 let endpoint = self.endpoint.join("/v1/session/login")?;
661 let headers = self.make_headers(None)?;
662 let body = LoginRequest::from(&*self.session_state.lock());
663 let mut builder = self.cli.post(endpoint.clone()).json(&body);
664 if self.disable_session_token {
665 builder = builder.query(&[("disable_session_token", true)]);
666 }
667 let builder = self.auth.wrap(builder)?;
668 let request = builder
669 .headers(headers.clone())
670 .timeout(self.connect_timeout)
671 .build()?;
672 let response = self.query_request_helper(request, true, false).await;
673 let response = match response {
674 Ok(r) => r,
675 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
676 info!("login return 404, skip login on the old version server");
677 return Ok(());
678 }
679 Err(e) => return Err(e),
680 };
681 let body = response.bytes().await?;
682 let response = json_from_slice(&body)?;
683 match response {
684 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
685 LoginResponseResult::Ok(info) => {
686 self.server_version = Some(info.version.clone());
687 if let Some(tokens) = info.tokens {
688 info!("login success with session token");
689 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
690 }
691 info!("login success without session token");
692 }
693 }
694 Ok(())
695 }
696
697 fn build_log_out_request(&self) -> Result<Request> {
698 let endpoint = self.endpoint.join("/v1/session/logout")?;
699
700 let session_state = self.session_state();
701 let need_sticky = session_state.need_sticky.unwrap_or(false);
702 let mut headers = self.make_headers(None)?;
703 if need_sticky {
704 if let Some(node_id) = self.last_node_id() {
705 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
706 }
707 }
708 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
709
710 let builder = self.wrap_auth_or_session_token(builder)?;
711 let req = builder.build()?;
712 Ok(req)
713 }
714
715 fn need_logout(&self) -> bool {
716 (self.session_token_info.is_some()
717 || self.session_state.lock().need_keep_alive.unwrap_or(false))
718 && self
719 .closed
720 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
721 .is_ok()
722 }
723
724 async fn refresh_session_token(
725 &self,
726 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
727 ) -> Result<()> {
728 let (session_token_info, _) = { self_login_info.lock().clone() };
729 let endpoint = self.endpoint.join("/v1/session/refresh")?;
730 let body = RefreshSessionTokenRequest {
731 session_token: session_token_info.session_token.clone(),
732 };
733 let headers = self.make_headers(None)?;
734 let request = self
735 .cli
736 .post(endpoint.clone())
737 .json(&body)
738 .headers(headers.clone())
739 .bearer_auth(session_token_info.refresh_token.clone())
740 .timeout(self.connect_timeout)
741 .build()?;
742
743 for i in 0..3 {
745 let req = request.try_clone().expect("request not cloneable");
746 match self.cli.execute(req).await {
747 Ok(response) => {
748 let status = response.status();
749 let body = response.bytes().await?;
750 if status == StatusCode::OK {
751 let response = json_from_slice(&body)?;
752 return match response {
753 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
754 RefreshResponse::Ok(info) => {
755 *self_login_info.lock() = (info, Instant::now());
756 Ok(())
757 }
758 };
759 }
760 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
761 return Err(Error::response_error(status, &body));
762 }
763 }
764 Err(err) => {
765 if !(err.is_timeout() || err.is_connect()) || i > 2 {
766 return Err(Error::Request(err.to_string()));
767 }
768 }
769 };
770 sleep(jitter(Duration::from_secs(10))).await;
771 }
772 Ok(())
773 }
774
775 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
776 if let Some(info) = &self.session_token_info {
777 let (start, ttl) = {
778 let guard = info.lock();
779 (guard.1, guard.0.session_token_ttl_in_secs)
780 };
781 if Instant::now() > start + Duration::from_secs(ttl) {
782 return Some(info.clone());
783 }
784 }
785 None
786 }
787
788 async fn query_request_helper(
796 &self,
797 mut request: Request,
798 retry_if_503: bool,
799 refresh_if_401: bool,
800 ) -> std::result::Result<Response, Error> {
801 let mut refreshed = false;
802 let mut retries = 0;
803 loop {
804 let req = request.try_clone().expect("request not cloneable");
805 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
806 Ok(response) => {
807 let status = response.status();
808 if status == StatusCode::OK {
809 return Ok(response);
810 }
811 let body = response.bytes().await?;
812 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
813 (Error::response_error(status, &body), true)
815 } else {
816 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
817 match resp {
818 Ok(r) => {
819 let e = r.error;
820 if status == StatusCode::UNAUTHORIZED {
821 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
822 if let Some(session_token_info) = &self.session_token_info {
823 info!(
824 "will retry {} after refresh token on auth error {}",
825 request.url(),
826 e
827 );
828 let retry = if need_refresh_token(e.code)
829 && !refreshed
830 && refresh_if_401
831 {
832 self.refresh_session_token(session_token_info.clone())
833 .await?;
834 refreshed = true;
835 true
836 } else {
837 false
838 };
839 (Error::AuthFailure(e), retry)
840 } else if self.auth.can_reload() {
841 info!(
842 "will retry {} after reload token on auth error {}",
843 request.url(),
844 e
845 );
846 let builder = RequestBuilder::from_parts(
847 HttpClient::new(),
848 request.try_clone().unwrap(),
849 );
850 let builder = self.auth.wrap(builder)?;
851 request = builder.build()?;
852 (Error::AuthFailure(e), true)
853 } else {
854 (Error::AuthFailure(e), false)
855 }
856 } else {
857 (Error::Logic(status, e), false)
858 }
859 }
860 Err(_) => (
861 Error::Response {
862 status,
863 msg: String::from_utf8_lossy(&body).to_string(),
864 },
865 false,
866 ),
867 }
868 }
869 }
870 Err(err) => (
871 Error::Request(err.to_string()),
872 err.is_timeout() || err.is_connect(),
873 ),
874 };
875 if !retry {
876 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
877 }
878 match &err {
879 Error::AuthFailure(_) => {
880 if refreshed {
881 retries = 0;
882 } else if retries == 2 {
883 return Err(err.with_context(&format!(
884 "{} {} after 3 reties",
885 request.method(),
886 request.url()
887 )));
888 }
889 }
890 _ => {
891 if retries == 2 {
892 return Err(err.with_context(&format!(
893 "{} {} after 3 reties",
894 request.method(),
895 request.url()
896 )));
897 }
898 retries += 1;
899 info!(
900 "will retry {} the {retries}th times on error {}",
901 request.url(),
902 err
903 );
904 }
905 }
906 sleep(jitter(Duration::from_secs(10))).await;
907 }
908 }
909
910 pub async fn close(&self) {
911 if self.need_logout() {
912 let cli = self.cli.clone();
913 let req = self
914 .build_log_out_request()
915 .expect("failed to build logout request");
916 if let Err(err) = cli.execute(req).await {
917 error!("logout request failed: {err}");
918 } else {
919 debug!("logout success");
920 };
921 }
922 }
923}
924
925impl Drop for APIClient {
926 fn drop(&mut self) {
927 if self.need_logout() {
928 warn!("APIClient::close() was not called");
929 }
930 }
931}
932
933fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
934where
935 T: Deserialize<'a>,
936{
937 serde_json::from_slice::<T>(body).map_err(|e| {
938 Error::Decode(format!(
939 "fail to decode JSON response: {e}, body: {}",
940 String::from_utf8_lossy(body)
941 ))
942 })
943}
944
945impl Default for APIClient {
946 fn default() -> Self {
947 Self {
948 cli: HttpClient::new(),
949 scheme: "http".to_string(),
950 endpoint: Url::parse("http://localhost:8080").unwrap(),
951 host: "localhost".to_string(),
952 port: 8000,
953 tenant: None,
954 warehouse: Mutex::new(None),
955 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
956 session_state: Mutex::new(SessionState::default()),
957 wait_time_secs: None,
958 max_rows_in_buffer: None,
959 max_rows_per_page: None,
960 connect_timeout: Duration::from_secs(10),
961 page_request_timeout: Duration::from_secs(30),
962 tls_ca_file: None,
963 presign: Mutex::new(PresignMode::Auto),
964 route_hint: RouteHintGenerator::new(),
965 last_node_id: Default::default(),
966 disable_session_token: true,
967 disable_login: false,
968 session_token_info: None,
969 closed: Default::default(),
970 last_query_id: Default::default(),
971 server_version: None,
972 }
973 }
974}
975
976struct RouteHintGenerator {
977 nonce: AtomicU64,
978 current: std::sync::Mutex<String>,
979}
980
981impl RouteHintGenerator {
982 fn new() -> Self {
983 let gen = Self {
984 nonce: AtomicU64::new(0),
985 current: std::sync::Mutex::new("".to_string()),
986 };
987 gen.next();
988 gen
989 }
990
991 fn current(&self) -> String {
992 let guard = self.current.lock().unwrap();
993 guard.clone()
994 }
995
996 fn set(&self, hint: &str) {
997 let mut guard = self.current.lock().unwrap();
998 *guard = hint.to_string();
999 }
1000
1001 fn next(&self) -> String {
1002 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1003 let uuid = uuid::Uuid::new_v4();
1004 let current = format!("rh:{uuid}:{nonce:06}");
1005 let mut guard = self.current.lock().unwrap();
1006 guard.clone_from(¤t);
1007 current
1008 }
1009}
1010
1011#[cfg(test)]
1012mod test {
1013 use super::*;
1014
1015 #[tokio::test]
1016 async fn parse_dsn() -> Result<()> {
1017 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1018 let client = APIClient::from_dsn(dsn).await?;
1019 assert_eq!(client.host, "app.databend.com");
1020 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1021 assert_eq!(client.wait_time_secs, Some(10));
1022 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1023 assert_eq!(client.max_rows_per_page, Some(10000));
1024 assert_eq!(client.tenant, None);
1025 assert_eq!(
1026 *client.warehouse.try_lock().unwrap(),
1027 Some("wh".to_string())
1028 );
1029 Ok(())
1030 }
1031
1032 #[tokio::test]
1033 async fn parse_encoded_password() -> Result<()> {
1034 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1035 let client = APIClient::from_dsn(dsn).await?;
1036 assert_eq!(client.host(), "localhost");
1037 assert_eq!(client.port(), 443);
1038 Ok(())
1039 }
1040
1041 #[tokio::test]
1042 async fn parse_special_chars_password() -> Result<()> {
1043 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1044 let client = APIClient::from_dsn(dsn).await?;
1045 assert_eq!(client.host(), "localhost");
1046 assert_eq!(client.port(), 8000);
1047 Ok(())
1048 }
1049}