influxdb3_client/
flight.rs1use std::collections::HashMap;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use arrow::record_batch::RecordBatch;
8use arrow_flight::{
9 decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Ticket,
10};
11use arrow_schema::SchemaRef;
12use bytes::Bytes;
13use futures_util::{Stream, TryStreamExt};
14use serde_json::{json, Value as JsonValue};
15use tonic::{
16 metadata::MetadataValue,
17 transport::{Channel, ClientTlsConfig, Endpoint},
18 Request,
19};
20use url::Url;
21
22use crate::{
23 error::Error,
24 query::{QueryOptions, QueryResult},
25};
26
27pub(crate) struct FlightQueryClient {
31 inner: FlightServiceClient<Channel>,
32 token: Option<String>,
33 auth_scheme: String,
34}
35
36impl FlightQueryClient {
37 pub(crate) async fn new(
38 host_url: &str,
39 token: Option<&str>,
40 auth_scheme: &str,
41 ssl_roots_path: Option<&str>,
42 connect_timeout: Duration,
43 ) -> Result<Self, Error> {
44 let parsed = Url::parse(host_url)?;
45 let tls = parsed.scheme() == "https";
46
47 let host_str = parsed
48 .host_str()
49 .ok_or_else(|| Error::Config(format!("no host in URL: {host_url}")))?;
50 let port = parsed.port().unwrap_or(if tls { 443 } else { 80 });
51
52 let endpoint_url = if tls {
53 format!("https://{host_str}:{port}")
54 } else {
55 format!("http://{host_str}:{port}")
56 };
57
58 let endpoint: Endpoint = Channel::from_shared(endpoint_url)
59 .map_err(|e| Error::Config(e.to_string()))?
60 .connect_timeout(connect_timeout);
61
62 let endpoint = if tls {
63 let mut tls_config = ClientTlsConfig::new().with_native_roots();
64 if let Some(path) = ssl_roots_path {
65 let pem = std::fs::read(path)
66 .map_err(|e| Error::Config(format!("cannot read SSL roots '{path}': {e}")))?;
67 let cert = tonic::transport::Certificate::from_pem(pem);
68 tls_config = tls_config.ca_certificate(cert);
69 }
70 endpoint.tls_config(tls_config)?
71 } else {
72 endpoint
73 };
74
75 let channel = endpoint.connect().await?;
76 let inner = FlightServiceClient::new(channel);
77
78 Ok(FlightQueryClient {
79 inner,
80 token: token.map(str::to_owned),
81 auth_scheme: auth_scheme.to_owned(),
82 })
83 }
84
85 pub(crate) async fn stream(
90 &self,
91 query_str: &str,
92 database: &str,
93 options: &QueryOptions,
94 params: Option<&HashMap<String, JsonValue>>,
95 ) -> Result<BatchStream, Error> {
96 let ticket_payload = build_ticket(query_str, database, options, params);
97 let ticket = Ticket {
98 ticket: Bytes::from(ticket_payload),
99 };
100
101 let mut request = Request::new(ticket);
102
103 if let Some(tok) = &self.token {
104 let auth_value = format!("{} {}", self.auth_scheme, tok);
105 let meta: MetadataValue<tonic::metadata::Ascii> = auth_value.parse().map_err(|_| {
106 Error::Config("token contains characters invalid in gRPC metadata".into())
107 })?;
108 request.metadata_mut().insert("authorization", meta);
109 }
110
111 for (k, v) in &options.headers {
112 if let (Ok(name), Ok(val)) = (
113 tonic::metadata::MetadataKey::<tonic::metadata::Ascii>::from_bytes(k.as_bytes()),
114 v.parse::<MetadataValue<tonic::metadata::Ascii>>(),
115 ) {
116 request.metadata_mut().insert(name, val);
117 }
118 }
119
120 let mut client = self.inner.clone();
122 let response = client.do_get(request).await?;
123 let raw = response.into_inner();
124 let batch_stream = FlightRecordBatchStream::new_from_flight_data(
125 raw.map_err(arrow_flight::error::FlightError::Tonic),
126 )
127 .map_err(|e| Error::Arrow(arrow::error::ArrowError::ExternalError(Box::new(e))));
128
129 Ok(BatchStream {
130 inner: Box::pin(batch_stream),
131 })
132 }
133
134 pub(crate) async fn query(
136 &self,
137 query_str: &str,
138 database: &str,
139 options: &QueryOptions,
140 params: Option<&HashMap<String, JsonValue>>,
141 ) -> Result<QueryResult, Error> {
142 let mut stream = self.stream(query_str, database, options, params).await?;
143
144 let mut schema: Option<SchemaRef> = None;
145 let mut batches: Vec<RecordBatch> = Vec::new();
146
147 while let Some(batch) = stream.try_next().await? {
148 if schema.is_none() {
149 schema = Some(batch.schema());
150 }
151 batches.push(batch);
152 }
153
154 let schema = schema.unwrap_or_else(|| std::sync::Arc::new(arrow_schema::Schema::empty()));
155
156 Ok(QueryResult { schema, batches })
157 }
158}
159
160pub struct BatchStream {
176 inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, Error>> + Send>>,
177}
178
179impl Stream for BatchStream {
180 type Item = Result<RecordBatch, Error>;
181
182 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183 self.inner.as_mut().poll_next(cx)
184 }
185}
186
187fn build_ticket(
189 query_str: &str,
190 database: &str,
191 options: &QueryOptions,
192 params: Option<&HashMap<String, JsonValue>>,
193) -> Vec<u8> {
194 let mut ticket = json!({
195 "database": database,
196 "sql_query": query_str,
197 "query_type": options.query_type.as_str(),
198 });
199
200 if let Some(p) = params {
201 if !p.is_empty() {
202 ticket["params"] = json!(p);
203 }
204 }
205
206 ticket.to_string().into_bytes()
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use crate::query::{QueryOptions, QueryType};
213
214 #[test]
215 fn ticket_shape() {
216 let t = build_ticket("SELECT 1", "mydb", &QueryOptions::default(), None);
218 let v: serde_json::Value = serde_json::from_slice(&t).unwrap();
219 assert_eq!(v["database"], "mydb");
220 assert_eq!(v["sql_query"], "SELECT 1");
221 assert_eq!(v["query_type"], "sql");
222 assert!(v.get("params").is_none());
223
224 let opts = QueryOptions {
226 query_type: QueryType::InfluxQL,
227 ..Default::default()
228 };
229 let mut p = HashMap::new();
230 p.insert("loc".into(), json!("Paris"));
231 let t = build_ticket("SHOW MEASUREMENTS", "db", &opts, Some(&p));
232 let v: serde_json::Value = serde_json::from_slice(&t).unwrap();
233 assert_eq!(v["query_type"], "influxql");
234 assert_eq!(v["params"]["loc"], "Paris");
235 }
236}