1use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
16use crate::capability::Capability;
17use crate::error_code::{need_refresh_token, ResponseWithErrorCode};
18use crate::global_cookie_store::GlobalCookieStore;
19use crate::login::{
20 LoginRequest, LoginResponseResult, RefreshResponse, RefreshSessionTokenRequest,
21 SessionTokenInfo,
22};
23use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
24use crate::response::LoadResponse;
25use crate::stage::StageLocation;
26use crate::{
27 error::{Error, Result},
28 request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
29 response::QueryResponse,
30 session::SessionState,
31 QueryStats,
32};
33use crate::{Page, Pages};
34use base64::engine::general_purpose::URL_SAFE;
35use base64::Engine;
36use log::{debug, error, info, warn};
37use once_cell::sync::Lazy;
38use parking_lot::Mutex;
39use percent_encoding::percent_decode_str;
40use reqwest::cookie::CookieStore;
41use reqwest::header::{HeaderMap, HeaderValue};
42use reqwest::multipart::{Form, Part};
43use reqwest::{Body, Client as HttpClient, Request, RequestBuilder, Response, StatusCode};
44use semver::Version;
45use serde::{de, Deserialize};
46use std::collections::BTreeMap;
47use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::time::sleep;
51use tokio_retry::strategy::jitter;
52use tokio_stream::StreamExt;
53use tokio_util::io::ReaderStream;
54use url::Url;
55
56const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
57const HEADER_TENANT: &str = "X-DATABEND-TENANT";
58const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
59const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
60const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
61const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
62const TXN_STATE_ACTIVE: &str = "Active";
63const HEADER_SQL: &str = "X-DATABEND-SQL";
64const HEADER_QUERY_CONTEXT: &str = "X-DATABEND-QUERY-CONTEXT";
65
66static VERSION: Lazy<String> = Lazy::new(|| {
67 let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
68 version.to_string()
69});
70
71pub struct APIClient {
72 cli: HttpClient,
73 scheme: String,
74 host: String,
75 port: u16,
76
77 endpoint: Url,
78
79 auth: Arc<dyn Auth>,
80
81 tenant: Option<String>,
82 warehouse: Mutex<Option<String>>,
83 session_state: Mutex<SessionState>,
84 route_hint: RouteHintGenerator,
85
86 disable_login: bool,
87 disable_session_token: bool,
88 session_token_info: Option<Arc<Mutex<(SessionTokenInfo, Instant)>>>,
89
90 closed: AtomicBool,
91
92 server_version: Option<Version>,
93
94 wait_time_secs: Option<i64>,
95 max_rows_in_buffer: Option<i64>,
96 max_rows_per_page: Option<i64>,
97
98 connect_timeout: Duration,
99 page_request_timeout: Duration,
100
101 tls_ca_file: Option<String>,
102
103 presign: Mutex<PresignMode>,
104 last_node_id: Mutex<Option<String>>,
105 last_query_id: Mutex<Option<String>>,
106
107 capability: Capability,
108}
109
110impl APIClient {
111 pub async fn new(dsn: &str, name: Option<String>) -> Result<Arc<Self>> {
112 let mut client = Self::from_dsn(dsn).await?;
113 client.build_client(name).await?;
114 if !client.disable_login {
115 client.login().await?;
116 }
117 let client = Arc::new(client);
118 client.check_presign().await?;
119 Ok(client)
120 }
121
122 pub fn capability(&self) -> &Capability {
123 &self.capability
124 }
125
126 fn set_presign_mode(&self, mode: PresignMode) {
127 *self.presign.lock() = mode
128 }
129 fn get_presign_mode(&self) -> PresignMode {
130 *self.presign.lock()
131 }
132
133 async fn from_dsn(dsn: &str) -> Result<Self> {
134 let u = Url::parse(dsn)?;
135 let mut client = Self::default();
136 if let Some(host) = u.host_str() {
137 client.host = host.to_string();
138 }
139
140 if u.username() != "" {
141 let password = u.password().unwrap_or_default();
142 let password = percent_decode_str(password).decode_utf8()?;
143 client.auth = Arc::new(BasicAuth::new(u.username(), password));
144 }
145 let database = match u.path().trim_start_matches('/') {
146 "" => None,
147 s => Some(s.to_string()),
148 };
149 let mut role = None;
150 let mut scheme = "https";
151 let mut session_settings = BTreeMap::new();
152 for (k, v) in u.query_pairs() {
153 match k.as_ref() {
154 "wait_time_secs" => {
155 client.wait_time_secs = Some(v.parse()?);
156 }
157 "max_rows_in_buffer" => {
158 client.max_rows_in_buffer = Some(v.parse()?);
159 }
160 "max_rows_per_page" => {
161 client.max_rows_per_page = Some(v.parse()?);
162 }
163 "connect_timeout" => client.connect_timeout = Duration::from_secs(v.parse()?),
164 "page_request_timeout_secs" => {
165 client.page_request_timeout = {
166 let secs: u64 = v.parse()?;
167 Duration::from_secs(secs)
168 };
169 }
170 "presign" => {
171 let presign_mode = match v.as_ref() {
172 "auto" => PresignMode::Auto,
173 "detect" => PresignMode::Detect,
174 "on" => PresignMode::On,
175 "off" => PresignMode::Off,
176 _ => {
177 return Err(Error::BadArgument(format!(
178 "Invalid value for presign: {v}, should be one of auto/detect/on/off"
179 )))
180 }
181 };
182 client.set_presign_mode(presign_mode);
183 }
184 "tenant" => {
185 client.tenant = Some(v.to_string());
186 }
187 "warehouse" => {
188 client.warehouse = Mutex::new(Some(v.to_string()));
189 }
190 "role" => role = Some(v.to_string()),
191 "sslmode" => match v.as_ref() {
192 "disable" => scheme = "http",
193 "require" | "enable" => scheme = "https",
194 _ => {
195 return Err(Error::BadArgument(format!(
196 "Invalid value for sslmode: {v}"
197 )))
198 }
199 },
200 "tls_ca_file" => {
201 client.tls_ca_file = Some(v.to_string());
202 }
203 "access_token" => {
204 client.auth = Arc::new(AccessTokenAuth::new(v));
205 }
206 "access_token_file" => {
207 client.auth = Arc::new(AccessTokenFileAuth::new(v));
208 }
209 "login" => {
210 client.disable_login = match v.as_ref() {
211 "disable" => true,
212 "enable" => false,
213 _ => {
214 return Err(Error::BadArgument(format!("Invalid value for login: {v}")))
215 }
216 }
217 }
218 "session_token" => {
219 client.disable_session_token = match v.as_ref() {
220 "disable" => true,
221 "enable" => false,
222 _ => {
223 return Err(Error::BadArgument(format!(
224 "Invalid value for session_token: {v}"
225 )))
226 }
227 }
228 }
229 _ => {
230 session_settings.insert(k.to_string(), v.to_string());
231 }
232 }
233 }
234 client.port = match u.port() {
235 Some(p) => p,
236 None => match scheme {
237 "http" => 80,
238 "https" => 443,
239 _ => unreachable!(),
240 },
241 };
242 client.scheme = scheme.to_string();
243
244 client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;
245 client.session_state = Mutex::new(
246 SessionState::default()
247 .with_settings(Some(session_settings))
248 .with_role(role)
249 .with_database(database),
250 );
251
252 Ok(client)
253 }
254
255 pub fn host(&self) -> &str {
256 self.host.as_str()
257 }
258
259 pub fn port(&self) -> u16 {
260 self.port
261 }
262
263 pub fn scheme(&self) -> &str {
264 self.scheme.as_str()
265 }
266
267 async fn build_client(&mut self, name: Option<String>) -> Result<()> {
268 let ua = match name {
269 Some(n) => n,
270 None => format!("databend-client-rust/{}", VERSION.as_str()),
271 };
272 let cookie_provider = GlobalCookieStore::new();
273 let cookie = HeaderValue::from_str("cookie_enabled=true").unwrap();
274 let mut initial_cookies = [&cookie].into_iter();
275 cookie_provider.set_cookies(&mut initial_cookies, &Url::parse("https://a.com").unwrap());
276 let mut cli_builder = HttpClient::builder()
277 .user_agent(ua)
278 .cookie_provider(Arc::new(cookie_provider))
279 .pool_idle_timeout(Duration::from_secs(1));
280 #[cfg(any(feature = "rustls", feature = "native-tls"))]
281 if self.scheme == "https" {
282 if let Some(ref ca_file) = self.tls_ca_file {
283 let cert_pem = tokio::fs::read(ca_file).await?;
284 let cert = reqwest::Certificate::from_pem(&cert_pem)?;
285 cli_builder = cli_builder.add_root_certificate(cert);
286 }
287 }
288 self.cli = cli_builder.build()?;
289 Ok(())
290 }
291
292 async fn check_presign(self: &Arc<Self>) -> Result<()> {
293 let mode = match self.get_presign_mode() {
294 PresignMode::Auto => {
295 if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
296 PresignMode::On
297 } else {
298 PresignMode::Off
299 }
300 }
301 PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
302 Ok(_) => PresignMode::On,
303 Err(e) => {
304 warn!("presign mode off with error detected: {e}");
305 PresignMode::Off
306 }
307 },
308 mode => mode,
309 };
310 self.set_presign_mode(mode);
311 Ok(())
312 }
313
314 pub fn current_warehouse(&self) -> Option<String> {
315 let guard = self.warehouse.lock();
316 guard.clone()
317 }
318
319 pub fn current_catalog(&self) -> Option<String> {
320 let guard = self.session_state.lock();
321 guard.catalog.clone()
322 }
323
324 pub fn current_database(&self) -> Option<String> {
325 let guard = self.session_state.lock();
326 guard.database.clone()
327 }
328
329 pub async fn current_role(&self) -> Option<String> {
330 let guard = self.session_state.lock();
331 guard.role.clone()
332 }
333
334 fn in_active_transaction(&self) -> bool {
335 let guard = self.session_state.lock();
336 guard
337 .txn_state
338 .as_ref()
339 .map(|s| s.eq_ignore_ascii_case(TXN_STATE_ACTIVE))
340 .unwrap_or(false)
341 }
342
343 pub fn username(&self) -> String {
344 self.auth.username()
345 }
346
347 fn gen_query_id(&self) -> String {
348 uuid::Uuid::now_v7().simple().to_string()
349 }
350
351 async fn handle_session(&self, session: &Option<SessionState>) {
352 let session = match session {
353 Some(session) => session,
354 None => return,
355 };
356
357 {
359 let mut session_state = self.session_state.lock();
360 *session_state = session.clone();
361 }
362
363 if let Some(settings) = session.settings.as_ref() {
365 if let Some(v) = settings.get("warehouse") {
366 let mut warehouse = self.warehouse.lock();
367 *warehouse = Some(v.clone());
368 }
369 }
370 }
371
372 pub fn set_last_node_id(&self, node_id: String) {
373 *self.last_node_id.lock() = Some(node_id)
374 }
375
376 pub fn set_last_query_id(&self, query_id: Option<String>) {
377 *self.last_query_id.lock() = query_id
378 }
379
380 pub fn last_query_id(&self) -> Option<String> {
381 self.last_query_id.lock().clone()
382 }
383
384 fn last_node_id(&self) -> Option<String> {
385 self.last_node_id.lock().clone()
386 }
387
388 fn handle_warnings(&self, resp: &QueryResponse) {
389 if let Some(warnings) = &resp.warnings {
390 for w in warnings {
391 warn!(target: "server_warnings", "server warning: {w}");
392 }
393 }
394 }
395
396 pub async fn start_query(self: &Arc<Self>, sql: &str, need_progress: bool) -> Result<Pages> {
397 info!("start query: {sql}");
398 let resp = self.start_query_inner(sql, None).await?;
399 let pages = Pages::new(self.clone(), resp, need_progress);
400 Ok(pages)
401 }
402
403 fn wrap_auth_or_session_token(&self, builder: RequestBuilder) -> Result<RequestBuilder> {
404 if let Some(info) = &self.session_token_info {
405 let info = info.lock();
406 Ok(builder.bearer_auth(info.0.session_token.clone()))
407 } else {
408 self.auth.wrap(builder)
409 }
410 }
411
412 async fn start_query_inner(
413 &self,
414 sql: &str,
415 stage_attachment_config: Option<StageAttachmentConfig<'_>>,
416 ) -> Result<QueryResponse> {
417 if !self.in_active_transaction() {
418 self.route_hint.next();
419 }
420 let endpoint = self.endpoint.join("v1/query")?;
421
422 let session_state = self.session_state();
424 let need_sticky = session_state.need_sticky.unwrap_or(false);
425 let req = QueryRequest::new(sql)
426 .with_pagination(self.make_pagination())
427 .with_session(Some(session_state))
428 .with_stage_attachment(stage_attachment_config);
429
430 let query_id = self.gen_query_id();
432 let mut headers = self.make_headers(Some(&query_id))?;
433 if need_sticky {
434 if let Some(node_id) = self.last_node_id() {
435 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
436 }
437 }
438 let mut builder = self.cli.post(endpoint.clone()).json(&req);
439 builder = self.wrap_auth_or_session_token(builder)?;
440 let request = builder.headers(headers.clone()).build()?;
441 let response = self.query_request_helper(request, true, true).await?;
442 if let Some(route_hint) = response.headers().get(HEADER_ROUTE_HINT) {
443 self.route_hint.set(route_hint.to_str().unwrap_or_default());
444 }
445 let body = response.bytes().await?;
446 let result: QueryResponse = json_from_slice(&body)?;
447 self.handle_session(&result.session).await;
448 if let Some(err) = result.error {
449 return Err(Error::QueryFailed(err));
450 }
451
452 self.set_last_query_id(Some(query_id));
453 self.handle_warnings(&result);
454 if let Some(node_id) = &result.node_id {
455 self.set_last_node_id(node_id.clone());
456 }
457 Ok(result)
458 }
459
460 pub async fn query_page(
461 &self,
462 query_id: &str,
463 next_uri: &str,
464 node_id: &Option<String>,
465 ) -> Result<QueryResponse> {
466 info!("query page: {next_uri}");
467 let endpoint = self.endpoint.join(next_uri)?;
468 let headers = self.make_headers(Some(query_id))?;
469 let mut builder = self.cli.get(endpoint.clone());
470 builder = self
471 .wrap_auth_or_session_token(builder)?
472 .headers(headers.clone())
473 .timeout(self.page_request_timeout);
474 if let Some(node_id) = node_id {
475 builder = builder.header(HEADER_STICKY_NODE, node_id)
476 }
477 let request = builder.build()?;
478
479 let response = self.query_request_helper(request, false, true).await?;
480 let status = response.status();
481 if status != 200 {
482 return Err(Error::response_error(status, &response.bytes().await?));
483 }
484 let body = response.bytes().await?;
485 let resp: QueryResponse = json_from_slice(&body).map_err(|e| {
486 if let Error::Logic(status, ec) = &e {
487 if *status == 404 {
488 return Error::QueryNotFound(ec.message.clone());
489 }
490 }
491 e
492 })?;
493 self.handle_session(&resp.session).await;
494 match resp.error {
495 Some(err) => Err(Error::QueryFailed(err)),
496 None => Ok(resp),
497 }
498 }
499
500 pub async fn kill_query(&self, query_id: &str) -> Result<()> {
501 let kill_uri = format!("/v1/query/{query_id}/kill");
502 let endpoint = self.endpoint.join(&kill_uri)?;
503 let headers = self.make_headers(Some(query_id))?;
504 info!("kill query: {kill_uri}");
505
506 let mut builder = self.cli.post(endpoint);
507 builder = self.wrap_auth_or_session_token(builder)?;
508 let resp = builder.headers(headers.clone()).send().await?;
509 if resp.status() != 200 {
510 return Err(Error::response_error(resp.status(), &resp.bytes().await?)
511 .with_context("kill query"));
512 }
513 Ok(())
514 }
515
516 pub async fn query_all(self: &Arc<Self>, sql: &str) -> Result<Page> {
517 let mut pages = self.start_query(sql, false).await?;
518 let mut all = Page::default();
519 while let Some(page) = pages.next().await {
520 all.update(page?);
521 }
522 Ok(all)
523 }
524
525 fn session_state(&self) -> SessionState {
526 self.session_state.lock().clone()
527 }
528
529 fn make_pagination(&self) -> Option<PaginationConfig> {
530 if self.wait_time_secs.is_none()
531 && self.max_rows_in_buffer.is_none()
532 && self.max_rows_per_page.is_none()
533 {
534 return None;
535 }
536 let mut pagination = PaginationConfig {
537 wait_time_secs: None,
538 max_rows_in_buffer: None,
539 max_rows_per_page: None,
540 };
541 if let Some(wait_time_secs) = self.wait_time_secs {
542 pagination.wait_time_secs = Some(wait_time_secs);
543 }
544 if let Some(max_rows_in_buffer) = self.max_rows_in_buffer {
545 pagination.max_rows_in_buffer = Some(max_rows_in_buffer);
546 }
547 if let Some(max_rows_per_page) = self.max_rows_per_page {
548 pagination.max_rows_per_page = Some(max_rows_per_page);
549 }
550 Some(pagination)
551 }
552
553 fn make_headers(&self, query_id: Option<&str>) -> Result<HeaderMap> {
554 let mut headers = HeaderMap::new();
555 if let Some(tenant) = &self.tenant {
556 headers.insert(HEADER_TENANT, tenant.parse()?);
557 }
558 let warehouse = self.warehouse.lock().clone();
559 if let Some(warehouse) = warehouse {
560 headers.insert(HEADER_WAREHOUSE, warehouse.parse()?);
561 }
562 let route_hint = self.route_hint.current();
563 headers.insert(HEADER_ROUTE_HINT, route_hint.parse()?);
564 if let Some(query_id) = query_id {
565 headers.insert(HEADER_QUERY_ID, query_id.parse()?);
566 }
567 Ok(headers)
568 }
569
570 pub async fn insert_with_stage(
571 self: &Arc<Self>,
572 sql: &str,
573 stage: &str,
574 file_format_options: BTreeMap<&str, &str>,
575 copy_options: BTreeMap<&str, &str>,
576 ) -> Result<QueryStats> {
577 info!("insert with stage: {sql}, format: {file_format_options:?}, copy: {copy_options:?}");
578 let stage_attachment = Some(StageAttachmentConfig {
579 location: stage,
580 file_format_options: Some(file_format_options),
581 copy_options: Some(copy_options),
582 });
583 let resp = self.start_query_inner(sql, stage_attachment).await?;
584 let mut pages = Pages::new(self.clone(), resp, false);
585 let mut all = Page::default();
586 while let Some(page) = pages.next().await {
587 all.update(page?);
588 }
589 Ok(all.stats)
590 }
591
592 async fn get_presigned_upload_url(self: &Arc<Self>, stage: &str) -> Result<PresignedResponse> {
593 info!("get presigned upload url: {stage}");
594 let sql = format!("PRESIGN UPLOAD {stage}");
595 let resp = self.query_all(&sql).await?;
596 if resp.data.len() != 1 {
597 return Err(Error::Decode(
598 "Empty response from server for presigned request".to_string(),
599 ));
600 }
601 if resp.data[0].len() != 3 {
602 return Err(Error::Decode(
603 "Invalid response from server for presigned request".to_string(),
604 ));
605 }
606 let method = resp.data[0][0].clone().unwrap_or_default();
608 if method != "PUT" {
609 return Err(Error::Decode(format!(
610 "Invalid method for presigned upload request: {method}"
611 )));
612 }
613 let headers: BTreeMap<String, String> =
614 serde_json::from_str(resp.data[0][1].clone().unwrap_or("{}".to_string()).as_str())?;
615 let url = resp.data[0][2].clone().unwrap_or_default();
616 Ok(PresignedResponse {
617 method,
618 headers,
619 url,
620 })
621 }
622
623 pub async fn upload_to_stage(
624 self: &Arc<Self>,
625 stage: &str,
626 data: Reader,
627 size: u64,
628 ) -> Result<()> {
629 match self.get_presign_mode() {
630 PresignMode::Off => self.upload_to_stage_with_stream(stage, data, size).await,
631 PresignMode::On => {
632 let presigned = self.get_presigned_upload_url(stage).await?;
633 presign_upload_to_stage(presigned, data, size).await
634 }
635 PresignMode::Auto => {
636 unreachable!("PresignMode::Auto should be handled during client initialization")
637 }
638 PresignMode::Detect => {
639 unreachable!("PresignMode::Detect should be handled during client initialization")
640 }
641 }
642 }
643
644 async fn upload_to_stage_with_stream(
646 &self,
647 stage: &str,
648 data: Reader,
649 size: u64,
650 ) -> Result<()> {
651 info!("upload to stage with stream: {stage}, size: {size}");
652 if let Some(info) = self.need_pre_refresh_session().await {
653 self.refresh_session_token(info).await?;
654 }
655 let endpoint = self.endpoint.join("v1/upload_to_stage")?;
656 let location = StageLocation::try_from(stage)?;
657 let query_id = self.gen_query_id();
658 let mut headers = self.make_headers(Some(&query_id))?;
659 headers.insert(HEADER_STAGE_NAME, location.name.parse()?);
660 let stream = Body::wrap_stream(ReaderStream::new(data));
661 let part = Part::stream_with_length(stream, size).file_name(location.path);
662 let form = Form::new().part("upload", part);
663 let mut builder = self.cli.put(endpoint.clone());
664 builder = self.wrap_auth_or_session_token(builder)?;
665 let resp = builder.headers(headers).multipart(form).send().await?;
666 let status = resp.status();
667 if status != 200 {
668 return Err(
669 Error::response_error(status, &resp.bytes().await?).with_context("upload_to_stage")
670 );
671 }
672 Ok(())
673 }
674
675 pub fn decode_json_header<T>(key: &str, value: &str) -> Result<T, String>
678 where
679 T: de::DeserializeOwned,
680 {
681 if value.starts_with("{") {
682 serde_json::from_slice(value.as_bytes())
683 .map_err(|e| format!("Invalid value {value} for {key} JSON decode error: {e}",))?
684 } else {
685 let json = URL_SAFE.decode(value).map_err(|e| {
686 format!(
687 "Invalid value {} for {key}, base64 decode error: {}",
688 value, e
689 )
690 })?;
691 serde_json::from_slice(&json).map_err(|e| {
692 format!(
693 "Invalid value {value} for {key}, JSON value {}, decode error: {e}",
694 String::from_utf8_lossy(&json)
695 )
696 })
697 }
698 }
699
700 pub async fn streaming_load(
701 &self,
702 sql: &str,
703 data: Reader,
704 file_name: &str,
705 ) -> Result<LoadResponse> {
706 let body = Body::wrap_stream(ReaderStream::new(data));
707 let part = Part::stream(body).file_name(file_name.to_string());
708 let endpoint = self.endpoint.join("v1/streaming_load")?;
709 let mut builder = self.cli.put(endpoint.clone());
710 builder = self.wrap_auth_or_session_token(builder)?;
711 let query_id = self.gen_query_id();
712 let mut headers = self.make_headers(Some(&query_id))?;
713 headers.insert(HEADER_SQL, sql.parse()?);
714 let session = serde_json::to_string(&*self.session_state.lock())
715 .expect("serialize session state should not fail");
716 headers.insert(HEADER_QUERY_CONTEXT, session.parse()?);
717 let form = Form::new().part("upload", part);
718 let resp = builder.headers(headers).multipart(form).send().await?;
719 let status = resp.status();
720 if let Some(value) = resp.headers().get(HEADER_QUERY_CONTEXT) {
721 match Self::decode_json_header::<SessionState>(
722 HEADER_QUERY_CONTEXT,
723 value.to_str().unwrap(),
724 ) {
725 Ok(session) => *self.session_state.lock() = session,
726 Err(e) => {
727 error!("Error decoding session state when streaming load: {e}");
728 }
729 }
730 };
731 if status != 200 {
732 return Err(
733 Error::response_error(status, &resp.bytes().await?).with_context("streaming_load")
734 );
735 }
736 let resp = resp.json::<LoadResponse>().await?;
737 Ok(resp)
738 }
739
740 async fn login(&mut self) -> Result<()> {
741 let endpoint = self.endpoint.join("/v1/session/login")?;
742 let headers = self.make_headers(None)?;
743 let body = LoginRequest::from(&*self.session_state.lock());
744 let mut builder = self.cli.post(endpoint.clone()).json(&body);
745 if self.disable_session_token {
746 builder = builder.query(&[("disable_session_token", true)]);
747 }
748 let builder = self.auth.wrap(builder)?;
749 let request = builder
750 .headers(headers.clone())
751 .timeout(self.connect_timeout)
752 .build()?;
753 let response = self.query_request_helper(request, true, false).await;
754 let response = match response {
755 Ok(r) => r,
756 Err(e) if e.status_code() == Some(StatusCode::NOT_FOUND) => {
757 info!("login return 404, skip login on the old version server");
758 return Ok(());
759 }
760 Err(e) => return Err(e),
761 };
762 let body = response.bytes().await?;
763 let response = json_from_slice(&body)?;
764 match response {
765 LoginResponseResult::Err { error } => return Err(Error::AuthFailure(error)),
766 LoginResponseResult::Ok(info) => {
767 let server_version = info
768 .version
769 .parse()
770 .map_err(|e| Error::Decode(format!("invalid server version: {e}")))?;
771 self.capability = Capability::from_server_version(&server_version);
772 self.server_version = Some(server_version.clone());
773 if let Some(tokens) = info.tokens {
774 info!("login success with session token version = {server_version}",);
775 self.session_token_info = Some(Arc::new(Mutex::new((tokens, Instant::now()))))
776 } else {
777 info!("login success, version = {server_version}");
778 }
779 }
780 }
781 Ok(())
782 }
783
784 fn build_log_out_request(&self) -> Result<Request> {
785 let endpoint = self.endpoint.join("/v1/session/logout")?;
786
787 let session_state = self.session_state();
788 let need_sticky = session_state.need_sticky.unwrap_or(false);
789 let mut headers = self.make_headers(None)?;
790 if need_sticky {
791 if let Some(node_id) = self.last_node_id() {
792 headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
793 }
794 }
795 let builder = self.cli.post(endpoint.clone()).headers(headers.clone());
796
797 let builder = self.wrap_auth_or_session_token(builder)?;
798 let req = builder.build()?;
799 Ok(req)
800 }
801
802 fn need_logout(&self) -> bool {
803 (self.session_token_info.is_some()
804 || self.session_state.lock().need_keep_alive.unwrap_or(false))
805 && self
806 .closed
807 .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
808 .is_ok()
809 }
810
811 async fn refresh_session_token(
812 &self,
813 self_login_info: Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>,
814 ) -> Result<()> {
815 let (session_token_info, _) = { self_login_info.lock().clone() };
816 let endpoint = self.endpoint.join("/v1/session/refresh")?;
817 let body = RefreshSessionTokenRequest {
818 session_token: session_token_info.session_token.clone(),
819 };
820 let headers = self.make_headers(None)?;
821 let request = self
822 .cli
823 .post(endpoint.clone())
824 .json(&body)
825 .headers(headers.clone())
826 .bearer_auth(session_token_info.refresh_token.clone())
827 .timeout(self.connect_timeout)
828 .build()?;
829
830 for i in 0..3 {
832 let req = request.try_clone().expect("request not cloneable");
833 match self.cli.execute(req).await {
834 Ok(response) => {
835 let status = response.status();
836 let body = response.bytes().await?;
837 if status == StatusCode::OK {
838 let response = json_from_slice(&body)?;
839 return match response {
840 RefreshResponse::Err { error } => Err(Error::AuthFailure(error)),
841 RefreshResponse::Ok(info) => {
842 *self_login_info.lock() = (info, Instant::now());
843 Ok(())
844 }
845 };
846 }
847 if status != StatusCode::SERVICE_UNAVAILABLE || i >= 2 {
848 return Err(Error::response_error(status, &body));
849 }
850 }
851 Err(err) => {
852 if !(err.is_timeout() || err.is_connect()) || i > 2 {
853 return Err(Error::Request(err.to_string()));
854 }
855 }
856 };
857 sleep(jitter(Duration::from_secs(10))).await;
858 }
859 Ok(())
860 }
861
862 async fn need_pre_refresh_session(&self) -> Option<Arc<Mutex<(SessionTokenInfo, Instant)>>> {
863 if let Some(info) = &self.session_token_info {
864 let (start, ttl) = {
865 let guard = info.lock();
866 (guard.1, guard.0.session_token_ttl_in_secs)
867 };
868 if Instant::now() > start + Duration::from_secs(ttl) {
869 return Some(info.clone());
870 }
871 }
872 None
873 }
874
875 async fn query_request_helper(
883 &self,
884 mut request: Request,
885 retry_if_503: bool,
886 refresh_if_401: bool,
887 ) -> std::result::Result<Response, Error> {
888 let mut refreshed = false;
889 let mut retries = 0;
890 loop {
891 let req = request.try_clone().expect("request not cloneable");
892 let (err, retry): (Error, bool) = match self.cli.execute(req).await {
893 Ok(response) => {
894 let status = response.status();
895 if status == StatusCode::OK {
896 return Ok(response);
897 }
898 let body = response.bytes().await?;
899 if retry_if_503 && status == StatusCode::SERVICE_UNAVAILABLE {
900 (Error::response_error(status, &body), true)
902 } else {
903 let resp = serde_json::from_slice::<ResponseWithErrorCode>(&body);
904 match resp {
905 Ok(r) => {
906 let e = r.error;
907 if status == StatusCode::UNAUTHORIZED {
908 request.headers_mut().remove(reqwest::header::AUTHORIZATION);
909 if let Some(session_token_info) = &self.session_token_info {
910 info!(
911 "will retry {} after refresh token on auth error {}",
912 request.url(),
913 e
914 );
915 let retry = if need_refresh_token(e.code)
916 && !refreshed
917 && refresh_if_401
918 {
919 self.refresh_session_token(session_token_info.clone())
920 .await?;
921 refreshed = true;
922 true
923 } else {
924 false
925 };
926 (Error::AuthFailure(e), retry)
927 } else if self.auth.can_reload() {
928 info!(
929 "will retry {} after reload token on auth error {}",
930 request.url(),
931 e
932 );
933 let builder = RequestBuilder::from_parts(
934 HttpClient::new(),
935 request.try_clone().unwrap(),
936 );
937 let builder = self.auth.wrap(builder)?;
938 request = builder.build()?;
939 (Error::AuthFailure(e), true)
940 } else {
941 (Error::AuthFailure(e), false)
942 }
943 } else {
944 (Error::Logic(status, e), false)
945 }
946 }
947 Err(_) => (
948 Error::Response {
949 status,
950 msg: String::from_utf8_lossy(&body).to_string(),
951 },
952 false,
953 ),
954 }
955 }
956 }
957 Err(err) => (
958 Error::Request(err.to_string()),
959 err.is_timeout() || err.is_connect(),
960 ),
961 };
962 if !retry {
963 return Err(err.with_context(&format!("{} {}", request.method(), request.url())));
964 }
965 match &err {
966 Error::AuthFailure(_) => {
967 if refreshed {
968 retries = 0;
969 } else if retries == 2 {
970 return Err(err.with_context(&format!(
971 "{} {} after 3 reties",
972 request.method(),
973 request.url()
974 )));
975 }
976 }
977 _ => {
978 if retries == 2 {
979 return Err(err.with_context(&format!(
980 "{} {} after 3 reties",
981 request.method(),
982 request.url()
983 )));
984 }
985 retries += 1;
986 info!(
987 "will retry {} the {retries}th times on error {}",
988 request.url(),
989 err
990 );
991 }
992 }
993 sleep(jitter(Duration::from_secs(10))).await;
994 }
995 }
996
997 pub async fn close(&self) {
998 if self.need_logout() {
999 let cli = self.cli.clone();
1000 let req = self
1001 .build_log_out_request()
1002 .expect("failed to build logout request");
1003 if let Err(err) = cli.execute(req).await {
1004 error!("logout request failed: {err}");
1005 } else {
1006 debug!("logout success");
1007 };
1008 }
1009 }
1010}
1011
1012impl Drop for APIClient {
1013 fn drop(&mut self) {
1014 if self.need_logout() {
1015 warn!("APIClient::close() was not called");
1016 }
1017 }
1018}
1019
1020fn json_from_slice<'a, T>(body: &'a [u8]) -> Result<T>
1021where
1022 T: Deserialize<'a>,
1023{
1024 serde_json::from_slice::<T>(body).map_err(|e| {
1025 Error::Decode(format!(
1026 "fail to decode JSON response: {e}, body: {}",
1027 String::from_utf8_lossy(body)
1028 ))
1029 })
1030}
1031
1032impl Default for APIClient {
1033 fn default() -> Self {
1034 Self {
1035 cli: HttpClient::new(),
1036 scheme: "http".to_string(),
1037 endpoint: Url::parse("http://localhost:8080").unwrap(),
1038 host: "localhost".to_string(),
1039 port: 8000,
1040 tenant: None,
1041 warehouse: Mutex::new(None),
1042 auth: Arc::new(BasicAuth::new("root", "")) as Arc<dyn Auth>,
1043 session_state: Mutex::new(SessionState::default()),
1044 wait_time_secs: None,
1045 max_rows_in_buffer: None,
1046 max_rows_per_page: None,
1047 connect_timeout: Duration::from_secs(10),
1048 page_request_timeout: Duration::from_secs(30),
1049 tls_ca_file: None,
1050 presign: Mutex::new(PresignMode::Auto),
1051 route_hint: RouteHintGenerator::new(),
1052 last_node_id: Default::default(),
1053 disable_session_token: true,
1054 disable_login: false,
1055 session_token_info: None,
1056 closed: Default::default(),
1057 last_query_id: Default::default(),
1058 server_version: None,
1059 capability: Default::default(),
1060 }
1061 }
1062}
1063
1064struct RouteHintGenerator {
1065 nonce: AtomicU64,
1066 current: std::sync::Mutex<String>,
1067}
1068
1069impl RouteHintGenerator {
1070 fn new() -> Self {
1071 let gen = Self {
1072 nonce: AtomicU64::new(0),
1073 current: std::sync::Mutex::new("".to_string()),
1074 };
1075 gen.next();
1076 gen
1077 }
1078
1079 fn current(&self) -> String {
1080 let guard = self.current.lock().unwrap();
1081 guard.clone()
1082 }
1083
1084 fn set(&self, hint: &str) {
1085 let mut guard = self.current.lock().unwrap();
1086 *guard = hint.to_string();
1087 }
1088
1089 fn next(&self) -> String {
1090 let nonce = self.nonce.fetch_add(1, Ordering::AcqRel);
1091 let uuid = uuid::Uuid::new_v4();
1092 let current = format!("rh:{uuid}:{nonce:06}");
1093 let mut guard = self.current.lock().unwrap();
1094 guard.clone_from(¤t);
1095 current
1096 }
1097}
1098
1099#[cfg(test)]
1100mod test {
1101 use super::*;
1102
1103 #[tokio::test]
1104 async fn parse_dsn() -> Result<()> {
1105 let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
1106 let client = APIClient::from_dsn(dsn).await?;
1107 assert_eq!(client.host, "app.databend.com");
1108 assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
1109 assert_eq!(client.wait_time_secs, Some(10));
1110 assert_eq!(client.max_rows_in_buffer, Some(5000000));
1111 assert_eq!(client.max_rows_per_page, Some(10000));
1112 assert_eq!(client.tenant, None);
1113 assert_eq!(
1114 *client.warehouse.try_lock().unwrap(),
1115 Some("wh".to_string())
1116 );
1117 Ok(())
1118 }
1119
1120 #[tokio::test]
1121 async fn parse_encoded_password() -> Result<()> {
1122 let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
1123 let client = APIClient::from_dsn(dsn).await?;
1124 assert_eq!(client.host(), "localhost");
1125 assert_eq!(client.port(), 443);
1126 Ok(())
1127 }
1128
1129 #[tokio::test]
1130 async fn parse_special_chars_password() -> Result<()> {
1131 let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
1132 let client = APIClient::from_dsn(dsn).await?;
1133 assert_eq!(client.host(), "localhost");
1134 assert_eq!(client.port(), 8000);
1135 Ok(())
1136 }
1137}