1use std::collections::BTreeMap;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
21use crate::capability::Capability;
22use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
23use crate::global_cookie_store::GlobalCookieStore;
24use crate::login::{
25 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
26 SessionTokenInfo,
27};
28use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
29use crate::response::LoadResponse;
30use crate::stage::StageLocation;
31use crate::{
32 error::{Error, Result},
33 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
34 response::QueryResponse,
35 session::SessionState,
36 QueryStats,
37};
38use crate::{Page, Pages};
39use log::{debug, error, info, warn};
40use once_cell::sync::Lazy;
41use parking_lot::Mutex;
42use percent_encoding::percent_decode_str;
43use reqwest::cookie::CookieStore;
44use reqwest::header::{HeaderMap, HeaderValue};
45use reqwest::multipart::{Form, Part};
46use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
47use semver::Version;
48use serde::Deserialize;
49use tokio::time::sleep;
50use tokio_retry::strategy::jitter;
51use tokio_stream::StreamExt;
52use tokio_util::io::ReaderStream;
53use url::Url;
54
55const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
56const HEADER_TENANT: &str = "X-DATABEND-TENANT";
57const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
58const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
59const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
60const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
61const TXN_STATE_ACTIVE: &str = "Active";
62const HEADER_SQL: &str = "X-DATABEND-SQL";
63
64static VERSION: Lazy<String> = Lazy::new(|| {
65 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
66 version.to_string()
67});
68
69pub struct APIClient {
70 cli: HttpClient,
71 scheme: String,
72 host: String,
73 port: u16,
74
75 endpoint: Url,
76
77 auth: Arc<dyn Auth>,
78
79 tenant: Option<String>,
80 warehouse: Mutex<Option<String>>,
81 session_state: Mutex<SessionState>,
82 route_hint: RouteHintGenerator,
83
84 disable_login: bool,
85 disable_session_token: bool,
86 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
87
88 closed: AtomicBool,
89
90 server_version: Option<Version>,
91
92 wait_time_secs: Option<i64>,
93 max_rows_in_buffer: Option<i64>,
94 max_rows_per_page: Option<i64>,
95
96 connect_timeout: Duration,
97 page_request_timeout: Duration,
98
99 tls_ca_file: Option<String>,
100
101 presign: Mutex<PresignMode>,
102 last_node_id: Mutex<Option<String>>,
103 last_query_id: Mutex<Option<String>>,
104
105 capability: Capability,
106}
107
108impl APIClient {
109 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
110 let mut client = Self::from_dsn(dsn).await?;
111 client.build_client(name).await?;
112 if !client.disable_login {
113 client.login().await?;
114 }
115 let client = Arc::new(client);
116 client.check_presign().await?;
117 Ok(client)
118 }
119
120 pub fn capability(&self) -> &Capability {
121 &self.capability
122 }
123
124 fn set_presign_mode(&self, mode: PresignMode) {
125 *self.presign.lock() = mode
126 }
127 fn get_presign_mode(&self) -> PresignMode {
128 *self.presign.lock()
129 }
130
131 async fn from_dsn(dsn: &str) -> Result<Self> {
132 let u = Url::parse(dsn)?;
133 let mut client = Self::default();
134 if let Some(host) = u.host_str() {
135 client.host = host.to_string();
136 }
137
138 if u.username() != "" {
139 let password = u.password().unwrap_or_default();
140 let password = percent_decode_str(password).decode_utf8()?;
141 client.auth = Arc::new(BasicAuth::new(u.username(), password));
142 }
143 let database = match u.path().trim_start_matches('/') {
144 "" => None,
145 s => Some(s.to_string()),
146 };
147 let mut role = None;
148 let mut scheme = "https";
149 let mut session_settings = BTreeMap::new();
150 for (k, v) in u.query_pairs() {
151 match k.as_ref() {
152 "wait_time_secs" => {
153 client.wait_time_secs = Some(v.parse()?);
154 }
155 "max_rows_in_buffer" => {
156 client.max_rows_in_buffer = Some(v.parse()?);
157 }
158 "max_rows_per_page" => {
159 client.max_rows_per_page = Some(v.parse()?);
160 }
161 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
162 "page_request_timeout_secs" => {
163 client.page_request_timeout = {
164 let secs: u64 = v.parse()?;
165 Duration::from_secs(secs)
166 };
167 }
168 "presign" => {
169 let presign_mode = match v.as_ref() {
170 "auto" => PresignMode::Auto,
171 "detect" => PresignMode::Detect,
172 "on" => PresignMode::On,
173 "off" => PresignMode::Off,
174 _ => {
175 return Err(Error::BadArgument(format!(
176 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
177 )))
178 }
179 };
180 client.set_presign_mode(presign_mode);
181 }
182 "tenant" => {
183 client.tenant = Some(v.to_string());
184 }
185 "warehouse" => {
186 client.warehouse = Mutex::new(Some(v.to_string()));
187 }
188 "role" => role = Some(v.to_string()),
189 "sslmode" => match v.as_ref() {
190 "disable" => scheme = "http",
191 "require" | "enable" => scheme = "https",
192 _ => {
193 return Err(Error::BadArgument(format!(
194 "Invalid value for sslmode: {v}"
195 )))
196 }
197 },
198 "tls_ca_file" => {
199 client.tls_ca_file = Some(v.to_string());
200 }
201 "access_token" => {
202 client.auth = Arc::new(AccessTokenAuth::new(v));
203 }
204 "access_token_file" => {
205 client.auth = Arc::new(AccessTokenFileAuth::new(v));
206 }
207 "login" => {
208 client.disable_login = match v.as_ref() {
209 "disable" => true,
210 "enable" => false,
211 _ => {
212 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
213 }
214 }
215 }
216 "session_token" => {
217 client.disable_session_token = match v.as_ref() {
218 "disable" => true,
219 "enable" => false,
220 _ => {
221 return Err(Error::BadArgument(format!(
222 "Invalid value for session_token: {v}"
223 )))
224 }
225 }
226 }
227 _ => {
228 session_settings.insert(k.to_string(), v.to_string());
229 }
230 }
231 }
232 client.port = match u.port() {
233 Some(p) => p,
234 None => match scheme {
235 "http" => 80,
236 "https" => 443,
237 _ => unreachable!(),
238 },
239 };
240 client.scheme = scheme.to_string();
241
242 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
243 client.session_state = Mutex::new(
244 SessionState::default()
245 .with_settings(Some(session_settings))
246 .with_role(role)
247 .with_database(database),
248 );
249
250 Ok(client)
251 }
252
253 pub fn host(&self) -> &str {
254 self.host.as_str()
255 }
256
257 pub fn port(&self) -> u16 {
258 self.port
259 }
260
261 pub fn scheme(&self) -> &str {
262 self.scheme.as_str()
263 }
264
265 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
266 let ua = match name {
267 Some(n) => n,
268 None => format!("databend-client-rust/{}", VERSION.as_str()),
269 };
270 let cookie_provider = GlobalCookieStore::new();
271 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
272 let mut initial_cookies = [&cookie].into_iter();
273 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
274 let mut cli_builder = HttpClient::builder()
275 .user_agent(ua)
276 .cookie_provider(Arc::new(cookie_provider))
277 .pool_idle_timeout(Duration::from_secs(1));
278 #[cfg(any(feature = "rustls", feature = "native-tls"))]
279 if self.scheme == "https" {
280 if let Some(ref ca_file) = self.tls_ca_file {
281 let cert_pem = tokio::fs::read(ca_file).await?;
282 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
283 cli_builder = cli_builder.add_root_certificate(cert);
284 }
285 }
286 self.cli = cli_builder.build()?;
287 Ok(())
288 }
289
290 async fn check_presign(self: &Arc<Self>) -> Result<()> {
291 let mode = match self.get_presign_mode() {
292 PresignMode::Auto => {
293 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
294 PresignMode::On
295 } else {
296 PresignMode::Off
297 }
298 }
299 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
300 Ok(_) => PresignMode::On,
301 Err(e) => {
302 warn!("presign mode off with error detected: {e}");
303 PresignMode::Off
304 }
305 },
306 mode => mode,
307 };
308 self.set_presign_mode(mode);
309 Ok(())
310 }
311
312 pub fn current_warehouse(&self) -> Option<String> {
313 let guard = self.warehouse.lock();
314 guard.clone()
315 }
316
317 pub fn current_catalog(&self) -> Option<String> {
318 let guard = self.session_state.lock();
319 guard.catalog.clone()
320 }
321
322 pub fn current_database(&self) -> Option<String> {
323 let guard = self.session_state.lock();
324 guard.database.clone()
325 }
326
327 pub async fn current_role(&self) -> Option<String> {
328 let guard = self.session_state.lock();
329 guard.role.clone()
330 }
331
332 fn in_active_transaction(&self) -> bool {
333 let guard = self.session_state.lock();
334 guard
335 .txn_state
336 .as_ref()
337 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
338 .unwrap_or(false)
339 }
340
341 pub fn username(&self) -> String {
342 self.auth.username()
343 }
344
345 fn gen_query_id(&self) -> String {
346 uuid::Uuid::now_v7().simple().to_string()
347 }
348
349 async fn handle_session(&self, session: &Option<SessionState>) {
350 let session = match session {
351 Some(session) => session,
352 None => return,
353 };
354
355 {
357 let mut session_state = self.session_state.lock();
358 *session_state = session.clone();
359 }
360
361 if let Some(settings) = session.settings.as_ref() {
363 if let Some(v) = settings.get("warehouse") {
364 let mut warehouse = self.warehouse.lock();
365 *warehouse = Some(v.clone());
366 }
367 }
368 }
369
370 pub fn set_last_node_id(&self, node_id: String) {
371 *self.last_node_id.lock() = Some(node_id)
372 }
373
374 pub fn set_last_query_id(&self, query_id: Option<String>) {
375 *self.last_query_id.lock() = query_id
376 }
377
378 pub fn last_query_id(&self) -> Option<String> {
379 self.last_query_id.lock().clone()
380 }
381
382 fn last_node_id(&self) -> Option<String> {
383 self.last_node_id.lock().clone()
384 }
385
386 fn handle_warnings(&self, resp: &QueryResponse) {
387 if let Some(warnings) = &resp.warnings {
388 for w in warnings {
389 warn!(target: "server_warnings", "server warning: {w}");
390 }
391 }
392 }
393
394 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
395 info!("start query: {sql}");
396 let resp = self.start_query_inner(sql, None).await?;
397 let pages = Pages::new(self.clone(), resp, need_progress);
398 Ok(pages)
399 }
400
401 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
402 if let Some(info) = &self.session_token_info {
403 let info = info.lock();
404 Ok(builder.bearer_auth(info.0.session_token.clone()))
405 } else {
406 self.auth.wrap(builder)
407 }
408 }
409
410 async fn start_query_inner(
411 &self,
412 sql: &str,
413 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
414 ) -> Result<QueryResponse> {
415 if !self.in_active_transaction() {
416 self.route_hint.next();
417 }
418 let endpoint = self.endpoint.join("v1/query")?;
419
420 let session_state = self.session_state();
422 let need_sticky = session_state.need_sticky.unwrap_or(false);
423 let req = QueryRequest::new(sql)
424 .with_pagination(self.make_pagination())
425 .with_session(Some(session_state))
426 .with_stage_attachment(stage_attachment_config);
427
428 let query_id = self.gen_query_id();
430 let mut headers = self.make_headers(Some(&query_id))?;
431 if need_sticky {
432 if let Some(node_id) = self.last_node_id() {
433 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
434 }
435 }
436 let mut builder = self.cli.post(endpoint.clone()).json(&req);
437 builder = self.wrap_auth_or_session_token(builder)?;
438 let request = builder.headers(headers.clone()).build()?;
439 let response = self.query_request_helper(request, true, true).await?;
440 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
441 self.route_hint.set(route_hint.to_str().unwrap_or_default());
442 }
443 let body = response.bytes().await?;
444 let result: QueryResponse = json_from_slice(&body)?;
445 self.handle_session(&result.session).await;
446 if let Some(err) = result.error {
447 return Err(Error::QueryFailed(err));
448 }
449
450 self.set_last_query_id(Some(query_id));
451 self.handle_warnings(&result);
452 if let Some(node_id) = &result.node_id {
453 self.set_last_node_id(node_id.clone());
454 }
455 Ok(result)
456 }
457
458 pub async fn query_page(
459 &self,
460 query_id: &str,
461 next_uri: &str,
462 node_id: &Option<String>,
463 ) -> Result<QueryResponse> {
464 info!("query page: {next_uri}");
465 let endpoint = self.endpoint.join(next_uri)?;
466 let headers = self.make_headers(Some(query_id))?;
467 let mut builder = self.cli.get(endpoint.clone());
468 builder = self
469 .wrap_auth_or_session_token(builder)?
470 .headers(headers.clone())
471 .timeout(self.page_request_timeout);
472 if let Some(node_id) = node_id {
473 builder = builder.header(HEADER_STICKY_NODE, node_id)
474 }
475 let request = builder.build()?;
476
477 let response = self.query_request_helper(request, false, true).await?;
478 let status = response.status();
479 if status != 200 {
480 return Err(Error::response_error(status, &response.bytes().await?));
481 }
482 let body = response.bytes().await?;
483 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
484 if let Error::Logic(status, ec) = &e {
485 if *status == 404 {
486 return Error::QueryNotFound(ec.message.clone());
487 }
488 }
489 e
490 })?;
491 self.handle_session(&resp.session).await;
492 match resp.error {
493 Some(err) => Err(Error::QueryFailed(err)),
494 None => Ok(resp),
495 }
496 }
497
498 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
499 let kill_uri = format!("/v1/query/{query_id}/kill");
500 let endpoint = self.endpoint.join(&kill_uri)?;
501 let headers = self.make_headers(Some(query_id))?;
502 info!("kill query: {kill_uri}");
503
504 let mut builder = self.cli.post(endpoint);
505 builder = self.wrap_auth_or_session_token(builder)?;
506 let resp = builder.headers(headers.clone()).send().await?;
507 if resp.status() != 200 {
508 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
509 .with_context("kill query"));
510 }
511 Ok(())
512 }
513
514 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
515 let mut pages = self.start_query(sql, false).await?;
516 let mut all = Page::default();
517 while let Some(page) = pages.next().await {
518 all.update(page?);
519 }
520 Ok(all)
521 }
522
523 fn session_state(&self) -> SessionState {
524 self.session_state.lock().clone()
525 }
526
527 fn make_pagination(&self) -> Option<PaginationConfig> {
528 if self.wait_time_secs.is_none()
529 && self.max_rows_in_buffer.is_none()
530 && self.max_rows_per_page.is_none()
531 {
532 return None;
533 }
534 let mut pagination = PaginationConfig {
535 wait_time_secs: None,
536 max_rows_in_buffer: None,
537 max_rows_per_page: None,
538 };
539 if let Some(wait_time_secs) = self.wait_time_secs {
540 pagination.wait_time_secs = Some(wait_time_secs);
541 }
542 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
543 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
544 }
545 if let Some(max_rows_per_page) = self.max_rows_per_page {
546 pagination.max_rows_per_page = Some(max_rows_per_page);
547 }
548 Some(pagination)
549 }
550
551 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
552 let mut headers = HeaderMap::new();
553 if let Some(tenant) = &self.tenant {
554 headers.insert(HEADER_TENANT, tenant.parse()?);
555 }
556 let warehouse = self.warehouse.lock().clone();
557 if let Some(warehouse) = warehouse {
558 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
559 }
560 let route_hint = self.route_hint.current();
561 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
562 if let Some(query_id) = query_id {
563 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
564 }
565 Ok(headers)
566 }
567
568 pub async fn insert_with_stage(
569 self: &Arc<Self>,
570 sql: &str,
571 stage: &str,
572 file_format_options: BTreeMap<&str, &str>,
573 copy_options: BTreeMap<&str, &str>,
574 ) -> Result<QueryStats> {
575 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
576 let stage_attachment = Some(StageAttachmentConfig {
577 location: stage,
578 file_format_options: Some(file_format_options),
579 copy_options: Some(copy_options),
580 });
581 let resp = self.start_query_inner(sql, stage_attachment).await?;
582 let mut pages = Pages::new(self.clone(), resp, false);
583 let mut all = Page::default();
584 while let Some(page) = pages.next().await {
585 all.update(page?);
586 }
587 Ok(all.stats)
588 }
589
590 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
591 info!("get presigned upload url: {stage}");
592 let sql = format!("PRESIGN UPLOAD {stage}");
593 let resp = self.query_all(&sql).await?;
594 if resp.data.len() != 1 {
595 return Err(Error::Decode(
596 "Empty response from server for presigned request".to_string(),
597 ));
598 }
599 if resp.data[0].len() != 3 {
600 return Err(Error::Decode(
601 "Invalid response from server for presigned request".to_string(),
602 ));
603 }
604 let method = resp.data[0][0].clone().unwrap_or_default();
606 if method != "PUT" {
607 return Err(Error::Decode(format!(
608 "Invalid method for presigned upload request: {method}"
609 )));
610 }
611 let headers: BTreeMap<String, String> =
612 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
613 let url = resp.data[0][2].clone().unwrap_or_default();
614 Ok(PresignedResponse {
615 method,
616 headers,
617 url,
618 })
619 }
620
621 pub async fn upload_to_stage(
622 self: &Arc<Self>,
623 stage: &str,
624 data: Reader,
625 size: u64,
626 ) -> Result<()> {
627 match self.get_presign_mode() {
628 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
629 PresignMode::On => {
630 let presigned = self.get_presigned_upload_url(stage).await?;
631 presign_upload_to_stage(presigned, data, size).await
632 }
633 PresignMode::Auto => {
634 unreachable!("PresignMode::Auto should be handled during client initialization")
635 }
636 PresignMode::Detect => {
637 unreachable!("PresignMode::Detect should be handled during client initialization")
638 }
639 }
640 }
641
642 async fn upload_to_stage_with_stream(
644 &self,
645 stage: &str,
646 data: Reader,
647 size: u64,
648 ) -> Result<()> {
649 info!("upload to stage with stream: {stage}, size: {size}");
650 if let Some(info) = self.need_pre_refresh_session().await {
651 self.refresh_session_token(info).await?;
652 }
653 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
654 let location = StageLocation::try_from(stage)?;
655 let query_id = self.gen_query_id();
656 let mut headers = self.make_headers(Some(&query_id))?;
657 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
658 let stream = Body::wrap_stream(ReaderStream::new(data));
659 let part = Part::stream_with_length(stream, size).file_name(location.path);
660 let form = Form::new().part("upload", part);
661 let mut builder = self.cli.put(endpoint.clone());
662 builder = self.wrap_auth_or_session_token(builder)?;
663 let resp = builder.headers(headers).multipart(form).send().await?;
664 let status = resp.status();
665 if status != 200 {
666 return Err(
667 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
668 );
669 }
670 Ok(())
671 }
672
673 pub async fn streaming_load(
674 &self,
675 sql: &str,
676 data: Reader,
677 file_name: &str,
678 ) -> Result<LoadResponse> {
679 let body = Body::wrap_stream(ReaderStream::new(data));
680 let part = Part::stream(body).file_name(file_name.to_string());
681 let endpoint = self.endpoint.join("v1/streaming_load")?;
682 let mut builder = self.cli.put(endpoint.clone());
683 builder = self.wrap_auth_or_session_token(builder)?;
684 let query_id = self.gen_query_id();
685 let mut headers = self.make_headers(Some(&query_id))?;
686 headers.insert(HEADER_SQL, sql.parse()?);
687 let form = Form::new().part("upload", part);
688 let resp = builder.headers(headers).multipart(form).send().await?;
689 let status = resp.status();
690 if status != 200 {
691 return Err(
692 Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
693 );
694 }
695 let resp = resp.json::<LoadResponse>().await?;
696 Ok(resp)
697 }
698
699 async fn login(&mut self) -> Result<()> {
700 let endpoint = self.endpoint.join("/v1/session/login")?;
701 let headers = self.make_headers(None)?;
702 let body = LoginRequest::from(&*self.session_state.lock());
703 let mut builder = self.cli.post(endpoint.clone()).json(&body);
704 if self.disable_session_token {
705 builder = builder.query(&[("disable_session_token", true)]);
706 }
707 let builder = self.auth.wrap(builder)?;
708 let request = builder
709 .headers(headers.clone())
710 .timeout(self.connect_timeout)
711 .build()?;
712 let response = self.query_request_helper(request, true, false).await;
713 let response = match response {
714 Ok(r) => r,
715 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
716 info!("login return 404, skip login on the old version server");
717 return Ok(());
718 }
719 Err(e) => return Err(e),
720 };
721 let body = response.bytes().await?;
722 let response = json_from_slice(&body)?;
723 match response {
724 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
725 LoginResponseResult::Ok(info) => {
726 let server_version = info
727 .version
728 .parse()
729 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
730 self.capability = Capability::from_server_version(&server_version);
731 self.server_version = Some(server_version.clone());
732 if let Some(tokens) = info.tokens {
733 info!("login success with session token version = {server_version}",);
734 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
735 } else {
736 info!("login success, version = {server_version}");
737 }
738 }
739 }
740 Ok(())
741 }
742
743 fn build_log_out_request(&self) -> Result<Request> {
744 let endpoint = self.endpoint.join("/v1/session/logout")?;
745
746 let session_state = self.session_state();
747 let need_sticky = session_state.need_sticky.unwrap_or(false);
748 let mut headers = self.make_headers(None)?;
749 if need_sticky {
750 if let Some(node_id) = self.last_node_id() {
751 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
752 }
753 }
754 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
755
756 let builder = self.wrap_auth_or_session_token(builder)?;
757 let req = builder.build()?;
758 Ok(req)
759 }
760
761 fn need_logout(&self) -> bool {
762 (self.session_token_info.is_some()
763 || self.session_state.lock().need_keep_alive.unwrap_or(false))
764 && self
765 .closed
766 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
767 .is_ok()
768 }
769
770 async fn refresh_session_token(
771 &self,
772 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
773 ) -> Result<()> {
774 let (session_token_info, _) = { self_login_info.lock().clone() };
775 let endpoint = self.endpoint.join("/v1/session/refresh")?;
776 let body = RefreshSessionTokenRequest {
777 session_token: session_token_info.session_token.clone(),
778 };
779 let headers = self.make_headers(None)?;
780 let request = self
781 .cli
782 .post(endpoint.clone())
783 .json(&body)
784 .headers(headers.clone())
785 .bearer_auth(session_token_info.refresh_token.clone())
786 .timeout(self.connect_timeout)
787 .build()?;
788
789 for i in 0..3 {
791 let req = request.try_clone().expect("request not cloneable");
792 match self.cli.execute(req).await {
793 Ok(response) => {
794 let status = response.status();
795 let body = response.bytes().await?;
796 if status == StatusCode::OK {
797 let response = json_from_slice(&body)?;
798 return match response {
799 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
800 RefreshResponse::Ok(info) => {
801 *self_login_info.lock() = (info, Instant::now());
802 Ok(())
803 }
804 };
805 }
806 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
807 return Err(Error::response_error(status, &body));
808 }
809 }
810 Err(err) => {
811 if !(err.is_timeout() || err.is_connect()) || i > 2 {
812 return Err(Error::Request(err.to_string()));
813 }
814 }
815 };
816 sleep(jitter(Duration::from_secs(10))).await;
817 }
818 Ok(())
819 }
820
821 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
822 if let Some(info) = &self.session_token_info {
823 let (start, ttl) = {
824 let guard = info.lock();
825 (guard.1, guard.0.session_token_ttl_in_secs)
826 };
827 if Instant::now() > start + Duration::from_secs(ttl) {
828 return Some(info.clone());
829 }
830 }
831 None
832 }
833
834 async fn query_request_helper(
842 &self,
843 mut request: Request,
844 retry_if_503: bool,
845 refresh_if_401: bool,
846 ) -> std::result::Result<Response, Error> {
847 let mut refreshed = false;
848 let mut retries = 0;
849 loop {
850 let req = request.try_clone().expect("request not cloneable");
851 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
852 Ok(response) => {
853 let status = response.status();
854 if status == StatusCode::OK {
855 return Ok(response);
856 }
857 let body = response.bytes().await?;
858 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
859 (Error::response_error(status, &body), true)
861 } else {
862 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
863 match resp {
864 Ok(r) => {
865 let e = r.error;
866 if status == StatusCode::UNAUTHORIZED {
867 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
868 if let Some(session_token_info) = &self.session_token_info {
869 info!(
870 "will retry {} after refresh token on auth error {}",
871 request.url(),
872 e
873 );
874 let retry = if need_refresh_token(e.code)
875 && !refreshed
876 && refresh_if_401
877 {
878 self.refresh_session_token(session_token_info.clone())
879 .await?;
880 refreshed = true;
881 true
882 } else {
883 false
884 };
885 (Error::AuthFailure(e), retry)
886 } else if self.auth.can_reload() {
887 info!(
888 "will retry {} after reload token on auth error {}",
889 request.url(),
890 e
891 );
892 let builder = RequestBuilder::from_parts(
893 HttpClient::new(),
894 request.try_clone().unwrap(),
895 );
896 let builder = self.auth.wrap(builder)?;
897 request = builder.build()?;
898 (Error::AuthFailure(e), true)
899 } else {
900 (Error::AuthFailure(e), false)
901 }
902 } else {
903 (Error::Logic(status, e), false)
904 }
905 }
906 Err(_) => (
907 Error::Response {
908 status,
909 msg: String::from_utf8_lossy(&body).to_string(),
910 },
911 false,
912 ),
913 }
914 }
915 }
916 Err(err) => (
917 Error::Request(err.to_string()),
918 err.is_timeout() || err.is_connect(),
919 ),
920 };
921 if !retry {
922 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
923 }
924 match &err {
925 Error::AuthFailure(_) => {
926 if refreshed {
927 retries = 0;
928 } else if retries == 2 {
929 return Err(err.with_context(&format!(
930 "{} {} after 3 reties",
931 request.method(),
932 request.url()
933 )));
934 }
935 }
936 _ => {
937 if retries == 2 {
938 return Err(err.with_context(&format!(
939 "{} {} after 3 reties",
940 request.method(),
941 request.url()
942 )));
943 }
944 retries += 1;
945 info!(
946 "will retry {} the {retries}th times on error {}",
947 request.url(),
948 err
949 );
950 }
951 }
952 sleep(jitter(Duration::from_secs(10))).await;
953 }
954 }
955
956 pub async fn close(&self) {
957 if self.need_logout() {
958 let cli = self.cli.clone();
959 let req = self
960 .build_log_out_request()
961 .expect("failed to build logout request");
962 if let Err(err) = cli.execute(req).await {
963 error!("logout request failed: {err}");
964 } else {
965 debug!("logout success");
966 };
967 }
968 }
969}
970
971impl Drop for APIClient {
972 fn drop(&mut self) {
973 if self.need_logout() {
974 warn!("APIClient::close() was not called");
975 }
976 }
977}
978
979fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
980where
981 T: Deserialize<'a>,
982{
983 serde_json::from_slice::<T>(body).map_err(|e| {
984 Error::Decode(format!(
985 "fail to decode JSON response: {e}, body: {}",
986 String::from_utf8_lossy(body)
987 ))
988 })
989}
990
991impl Default for APIClient {
992 fn default() -> Self {
993 Self {
994 cli: HttpClient::new(),
995 scheme: "http".to_string(),
996 endpoint: Url::parse("http://localhost:8080").unwrap(),
997 host: "localhost".to_string(),
998 port: 8000,
999 tenant: None,
1000 warehouse: Mutex::new(None),
1001 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1002 session_state: Mutex::new(SessionState::default()),
1003 wait_time_secs: None,
1004 max_rows_in_buffer: None,
1005 max_rows_per_page: None,
1006 connect_timeout: Duration::from_secs(10),
1007 page_request_timeout: Duration::from_secs(30),
1008 tls_ca_file: None,
1009 presign: Mutex::new(PresignMode::Auto),
1010 route_hint: RouteHintGenerator::new(),
1011 last_node_id: Default::default(),
1012 disable_session_token: true,
1013 disable_login: false,
1014 session_token_info: None,
1015 closed: Default::default(),
1016 last_query_id: Default::default(),
1017 server_version: None,
1018 capability: Default::default(),
1019 }
1020 }
1021}
1022
1023struct RouteHintGenerator {
1024 nonce: AtomicU64,
1025 current: std::sync::Mutex<String>,
1026}
1027
1028impl RouteHintGenerator {
1029 fn new() -> Self {
1030 let gen = Self {
1031 nonce: AtomicU64::new(0),
1032 current: std::sync::Mutex::new("".to_string()),
1033 };
1034 gen.next();
1035 gen
1036 }
1037
1038 fn current(&self) -> String {
1039 let guard = self.current.lock().unwrap();
1040 guard.clone()
1041 }
1042
1043 fn set(&self, hint: &str) {
1044 let mut guard = self.current.lock().unwrap();
1045 *guard = hint.to_string();
1046 }
1047
1048 fn next(&self) -> String {
1049 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1050 let uuid = uuid::Uuid::new_v4();
1051 let current = format!("rh:{uuid}:{nonce:06}");
1052 let mut guard = self.current.lock().unwrap();
1053 guard.clone_from(¤t);
1054 current
1055 }
1056}
1057
1058#[cfg(test)]
1059mod test {
1060 use super::*;
1061
1062 #[tokio::test]
1063 async fn parse_dsn() -> Result<()> {
1064 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1065 let client = APIClient::from_dsn(dsn).await?;
1066 assert_eq!(client.host, "app.databend.com");
1067 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1068 assert_eq!(client.wait_time_secs, Some(10));
1069 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1070 assert_eq!(client.max_rows_per_page, Some(10000));
1071 assert_eq!(client.tenant, None);
1072 assert_eq!(
1073 *client.warehouse.try_lock().unwrap(),
1074 Some("wh".to_string())
1075 );
1076 Ok(())
1077 }
1078
1079 #[tokio::test]
1080 async fn parse_encoded_password() -> Result<()> {
1081 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1082 let client = APIClient::from_dsn(dsn).await?;
1083 assert_eq!(client.host(), "localhost");
1084 assert_eq!(client.port(), 443);
1085 Ok(())
1086 }
1087
1088 #[tokio::test]
1089 async fn parse_special_chars_password() -> Result<()> {
1090 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1091 let client = APIClient::from_dsn(dsn).await?;
1092 assert_eq!(client.host(), "localhost");
1093 assert_eq!(client.port(), 8000);
1094 Ok(())
1095 }
1096}