prusto_rs/
client.rs

1use std::collections::{HashMap, HashSet};
2
3use http::header::{ACCEPT_ENCODING, USER_AGENT};
4use http::StatusCode;
5use iterable::*;
6use log::*;
7use reqwest::header::HeaderValue;
8use reqwest::{RequestBuilder, Response, Url};
9use tokio::sync::RwLock;
10use tokio::time::{sleep, Duration};
11
12use crate::auth::Auth;
13use crate::error::{Error, Result};
14#[cfg(not(feature = "presto"))]
15use crate::header::*;
16#[cfg(feature = "presto")]
17use crate::presto_header::*;
18use crate::selected_role::SelectedRole;
19use crate::session::{Session, SessionBuilder};
20use crate::ssl::Ssl;
21use crate::transaction::TransactionId;
22use crate::{DataSet, Presto, QueryResult, Row};
23
24// TODO:
25// allow_redirects
26// proxies
27// cancel
28
29pub struct Client {
30    client: reqwest::Client,
31    session: RwLock<Session>,
32    auth: Option<Auth>,
33    max_attempt: usize,
34    url: Url,
35}
36
37pub struct ClientBuilder {
38    session: SessionBuilder,
39    auth: Option<Auth>,
40    max_attempt: usize,
41    ssl: Option<Ssl>,
42}
43
44#[derive(Debug)]
45pub struct ExecuteResult {
46    _m: (),
47}
48
49impl ClientBuilder {
50    pub fn new(user: impl ToString, host: impl ToString) -> Self {
51        let builder = SessionBuilder::new(user, host);
52        Self {
53            session: builder,
54            auth: None,
55            max_attempt: 3,
56            ssl: None,
57        }
58    }
59
60    pub fn port(mut self, s: u16) -> Self {
61        self.session.port = s;
62        self
63    }
64
65    pub fn secure(mut self, s: bool) -> Self {
66        self.session.secure = s;
67        self
68    }
69
70    pub fn source(mut self, s: impl ToString) -> Self {
71        self.session.source = s.to_string();
72        self
73    }
74
75    pub fn trace_token(mut self, s: impl ToString) -> Self {
76        self.session.trace_token = Some(s.to_string());
77        self
78    }
79
80    pub fn client_tags(mut self, s: HashSet<String>) -> Self {
81        self.session.client_tags = s;
82        self
83    }
84
85    pub fn client_tag(mut self, s: impl ToString) -> Self {
86        self.session.client_tags.insert(s.to_string());
87        self
88    }
89
90    pub fn client_info(mut self, s: impl ToString) -> Self {
91        self.session.client_info = Some(s.to_string());
92        self
93    }
94
95    pub fn catalog(mut self, s: impl ToString) -> Self {
96        self.session.catalog = Some(s.to_string());
97        self
98    }
99
100    pub fn schema(mut self, s: impl ToString) -> Self {
101        self.session.schema = Some(s.to_string());
102        self
103    }
104
105    pub fn path(mut self, s: impl ToString) -> Self {
106        self.session.path = Some(s.to_string());
107        self
108    }
109
110    pub fn resource_estimates(mut self, s: HashMap<String, String>) -> Self {
111        self.session.resource_estimates = s;
112        self
113    }
114
115    pub fn resource_estimate(mut self, k: impl ToString, v: impl ToString) -> Self {
116        self.session
117            .resource_estimates
118            .insert(k.to_string(), v.to_string());
119        self
120    }
121
122    pub fn properties(mut self, s: HashMap<String, String>) -> Self {
123        self.session.properties = s;
124        self
125    }
126
127    pub fn property(mut self, k: impl ToString, v: impl ToString) -> Self {
128        self.session.properties.insert(k.to_string(), v.to_string());
129        self
130    }
131
132    pub fn prepared_statements(mut self, s: HashMap<String, String>) -> Self {
133        self.session.prepared_statements = s;
134        self
135    }
136
137    pub fn prepared_statement(mut self, k: impl ToString, v: impl ToString) -> Self {
138        self.session
139            .prepared_statements
140            .insert(k.to_string(), v.to_string());
141        self
142    }
143
144    pub fn extra_credentials(mut self, s: HashMap<String, String>) -> Self {
145        self.session.extra_credentials = s;
146        self
147    }
148
149    pub fn extra_credential(mut self, k: impl ToString, v: impl ToString) -> Self {
150        self.session
151            .extra_credentials
152            .insert(k.to_string(), v.to_string());
153        self
154    }
155
156    pub fn transaction_id(mut self, s: TransactionId) -> Self {
157        self.session.transaction_id = s;
158        self
159    }
160
161    pub fn client_request_timeout(mut self, s: Duration) -> Self {
162        self.session.client_request_timeout = s;
163        self
164    }
165
166    pub fn compression_disabled(mut self, s: bool) -> Self {
167        self.session.compression_disabled = s;
168        self
169    }
170
171    ////////////////////////////////////////////////////////////////////////////////////////////////
172
173    pub fn auth(mut self, s: Auth) -> Self {
174        self.auth = Some(s);
175        self
176    }
177
178    pub fn max_attempt(mut self, s: usize) -> Self {
179        self.max_attempt = s;
180        self
181    }
182
183    pub fn ssl(mut self, ssl: Ssl) -> Self {
184        self.ssl = Some(ssl);
185        self
186    }
187
188    pub fn build(self) -> Result<Client> {
189        let session = self.session.build()?;
190        let max_attempt = self.max_attempt;
191
192        if self.auth.is_some() && session.url.scheme() == "http" {
193            return Err(Error::BasicAuthWithHttp);
194        }
195
196        let mut client_builder =
197            reqwest::ClientBuilder::new().timeout(session.client_request_timeout);
198
199        if let Some(ssl) = &self.ssl {
200            if let Some(root) = &ssl.root_cert {
201                client_builder = client_builder.add_root_certificate(root.0.clone());
202            }
203        }
204
205        let cli = Client {
206            auth: self.auth,
207            url: session.url.clone(),
208            session: RwLock::new(session),
209            client: client_builder.build()?,
210            max_attempt,
211        };
212
213        Ok(cli)
214    }
215}
216
217fn add_prepare_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
218    builder = builder.header(HEADER_USER, &session.user);
219    // TODO: difference with session.source?
220    builder = builder.header(USER_AGENT, "trino-rust-client");
221    if session.compression_disabled {
222        builder = builder.header(ACCEPT_ENCODING, "identity")
223    }
224    builder
225}
226
227fn add_session_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
228    builder = add_prepare_header(builder, session);
229    builder = builder.header(HEADER_SOURCE, &session.source);
230
231    if let Some(v) = &session.trace_token {
232        builder = builder.header(HEADER_TRACE_TOKEN, v);
233    }
234
235    if !session.client_tags.is_empty() {
236        builder = builder.header(HEADER_CLIENT_TAGS, session.client_tags.by_ref().join(","));
237    }
238
239    if let Some(v) = &session.client_info {
240        builder = builder.header(HEADER_CLIENT_INFO, v);
241    }
242
243    if let Some(v) = &session.catalog {
244        builder = builder.header(HEADER_CATALOG, v);
245    }
246
247    if let Some(v) = &session.schema {
248        builder = builder.header(HEADER_SCHEMA, v);
249    }
250
251    if let Some(v) = &session.path {
252        builder = builder.header(HEADER_PATH, v);
253    }
254    if let Some(v) = &session.timezone {
255        builder = builder.header(HEADER_TIME_ZONE, v.to_string())
256    }
257    // TODO: add locale
258    builder = add_header_map(builder, HEADER_SESSION, &session.properties);
259    builder = add_header_map(
260        builder,
261        HEADER_RESOURCE_ESTIMATE,
262        &session.resource_estimates,
263    );
264    builder = add_header_map(
265        builder,
266        HEADER_ROLE,
267        &session
268            .roles
269            .by_ref()
270            .map_kv(|(k, v)| (k.to_string(), v.to_string())),
271    );
272    builder = add_header_map(builder, HEADER_EXTRA_CREDENTIAL, &session.extra_credentials);
273    builder = add_header_map(
274        builder,
275        HEADER_PREPARED_STATEMENT,
276        &session.prepared_statements,
277    );
278    builder = builder.header(HEADER_TRANSACTION, session.transaction_id.to_str());
279    builder = builder.header(HEADER_CLIENT_CAPABILITIES, "PATH,PARAMETRIC_DATETIME");
280    builder
281}
282
283fn add_header_map<'a>(
284    mut builder: RequestBuilder,
285    header: &str,
286    map: impl IntoIterator<Item = (&'a String, &'a String)>,
287) -> RequestBuilder {
288    for (k, v) in map {
289        let kv = encode_kv(k, v);
290        builder = builder.header(header, kv);
291    }
292    builder
293}
294
295macro_rules! retry {
296    ($self:expr, $f:ident, $param:expr, $max_attempt:expr) => {{
297        for _ in 0..$max_attempt {
298            let res = $self.$f($param.clone()).await;
299            match res {
300                Ok(d) => match d.error {
301                    Some(e) => return Err(Error::QueryError(e)),
302                    None => return Ok(d),
303                },
304                Err(e) if need_retry(&e) => {
305                    sleep(Duration::from_millis(100)).await;
306                    continue;
307                }
308                Err(e) => return Err(e),
309            }
310        }
311
312        Err(Error::ReachMaxAttempt($max_attempt))
313    }};
314}
315
316macro_rules! set_header {
317    ($session:expr, $header:expr, $resp:expr) => {
318        set_header!($session, $header, $resp, |x: &str| Some(Some(
319            x.to_string()
320        )));
321    };
322
323    ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
324        if let Some(v) = $resp.headers().get($header) {
325            match v.to_str() {
326                Ok(s) => {
327                    if let Some(s) = $from_str(s) {
328                        $session = s;
329                    }
330                }
331                Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
332            }
333        }
334    };
335}
336
337macro_rules! clear_header {
338    ($session:expr, $header:expr, $resp:expr) => {
339        if let Some(_) = $resp.headers().get($header) {
340            $session = Default::default();
341        }
342    };
343}
344
345macro_rules! set_header_map {
346    ($session:expr, $header:expr, $resp:expr) => {
347        set_header_map!($session, $header, $resp, |x: &str| Some(x.to_string()));
348    };
349    ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
350        for v in $resp.headers().get_all($header) {
351            if let Some((k, v)) = decode_kv_from_header(v) {
352                if let Some(v) = $from_str(&v) {
353                    $session.insert(k, v);
354                }
355            } else {
356                warn!("decode '{:?}' failed", v)
357            }
358        }
359    };
360}
361
362macro_rules! clear_header_map {
363    ($session:expr, $header:expr, $resp:expr) => {
364        for v in $resp.headers().get_all($header) {
365            match v.to_str() {
366                Ok(s) => {
367                    $session.remove(s);
368                }
369                Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
370            }
371        }
372    };
373}
374
375fn need_retry(e: &Error) -> bool {
376    match e {
377        Error::HttpError(e) => e.status() == Some(StatusCode::SERVICE_UNAVAILABLE),
378        Error::HttpNotOk(code, _) => code == &StatusCode::SERVICE_UNAVAILABLE,
379        _ => false,
380    }
381}
382
383impl Client {
384    pub async fn get_all<T: Presto + 'static>(&self, sql: String) -> Result<DataSet<T>> {
385        let res = self.get_retry(sql).await?;
386        let mut ret = res.data_set;
387
388        let mut next = res.next_uri;
389        while let Some(url) = &next {
390            let res = self.get_next_retry(url).await?;
391            next = res.next_uri;
392            if let Some(d) = res.data_set {
393                match &mut ret {
394                    Some(ret) => ret.merge(d),
395                    None => ret = Some(d),
396                }
397            }
398        }
399
400        if let Some(d) = ret {
401            Ok(d)
402        } else {
403            Err(Error::EmptyData)
404        }
405    }
406
407    pub async fn execute(&self, sql: String) -> Result<ExecuteResult> {
408        let res = self.get_retry::<Row>(sql).await?;
409
410        let mut next = res.next_uri;
411        while let Some(url) = &next {
412            let res = self.get_next_retry::<Row>(url).await?;
413            next = res.next_uri;
414        }
415        Ok(ExecuteResult { _m: () })
416    }
417
418    async fn get_retry<T: Presto + 'static>(&self, sql: String) -> Result<QueryResult<T>> {
419        retry!(self, get, sql, self.max_attempt)
420    }
421
422    async fn get_next_retry<T: Presto + 'static>(&self, url: &str) -> Result<QueryResult<T>> {
423        retry!(self, get_next, url, self.max_attempt)
424    }
425
426    pub async fn get<T: Presto + 'static>(&self, sql: String) -> Result<QueryResult<T>> {
427        let req = self.client.post(self.url.clone()).body(sql);
428        let req = {
429            let session = self.session.read().await;
430            add_session_header(req, &session)
431        };
432
433        let req = if let Some(auth) = self.auth.as_ref() {
434            match auth {
435                Auth::Basic(u, p) => req.basic_auth(u, p.as_ref()),
436            }
437        } else {
438            req
439        };
440
441        self.send(req).await
442    }
443
444    pub async fn get_next<T: Presto + 'static>(&self, url: &str) -> Result<QueryResult<T>> {
445        let req = self.client.get(url);
446        let req = {
447            let session = self.session.read().await;
448            add_prepare_header(req, &session)
449        };
450
451        self.send(req).await
452    }
453
454    async fn send<T: Presto + 'static>(&self, req: RequestBuilder) -> Result<QueryResult<T>> {
455        let resp = req.send().await?;
456        let status = resp.status();
457        if status != StatusCode::OK {
458            let data = resp.text().await.unwrap_or("".to_string());
459            Err(Error::HttpNotOk(status, data))
460        } else {
461            self.update_session(&resp).await;
462            let data = resp.json::<QueryResult<T>>().await?;
463            Ok(data)
464        }
465    }
466
467    async fn update_session(&self, resp: &Response) {
468        let mut session = self.session.write().await;
469
470        set_header!(session.catalog, HEADER_SET_CATALOG, resp);
471        set_header!(session.schema, HEADER_SET_SCHEMA, resp);
472        set_header!(session.path, HEADER_SET_PATH, resp);
473
474        set_header_map!(session.properties, HEADER_SET_SESSION, resp);
475        clear_header_map!(session.properties, HEADER_CLEAR_SESSION, resp);
476
477        set_header_map!(session.roles, HEADER_SET_ROLE, resp, SelectedRole::from_str);
478
479        set_header_map!(session.prepared_statements, HEADER_ADDED_PREPARE, resp);
480        clear_header_map!(
481            session.prepared_statements,
482            HEADER_DEALLOCATED_PREPARE,
483            resp
484        );
485
486        set_header!(
487            session.transaction_id,
488            HEADER_STARTED_TRANSACTION_ID,
489            resp,
490            TransactionId::from_str
491        );
492        clear_header!(session.transaction_id, HEADER_CLEAR_TRANSACTION_ID, resp);
493    }
494}
495
496////////////////////////////////////////////////////////////////////////////////////////////////
497// helper functions
498
499fn encode_kv(k: &str, v: &str) -> String {
500    format!("{}={}", k, urlencoding::encode(v))
501}
502
503fn decode_kv_from_header(input: &HeaderValue) -> Option<(String, String)> {
504    let s = input.to_str().ok()?;
505    let kv = s.split('=').collect::<Vec<_>>();
506    if kv.len() != 2 {
507        return None;
508    }
509    let k = kv[0].to_string();
510    let v = urlencoding::decode(kv[1]).ok()?;
511    Some((k, v.to_string()))
512}