datafusion_distributed/worker/session_builder.rs
1use async_trait::async_trait;
2use datafusion::error::DataFusionError;
3use datafusion::execution::{SessionState, SessionStateBuilder};
4use http::HeaderMap;
5use std::sync::Arc;
6
7#[derive(Debug, Default)]
8pub struct WorkerQueryContext {
9 pub builder: SessionStateBuilder,
10 pub headers: HeaderMap,
11}
12
13/// builds a DataFusion's [SessionState] in each query issued to a worker.
14#[async_trait]
15pub trait WorkerSessionBuilder {
16 /// Builds a custom [SessionState] scoped to a single ArrowFlight gRPC call, allowing the
17 /// users to provide a customized DataFusion session with things like custom extension codecs,
18 /// custom physical optimization rules, UDFs, UDAFs, config extensions, etc...
19 ///
20 /// Example:
21 ///
22 /// ```rust
23 /// # use std::sync::Arc;
24 /// # use async_trait::async_trait;
25 /// # use datafusion::error::DataFusionError;
26 /// # use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder, TaskContext};
27 /// # use datafusion::physical_plan::ExecutionPlan;
28 /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
29 /// # use datafusion_distributed::{DistributedExt, WorkerSessionBuilder, WorkerQueryContext};
30 ///
31 /// #[derive(Debug)]
32 /// struct CustomExecCodec;
33 ///
34 /// impl PhysicalExtensionCodec for CustomExecCodec {
35 /// fn try_decode(&self, buf: &[u8], inputs: &[Arc<dyn ExecutionPlan>], ctx: &TaskContext) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
36 /// todo!()
37 /// }
38 ///
39 /// fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> datafusion::common::Result<()> {
40 /// todo!()
41 /// }
42 /// }
43 ///
44 /// #[derive(Clone)]
45 /// struct CustomSessionBuilder;
46 ///
47 /// #[async_trait]
48 /// impl WorkerSessionBuilder for CustomSessionBuilder {
49 /// async fn build_session_state(&self, ctx: WorkerQueryContext) -> Result<SessionState, DataFusionError> {
50 /// Ok(ctx
51 /// .builder
52 /// .with_distributed_user_codec(CustomExecCodec)
53 /// // Add your UDFs, optimization rules, etc...
54 /// .build())
55 /// }
56 /// }
57 /// ```
58 async fn build_session_state(
59 &self,
60 ctx: WorkerQueryContext,
61 ) -> Result<SessionState, DataFusionError>;
62}
63
64/// Noop implementation of the [WorkerSessionBuilder]. Used by default if no [WorkerSessionBuilder]
65/// is provided while building the Worker.
66#[derive(Debug, Clone)]
67pub struct DefaultSessionBuilder;
68
69#[async_trait]
70impl WorkerSessionBuilder for DefaultSessionBuilder {
71 async fn build_session_state(
72 &self,
73 ctx: WorkerQueryContext,
74 ) -> Result<SessionState, DataFusionError> {
75 Ok(ctx.builder.build())
76 }
77}
78
79/// Implementation of [WorkerSessionBuilder] for any async function that returns a [Result]
80#[async_trait]
81impl<F, Fut> WorkerSessionBuilder for F
82where
83 F: Fn(WorkerQueryContext) -> Fut + Send + Sync + 'static,
84 Fut: std::future::Future<Output = Result<SessionState, DataFusionError>> + Send + 'static,
85{
86 async fn build_session_state(
87 &self,
88 ctx: WorkerQueryContext,
89 ) -> Result<SessionState, DataFusionError> {
90 self(ctx).await
91 }
92}
93
94pub trait MappedWorkerSessionBuilderExt {
95 /// Maps an existing [WorkerSessionBuilder] allowing to add further extensions
96 /// to its already built [SessionStateBuilder].
97 ///
98 /// Useful if there's already a [WorkerSessionBuilder] that needs to be extended
99 /// with further capabilities.
100 ///
101 /// Example:
102 ///
103 /// ```rust
104 /// # use datafusion::execution::SessionStateBuilder;
105 /// # use datafusion_distributed::{DefaultSessionBuilder, MappedWorkerSessionBuilderExt};
106 ///
107 /// let session_builder = DefaultSessionBuilder
108 /// .map(|b: SessionStateBuilder| {
109 /// // Add further things.
110 /// Ok(b.build())
111 /// });
112 /// ```
113 fn map<F>(self, f: F) -> MappedWorkerSessionBuilder<Self, F>
114 where
115 Self: Sized,
116 F: Fn(SessionStateBuilder) -> Result<SessionState, DataFusionError>;
117}
118
119impl<T: WorkerSessionBuilder> MappedWorkerSessionBuilderExt for T {
120 fn map<F>(self, f: F) -> MappedWorkerSessionBuilder<Self, F>
121 where
122 Self: Sized,
123 {
124 MappedWorkerSessionBuilder {
125 inner: self,
126 f: Arc::new(f),
127 }
128 }
129}
130
131pub struct MappedWorkerSessionBuilder<T, F> {
132 inner: T,
133 f: Arc<F>,
134}
135
136impl<T: Clone, F> Clone for MappedWorkerSessionBuilder<T, F> {
137 fn clone(&self) -> Self {
138 Self {
139 inner: self.inner.clone(),
140 f: self.f.clone(),
141 }
142 }
143}
144
145#[async_trait]
146impl<T, F> WorkerSessionBuilder for MappedWorkerSessionBuilder<T, F>
147where
148 T: WorkerSessionBuilder + Send + Sync + 'static,
149 F: Fn(SessionStateBuilder) -> Result<SessionState, DataFusionError> + Send + Sync,
150{
151 async fn build_session_state(
152 &self,
153 ctx: WorkerQueryContext,
154 ) -> Result<SessionState, DataFusionError> {
155 let state = self.inner.build_session_state(ctx).await?;
156 let builder = SessionStateBuilder::new_from_existing(state);
157 (self.f)(builder)
158 }
159}