Skip to main content

datafusion_distributed/worker/
session_builder.rs

1use async_trait::async_trait;
2use datafusion::error::DataFusionError;
3use datafusion::execution::{SessionState, SessionStateBuilder};
4use http::HeaderMap;
5use std::sync::Arc;
6
7#[derive(Debug, Default)]
8pub struct WorkerQueryContext {
9    pub builder: SessionStateBuilder,
10    pub headers: HeaderMap,
11}
12
13/// builds a DataFusion's [SessionState] in each query issued to a worker.
14#[async_trait]
15pub trait WorkerSessionBuilder {
16    /// Builds a custom [SessionState] scoped to a single ArrowFlight gRPC call, allowing the
17    /// users to provide a customized DataFusion session with things like custom extension codecs,
18    /// custom physical optimization rules, UDFs, UDAFs, config extensions, etc...
19    ///
20    /// Example:
21    ///
22    /// ```rust
23    /// # use std::sync::Arc;
24    /// # use async_trait::async_trait;
25    /// # use datafusion::error::DataFusionError;
26    /// # use datafusion::execution::{FunctionRegistry, SessionState, SessionStateBuilder, TaskContext};
27    /// # use datafusion::physical_plan::ExecutionPlan;
28    /// # use datafusion_proto::physical_plan::PhysicalExtensionCodec;
29    /// # use datafusion_distributed::{DistributedExt, WorkerSessionBuilder, WorkerQueryContext};
30    ///
31    /// #[derive(Debug)]
32    /// struct CustomExecCodec;
33    ///
34    /// impl PhysicalExtensionCodec for CustomExecCodec {
35    ///     fn try_decode(&self, buf: &[u8], inputs: &[Arc<dyn ExecutionPlan>], ctx: &TaskContext) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
36    ///         todo!()
37    ///     }
38    ///
39    ///     fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> datafusion::common::Result<()> {
40    ///         todo!()
41    ///     }
42    /// }
43    ///
44    /// #[derive(Clone)]
45    /// struct CustomSessionBuilder;
46    ///
47    /// #[async_trait]
48    /// impl WorkerSessionBuilder for CustomSessionBuilder {
49    ///     async fn build_session_state(&self, ctx: WorkerQueryContext) -> Result<SessionState, DataFusionError> {
50    ///         Ok(ctx
51    ///             .builder
52    ///             .with_distributed_user_codec(CustomExecCodec)
53    ///             // Add your UDFs, optimization rules, etc...
54    ///             .build())
55    ///     }
56    /// }
57    /// ```
58    async fn build_session_state(
59        &self,
60        ctx: WorkerQueryContext,
61    ) -> Result<SessionState, DataFusionError>;
62}
63
64/// Noop implementation of the [WorkerSessionBuilder]. Used by default if no [WorkerSessionBuilder]
65/// is provided while building the Worker.
66#[derive(Debug, Clone)]
67pub struct DefaultSessionBuilder;
68
69#[async_trait]
70impl WorkerSessionBuilder for DefaultSessionBuilder {
71    async fn build_session_state(
72        &self,
73        ctx: WorkerQueryContext,
74    ) -> Result<SessionState, DataFusionError> {
75        Ok(ctx.builder.build())
76    }
77}
78
79/// Implementation of [WorkerSessionBuilder] for any async function that returns a [Result]
80#[async_trait]
81impl<F, Fut> WorkerSessionBuilder for F
82where
83    F: Fn(WorkerQueryContext) -> Fut + Send + Sync + 'static,
84    Fut: std::future::Future<Output = Result<SessionState, DataFusionError>> + Send + 'static,
85{
86    async fn build_session_state(
87        &self,
88        ctx: WorkerQueryContext,
89    ) -> Result<SessionState, DataFusionError> {
90        self(ctx).await
91    }
92}
93
94pub trait MappedWorkerSessionBuilderExt {
95    /// Maps an existing [WorkerSessionBuilder] allowing to add further extensions
96    /// to its already built [SessionStateBuilder].
97    ///
98    /// Useful if there's already a [WorkerSessionBuilder] that needs to be extended
99    /// with further capabilities.
100    ///
101    /// Example:
102    ///
103    /// ```rust
104    /// # use datafusion::execution::SessionStateBuilder;
105    /// # use datafusion_distributed::{DefaultSessionBuilder, MappedWorkerSessionBuilderExt};
106    ///
107    /// let session_builder = DefaultSessionBuilder
108    ///     .map(|b: SessionStateBuilder| {
109    ///         // Add further things.
110    ///         Ok(b.build())
111    ///     });
112    /// ```
113    fn map<F>(self, f: F) -> MappedWorkerSessionBuilder<Self, F>
114    where
115        Self: Sized,
116        F: Fn(SessionStateBuilder) -> Result<SessionState, DataFusionError>;
117}
118
119impl<T: WorkerSessionBuilder> MappedWorkerSessionBuilderExt for T {
120    fn map<F>(self, f: F) -> MappedWorkerSessionBuilder<Self, F>
121    where
122        Self: Sized,
123    {
124        MappedWorkerSessionBuilder {
125            inner: self,
126            f: Arc::new(f),
127        }
128    }
129}
130
131pub struct MappedWorkerSessionBuilder<T, F> {
132    inner: T,
133    f: Arc<F>,
134}
135
136impl<T: Clone, F> Clone for MappedWorkerSessionBuilder<T, F> {
137    fn clone(&self) -> Self {
138        Self {
139            inner: self.inner.clone(),
140            f: self.f.clone(),
141        }
142    }
143}
144
145#[async_trait]
146impl<T, F> WorkerSessionBuilder for MappedWorkerSessionBuilder<T, F>
147where
148    T: WorkerSessionBuilder + Send + Sync + 'static,
149    F: Fn(SessionStateBuilder) -> Result<SessionState, DataFusionError> + Send + Sync,
150{
151    async fn build_session_state(
152        &self,
153        ctx: WorkerQueryContext,
154    ) -> Result<SessionState, DataFusionError> {
155        let state = self.inner.build_session_state(ctx).await?;
156        let builder = SessionStateBuilder::new_from_existing(state);
157        (self.f)(builder)
158    }
159}