liquid_cache_client/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!(concat!("../", std::env!("CARGO_PKG_README")))]
3use std::any::Any;
4use std::collections::HashMap;
5use std::error::Error;
6use std::fmt::Debug;
7use std::sync::Arc;
8use std::time::Duration;
9mod exec;
10mod metrics;
11mod sql;
12use arrow_flight::FlightInfo;
13use arrow_flight::error::FlightError;
14use arrow_schema::{DataType, Field, Schema, SchemaRef};
15use async_trait::async_trait;
16use datafusion::{
17    catalog::{Session, TableProvider},
18    common::{Statistics, ToDFSchema, stats::Precision},
19    datasource::{DefaultTableSource, TableType, empty::EmptyTable},
20    error::{DataFusionError, Result},
21    logical_expr::{LogicalPlan, TableProviderFilterPushDown, TableScan},
22    physical_plan::ExecutionPlan,
23    prelude::*,
24    sql::{
25        TableReference,
26        unparser::{Unparser, dialect::PostgreSqlDialect},
27    },
28};
29use exec::FlightExec;
30use liquid_cache_common::CacheMode;
31use log::info;
32use owo_colors::OwoColorize;
33use serde::{Deserialize, Serialize};
34use sql::FlightSqlDriver;
35use tonic::transport::Channel;
36
37fn transform_flight_schema_to_output_schema(schema: &SchemaRef) -> Schema {
38    let transformed_fields: Vec<Arc<Field>> = schema
39        .fields
40        .iter()
41        .map(|field| match field.data_type() {
42            DataType::Dictionary(key, value) => {
43                if key.as_ref() == &DataType::UInt16 && value.as_ref() == &DataType::Utf8 {
44                    Arc::new(field.as_ref().clone().with_data_type(DataType::Utf8))
45                } else if key.as_ref() == &DataType::UInt16 && value.as_ref() == &DataType::Binary {
46                    Arc::new(field.as_ref().clone().with_data_type(DataType::Binary))
47                } else {
48                    field.clone()
49                }
50            }
51            _ => field.clone(),
52        })
53        .collect();
54    Schema::new_with_metadata(transformed_fields, schema.metadata.clone())
55}
56
57/// The builder for a [LiquidCacheTable].
58///
59/// # Example
60///
61/// ```no_run
62/// let mut session_config = SessionConfig::from_env()?;
63/// session_config
64///     .options_mut()
65///     .execution
66///     .parquet
67///     .pushdown_filters = true;
68/// let ctx = Arc::new(SessionContext::new_with_config(session_config));
69/// let table = LiquidCacheTableBuilder::new(cache_server, table_name, url.as_ref())
70///     .with_object_store(
71///         format!("{}://{}", url.scheme(), url.host_str().unwrap_or_default()),
72///         None,
73///     )
74///     .build()
75///     .await?;
76/// ctx.register_table(table_name, Arc::new(table))?;
77/// ctx.sql(&sql).await?.show().await?;
78/// ```
79#[derive(Clone, Debug)]
80pub struct LiquidCacheTableBuilder {
81    driver: Arc<FlightSqlDriver>,
82    object_stores: Vec<(String, HashMap<String, String>)>,
83    cache_mode: CacheMode,
84    cache_server: String,
85    table_name: String,
86    table_url: String,
87}
88
89impl LiquidCacheTableBuilder {
90    /// Create a new builder for a [LiquidCacheTable].
91    ///
92    /// # Arguments
93    ///
94    /// * `cache_server` - The address of the cache server
95    /// * `table_name` - The name of the table
96    /// * `table_url` - The url of the table
97    pub fn new(
98        cache_server: impl AsRef<str>,
99        table_name: impl AsRef<str>,
100        table_url: impl AsRef<str>,
101    ) -> Self {
102        Self {
103            driver: Arc::new(FlightSqlDriver::default()),
104            object_stores: vec![],
105            cache_mode: CacheMode::Liquid,
106            cache_server: cache_server.as_ref().to_string(),
107            table_name: table_name.as_ref().to_string(),
108            table_url: table_url.as_ref().to_string(),
109        }
110    }
111
112    /// Add an object store to the builder.
113    ///
114    /// # Arguments
115    ///
116    /// * `url` - The url of the object store
117    /// * `object_store_options` - The options for the object store
118    ///
119    /// # Example
120    ///
121    /// ```rust
122    /// use liquid_cache_client::LiquidCacheTableBuilder;
123    /// use std::collections::HashMap;
124    ///
125    /// let mut builder = LiquidCacheTableBuilder::new("localhost:50051", "my_table", "my_table_url");
126    /// let object_store_options = HashMap::new();
127    /// builder.with_object_store("s3://my_bucket", Some(object_store_options)).build();
128    /// ```
129    pub fn with_object_store(
130        mut self,
131        url: impl AsRef<str>,
132        object_store_options: Option<HashMap<String, String>>,
133    ) -> Self {
134        self.object_stores.push((
135            url.as_ref().to_string(),
136            object_store_options.unwrap_or_default(),
137        ));
138        self
139    }
140
141    /// Set the cache mode for the builder.
142    ///
143    /// # Arguments
144    ///
145    /// * `cache_mode` - The cache mode to use
146    pub fn with_cache_mode(mut self, cache_mode: CacheMode) -> Self {
147        self.cache_mode = cache_mode;
148        self
149    }
150
151    /// Build the [LiquidCacheTable].
152    /// It will communicate with the cache server to get the metadata of the table.
153    pub async fn build(self) -> Result<LiquidCacheTable> {
154        let channel = flight_channel(&self.cache_server).await?;
155        for (url, object_store_options) in self.object_stores {
156            self.driver
157                .register_object_store(channel.clone(), &url, object_store_options)
158                .await
159                .map_err(to_df_err)?;
160        }
161
162        let metadata = self
163            .driver
164            .metadata(
165                channel.clone(),
166                &self.table_name,
167                &self.table_url,
168                self.cache_mode,
169            )
170            .await
171            .map_err(to_df_err)?;
172        let num_rows = precision(metadata.info.total_records);
173        let total_byte_size = precision(metadata.info.total_bytes);
174        let output_schema = Arc::new(transform_flight_schema_to_output_schema(&metadata.schema));
175        let flight_schema = metadata.schema;
176        let stats = Statistics {
177            num_rows,
178            total_byte_size,
179            column_statistics: vec![],
180        };
181        Ok(LiquidCacheTable {
182            driver: self.driver.clone(),
183            channel,
184            origin: self.cache_server,
185            table_name: self.table_name.into(),
186            flight_schema,
187            output_schema,
188            stats,
189        })
190    }
191}
192
193/// The information that a [FlightSqlDriver] must produce
194/// in order to register flights as DataFusion tables.
195#[derive(Clone, Debug)]
196pub(crate) struct FlightMetadata {
197    /// FlightInfo object produced by the driver
198    pub(crate) info: FlightInfo,
199    /// Various knobs that control execution
200    pub(crate) props: FlightProperties,
201    /// Arrow schema. Can be enforced by the driver or inferred from the FlightInfo
202    pub(crate) schema: SchemaRef,
203}
204
205impl FlightMetadata {
206    /// Customize everything that is in the driver's control
207    pub fn new(info: FlightInfo, props: FlightProperties, schema: SchemaRef) -> Self {
208        Self {
209            info,
210            props,
211            schema,
212        }
213    }
214
215    /// Customize flight properties and try to use the FlightInfo schema
216    // TODO: fix from upstream
217    #[allow(clippy::result_large_err)]
218    pub fn try_new(info: FlightInfo, props: FlightProperties) -> arrow_flight::error::Result<Self> {
219        let schema = Arc::new(info.clone().try_decode_schema()?);
220        Ok(Self::new(info, props, schema))
221    }
222}
223
224impl TryFrom<FlightInfo> for FlightMetadata {
225    type Error = FlightError;
226
227    fn try_from(info: FlightInfo) -> Result<Self, Self::Error> {
228        Self::try_new(info, FlightProperties::default())
229    }
230}
231
232/// Meant to gradually encapsulate all sorts of knobs required
233/// for controlling the protocol and query execution details.
234#[derive(Clone, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
235pub(crate) struct FlightProperties {
236    pub(crate) unbounded_stream: bool,
237    pub(crate) grpc_headers: HashMap<String, String>,
238}
239
240impl FlightProperties {
241    pub fn grpc_headers(mut self, grpc_headers: HashMap<String, String>) -> Self {
242        self.grpc_headers = grpc_headers;
243        self
244    }
245}
246
247/// Table provider that wraps a specific flight from an Arrow Flight service
248#[derive(Debug)]
249pub struct LiquidCacheTable {
250    driver: Arc<FlightSqlDriver>,
251    channel: Channel,
252    origin: String,
253    table_name: TableReference,
254    /// Flight schema is the schema sent between network.
255    flight_schema: SchemaRef,
256    /// Output schema is the schema we emit to users.
257    /// The flight schema is only optimized for transmission, so we need a schema adapter to change it to the output schema.
258    output_schema: SchemaRef,
259    stats: Statistics,
260}
261
262#[async_trait]
263impl TableProvider for LiquidCacheTable {
264    fn as_any(&self) -> &dyn Any {
265        self
266    }
267
268    fn schema(&self) -> SchemaRef {
269        self.output_schema.clone()
270    }
271
272    fn table_type(&self) -> TableType {
273        TableType::View
274    }
275
276    async fn scan(
277        &self,
278        _state: &dyn Session,
279        projection: Option<&Vec<usize>>,
280        filters: &[Expr],
281        limit: Option<usize>,
282    ) -> Result<Arc<dyn ExecutionPlan>> {
283        let unparsed_sql = {
284            // we don't care about actual source for the purpose of unparsing the sql.
285            let empty_table_provider = EmptyTable::new(self.schema().clone());
286            let table_source = Arc::new(DefaultTableSource::new(Arc::new(empty_table_provider)));
287
288            let logical_plan = TableScan {
289                table_name: self.table_name.clone(),
290                source: table_source,
291                projection: projection.map(|p| p.to_vec()),
292                filters: filters.to_vec(),
293                fetch: limit,
294                projected_schema: Arc::new(self.schema().as_ref().clone().to_dfschema().unwrap()),
295            };
296            let unparser = Unparser::new(&PostgreSqlDialect {});
297            let unparsed_sql = unparser
298                .plan_to_sql(&LogicalPlan::TableScan(logical_plan))
299                .unwrap();
300            unparsed_sql.to_string()
301        };
302
303        info!("SQL send to cache: \n{}", unparsed_sql.cyan());
304
305        let metadata = self
306            .driver
307            .run_sql(self.channel.clone(), &unparsed_sql)
308            .await
309            .map_err(to_df_err)?;
310
311        Ok(Arc::new(FlightExec::try_new(
312            self.flight_schema.clone(),
313            self.output_schema.clone(),
314            metadata,
315            projection,
316            &self.origin,
317            limit,
318        )?))
319    }
320
321    fn statistics(&self) -> Option<Statistics> {
322        Some(self.stats.clone())
323    }
324
325    fn supports_filters_pushdown(
326        &self,
327        filters: &[&Expr],
328    ) -> Result<Vec<TableProviderFilterPushDown>> {
329        let filter_push_down: Vec<TableProviderFilterPushDown> = filters
330            .iter()
331            .map(
332                |f| match Unparser::new(&PostgreSqlDialect {}).expr_to_sql(f) {
333                    Ok(_) => TableProviderFilterPushDown::Exact,
334                    Err(_) => TableProviderFilterPushDown::Unsupported,
335                },
336            )
337            .collect();
338
339        Ok(filter_push_down)
340    }
341}
342
343pub(crate) fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
344    DataFusionError::External(Box::new(err))
345}
346
347pub(crate) async fn flight_channel(source: impl Into<String>) -> Result<Channel> {
348    // No tls here, to avoid the overhead of TLS
349    // we assume both server and client are running on the trusted network.
350    let endpoint = Channel::from_shared(source.into())
351        .map_err(to_df_err)?
352        .tcp_keepalive(Some(Duration::from_secs(10)));
353    endpoint.connect().await.map_err(to_df_err)
354}
355
356fn precision(total: i64) -> Precision<usize> {
357    if total < 0 {
358        Precision::Absent
359    } else {
360        Precision::Exact(total as usize)
361    }
362}