Skip to main content

influxdb3_client/
flight.rs

1/// Arrow Flight gRPC query transport for InfluxDB 3.
2use std::collections::HashMap;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use arrow::record_batch::RecordBatch;
8use arrow_flight::{
9    decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Ticket,
10};
11use arrow_schema::SchemaRef;
12use bytes::Bytes;
13use futures_util::{Stream, TryStreamExt};
14use serde_json::{json, Value as JsonValue};
15use tonic::{
16    metadata::MetadataValue,
17    transport::{Channel, ClientTlsConfig, Endpoint},
18    Request,
19};
20use url::Url;
21
22use crate::{
23    error::Error,
24    query::{QueryOptions, QueryResult},
25};
26
27/// Holds the gRPC channel.  `Channel` is internally an `Arc` over the HTTP/2
28/// connection, so cloning is cheap and concurrent calls multiplex on the same
29/// underlying transport.
30pub(crate) struct FlightQueryClient {
31    inner: FlightServiceClient<Channel>,
32    token: Option<String>,
33    auth_scheme: String,
34}
35
36impl FlightQueryClient {
37    pub(crate) async fn new(
38        host_url: &str,
39        token: Option<&str>,
40        auth_scheme: &str,
41        ssl_roots_path: Option<&str>,
42        connect_timeout: Duration,
43    ) -> Result<Self, Error> {
44        let parsed = Url::parse(host_url)?;
45        let tls = parsed.scheme() == "https";
46
47        let host_str = parsed
48            .host_str()
49            .ok_or_else(|| Error::Config(format!("no host in URL: {host_url}")))?;
50        let port = parsed.port().unwrap_or(if tls { 443 } else { 80 });
51
52        let endpoint_url = if tls {
53            format!("https://{host_str}:{port}")
54        } else {
55            format!("http://{host_str}:{port}")
56        };
57
58        let endpoint: Endpoint = Channel::from_shared(endpoint_url)
59            .map_err(|e| Error::Config(e.to_string()))?
60            .connect_timeout(connect_timeout);
61
62        let endpoint = if tls {
63            let mut tls_config = ClientTlsConfig::new().with_native_roots();
64            if let Some(path) = ssl_roots_path {
65                let pem = std::fs::read(path)
66                    .map_err(|e| Error::Config(format!("cannot read SSL roots '{path}': {e}")))?;
67                let cert = tonic::transport::Certificate::from_pem(pem);
68                tls_config = tls_config.ca_certificate(cert);
69            }
70            endpoint.tls_config(tls_config)?
71        } else {
72            endpoint
73        };
74
75        let channel = endpoint.connect().await?;
76        let inner = FlightServiceClient::new(channel);
77
78        Ok(FlightQueryClient {
79            inner,
80            token: token.map(str::to_owned),
81            auth_scheme: auth_scheme.to_owned(),
82        })
83    }
84
85    /// Open a streaming query and return a [`BatchStream`].
86    ///
87    /// Clones the underlying gRPC client per call; `Channel` is `Arc`-backed, so
88    /// concurrent queries multiplex on the same connection.
89    pub(crate) async fn stream(
90        &self,
91        query_str: &str,
92        database: &str,
93        options: &QueryOptions,
94        params: Option<&HashMap<String, JsonValue>>,
95    ) -> Result<BatchStream, Error> {
96        let ticket_payload = build_ticket(query_str, database, options, params);
97        let ticket = Ticket {
98            ticket: Bytes::from(ticket_payload),
99        };
100
101        let mut request = Request::new(ticket);
102
103        if let Some(tok) = &self.token {
104            let auth_value = format!("{} {}", self.auth_scheme, tok);
105            let meta: MetadataValue<tonic::metadata::Ascii> = auth_value.parse().map_err(|_| {
106                Error::Config("token contains characters invalid in gRPC metadata".into())
107            })?;
108            request.metadata_mut().insert("authorization", meta);
109        }
110
111        for (k, v) in &options.headers {
112            if let (Ok(name), Ok(val)) = (
113                tonic::metadata::MetadataKey::<tonic::metadata::Ascii>::from_bytes(k.as_bytes()),
114                v.parse::<MetadataValue<tonic::metadata::Ascii>>(),
115            ) {
116                request.metadata_mut().insert(name, val);
117            }
118        }
119
120        // Channel is Arc-backed, so cloning the client is cheap.
121        let mut client = self.inner.clone();
122        let response = client.do_get(request).await?;
123        let raw = response.into_inner();
124        let batch_stream = FlightRecordBatchStream::new_from_flight_data(
125            raw.map_err(arrow_flight::error::FlightError::Tonic),
126        )
127        .map_err(|e| Error::Arrow(arrow::error::ArrowError::ExternalError(Box::new(e))));
128
129        Ok(BatchStream {
130            inner: Box::pin(batch_stream),
131        })
132    }
133
134    /// Execute a query and collect all batches.
135    pub(crate) async fn query(
136        &self,
137        query_str: &str,
138        database: &str,
139        options: &QueryOptions,
140        params: Option<&HashMap<String, JsonValue>>,
141    ) -> Result<QueryResult, Error> {
142        let mut stream = self.stream(query_str, database, options, params).await?;
143
144        let mut schema: Option<SchemaRef> = None;
145        let mut batches: Vec<RecordBatch> = Vec::new();
146
147        while let Some(batch) = stream.try_next().await? {
148            if schema.is_none() {
149                schema = Some(batch.schema());
150            }
151            batches.push(batch);
152        }
153
154        let schema = schema.unwrap_or_else(|| std::sync::Arc::new(arrow_schema::Schema::empty()));
155
156        Ok(QueryResult { schema, batches })
157    }
158}
159
160/// Streaming iterator over query result [`RecordBatch`]es.
161///
162/// Use this when the result is too large to materialise in memory. The
163/// underlying gRPC stream is consumed lazily as you poll.
164///
165/// ```rust,no_run
166/// # use influxdb3_client::Client;
167/// # use futures_util::TryStreamExt;
168/// # async fn example(client: &Client) -> influxdb3_client::Result<()> {
169/// let mut stream = client.sql("SELECT * FROM huge_table").stream().await?;
170/// while let Some(batch) = stream.try_next().await? {
171///     println!("got {} rows", batch.num_rows());
172/// }
173/// # Ok(()) }
174/// ```
175pub struct BatchStream {
176    inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, Error>> + Send>>,
177}
178
179impl Stream for BatchStream {
180    type Item = Result<RecordBatch, Error>;
181
182    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183        self.inner.as_mut().poll_next(cx)
184    }
185}
186
187/// Build the JSON ticket that InfluxDB 3 expects on its Flight `DoGet` endpoint.
188fn build_ticket(
189    query_str: &str,
190    database: &str,
191    options: &QueryOptions,
192    params: Option<&HashMap<String, JsonValue>>,
193) -> Vec<u8> {
194    let mut ticket = json!({
195        "database": database,
196        "sql_query": query_str,
197        "query_type": options.query_type.as_str(),
198    });
199
200    if let Some(p) = params {
201        if !p.is_empty() {
202            ticket["params"] = json!(p);
203        }
204    }
205
206    ticket.to_string().into_bytes()
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use crate::query::{QueryOptions, QueryType};
213
214    #[test]
215    fn ticket_shape() {
216        // Default SQL
217        let t = build_ticket("SELECT 1", "mydb", &QueryOptions::default(), None);
218        let v: serde_json::Value = serde_json::from_slice(&t).unwrap();
219        assert_eq!(v["database"], "mydb");
220        assert_eq!(v["sql_query"], "SELECT 1");
221        assert_eq!(v["query_type"], "sql");
222        assert!(v.get("params").is_none());
223
224        // InfluxQL + params
225        let opts = QueryOptions {
226            query_type: QueryType::InfluxQL,
227            ..Default::default()
228        };
229        let mut p = HashMap::new();
230        p.insert("loc".into(), json!("Paris"));
231        let t = build_ticket("SHOW MEASUREMENTS", "db", &opts, Some(&p));
232        let v: serde_json::Value = serde_json::from_slice(&t).unwrap();
233        assert_eq!(v["query_type"], "influxql");
234        assert_eq!(v["params"]["loc"], "Paris");
235    }
236}