1use std::collections::{HashMap, HashSet};
2
3use backon::ExponentialBuilder;
4use backon::Retryable;
5use http::header::{ACCEPT_ENCODING, USER_AGENT};
6use http::StatusCode;
7use iterable::*;
8use log::*;
9use reqwest::header::HeaderValue;
10use reqwest::{RequestBuilder, Response, Url};
11use tokio::sync::RwLock;
12use tokio::time::Duration;
13
14use crate::auth::Auth;
15use crate::build_dataset;
16use crate::error::TrinoRetryResult;
17use crate::error::{Error, Result};
18use crate::header::*;
19use crate::models::QueryResultData;
20#[cfg(feature = "spooling")]
21use crate::models::SpooledData;
22use crate::selected_role::SelectedRole;
23use crate::session::{Session, SessionBuilder};
24#[cfg(feature = "spooling")]
25use crate::spooling::decompress_segment_bytes;
26#[cfg(feature = "spooling")]
27use crate::spooling::{SegmentFetcher, SpoolingEncoding};
28use crate::ssl::Ssl;
29use crate::transaction::TransactionId;
30use crate::{DataSet, QueryResult, Row, Trino};
31
32pub struct Client {
37 client: reqwest::Client,
38 session: RwLock<Session>,
39 auth: Option<Auth>,
40 max_attempt: usize,
41 url: Url,
42 #[cfg(feature = "spooling")]
43 segment_fetcher: SegmentFetcher,
44}
45
46pub struct ClientBuilder {
47 session: SessionBuilder,
48 auth: Option<Auth>,
49 auth_http_insecure: bool,
50 max_attempt: usize,
51 ssl: Option<Ssl>,
52 no_verify: bool,
53 #[cfg(feature = "spooling")]
54 segment_fetcher: Option<SegmentFetcher>,
55 #[cfg(feature = "spooling")]
56 max_concurrent_segments: Option<usize>,
57}
58
59#[derive(Debug)]
60pub struct ExecuteResult {
61 pub output_uri: Option<String>,
62 pub update_type: Option<String>,
63 pub update_count: Option<u64>,
64}
65
66impl ClientBuilder {
67 pub fn new(user: impl ToString, host: impl ToString) -> Self {
68 let builder = SessionBuilder::new(user, host);
69 Self {
70 session: builder,
71 auth: None,
72 auth_http_insecure: false,
73 max_attempt: 3,
74 ssl: None,
75 no_verify: false,
76 #[cfg(feature = "spooling")]
77 segment_fetcher: None,
78 #[cfg(feature = "spooling")]
79 max_concurrent_segments: None,
80 }
81 }
82
83 pub fn port(mut self, s: u16) -> Self {
84 self.session.port = s;
85 self
86 }
87
88 pub fn secure(mut self, s: bool) -> Self {
89 self.session.secure = s;
90 self
91 }
92
93 pub fn no_verify(mut self, nv: bool) -> Self {
94 self.no_verify = nv;
95 self
96 }
97
98 pub fn source(mut self, s: impl ToString) -> Self {
99 self.session.source = s.to_string();
100 self
101 }
102
103 pub fn trace_token(mut self, s: impl ToString) -> Self {
104 self.session.trace_token = Some(s.to_string());
105 self
106 }
107
108 pub fn client_tags(mut self, s: HashSet<String>) -> Self {
109 self.session.client_tags = s;
110 self
111 }
112
113 pub fn client_tag(mut self, s: impl ToString) -> Self {
114 self.session.client_tags.insert(s.to_string());
115 self
116 }
117
118 pub fn client_info(mut self, s: impl ToString) -> Self {
119 self.session.client_info = Some(s.to_string());
120 self
121 }
122
123 pub fn catalog(mut self, s: impl ToString) -> Self {
124 self.session.catalog = Some(s.to_string());
125 self
126 }
127
128 pub fn schema(mut self, s: impl ToString) -> Self {
129 self.session.schema = Some(s.to_string());
130 self
131 }
132
133 pub fn path(mut self, s: impl ToString) -> Self {
134 self.session.path = Some(s.to_string());
135 self
136 }
137
138 pub fn resource_estimates(mut self, s: HashMap<String, String>) -> Self {
139 self.session.resource_estimates = s;
140 self
141 }
142
143 pub fn resource_estimate(mut self, k: impl ToString, v: impl ToString) -> Self {
144 self.session
145 .resource_estimates
146 .insert(k.to_string(), v.to_string());
147 self
148 }
149
150 pub fn properties(mut self, s: HashMap<String, String>) -> Self {
151 self.session.properties = s;
152 self
153 }
154
155 pub fn property(mut self, k: impl ToString, v: impl ToString) -> Self {
156 self.session.properties.insert(k.to_string(), v.to_string());
157 self
158 }
159
160 pub fn prepared_statements(mut self, s: HashMap<String, String>) -> Self {
161 self.session.prepared_statements = s;
162 self
163 }
164
165 pub fn prepared_statement(mut self, k: impl ToString, v: impl ToString) -> Self {
166 self.session
167 .prepared_statements
168 .insert(k.to_string(), v.to_string());
169 self
170 }
171
172 pub fn extra_credentials(mut self, s: HashMap<String, String>) -> Self {
173 self.session.extra_credentials = s;
174 self
175 }
176
177 pub fn extra_credential(mut self, k: impl ToString, v: impl ToString) -> Self {
178 self.session
179 .extra_credentials
180 .insert(k.to_string(), v.to_string());
181 self
182 }
183
184 pub fn transaction_id(mut self, s: TransactionId) -> Self {
185 self.session.transaction_id = s;
186 self
187 }
188
189 pub fn client_request_timeout(mut self, s: Duration) -> Self {
190 self.session.client_request_timeout = s;
191 self
192 }
193
194 pub fn compression_disabled(mut self, s: bool) -> Self {
195 self.session.compression_disabled = s;
196 self
197 }
198
199 #[cfg(feature = "spooling")]
200 pub fn segment_fetcher(mut self, segment_fetcher: SegmentFetcher) -> Self {
201 self.segment_fetcher = Some(segment_fetcher);
202 self
203 }
204
205 #[cfg(feature = "spooling")]
206 pub fn max_concurrent_segments(mut self, count: usize) -> Self {
209 self.max_concurrent_segments = Some(count);
210 self
211 }
212
213 #[cfg(feature = "spooling")]
214 pub fn spooling_encoding(mut self, encoding: impl ToString) -> Self {
217 let encoding_str = encoding.to_string();
218
219 match SpoolingEncoding::try_from(encoding_str.as_str()) {
220 Ok(_) => {
221 self.session.spooling_encoding = Some(encoding_str);
222 }
223 Err(_) => {
224 log::warn!(
225 "Invalid spooling encoding '{}', using default 'json+zstd'. Valid values: json, json+zstd, json+lz4",
226 encoding_str
227 );
228 self.session.spooling_encoding = Some("json+zstd".to_string());
229 }
230 }
231
232 self
233 }
234
235 pub fn auth(mut self, s: Auth) -> Self {
238 self.auth = Some(s);
239 self
240 }
241
242 pub fn auth_http_insecure(mut self, ahi: bool) -> Self {
243 self.auth_http_insecure = ahi;
244 self
245 }
246
247 pub fn max_attempt(mut self, s: usize) -> Self {
248 self.max_attempt = s;
249 self
250 }
251
252 pub fn ssl(mut self, ssl: Ssl) -> Self {
253 self.ssl = Some(ssl);
254 self
255 }
256
257 pub fn build(self) -> Result<Client> {
258 let session = self.session.build()?;
259 let max_attempt = self.max_attempt;
260
261 if (self.auth.is_some() && session.url.scheme() == "http") && !self.auth_http_insecure {
262 return Err(Error::BasicAuthWithHttp);
263 }
264
265 let mut client_builder =
266 reqwest::ClientBuilder::new().timeout(session.client_request_timeout);
267
268 if self.no_verify {
269 client_builder = client_builder.danger_accept_invalid_certs(true);
270 }
271
272 if let Some(ssl) = &self.ssl {
273 if let Some(root) = &ssl.root_cert {
274 client_builder = client_builder.add_root_certificate(root.0.clone());
275 }
276 }
277
278 let client = client_builder.build()?;
279
280 #[cfg(feature = "spooling")]
281 let segment_fetcher = self.segment_fetcher.unwrap_or_else(|| {
282 let mut fetcher = SegmentFetcher::new(client.clone());
283 if let Some(max_concurrent) = self.max_concurrent_segments {
284 fetcher = fetcher.with_max_concurrent(max_concurrent);
285 }
286 fetcher
287 });
288
289 let cli = Client {
290 auth: self.auth,
291 url: session.url.clone(),
292 session: RwLock::new(session),
293 client,
294 max_attempt,
295 #[cfg(feature = "spooling")]
296 segment_fetcher,
297 };
298
299 Ok(cli)
300 }
301}
302
303fn add_prepare_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
304 builder = builder.header(HEADER_USER, &session.user);
306 builder = builder.header(USER_AGENT, "trino-rust-client");
308 if session.compression_disabled {
309 builder = builder.header(ACCEPT_ENCODING, "identity")
310 }
311 builder
312}
313
314fn add_session_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
315 builder = add_prepare_header(builder, session);
316 builder = builder.header(HEADER_SOURCE, &session.source);
317
318 if let Some(v) = &session.trace_token {
319 builder = builder.header(HEADER_TRACE_TOKEN, v);
320 }
321
322 if !session.client_tags.is_empty() {
323 builder = builder.header(HEADER_CLIENT_TAGS, session.client_tags.by_ref().join(","));
324 }
325
326 if let Some(v) = &session.client_info {
327 builder = builder.header(HEADER_CLIENT_INFO, v);
328 }
329
330 if let Some(v) = &session.catalog {
331 builder = builder.header(HEADER_CATALOG, v);
332 }
333
334 if let Some(v) = &session.schema {
335 builder = builder.header(HEADER_SCHEMA, v);
336 }
337
338 if let Some(v) = &session.path {
339 builder = builder.header(HEADER_PATH, v);
340 }
341 if let Some(v) = &session.timezone {
342 builder = builder.header(HEADER_TIME_ZONE, v.to_string())
343 }
344 builder = add_header_map(builder, HEADER_SESSION, &session.properties);
346 builder = add_header_map(
347 builder,
348 HEADER_RESOURCE_ESTIMATE,
349 &session.resource_estimates,
350 );
351 builder = add_header_map(
352 builder,
353 HEADER_ROLE,
354 &session
355 .roles
356 .by_ref()
357 .map_kv(|(k, v)| (k.to_string(), v.to_string())),
358 );
359 builder = add_header_map(builder, HEADER_EXTRA_CREDENTIAL, &session.extra_credentials);
360 builder = add_header_map(
361 builder,
362 HEADER_PREPARED_STATEMENT,
363 &session.prepared_statements,
364 );
365 builder = builder.header(HEADER_TRANSACTION, session.transaction_id.to_str());
366 builder = builder.header(HEADER_CLIENT_CAPABILITIES, "PATH,PARAMETRIC_DATETIME");
367
368 #[cfg(feature = "spooling")]
370 {
371 if let Some(encoding) = &session.spooling_encoding {
372 builder = builder.header(HEADER_SPOOLING, encoding);
373 }
374 }
375
376 builder
377}
378
379fn add_header_map<'a>(
380 mut builder: RequestBuilder,
381 header: &str,
382 map: impl IntoIterator<Item = (&'a String, &'a String)>,
383) -> RequestBuilder {
384 for (k, v) in map {
385 let kv = encode_kv(k, v);
386 builder = builder.header(header, kv);
387 }
388 builder
389}
390
391macro_rules! set_header {
392 ($session:expr, $header:expr, $resp:expr) => {
393 set_header!($session, $header, $resp, |x: &str| Some(Some(
394 x.to_string()
395 )));
396 };
397
398 ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
399 if let Some(v) = $resp.headers().get($header) {
400 match v.to_str() {
401 Ok(s) => {
402 if let Some(s) = $from_str(s) {
403 $session = s;
404 }
405 }
406 Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
407 }
408 }
409 };
410}
411
412macro_rules! clear_header {
413 ($session:expr, $header:expr, $resp:expr) => {
414 if let Some(_) = $resp.headers().get($header) {
415 $session = Default::default();
416 }
417 };
418}
419
420macro_rules! set_header_map {
421 ($session:expr, $header:expr, $resp:expr) => {
422 set_header_map!($session, $header, $resp, |x: &str| Some(x.to_string()));
423 };
424 ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
425 for v in $resp.headers().get_all($header) {
426 if let Some((k, v)) = decode_kv_from_header(v) {
427 if let Some(v) = $from_str(&v) {
428 $session.insert(k, v);
429 }
430 } else {
431 warn!("decode '{:?}' failed", v)
432 }
433 }
434 };
435}
436
437macro_rules! clear_header_map {
438 ($session:expr, $header:expr, $resp:expr) => {
439 for v in $resp.headers().get_all($header) {
440 match v.to_str() {
441 Ok(s) => {
442 $session.remove(s);
443 }
444 Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
445 }
446 }
447 };
448}
449
450fn need_retry(e: &Error) -> bool {
451 match e {
452 Error::HttpError(e) => e.status() == Some(StatusCode::SERVICE_UNAVAILABLE),
453 Error::HttpNotOk(code, _) => code == &StatusCode::SERVICE_UNAVAILABLE,
454 _ => false,
455 }
456}
457
458impl Client {
459 pub async fn get_all<T>(&self, sql: String) -> Result<DataSet<T>>
460 where
461 T: Trino + 'static,
462 for<'de> T: serde::Deserialize<'de> + serde::Serialize,
463 {
464 let res = self.get_retry(sql).await?;
465
466 let mut columns = res.columns;
468
469 match res.data {
470 Some(QueryResultData::Direct(rows)) => {
471 let mut all_rows = rows;
473
474 let mut next = res.next_uri;
475 while let Some(url) = &next {
476 let mut res = self.get_next_retry(url).await?;
477 next = res.next_uri;
478
479 if columns.is_none() {
481 columns = res.columns.take();
482 }
483
484 if let Some(error) = res.error {
485 if error.error_code == 4 {
486 return Err(Error::Forbidden {
487 message: error.message,
488 });
489 } else {
490 return Err(Error::InternalError(format!(
491 "Query failed with {} (error code {}): {}",
492 error.error_name, error.error_code, error.message
493 )));
494 }
495 }
496
497 if let Some(data) = res.data {
498 match data {
499 QueryResultData::Direct(rows) => {
500 all_rows.extend(rows);
501 }
502 #[cfg(feature = "spooling")]
503 QueryResultData::Spooled(_) => {
504 return Err(Error::InternalError(
505 "Cannot mix Direct and Spooled protocols in same query".to_string(),
506 ));
507 }
508 #[cfg(not(feature = "spooling"))]
509 QueryResultData::Spooled(_) => {
510 return Err(Error::InternalError(
511 "Server sent spooled data but 'spooling' feature is not enabled. \
512 Add features = [\"spooling\"] to your trino-rust-client dependency in Cargo.toml.".to_string(),
513 ));
514 }
515 }
516 }
517 }
518
519 build_dataset(all_rows, columns)
520 }
521 #[cfg(feature = "spooling")]
522 Some(QueryResultData::Spooled(spooled)) => {
523 let mut dataset = self
524 .fetch_spooled_data::<T>(spooled, columns.clone())
525 .await?;
526
527 let mut next = res.next_uri;
528 while let Some(url) = &next {
529 let mut res = self.get_next_retry::<T>(url).await?;
530 next = res.next_uri;
531
532 if columns.is_none() {
533 columns = res.columns.take();
534 }
535
536 if let Some(error) = res.error {
537 if error.error_code == 4 {
538 return Err(Error::Forbidden {
539 message: error.message,
540 });
541 } else {
542 return Err(Error::InternalError(format!(
543 "Query failed with {} (error code {}): {}",
544 error.error_name, error.error_code, error.message
545 )));
546 }
547 }
548
549 if let Some(data) = res.data {
550 match data {
551 QueryResultData::Direct(_) => {
552 return Err(Error::InternalError(
553 "Cannot mix Direct and Spooled protocols in same query".to_string(),
554 ));
555 }
556 QueryResultData::Spooled(spooled) => {
557 log::info!("🗄️ Received SPOOLED protocol data - fetching from S3/MinIO");
558 let cols_for_spooled = columns.clone().or_else(|| res.columns.take());
559 let next_dataset = self
560 .fetch_spooled_data::<T>(spooled, cols_for_spooled)
561 .await?;
562 dataset.merge(next_dataset);
563 }
564 }
565 }
566 }
567
568 Ok(dataset)
569 }
570 #[cfg(not(feature = "spooling"))]
571 Some(QueryResultData::Spooled(_)) => {
572 Err(Error::InternalError(
573 "Server sent spooled data but 'spooling' feature is not enabled. \
574 Add features = [\"spooling\"] to your trino-rust-client dependency in Cargo.toml.".to_string(),
575 ))
576 }
577 None => {
578 let mut next = res.next_uri;
580 let mut protocol_detected = false;
581 let mut all_rows: Vec<T> = Vec::new();
582 #[cfg(feature = "spooling")]
583 let mut dataset: Option<DataSet<T>> = None;
584
585 while let Some(url) = &next {
586 let mut res = self.get_next_retry::<T>(url).await?;
587 next = res.next_uri;
588
589 if columns.is_none() {
590 columns = res.columns.take();
591 }
592
593 if let Some(error) = res.error {
594 if error.error_code == 4 {
595 return Err(Error::Forbidden {
596 message: error.message,
597 });
598 } else {
599 return Err(Error::InternalError(format!(
600 "Query failed with {} (error code {}): {}",
601 error.error_name, error.error_code, error.message
602 )));
603 }
604 }
605
606 if let Some(data) = res.data {
607 match data {
608 QueryResultData::Direct(rows) => {
609 if !protocol_detected {
610 protocol_detected = true;
611 }
612 all_rows.extend(rows);
613 }
614 #[cfg(feature = "spooling")]
615 QueryResultData::Spooled(spooled) => {
616 if !protocol_detected {
617 protocol_detected = true;
618 let cols_for_spooled = columns.clone().or_else(|| res.columns.take());
619 dataset = Some(self.fetch_spooled_data::<T>(spooled, cols_for_spooled).await?);
620 } else {
621 let cols_for_spooled = columns.clone().or_else(|| res.columns.take());
622 let next_dataset = self.fetch_spooled_data::<T>(spooled, cols_for_spooled).await?;
623 if let Some(ref mut ds) = dataset {
624 ds.merge(next_dataset);
625 }
626 }
627 }
628 #[cfg(not(feature = "spooling"))]
629 QueryResultData::Spooled(_) => {
630 return Err(Error::InternalError(
631 "Server sent spooled data but 'spooling' feature is not enabled. \
632 Add features = [\"spooling\"] to your trino-rust-client dependency in Cargo.toml.".to_string(),
633 ));
634 }
635 }
636 }
637 }
638
639 #[cfg(feature = "spooling")]
640 if let Some(ds) = dataset {
641 Ok(ds)
642 } else if !all_rows.is_empty() {
643 build_dataset(all_rows, columns)
644 } else {
645 Err(Error::EmptyData)
646 }
647 #[cfg(not(feature = "spooling"))]
648 if !all_rows.is_empty() {
649 build_dataset(all_rows, columns)
650 } else {
651 Err(Error::EmptyData)
652 }
653 }
654 }
655 }
656
657 #[cfg(feature = "spooling")]
658 async fn fetch_spooled_data<T: Trino + 'static>(
659 &self,
660 spooled: SpooledData,
661 columns: Option<Vec<crate::models::Column>>,
662 ) -> Result<DataSet<T>> {
663 let segment_bytes = self
664 .segment_fetcher
665 .fetch_segments(spooled.segments)
666 .await?;
667
668 let dataset = self.decode_segments::<T>(&spooled.encoding, segment_bytes, columns)?;
669
670 Ok(dataset)
671 }
672
673 #[cfg(feature = "spooling")]
674 fn decode_segments<T: Trino + 'static>(
675 &self,
676 encoding: &str,
677 segment_bytes: Vec<Vec<u8>>,
678 columns: Option<Vec<crate::models::Column>>,
679 ) -> Result<DataSet<T>> {
680 let cols = columns.ok_or_else(|| {
681 Error::InternalError("Column metadata required for spooling protocol".to_string())
682 })?;
683
684 let mut all_rows: Vec<Vec<serde_json::Value>> = Vec::new();
685
686 let encoding = SpoolingEncoding::try_from(encoding).map_err(|e| {
687 Error::InternalError(format!(
688 "Failed to parse encoding: {}. Only 'json' based formats are supported.",
689 e
690 ))
691 })?;
692
693 for bytes in segment_bytes {
694 let json_str = decompress_segment_bytes(&bytes, &encoding)?;
695
696 let mut rows: Vec<Vec<serde_json::Value>> =
697 serde_json::from_str(&json_str).map_err(|e| {
698 Error::InternalError(format!("Failed to parse segment JSON: {}", e))
699 })?;
700
701 all_rows.append(&mut rows);
702 }
703
704 let json_obj = serde_json::json!({
705 "columns": cols,
706 "data": all_rows
707 });
708
709 let dataset: DataSet<T> = serde_json::from_value(json_obj)
710 .map_err(|e| Error::InternalError(format!("Failed to deserialize DataSet: {}", e)))?;
711
712 Ok(dataset)
713 }
714
715 pub async fn execute(&self, sql: String) -> Result<ExecuteResult> {
722 let res = self.get_retry::<Row>(sql).await?;
724
725 let mut next = res.next_uri;
726 let mut final_uri = next.clone();
727
728 while let Some(url) = &next {
731 let res = self.get_next_retry::<Row>(url).await?;
732
733 let next_uri = res.next_uri;
734
735 if next_uri.is_some() {
737 final_uri = next_uri.clone();
738 }
739 next = next_uri;
740 }
741
742 let url = final_uri.ok_or_else(|| {
743 Error::InternalError("No next URI available for execution result".to_string())
744 })?;
745
746 let result = self.try_get_retry_result(&url).await?;
748
749 if let Some(error) = result.error {
750 return Err(error.into());
751 }
752
753 Ok(ExecuteResult {
754 output_uri: None,
755 update_type: result.update_type,
756 update_count: result.update_count,
757 })
758 }
759
760 async fn try_get_retry_result(&self, url: &str) -> Result<TrinoRetryResult> {
761 let response = self.client.get(url).send().await?;
762
763 let result = response.json::<TrinoRetryResult>().await?;
764
765 Ok(result)
766 }
767
768 fn retry_policy(&self) -> ExponentialBuilder {
769 ExponentialBuilder::default()
770 .with_max_times(self.max_attempt)
771 .with_max_delay(Duration::from_secs(2))
772 }
773
774 async fn get_retry<T>(&self, sql: String) -> Result<QueryResult<T>>
775 where
776 T: Trino + 'static,
777 for<'de> T: serde::Deserialize<'de>,
778 {
779 let result = || async { self.get::<T>(sql.clone()).await };
780
781 result.retry(self.retry_policy()).when(need_retry).await
782 }
783
784 async fn get_next_retry<T>(&self, url: &str) -> Result<QueryResult<T>>
785 where
786 T: Trino + 'static,
787 for<'de> T: serde::Deserialize<'de>,
788 {
789 let result = || async { self.get_next(url).await };
790
791 result.retry(self.retry_policy()).when(need_retry).await
792 }
793
794 pub async fn get<T>(&self, sql: String) -> Result<QueryResult<T>>
795 where
796 T: Trino + 'static,
797 for<'de> T: serde::Deserialize<'de>,
798 {
799 let req = self
800 .client
801 .post(format!("{}v1/statement", self.url))
802 .body(sql);
803 let req = {
804 let session = self.session.read().await;
805 add_session_header(req, &session)
806 };
807
808 let req = self.auth_req(req);
809 self.send(req, StatusCode::OK, |resp| async {
810 let text = resp.text().await?;
811
812 let data: QueryResult<T> = serde_json::from_str(&text)
813 .map_err(|e| Error::InternalError(format!("Failed to parse response: {}", e)))?;
814 Ok(data)
815 })
816 .await
817 }
818
819 pub async fn get_next<T>(&self, url: &str) -> Result<QueryResult<T>>
820 where
821 T: Trino + 'static,
822 for<'de> T: serde::Deserialize<'de>,
823 {
824 let req = self.client.get(url);
825 let req = {
826 let session = self.session.read().await;
827 add_prepare_header(req, &session)
828 };
829
830 let req = self.auth_req(req);
831 self.send(req, StatusCode::OK, |resp| async {
832 let text = resp.text().await?;
833 let data: QueryResult<T> = serde_json::from_str(&text)
834 .map_err(|e| Error::InternalError(format!("Failed to parse response: {}", e)))?;
835 Ok(data)
836 })
837 .await
838 }
839
840 pub async fn cancel(&self, query_id: &str) -> Result<()> {
841 let url = format!("{}v1/query/{}", self.url, query_id);
842 let req = self.client.delete(url);
843 let req = {
844 let session = self.session.read().await;
845 add_prepare_header(req, &session)
846 };
847
848 let req = self.auth_req(req);
849 self.send(req, StatusCode::NO_CONTENT, |_| async { Ok(()) })
850 .await
851 }
852
853 fn auth_req(&self, req: RequestBuilder) -> RequestBuilder {
854 if let Some(auth) = self.auth.as_ref() {
855 match auth {
856 Auth::Basic(u, p) => req.basic_auth(u, p.as_ref()),
857 Auth::Jwt(t) => req.bearer_auth(t),
858 }
859 } else {
860 req
861 }
862 }
863
864 async fn send<R, F, Fut>(
865 &self,
866 req: RequestBuilder,
867 expected_status: StatusCode,
868 handle_response: F,
869 ) -> Result<R>
870 where
871 F: FnOnce(Response) -> Fut,
872 Fut: std::future::Future<Output = Result<R>>,
873 {
874 let resp = req.send().await?;
875 let status = resp.status();
876 if status != expected_status {
877 let data = resp.text().await.unwrap_or("".to_string());
878 Err(Error::HttpNotOk(status, data))
879 } else {
880 self.update_session(&resp).await;
881 handle_response(resp).await
882 }
883 }
884
885 async fn update_session(&self, resp: &Response) {
886 let mut session = self.session.write().await;
887
888 set_header!(session.catalog, HEADER_SET_CATALOG, resp);
889 set_header!(session.schema, HEADER_SET_SCHEMA, resp);
890 set_header!(session.path, HEADER_SET_PATH, resp);
891
892 set_header_map!(session.properties, HEADER_SET_SESSION, resp);
893 clear_header_map!(session.properties, HEADER_CLEAR_SESSION, resp);
894
895 set_header_map!(session.roles, HEADER_SET_ROLE, resp, SelectedRole::from_str);
896
897 set_header_map!(session.prepared_statements, HEADER_ADDED_PREPARE, resp);
898 clear_header_map!(
899 session.prepared_statements,
900 HEADER_DEALLOCATED_PREPARE,
901 resp
902 );
903
904 set_header!(
905 session.transaction_id,
906 HEADER_STARTED_TRANSACTION_ID,
907 resp,
908 TransactionId::from_str
909 );
910 clear_header!(session.transaction_id, HEADER_CLEAR_TRANSACTION_ID, resp);
911 }
912}
913
914fn encode_kv(k: &str, v: &str) -> String {
918 url::form_urlencoded::Serializer::new(String::new())
919 .append_pair(k, v)
920 .finish()
921}
922
923fn decode_kv_from_header(input: &HeaderValue) -> Option<(String, String)> {
924 let kvs = url::form_urlencoded::parse(input.as_bytes()).collect::<Vec<_>>();
925 if kvs.is_empty() {
926 None
927 } else {
928 Some((kvs[0].0.to_string(), kvs[0].1.to_string()))
929 }
930}
931
932#[cfg(test)]
933mod tests {
934 use reqwest::header::HeaderValue;
935
936 use crate::client::decode_kv_from_header;
937
938 #[test]
939 fn test_decode_kv_from_header_plus_sign_to_space() {
940 let header_value = HeaderValue::from_static("statement=show+tables");
941 let result = decode_kv_from_header(&header_value);
942 assert!(result.is_some());
943 let (key, value) = result.unwrap();
944 assert_eq!(key, "statement");
945 assert_eq!(value, "show tables");
946 }
947
948 #[test]
949 fn test_decode_kv_from_header_percent_encoding() {
950 let header_value = HeaderValue::from_static("statement=show%20tables");
951 let result = decode_kv_from_header(&header_value);
952 assert!(result.is_some());
953 let (key, value) = result.unwrap();
954 assert_eq!(key, "statement");
955 assert_eq!(value, "show tables");
956 }
957}