Skip to main content

datafusion_distributed/networking/
channel_resolver.rs

1use crate::DistributedConfig;
2use crate::config_extension_ext::set_distributed_option_extension;
3use crate::worker::generated::worker::worker_service_client::WorkerServiceClient;
4use async_trait::async_trait;
5use datafusion::common::{DataFusionError, config_datafusion_err, exec_datafusion_err};
6use datafusion::execution::TaskContext;
7use datafusion::prelude::SessionConfig;
8use futures::FutureExt;
9use futures::future::Shared;
10use std::sync::{Arc, LazyLock};
11use std::time::Duration;
12use tonic::body::Body;
13use tonic::codegen::BoxFuture;
14use tonic::transport::Channel;
15use tower::ServiceExt;
16use url::Url;
17
18/// Allows users to customize the way Worker clients are created. A common use case is to
19/// wrap the client with tower layers or schedule it in an IO-specific tokio runtime.
20///
21/// There is a default implementation of this trait that should be enough for the most common
22/// use-cases.
23///
24/// # Implementation Notes
25/// - This is called per gRPC request, so implementors of this trait should make sure that
26///   clients are reused across method calls instead of building a new Worker client every time.
27///
28/// - When implementing `get_worker_client_for_url`, it is recommended to use the
29///   [`create_worker_client`] helper function to ensure clients are configured with
30///   appropriate message size limits for internal communication. This helps avoid message
31///   size errors when transferring large datasets.
32#[async_trait]
33pub trait ChannelResolver {
34    /// For a given URL, get a Worker gRPC client for communicating to it.
35    ///
36    /// *WARNING*: This method is called for every gRPC request, so to not create
37    /// one client connection for each request, users are required to reuse generated clients.
38    /// It's recommended to rely on [DefaultChannelResolver] either by delegating method calls
39    /// to it or by copying the implementation.
40    ///
41    /// Consider using [`create_worker_client`] to create the client with appropriate
42    /// default message size limits.
43    async fn get_worker_client_for_url(
44        &self,
45        url: &Url,
46    ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError>;
47}
48
49pub(crate) fn set_distributed_channel_resolver(
50    cfg: &mut SessionConfig,
51    channel_resolver: impl ChannelResolver + Send + Sync + 'static,
52) {
53    let opts = cfg.options_mut();
54    let channel_resolver = ChannelResolverExtension(Some(Arc::new(channel_resolver)));
55    if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
56        distributed_cfg.__private_channel_resolver = channel_resolver;
57    } else {
58        set_distributed_option_extension(
59            cfg,
60            DistributedConfig {
61                __private_channel_resolver: channel_resolver,
62                ..Default::default()
63            },
64        )
65    }
66}
67
68// Unlike TaskContext, a DataFusion RuntimeEnv does not allow to introduce user-defined extensions.
69// For the default implementation of the ChannelResolvers, we cannot inject one DefaultChannelResolver
70// per TaskContext, as this holds reference to Tonic channels that must outlive a single TaskContext.
71//
72// The Tonic channels need to be established and reused under a whole RuntimeEnv scope, not a single
73// TaskContext, which forces us to put the default implementation in a static global variable that
74// stores and reuses tonic channels per RuntimeEnv's pointer address.
75static DEFAULT_CHANNEL_RESOLVER_PER_RUNTIME: LazyLock<
76    moka::sync::Cache<
77        /* Arc<RuntimeEnv> pointer address */ usize,
78        /* ChannelResolver that reuses built channels */ Arc<DefaultChannelResolver>,
79    >,
80> = LazyLock::new(|| moka::sync::Cache::builder().max_capacity(256).build());
81
82pub fn get_distributed_channel_resolver(
83    task_ctx: &TaskContext,
84) -> Arc<dyn ChannelResolver + Send + Sync> {
85    let opts = task_ctx.session_config().options();
86    if let Some(distributed_cfg) = opts.extensions.get::<DistributedConfig>()
87        && let Some(cr) = &distributed_cfg.__private_channel_resolver.0
88    {
89        return Arc::clone(cr);
90    }
91    let runtime_addr = Arc::as_ptr(&task_ctx.runtime_env()) as usize;
92    DEFAULT_CHANNEL_RESOLVER_PER_RUNTIME
93        .get_with(runtime_addr, || Arc::new(DefaultChannelResolver::default()))
94}
95
96pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
97    http::Request<Body>,
98    http::Response<Body>,
99    tonic::transport::Error,
100>;
101
102type ChannelCacheValue = Shared<BoxFuture<BoxCloneSyncChannel, Arc<DataFusionError>>>;
103
104#[derive(Clone, Default)]
105pub(crate) struct ChannelResolverExtension(Option<Arc<dyn ChannelResolver + Send + Sync>>);
106
107/// Default implementation of a [ChannelResolver] that connects to the workers given the URL once
108/// and stores the connection instance in a TTI cache.
109///
110/// Sane default over which other [ChannelResolver] can be built for better customization of the
111/// [WorkerServiceClient]s.
112#[derive(Clone)]
113pub struct DefaultChannelResolver {
114    cache: Arc<moka::sync::Cache<Url, ChannelCacheValue>>,
115}
116
117impl Default for DefaultChannelResolver {
118    fn default() -> Self {
119        Self {
120            cache: Arc::new(
121                moka::sync::Cache::builder()
122                    // Use an unrealistic max capacity, just in case there is a logic error on the
123                    // user part that produces an unreasonable amount of URLs.
124                    .max_capacity(64556)
125                    // If a channel has not been used in 5 mins, delete it.
126                    .time_to_idle(Duration::from_secs(5 * 60))
127                    .build(),
128            ),
129        }
130    }
131}
132
133impl DefaultChannelResolver {
134    /// Gets the cached [BoxCloneSyncChannel] for the given URL, or builds a new one.
135    pub async fn get_channel(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
136        let channel = self.cache.get_with_by_ref(url, move || {
137            let url = url.to_string();
138            async move {
139                let endpoint = Channel::from_shared(url.clone()).map_err(|err| {
140                    config_datafusion_err!(
141                        "Invalid URL '{url}' returned by WorkerResolver implementation: {err}"
142                    )
143                })?;
144                let mut channel = endpoint.connect().await.map_err(|err| {
145                    DataFusionError::Context(
146                        format!("{err:?}"),
147                        Box::new(exec_datafusion_err!(
148                            "Error connecting to Distributed DataFusion worker on '{url}': {err}"
149                        )),
150                    )
151                })?;
152                channel.ready().await.map_err(|err| {
153                    DataFusionError::Context(
154                        format!("{err:?}"),
155                        Box::new(exec_datafusion_err!(
156                            "Error waiting for Distributed DataFusion channel to be ready on '{url}': {err}"
157                        )),
158                    )
159                })?;
160                Ok(BoxCloneSyncChannel::new(channel))
161            }
162            .boxed()
163            .shared()
164        });
165
166        channel.await.map_err(|err| {
167            self.cache.invalidate(url);
168            DataFusionError::Shared(err)
169        })
170    }
171}
172
173#[async_trait]
174impl ChannelResolver for DefaultChannelResolver {
175    async fn get_worker_client_for_url(
176        &self,
177        url: &Url,
178    ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
179        self.get_channel(url).await.map(create_worker_client)
180    }
181}
182
183#[async_trait]
184impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
185    async fn get_worker_client_for_url(
186        &self,
187        url: &Url,
188    ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
189        self.as_ref().get_worker_client_for_url(url).await
190    }
191}
192
193/// Creates a [`WorkerServiceClient`] with high default message size limits.
194///
195/// This is a convenience function that wraps [`WorkerServiceClient::new`] and configures
196/// it with `max_decoding_message_size(usize::MAX)` and `max_encoding_message_size(usize::MAX)`
197/// to avoid message size limitations for internal communication.
198///
199/// Users implementing custom [`ChannelResolver`]s should use this function in their
200/// `get_worker_client_for_url` implementations to ensure consistent behavior with built-in
201/// implementations.
202///
203/// # Example
204///
205/// ```rust,ignore
206/// use datafusion_distributed::{create_worker_client, BoxCloneSyncChannel, ChannelResolver};
207/// /// use tonic::transport::Channel;
208///
209/// #[async_trait]
210/// impl ChannelResolver for MyResolver {
211///     async fn get_worker_client_for_url(
212///         &self,
213///         url: &Url,
214///     ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
215///         let channel = Channel::from_shared(url.to_string())?.connect().await?;
216///         Ok(create_worker_client(BoxCloneSyncChannel::new(channel)))
217///     }
218/// }
219/// ```
220pub fn create_worker_client(
221    channel: BoxCloneSyncChannel,
222) -> WorkerServiceClient<BoxCloneSyncChannel> {
223    WorkerServiceClient::new(channel)
224        .max_decoding_message_size(usize::MAX)
225        .max_encoding_message_size(usize::MAX)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::Worker;
232    use datafusion::common::assert_contains;
233    use datafusion::common::runtime::SpawnedTask;
234    use std::error::Error;
235    use std::time::Instant;
236    use tokio::net::TcpListener;
237    use tonic::transport::Server;
238
239    #[tokio::test]
240    async fn fails_establishing_connection() -> Result<(), Box<dyn Error>> {
241        let (url, _guard) = spawn_http_localhost_worker().await?;
242        drop(_guard);
243        let channel_resolver = DefaultChannelResolver::default();
244        let err = channel_resolver.get_channel(&url).await.unwrap_err();
245        assert_contains!(err.to_string(), "tcp connect error");
246        Ok(())
247    }
248
249    #[tokio::test]
250    async fn can_establish_connection() -> Result<(), Box<dyn Error>> {
251        let (url, _guard) = spawn_http_localhost_worker().await?;
252        let channel_resolver = DefaultChannelResolver::default();
253        channel_resolver.get_channel(&url).await?;
254        Ok(())
255    }
256
257    #[tokio::test]
258    async fn channel_resolve_is_cached() -> Result<(), Box<dyn Error>> {
259        let (url, _guard) = spawn_http_localhost_worker().await?;
260        let channel_resolver = DefaultChannelResolver::default();
261
262        let start = Instant::now();
263        channel_resolver.get_channel(&url).await?;
264        let first_call = start.elapsed();
265
266        let start = Instant::now();
267        channel_resolver.get_channel(&url).await?;
268        let second_call = start.elapsed();
269
270        assert!(first_call > second_call);
271        Ok(())
272    }
273
274    async fn spawn_http_localhost_worker() -> Result<(Url, SpawnedTask<()>), Box<dyn Error>> {
275        let listener = TcpListener::bind("127.0.0.1:0").await?;
276
277        let port = listener
278            .local_addr()
279            .expect("Failed to get local address")
280            .port();
281
282        let task = SpawnedTask::spawn(async {
283            let worker = Worker::default();
284            let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
285            if let Err(err) = Server::builder()
286                .add_service(worker.into_worker_server())
287                .serve_with_incoming(incoming)
288                .await
289            {
290                panic!("{err}")
291            }
292        });
293
294        Ok((Url::parse(&format!("http://127.0.0.1:{port}"))?, task))
295    }
296}