1use std::collections::HashMap;
2use std::task::{Context, Poll};
3use std::{any::Any, fmt::Formatter, sync::Arc};
4
5use arrow::array::RecordBatch;
6use arrow_flight::decode::FlightRecordBatchStream;
7use arrow_flight::error::FlightError;
8use arrow_flight::flight_service_client::FlightServiceClient;
9use arrow_schema::SchemaRef;
10use datafusion::common::Statistics;
11use datafusion::config::ConfigOptions;
12use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaMapper};
13use datafusion::execution::object_store::ObjectStoreUrl;
14use datafusion::physical_plan::Distribution;
15use datafusion::physical_plan::execution_plan::CardinalityEffect;
16use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::projection::ProjectionExec;
18use datafusion::{
19 error::Result,
20 execution::{RecordBatchStream, SendableRecordBatchStream},
21 physical_plan::{
22 DisplayAs, DisplayFormatType, ExecutionPlan, stream::RecordBatchStreamAdapter,
23 },
24};
25use datafusion_proto::bytes::physical_plan_to_bytes;
26use fastrace::Span;
27use fastrace::future::FutureExt;
28use fastrace::prelude::*;
29use futures::{Stream, TryStreamExt, future::BoxFuture, ready};
30use liquid_cache_common::CacheMode;
31use liquid_cache_common::rpc::{
32 FetchResults, LiquidCacheActions, RegisterObjectStoreRequest, RegisterPlanRequest,
33};
34use tokio::sync::Mutex;
35use tonic::Request;
36use uuid::Uuid;
37
38use crate::metrics::FlightStreamMetrics;
39use crate::{flight_channel, to_df_err};
40
41pub struct LiquidCacheClientExec {
43 remote_plan: Arc<dyn ExecutionPlan>,
44 cache_server: String,
45 plan_register_lock: Arc<Mutex<Option<Uuid>>>,
46 cache_mode: CacheMode,
47 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
48 metrics: ExecutionPlanMetricsSet,
49}
50
51impl std::fmt::Debug for LiquidCacheClientExec {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "LiquidCacheClientExec")
54 }
55}
56
57impl LiquidCacheClientExec {
58 pub(crate) fn new(
59 remote_plan: Arc<dyn ExecutionPlan>,
60 cache_server: String,
61 cache_mode: CacheMode,
62 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
63 ) -> Self {
64 Self {
65 remote_plan,
66 cache_server,
67 plan_register_lock: Arc::new(Mutex::new(None)),
68 cache_mode,
69 object_stores,
70 metrics: ExecutionPlanMetricsSet::new(),
71 }
72 }
73
74 pub async fn get_plan_uuid(&self) -> Option<Uuid> {
76 *self.plan_register_lock.lock().await
77 }
78}
79
80impl DisplayAs for LiquidCacheClientExec {
81 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result {
82 match t {
83 DisplayFormatType::Default | DisplayFormatType::Verbose => {
84 write!(
85 f,
86 "LiquidCacheClientExec: server={}, mode={}, object_stores={:?}",
87 self.cache_server, self.cache_mode, self.object_stores
88 )
89 }
90 }
91 }
92}
93
94impl ExecutionPlan for LiquidCacheClientExec {
95 fn as_any(&self) -> &dyn Any {
96 self
97 }
98
99 fn name(&self) -> &str {
100 "LiquidCacheClientExec"
101 }
102
103 fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
104 self.remote_plan.properties()
105 }
106
107 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
108 vec![&self.remote_plan]
109 }
110
111 fn with_new_children(
112 self: Arc<Self>,
113 children: Vec<Arc<dyn ExecutionPlan>>,
114 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
115 Ok(Arc::new(Self {
116 remote_plan: children.first().unwrap().clone(),
117 cache_server: self.cache_server.clone(),
118 plan_register_lock: self.plan_register_lock.clone(),
119 cache_mode: self.cache_mode,
120 object_stores: self.object_stores.clone(),
121 metrics: self.metrics.clone(),
122 }))
123 }
124
125 fn execute(
126 &self,
127 partition: usize,
128 context: Arc<datafusion::execution::TaskContext>,
129 ) -> datafusion::error::Result<datafusion::execution::SendableRecordBatchStream> {
130 let cache_server = self.cache_server.clone();
131 let plan = self.remote_plan.clone();
132 let lock = self.plan_register_lock.clone();
133 let stream_metrics = FlightStreamMetrics::new(&self.metrics, partition);
134
135 let span = context
136 .session_config()
137 .get_extension::<Span>()
138 .unwrap_or_default();
139 let exec_span = Span::enter_with_parent("exec_flight_stream", &span);
140 let create_stream_span = Span::enter_with_parent("create_flight_stream", &exec_span);
141 let stream = flight_stream(
142 cache_server,
143 plan,
144 lock,
145 partition,
146 self.object_stores.clone(),
147 self.cache_mode,
148 );
149 Ok(Box::pin(FlightStream::new(
150 Some(Box::pin(stream)),
151 self.remote_plan.schema().clone(),
152 stream_metrics,
153 exec_span,
154 create_stream_span,
155 )))
156 }
157
158 fn required_input_distribution(&self) -> Vec<Distribution> {
159 self.remote_plan.required_input_distribution()
160 }
161
162 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
163 self.remote_plan.benefits_from_input_partitioning()
164 }
165
166 fn repartitioned(
167 &self,
168 target_partitions: usize,
169 config: &ConfigOptions,
170 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
171 self.remote_plan.repartitioned(target_partitions, config)
172 }
173
174 fn statistics(&self) -> Result<Statistics> {
175 self.remote_plan.statistics()
176 }
177
178 fn supports_limit_pushdown(&self) -> bool {
179 self.remote_plan.supports_limit_pushdown()
180 }
181
182 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
183 self.remote_plan.with_fetch(limit)
184 }
185
186 fn fetch(&self) -> Option<usize> {
187 self.remote_plan.fetch()
188 }
189
190 fn cardinality_effect(&self) -> CardinalityEffect {
191 self.remote_plan.cardinality_effect()
192 }
193
194 fn try_swapping_with_projection(
195 &self,
196 projection: &ProjectionExec,
197 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
198 self.remote_plan.try_swapping_with_projection(projection)
199 }
200
201 fn metrics(&self) -> Option<MetricsSet> {
202 Some(self.metrics.clone_inner())
203 }
204}
205
206async fn flight_stream(
207 server: String,
208 plan: Arc<dyn ExecutionPlan>,
209 plan_register_lock: Arc<Mutex<Option<Uuid>>>,
210 partition: usize,
211 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
212 cache_mode: CacheMode,
213) -> Result<SendableRecordBatchStream> {
214 let channel = flight_channel(server)
215 .in_span(Span::enter_with_local_parent("connect_channel"))
216 .await?;
217
218 let mut client = FlightServiceClient::new(channel);
219 let schema = plan.schema().clone();
220
221 let handle = {
223 let _span = Span::enter_with_local_parent("register_plan");
224 let mut maybe_uuid = plan_register_lock.lock().await;
225 match maybe_uuid.as_ref() {
226 Some(uuid) => {
227 LocalSpan::add_event(Event::new("get_existing_plan"));
228 *uuid
229 }
230 None => {
231 LocalSpan::add_event(Event::new("locked_register_plan"));
233 for (url, options) in &object_stores {
234 let action =
235 LiquidCacheActions::RegisterObjectStore(RegisterObjectStoreRequest {
236 url: url.to_string(),
237 options: options.clone(),
238 })
239 .into();
240 client
241 .do_action(Request::new(action))
242 .await
243 .map_err(to_df_err)?;
244 }
245 let plan_bytes = physical_plan_to_bytes(plan)?;
247 let handle = Uuid::new_v4();
248 let action = LiquidCacheActions::RegisterPlan(RegisterPlanRequest {
249 plan: plan_bytes.to_vec(),
250 handle: handle.into_bytes().to_vec().into(),
251 cache_mode: cache_mode.to_string(),
252 })
253 .into();
254 client
255 .do_action(Request::new(action))
256 .await
257 .map_err(to_df_err)?;
258 *maybe_uuid = Some(handle);
259 LocalSpan::add_event(Event::new("unlocked_register_plan"));
260 handle
261 }
262 }
263 };
264
265 let current = SpanContext::current_local_parent().unwrap_or_else(SpanContext::random);
266
267 let fetch_results = FetchResults {
268 handle: handle.into_bytes().to_vec().into(),
269 partition: partition as u32,
270 traceparent: current.encode_w3c_traceparent(),
271 };
272 let ticket = fetch_results.into_ticket();
273 let (md, response_stream, _ext) = client.do_get(ticket).await.map_err(to_df_err)?.into_parts();
274 LocalSpan::add_event(Event::new("get_flight_stream"));
275 let stream =
276 FlightRecordBatchStream::new_from_flight_data(response_stream.map_err(FlightError::Tonic))
277 .with_headers(md)
278 .map_err(to_df_err);
279 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
280}
281
282enum FlightStreamState {
283 Init,
284 GetStream(BoxFuture<'static, Result<SendableRecordBatchStream>>),
285 Processing(SendableRecordBatchStream),
286}
287
288struct FlightStream {
289 future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
290 state: FlightStreamState,
291 schema: SchemaRef,
292 schema_mapper: Option<Arc<dyn SchemaMapper>>,
293 metrics: FlightStreamMetrics,
294 poll_stream_span: fastrace::Span,
295 create_stream_span: Option<fastrace::Span>,
296}
297
298impl FlightStream {
299 fn new(
300 future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
301 schema: SchemaRef,
302 metrics: FlightStreamMetrics,
303 poll_stream_span: fastrace::Span,
304 create_stream_span: fastrace::Span,
305 ) -> Self {
306 Self {
307 future_stream,
308 state: FlightStreamState::Init,
309 schema,
310 schema_mapper: None,
311 metrics,
312 poll_stream_span,
313 create_stream_span: Some(create_stream_span),
314 }
315 }
316}
317
318use futures::StreamExt;
319impl FlightStream {
320 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
321 loop {
322 match &mut self.state {
323 FlightStreamState::Init => {
324 self.metrics.time_reading_total.start();
325 self.state = FlightStreamState::GetStream(self.future_stream.take().unwrap());
326 continue;
327 }
328 FlightStreamState::GetStream(fut) => {
329 let _guard = self.create_stream_span.as_ref().unwrap().set_local_parent();
330 let stream = ready!(fut.as_mut().poll(cx)).unwrap();
331 self.create_stream_span.take();
332 self.state = FlightStreamState::Processing(stream);
333 continue;
334 }
335 FlightStreamState::Processing(stream) => {
336 let result = stream.poll_next_unpin(cx);
337 self.metrics.poll_count.add(1);
338 return result;
339 }
340 }
341 }
342 }
343}
344
345impl Stream for FlightStream {
346 type Item = Result<RecordBatch>;
347
348 fn poll_next(
349 mut self: std::pin::Pin<&mut Self>,
350 cx: &mut std::task::Context<'_>,
351 ) -> std::task::Poll<Option<Self::Item>> {
352 let _guard = self.poll_stream_span.set_local_parent();
353 self.metrics.time_processing.start();
354 let result = self.poll_inner(cx);
355 match result {
356 Poll::Ready(Some(Ok(batch))) => {
357 let coerced_batch = if let Some(schema_mapper) = &self.schema_mapper {
358 schema_mapper.map_batch(batch).unwrap()
359 } else {
360 let (schema_mapper, _) =
361 DefaultSchemaAdapterFactory::from_schema(self.schema.clone())
362 .map_schema(&batch.schema())
363 .unwrap();
364 let batch = schema_mapper.map_batch(batch).unwrap();
365
366 self.schema_mapper = Some(schema_mapper);
367 batch
368 };
369 self.metrics.output_rows.add(coerced_batch.num_rows());
370 self.metrics
371 .bytes_decoded
372 .add(coerced_batch.get_array_memory_size());
373 self.metrics.time_processing.stop();
374 LocalSpan::add_event(Event::new("emit_batch"));
375 Poll::Ready(Some(Ok(coerced_batch)))
376 }
377 Poll::Ready(None) => {
378 self.metrics.time_processing.stop();
379 self.metrics.time_reading_total.stop();
380 Poll::Ready(None)
381 }
382 Poll::Ready(Some(Err(e))) => {
383 panic!("Error in flight stream: {:?}", e);
384 }
385 Poll::Pending => {
386 self.metrics.time_processing.stop();
387 Poll::Pending
388 }
389 }
390 }
391}
392
393impl RecordBatchStream for FlightStream {
394 fn schema(&self) -> SchemaRef {
395 self.schema.clone()
396 }
397}