datafusion_distributed/networking/
channel_resolver.rs1use crate::DistributedConfig;
2use crate::config_extension_ext::set_distributed_option_extension;
3use crate::worker::generated::worker::worker_service_client::WorkerServiceClient;
4use async_trait::async_trait;
5use datafusion::common::{DataFusionError, config_datafusion_err, exec_datafusion_err};
6use datafusion::execution::TaskContext;
7use datafusion::prelude::SessionConfig;
8use futures::FutureExt;
9use futures::future::Shared;
10use std::sync::{Arc, LazyLock};
11use std::time::Duration;
12use tonic::body::Body;
13use tonic::codegen::BoxFuture;
14use tonic::transport::Channel;
15use tower::ServiceExt;
16use url::Url;
17
18#[async_trait]
33pub trait ChannelResolver {
34 async fn get_worker_client_for_url(
44 &self,
45 url: &Url,
46 ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError>;
47}
48
49pub(crate) fn set_distributed_channel_resolver(
50 cfg: &mut SessionConfig,
51 channel_resolver: impl ChannelResolver + Send + Sync + 'static,
52) {
53 let opts = cfg.options_mut();
54 let channel_resolver = ChannelResolverExtension(Some(Arc::new(channel_resolver)));
55 if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
56 distributed_cfg.__private_channel_resolver = channel_resolver;
57 } else {
58 set_distributed_option_extension(
59 cfg,
60 DistributedConfig {
61 __private_channel_resolver: channel_resolver,
62 ..Default::default()
63 },
64 )
65 }
66}
67
68static DEFAULT_CHANNEL_RESOLVER_PER_RUNTIME: LazyLock<
76 moka::sync::Cache<
77 usize,
78 Arc<DefaultChannelResolver>,
79 >,
80> = LazyLock::new(|| moka::sync::Cache::builder().max_capacity(256).build());
81
82pub fn get_distributed_channel_resolver(
83 task_ctx: &TaskContext,
84) -> Arc<dyn ChannelResolver + Send + Sync> {
85 let opts = task_ctx.session_config().options();
86 if let Some(distributed_cfg) = opts.extensions.get::<DistributedConfig>()
87 && let Some(cr) = &distributed_cfg.__private_channel_resolver.0
88 {
89 return Arc::clone(cr);
90 }
91 let runtime_addr = Arc::as_ptr(&task_ctx.runtime_env()) as usize;
92 DEFAULT_CHANNEL_RESOLVER_PER_RUNTIME
93 .get_with(runtime_addr, || Arc::new(DefaultChannelResolver::default()))
94}
95
96pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
97 http::Request<Body>,
98 http::Response<Body>,
99 tonic::transport::Error,
100>;
101
102type ChannelCacheValue = Shared<BoxFuture<BoxCloneSyncChannel, Arc<DataFusionError>>>;
103
104#[derive(Clone, Default)]
105pub(crate) struct ChannelResolverExtension(Option<Arc<dyn ChannelResolver + Send + Sync>>);
106
107#[derive(Clone)]
113pub struct DefaultChannelResolver {
114 cache: Arc<moka::sync::Cache<Url, ChannelCacheValue>>,
115}
116
117impl Default for DefaultChannelResolver {
118 fn default() -> Self {
119 Self {
120 cache: Arc::new(
121 moka::sync::Cache::builder()
122 .max_capacity(64556)
125 .time_to_idle(Duration::from_secs(5 * 60))
127 .build(),
128 ),
129 }
130 }
131}
132
133impl DefaultChannelResolver {
134 pub async fn get_channel(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
136 let channel = self.cache.get_with_by_ref(url, move || {
137 let url = url.to_string();
138 async move {
139 let endpoint = Channel::from_shared(url.clone()).map_err(|err| {
140 config_datafusion_err!(
141 "Invalid URL '{url}' returned by WorkerResolver implementation: {err}"
142 )
143 })?;
144 let mut channel = endpoint.connect().await.map_err(|err| {
145 DataFusionError::Context(
146 format!("{err:?}"),
147 Box::new(exec_datafusion_err!(
148 "Error connecting to Distributed DataFusion worker on '{url}': {err}"
149 )),
150 )
151 })?;
152 channel.ready().await.map_err(|err| {
153 DataFusionError::Context(
154 format!("{err:?}"),
155 Box::new(exec_datafusion_err!(
156 "Error waiting for Distributed DataFusion channel to be ready on '{url}': {err}"
157 )),
158 )
159 })?;
160 Ok(BoxCloneSyncChannel::new(channel))
161 }
162 .boxed()
163 .shared()
164 });
165
166 channel.await.map_err(|err| {
167 self.cache.invalidate(url);
168 DataFusionError::Shared(err)
169 })
170 }
171}
172
173#[async_trait]
174impl ChannelResolver for DefaultChannelResolver {
175 async fn get_worker_client_for_url(
176 &self,
177 url: &Url,
178 ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
179 self.get_channel(url).await.map(create_worker_client)
180 }
181}
182
183#[async_trait]
184impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
185 async fn get_worker_client_for_url(
186 &self,
187 url: &Url,
188 ) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
189 self.as_ref().get_worker_client_for_url(url).await
190 }
191}
192
193pub fn create_worker_client(
221 channel: BoxCloneSyncChannel,
222) -> WorkerServiceClient<BoxCloneSyncChannel> {
223 WorkerServiceClient::new(channel)
224 .max_decoding_message_size(usize::MAX)
225 .max_encoding_message_size(usize::MAX)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::Worker;
232 use datafusion::common::assert_contains;
233 use datafusion::common::runtime::SpawnedTask;
234 use std::error::Error;
235 use std::time::Instant;
236 use tokio::net::TcpListener;
237 use tonic::transport::Server;
238
239 #[tokio::test]
240 async fn fails_establishing_connection() -> Result<(), Box<dyn Error>> {
241 let (url, _guard) = spawn_http_localhost_worker().await?;
242 drop(_guard);
243 let channel_resolver = DefaultChannelResolver::default();
244 let err = channel_resolver.get_channel(&url).await.unwrap_err();
245 assert_contains!(err.to_string(), "tcp connect error");
246 Ok(())
247 }
248
249 #[tokio::test]
250 async fn can_establish_connection() -> Result<(), Box<dyn Error>> {
251 let (url, _guard) = spawn_http_localhost_worker().await?;
252 let channel_resolver = DefaultChannelResolver::default();
253 channel_resolver.get_channel(&url).await?;
254 Ok(())
255 }
256
257 #[tokio::test]
258 async fn channel_resolve_is_cached() -> Result<(), Box<dyn Error>> {
259 let (url, _guard) = spawn_http_localhost_worker().await?;
260 let channel_resolver = DefaultChannelResolver::default();
261
262 let start = Instant::now();
263 channel_resolver.get_channel(&url).await?;
264 let first_call = start.elapsed();
265
266 let start = Instant::now();
267 channel_resolver.get_channel(&url).await?;
268 let second_call = start.elapsed();
269
270 assert!(first_call > second_call);
271 Ok(())
272 }
273
274 async fn spawn_http_localhost_worker() -> Result<(Url, SpawnedTask<()>), Box<dyn Error>> {
275 let listener = TcpListener::bind("127.0.0.1:0").await?;
276
277 let port = listener
278 .local_addr()
279 .expect("Failed to get local address")
280 .port();
281
282 let task = SpawnedTask::spawn(async {
283 let worker = Worker::default();
284 let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
285 if let Err(err) = Server::builder()
286 .add_service(worker.into_worker_server())
287 .serve_with_incoming(incoming)
288 .await
289 {
290 panic!("{err}")
291 }
292 });
293
294 Ok((Url::parse(&format!("http://127.0.0.1:{port}"))?, task))
295 }
296}