1use std::collections::{HashMap, HashSet};
2
3use http::header::{ACCEPT_ENCODING, USER_AGENT};
4use http::StatusCode;
5use iterable::*;
6use log::*;
7use reqwest::header::HeaderValue;
8use reqwest::{RequestBuilder, Response, Url};
9use tokio::sync::RwLock;
10use tokio::time::{sleep, Duration};
11
12use crate::auth::Auth;
13use crate::error::{Error, Result};
14#[cfg(not(feature = "presto"))]
15use crate::header::*;
16#[cfg(feature = "presto")]
17use crate::presto_header::*;
18use crate::selected_role::SelectedRole;
19use crate::session::{Session, SessionBuilder};
20use crate::ssl::Ssl;
21use crate::transaction::TransactionId;
22use crate::{DataSet, Presto, QueryResult, Row};
23
24pub struct Client {
30 client: reqwest::Client,
31 session: RwLock<Session>,
32 auth: Option<Auth>,
33 max_attempt: usize,
34 url: Url,
35}
36
37pub struct ClientBuilder {
38 session: SessionBuilder,
39 auth: Option<Auth>,
40 max_attempt: usize,
41 ssl: Option<Ssl>,
42}
43
44#[derive(Debug)]
45pub struct ExecuteResult {
46 _m: (),
47}
48
49impl ClientBuilder {
50 pub fn new(user: impl ToString, host: impl ToString) -> Self {
51 let builder = SessionBuilder::new(user, host);
52 Self {
53 session: builder,
54 auth: None,
55 max_attempt: 3,
56 ssl: None,
57 }
58 }
59
60 pub fn port(mut self, s: u16) -> Self {
61 self.session.port = s;
62 self
63 }
64
65 pub fn secure(mut self, s: bool) -> Self {
66 self.session.secure = s;
67 self
68 }
69
70 pub fn source(mut self, s: impl ToString) -> Self {
71 self.session.source = s.to_string();
72 self
73 }
74
75 pub fn trace_token(mut self, s: impl ToString) -> Self {
76 self.session.trace_token = Some(s.to_string());
77 self
78 }
79
80 pub fn client_tags(mut self, s: HashSet<String>) -> Self {
81 self.session.client_tags = s;
82 self
83 }
84
85 pub fn client_tag(mut self, s: impl ToString) -> Self {
86 self.session.client_tags.insert(s.to_string());
87 self
88 }
89
90 pub fn client_info(mut self, s: impl ToString) -> Self {
91 self.session.client_info = Some(s.to_string());
92 self
93 }
94
95 pub fn catalog(mut self, s: impl ToString) -> Self {
96 self.session.catalog = Some(s.to_string());
97 self
98 }
99
100 pub fn schema(mut self, s: impl ToString) -> Self {
101 self.session.schema = Some(s.to_string());
102 self
103 }
104
105 pub fn path(mut self, s: impl ToString) -> Self {
106 self.session.path = Some(s.to_string());
107 self
108 }
109
110 pub fn resource_estimates(mut self, s: HashMap<String, String>) -> Self {
111 self.session.resource_estimates = s;
112 self
113 }
114
115 pub fn resource_estimate(mut self, k: impl ToString, v: impl ToString) -> Self {
116 self.session
117 .resource_estimates
118 .insert(k.to_string(), v.to_string());
119 self
120 }
121
122 pub fn properties(mut self, s: HashMap<String, String>) -> Self {
123 self.session.properties = s;
124 self
125 }
126
127 pub fn property(mut self, k: impl ToString, v: impl ToString) -> Self {
128 self.session.properties.insert(k.to_string(), v.to_string());
129 self
130 }
131
132 pub fn prepared_statements(mut self, s: HashMap<String, String>) -> Self {
133 self.session.prepared_statements = s;
134 self
135 }
136
137 pub fn prepared_statement(mut self, k: impl ToString, v: impl ToString) -> Self {
138 self.session
139 .prepared_statements
140 .insert(k.to_string(), v.to_string());
141 self
142 }
143
144 pub fn extra_credentials(mut self, s: HashMap<String, String>) -> Self {
145 self.session.extra_credentials = s;
146 self
147 }
148
149 pub fn extra_credential(mut self, k: impl ToString, v: impl ToString) -> Self {
150 self.session
151 .extra_credentials
152 .insert(k.to_string(), v.to_string());
153 self
154 }
155
156 pub fn transaction_id(mut self, s: TransactionId) -> Self {
157 self.session.transaction_id = s;
158 self
159 }
160
161 pub fn client_request_timeout(mut self, s: Duration) -> Self {
162 self.session.client_request_timeout = s;
163 self
164 }
165
166 pub fn compression_disabled(mut self, s: bool) -> Self {
167 self.session.compression_disabled = s;
168 self
169 }
170
171 pub fn auth(mut self, s: Auth) -> Self {
174 self.auth = Some(s);
175 self
176 }
177
178 pub fn max_attempt(mut self, s: usize) -> Self {
179 self.max_attempt = s;
180 self
181 }
182
183 pub fn ssl(mut self, ssl: Ssl) -> Self {
184 self.ssl = Some(ssl);
185 self
186 }
187
188 pub fn build(self) -> Result<Client> {
189 let session = self.session.build()?;
190 let max_attempt = self.max_attempt;
191
192 if self.auth.is_some() && session.url.scheme() == "http" {
193 return Err(Error::BasicAuthWithHttp);
194 }
195
196 let mut client_builder =
197 reqwest::ClientBuilder::new().timeout(session.client_request_timeout);
198
199 if let Some(ssl) = &self.ssl {
200 if let Some(root) = &ssl.root_cert {
201 client_builder = client_builder.add_root_certificate(root.0.clone());
202 }
203 }
204
205 let cli = Client {
206 auth: self.auth,
207 url: session.url.clone(),
208 session: RwLock::new(session),
209 client: client_builder.build()?,
210 max_attempt,
211 };
212
213 Ok(cli)
214 }
215}
216
217fn add_prepare_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
218 builder = builder.header(HEADER_USER, &session.user);
219 builder = builder.header(USER_AGENT, "trino-rust-client");
221 if session.compression_disabled {
222 builder = builder.header(ACCEPT_ENCODING, "identity")
223 }
224 builder
225}
226
227fn add_session_header(mut builder: RequestBuilder, session: &Session) -> RequestBuilder {
228 builder = add_prepare_header(builder, session);
229 builder = builder.header(HEADER_SOURCE, &session.source);
230
231 if let Some(v) = &session.trace_token {
232 builder = builder.header(HEADER_TRACE_TOKEN, v);
233 }
234
235 if !session.client_tags.is_empty() {
236 builder = builder.header(HEADER_CLIENT_TAGS, session.client_tags.by_ref().join(","));
237 }
238
239 if let Some(v) = &session.client_info {
240 builder = builder.header(HEADER_CLIENT_INFO, v);
241 }
242
243 if let Some(v) = &session.catalog {
244 builder = builder.header(HEADER_CATALOG, v);
245 }
246
247 if let Some(v) = &session.schema {
248 builder = builder.header(HEADER_SCHEMA, v);
249 }
250
251 if let Some(v) = &session.path {
252 builder = builder.header(HEADER_PATH, v);
253 }
254 if let Some(v) = &session.timezone {
255 builder = builder.header(HEADER_TIME_ZONE, v.to_string())
256 }
257 builder = add_header_map(builder, HEADER_SESSION, &session.properties);
259 builder = add_header_map(
260 builder,
261 HEADER_RESOURCE_ESTIMATE,
262 &session.resource_estimates,
263 );
264 builder = add_header_map(
265 builder,
266 HEADER_ROLE,
267 &session
268 .roles
269 .by_ref()
270 .map_kv(|(k, v)| (k.to_string(), v.to_string())),
271 );
272 builder = add_header_map(builder, HEADER_EXTRA_CREDENTIAL, &session.extra_credentials);
273 builder = add_header_map(
274 builder,
275 HEADER_PREPARED_STATEMENT,
276 &session.prepared_statements,
277 );
278 builder = builder.header(HEADER_TRANSACTION, session.transaction_id.to_str());
279 builder = builder.header(HEADER_CLIENT_CAPABILITIES, "PATH,PARAMETRIC_DATETIME");
280 builder
281}
282
283fn add_header_map<'a>(
284 mut builder: RequestBuilder,
285 header: &str,
286 map: impl IntoIterator<Item = (&'a String, &'a String)>,
287) -> RequestBuilder {
288 for (k, v) in map {
289 let kv = encode_kv(k, v);
290 builder = builder.header(header, kv);
291 }
292 builder
293}
294
295macro_rules! retry {
296 ($self:expr, $f:ident, $param:expr, $max_attempt:expr) => {{
297 for _ in 0..$max_attempt {
298 let res = $self.$f($param.clone()).await;
299 match res {
300 Ok(d) => match d.error {
301 Some(e) => return Err(Error::QueryError(e)),
302 None => return Ok(d),
303 },
304 Err(e) if need_retry(&e) => {
305 sleep(Duration::from_millis(100)).await;
306 continue;
307 }
308 Err(e) => return Err(e),
309 }
310 }
311
312 Err(Error::ReachMaxAttempt($max_attempt))
313 }};
314}
315
316macro_rules! set_header {
317 ($session:expr, $header:expr, $resp:expr) => {
318 set_header!($session, $header, $resp, |x: &str| Some(Some(
319 x.to_string()
320 )));
321 };
322
323 ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
324 if let Some(v) = $resp.headers().get($header) {
325 match v.to_str() {
326 Ok(s) => {
327 if let Some(s) = $from_str(s) {
328 $session = s;
329 }
330 }
331 Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
332 }
333 }
334 };
335}
336
337macro_rules! clear_header {
338 ($session:expr, $header:expr, $resp:expr) => {
339 if let Some(_) = $resp.headers().get($header) {
340 $session = Default::default();
341 }
342 };
343}
344
345macro_rules! set_header_map {
346 ($session:expr, $header:expr, $resp:expr) => {
347 set_header_map!($session, $header, $resp, |x: &str| Some(x.to_string()));
348 };
349 ($session:expr, $header:expr, $resp:expr, $from_str:expr) => {
350 for v in $resp.headers().get_all($header) {
351 if let Some((k, v)) = decode_kv_from_header(v) {
352 if let Some(v) = $from_str(&v) {
353 $session.insert(k, v);
354 }
355 } else {
356 warn!("decode '{:?}' failed", v)
357 }
358 }
359 };
360}
361
362macro_rules! clear_header_map {
363 ($session:expr, $header:expr, $resp:expr) => {
364 for v in $resp.headers().get_all($header) {
365 match v.to_str() {
366 Ok(s) => {
367 $session.remove(s);
368 }
369 Err(e) => warn!("parse header {} failed, reason: {}", $header, e),
370 }
371 }
372 };
373}
374
375fn need_retry(e: &Error) -> bool {
376 match e {
377 Error::HttpError(e) => e.status() == Some(StatusCode::SERVICE_UNAVAILABLE),
378 Error::HttpNotOk(code, _) => code == &StatusCode::SERVICE_UNAVAILABLE,
379 _ => false,
380 }
381}
382
383impl Client {
384 pub async fn get_all<T: Presto + 'static>(&self, sql: String) -> Result<DataSet<T>> {
385 let res = self.get_retry(sql).await?;
386 let mut ret = res.data_set;
387
388 let mut next = res.next_uri;
389 while let Some(url) = &next {
390 let res = self.get_next_retry(url).await?;
391 next = res.next_uri;
392 if let Some(d) = res.data_set {
393 match &mut ret {
394 Some(ret) => ret.merge(d),
395 None => ret = Some(d),
396 }
397 }
398 }
399
400 if let Some(d) = ret {
401 Ok(d)
402 } else {
403 Err(Error::EmptyData)
404 }
405 }
406
407 pub async fn execute(&self, sql: String) -> Result<ExecuteResult> {
408 let res = self.get_retry::<Row>(sql).await?;
409
410 let mut next = res.next_uri;
411 while let Some(url) = &next {
412 let res = self.get_next_retry::<Row>(url).await?;
413 next = res.next_uri;
414 }
415 Ok(ExecuteResult { _m: () })
416 }
417
418 async fn get_retry<T: Presto + 'static>(&self, sql: String) -> Result<QueryResult<T>> {
419 retry!(self, get, sql, self.max_attempt)
420 }
421
422 async fn get_next_retry<T: Presto + 'static>(&self, url: &str) -> Result<QueryResult<T>> {
423 retry!(self, get_next, url, self.max_attempt)
424 }
425
426 pub async fn get<T: Presto + 'static>(&self, sql: String) -> Result<QueryResult<T>> {
427 let req = self.client.post(self.url.clone()).body(sql);
428 let req = {
429 let session = self.session.read().await;
430 add_session_header(req, &session)
431 };
432
433 let req = if let Some(auth) = self.auth.as_ref() {
434 match auth {
435 Auth::Basic(u, p) => req.basic_auth(u, p.as_ref()),
436 }
437 } else {
438 req
439 };
440
441 self.send(req).await
442 }
443
444 pub async fn get_next<T: Presto + 'static>(&self, url: &str) -> Result<QueryResult<T>> {
445 let req = self.client.get(url);
446 let req = {
447 let session = self.session.read().await;
448 add_prepare_header(req, &session)
449 };
450
451 self.send(req).await
452 }
453
454 async fn send<T: Presto + 'static>(&self, req: RequestBuilder) -> Result<QueryResult<T>> {
455 let resp = req.send().await?;
456 let status = resp.status();
457 if status != StatusCode::OK {
458 let data = resp.text().await.unwrap_or("".to_string());
459 Err(Error::HttpNotOk(status, data))
460 } else {
461 self.update_session(&resp).await;
462 let data = resp.json::<QueryResult<T>>().await?;
463 Ok(data)
464 }
465 }
466
467 async fn update_session(&self, resp: &Response) {
468 let mut session = self.session.write().await;
469
470 set_header!(session.catalog, HEADER_SET_CATALOG, resp);
471 set_header!(session.schema, HEADER_SET_SCHEMA, resp);
472 set_header!(session.path, HEADER_SET_PATH, resp);
473
474 set_header_map!(session.properties, HEADER_SET_SESSION, resp);
475 clear_header_map!(session.properties, HEADER_CLEAR_SESSION, resp);
476
477 set_header_map!(session.roles, HEADER_SET_ROLE, resp, SelectedRole::from_str);
478
479 set_header_map!(session.prepared_statements, HEADER_ADDED_PREPARE, resp);
480 clear_header_map!(
481 session.prepared_statements,
482 HEADER_DEALLOCATED_PREPARE,
483 resp
484 );
485
486 set_header!(
487 session.transaction_id,
488 HEADER_STARTED_TRANSACTION_ID,
489 resp,
490 TransactionId::from_str
491 );
492 clear_header!(session.transaction_id, HEADER_CLEAR_TRANSACTION_ID, resp);
493 }
494}
495
496fn encode_kv(k: &str, v: &str) -> String {
500 format!("{}={}", k, urlencoding::encode(v))
501}
502
503fn decode_kv_from_header(input: &HeaderValue) -> Option<(String, String)> {
504 let s = input.to_str().ok()?;
505 let kv = s.split('=').collect::<Vec<_>>();
506 if kv.len() != 2 {
507 return None;
508 }
509 let k = kv[0].to_string();
510 let v = urlencoding::decode(kv[1]).ok()?;
511 Some((k, v.to_string()))
512}