liquid_cache_server/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#![warn(missing_docs)]
19#![doc = include_str!(concat!("../", std::env!("CARGO_PKG_README")))]
20
21use arrow::ipc::writer::IpcWriteOptions;
22use arrow_flight::{
23    Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
24    IpcMessage, SchemaAsIpc, Ticket,
25    encode::{DictionaryHandling, FlightDataEncoderBuilder},
26    flight_descriptor::DescriptorType,
27    flight_service_server::FlightService,
28    sql::{
29        Any, CommandGetDbSchemas, CommandPreparedStatementUpdate, CommandStatementQuery,
30        ProstMessageExt, SqlInfo,
31        server::{FlightSqlService, PeekableFlightDataStream},
32    },
33};
34use datafusion::{
35    error::DataFusionError,
36    execution::{SessionStateBuilder, object_store::ObjectStoreUrl},
37    physical_plan::{ExecutionPlan, ExecutionPlanProperties},
38    prelude::{SessionConfig, SessionContext},
39};
40use futures::{Stream, TryStreamExt};
41use liquid_cache_common::{
42    CacheMode,
43    rpc::{FetchResults, LiquidCacheActions},
44};
45use liquid_cache_parquet::LiquidCacheRef;
46use log::info;
47use prost::Message;
48use prost::bytes::Bytes;
49use service::LiquidCacheServiceInner;
50use std::{
51    path::PathBuf,
52    sync::{Arc, atomic::AtomicU64},
53};
54use std::{pin::Pin, str::FromStr};
55use tonic::{Request, Response, Status, Streaming};
56use url::Url;
57mod service;
58mod utils;
59use utils::FinalStream;
60mod local_cache;
61
62/// A trait to collect stats for the execution plan.
63/// The server calls `start` right before polling the stream,
64/// and calls `stop` right after exhausting the stream.
65pub trait StatsCollector: Send + Sync {
66    /// Start the stats collector.
67    fn start(&self, partition: usize, plan: &Arc<dyn ExecutionPlan>);
68    /// Stop the stats collector.
69    fn stop(&self, partition: usize, plan: &Arc<dyn ExecutionPlan>);
70}
71
72/// The LiquidCache server.
73///
74/// # Example
75///
76/// ```rust
77/// use arrow_flight::flight_service_server::FlightServiceServer;
78/// use datafusion::prelude::SessionContext;
79/// use liquid_cache_server::LiquidCacheService;
80/// use tonic::transport::Server;
81/// let liquid_cache = LiquidCacheService::new(SessionContext::new(), None, None);
82/// let flight = FlightServiceServer::new(liquid_cache);
83/// Server::builder()
84///     .add_service(flight)
85///     .serve("0.0.0.0:50051".parse().unwrap());
86/// ```
87pub struct LiquidCacheService {
88    inner: LiquidCacheServiceInner,
89    stats_collector: Vec<Arc<dyn StatsCollector>>,
90    next_execution_id: AtomicU64,
91    most_recent_execution_id: AtomicU64,
92}
93
94impl Default for LiquidCacheService {
95    fn default() -> Self {
96        Self::try_new().unwrap()
97    }
98}
99
100impl LiquidCacheService {
101    /// Create a new LiquidCacheService with a default SessionContext
102    /// With no disk cache and unbounded memory usage.
103    pub fn try_new() -> Result<Self, DataFusionError> {
104        let ctx = Self::context(None)?;
105        Ok(Self::new(ctx, None, None))
106    }
107
108    /// Create a new LiquidCacheService with a custom SessionContext
109    ///
110    /// # Arguments
111    ///
112    /// * `ctx` - The SessionContext to use
113    /// * `max_cache_bytes` - The maximum number of bytes to cache in memory
114    /// * `disk_cache_dir` - The directory to store the disk cache
115    pub fn new(
116        ctx: SessionContext,
117        max_cache_bytes: Option<usize>,
118        disk_cache_dir: Option<PathBuf>,
119    ) -> Self {
120        Self {
121            inner: LiquidCacheServiceInner::new(Arc::new(ctx), max_cache_bytes, disk_cache_dir),
122            stats_collector: vec![],
123            next_execution_id: AtomicU64::new(0),
124            most_recent_execution_id: AtomicU64::new(0),
125        }
126    }
127
128    /// Get a reference to the cache
129    pub fn cache(&self) -> &LiquidCacheRef {
130        self.inner.cache()
131    }
132
133    /// Add a stats collector to the service
134    pub fn add_stats_collector(&mut self, collector: Arc<dyn StatsCollector>) {
135        self.stats_collector.push(collector);
136    }
137
138    /// Create a new SessionContext with good defaults
139    /// This is the recommended way to create a SessionContext for LiquidCache
140    pub fn context(partitions: Option<usize>) -> Result<SessionContext, DataFusionError> {
141        let mut session_config = SessionConfig::from_env()?;
142        let options_mut = session_config.options_mut();
143        options_mut.execution.parquet.pushdown_filters = true;
144        options_mut.execution.parquet.binary_as_string = true;
145
146        {
147            // view types cause excessive memory usage because they are not gced.
148            // For Arrow memory mode, we need to read as UTF-8
149            // For Liquid cache, we have our own way of handling string columns
150            options_mut.execution.parquet.schema_force_view_types = false;
151        }
152
153        if let Some(partitions) = partitions {
154            options_mut.execution.target_partitions = partitions;
155        }
156
157        let object_store_url = ObjectStoreUrl::parse("file://").unwrap();
158        let object_store = object_store::local::LocalFileSystem::new();
159
160        let state = SessionStateBuilder::new()
161            .with_config(session_config)
162            .with_default_features()
163            .with_object_store(object_store_url.as_ref(), Arc::new(object_store))
164            .build();
165
166        let ctx = SessionContext::new_with_state(state);
167        Ok(ctx)
168    }
169
170    fn get_next_execution_id(&self) -> u64 {
171        self.next_execution_id
172            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
173    }
174}
175
176#[tonic::async_trait]
177impl FlightSqlService for LiquidCacheService {
178    type FlightService = LiquidCacheService;
179
180    async fn do_handshake(
181        &self,
182        _request: Request<Streaming<HandshakeRequest>>,
183    ) -> Result<
184        Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
185        Status,
186    > {
187        unimplemented!("We don't do handshake")
188    }
189
190    async fn get_flight_info_schemas(
191        &self,
192        query: CommandGetDbSchemas,
193        _request: Request<FlightDescriptor>,
194    ) -> Result<Response<FlightInfo>, Status> {
195        let table_name = query
196            .db_schema_filter_pattern
197            .ok_or(Status::invalid_argument(
198                "db_schema_filter_pattern is required",
199            ))?;
200        let schema = self.inner.get_table_schema(&table_name).await?;
201
202        let mut info = FlightInfo::new();
203        info.schema = encode_schema_to_ipc_bytes(&schema);
204
205        Ok(Response::new(info))
206    }
207
208    async fn do_get_fallback(
209        &self,
210        _request: Request<Ticket>,
211        message: Any,
212    ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
213        if !message.is::<FetchResults>() {
214            Err(Status::unimplemented(format!(
215                "do_get: The defined request is invalid: {}",
216                message.type_url
217            )))?
218        }
219
220        let fetch_results: FetchResults = message
221            .unpack()
222            .map_err(|e| Status::internal(format!("{e:?}")))?
223            .ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?;
224
225        let handle = fetch_results.handle;
226        let partition = fetch_results.partition as usize;
227        let stream = self.inner.execute_plan(handle, partition).await;
228        let execution_plan = self.inner.get_plan(handle).unwrap();
229        let stream = FinalStream::new(
230            stream,
231            self.stats_collector.clone(),
232            self.inner.batch_size(),
233            partition,
234            execution_plan,
235        )
236        .map_err(|e| {
237            panic!("Error executing plan: {:?}", e);
238        });
239
240        let ipc_options = IpcWriteOptions::default();
241        let stream = FlightDataEncoderBuilder::new()
242            .with_options(ipc_options)
243            .with_dictionary_handling(DictionaryHandling::Resend)
244            .build(stream)
245            .map_err(Status::from);
246        self.most_recent_execution_id
247            .store(handle, std::sync::atomic::Ordering::Relaxed);
248
249        Ok(Response::new(Box::pin(stream)))
250    }
251
252    async fn get_flight_info_statement(
253        &self,
254        cmd: CommandStatementQuery,
255        _request: Request<FlightDescriptor>,
256    ) -> Result<Response<FlightInfo>, Status> {
257        let user_query = cmd.query.as_str();
258        let handle = self.get_next_execution_id();
259        let physical_plan = self
260            .inner
261            .prepare_and_register_plan(user_query, handle)
262            .await?;
263        let partition_count = physical_plan.output_partitioning().partition_count();
264
265        let schema = physical_plan.schema();
266
267        let flight_desc = FlightDescriptor {
268            r#type: DescriptorType::Cmd.into(),
269            cmd: Default::default(),
270            path: vec![],
271        };
272
273        let mut info = FlightInfo::new().with_descriptor(flight_desc);
274        info.schema = encode_schema_to_ipc_bytes(&schema);
275
276        for partition in 0..partition_count {
277            let fetch = FetchResults {
278                handle,
279                partition: partition as u32,
280            };
281            let buf = fetch.as_any().encode_to_vec().into();
282            let ticket = Ticket { ticket: buf };
283            let endpoint = FlightEndpoint::new().with_ticket(ticket.clone());
284            info = info.with_endpoint(endpoint);
285        }
286
287        let resp = Response::new(info);
288        Ok(resp)
289    }
290
291    async fn do_put_prepared_statement_update(
292        &self,
293        _handle: CommandPreparedStatementUpdate,
294        _request: Request<PeekableFlightDataStream>,
295    ) -> Result<i64, Status> {
296        info!("do_put_prepared_statement_update");
297        // statements like "CREATE TABLE.." or "SET datafusion.nnn.." call this function
298        // and we are required to return some row count here
299        Ok(-1)
300    }
301
302    async fn do_action_fallback(
303        &self,
304        request: Request<Action>,
305    ) -> Result<Response<<Self as FlightService>::DoActionStream>, Status> {
306        let action = LiquidCacheActions::from(request.into_inner());
307        match action {
308            LiquidCacheActions::RegisterObjectStore(cmd) => {
309                self.inner
310                    .register_object_store(&Url::parse(&cmd.url).unwrap(), cmd.options)
311                    .await
312                    .map_err(df_error_to_status)?;
313
314                let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
315                    body: Bytes::default(),
316                })]);
317                return Ok(Response::new(Box::pin(output)));
318            }
319            LiquidCacheActions::RegisterTable(cmd) => {
320                let parquet_mode = CacheMode::from_str(&cmd.cache_mode).unwrap();
321                self.inner
322                    .register_table(&cmd.url, &cmd.table_name, parquet_mode)
323                    .await
324                    .map_err(df_error_to_status)?;
325
326                let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
327                    body: Bytes::default(),
328                })]);
329                return Ok(Response::new(Box::pin(output)));
330            }
331            LiquidCacheActions::ExecutionMetrics => {
332                let execution_id = self
333                    .most_recent_execution_id
334                    .load(std::sync::atomic::Ordering::Relaxed);
335                let response = self.inner.get_metrics(execution_id).unwrap();
336                let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
337                    body: response.as_any().encode_to_vec().into(),
338                })]);
339                return Ok(Response::new(Box::pin(output)));
340            }
341            LiquidCacheActions::ResetCache => {
342                self.inner.cache().reset();
343
344                let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
345                    body: Bytes::default(),
346                })]);
347                return Ok(Response::new(Box::pin(output)));
348            }
349        }
350    }
351
352    async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
353}
354
355fn df_error_to_status(err: datafusion::error::DataFusionError) -> Status {
356    Status::internal(format!("{err:?}"))
357}
358
359// TODO: we need to workaround a arrow-flight bug here:
360// https://github.com/apache/arrow-rs/issues/7058
361fn encode_schema_to_ipc_bytes(schema: &arrow_schema::Schema) -> Bytes {
362    let options = IpcWriteOptions::default();
363    let schema_as_ipc = SchemaAsIpc::new(schema, &options);
364    let IpcMessage(schema) = schema_as_ipc.try_into().unwrap();
365    schema
366}