liquid_cache_client/
client_exec.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Context, Poll};
4use std::time::Duration;
5use std::{any::Any, fmt::Formatter, sync::Arc};
6
7use arrow::array::RecordBatch;
8use arrow_flight::decode::FlightRecordBatchStream;
9use arrow_flight::flight_service_client::FlightServiceClient;
10use arrow_schema::SchemaRef;
11use datafusion::common::Statistics;
12use datafusion::config::ConfigOptions;
13use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaMapper};
14use datafusion::execution::object_store::ObjectStoreUrl;
15use datafusion::physical_plan::execution_plan::CardinalityEffect;
16use datafusion::physical_plan::filter_pushdown::{
17    ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation,
18};
19use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
20use datafusion::physical_plan::{ExecutionPlanProperties, PhysicalExpr, PlanProperties};
21use datafusion::{
22    error::Result,
23    execution::{RecordBatchStream, SendableRecordBatchStream},
24    physical_plan::{
25        DisplayAs, DisplayFormatType, ExecutionPlan, stream::RecordBatchStreamAdapter,
26    },
27};
28use datafusion_proto::bytes::physical_plan_to_bytes;
29use fastrace::Span;
30use fastrace::future::FutureExt;
31use fastrace::prelude::*;
32use futures::{Stream, TryStreamExt, future::BoxFuture, ready};
33use liquid_cache_common::rpc::{
34    FetchResults, LiquidCacheActions, RegisterObjectStoreRequest, RegisterPlanRequest,
35};
36use tonic::Request;
37use uuid::Uuid;
38
39use crate::metrics::FlightStreamMetrics;
40use crate::{flight_channel, to_df_err};
41
42#[repr(usize)]
43enum PlanRegisterState {
44    NotRegistered = 0,
45    InProgress = 1,
46    Registered = 2,
47}
48
49/// The execution plan for the LiquidCache client.
50pub struct LiquidCacheClientExec {
51    remote_plan: Arc<dyn ExecutionPlan>,
52    cache_server: String,
53    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
54    metrics: ExecutionPlanMetricsSet,
55    uuid: Uuid,
56    plan_registered: Arc<AtomicUsize>,
57    properties: PlanProperties,
58}
59
60impl std::fmt::Debug for LiquidCacheClientExec {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "LiquidCacheClientExec")
63    }
64}
65
66impl LiquidCacheClientExec {
67    pub(crate) fn new(
68        remote_plan: Arc<dyn ExecutionPlan>,
69        cache_server: String,
70        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
71    ) -> Self {
72        let properties = PlanProperties::new(
73            remote_plan.equivalence_properties().clone(), // Equivalence Properties
74            remote_plan.output_partitioning().clone(),    // Output Partitioning
75            remote_plan.pipeline_behavior(),
76            remote_plan.boundedness(),
77        );
78        let uuid = Uuid::new_v4();
79        Self {
80            remote_plan,
81            cache_server,
82            plan_registered: Arc::new(AtomicUsize::new(PlanRegisterState::NotRegistered as usize)),
83            object_stores,
84            uuid,
85            metrics: ExecutionPlanMetricsSet::new(),
86            properties,
87        }
88    }
89
90    /// Get the UUID of the plan.
91    pub fn get_uuid(&self) -> Uuid {
92        self.uuid
93    }
94}
95
96impl DisplayAs for LiquidCacheClientExec {
97    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result {
98        match t {
99            DisplayFormatType::Default | DisplayFormatType::Verbose => {
100                write!(
101                    f,
102                    "LiquidCacheClientExec: server={}, object_stores={:?}",
103                    self.cache_server, self.object_stores
104                )
105            }
106            DisplayFormatType::TreeRender => {
107                write!(
108                    f,
109                    "server={}, object_stores={:?}",
110                    self.cache_server, self.object_stores
111                )
112            }
113        }
114    }
115}
116
117impl ExecutionPlan for LiquidCacheClientExec {
118    fn as_any(&self) -> &dyn Any {
119        self
120    }
121
122    fn name(&self) -> &str {
123        "LiquidCacheClientExec"
124    }
125
126    fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
127        self.remote_plan.properties()
128    }
129
130    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
131        vec![&self.remote_plan]
132    }
133
134    fn with_new_children(
135        self: Arc<Self>,
136        children: Vec<Arc<dyn ExecutionPlan>>,
137    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
138        Ok(Arc::new(Self {
139            remote_plan: children.first().unwrap().clone(),
140            cache_server: self.cache_server.clone(),
141            plan_registered: self.plan_registered.clone(),
142            object_stores: self.object_stores.clone(),
143            metrics: self.metrics.clone(),
144            uuid: self.uuid,
145            properties: self.properties.clone(),
146        }))
147    }
148
149    fn execute(
150        &self,
151        partition: usize,
152        context: Arc<datafusion::execution::TaskContext>,
153    ) -> datafusion::error::Result<datafusion::execution::SendableRecordBatchStream> {
154        let cache_server = self.cache_server.clone();
155        let plan = self.remote_plan.clone();
156        let lock = self.plan_registered.clone();
157        let stream_metrics = FlightStreamMetrics::new(&self.metrics, partition);
158
159        let span = context
160            .session_config()
161            .get_extension::<Span>()
162            .unwrap_or_default();
163        let exec_span = Span::enter_with_parent("exec_flight_stream", &span);
164        let create_stream_span = Span::enter_with_parent("create_flight_stream", &exec_span);
165        let stream = flight_stream(
166            cache_server,
167            plan,
168            lock,
169            self.uuid,
170            partition,
171            self.object_stores.clone(),
172        );
173        Ok(Box::pin(FlightStream::new(
174            Some(Box::pin(stream)),
175            self.remote_plan.schema().clone(),
176            stream_metrics,
177            exec_span,
178            create_stream_span,
179        )))
180    }
181
182    fn statistics(&self) -> Result<Statistics> {
183        self.remote_plan.partition_statistics(None)
184    }
185
186    fn supports_limit_pushdown(&self) -> bool {
187        self.remote_plan.supports_limit_pushdown()
188    }
189
190    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
191        self.remote_plan.with_fetch(limit)
192    }
193
194    fn fetch(&self) -> Option<usize> {
195        self.remote_plan.fetch()
196    }
197
198    fn metrics(&self) -> Option<MetricsSet> {
199        Some(self.metrics.clone_inner())
200    }
201
202    fn maintains_input_order(&self) -> Vec<bool> {
203        vec![true]
204    }
205
206    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
207        vec![false]
208    }
209
210    fn cardinality_effect(&self) -> CardinalityEffect {
211        CardinalityEffect::Equal
212    }
213
214    fn gather_filters_for_pushdown(
215        &self,
216        _phase: FilterPushdownPhase,
217        parent_filters: Vec<Arc<dyn PhysicalExpr>>,
218        _config: &ConfigOptions,
219    ) -> Result<FilterDescription> {
220        FilterDescription::from_children(parent_filters, &self.children())
221    }
222
223    fn handle_child_pushdown_result(
224        &self,
225        _phase: FilterPushdownPhase,
226        child_pushdown_result: ChildPushdownResult,
227        _config: &ConfigOptions,
228    ) -> Result<FilterPushdownPropagation<Arc<dyn ExecutionPlan>>> {
229        Ok(FilterPushdownPropagation::if_all(child_pushdown_result))
230    }
231}
232
233async fn flight_stream(
234    server: String,
235    plan: Arc<dyn ExecutionPlan>,
236    plan_registered: Arc<AtomicUsize>,
237    handle: Uuid,
238    partition: usize,
239    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
240) -> Result<SendableRecordBatchStream> {
241    let channel = flight_channel(server)
242        .in_span(Span::enter_with_local_parent("connect_channel"))
243        .await?;
244
245    let mut client = FlightServiceClient::new(channel).max_decoding_message_size(1024 * 1024 * 8);
246    let schema = plan.schema().clone();
247
248    match plan_registered.compare_exchange(
249        PlanRegisterState::NotRegistered as usize,
250        PlanRegisterState::InProgress as usize,
251        Ordering::AcqRel,
252        Ordering::Relaxed,
253    ) {
254        Ok(_) => {
255            LocalSpan::add_event(Event::new("register_plan"));
256
257            for (url, options) in &object_stores {
258                let action = LiquidCacheActions::RegisterObjectStore(RegisterObjectStoreRequest {
259                    url: url.to_string(),
260                    options: options.clone(),
261                })
262                .into();
263                client
264                    .do_action(Request::new(action))
265                    .await
266                    .map_err(to_df_err)?;
267            }
268            // Register plan
269            let plan_bytes = physical_plan_to_bytes(plan)?;
270            let action = LiquidCacheActions::RegisterPlan(RegisterPlanRequest {
271                plan: plan_bytes.to_vec(),
272                handle: handle.into_bytes().to_vec().into(),
273            })
274            .into();
275            client
276                .do_action(Request::new(action))
277                .await
278                .map_err(to_df_err)?;
279            plan_registered.store(PlanRegisterState::Registered as usize, Ordering::Release);
280            LocalSpan::add_event(Event::new("register_plan_done"));
281        }
282        Err(_e) => {
283            LocalSpan::add_event(Event::new("getting_existing_plan"));
284            while plan_registered.load(Ordering::Acquire) != PlanRegisterState::Registered as usize
285            {
286                tokio::time::sleep(Duration::from_micros(100)).await;
287            }
288            LocalSpan::add_event(Event::new("got_existing_plan"));
289        }
290    };
291
292    let current = SpanContext::current_local_parent().unwrap_or_else(SpanContext::random);
293
294    let fetch_results = FetchResults {
295        handle: handle.into_bytes().to_vec().into(),
296        partition: partition as u32,
297        traceparent: current.encode_w3c_traceparent(),
298    };
299    let ticket = fetch_results.into_ticket();
300    let (md, response_stream, _ext) = client.do_get(ticket).await.map_err(to_df_err)?.into_parts();
301    LocalSpan::add_event(Event::new("get_flight_stream"));
302    let stream =
303        FlightRecordBatchStream::new_from_flight_data(response_stream.map_err(|e| e.into()))
304            .with_headers(md)
305            .map_err(to_df_err);
306    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
307}
308
309enum FlightStreamState {
310    Init,
311    GetStream(BoxFuture<'static, Result<SendableRecordBatchStream>>),
312    Processing(SendableRecordBatchStream),
313}
314
315struct FlightStream {
316    future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
317    state: FlightStreamState,
318    schema: SchemaRef,
319    schema_mapper: Option<Arc<dyn SchemaMapper>>,
320    metrics: FlightStreamMetrics,
321    poll_stream_span: fastrace::Span,
322    create_stream_span: Option<fastrace::Span>,
323}
324
325impl FlightStream {
326    fn new(
327        future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
328        schema: SchemaRef,
329        metrics: FlightStreamMetrics,
330        poll_stream_span: fastrace::Span,
331        create_stream_span: fastrace::Span,
332    ) -> Self {
333        Self {
334            future_stream,
335            state: FlightStreamState::Init,
336            schema,
337            schema_mapper: None,
338            metrics,
339            poll_stream_span,
340            create_stream_span: Some(create_stream_span),
341        }
342    }
343}
344
345use futures::StreamExt;
346impl FlightStream {
347    fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
348        loop {
349            match &mut self.state {
350                FlightStreamState::Init => {
351                    self.metrics.time_reading_total.start();
352                    self.state = FlightStreamState::GetStream(self.future_stream.take().unwrap());
353                    continue;
354                }
355                FlightStreamState::GetStream(fut) => {
356                    let _guard = self.create_stream_span.as_ref().unwrap().set_local_parent();
357                    let stream = ready!(fut.as_mut().poll(cx)).unwrap();
358                    self.create_stream_span.take();
359                    self.state = FlightStreamState::Processing(stream);
360                    continue;
361                }
362                FlightStreamState::Processing(stream) => {
363                    let result = stream.poll_next_unpin(cx);
364                    self.metrics.poll_count.add(1);
365                    return result;
366                }
367            }
368        }
369    }
370}
371
372impl Stream for FlightStream {
373    type Item = Result<RecordBatch>;
374
375    fn poll_next(
376        mut self: std::pin::Pin<&mut Self>,
377        cx: &mut std::task::Context<'_>,
378    ) -> std::task::Poll<Option<Self::Item>> {
379        let _guard = self.poll_stream_span.set_local_parent();
380        self.metrics.time_processing.start();
381        let result = self.poll_inner(cx);
382        match result {
383            Poll::Ready(Some(Ok(batch))) => {
384                let coerced_batch = if let Some(schema_mapper) = &self.schema_mapper {
385                    schema_mapper.map_batch(batch).unwrap()
386                } else {
387                    let (schema_mapper, _) =
388                        DefaultSchemaAdapterFactory::from_schema(self.schema.clone())
389                            .map_schema(&batch.schema())
390                            .unwrap();
391                    let batch = schema_mapper.map_batch(batch).unwrap();
392
393                    self.schema_mapper = Some(schema_mapper);
394                    batch
395                };
396                self.metrics.output_rows.add(coerced_batch.num_rows());
397                self.metrics
398                    .bytes_decoded
399                    .add(coerced_batch.get_array_memory_size());
400                self.metrics.time_processing.stop();
401                LocalSpan::add_event(Event::new("emit_batch"));
402                Poll::Ready(Some(Ok(coerced_batch)))
403            }
404            Poll::Ready(None) => {
405                self.metrics.time_processing.stop();
406                self.metrics.time_reading_total.stop();
407                Poll::Ready(None)
408            }
409            Poll::Ready(Some(Err(e))) => {
410                panic!("Error in flight stream: {e:?}");
411            }
412            Poll::Pending => {
413                self.metrics.time_processing.stop();
414                Poll::Pending
415            }
416        }
417    }
418}
419
420impl RecordBatchStream for FlightStream {
421    fn schema(&self) -> SchemaRef {
422        self.schema.clone()
423    }
424}