liquid_cache_client/
client_exec.rs

1use std::collections::HashMap;
2use std::task::{Context, Poll};
3use std::{any::Any, fmt::Formatter, sync::Arc};
4
5use arrow::array::RecordBatch;
6use arrow_flight::decode::FlightRecordBatchStream;
7use arrow_flight::error::FlightError;
8use arrow_flight::flight_service_client::FlightServiceClient;
9use arrow_schema::SchemaRef;
10use datafusion::common::Statistics;
11use datafusion::config::ConfigOptions;
12use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaMapper};
13use datafusion::execution::object_store::ObjectStoreUrl;
14use datafusion::physical_plan::Distribution;
15use datafusion::physical_plan::execution_plan::CardinalityEffect;
16use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::projection::ProjectionExec;
18use datafusion::{
19    error::Result,
20    execution::{RecordBatchStream, SendableRecordBatchStream},
21    physical_plan::{
22        DisplayAs, DisplayFormatType, ExecutionPlan, stream::RecordBatchStreamAdapter,
23    },
24};
25use datafusion_proto::bytes::physical_plan_to_bytes;
26use fastrace::Span;
27use fastrace::future::FutureExt;
28use fastrace::prelude::*;
29use futures::{Stream, TryStreamExt, future::BoxFuture, ready};
30use liquid_cache_common::CacheMode;
31use liquid_cache_common::rpc::{
32    FetchResults, LiquidCacheActions, RegisterObjectStoreRequest, RegisterPlanRequest,
33};
34use tokio::sync::Mutex;
35use tonic::Request;
36use uuid::Uuid;
37
38use crate::metrics::FlightStreamMetrics;
39use crate::{flight_channel, to_df_err};
40
41/// The execution plan for the LiquidCache client.
42pub struct LiquidCacheClientExec {
43    remote_plan: Arc<dyn ExecutionPlan>,
44    cache_server: String,
45    plan_register_lock: Arc<Mutex<Option<Uuid>>>,
46    cache_mode: CacheMode,
47    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
48    metrics: ExecutionPlanMetricsSet,
49}
50
51impl std::fmt::Debug for LiquidCacheClientExec {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "LiquidCacheClientExec")
54    }
55}
56
57impl LiquidCacheClientExec {
58    pub(crate) fn new(
59        remote_plan: Arc<dyn ExecutionPlan>,
60        cache_server: String,
61        cache_mode: CacheMode,
62        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
63    ) -> Self {
64        Self {
65            remote_plan,
66            cache_server,
67            plan_register_lock: Arc::new(Mutex::new(None)),
68            cache_mode,
69            object_stores,
70            metrics: ExecutionPlanMetricsSet::new(),
71        }
72    }
73
74    /// Get the UUID of the plan.
75    pub async fn get_plan_uuid(&self) -> Option<Uuid> {
76        *self.plan_register_lock.lock().await
77    }
78}
79
80impl DisplayAs for LiquidCacheClientExec {
81    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result {
82        match t {
83            DisplayFormatType::Default | DisplayFormatType::Verbose => {
84                write!(
85                    f,
86                    "LiquidCacheClientExec: server={}, mode={}, object_stores={:?}",
87                    self.cache_server, self.cache_mode, self.object_stores
88                )
89            }
90        }
91    }
92}
93
94impl ExecutionPlan for LiquidCacheClientExec {
95    fn as_any(&self) -> &dyn Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "LiquidCacheClientExec"
101    }
102
103    fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
104        self.remote_plan.properties()
105    }
106
107    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
108        vec![&self.remote_plan]
109    }
110
111    fn with_new_children(
112        self: Arc<Self>,
113        children: Vec<Arc<dyn ExecutionPlan>>,
114    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
115        Ok(Arc::new(Self {
116            remote_plan: children.first().unwrap().clone(),
117            cache_server: self.cache_server.clone(),
118            plan_register_lock: self.plan_register_lock.clone(),
119            cache_mode: self.cache_mode,
120            object_stores: self.object_stores.clone(),
121            metrics: self.metrics.clone(),
122        }))
123    }
124
125    fn execute(
126        &self,
127        partition: usize,
128        context: Arc<datafusion::execution::TaskContext>,
129    ) -> datafusion::error::Result<datafusion::execution::SendableRecordBatchStream> {
130        let cache_server = self.cache_server.clone();
131        let plan = self.remote_plan.clone();
132        let lock = self.plan_register_lock.clone();
133        let stream_metrics = FlightStreamMetrics::new(&self.metrics, partition);
134
135        let span = context
136            .session_config()
137            .get_extension::<Span>()
138            .unwrap_or_default();
139        let exec_span = Span::enter_with_parent("exec_flight_stream", &span);
140        let create_stream_span = Span::enter_with_parent("create_flight_stream", &exec_span);
141        let stream = flight_stream(
142            cache_server,
143            plan,
144            lock,
145            partition,
146            self.object_stores.clone(),
147            self.cache_mode,
148        );
149        Ok(Box::pin(FlightStream::new(
150            Some(Box::pin(stream)),
151            self.remote_plan.schema().clone(),
152            stream_metrics,
153            exec_span,
154            create_stream_span,
155        )))
156    }
157
158    fn required_input_distribution(&self) -> Vec<Distribution> {
159        self.remote_plan.required_input_distribution()
160    }
161
162    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
163        self.remote_plan.benefits_from_input_partitioning()
164    }
165
166    fn repartitioned(
167        &self,
168        target_partitions: usize,
169        config: &ConfigOptions,
170    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
171        self.remote_plan.repartitioned(target_partitions, config)
172    }
173
174    fn statistics(&self) -> Result<Statistics> {
175        self.remote_plan.statistics()
176    }
177
178    fn supports_limit_pushdown(&self) -> bool {
179        self.remote_plan.supports_limit_pushdown()
180    }
181
182    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
183        self.remote_plan.with_fetch(limit)
184    }
185
186    fn fetch(&self) -> Option<usize> {
187        self.remote_plan.fetch()
188    }
189
190    fn cardinality_effect(&self) -> CardinalityEffect {
191        self.remote_plan.cardinality_effect()
192    }
193
194    fn try_swapping_with_projection(
195        &self,
196        projection: &ProjectionExec,
197    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
198        self.remote_plan.try_swapping_with_projection(projection)
199    }
200
201    fn metrics(&self) -> Option<MetricsSet> {
202        Some(self.metrics.clone_inner())
203    }
204}
205
206async fn flight_stream(
207    server: String,
208    plan: Arc<dyn ExecutionPlan>,
209    plan_register_lock: Arc<Mutex<Option<Uuid>>>,
210    partition: usize,
211    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
212    cache_mode: CacheMode,
213) -> Result<SendableRecordBatchStream> {
214    let channel = flight_channel(server)
215        .in_span(Span::enter_with_local_parent("connect_channel"))
216        .await?;
217
218    let mut client = FlightServiceClient::new(channel);
219    let schema = plan.schema().clone();
220
221    // Only one partition needs to register the plan
222    let handle = {
223        let _span = Span::enter_with_local_parent("register_plan");
224        let mut maybe_uuid = plan_register_lock.lock().await;
225        match maybe_uuid.as_ref() {
226            Some(uuid) => {
227                LocalSpan::add_event(Event::new("get_existing_plan"));
228                *uuid
229            }
230            None => {
231                // Register object stores
232                LocalSpan::add_event(Event::new("locked_register_plan"));
233                for (url, options) in &object_stores {
234                    let action =
235                        LiquidCacheActions::RegisterObjectStore(RegisterObjectStoreRequest {
236                            url: url.to_string(),
237                            options: options.clone(),
238                        })
239                        .into();
240                    client
241                        .do_action(Request::new(action))
242                        .await
243                        .map_err(to_df_err)?;
244                }
245                // Register plan
246                let plan_bytes = physical_plan_to_bytes(plan)?;
247                let handle = Uuid::new_v4();
248                let action = LiquidCacheActions::RegisterPlan(RegisterPlanRequest {
249                    plan: plan_bytes.to_vec(),
250                    handle: handle.into_bytes().to_vec().into(),
251                    cache_mode: cache_mode.to_string(),
252                })
253                .into();
254                client
255                    .do_action(Request::new(action))
256                    .await
257                    .map_err(to_df_err)?;
258                *maybe_uuid = Some(handle);
259                LocalSpan::add_event(Event::new("unlocked_register_plan"));
260                handle
261            }
262        }
263    };
264
265    let current = SpanContext::current_local_parent().unwrap_or_else(SpanContext::random);
266
267    let fetch_results = FetchResults {
268        handle: handle.into_bytes().to_vec().into(),
269        partition: partition as u32,
270        traceparent: current.encode_w3c_traceparent(),
271    };
272    let ticket = fetch_results.into_ticket();
273    let (md, response_stream, _ext) = client.do_get(ticket).await.map_err(to_df_err)?.into_parts();
274    LocalSpan::add_event(Event::new("get_flight_stream"));
275    let stream =
276        FlightRecordBatchStream::new_from_flight_data(response_stream.map_err(FlightError::Tonic))
277            .with_headers(md)
278            .map_err(to_df_err);
279    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
280}
281
282enum FlightStreamState {
283    Init,
284    GetStream(BoxFuture<'static, Result<SendableRecordBatchStream>>),
285    Processing(SendableRecordBatchStream),
286}
287
288struct FlightStream {
289    future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
290    state: FlightStreamState,
291    schema: SchemaRef,
292    schema_mapper: Option<Arc<dyn SchemaMapper>>,
293    metrics: FlightStreamMetrics,
294    poll_stream_span: fastrace::Span,
295    create_stream_span: Option<fastrace::Span>,
296}
297
298impl FlightStream {
299    fn new(
300        future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
301        schema: SchemaRef,
302        metrics: FlightStreamMetrics,
303        poll_stream_span: fastrace::Span,
304        create_stream_span: fastrace::Span,
305    ) -> Self {
306        Self {
307            future_stream,
308            state: FlightStreamState::Init,
309            schema,
310            schema_mapper: None,
311            metrics,
312            poll_stream_span,
313            create_stream_span: Some(create_stream_span),
314        }
315    }
316}
317
318use futures::StreamExt;
319impl FlightStream {
320    fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
321        loop {
322            match &mut self.state {
323                FlightStreamState::Init => {
324                    self.metrics.time_reading_total.start();
325                    self.state = FlightStreamState::GetStream(self.future_stream.take().unwrap());
326                    continue;
327                }
328                FlightStreamState::GetStream(fut) => {
329                    let _guard = self.create_stream_span.as_ref().unwrap().set_local_parent();
330                    let stream = ready!(fut.as_mut().poll(cx)).unwrap();
331                    self.create_stream_span.take();
332                    self.state = FlightStreamState::Processing(stream);
333                    continue;
334                }
335                FlightStreamState::Processing(stream) => {
336                    let result = stream.poll_next_unpin(cx);
337                    self.metrics.poll_count.add(1);
338                    return result;
339                }
340            }
341        }
342    }
343}
344
345impl Stream for FlightStream {
346    type Item = Result<RecordBatch>;
347
348    fn poll_next(
349        mut self: std::pin::Pin<&mut Self>,
350        cx: &mut std::task::Context<'_>,
351    ) -> std::task::Poll<Option<Self::Item>> {
352        let _guard = self.poll_stream_span.set_local_parent();
353        self.metrics.time_processing.start();
354        let result = self.poll_inner(cx);
355        match result {
356            Poll::Ready(Some(Ok(batch))) => {
357                let coerced_batch = if let Some(schema_mapper) = &self.schema_mapper {
358                    schema_mapper.map_batch(batch).unwrap()
359                } else {
360                    let (schema_mapper, _) =
361                        DefaultSchemaAdapterFactory::from_schema(self.schema.clone())
362                            .map_schema(&batch.schema())
363                            .unwrap();
364                    let batch = schema_mapper.map_batch(batch).unwrap();
365
366                    self.schema_mapper = Some(schema_mapper);
367                    batch
368                };
369                self.metrics.output_rows.add(coerced_batch.num_rows());
370                self.metrics
371                    .bytes_decoded
372                    .add(coerced_batch.get_array_memory_size());
373                self.metrics.time_processing.stop();
374                LocalSpan::add_event(Event::new("emit_batch"));
375                Poll::Ready(Some(Ok(coerced_batch)))
376            }
377            Poll::Ready(None) => {
378                self.metrics.time_processing.stop();
379                self.metrics.time_reading_total.stop();
380                Poll::Ready(None)
381            }
382            Poll::Ready(Some(Err(e))) => {
383                panic!("Error in flight stream: {:?}", e);
384            }
385            Poll::Pending => {
386                self.metrics.time_processing.stop();
387                Poll::Pending
388            }
389        }
390    }
391}
392
393impl RecordBatchStream for FlightStream {
394    fn schema(&self) -> SchemaRef {
395        self.schema.clone()
396    }
397}