1use std::collections::HashMap;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::task::{Context, Poll};
4use std::time::Duration;
5use std::{any::Any, fmt::Formatter, sync::Arc};
6
7use arrow::array::RecordBatch;
8use arrow_flight::decode::FlightRecordBatchStream;
9use arrow_flight::flight_service_client::FlightServiceClient;
10use arrow_schema::SchemaRef;
11use datafusion::common::Statistics;
12use datafusion::config::ConfigOptions;
13use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaMapper};
14use datafusion::execution::object_store::ObjectStoreUrl;
15use datafusion::physical_plan::Distribution;
16use datafusion::physical_plan::execution_plan::CardinalityEffect;
17use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::projection::ProjectionExec;
19use datafusion::{
20 error::Result,
21 execution::{RecordBatchStream, SendableRecordBatchStream},
22 physical_plan::{
23 DisplayAs, DisplayFormatType, ExecutionPlan, stream::RecordBatchStreamAdapter,
24 },
25};
26use datafusion_proto::bytes::physical_plan_to_bytes;
27use fastrace::Span;
28use fastrace::future::FutureExt;
29use fastrace::prelude::*;
30use futures::{Stream, TryStreamExt, future::BoxFuture, ready};
31use liquid_cache_common::CacheMode;
32use liquid_cache_common::rpc::{
33 FetchResults, LiquidCacheActions, RegisterObjectStoreRequest, RegisterPlanRequest,
34};
35use tonic::Request;
36use uuid::Uuid;
37
38use crate::metrics::FlightStreamMetrics;
39use crate::{flight_channel, to_df_err};
40
41#[repr(usize)]
42enum PlanRegisterState {
43 NotRegistered = 0,
44 InProgress = 1,
45 Registered = 2,
46}
47
48pub struct LiquidCacheClientExec {
50 remote_plan: Arc<dyn ExecutionPlan>,
51 cache_server: String,
52 cache_mode: CacheMode,
53 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
54 metrics: ExecutionPlanMetricsSet,
55 uuid: Uuid,
56 plan_registered: Arc<AtomicUsize>,
57}
58
59impl std::fmt::Debug for LiquidCacheClientExec {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 write!(f, "LiquidCacheClientExec")
62 }
63}
64
65impl LiquidCacheClientExec {
66 pub(crate) fn new(
67 remote_plan: Arc<dyn ExecutionPlan>,
68 cache_server: String,
69 cache_mode: CacheMode,
70 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
71 ) -> Self {
72 let uuid = Uuid::new_v4();
73 Self {
74 remote_plan,
75 cache_server,
76 plan_registered: Arc::new(AtomicUsize::new(PlanRegisterState::NotRegistered as usize)),
77 cache_mode,
78 object_stores,
79 uuid,
80 metrics: ExecutionPlanMetricsSet::new(),
81 }
82 }
83
84 pub fn get_uuid(&self) -> Uuid {
86 self.uuid
87 }
88}
89
90impl DisplayAs for LiquidCacheClientExec {
91 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter<'_>) -> std::fmt::Result {
92 match t {
93 DisplayFormatType::Default | DisplayFormatType::Verbose => {
94 write!(
95 f,
96 "LiquidCacheClientExec: server={}, mode={}, object_stores={:?}",
97 self.cache_server, self.cache_mode, self.object_stores
98 )
99 }
100 DisplayFormatType::TreeRender => {
101 write!(
102 f,
103 "server={}, mode={}, object_stores={:?}",
104 self.cache_server, self.cache_mode, self.object_stores
105 )
106 }
107 }
108 }
109}
110
111impl ExecutionPlan for LiquidCacheClientExec {
112 fn as_any(&self) -> &dyn Any {
113 self
114 }
115
116 fn name(&self) -> &str {
117 "LiquidCacheClientExec"
118 }
119
120 fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
121 self.remote_plan.properties()
122 }
123
124 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
125 vec![&self.remote_plan]
126 }
127
128 fn with_new_children(
129 self: Arc<Self>,
130 children: Vec<Arc<dyn ExecutionPlan>>,
131 ) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
132 Ok(Arc::new(Self {
133 remote_plan: children.first().unwrap().clone(),
134 cache_server: self.cache_server.clone(),
135 plan_registered: self.plan_registered.clone(),
136 cache_mode: self.cache_mode,
137 object_stores: self.object_stores.clone(),
138 metrics: self.metrics.clone(),
139 uuid: self.uuid,
140 }))
141 }
142
143 fn execute(
144 &self,
145 partition: usize,
146 context: Arc<datafusion::execution::TaskContext>,
147 ) -> datafusion::error::Result<datafusion::execution::SendableRecordBatchStream> {
148 let cache_server = self.cache_server.clone();
149 let plan = self.remote_plan.clone();
150 let lock = self.plan_registered.clone();
151 let stream_metrics = FlightStreamMetrics::new(&self.metrics, partition);
152
153 let span = context
154 .session_config()
155 .get_extension::<Span>()
156 .unwrap_or_default();
157 let exec_span = Span::enter_with_parent("exec_flight_stream", &span);
158 let create_stream_span = Span::enter_with_parent("create_flight_stream", &exec_span);
159 let stream = flight_stream(
160 cache_server,
161 plan,
162 lock,
163 self.uuid,
164 partition,
165 self.object_stores.clone(),
166 );
167 Ok(Box::pin(FlightStream::new(
168 Some(Box::pin(stream)),
169 self.remote_plan.schema().clone(),
170 stream_metrics,
171 exec_span,
172 create_stream_span,
173 )))
174 }
175
176 fn required_input_distribution(&self) -> Vec<Distribution> {
177 self.remote_plan.required_input_distribution()
178 }
179
180 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
181 self.remote_plan.benefits_from_input_partitioning()
182 }
183
184 fn repartitioned(
185 &self,
186 target_partitions: usize,
187 config: &ConfigOptions,
188 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
189 self.remote_plan.repartitioned(target_partitions, config)
190 }
191
192 fn statistics(&self) -> Result<Statistics> {
193 self.remote_plan.statistics()
194 }
195
196 fn supports_limit_pushdown(&self) -> bool {
197 self.remote_plan.supports_limit_pushdown()
198 }
199
200 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
201 self.remote_plan.with_fetch(limit)
202 }
203
204 fn fetch(&self) -> Option<usize> {
205 self.remote_plan.fetch()
206 }
207
208 fn cardinality_effect(&self) -> CardinalityEffect {
209 self.remote_plan.cardinality_effect()
210 }
211
212 fn try_swapping_with_projection(
213 &self,
214 projection: &ProjectionExec,
215 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
216 self.remote_plan.try_swapping_with_projection(projection)
217 }
218
219 fn metrics(&self) -> Option<MetricsSet> {
220 Some(self.metrics.clone_inner())
221 }
222}
223
224async fn flight_stream(
225 server: String,
226 plan: Arc<dyn ExecutionPlan>,
227 plan_registered: Arc<AtomicUsize>,
228 handle: Uuid,
229 partition: usize,
230 object_stores: Vec<(ObjectStoreUrl, HashMap<String, String>)>,
231) -> Result<SendableRecordBatchStream> {
232 let channel = flight_channel(server)
233 .in_span(Span::enter_with_local_parent("connect_channel"))
234 .await?;
235
236 let mut client = FlightServiceClient::new(channel).max_decoding_message_size(1024 * 1024 * 8);
237 let schema = plan.schema().clone();
238
239 match plan_registered.compare_exchange(
240 PlanRegisterState::NotRegistered as usize,
241 PlanRegisterState::InProgress as usize,
242 Ordering::AcqRel,
243 Ordering::Relaxed,
244 ) {
245 Ok(_) => {
246 LocalSpan::add_event(Event::new("register_plan"));
247
248 for (url, options) in &object_stores {
249 let action = LiquidCacheActions::RegisterObjectStore(RegisterObjectStoreRequest {
250 url: url.to_string(),
251 options: options.clone(),
252 })
253 .into();
254 client
255 .do_action(Request::new(action))
256 .await
257 .map_err(to_df_err)?;
258 }
259 let plan_bytes = physical_plan_to_bytes(plan)?;
261 let action = LiquidCacheActions::RegisterPlan(RegisterPlanRequest {
262 plan: plan_bytes.to_vec(),
263 handle: handle.into_bytes().to_vec().into(),
264 })
265 .into();
266 client
267 .do_action(Request::new(action))
268 .await
269 .map_err(to_df_err)?;
270 plan_registered.store(PlanRegisterState::Registered as usize, Ordering::Release);
271 LocalSpan::add_event(Event::new("register_plan_done"));
272 }
273 Err(_e) => {
274 LocalSpan::add_event(Event::new("getting_existing_plan"));
275 while plan_registered.load(Ordering::Acquire) != PlanRegisterState::Registered as usize
276 {
277 tokio::time::sleep(Duration::from_micros(100)).await;
278 }
279 LocalSpan::add_event(Event::new("got_existing_plan"));
280 }
281 };
282
283 let current = SpanContext::current_local_parent().unwrap_or_else(SpanContext::random);
284
285 let fetch_results = FetchResults {
286 handle: handle.into_bytes().to_vec().into(),
287 partition: partition as u32,
288 traceparent: current.encode_w3c_traceparent(),
289 };
290 let ticket = fetch_results.into_ticket();
291 let (md, response_stream, _ext) = client.do_get(ticket).await.map_err(to_df_err)?.into_parts();
292 LocalSpan::add_event(Event::new("get_flight_stream"));
293 let stream =
294 FlightRecordBatchStream::new_from_flight_data(response_stream.map_err(|e| e.into()))
295 .with_headers(md)
296 .map_err(to_df_err);
297 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
298}
299
300enum FlightStreamState {
301 Init,
302 GetStream(BoxFuture<'static, Result<SendableRecordBatchStream>>),
303 Processing(SendableRecordBatchStream),
304}
305
306struct FlightStream {
307 future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
308 state: FlightStreamState,
309 schema: SchemaRef,
310 schema_mapper: Option<Arc<dyn SchemaMapper>>,
311 metrics: FlightStreamMetrics,
312 poll_stream_span: fastrace::Span,
313 create_stream_span: Option<fastrace::Span>,
314}
315
316impl FlightStream {
317 fn new(
318 future_stream: Option<BoxFuture<'static, Result<SendableRecordBatchStream>>>,
319 schema: SchemaRef,
320 metrics: FlightStreamMetrics,
321 poll_stream_span: fastrace::Span,
322 create_stream_span: fastrace::Span,
323 ) -> Self {
324 Self {
325 future_stream,
326 state: FlightStreamState::Init,
327 schema,
328 schema_mapper: None,
329 metrics,
330 poll_stream_span,
331 create_stream_span: Some(create_stream_span),
332 }
333 }
334}
335
336use futures::StreamExt;
337impl FlightStream {
338 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<RecordBatch>>> {
339 loop {
340 match &mut self.state {
341 FlightStreamState::Init => {
342 self.metrics.time_reading_total.start();
343 self.state = FlightStreamState::GetStream(self.future_stream.take().unwrap());
344 continue;
345 }
346 FlightStreamState::GetStream(fut) => {
347 let _guard = self.create_stream_span.as_ref().unwrap().set_local_parent();
348 let stream = ready!(fut.as_mut().poll(cx)).unwrap();
349 self.create_stream_span.take();
350 self.state = FlightStreamState::Processing(stream);
351 continue;
352 }
353 FlightStreamState::Processing(stream) => {
354 let result = stream.poll_next_unpin(cx);
355 self.metrics.poll_count.add(1);
356 return result;
357 }
358 }
359 }
360 }
361}
362
363impl Stream for FlightStream {
364 type Item = Result<RecordBatch>;
365
366 fn poll_next(
367 mut self: std::pin::Pin<&mut Self>,
368 cx: &mut std::task::Context<'_>,
369 ) -> std::task::Poll<Option<Self::Item>> {
370 let _guard = self.poll_stream_span.set_local_parent();
371 self.metrics.time_processing.start();
372 let result = self.poll_inner(cx);
373 match result {
374 Poll::Ready(Some(Ok(batch))) => {
375 let coerced_batch = if let Some(schema_mapper) = &self.schema_mapper {
376 schema_mapper.map_batch(batch).unwrap()
377 } else {
378 let (schema_mapper, _) =
379 DefaultSchemaAdapterFactory::from_schema(self.schema.clone())
380 .map_schema(&batch.schema())
381 .unwrap();
382 let batch = schema_mapper.map_batch(batch).unwrap();
383
384 self.schema_mapper = Some(schema_mapper);
385 batch
386 };
387 self.metrics.output_rows.add(coerced_batch.num_rows());
388 self.metrics
389 .bytes_decoded
390 .add(coerced_batch.get_array_memory_size());
391 self.metrics.time_processing.stop();
392 LocalSpan::add_event(Event::new("emit_batch"));
393 Poll::Ready(Some(Ok(coerced_batch)))
394 }
395 Poll::Ready(None) => {
396 self.metrics.time_processing.stop();
397 self.metrics.time_reading_total.stop();
398 Poll::Ready(None)
399 }
400 Poll::Ready(Some(Err(e))) => {
401 panic!("Error in flight stream: {e:?}");
402 }
403 Poll::Pending => {
404 self.metrics.time_processing.stop();
405 Poll::Pending
406 }
407 }
408 }
409}
410
411impl RecordBatchStream for FlightStream {
412 fn schema(&self) -> SchemaRef {
413 self.schema.clone()
414 }
415}