liquid_cache_client/
client_exec.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Context, Poll};
4use std::time::Duration;
5use std::{any::Any, fmt::Formatter, sync::Arc};
6
7use arrow::array::RecordBatch;
8use arrow_flight::decode::FlightRecordBatchStream;
9use arrow_flight::flight_service_client::FlightServiceClient;
10use arrow_schema::SchemaRef;
11use datafusion::common::Statistics;
12use datafusion::config::ConfigOptions;
13use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaMapper};
14use datafusion::execution::object_store::ObjectStoreUrl;
15use datafusion::physical_plan::Distribution;
16use datafusion::physical_plan::execution_plan::CardinalityEffect;
17use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::projection::ProjectionExec;
19use datafusion::{
20    error::Result,
21    execution::{RecordBatchStream, SendableRecordBatchStream},
22    physical_plan::{
23        DisplayAs, DisplayFormatType, ExecutionPlan, stream::RecordBatchStreamAdapter,
24    },
25};
26use datafusion_proto::bytes::physical_plan_to_bytes;
27use fastrace::Span;
28use fastrace::future::FutureExt;
29use fastrace::prelude::*;
30use futures::{Stream, TryStreamExt, future::BoxFuture, ready};
31use liquid_cache_common::CacheMode;
32use liquid_cache_common::rpc::{
33    FetchResults, LiquidCacheActions, RegisterObjectStoreRequest, RegisterPlanRequest,
34};
35use tonic::Request;
36use uuid::Uuid;
37
38use crate::metrics::FlightStreamMetrics;
39use crate::{flight_channel, to_df_err};
40
41#[repr(usize)]
42enum PlanRegisterState {
43    NotRegistered = 0,
44    InProgress = 1,
45    Registered = 2,
46}
47
48/// The execution plan for the LiquidCache client.
49pub struct LiquidCacheClientExec {
50    remote_plan: Arc<dyn ExecutionPlan>,
51    cache_server: String,
52    cache_mode: CacheMode,
53    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
54    metrics: ExecutionPlanMetricsSet,
55    uuid: Uuid,
56    plan_registered: Arc<AtomicUsize>,
57}
58
59impl std::fmt::Debug for LiquidCacheClientExec {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        write!(f, "LiquidCacheClientExec")
62    }
63}
64
65impl LiquidCacheClientExec {
66    pub(crate) fn new(
67        remote_plan: Arc<dyn ExecutionPlan>,
68        cache_server: String,
69        cache_mode: CacheMode,
70        object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
71    ) -> Self {
72        let uuid = Uuid::new_v4();
73        Self {
74            remote_plan,
75            cache_server,
76            plan_registered: Arc::new(AtomicUsize::new(PlanRegisterState::NotRegistered as usize)),
77            cache_mode,
78            object_stores,
79            uuid,
80            metrics: ExecutionPlanMetricsSet::new(),
81        }
82    }
83
84    /// Get the UUID of the plan.
85    pub fn get_uuid(&self) -> Uuid {
86        self.uuid
87    }
88}
89
90impl DisplayAs for LiquidCacheClientExec {
91    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result {
92        match t {
93            DisplayFormatType::Default | DisplayFormatType::Verbose => {
94                write!(
95                    f,
96                    "LiquidCacheClientExec: server={}, mode={}, object_stores={:?}",
97                    self.cache_server, self.cache_mode, self.object_stores
98                )
99            }
100            DisplayFormatType::TreeRender => {
101                write!(
102                    f,
103                    "server={}, mode={}, object_stores={:?}",
104                    self.cache_server, self.cache_mode, self.object_stores
105                )
106            }
107        }
108    }
109}
110
111impl ExecutionPlan for LiquidCacheClientExec {
112    fn as_any(&self) -> &dyn Any {
113        self
114    }
115
116    fn name(&self) -> &str {
117        "LiquidCacheClientExec"
118    }
119
120    fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
121        self.remote_plan.properties()
122    }
123
124    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
125        vec![&self.remote_plan]
126    }
127
128    fn with_new_children(
129        self: Arc<Self>,
130        children: Vec<Arc<dyn ExecutionPlan>>,
131    ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
132        Ok(Arc::new(Self {
133            remote_plan: children.first().unwrap().clone(),
134            cache_server: self.cache_server.clone(),
135            plan_registered: self.plan_registered.clone(),
136            cache_mode: self.cache_mode,
137            object_stores: self.object_stores.clone(),
138            metrics: self.metrics.clone(),
139            uuid: self.uuid,
140        }))
141    }
142
143    fn execute(
144        &self,
145        partition: usize,
146        context: Arc<datafusion::execution::TaskContext>,
147    ) -> datafusion::error::Result<datafusion::execution::SendableRecordBatchStream> {
148        let cache_server = self.cache_server.clone();
149        let plan = self.remote_plan.clone();
150        let lock = self.plan_registered.clone();
151        let stream_metrics = FlightStreamMetrics::new(&self.metrics, partition);
152
153        let span = context
154            .session_config()
155            .get_extension::<Span>()
156            .unwrap_or_default();
157        let exec_span = Span::enter_with_parent("exec_flight_stream", &span);
158        let create_stream_span = Span::enter_with_parent("create_flight_stream", &exec_span);
159        let stream = flight_stream(
160            cache_server,
161            plan,
162            lock,
163            self.uuid,
164            partition,
165            self.object_stores.clone(),
166        );
167        Ok(Box::pin(FlightStream::new(
168            Some(Box::pin(stream)),
169            self.remote_plan.schema().clone(),
170            stream_metrics,
171            exec_span,
172            create_stream_span,
173        )))
174    }
175
176    fn required_input_distribution(&self) -> Vec<Distribution> {
177        self.remote_plan.required_input_distribution()
178    }
179
180    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
181        self.remote_plan.benefits_from_input_partitioning()
182    }
183
184    fn repartitioned(
185        &self,
186        target_partitions: usize,
187        config: &ConfigOptions,
188    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
189        self.remote_plan.repartitioned(target_partitions, config)
190    }
191
192    fn statistics(&self) -> Result<Statistics> {
193        self.remote_plan.statistics()
194    }
195
196    fn supports_limit_pushdown(&self) -> bool {
197        self.remote_plan.supports_limit_pushdown()
198    }
199
200    fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
201        self.remote_plan.with_fetch(limit)
202    }
203
204    fn fetch(&self) -> Option<usize> {
205        self.remote_plan.fetch()
206    }
207
208    fn cardinality_effect(&self) -> CardinalityEffect {
209        self.remote_plan.cardinality_effect()
210    }
211
212    fn try_swapping_with_projection(
213        &self,
214        projection: &ProjectionExec,
215    ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
216        self.remote_plan.try_swapping_with_projection(projection)
217    }
218
219    fn metrics(&self) -> Option<MetricsSet> {
220        Some(self.metrics.clone_inner())
221    }
222}
223
224async fn flight_stream(
225    server: String,
226    plan: Arc<dyn ExecutionPlan>,
227    plan_registered: Arc<AtomicUsize>,
228    handle: Uuid,
229    partition: usize,
230    object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
231) -> Result<SendableRecordBatchStream> {
232    let channel = flight_channel(server)
233        .in_span(Span::enter_with_local_parent("connect_channel"))
234        .await?;
235
236    let mut client = FlightServiceClient::new(channel).max_decoding_message_size(1024 * 1024 * 8);
237    let schema = plan.schema().clone();
238
239    match plan_registered.compare_exchange(
240        PlanRegisterState::NotRegistered as usize,
241        PlanRegisterState::InProgress as usize,
242        Ordering::AcqRel,
243        Ordering::Relaxed,
244    ) {
245        Ok(_) => {
246            LocalSpan::add_event(Event::new("register_plan"));
247
248            for (url, options) in &object_stores {
249                let action = LiquidCacheActions::RegisterObjectStore(RegisterObjectStoreRequest {
250                    url: url.to_string(),
251                    options: options.clone(),
252                })
253                .into();
254                client
255                    .do_action(Request::new(action))
256                    .await
257                    .map_err(to_df_err)?;
258            }
259            // Register plan
260            let plan_bytes = physical_plan_to_bytes(plan)?;
261            let action = LiquidCacheActions::RegisterPlan(RegisterPlanRequest {
262                plan: plan_bytes.to_vec(),
263                handle: handle.into_bytes().to_vec().into(),
264            })
265            .into();
266            client
267                .do_action(Request::new(action))
268                .await
269                .map_err(to_df_err)?;
270            plan_registered.store(PlanRegisterState::Registered as usize, Ordering::Release);
271            LocalSpan::add_event(Event::new("register_plan_done"));
272        }
273        Err(_e) => {
274            LocalSpan::add_event(Event::new("getting_existing_plan"));
275            while plan_registered.load(Ordering::Acquire) != PlanRegisterState::Registered as usize
276            {
277                tokio::time::sleep(Duration::from_micros(100)).await;
278            }
279            LocalSpan::add_event(Event::new("got_existing_plan"));
280        }
281    };
282
283    let current = SpanContext::current_local_parent().unwrap_or_else(SpanContext::random);
284
285    let fetch_results = FetchResults {
286        handle: handle.into_bytes().to_vec().into(),
287        partition: partition as u32,
288        traceparent: current.encode_w3c_traceparent(),
289    };
290    let ticket = fetch_results.into_ticket();
291    let (md, response_stream, _ext) = client.do_get(ticket).await.map_err(to_df_err)?.into_parts();
292    LocalSpan::add_event(Event::new("get_flight_stream"));
293    let stream =
294        FlightRecordBatchStream::new_from_flight_data(response_stream.map_err(|e| e.into()))
295            .with_headers(md)
296            .map_err(to_df_err);
297    Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
298}
299
300enum FlightStreamState {
301    Init,
302    GetStream(BoxFuture<'static, Result<SendableRecordBatchStream>>),
303    Processing(SendableRecordBatchStream),
304}
305
306struct FlightStream {
307    future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
308    state: FlightStreamState,
309    schema: SchemaRef,
310    schema_mapper: Option<Arc<dyn SchemaMapper>>,
311    metrics: FlightStreamMetrics,
312    poll_stream_span: fastrace::Span,
313    create_stream_span: Option<fastrace::Span>,
314}
315
316impl FlightStream {
317    fn new(
318        future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
319        schema: SchemaRef,
320        metrics: FlightStreamMetrics,
321        poll_stream_span: fastrace::Span,
322        create_stream_span: fastrace::Span,
323    ) -> Self {
324        Self {
325            future_stream,
326            state: FlightStreamState::Init,
327            schema,
328            schema_mapper: None,
329            metrics,
330            poll_stream_span,
331            create_stream_span: Some(create_stream_span),
332        }
333    }
334}
335
336use futures::StreamExt;
337impl FlightStream {
338    fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
339        loop {
340            match &mut self.state {
341                FlightStreamState::Init => {
342                    self.metrics.time_reading_total.start();
343                    self.state = FlightStreamState::GetStream(self.future_stream.take().unwrap());
344                    continue;
345                }
346                FlightStreamState::GetStream(fut) => {
347                    let _guard = self.create_stream_span.as_ref().unwrap().set_local_parent();
348                    let stream = ready!(fut.as_mut().poll(cx)).unwrap();
349                    self.create_stream_span.take();
350                    self.state = FlightStreamState::Processing(stream);
351                    continue;
352                }
353                FlightStreamState::Processing(stream) => {
354                    let result = stream.poll_next_unpin(cx);
355                    self.metrics.poll_count.add(1);
356                    return result;
357                }
358            }
359        }
360    }
361}
362
363impl Stream for FlightStream {
364    type Item = Result<RecordBatch>;
365
366    fn poll_next(
367        mut self: std::pin::Pin<&mut Self>,
368        cx: &mut std::task::Context<'_>,
369    ) -> std::task::Poll<Option<Self::Item>> {
370        let _guard = self.poll_stream_span.set_local_parent();
371        self.metrics.time_processing.start();
372        let result = self.poll_inner(cx);
373        match result {
374            Poll::Ready(Some(Ok(batch))) => {
375                let coerced_batch = if let Some(schema_mapper) = &self.schema_mapper {
376                    schema_mapper.map_batch(batch).unwrap()
377                } else {
378                    let (schema_mapper, _) =
379                        DefaultSchemaAdapterFactory::from_schema(self.schema.clone())
380                            .map_schema(&batch.schema())
381                            .unwrap();
382                    let batch = schema_mapper.map_batch(batch).unwrap();
383
384                    self.schema_mapper = Some(schema_mapper);
385                    batch
386                };
387                self.metrics.output_rows.add(coerced_batch.num_rows());
388                self.metrics
389                    .bytes_decoded
390                    .add(coerced_batch.get_array_memory_size());
391                self.metrics.time_processing.stop();
392                LocalSpan::add_event(Event::new("emit_batch"));
393                Poll::Ready(Some(Ok(coerced_batch)))
394            }
395            Poll::Ready(None) => {
396                self.metrics.time_processing.stop();
397                self.metrics.time_reading_total.stop();
398                Poll::Ready(None)
399            }
400            Poll::Ready(Some(Err(e))) => {
401                panic!("Error in flight stream: {e:?}");
402            }
403            Poll::Pending => {
404                self.metrics.time_processing.stop();
405                Poll::Pending
406            }
407        }
408    }
409}
410
411impl RecordBatchStream for FlightStream {
412    fn schema(&self) -> SchemaRef {
413        self.schema.clone()
414    }
415}