1#![warn(missing_docs)]
2#![doc = include_str!(concat!("../", std::env!("CARGO_PKG_README")))]
3use std::any::Any;
4use std::collections::HashMap;
5use std::error::Error;
6use std::fmt::Debug;
7use std::sync::Arc;
8use std::time::Duration;
9mod exec;
10mod metrics;
11mod sql;
12use arrow_flight::FlightInfo;
13use arrow_flight::error::FlightError;
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use async_trait::async_trait;
16use datafusion::{
17 catalog::{Session, TableProvider},
18 common::{Statistics, ToDFSchema, stats::Precision},
19 datasource::{DefaultTableSource, TableType, empty::EmptyTable},
20 error::{DataFusionError, Result},
21 logical_expr::{LogicalPlan, TableProviderFilterPushDown, TableScan},
22 physical_plan::ExecutionPlan,
23 prelude::*,
24 sql::{
25 TableReference,
26 unparser::{Unparser, dialect::PostgreSqlDialect},
27 },
28};
29use exec::FlightExec;
30use liquid_cache_common::CacheMode;
31use log::info;
32use owo_colors::OwoColorize;
33use serde::{Deserialize, Serialize};
34use sql::FlightSqlDriver;
35use tonic::transport::Channel;
36
37fn transform_flight_schema_to_output_schema(schema: &SchemaRef) -> Schema {
38 let transformed_fields: Vec<Arc<Field>> = schema
39 .fields
40 .iter()
41 .map(|field| match field.data_type() {
42 DataType::Dictionary(key, value) => {
43 if key.as_ref() == &DataType::UInt16 && value.as_ref() == &DataType::Utf8 {
44 Arc::new(field.as_ref().clone().with_data_type(DataType::Utf8))
45 } else if key.as_ref() == &DataType::UInt16 && value.as_ref() == &DataType::Binary {
46 Arc::new(field.as_ref().clone().with_data_type(DataType::Binary))
47 } else {
48 field.clone()
49 }
50 }
51 _ => field.clone(),
52 })
53 .collect();
54 Schema::new_with_metadata(transformed_fields, schema.metadata.clone())
55}
56
57#[derive(Clone, Debug)]
80pub struct LiquidCacheTableBuilder {
81 driver: Arc<FlightSqlDriver>,
82 object_stores: Vec<(String, HashMap<String, String>)>,
83 cache_mode: CacheMode,
84 cache_server: String,
85 table_name: String,
86 table_url: String,
87}
88
89impl LiquidCacheTableBuilder {
90 pub fn new(
98 cache_server: impl AsRef<str>,
99 table_name: impl AsRef<str>,
100 table_url: impl AsRef<str>,
101 ) -> Self {
102 Self {
103 driver: Arc::new(FlightSqlDriver::default()),
104 object_stores: vec![],
105 cache_mode: CacheMode::Liquid,
106 cache_server: cache_server.as_ref().to_string(),
107 table_name: table_name.as_ref().to_string(),
108 table_url: table_url.as_ref().to_string(),
109 }
110 }
111
112 pub fn with_object_store(
130 mut self,
131 url: impl AsRef<str>,
132 object_store_options: Option<HashMap<String, String>>,
133 ) -> Self {
134 self.object_stores.push((
135 url.as_ref().to_string(),
136 object_store_options.unwrap_or_default(),
137 ));
138 self
139 }
140
141 pub fn with_cache_mode(mut self, cache_mode: CacheMode) -> Self {
147 self.cache_mode = cache_mode;
148 self
149 }
150
151 pub async fn build(self) -> Result<LiquidCacheTable> {
154 let channel = flight_channel(&self.cache_server).await?;
155 for (url, object_store_options) in self.object_stores {
156 self.driver
157 .register_object_store(channel.clone(), &url, object_store_options)
158 .await
159 .map_err(to_df_err)?;
160 }
161
162 let metadata = self
163 .driver
164 .metadata(
165 channel.clone(),
166 &self.table_name,
167 &self.table_url,
168 self.cache_mode,
169 )
170 .await
171 .map_err(to_df_err)?;
172 let num_rows = precision(metadata.info.total_records);
173 let total_byte_size = precision(metadata.info.total_bytes);
174 let output_schema = Arc::new(transform_flight_schema_to_output_schema(&metadata.schema));
175 let flight_schema = metadata.schema;
176 let stats = Statistics {
177 num_rows,
178 total_byte_size,
179 column_statistics: vec![],
180 };
181 Ok(LiquidCacheTable {
182 driver: self.driver.clone(),
183 channel,
184 origin: self.cache_server,
185 table_name: self.table_name.into(),
186 flight_schema,
187 output_schema,
188 stats,
189 })
190 }
191}
192
193#[derive(Clone, Debug)]
196pub(crate) struct FlightMetadata {
197 pub(crate) info: FlightInfo,
199 pub(crate) props: FlightProperties,
201 pub(crate) schema: SchemaRef,
203}
204
205impl FlightMetadata {
206 pub fn new(info: FlightInfo, props: FlightProperties, schema: SchemaRef) -> Self {
208 Self {
209 info,
210 props,
211 schema,
212 }
213 }
214
215 #[allow(clippy::result_large_err)]
218 pub fn try_new(info: FlightInfo, props: FlightProperties) -> arrow_flight::error::Result<Self> {
219 let schema = Arc::new(info.clone().try_decode_schema()?);
220 Ok(Self::new(info, props, schema))
221 }
222}
223
224impl TryFrom<FlightInfo> for FlightMetadata {
225 type Error = FlightError;
226
227 fn try_from(info: FlightInfo) -> Result<Self, Self::Error> {
228 Self::try_new(info, FlightProperties::default())
229 }
230}
231
232#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
235pub(crate) struct FlightProperties {
236 pub(crate) unbounded_stream: bool,
237 pub(crate) grpc_headers: HashMap<String, String>,
238}
239
240impl FlightProperties {
241 pub fn grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
242 self.grpc_headers = grpc_headers;
243 self
244 }
245}
246
247#[derive(Debug)]
249pub struct LiquidCacheTable {
250 driver: Arc<FlightSqlDriver>,
251 channel: Channel,
252 origin: String,
253 table_name: TableReference,
254 flight_schema: SchemaRef,
256 output_schema: SchemaRef,
259 stats: Statistics,
260}
261
262#[async_trait]
263impl TableProvider for LiquidCacheTable {
264 fn as_any(&self) -> &dyn Any {
265 self
266 }
267
268 fn schema(&self) -> SchemaRef {
269 self.output_schema.clone()
270 }
271
272 fn table_type(&self) -> TableType {
273 TableType::View
274 }
275
276 async fn scan(
277 &self,
278 _state: &dyn Session,
279 projection: Option<&Vec<usize>>,
280 filters: &[Expr],
281 limit: Option<usize>,
282 ) -> Result<Arc<dyn ExecutionPlan>> {
283 let unparsed_sql = {
284 let empty_table_provider = EmptyTable::new(self.schema().clone());
286 let table_source = Arc::new(DefaultTableSource::new(Arc::new(empty_table_provider)));
287
288 let logical_plan = TableScan {
289 table_name: self.table_name.clone(),
290 source: table_source,
291 projection: projection.map(|p| p.to_vec()),
292 filters: filters.to_vec(),
293 fetch: limit,
294 projected_schema: Arc::new(self.schema().as_ref().clone().to_dfschema().unwrap()),
295 };
296 let unparser = Unparser::new(&PostgreSqlDialect {});
297 let unparsed_sql = unparser
298 .plan_to_sql(&LogicalPlan::TableScan(logical_plan))
299 .unwrap();
300 unparsed_sql.to_string()
301 };
302
303 info!("SQL send to cache: \n{}", unparsed_sql.cyan());
304
305 let metadata = self
306 .driver
307 .run_sql(self.channel.clone(), &unparsed_sql)
308 .await
309 .map_err(to_df_err)?;
310
311 Ok(Arc::new(FlightExec::try_new(
312 self.flight_schema.clone(),
313 self.output_schema.clone(),
314 metadata,
315 projection,
316 &self.origin,
317 limit,
318 )?))
319 }
320
321 fn statistics(&self) -> Option<Statistics> {
322 Some(self.stats.clone())
323 }
324
325 fn supports_filters_pushdown(
326 &self,
327 filters: &[&Expr],
328 ) -> Result<Vec<TableProviderFilterPushDown>> {
329 let filter_push_down: Vec<TableProviderFilterPushDown> = filters
330 .iter()
331 .map(
332 |f| match Unparser::new(&PostgreSqlDialect {}).expr_to_sql(f) {
333 Ok(_) => TableProviderFilterPushDown::Exact,
334 Err(_) => TableProviderFilterPushDown::Unsupported,
335 },
336 )
337 .collect();
338
339 Ok(filter_push_down)
340 }
341}
342
343pub(crate) fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
344 DataFusionError::External(Box::new(err))
345}
346
347pub(crate) async fn flight_channel(source: impl Into<String>) -> Result<Channel> {
348 let endpoint = Channel::from_shared(source.into())
351 .map_err(to_df_err)?
352 .tcp_keepalive(Some(Duration::from_secs(10)));
353 endpoint.connect().await.map_err(to_df_err)
354}
355
356fn precision(total: i64) -> Precision<usize> {
357 if total < 0 {
358 Precision::Absent
359 } else {
360 Precision::Exact(total as usize)
361 }
362}