1use std::collections::HashMap;
2use std::future::IntoFuture;
3use std::io::Write as IoWrite;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9use futures_util::stream::{self, StreamExt, TryStreamExt};
10use reqwest::{
11 header::{AUTHORIZATION, CONTENT_ENCODING, CONTENT_TYPE, RETRY_AFTER},
12 Client as HttpClient, ClientBuilder, Response,
13};
14use tokio::sync::OnceCell;
15
16use crate::{
17 config::ClientConfig,
18 error::{Error, LineError, PartialWriteError},
19 flight::{BatchStream, FlightQueryClient},
20 precision::Precision,
21 query::{QueryOptions, QueryParameters, QueryResult, QueryType},
22 retry::{self, RetryConfig},
23 write::{WriteInput, WriteOptions},
24 Result,
25};
26
27pub struct Client {
31 config: ClientConfig,
32 http: HttpClient,
33 flight: OnceCell<FlightQueryClient>,
35}
36
37impl Client {
38 pub async fn new(config: ClientConfig) -> Result<Self> {
43 let http = build_http_client(&config)?;
44 Ok(Client {
45 config,
46 http,
47 flight: OnceCell::new(),
48 })
49 }
50
51 pub async fn from_connection_string(cs: &str) -> Result<Self> {
53 Client::new(ClientConfig::from_connection_string(cs)?).await
54 }
55
56 pub async fn from_env() -> Result<Self> {
59 Client::new(ClientConfig::from_env()?).await
60 }
61
62 pub fn config(&self) -> &ClientConfig {
64 &self.config
65 }
66
67 pub fn write<W: WriteInput>(&self, data: W) -> WriteRequest<'_, W> {
78 WriteRequest {
79 client: self,
80 data: Some(data),
81 options: self.config.write_options.clone(),
82 retry: None,
83 }
84 }
85
86 pub fn sql(&self, q: impl Into<String>) -> QueryRequest<'_> {
88 self.query(q, QueryType::Sql)
89 }
90
91 pub fn influxql(&self, q: impl Into<String>) -> QueryRequest<'_> {
93 self.query(q, QueryType::InfluxQL)
94 }
95
96 pub fn query(&self, q: impl Into<String>, language: QueryType) -> QueryRequest<'_> {
98 QueryRequest {
99 client: self,
100 query: q.into(),
101 query_type: language,
102 params: QueryParameters::new(),
103 headers: HashMap::new(),
104 retry: None,
105 }
106 }
107
108 pub async fn ping(&self) -> Result<String> {
110 let url = format!("{}/ping", self.config.host_url());
111 let mut req = self.http.get(&url);
112 if let Some(auth) = self.config.authorization_header()? {
113 req = req.header(AUTHORIZATION, auth);
114 }
115 let resp = req.send().await?;
116 let version = resp
117 .headers()
118 .get("x-influxdb-version")
119 .and_then(|v| v.to_str().ok())
120 .unwrap_or("unknown")
121 .to_owned();
122 Ok(version)
123 }
124
125 async fn flight(&self) -> Result<&FlightQueryClient> {
127 self.flight
128 .get_or_try_init(|| async {
129 FlightQueryClient::new(
130 &self.config.host,
131 self.config.token.as_deref(),
132 &self.config.auth_scheme,
133 self.config.ssl_roots_path.as_deref(),
134 self.config.query_timeout,
135 )
136 .await
137 })
138 .await
139 }
140
141 async fn send_lp(
143 &self,
144 body: Vec<u8>,
145 opts: &WriteOptions,
146 policy: &RetryConfig,
147 ) -> Result<()> {
148 let db = &self.config.database;
149
150 let (url, mut params) = if opts.use_v2_api {
151 let url = format!("{}/api/v2/write", self.config.host_url());
152 let mut p = vec![("bucket", db.clone())];
153 if let Some(org) = &self.config.org {
154 p.push(("org", org.clone()));
155 }
156 (url, p)
157 } else {
158 let url = format!("{}/api/v3/write_lp", self.config.host_url());
159 let mut p = vec![("db", db.clone())];
160 p.push(("accept_partial", opts.accept_partial.to_string()));
161 if opts.no_sync {
162 p.push(("no_sync", "true".to_string()));
163 }
164 (url, p)
165 };
166
167 params.push(("precision", opts.precision.as_str().to_string()));
168
169 let (final_body, compressed) = maybe_gzip(body, opts.gzip_threshold).await?;
171
172 let mut base = self
173 .http
174 .post(&url)
175 .query(¶ms)
176 .header(CONTENT_TYPE, "text/plain; charset=utf-8");
177 if compressed {
178 base = base.header(CONTENT_ENCODING, "gzip");
179 }
180 if let Some(auth) = self.config.authorization_header()? {
181 base = base.header(AUTHORIZATION, auth);
182 }
183 for (k, v) in &self.config.headers {
184 base = base.header(k, v);
185 }
186 let base = base.body(final_body);
187
188 let start = Instant::now();
189 let mut attempt: u32 = 0;
190 loop {
191 let req = base.try_clone().expect("write body is cloneable");
192 match send_write_once(req).await {
193 WriteAttempt::Ok => return Ok(()),
194 WriteAttempt::Fatal(e) => return Err(e),
195 WriteAttempt::Retry { after, last } => {
196 if attempt >= policy.max_retries {
197 return Err(last);
198 }
199 let delay = after
200 .filter(|_| policy.honor_retry_after)
201 .unwrap_or_else(|| policy.backoff(attempt));
202 if let Some(budget) = policy.max_elapsed {
203 if start.elapsed() + delay > budget {
204 return Err(last);
205 }
206 }
207 tokio::time::sleep(delay).await;
208 attempt += 1;
209 }
210 }
211 }
212 }
213}
214
215pub struct WriteRequest<'a, W: WriteInput> {
217 client: &'a Client,
218 data: Option<W>,
219 options: WriteOptions,
220 retry: Option<RetryConfig>,
221}
222
223impl<'a, W: WriteInput> WriteRequest<'a, W> {
224 pub fn precision(mut self, p: Precision) -> Self {
225 self.options.precision = p;
226 self
227 }
228 pub fn no_sync(mut self) -> Self {
229 self.options.no_sync = true;
230 self
231 }
232 pub fn accept_partial(mut self, accept: bool) -> Self {
233 self.options.accept_partial = accept;
234 self
235 }
236 pub fn use_v2_api(mut self) -> Self {
237 self.options.use_v2_api = true;
238 self
239 }
240 pub fn batch_size(mut self, n: usize) -> Self {
241 self.options.batch_size = n;
242 self
243 }
244 pub fn max_inflight(mut self, n: usize) -> Self {
245 self.options.max_inflight = n;
246 self
247 }
248 pub fn default_tag(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
249 self.options.default_tags.insert(key.into(), value.into());
250 self
251 }
252 pub fn gzip_threshold(mut self, t: Option<usize>) -> Self {
256 self.options.gzip_threshold = t;
257 self
258 }
259 pub fn tag_order(mut self, order: impl IntoIterator<Item = impl Into<String>>) -> Self {
260 self.options.tag_order = order.into_iter().map(Into::into).collect();
261 self
262 }
263
264 pub fn with_options(mut self, opts: WriteOptions) -> Self {
266 self.options = opts;
267 self
268 }
269
270 pub fn retry(mut self, policy: RetryConfig) -> Self {
272 self.retry = Some(policy);
273 self
274 }
275
276 pub fn no_retry(mut self) -> Self {
278 self.retry = Some(RetryConfig::disabled());
279 self
280 }
281}
282
283impl<'a, W: WriteInput + Send + 'a> IntoFuture for WriteRequest<'a, W> {
284 type Output = Result<()>;
285 type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
286
287 fn into_future(mut self) -> Self::IntoFuture {
288 let client = self.client;
289 let data = self.data.take().expect("data already taken");
290 let options = self.options;
291 let policy = self.retry.unwrap_or_else(|| client.config.retry.clone());
292 Box::pin(async move {
293 let max_inflight = options.max_inflight.max(1);
294 let batches = data.into_lp_batches(&options);
295
296 if max_inflight == 1 {
297 for batch in batches {
298 let bytes = batch?;
299 client.send_lp(bytes, &options, &policy).await?;
300 }
301 return Ok(());
302 }
303
304 let options = Arc::new(options);
305 let policy = Arc::new(policy);
306 stream::iter(batches)
307 .map(|b| b.map(|bytes| (bytes, Arc::clone(&options), Arc::clone(&policy))))
308 .try_for_each_concurrent(Some(max_inflight), |(bytes, opts, pol)| async move {
309 client.send_lp(bytes, &opts, &pol).await
310 })
311 .await
312 })
313 }
314}
315
316pub struct QueryRequest<'a> {
320 client: &'a Client,
321 query: String,
322 query_type: QueryType,
323 params: QueryParameters,
324 headers: HashMap<String, String>,
325 retry: Option<RetryConfig>,
326}
327
328impl<'a> QueryRequest<'a> {
329 pub fn param(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
331 self.params.insert(key.into(), value.into());
332 self
333 }
334
335 pub fn params<K, V, I>(mut self, params: I) -> Self
337 where
338 I: IntoIterator<Item = (K, V)>,
339 K: Into<String>,
340 V: Into<serde_json::Value>,
341 {
342 for (k, v) in params {
343 self.params.insert(k.into(), v.into());
344 }
345 self
346 }
347
348 pub fn header(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
350 self.headers.insert(k.into(), v.into());
351 self
352 }
353
354 pub fn retry(mut self, policy: RetryConfig) -> Self {
356 self.retry = Some(policy);
357 self
358 }
359
360 pub fn no_retry(mut self) -> Self {
362 self.retry = Some(RetryConfig::disabled());
363 self
364 }
365
366 pub async fn stream(self) -> Result<BatchStream> {
372 let policy = self
373 .retry
374 .clone()
375 .unwrap_or_else(|| self.client.config.retry.clone());
376 let opts = QueryOptions {
377 query_type: self.query_type,
378 headers: self.headers,
379 };
380 let params = (!self.params.is_empty()).then_some(self.params);
381
382 let mut attempt: u32 = 0;
383 loop {
384 let result = async {
385 let flight = self.client.flight().await?;
386 flight
387 .stream(
388 &self.query,
389 &self.client.config.database,
390 &opts,
391 params.as_ref(),
392 )
393 .await
394 }
395 .await;
396
397 match result {
398 Ok(s) => return Ok(s),
399 Err(e) => {
400 if attempt >= policy.max_retries || !is_retryable_query_err(&e) {
401 return Err(e);
402 }
403 tokio::time::sleep(policy.backoff(attempt)).await;
404 attempt += 1;
405 }
406 }
407 }
408 }
409}
410
411fn is_retryable_query_err(e: &Error) -> bool {
413 match e {
414 Error::Flight(status) => retry::retryable_tonic(status.code()),
415 Error::Transport(_) => true,
416 Error::Timeout(_) => true,
417 _ => false,
418 }
419}
420
421impl<'a> IntoFuture for QueryRequest<'a> {
422 type Output = Result<QueryResult>;
423 type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'a>>;
424
425 fn into_future(self) -> Self::IntoFuture {
426 Box::pin(async move {
427 let timeout = self.client.config.query_timeout;
428 let policy = self
429 .retry
430 .clone()
431 .unwrap_or_else(|| self.client.config.retry.clone());
432 let opts = QueryOptions {
433 query_type: self.query_type,
434 headers: self.headers,
435 };
436 let params = (!self.params.is_empty()).then_some(self.params);
437
438 let start = Instant::now();
439 let mut attempt: u32 = 0;
440 loop {
441 let fut = async {
444 let flight = self.client.flight().await?;
445 flight
446 .query(
447 &self.query,
448 &self.client.config.database,
449 &opts,
450 params.as_ref(),
451 )
452 .await
453 };
454 let outcome = match tokio::time::timeout(timeout, fut).await {
455 Ok(inner) => inner,
456 Err(_) => Err(Error::Timeout(timeout)),
457 };
458
459 match outcome {
460 Ok(result) => return Ok(result),
461 Err(e) => {
462 if attempt >= policy.max_retries || !is_retryable_query_err(&e) {
463 return Err(e);
464 }
465 let delay = policy.backoff(attempt);
466 if let Some(budget) = policy.max_elapsed {
467 if start.elapsed() + delay > budget {
468 return Err(e);
469 }
470 }
471 tokio::time::sleep(delay).await;
472 attempt += 1;
473 }
474 }
475 }
476 })
477 }
478}
479
480fn build_http_client(config: &ClientConfig) -> Result<HttpClient> {
481 let mut builder = ClientBuilder::new()
482 .timeout(config.write_timeout)
483 .pool_idle_timeout(config.idle_connection_timeout)
484 .pool_max_idle_per_host(config.max_idle_connections)
485 .gzip(true)
486 .use_rustls_tls();
487
488 if let Some(proxy_url) = &config.proxy {
489 let proxy = reqwest::Proxy::all(proxy_url)
490 .map_err(|e| Error::Config(format!("invalid proxy URL: {e}")))?;
491 builder = builder.proxy(proxy);
492 }
493
494 if let Some(roots_path) = &config.ssl_roots_path {
495 let pem = std::fs::read(roots_path)
496 .map_err(|e| Error::Config(format!("cannot read SSL roots '{roots_path}': {e}")))?;
497 let cert = reqwest::tls::Certificate::from_pem(&pem)
498 .map_err(|e| Error::Config(format!("invalid SSL roots PEM: {e}")))?;
499 builder = builder.add_root_certificate(cert);
500 }
501
502 Ok(builder.build()?)
503}
504
505const SPAWN_BLOCKING_GZIP_THRESHOLD: usize = 64 * 1024;
508
509async fn maybe_gzip(data: Vec<u8>, threshold: Option<usize>) -> Result<(Bytes, bool)> {
511 let should_compress = matches!(threshold, Some(t) if data.len() > t);
512 if !should_compress {
513 return Ok((Bytes::from(data), false));
514 }
515
516 if data.len() < SPAWN_BLOCKING_GZIP_THRESHOLD {
517 let compressed = gzip_compress(data)?;
518 return Ok((Bytes::from(compressed), true));
519 }
520
521 let compressed = tokio::task::spawn_blocking(move || gzip_compress(data))
522 .await
523 .map_err(|e| Error::Config(format!("gzip task join error: {e}")))??;
524 Ok((Bytes::from(compressed), true))
525}
526
527fn gzip_compress(data: Vec<u8>) -> Result<Vec<u8>> {
528 let mut encoder = flate2::write::GzEncoder::new(
529 Vec::with_capacity(data.len() / 2),
530 flate2::Compression::default(),
531 );
532 encoder
533 .write_all(&data)
534 .map_err(|e| Error::Config(format!("gzip encoding failed: {e}")))?;
535 encoder
536 .finish()
537 .map_err(|e| Error::Config(format!("gzip finalization failed: {e}")))
538}
539
540enum WriteAttempt {
542 Ok,
543 Retry {
546 after: Option<Duration>,
547 last: Error,
548 },
549 Fatal(Error),
551}
552
553async fn send_write_once(req: reqwest::RequestBuilder) -> WriteAttempt {
555 match req.send().await {
556 Err(e) => {
557 let retryable = retry::retryable_reqwest(&e);
558 let err = Error::Http(e);
559 if retryable {
560 WriteAttempt::Retry {
561 after: None,
562 last: err,
563 }
564 } else {
565 WriteAttempt::Fatal(err)
566 }
567 }
568 Ok(resp) => {
569 let status = resp.status();
570 if status.is_success() {
571 return WriteAttempt::Ok;
572 }
573 let code = status.as_u16();
574 let retry_after = resp
575 .headers()
576 .get(RETRY_AFTER)
577 .and_then(|v| v.to_str().ok())
578 .and_then(retry::parse_retry_after);
579 let err = parse_write_error(code, resp).await;
582 if retry::retryable_status(code) {
583 WriteAttempt::Retry {
584 after: retry_after,
585 last: err,
586 }
587 } else {
588 WriteAttempt::Fatal(err)
589 }
590 }
591 }
592}
593
594async fn parse_write_error(code: u16, resp: Response) -> Error {
596 let body = resp.text().await.unwrap_or_default();
597
598 if let Ok(v) = serde_json::from_str::<serde_json::Value>(&body) {
599 let is_partial = v
600 .get("error")
601 .and_then(|e| e.as_str())
602 .map(|s| s.contains("partial write"))
603 .unwrap_or(false);
604
605 if is_partial && v.get("data").and_then(|d| d.as_array()).is_some() {
606 return Error::PartialWrite(PartialWriteError {
607 line_errors: parse_line_errors(&v),
608 });
609 }
610
611 let msg = v
612 .get("error")
613 .or_else(|| v.get("message"))
614 .and_then(|m| m.as_str())
615 .unwrap_or(&body)
616 .to_owned();
617
618 return Error::Server { code, message: msg };
619 }
620
621 Error::Server {
622 code,
623 message: body,
624 }
625}
626
627fn parse_line_errors(v: &serde_json::Value) -> Vec<LineError> {
628 v.get("data")
629 .and_then(|d| d.as_array())
630 .map(|arr| {
631 arr.iter()
632 .filter_map(|e| {
633 Some(LineError {
634 line: e.get("line_number")?.as_u64()?,
635 message: e.get("error_message")?.as_str()?.to_owned(),
636 original_line: e
637 .get("original_line")
638 .and_then(|s| s.as_str())
639 .map(str::to_owned),
640 })
641 })
642 .collect()
643 })
644 .unwrap_or_default()
645}