1#![warn(missing_docs)]
19#![doc = include_str!(concat!("../", std::env!("CARGO_PKG_README")))]
20
21use arrow::ipc::writer::IpcWriteOptions;
22use arrow_flight::{
23 Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse,
24 IpcMessage, SchemaAsIpc, Ticket,
25 encode::{DictionaryHandling, FlightDataEncoderBuilder},
26 flight_descriptor::DescriptorType,
27 flight_service_server::FlightService,
28 sql::{
29 Any, CommandGetDbSchemas, CommandPreparedStatementUpdate, CommandStatementQuery,
30 ProstMessageExt, SqlInfo,
31 server::{FlightSqlService, PeekableFlightDataStream},
32 },
33};
34use datafusion::{
35 error::DataFusionError,
36 execution::{SessionStateBuilder, object_store::ObjectStoreUrl},
37 physical_plan::{ExecutionPlan, ExecutionPlanProperties},
38 prelude::{SessionConfig, SessionContext},
39};
40use futures::{Stream, TryStreamExt};
41use liquid_cache_common::{
42 CacheMode,
43 rpc::{FetchResults, LiquidCacheActions},
44};
45use liquid_cache_parquet::LiquidCacheRef;
46use log::info;
47use prost::Message;
48use prost::bytes::Bytes;
49use service::LiquidCacheServiceInner;
50use std::{
51 path::PathBuf,
52 sync::{Arc, atomic::AtomicU64},
53};
54use std::{pin::Pin, str::FromStr};
55use tonic::{Request, Response, Status, Streaming};
56use url::Url;
57mod service;
58mod utils;
59use utils::FinalStream;
60mod local_cache;
61
62pub trait StatsCollector: Send + Sync {
66 fn start(&self, partition: usize, plan: &Arc<dyn ExecutionPlan>);
68 fn stop(&self, partition: usize, plan: &Arc<dyn ExecutionPlan>);
70}
71
72pub struct LiquidCacheService {
88 inner: LiquidCacheServiceInner,
89 stats_collector: Vec<Arc<dyn StatsCollector>>,
90 next_execution_id: AtomicU64,
91 most_recent_execution_id: AtomicU64,
92}
93
94impl Default for LiquidCacheService {
95 fn default() -> Self {
96 Self::try_new().unwrap()
97 }
98}
99
100impl LiquidCacheService {
101 pub fn try_new() -> Result<Self, DataFusionError> {
104 let ctx = Self::context(None)?;
105 Ok(Self::new(ctx, None, None))
106 }
107
108 pub fn new(
116 ctx: SessionContext,
117 max_cache_bytes: Option<usize>,
118 disk_cache_dir: Option<PathBuf>,
119 ) -> Self {
120 Self {
121 inner: LiquidCacheServiceInner::new(Arc::new(ctx), max_cache_bytes, disk_cache_dir),
122 stats_collector: vec![],
123 next_execution_id: AtomicU64::new(0),
124 most_recent_execution_id: AtomicU64::new(0),
125 }
126 }
127
128 pub fn cache(&self) -> &LiquidCacheRef {
130 self.inner.cache()
131 }
132
133 pub fn add_stats_collector(&mut self, collector: Arc<dyn StatsCollector>) {
135 self.stats_collector.push(collector);
136 }
137
138 pub fn context(partitions: Option<usize>) -> Result<SessionContext, DataFusionError> {
141 let mut session_config = SessionConfig::from_env()?;
142 let options_mut = session_config.options_mut();
143 options_mut.execution.parquet.pushdown_filters = true;
144 options_mut.execution.parquet.binary_as_string = true;
145
146 {
147 options_mut.execution.parquet.schema_force_view_types = false;
151 }
152
153 if let Some(partitions) = partitions {
154 options_mut.execution.target_partitions = partitions;
155 }
156
157 let object_store_url = ObjectStoreUrl::parse("file://").unwrap();
158 let object_store = object_store::local::LocalFileSystem::new();
159
160 let state = SessionStateBuilder::new()
161 .with_config(session_config)
162 .with_default_features()
163 .with_object_store(object_store_url.as_ref(), Arc::new(object_store))
164 .build();
165
166 let ctx = SessionContext::new_with_state(state);
167 Ok(ctx)
168 }
169
170 fn get_next_execution_id(&self) -> u64 {
171 self.next_execution_id
172 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
173 }
174}
175
176#[tonic::async_trait]
177impl FlightSqlService for LiquidCacheService {
178 type FlightService = LiquidCacheService;
179
180 async fn do_handshake(
181 &self,
182 _request: Request<Streaming<HandshakeRequest>>,
183 ) -> Result<
184 Response<Pin<Box<dyn Stream<Item = Result<HandshakeResponse, Status>> + Send>>>,
185 Status,
186 > {
187 unimplemented!("We don't do handshake")
188 }
189
190 async fn get_flight_info_schemas(
191 &self,
192 query: CommandGetDbSchemas,
193 _request: Request<FlightDescriptor>,
194 ) -> Result<Response<FlightInfo>, Status> {
195 let table_name = query
196 .db_schema_filter_pattern
197 .ok_or(Status::invalid_argument(
198 "db_schema_filter_pattern is required",
199 ))?;
200 let schema = self.inner.get_table_schema(&table_name).await?;
201
202 let mut info = FlightInfo::new();
203 info.schema = encode_schema_to_ipc_bytes(&schema);
204
205 Ok(Response::new(info))
206 }
207
208 async fn do_get_fallback(
209 &self,
210 _request: Request<Ticket>,
211 message: Any,
212 ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
213 if !message.is::<FetchResults>() {
214 Err(Status::unimplemented(format!(
215 "do_get: The defined request is invalid: {}",
216 message.type_url
217 )))?
218 }
219
220 let fetch_results: FetchResults = message
221 .unpack()
222 .map_err(|e| Status::internal(format!("{e:?}")))?
223 .ok_or_else(|| Status::internal("Expected FetchResults but got None!"))?;
224
225 let handle = fetch_results.handle;
226 let partition = fetch_results.partition as usize;
227 let stream = self.inner.execute_plan(handle, partition).await;
228 let execution_plan = self.inner.get_plan(handle).unwrap();
229 let stream = FinalStream::new(
230 stream,
231 self.stats_collector.clone(),
232 self.inner.batch_size(),
233 partition,
234 execution_plan,
235 )
236 .map_err(|e| {
237 panic!("Error executing plan: {:?}", e);
238 });
239
240 let ipc_options = IpcWriteOptions::default();
241 let stream = FlightDataEncoderBuilder::new()
242 .with_options(ipc_options)
243 .with_dictionary_handling(DictionaryHandling::Resend)
244 .build(stream)
245 .map_err(Status::from);
246 self.most_recent_execution_id
247 .store(handle, std::sync::atomic::Ordering::Relaxed);
248
249 Ok(Response::new(Box::pin(stream)))
250 }
251
252 async fn get_flight_info_statement(
253 &self,
254 cmd: CommandStatementQuery,
255 _request: Request<FlightDescriptor>,
256 ) -> Result<Response<FlightInfo>, Status> {
257 let user_query = cmd.query.as_str();
258 let handle = self.get_next_execution_id();
259 let physical_plan = self
260 .inner
261 .prepare_and_register_plan(user_query, handle)
262 .await?;
263 let partition_count = physical_plan.output_partitioning().partition_count();
264
265 let schema = physical_plan.schema();
266
267 let flight_desc = FlightDescriptor {
268 r#type: DescriptorType::Cmd.into(),
269 cmd: Default::default(),
270 path: vec![],
271 };
272
273 let mut info = FlightInfo::new().with_descriptor(flight_desc);
274 info.schema = encode_schema_to_ipc_bytes(&schema);
275
276 for partition in 0..partition_count {
277 let fetch = FetchResults {
278 handle,
279 partition: partition as u32,
280 };
281 let buf = fetch.as_any().encode_to_vec().into();
282 let ticket = Ticket { ticket: buf };
283 let endpoint = FlightEndpoint::new().with_ticket(ticket.clone());
284 info = info.with_endpoint(endpoint);
285 }
286
287 let resp = Response::new(info);
288 Ok(resp)
289 }
290
291 async fn do_put_prepared_statement_update(
292 &self,
293 _handle: CommandPreparedStatementUpdate,
294 _request: Request<PeekableFlightDataStream>,
295 ) -> Result<i64, Status> {
296 info!("do_put_prepared_statement_update");
297 Ok(-1)
300 }
301
302 async fn do_action_fallback(
303 &self,
304 request: Request<Action>,
305 ) -> Result<Response<<Self as FlightService>::DoActionStream>, Status> {
306 let action = LiquidCacheActions::from(request.into_inner());
307 match action {
308 LiquidCacheActions::RegisterObjectStore(cmd) => {
309 self.inner
310 .register_object_store(&Url::parse(&cmd.url).unwrap(), cmd.options)
311 .await
312 .map_err(df_error_to_status)?;
313
314 let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
315 body: Bytes::default(),
316 })]);
317 return Ok(Response::new(Box::pin(output)));
318 }
319 LiquidCacheActions::RegisterTable(cmd) => {
320 let parquet_mode = CacheMode::from_str(&cmd.cache_mode).unwrap();
321 self.inner
322 .register_table(&cmd.url, &cmd.table_name, parquet_mode)
323 .await
324 .map_err(df_error_to_status)?;
325
326 let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
327 body: Bytes::default(),
328 })]);
329 return Ok(Response::new(Box::pin(output)));
330 }
331 LiquidCacheActions::ExecutionMetrics => {
332 let execution_id = self
333 .most_recent_execution_id
334 .load(std::sync::atomic::Ordering::Relaxed);
335 let response = self.inner.get_metrics(execution_id).unwrap();
336 let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
337 body: response.as_any().encode_to_vec().into(),
338 })]);
339 return Ok(Response::new(Box::pin(output)));
340 }
341 LiquidCacheActions::ResetCache => {
342 self.inner.cache().reset();
343
344 let output = futures::stream::iter(vec![Ok(arrow_flight::Result {
345 body: Bytes::default(),
346 })]);
347 return Ok(Response::new(Box::pin(output)));
348 }
349 }
350 }
351
352 async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
353}
354
355fn df_error_to_status(err: datafusion::error::DataFusionError) -> Status {
356 Status::internal(format!("{err:?}"))
357}
358
359fn encode_schema_to_ipc_bytes(schema: &arrow_schema::Schema) -> Bytes {
362 let options = IpcWriteOptions::default();
363 let schema_as_ipc = SchemaAsIpc::new(schema, &options);
364 let IpcMessage(schema) = schema_as_ipc.try_into().unwrap();
365 schema
366}