1use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
22use crate::global_cookie_store::GlobalCookieStore;
23use crate::login::{
24 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
25 SessionTokenInfo,
26};
27use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
28use crate::stage::StageLocation;
29use crate::{
30 error::{Error, Result},
31 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
32 response::QueryResponse,
33 session::SessionState,
34};
35use log::{debug, error, info, warn};
36use once_cell::sync::Lazy;
37use percent_encoding::percent_decode_str;
38use reqwest::cookie::CookieStore;
39use reqwest::header::{HeaderMap, HeaderValue};
40use reqwest::multipart::{Form, Part};
41use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
42use serde::Deserialize;
43use tokio::time::sleep;
44use tokio_retry::strategy::jitter;
45use tokio_util::io::ReaderStream;
46use url::Url;
47
48const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
49const HEADER_TENANT: &str = "X-DATABEND-TENANT";
50const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
51const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
52const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
53const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
54const TXN_STATE_ACTIVE: &str = "Active";
55
56static VERSION: Lazy<String> = Lazy::new(|| {
57 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
58 version.to_string()
59});
60
61#[derive(Clone)]
62pub struct APIClient {
63 cli: HttpClient,
64 scheme: String,
65 host: String,
66 port: u16,
67
68 endpoint: Url,
69
70 auth: Arc<dyn Auth>,
71
72 tenant: Option<String>,
73 warehouse: Arc<parking_lot::Mutex<Option<String>>>,
74 session_state: Arc<parking_lot::Mutex<SessionState>>,
75 route_hint: Arc<RouteHintGenerator>,
76
77 disable_login: bool,
78 disable_session_token: bool,
79 session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,
80
81 closed: Arc<AtomicBool>,
82
83 server_version: Option<String>,
84
85 wait_time_secs: Option<i64>,
86 max_rows_in_buffer: Option<i64>,
87 max_rows_per_page: Option<i64>,
88
89 connect_timeout: Duration,
90 page_request_timeout: Duration,
91
92 tls_ca_file: Option<String>,
93
94 presign: PresignMode,
95 last_node_id: Arc<parking_lot::Mutex<Option<String>>>,
96 last_query_id: Arc<parking_lot::Mutex<Option<String>>>,
97}
98
99impl APIClient {
100 pub async fn new(dsn: &str, name: Option<String>) -> Result<Self> {
101 let mut client = Self::from_dsn(dsn).await?;
102 client.build_client(name).await?;
103 client.check_presign().await?;
104 if !client.disable_login {
105 client.login().await?;
106 }
107 Ok(client)
108 }
109
110 async fn from_dsn(dsn: &str) -> Result<Self> {
111 let u = Url::parse(dsn)?;
112 let mut client = Self::default();
113 if let Some(host) = u.host_str() {
114 client.host = host.to_string();
115 }
116
117 if u.username() != "" {
118 let password = u.password().unwrap_or_default();
119 let password = percent_decode_str(password).decode_utf8()?;
120 client.auth = Arc::new(BasicAuth::new(u.username(), password));
121 }
122 let database = match u.path().trim_start_matches('/') {
123 "" => None,
124 s => Some(s.to_string()),
125 };
126 let mut role = None;
127 let mut scheme = "https";
128 let mut session_settings = BTreeMap::new();
129 for (k, v) in u.query_pairs() {
130 match k.as_ref() {
131 "wait_time_secs" => {
132 client.wait_time_secs = Some(v.parse()?);
133 }
134 "max_rows_in_buffer" => {
135 client.max_rows_in_buffer = Some(v.parse()?);
136 }
137 "max_rows_per_page" => {
138 client.max_rows_per_page = Some(v.parse()?);
139 }
140 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
141 "page_request_timeout_secs" => {
142 client.page_request_timeout = {
143 let secs: u64 = v.parse()?;
144 Duration::from_secs(secs)
145 };
146 }
147 "presign" => {
148 client.presign = match v.as_ref() {
149 "auto" => PresignMode::Auto,
150 "detect" => PresignMode::Detect,
151 "on" => PresignMode::On,
152 "off" => PresignMode::Off,
153 _ => {
154 return Err(Error::BadArgument(format!(
155 "Invalid value for presign: {}, should be one of auto/detect/on/off",
156 v
157 )))
158 }
159 }
160 }
161 "tenant" => {
162 client.tenant = Some(v.to_string());
163 }
164 "warehouse" => {
165 client.warehouse = Arc::new(parking_lot::Mutex::new(Some(v.to_string())));
166 }
167 "role" => role = Some(v.to_string()),
168 "sslmode" => match v.as_ref() {
169 "disable" => scheme = "http",
170 "require" | "enable" => scheme = "https",
171 _ => {
172 return Err(Error::BadArgument(format!(
173 "Invalid value for sslmode: {}",
174 v
175 )))
176 }
177 },
178 "tls_ca_file" => {
179 client.tls_ca_file = Some(v.to_string());
180 }
181 "access_token" => {
182 client.auth = Arc::new(AccessTokenAuth::new(v));
183 }
184 "access_token_file" => {
185 client.auth = Arc::new(AccessTokenFileAuth::new(v));
186 }
187 "login" => {
188 client.disable_login = match v.as_ref() {
189 "disable" => true,
190 "enable" => false,
191 _ => {
192 return Err(Error::BadArgument(format!(
193 "Invalid value for login: {}",
194 v
195 )))
196 }
197 }
198 }
199 "session_token" => {
200 client.disable_session_token = match v.as_ref() {
201 "disable" => true,
202 "enable" => false,
203 _ => {
204 return Err(Error::BadArgument(format!(
205 "Invalid value for session_token: {}",
206 v
207 )))
208 }
209 }
210 }
211 _ => {
212 session_settings.insert(k.to_string(), v.to_string());
213 }
214 }
215 }
216 client.port = match u.port() {
217 Some(p) => p,
218 None => match scheme {
219 "http" => 80,
220 "https" => 443,
221 _ => unreachable!(),
222 },
223 };
224 client.scheme = scheme.to_string();
225
226 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
227 client.session_state = Arc::new(parking_lot::Mutex::new(
228 SessionState::default()
229 .with_settings(Some(session_settings))
230 .with_role(role)
231 .with_database(database),
232 ));
233
234 Ok(client)
235 }
236
237 pub fn host(&self) -> &str {
238 self.host.as_str()
239 }
240
241 pub fn port(&self) -> u16 {
242 self.port
243 }
244
245 pub fn scheme(&self) -> &str {
246 self.scheme.as_str()
247 }
248
249 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
250 let ua = match name {
251 Some(n) => n,
252 None => format!("databend-client-rust/{}", VERSION.as_str()),
253 };
254 let cookie_provider = GlobalCookieStore::new();
255 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
256 let mut initial_cookies = [&cookie].into_iter();
257 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
258 let mut cli_builder = HttpClient::builder()
259 .user_agent(ua)
260 .cookie_provider(Arc::new(cookie_provider))
261 .pool_idle_timeout(Duration::from_secs(1));
262 #[cfg(any(feature = "rustls", feature = "native-tls"))]
263 if self.scheme == "https" {
264 if let Some(ref ca_file) = self.tls_ca_file {
265 let cert_pem = tokio::fs::read(ca_file).await?;
266 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
267 cli_builder = cli_builder.add_root_certificate(cert);
268 }
269 }
270 self.cli = cli_builder.build()?;
271 Ok(())
272 }
273
274 async fn check_presign(&mut self) -> Result<()> {
275 match self.presign {
276 PresignMode::Auto => {
277 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
278 self.presign = PresignMode::On;
279 } else {
280 self.presign = PresignMode::Off;
281 }
282 }
283 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
284 Ok(_) => self.presign = PresignMode::On,
285 Err(e) => {
286 warn!("presign mode off with error detected: {}", e);
287 self.presign = PresignMode::Off;
288 }
289 },
290 _ => {}
291 }
292 Ok(())
293 }
294
295 pub fn current_warehouse(&self) -> Option<String> {
296 let guard = self.warehouse.lock();
297 guard.clone()
298 }
299
300 pub fn current_database(&self) -> Option<String> {
301 let guard = self.session_state.lock();
302 guard.database.clone()
303 }
304
305 pub async fn current_role(&self) -> Option<String> {
306 let guard = self.session_state.lock();
307 guard.role.clone()
308 }
309
310 fn in_active_transaction(&self) -> bool {
311 let guard = self.session_state.lock();
312 guard
313 .txn_state
314 .as_ref()
315 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
316 .unwrap_or(false)
317 }
318
319 pub fn username(&self) -> String {
320 self.auth.username()
321 }
322
323 fn gen_query_id(&self) -> String {
324 uuid::Uuid::new_v4().to_string()
325 }
326
327 async fn handle_session(&self, session: &Option<SessionState>) {
328 let session = match session {
329 Some(session) => session,
330 None => return,
331 };
332
333 {
335 let mut session_state = self.session_state.lock();
336 *session_state = session.clone();
337 }
338
339 if let Some(settings) = session.settings.as_ref() {
341 if let Some(v) = settings.get("warehouse") {
342 let mut warehouse = self.warehouse.lock();
343 *warehouse = Some(v.clone());
344 }
345 }
346 }
347
348 pub fn set_last_node_id(&self, node_id: String) {
349 *self.last_node_id.lock() = Some(node_id)
350 }
351
352 pub fn set_last_query_id(&self, query_id: Option<String>) {
353 *self.last_query_id.lock() = query_id
354 }
355
356 pub fn last_query_id(&self) -> Option<String> {
357 self.last_query_id.lock().clone()
358 }
359
360 fn last_node_id(&self) -> Option<String> {
361 self.last_node_id.lock().clone()
362 }
363
364 fn handle_warnings(&self, resp: &QueryResponse) {
365 if let Some(warnings) = &resp.warnings {
366 for w in warnings {
367 warn!(target: "server_warnings", "server warning: {}", w);
368 }
369 }
370 }
371
372 pub async fn start_query(&self, sql: &str) -> Result<QueryResponse> {
373 info!("start query: {}", sql);
374 self.start_query_inner(sql, None).await
375 }
376
377 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
378 if let Some(info) = &self.session_token_info {
379 let info = info.lock();
380 Ok(builder.bearer_auth(info.0.session_token.clone()))
381 } else {
382 self.auth.wrap(builder)
383 }
384 }
385
386 async fn start_query_inner(
387 &self,
388 sql: &str,
389 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
390 ) -> Result<QueryResponse> {
391 if !self.in_active_transaction() {
392 self.route_hint.next();
393 }
394 let endpoint = self.endpoint.join("v1/query")?;
395
396 let session_state = self.session_state();
398 let need_sticky = session_state.need_sticky.unwrap_or(false);
399 let req = QueryRequest::new(sql)
400 .with_pagination(self.make_pagination())
401 .with_session(Some(session_state))
402 .with_stage_attachment(stage_attachment_config);
403
404 let query_id = self.gen_query_id();
406 let mut headers = self.make_headers(Some(&query_id))?;
407 if need_sticky {
408 if let Some(node_id) = self.last_node_id() {
409 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
410 }
411 }
412 let mut builder = self.cli.post(endpoint.clone()).json(&req);
413 builder = self.wrap_auth_or_session_token(builder)?;
414 let request = builder.headers(headers.clone()).build()?;
415 let response = self.query_request_helper(request, true, true).await?;
416 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
417 self.route_hint.set(route_hint.to_str().unwrap_or_default());
418 }
419 let body = response.bytes().await?;
420 let result: QueryResponse = json_from_slice(&body)?;
421 self.handle_session(&result.session).await;
422 if let Some(err) = result.error {
423 return Err(Error::QueryFailed(err));
424 }
425
426 self.set_last_query_id(Some(query_id));
427 self.handle_warnings(&result);
428 Ok(result)
429 }
430
431 pub async fn query_page(
432 &self,
433 query_id: &str,
434 next_uri: &str,
435 node_id: &Option<String>,
436 ) -> Result<QueryResponse> {
437 info!("query page: {}", next_uri);
438 let endpoint = self.endpoint.join(next_uri)?;
439 let headers = self.make_headers(Some(query_id))?;
440 let mut builder = self.cli.get(endpoint.clone());
441 builder = self
442 .wrap_auth_or_session_token(builder)?
443 .headers(headers.clone())
444 .timeout(self.page_request_timeout);
445 if let Some(node_id) = node_id {
446 builder = builder.header(HEADER_STICKY_NODE, node_id)
447 }
448 let request = builder.build()?;
449
450 let response = self.query_request_helper(request, false, true).await?;
451 let body = response.bytes().await?;
452 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
453 if let Error::Logic(status, ec) = &e {
454 if *status == 404 {
455 return Error::QueryNotFound(ec.message.clone());
456 }
457 }
458 e
459 })?;
460 self.handle_session(&resp.session).await;
461 match resp.error {
462 Some(err) => Err(Error::QueryFailed(err)),
463 None => Ok(resp),
464 }
465 }
466
467 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
468 let kill_uri = format!("/v1/query/{}/kill", query_id);
469 let endpoint = self.endpoint.join(&kill_uri)?;
470 let headers = self.make_headers(Some(query_id))?;
471 info!("kill query: {}", kill_uri);
472
473 let mut builder = self.cli.post(endpoint);
474 builder = self.wrap_auth_or_session_token(builder)?;
475 let resp = builder.headers(headers.clone()).send().await?;
476 if resp.status() != 200 {
477 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
478 .with_context("kill query"));
479 }
480 Ok(())
481 }
482
483 async fn wait_for_query(&self, resp: QueryResponse) -> Result<QueryResponse> {
484 info!("wait for query: {}", resp.id);
485 let node_id = resp.node_id.clone();
486 if let Some(node_id) = self.last_node_id() {
487 self.set_last_node_id(node_id.clone());
488 }
489
490 if let Some(next_uri) = &resp.next_uri {
491 let schema = resp.schema;
492 let mut data = resp.data;
493 let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?;
494 while let Some(next_uri) = &resp.next_uri {
495 resp = self.query_page(&resp.id, next_uri, &node_id).await?;
496 data.append(&mut resp.data);
497 }
498 resp.schema = schema;
499 resp.data = data;
500 Ok(resp)
501 } else {
502 Ok(resp)
503 }
504 }
505
506 pub async fn query(&self, sql: &str) -> Result<QueryResponse> {
507 let resp = self.start_query(sql).await?;
508 self.wait_for_query(resp).await
509 }
510
511 fn session_state(&self) -> SessionState {
512 self.session_state.lock().clone()
513 }
514
515 fn make_pagination(&self) -> Option<PaginationConfig> {
516 if self.wait_time_secs.is_none()
517 && self.max_rows_in_buffer.is_none()
518 && self.max_rows_per_page.is_none()
519 {
520 return None;
521 }
522 let mut pagination = PaginationConfig {
523 wait_time_secs: None,
524 max_rows_in_buffer: None,
525 max_rows_per_page: None,
526 };
527 if let Some(wait_time_secs) = self.wait_time_secs {
528 pagination.wait_time_secs = Some(wait_time_secs);
529 }
530 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
531 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
532 }
533 if let Some(max_rows_per_page) = self.max_rows_per_page {
534 pagination.max_rows_per_page = Some(max_rows_per_page);
535 }
536 Some(pagination)
537 }
538
539 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
540 let mut headers = HeaderMap::new();
541 if let Some(tenant) = &self.tenant {
542 headers.insert(HEADER_TENANT, tenant.parse()?);
543 }
544 let warehouse = self.warehouse.lock().clone();
545 if let Some(warehouse) = warehouse {
546 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
547 }
548 let route_hint = self.route_hint.current();
549 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
550 if let Some(query_id) = query_id {
551 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
552 }
553 Ok(headers)
554 }
555
556 pub async fn insert_with_stage(
557 &self,
558 sql: &str,
559 stage: &str,
560 file_format_options: BTreeMap<&str, &str>,
561 copy_options: BTreeMap<&str, &str>,
562 ) -> Result<QueryResponse> {
563 info!(
564 "insert with stage: {}, format: {:?}, copy: {:?}",
565 sql, file_format_options, copy_options
566 );
567 let stage_attachment = Some(StageAttachmentConfig {
568 location: stage,
569 file_format_options: Some(file_format_options),
570 copy_options: Some(copy_options),
571 });
572 let resp = self.start_query_inner(sql, stage_attachment).await?;
573 let resp = self.wait_for_query(resp).await?;
574 Ok(resp)
575 }
576
577 async fn get_presigned_upload_url(&self, stage: &str) -> Result<PresignedResponse> {
578 info!("get presigned upload url: {}", stage);
579 let sql = format!("PRESIGN UPLOAD {}", stage);
580 let resp = self.query(&sql).await?;
581 if resp.data.len() != 1 {
582 return Err(Error::Decode(
583 "Empty response from server for presigned request".to_string(),
584 ));
585 }
586 if resp.data[0].len() != 3 {
587 return Err(Error::Decode(
588 "Invalid response from server for presigned request".to_string(),
589 ));
590 }
591 let method = resp.data[0][0].clone().unwrap_or_default();
593 if method != "PUT" {
594 return Err(Error::Decode(format!(
595 "Invalid method for presigned upload request: {}",
596 method
597 )));
598 }
599 let headers: BTreeMap<String, String> =
600 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
601 let url = resp.data[0][2].clone().unwrap_or_default();
602 Ok(PresignedResponse {
603 method,
604 headers,
605 url,
606 })
607 }
608
609 pub async fn upload_to_stage(&self, stage: &str, data: Reader, size: u64) -> Result<()> {
610 match self.presign {
611 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
612 PresignMode::On => {
613 let presigned = self.get_presigned_upload_url(stage).await?;
614 presign_upload_to_stage(presigned, data, size).await
615 }
616 PresignMode::Auto => {
617 unreachable!("PresignMode::Auto should be handled during client initialization")
618 }
619 PresignMode::Detect => {
620 unreachable!("PresignMode::Detect should be handled during client initialization")
621 }
622 }
623 }
624
625 async fn upload_to_stage_with_stream(
627 &self,
628 stage: &str,
629 data: Reader,
630 size: u64,
631 ) -> Result<()> {
632 info!("upload to stage with stream: {}, size: {}", stage, size);
633 if let Some(info) = self.need_pre_refresh_session().await {
634 self.refresh_session_token(info).await?;
635 }
636 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
637 let location = StageLocation::try_from(stage)?;
638 let query_id = self.gen_query_id();
639 let mut headers = self.make_headers(Some(&query_id))?;
640 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
641 let stream = Body::wrap_stream(ReaderStream::new(data));
642 let part = Part::stream_with_length(stream, size).file_name(location.path);
643 let form = Form::new().part("upload", part);
644 let mut builder = self.cli.put(endpoint.clone());
645 builder = self.wrap_auth_or_session_token(builder)?;
646 let resp = builder.headers(headers).multipart(form).send().await?;
647 let status = resp.status();
648 if status != 200 {
649 return Err(
650 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
651 );
652 }
653 Ok(())
654 }
655
656 async fn login(&mut self) -> Result<()> {
657 let endpoint = self.endpoint.join("/v1/session/login")?;
658 let headers = self.make_headers(None)?;
659 let body = LoginRequest::from(&*self.session_state.lock());
660 let mut builder = self.cli.post(endpoint.clone()).json(&body);
661 if self.disable_session_token {
662 builder = builder.query(&[("disable_session_token", true)]);
663 }
664 let builder = self.auth.wrap(builder)?;
665 let request = builder
666 .headers(headers.clone())
667 .timeout(self.connect_timeout)
668 .build()?;
669 let response = self.query_request_helper(request, true, false).await;
670 let response = match response {
671 Ok(r) => r,
672 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
673 info!("login return 404, skip login on the old version server");
674 return Ok(());
675 }
676 Err(e) => return Err(e),
677 };
678 let body = response.bytes().await?;
679 let response = json_from_slice(&body)?;
680 match response {
681 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
682 LoginResponseResult::Ok(info) => {
683 self.server_version = Some(info.version.clone());
684 if let Some(tokens) = info.tokens {
685 info!("login success with session token");
686 self.session_token_info =
687 Some(Arc::new(parking_lot::Mutex::new((tokens, Instant::now()))))
688 }
689 info!("login success without session token");
690 }
691 }
692 Ok(())
693 }
694
695 fn build_log_out_request(&self) -> Result<Request> {
696 let endpoint = self.endpoint.join("/v1/session/logout")?;
697
698 let session_state = self.session_state();
699 let need_sticky = session_state.need_sticky.unwrap_or(false);
700 let mut headers = self.make_headers(None)?;
701 if need_sticky {
702 if let Some(node_id) = self.last_node_id() {
703 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
704 }
705 }
706 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
707
708 let builder = self.wrap_auth_or_session_token(builder)?;
709 let req = builder.build()?;
710 Ok(req)
711 }
712
713 fn need_logout(&self) -> bool {
714 (self.session_token_info.is_some()
715 || self.session_state.lock().need_keep_alive.unwrap_or(false))
716 && self
717 .closed
718 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
719 .is_ok()
720 }
721
722 async fn refresh_session_token(
723 &self,
724 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
725 ) -> Result<()> {
726 let (session_token_info, _) = { self_login_info.lock().clone() };
727 let endpoint = self.endpoint.join("/v1/session/refresh")?;
728 let body = RefreshSessionTokenRequest {
729 session_token: session_token_info.session_token.clone(),
730 };
731 let headers = self.make_headers(None)?;
732 let request = self
733 .cli
734 .post(endpoint.clone())
735 .json(&body)
736 .headers(headers.clone())
737 .bearer_auth(session_token_info.refresh_token.clone())
738 .timeout(self.connect_timeout)
739 .build()?;
740
741 for i in 0..3 {
743 let req = request.try_clone().expect("request not cloneable");
744 match self.cli.execute(req).await {
745 Ok(response) => {
746 let status = response.status();
747 let body = response.bytes().await?;
748 if status == StatusCode::OK {
749 let response = json_from_slice(&body)?;
750 return match response {
751 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
752 RefreshResponse::Ok(info) => {
753 *self_login_info.lock() = (info, Instant::now());
754 Ok(())
755 }
756 };
757 }
758 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
759 return Err(Error::response_error(status, &body));
760 }
761 }
762 Err(err) => {
763 if !(err.is_timeout() || err.is_connect()) || i > 2 {
764 return Err(Error::Request(err.to_string()));
765 }
766 }
767 };
768 sleep(jitter(Duration::from_secs(10))).await;
769 }
770 Ok(())
771 }
772
773 async fn need_pre_refresh_session(
774 &self,
775 ) -> Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>> {
776 if let Some(info) = &self.session_token_info {
777 let (start, ttl) = {
778 let guard = info.lock();
779 (guard.1, guard.0.session_token_ttl_in_secs)
780 };
781 if Instant::now() > start + Duration::from_secs(ttl) {
782 return Some(info.clone());
783 }
784 }
785 None
786 }
787
788 async fn query_request_helper(
796 &self,
797 mut request: Request,
798 retry_if_503: bool,
799 refresh_if_401: bool,
800 ) -> std::result::Result<Response, Error> {
801 let mut refreshed = false;
802 let mut retries = 0;
803 loop {
804 let req = request.try_clone().expect("request not cloneable");
805 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
806 Ok(response) => {
807 let status = response.status();
808 if status == StatusCode::OK {
809 return Ok(response);
810 }
811 let body = response.bytes().await?;
812 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
813 (Error::response_error(status, &body), true)
815 } else {
816 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
817 match resp {
818 Ok(r) => {
819 let e = r.error;
820 if status == StatusCode::UNAUTHORIZED {
821 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
822 if let Some(session_token_info) = &self.session_token_info {
823 info!(
824 "will retry {} after refresh token on auth error {}",
825 request.url(),
826 e
827 );
828 let retry = if need_refresh_token(e.code)
829 && !refreshed
830 && refresh_if_401
831 {
832 self.refresh_session_token(session_token_info.clone())
833 .await?;
834 refreshed = true;
835 true
836 } else {
837 false
838 };
839 (Error::AuthFailure(e), retry)
840 } else if self.auth.can_reload() {
841 info!(
842 "will retry {} after reload token on auth error {}",
843 request.url(),
844 e
845 );
846 let builder = RequestBuilder::from_parts(
847 HttpClient::new(),
848 request.try_clone().unwrap(),
849 );
850 let builder = self.auth.wrap(builder)?;
851 request = builder.build()?;
852 (Error::AuthFailure(e), true)
853 } else {
854 (Error::AuthFailure(e), false)
855 }
856 } else {
857 (Error::Logic(status, e), false)
858 }
859 }
860 Err(_) => (
861 Error::Response {
862 status,
863 msg: String::from_utf8_lossy(&body).to_string(),
864 },
865 false,
866 ),
867 }
868 }
869 }
870 Err(err) => (
871 Error::Request(err.to_string()),
872 err.is_timeout() || err.is_connect(),
873 ),
874 };
875 if !retry {
876 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
877 }
878 match &err {
879 Error::AuthFailure(_) => {
880 if refreshed {
881 retries = 0;
882 } else if retries == 2 {
883 return Err(err.with_context(&format!(
884 "{} {} after 3 reties",
885 request.method(),
886 request.url()
887 )));
888 }
889 }
890 _ => {
891 if retries == 2 {
892 return Err(err.with_context(&format!(
893 "{} {} after 3 reties",
894 request.method(),
895 request.url()
896 )));
897 }
898 retries += 1;
899 info!(
900 "will retry {} the {retries}th times on error {}",
901 request.url(),
902 err
903 );
904 }
905 }
906 sleep(jitter(Duration::from_secs(10))).await;
907 }
908 }
909
910 pub async fn close(&self) {
911 if self.need_logout() {
912 let cli = self.cli.clone();
913 let req = self
914 .build_log_out_request()
915 .expect("failed to build logout request");
916 if let Err(err) = cli.execute(req).await {
917 error!("logout request failed: {}", err);
918 } else {
919 debug!("logout success");
920 };
921 }
922 }
923}
924
925impl Drop for APIClient {
926 fn drop(&mut self) {
927 if self.need_logout() {
928 warn!("APIClient::close() was not called");
929 }
930 }
931}
932
933fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
934where
935 T: Deserialize<'a>,
936{
937 serde_json::from_slice::<T>(body).map_err(|e| {
938 Error::Decode(format!(
939 "fail to decode JSON response: {}, body: {}",
940 e,
941 String::from_utf8_lossy(body)
942 ))
943 })
944}
945
946impl Default for APIClient {
947 fn default() -> Self {
948 Self {
949 cli: HttpClient::new(),
950 scheme: "http".to_string(),
951 endpoint: Url::parse("http://localhost:8080").unwrap(),
952 host: "localhost".to_string(),
953 port: 8000,
954 tenant: None,
955 warehouse: Arc::new(parking_lot::Mutex::new(None)),
956 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
957 session_state: Arc::new(parking_lot::Mutex::new(SessionState::default())),
958 wait_time_secs: None,
959 max_rows_in_buffer: None,
960 max_rows_per_page: None,
961 connect_timeout: Duration::from_secs(10),
962 page_request_timeout: Duration::from_secs(30),
963 tls_ca_file: None,
964 presign: PresignMode::Auto,
965 route_hint: Arc::new(RouteHintGenerator::new()),
966 last_node_id: Arc::new(Default::default()),
967 disable_session_token: true,
968 disable_login: false,
969 session_token_info: None,
970 closed: Arc::new(Default::default()),
971 last_query_id: Arc::new(Default::default()),
972 server_version: None,
973 }
974 }
975}
976
977struct RouteHintGenerator {
978 nonce: AtomicU64,
979 current: std::sync::Mutex<String>,
980}
981
982impl RouteHintGenerator {
983 fn new() -> Self {
984 let gen = Self {
985 nonce: AtomicU64::new(0),
986 current: std::sync::Mutex::new("".to_string()),
987 };
988 gen.next();
989 gen
990 }
991
992 fn current(&self) -> String {
993 let guard = self.current.lock().unwrap();
994 guard.clone()
995 }
996
997 fn set(&self, hint: &str) {
998 let mut guard = self.current.lock().unwrap();
999 *guard = hint.to_string();
1000 }
1001
1002 fn next(&self) -> String {
1003 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1004 let uuid = uuid::Uuid::new_v4();
1005 let current = format!("rh:{}:{:06}", uuid, nonce);
1006 let mut guard = self.current.lock().unwrap();
1007 guard.clone_from(¤t);
1008 current
1009 }
1010}
1011
1012#[cfg(test)]
1013mod test {
1014 use super::*;
1015
1016 #[tokio::test]
1017 async fn parse_dsn() -> Result<()> {
1018 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1019 let client = APIClient::from_dsn(dsn).await?;
1020 assert_eq!(client.host, "app.databend.com");
1021 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1022 assert_eq!(client.wait_time_secs, Some(10));
1023 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1024 assert_eq!(client.max_rows_per_page, Some(10000));
1025 assert_eq!(client.tenant, None);
1026 assert_eq!(
1027 *client.warehouse.try_lock().unwrap(),
1028 Some("wh".to_string())
1029 );
1030 Ok(())
1031 }
1032
1033 #[tokio::test]
1034 async fn parse_encoded_password() -> Result<()> {
1035 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1036 let client = APIClient::from_dsn(dsn).await?;
1037 assert_eq!(client.host(), "localhost");
1038 assert_eq!(client.port(), 443);
1039 Ok(())
1040 }
1041
1042 #[tokio::test]
1043 async fn parse_special_chars_password() -> Result<()> {
1044 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1045 let client = APIClient::from_dsn(dsn).await?;
1046 assert_eq!(client.host(), "localhost");
1047 assert_eq!(client.port(), 8000);
1048 Ok(())
1049 }
1050}