liquid_cache_client/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(not(doctest), doc = include_str!(concat!("../", std::env!("CARGO_PKG_README"))))]
3use std::collections::HashMap;
4use std::error::Error;
5use std::sync::Arc;
6use std::time::Duration;
7mod client_exec;
8mod metrics;
9mod optimizer;
10pub use client_exec::LiquidCacheClientExec;
11use datafusion::{
12    error::{DataFusionError, Result},
13    execution::{SessionStateBuilder, object_store::ObjectStoreUrl, runtime_env::RuntimeEnv},
14    prelude::*,
15};
16use fastrace_tonic::FastraceClientService;
17use liquid_cache_common::CacheMode;
18pub use optimizer::PushdownOptimizer;
19use tonic::transport::Channel;
20
21#[cfg(test)]
22mod tests;
23
24/// The builder for LiquidCache client state.
25///
26/// # Example
27///
28/// ```ignore
29/// use liquid_cache_client::LiquidCacheBuilder;
30/// let ctx = LiquidCacheBuilder::new("localhost:15214")
31///     .with_object_store("s3://my_bucket", None)
32///     .with_cache_mode(CacheMode::Liquid)
33///     .build(SessionConfig::from_env().unwrap())
34///     .unwrap();
35///
36/// ctx.register_parquet("my_table", "s3://my_bucket/my_table.parquet", Default::default())
37///     .await?;
38/// let df = ctx.sql("SELECT * FROM my_table").await?.show().await?;
39/// println!("{:?}", df);
40/// ```
41pub struct LiquidCacheBuilder {
42    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
43    cache_mode: CacheMode,
44    cache_server: String,
45}
46
47impl LiquidCacheBuilder {
48    /// Create a new builder for LiquidCache client state.
49    pub fn new(cache_server: impl AsRef<str>) -> Self {
50        Self {
51            object_stores: vec![],
52            cache_mode: CacheMode::Liquid,
53            cache_server: cache_server.as_ref().to_string(),
54        }
55    }
56
57    /// Add an object store to the builder.
58    pub fn with_object_store(
59        mut self,
60        url: ObjectStoreUrl,
61        object_store_options: Option<HashMap<String, String>>,
62    ) -> Self {
63        self.object_stores
64            .push((url, object_store_options.unwrap_or_default()));
65        self
66    }
67
68    /// Set the cache mode for the builder.
69    pub fn with_cache_mode(mut self, cache_mode: CacheMode) -> Self {
70        self.cache_mode = cache_mode;
71        self
72    }
73
74    /// Build the [SessionContext].
75    pub fn build(self, config: SessionConfig) -> Result<SessionContext> {
76        let mut session_config = config;
77        session_config
78            .options_mut()
79            .execution
80            .parquet
81            .pushdown_filters = true;
82        session_config
83            .options_mut()
84            .execution
85            .parquet
86            .schema_force_view_types = false;
87        session_config
88            .options_mut()
89            .execution
90            .parquet
91            .binary_as_string = true;
92        session_config.options_mut().execution.batch_size = 8192 * 2;
93        let session_state = SessionStateBuilder::new()
94            .with_config(session_config)
95            .with_runtime_env(Arc::new(RuntimeEnv::default()))
96            .with_default_features()
97            .with_physical_optimizer_rule(Arc::new(PushdownOptimizer::new(
98                self.cache_server.clone(),
99                self.cache_mode,
100                self.object_stores.clone(),
101            )))
102            .build();
103        Ok(SessionContext::new_with_state(session_state))
104    }
105}
106
107pub(crate) fn to_df_err<E: Error + Send + Sync + 'static>(err: E) -> DataFusionError {
108    DataFusionError::External(Box::new(err))
109}
110
111pub(crate) async fn flight_channel(
112    source: impl Into<String>,
113) -> Result<FastraceClientService<Channel>> {
114    use fastrace_tonic::FastraceClientLayer;
115    use tower::ServiceBuilder;
116
117    // No tls here, to avoid the overhead of TLS
118    // we assume both server and client are running on the trusted network.
119    let endpoint = Channel::from_shared(source.into())
120        .map_err(to_df_err)?
121        .tcp_keepalive(Some(Duration::from_secs(10)));
122
123    let channel = endpoint.connect().await.map_err(to_df_err)?;
124    let channel = ServiceBuilder::new()
125        .layer(FastraceClientLayer)
126        .service(channel);
127    Ok(channel)
128}