Skip to main content

trino_rust_client/
client.rs

1use std::collections::{HashMap, HashSet};
2
3use backon::ExponentialBuilder;
4use backon::Retryable;
5use http::header::{ACCEPT_ENCODING, USER_AGENT};
6use http::StatusCode;
7use iterable::*;
8use log::*;
9use reqwest::header::HeaderValue;
10use reqwest::{RequestBuilder, Response, Url};
11use tokio::sync::RwLock;
12use tokio::time::Duration;
13
14use crate::auth::Auth;
15use crate::build_dataset;
16use crate::error::TrinoRetryResult;
17use crate::error::{Error, Result};
18use crate::header::*;
19use crate::models::QueryResultData;
20#[cfg(feature = "spooling")]
21use crate::models::SpooledData;
22use crate::selected_role::SelectedRole;
23use crate::session::{Session, SessionBuilder};
24#[cfg(feature = "spooling")]
25use crate::spooling::decompress_segment_bytes;
26#[cfg(feature = "spooling")]
27use crate::spooling::{SegmentFetcher, SpoolingEncoding};
28use crate::ssl::Ssl;
29use crate::transaction::TransactionId;
30use crate::{DataSet, QueryResult, Row, Trino};
31
32// TODO:
33// allow_redirects
34// proxies
35
36pub struct Client {
37    client: reqwest::Client,
38    session: RwLock<Session>,
39    auth: Option<Auth>,
40    max_attempt: usize,
41    url: Url,
42    #[cfg(feature = "spooling")]
43    segment_fetcher: SegmentFetcher,
44}
45
46pub struct ClientBuilder {
47    session: SessionBuilder,
48    auth: Option<Auth>,
49    auth_http_insecure: bool,
50    max_attempt: usize,
51    ssl: Option<Ssl>,
52    no_verify: bool,
53    #[cfg(feature = "spooling")]
54    segment_fetcher: Option<SegmentFetcher>,
55    #[cfg(feature = "spooling")]
56    max_concurrent_segments: Option<usize>,
57}
58
59#[derive(Debug)]
60pub struct ExecuteResult {
61    pub output_uri: Option<String>,
62    pub update_type: Option<String>,
63    pub update_count: Option<u64>,
64}
65
66impl ClientBuilder {
67    pub fn new(user: impl ToString, host: impl ToString) -> Self {
68        let builder = SessionBuilder::new(user, host);
69        Self {
70            session: builder,
71            auth: None,
72            auth_http_insecure: false,
73            max_attempt: 3,
74            ssl: None,
75            no_verify: false,
76            #[cfg(feature = "spooling")]
77            segment_fetcher: None,
78            #[cfg(feature = "spooling")]
79            max_concurrent_segments: None,
80        }
81    }
82
83    pub fn port(mut self, s: u16) -> Self {
84        self.session.port = s;
85        self
86    }
87
88    pub fn secure(mut self, s: bool) -> Self {
89        self.session.secure = s;
90        self
91    }
92
93    pub fn no_verify(mut self, nv: bool) -> Self {
94        self.no_verify = nv;
95        self
96    }
97
98    pub fn source(mut self, s: impl ToString) -> Self {
99        self.session.source = s.to_string();
100        self
101    }
102
103    pub fn trace_token(mut self, s: impl ToString) -> Self {
104        self.session.trace_token = Some(s.to_string());
105        self
106    }
107
108    pub fn client_tags(mut self, s: HashSet<String>) -> Self {
109        self.session.client_tags = s;
110        self
111    }
112
113    pub fn client_tag(mut self, s: impl ToString) -> Self {
114        self.session.client_tags.insert(s.to_string());
115        self
116    }
117
118    pub fn client_info(mut self, s: impl ToString) -> Self {
119        self.session.client_info = Some(s.to_string());
120        self
121    }
122
123    pub fn catalog(mut self, s: impl ToString) -> Self {
124        self.session.catalog = Some(s.to_string());
125        self
126    }
127
128    pub fn schema(mut self, s: impl ToString) -> Self {
129        self.session.schema = Some(s.to_string());
130        self
131    }
132
133    pub fn path(mut self, s: impl ToString) -> Self {
134        self.session.path = Some(s.to_string());
135        self
136    }
137
138    pub fn resource_estimates(mut self, s: HashMap<String, String>) -> Self {
139        self.session.resource_estimates = s;
140        self
141    }
142
143    pub fn resource_estimate(mut self, k: impl ToString, v: impl ToString) -> Self {
144        self.session
145            .resource_estimates
146            .insert(k.to_string(), v.to_string());
147        self
148    }
149
150    pub fn properties(mut self, s: HashMap<String, String>) -> Self {
151        self.session.properties = s;
152        self
153    }
154
155    pub fn property(mut self, k: impl ToString, v: impl ToString) -> Self {
156        self.session.properties.insert(k.to_string(), v.to_string());
157        self
158    }
159
160    pub fn prepared_statements(mut self, s: HashMap<String, String>) -> Self {
161        self.session.prepared_statements = s;
162        self
163    }
164
165    pub fn prepared_statement(mut self, k: impl ToString, v: impl ToString) -> Self {
166        self.session
167            .prepared_statements
168            .insert(k.to_string(), v.to_string());
169        self
170    }
171
172    pub fn extra_credentials(mut self, s: HashMap<String, String>) -> Self {
173        self.session.extra_credentials = s;
174        self
175    }
176
177    pub fn extra_credential(mut self, k: impl ToString, v: impl ToString) -> Self {
178        self.session
179            .extra_credentials
180            .insert(k.to_string(), v.to_string());
181        self
182    }
183
184    pub fn transaction_id(mut self, s: TransactionId) -> Self {
185        self.session.transaction_id = s;
186        self
187    }
188
189    pub fn client_request_timeout(mut self, s: Duration) -> Self {
190        self.session.client_request_timeout = s;
191        self
192    }
193
194    pub fn compression_disabled(mut self, s: bool) -> Self {
195        self.session.compression_disabled = s;
196        self
197    }
198
199    #[cfg(feature = "spooling")]
200    pub fn segment_fetcher(mut self, segment_fetcher: SegmentFetcher) -> Self {
201        self.segment_fetcher = Some(segment_fetcher);
202        self
203    }
204
205    #[cfg(feature = "spooling")]
206    /// Set the maximum number of concurrent segment fetches
207    /// Default is based on available CPU parallelism (minimum 1)
208    pub fn max_concurrent_segments(mut self, count: usize) -> Self {
209        self.max_concurrent_segments = Some(count);
210        self
211    }
212
213    #[cfg(feature = "spooling")]
214    /// Set the spooling encoding format. Supported values: "json", "json+zstd", "json+lz4".
215    /// Defaults to "json+zstd" if not specified.
216    pub fn spooling_encoding(mut self, encoding: impl ToString) -> Self {
217        let encoding_str = encoding.to_string();
218
219        match SpoolingEncoding::try_from(encoding_str.as_str()) {
220            Ok(_) => {
221                self.session.spooling_encoding = Some(encoding_str);
222            }
223            Err(_) => {
224                log::warn!(
225                    "Invalid spooling encoding '{}', using default 'json+zstd'. Valid values: json, json+zstd, json+lz4",
226                    encoding_str
227                );
228                self.session.spooling_encoding = Some("json+zstd".to_string());
229            }
230        }
231
232        self
233    }
234
235    ////////////////////////////////////////////////////////////////////////////////////////////////
236
237    pub fn auth(mut self, s: Auth) -> Self {
238        self.auth = Some(s);
239        self
240    }
241
242    pub fn auth_http_insecure(mut self, ahi: bool) -> Self {
243        self.auth_http_insecure = ahi;
244        self
245    }
246
247    pub fn max_attempt(mut self, s: usize) -> Self {
248        self.max_attempt = s;
249        self
250    }
251
252    pub fn ssl(mut self, ssl: Ssl) -> Self {
253        self.ssl = Some(ssl);
254        self
255    }
256
257    pub fn build(self) -> Result<Client> {
258        let session = self.session.build()?;
259        let max_attempt = self.max_attempt;
260
261        if (self.auth.is_some() && session.url.scheme() == "http") && !self.auth_http_insecure {
262            return Err(Error::BasicAuthWithHttp);
263        }
264
265        let mut client_builder =
266            reqwest::ClientBuilder::new().timeout(session.client_request_timeout);
267
268        if self.no_verify {
269            client_builder = client_builder.danger_accept_invalid_certs(true);
270        }
271
272        if let Some(ssl) = &self.ssl {
273            if let Some(root) = &ssl.root_cert {
274                client_builder = client_builder.add_root_certificate(root.0.clone());
275            }
276        }
277
278        let client = client_builder.build()?;
279
280        #[cfg(feature = "spooling")]
281        let segment_fetcher = self.segment_fetcher.unwrap_or_else(|| {
282            let mut fetcher = SegmentFetcher::new(client.clone());
283            if let Some(max_concurrent) = self.max_concurrent_segments {
284                fetcher = fetcher.with_max_concurrent(max_concurrent);
285            }
286            fetcher
287        });
288
289        let cli = Client {
290            auth: self.auth,
291            url: session.url.clone(),
292            session: RwLock::new(session),
293            client,
294            max_attempt,
295            #[cfg(feature = "spooling")]
296            segment_fetcher,
297        };
298
299        Ok(cli)
300    }
301}
302
303fn add_prepare_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
304    //FIXME : set trino user from jwt ?
305    builder = builder.header(HEADER_USER, &session.user);
306    // TODO: difference with session.source?
307    builder = builder.header(USER_AGENT, "trino-rust-client");
308    if session.compression_disabled {
309        builder = builder.header(ACCEPT_ENCODING, "identity")
310    }
311    builder
312}
313
314fn add_session_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
315    builder = add_prepare_header(builder, session);
316    builder = builder.header(HEADER_SOURCE, &session.source);
317
318    if let Some(v) = &session.trace_token {
319        builder = builder.header(HEADER_TRACE_TOKEN, v);
320    }
321
322    if !session.client_tags.is_empty() {
323        builder = builder.header(HEADER_CLIENT_TAGS, session.client_tags.by_ref().join(","));
324    }
325
326    if let Some(v) = &session.client_info {
327        builder = builder.header(HEADER_CLIENT_INFO, v);
328    }
329
330    if let Some(v) = &session.catalog {
331        builder = builder.header(HEADER_CATALOG, v);
332    }
333
334    if let Some(v) = &session.schema {
335        builder = builder.header(HEADER_SCHEMA, v);
336    }
337
338    if let Some(v) = &session.path {
339        builder = builder.header(HEADER_PATH, v);
340    }
341    if let Some(v) = &session.timezone {
342        builder = builder.header(HEADER_TIME_ZONE, v.to_string())
343    }
344    // TODO: add locale
345    builder = add_header_map(builder, HEADER_SESSION, &session.properties);
346    builder = add_header_map(
347        builder,
348        HEADER_RESOURCE_ESTIMATE,
349        &session.resource_estimates,
350    );
351    builder = add_header_map(
352        builder,
353        HEADER_ROLE,
354        &session
355            .roles
356            .by_ref()
357            .map_kv(|(k, v)| (k.to_string(), v.to_string())),
358    );
359    builder = add_header_map(builder, HEADER_EXTRA_CREDENTIAL, &session.extra_credentials);
360    builder = add_header_map(
361        builder,
362        HEADER_PREPARED_STATEMENT,
363        &session.prepared_statements,
364    );
365    builder = builder.header(HEADER_TRANSACTION, session.transaction_id.to_str());
366    builder = builder.header(HEADER_CLIENT_CAPABILITIES, "PATH,PARAMETRIC_DATETIME");
367
368    // Add spooling header when feature is enabled
369    #[cfg(feature = "spooling")]
370    {
371        if let Some(encoding) = &session.spooling_encoding {
372            builder = builder.header(HEADER_SPOOLING, encoding);
373        }
374    }
375
376    builder
377}
378
379fn add_header_map<'a>(
380    mut builder: RequestBuilder,
381    header: &str,
382    map: impl IntoIterator<Item = (&'a String, &'a String)>,
383) -> RequestBuilder {
384    for (k, v) in map {
385        let kv = encode_kv(k, v);
386        builder = builder.header(header, kv);
387    }
388    builder
389}
390
391macro_rules! set_header {
392    ($session:expr, $header:expr, $resp:expr) => {
393        set_header!($session, $header, $resp, |x: &str| Some(Some(
394            x.to_string()
395        )));
396    };
397
398    ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
399        if let Some(v) = $resp.headers().get($header) {
400            match v.to_str() {
401                Ok(s) => {
402                    if let Some(s) = $from_str(s) {
403                        $session = s;
404                    }
405                }
406                Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
407            }
408        }
409    };
410}
411
412macro_rules! clear_header {
413    ($session:expr, $header:expr, $resp:expr) => {
414        if let Some(_) = $resp.headers().get($header) {
415            $session = Default::default();
416        }
417    };
418}
419
420macro_rules! set_header_map {
421    ($session:expr, $header:expr, $resp:expr) => {
422        set_header_map!($session, $header, $resp, |x: &str| Some(x.to_string()));
423    };
424    ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
425        for v in $resp.headers().get_all($header) {
426            if let Some((k, v)) = decode_kv_from_header(v) {
427                if let Some(v) = $from_str(&v) {
428                    $session.insert(k, v);
429                }
430            } else {
431                warn!("decode '{:?}' failed", v)
432            }
433        }
434    };
435}
436
437macro_rules! clear_header_map {
438    ($session:expr, $header:expr, $resp:expr) => {
439        for v in $resp.headers().get_all($header) {
440            match v.to_str() {
441                Ok(s) => {
442                    $session.remove(s);
443                }
444                Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
445            }
446        }
447    };
448}
449
450fn need_retry(e: &Error) -> bool {
451    match e {
452        Error::HttpError(e) => e.status() == Some(StatusCode::SERVICE_UNAVAILABLE),
453        Error::HttpNotOk(code, _) => code == &StatusCode::SERVICE_UNAVAILABLE,
454        _ => false,
455    }
456}
457
458impl Client {
459    pub async fn get_all<T>(&self, sql: String) -> Result<DataSet<T>>
460    where
461        T: Trino + 'static,
462        for<'de> T: serde::Deserialize<'de> + serde::Serialize,
463    {
464        let res = self.get_retry(sql).await?;
465
466        // Store columns from responses (used for Direct protocol DataSet construction)
467        let mut columns = res.columns;
468
469        match res.data {
470            Some(QueryResultData::Direct(rows)) => {
471                // Direct protocol: accumulate Vec<T>, convert to DataSet at the end
472                let mut all_rows = rows;
473
474                let mut next = res.next_uri;
475                while let Some(url) = &next {
476                    let mut res = self.get_next_retry(url).await?;
477                    next = res.next_uri;
478
479                    // Collect columns from any response that has them
480                    if columns.is_none() {
481                        columns = res.columns.take();
482                    }
483
484                    if let Some(error) = res.error {
485                        if error.error_code == 4 {
486                            return Err(Error::Forbidden {
487                                message: error.message,
488                            });
489                        } else {
490                            return Err(Error::InternalError(format!(
491                                "Query failed with {} (error code {}): {}",
492                                error.error_name, error.error_code, error.message
493                            )));
494                        }
495                    }
496
497                    if let Some(data) = res.data {
498                        match data {
499                            QueryResultData::Direct(rows) => {
500                                all_rows.extend(rows);
501                            }
502                            #[cfg(feature = "spooling")]
503                            QueryResultData::Spooled(_) => {
504                                return Err(Error::InternalError(
505                                    "Cannot mix Direct and Spooled protocols in same query".to_string(),
506                                ));
507                            }
508                            #[cfg(not(feature = "spooling"))]
509                            QueryResultData::Spooled(_) => {
510                                return Err(Error::InternalError(
511                                    "Server sent spooled data but 'spooling' feature is not enabled. \
512                                     Add features = [\"spooling\"] to your trino-rust-client dependency in Cargo.toml.".to_string(),
513                                ));
514                            }
515                        }
516                    }
517                }
518
519                build_dataset(all_rows, columns)
520            }
521            #[cfg(feature = "spooling")]
522            Some(QueryResultData::Spooled(spooled)) => {
523                let mut dataset = self
524                    .fetch_spooled_data::<T>(spooled, columns.clone())
525                    .await?;
526
527                let mut next = res.next_uri;
528                while let Some(url) = &next {
529                    let mut res = self.get_next_retry::<T>(url).await?;
530                    next = res.next_uri;
531
532                    if columns.is_none() {
533                        columns = res.columns.take();
534                    }
535
536                    if let Some(error) = res.error {
537                        if error.error_code == 4 {
538                            return Err(Error::Forbidden {
539                                message: error.message,
540                            });
541                        } else {
542                            return Err(Error::InternalError(format!(
543                                "Query failed with {} (error code {}): {}",
544                                error.error_name, error.error_code, error.message
545                            )));
546                        }
547                    }
548
549                    if let Some(data) = res.data {
550                        match data {
551                            QueryResultData::Direct(_) => {
552                                return Err(Error::InternalError(
553                                    "Cannot mix Direct and Spooled protocols in same query".to_string(),
554                                ));
555                            }
556                            QueryResultData::Spooled(spooled) => {
557                                log::info!("🗄️  Received SPOOLED protocol data - fetching from S3/MinIO");
558                                let cols_for_spooled = columns.clone().or_else(|| res.columns.take());
559                                let next_dataset = self
560                                    .fetch_spooled_data::<T>(spooled, cols_for_spooled)
561                                    .await?;
562                                dataset.merge(next_dataset);
563                            }
564                        }
565                    }
566                }
567
568                Ok(dataset)
569            }
570            #[cfg(not(feature = "spooling"))]
571            Some(QueryResultData::Spooled(_)) => {
572                Err(Error::InternalError(
573                    "Server sent spooled data but 'spooling' feature is not enabled. \
574                     Add features = [\"spooling\"] to your trino-rust-client dependency in Cargo.toml.".to_string(),
575                ))
576            }
577            None => {
578                // No initial data, wait for next response to detect protocol
579                let mut next = res.next_uri;
580                let mut protocol_detected = false;
581                let mut all_rows: Vec<T> = Vec::new();
582                #[cfg(feature = "spooling")]
583                let mut dataset: Option<DataSet<T>> = None;
584
585                while let Some(url) = &next {
586                    let mut res = self.get_next_retry::<T>(url).await?;
587                    next = res.next_uri;
588
589                    if columns.is_none() {
590                        columns = res.columns.take();
591                    }
592
593                    if let Some(error) = res.error {
594                        if error.error_code == 4 {
595                            return Err(Error::Forbidden {
596                                message: error.message,
597                            });
598                        } else {
599                            return Err(Error::InternalError(format!(
600                                "Query failed with {} (error code {}): {}",
601                                error.error_name, error.error_code, error.message
602                            )));
603                        }
604                    }
605
606                    if let Some(data) = res.data {
607                        match data {
608                            QueryResultData::Direct(rows) => {
609                                if !protocol_detected {
610                                    protocol_detected = true;
611                                }
612                                all_rows.extend(rows);
613                            }
614                            #[cfg(feature = "spooling")]
615                            QueryResultData::Spooled(spooled) => {
616                                if !protocol_detected {
617                                    protocol_detected = true;
618                                    let cols_for_spooled = columns.clone().or_else(|| res.columns.take());
619                                    dataset = Some(self.fetch_spooled_data::<T>(spooled, cols_for_spooled).await?);
620                                } else {
621                                    let cols_for_spooled = columns.clone().or_else(|| res.columns.take());
622                                    let next_dataset = self.fetch_spooled_data::<T>(spooled, cols_for_spooled).await?;
623                                    if let Some(ref mut ds) = dataset {
624                                        ds.merge(next_dataset);
625                                    }
626                                }
627                            }
628                            #[cfg(not(feature = "spooling"))]
629                            QueryResultData::Spooled(_) => {
630                                return Err(Error::InternalError(
631                                    "Server sent spooled data but 'spooling' feature is not enabled. \
632                                     Add features = [\"spooling\"] to your trino-rust-client dependency in Cargo.toml.".to_string(),
633                                ));
634                            }
635                        }
636                    }
637                }
638
639                #[cfg(feature = "spooling")]
640                if let Some(ds) = dataset {
641                    Ok(ds)
642                } else if !all_rows.is_empty() {
643                    build_dataset(all_rows, columns)
644                } else {
645                    Err(Error::EmptyData)
646                }
647                #[cfg(not(feature = "spooling"))]
648                if !all_rows.is_empty() {
649                    build_dataset(all_rows, columns)
650                } else {
651                    Err(Error::EmptyData)
652                }
653            }
654        }
655    }
656
657    #[cfg(feature = "spooling")]
658    async fn fetch_spooled_data<T: Trino + 'static>(
659        &self,
660        spooled: SpooledData,
661        columns: Option<Vec<crate::models::Column>>,
662    ) -> Result<DataSet<T>> {
663        let segment_bytes = self
664            .segment_fetcher
665            .fetch_segments(spooled.segments)
666            .await?;
667
668        let dataset = self.decode_segments::<T>(&spooled.encoding, segment_bytes, columns)?;
669
670        Ok(dataset)
671    }
672
673    #[cfg(feature = "spooling")]
674    fn decode_segments<T: Trino + 'static>(
675        &self,
676        encoding: &str,
677        segment_bytes: Vec<Vec<u8>>,
678        columns: Option<Vec<crate::models::Column>>,
679    ) -> Result<DataSet<T>> {
680        let cols = columns.ok_or_else(|| {
681            Error::InternalError("Column metadata required for spooling protocol".to_string())
682        })?;
683
684        let mut all_rows: Vec<Vec<serde_json::Value>> = Vec::new();
685
686        let encoding = SpoolingEncoding::try_from(encoding).map_err(|e| {
687            Error::InternalError(format!(
688                "Failed to parse encoding: {}. Only 'json' based formats are supported.",
689                e
690            ))
691        })?;
692
693        for bytes in segment_bytes {
694            let json_str = decompress_segment_bytes(&bytes, &encoding)?;
695
696            let mut rows: Vec<Vec<serde_json::Value>> =
697                serde_json::from_str(&json_str).map_err(|e| {
698                    Error::InternalError(format!("Failed to parse segment JSON: {}", e))
699                })?;
700
701            all_rows.append(&mut rows);
702        }
703
704        let json_obj = serde_json::json!({
705            "columns": cols,
706            "data": all_rows
707        });
708
709        let dataset: DataSet<T> = serde_json::from_value(json_obj)
710            .map_err(|e| Error::InternalError(format!("Failed to deserialize DataSet: {}", e)))?;
711
712        Ok(dataset)
713    }
714
715    /**
716     * Execute a SQL statement and return the result.
717     * If the TRINO query returns an error, the method returns an error of type `Error::TrinoError`
718     * @param sql The SQL statement to execute
719     * @return Result<ExecuteResult> The result of the execution
720     * */
721    pub async fn execute(&self, sql: String) -> Result<ExecuteResult> {
722        // try the sql first
723        let res = self.get_retry::<Row>(sql).await?;
724
725        let mut next = res.next_uri;
726        let mut final_uri = next.clone();
727
728        // Trino attempts several times to execute a query before marking it as failed.
729        // At the end, retrieve the URL of the last request to get the result
730        while let Some(url) = &next {
731            let res = self.get_next_retry::<Row>(url).await?;
732
733            let next_uri = res.next_uri;
734
735            // If next_uri is not None, update final_uri
736            if next_uri.is_some() {
737                final_uri = next_uri.clone();
738            }
739            next = next_uri;
740        }
741
742        let url = final_uri.ok_or_else(|| {
743            Error::InternalError("No next URI available for execution result".to_string())
744        })?;
745
746        // Parse the final URI to get TrinoRetryResult
747        let result = self.try_get_retry_result(&url).await?;
748
749        if let Some(error) = result.error {
750            return Err(error.into());
751        }
752
753        Ok(ExecuteResult {
754            output_uri: None,
755            update_type: result.update_type,
756            update_count: result.update_count,
757        })
758    }
759
760    async fn try_get_retry_result(&self, url: &str) -> Result<TrinoRetryResult> {
761        let response = self.client.get(url).send().await?;
762
763        let result = response.json::<TrinoRetryResult>().await?;
764
765        Ok(result)
766    }
767
768    fn retry_policy(&self) -> ExponentialBuilder {
769        ExponentialBuilder::default()
770            .with_max_times(self.max_attempt)
771            .with_max_delay(Duration::from_secs(2))
772    }
773
774    async fn get_retry<T>(&self, sql: String) -> Result<QueryResult<T>>
775    where
776        T: Trino + 'static,
777        for<'de> T: serde::Deserialize<'de>,
778    {
779        let result = || async { self.get::<T>(sql.clone()).await };
780
781        result.retry(self.retry_policy()).when(need_retry).await
782    }
783
784    async fn get_next_retry<T>(&self, url: &str) -> Result<QueryResult<T>>
785    where
786        T: Trino + 'static,
787        for<'de> T: serde::Deserialize<'de>,
788    {
789        let result = || async { self.get_next(url).await };
790
791        result.retry(self.retry_policy()).when(need_retry).await
792    }
793
794    pub async fn get<T>(&self, sql: String) -> Result<QueryResult<T>>
795    where
796        T: Trino + 'static,
797        for<'de> T: serde::Deserialize<'de>,
798    {
799        let req = self
800            .client
801            .post(format!("{}v1/statement", self.url))
802            .body(sql);
803        let req = {
804            let session = self.session.read().await;
805            add_session_header(req, &session)
806        };
807
808        let req = self.auth_req(req);
809        self.send(req, StatusCode::OK, |resp| async {
810            let text = resp.text().await?;
811
812            let data: QueryResult<T> = serde_json::from_str(&text)
813                .map_err(|e| Error::InternalError(format!("Failed to parse response: {}", e)))?;
814            Ok(data)
815        })
816        .await
817    }
818
819    pub async fn get_next<T>(&self, url: &str) -> Result<QueryResult<T>>
820    where
821        T: Trino + 'static,
822        for<'de> T: serde::Deserialize<'de>,
823    {
824        let req = self.client.get(url);
825        let req = {
826            let session = self.session.read().await;
827            add_prepare_header(req, &session)
828        };
829
830        let req = self.auth_req(req);
831        self.send(req, StatusCode::OK, |resp| async {
832            let text = resp.text().await?;
833            let data: QueryResult<T> = serde_json::from_str(&text)
834                .map_err(|e| Error::InternalError(format!("Failed to parse response: {}", e)))?;
835            Ok(data)
836        })
837        .await
838    }
839
840    pub async fn cancel(&self, query_id: &str) -> Result<()> {
841        let url = format!("{}v1/query/{}", self.url, query_id);
842        let req = self.client.delete(url);
843        let req = {
844            let session = self.session.read().await;
845            add_prepare_header(req, &session)
846        };
847
848        let req = self.auth_req(req);
849        self.send(req, StatusCode::NO_CONTENT, |_| async { Ok(()) })
850            .await
851    }
852
853    fn auth_req(&self, req: RequestBuilder) -> RequestBuilder {
854        if let Some(auth) = self.auth.as_ref() {
855            match auth {
856                Auth::Basic(u, p) => req.basic_auth(u, p.as_ref()),
857                Auth::Jwt(t) => req.bearer_auth(t),
858            }
859        } else {
860            req
861        }
862    }
863
864    async fn send<R, F, Fut>(
865        &self,
866        req: RequestBuilder,
867        expected_status: StatusCode,
868        handle_response: F,
869    ) -> Result<R>
870    where
871        F: FnOnce(Response) -> Fut,
872        Fut: std::future::Future<Output = Result<R>>,
873    {
874        let resp = req.send().await?;
875        let status = resp.status();
876        if status != expected_status {
877            let data = resp.text().await.unwrap_or("".to_string());
878            Err(Error::HttpNotOk(status, data))
879        } else {
880            self.update_session(&resp).await;
881            handle_response(resp).await
882        }
883    }
884
885    async fn update_session(&self, resp: &Response) {
886        let mut session = self.session.write().await;
887
888        set_header!(session.catalog, HEADER_SET_CATALOG, resp);
889        set_header!(session.schema, HEADER_SET_SCHEMA, resp);
890        set_header!(session.path, HEADER_SET_PATH, resp);
891
892        set_header_map!(session.properties, HEADER_SET_SESSION, resp);
893        clear_header_map!(session.properties, HEADER_CLEAR_SESSION, resp);
894
895        set_header_map!(session.roles, HEADER_SET_ROLE, resp, SelectedRole::from_str);
896
897        set_header_map!(session.prepared_statements, HEADER_ADDED_PREPARE, resp);
898        clear_header_map!(
899            session.prepared_statements,
900            HEADER_DEALLOCATED_PREPARE,
901            resp
902        );
903
904        set_header!(
905            session.transaction_id,
906            HEADER_STARTED_TRANSACTION_ID,
907            resp,
908            TransactionId::from_str
909        );
910        clear_header!(session.transaction_id, HEADER_CLEAR_TRANSACTION_ID, resp);
911    }
912}
913
914////////////////////////////////////////////////////////////////////////////////////////////////
915// helper functions
916
917fn encode_kv(k: &str, v: &str) -> String {
918    url::form_urlencoded::Serializer::new(String::new())
919        .append_pair(k, v)
920        .finish()
921}
922
923fn decode_kv_from_header(input: &HeaderValue) -> Option<(String, String)> {
924    let kvs = url::form_urlencoded::parse(input.as_bytes()).collect::<Vec<_>>();
925    if kvs.is_empty() {
926        None
927    } else {
928        Some((kvs[0].0.to_string(), kvs[0].1.to_string()))
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use reqwest::header::HeaderValue;
935
936    use crate::client::decode_kv_from_header;
937
938    #[test]
939    fn test_decode_kv_from_header_plus_sign_to_space() {
940        let header_value = HeaderValue::from_static("statement=show+tables");
941        let result = decode_kv_from_header(&header_value);
942        assert!(result.is_some());
943        let (key, value) = result.unwrap();
944        assert_eq!(key, "statement");
945        assert_eq!(value, "show tables");
946    }
947
948    #[test]
949    fn test_decode_kv_from_header_percent_encoding() {
950        let header_value = HeaderValue::from_static("statement=show%20tables");
951        let result = decode_kv_from_header(&header_value);
952        assert!(result.is_some());
953        let (key, value) = result.unwrap();
954        assert_eq!(key, "statement");
955        assert_eq!(value, "show tables");
956    }
957}