use crate::DistributedConfig;
use crate::config_extension_ext::set_distributed_option_extension;
use crate::worker::generated::worker::worker_service_client::WorkerServiceClient;
use async_trait::async_trait;
use datafusion::common::{DataFusionError, config_datafusion_err, exec_datafusion_err};
use datafusion::execution::TaskContext;
use datafusion::prelude::SessionConfig;
use futures::FutureExt;
use futures::future::Shared;
use std::sync::{Arc, LazyLock};
use std::time::Duration;
use tonic::body::Body;
use tonic::codegen::BoxFuture;
use tonic::transport::Channel;
use tower::ServiceExt;
use url::Url;
#[async_trait]
pub trait ChannelResolver {
async fn get_worker_client_for_url(
&self,
url: &Url,
) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError>;
}
pub(crate) fn set_distributed_channel_resolver(
cfg: &mut SessionConfig,
channel_resolver: impl ChannelResolver + Send + Sync + 'static,
) {
let opts = cfg.options_mut();
let channel_resolver = ChannelResolverExtension(Some(Arc::new(channel_resolver)));
if let Some(distributed_cfg) = opts.extensions.get_mut::<DistributedConfig>() {
distributed_cfg.__private_channel_resolver = channel_resolver;
} else {
set_distributed_option_extension(
cfg,
DistributedConfig {
__private_channel_resolver: channel_resolver,
..Default::default()
},
)
}
}
static DEFAULT_CHANNEL_RESOLVER_PER_RUNTIME: LazyLock<
moka::sync::Cache<
usize,
Arc<DefaultChannelResolver>,
>,
> = LazyLock::new(|| moka::sync::Cache::builder().max_capacity(256).build());
pub fn get_distributed_channel_resolver(
task_ctx: &TaskContext,
) -> Arc<dyn ChannelResolver + Send + Sync> {
let opts = task_ctx.session_config().options();
if let Some(distributed_cfg) = opts.extensions.get::<DistributedConfig>()
&& let Some(cr) = &distributed_cfg.__private_channel_resolver.0
{
return Arc::clone(cr);
}
let runtime_addr = Arc::as_ptr(&task_ctx.runtime_env()) as usize;
DEFAULT_CHANNEL_RESOLVER_PER_RUNTIME
.get_with(runtime_addr, || Arc::new(DefaultChannelResolver::default()))
}
pub type BoxCloneSyncChannel = tower::util::BoxCloneSyncService<
http::Request<Body>,
http::Response<Body>,
tonic::transport::Error,
>;
type ChannelCacheValue = Shared<BoxFuture<BoxCloneSyncChannel, Arc<DataFusionError>>>;
#[derive(Clone, Default)]
pub(crate) struct ChannelResolverExtension(Option<Arc<dyn ChannelResolver + Send + Sync>>);
#[derive(Clone)]
pub struct DefaultChannelResolver {
cache: Arc<moka::sync::Cache<Url, ChannelCacheValue>>,
}
impl Default for DefaultChannelResolver {
fn default() -> Self {
Self {
cache: Arc::new(
moka::sync::Cache::builder()
.max_capacity(64556)
.time_to_idle(Duration::from_secs(5 * 60))
.build(),
),
}
}
}
impl DefaultChannelResolver {
pub async fn get_channel(&self, url: &Url) -> Result<BoxCloneSyncChannel, DataFusionError> {
let channel = self.cache.get_with_by_ref(url, move || {
let url = url.to_string();
async move {
let endpoint = Channel::from_shared(url.clone()).map_err(|err| {
config_datafusion_err!(
"Invalid URL '{url}' returned by WorkerResolver implementation: {err}"
)
})?;
let mut channel = endpoint.connect().await.map_err(|err| {
DataFusionError::Context(
format!("{err:?}"),
Box::new(exec_datafusion_err!(
"Error connecting to Distributed DataFusion worker on '{url}': {err}"
)),
)
})?;
channel.ready().await.map_err(|err| {
DataFusionError::Context(
format!("{err:?}"),
Box::new(exec_datafusion_err!(
"Error waiting for Distributed DataFusion channel to be ready on '{url}': {err}"
)),
)
})?;
Ok(BoxCloneSyncChannel::new(channel))
}
.boxed()
.shared()
});
channel.await.map_err(|err| {
self.cache.invalidate(url);
DataFusionError::Shared(err)
})
}
}
#[async_trait]
impl ChannelResolver for DefaultChannelResolver {
async fn get_worker_client_for_url(
&self,
url: &Url,
) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
self.get_channel(url).await.map(create_worker_client)
}
}
#[async_trait]
impl ChannelResolver for Arc<dyn ChannelResolver + Send + Sync> {
async fn get_worker_client_for_url(
&self,
url: &Url,
) -> Result<WorkerServiceClient<BoxCloneSyncChannel>, DataFusionError> {
self.as_ref().get_worker_client_for_url(url).await
}
}
pub fn create_worker_client(
channel: BoxCloneSyncChannel,
) -> WorkerServiceClient<BoxCloneSyncChannel> {
WorkerServiceClient::new(channel)
.max_decoding_message_size(usize::MAX)
.max_encoding_message_size(usize::MAX)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Worker;
use datafusion::common::assert_contains;
use datafusion::common::runtime::SpawnedTask;
use std::error::Error;
use std::time::Instant;
use tokio::net::TcpListener;
use tonic::transport::Server;
#[tokio::test]
async fn fails_establishing_connection() -> Result<(), Box<dyn Error>> {
let (url, _guard) = spawn_http_localhost_worker().await?;
drop(_guard);
let channel_resolver = DefaultChannelResolver::default();
let err = channel_resolver.get_channel(&url).await.unwrap_err();
assert_contains!(err.to_string(), "tcp connect error");
Ok(())
}
#[tokio::test]
async fn can_establish_connection() -> Result<(), Box<dyn Error>> {
let (url, _guard) = spawn_http_localhost_worker().await?;
let channel_resolver = DefaultChannelResolver::default();
channel_resolver.get_channel(&url).await?;
Ok(())
}
#[tokio::test]
async fn channel_resolve_is_cached() -> Result<(), Box<dyn Error>> {
let (url, _guard) = spawn_http_localhost_worker().await?;
let channel_resolver = DefaultChannelResolver::default();
let start = Instant::now();
channel_resolver.get_channel(&url).await?;
let first_call = start.elapsed();
let start = Instant::now();
channel_resolver.get_channel(&url).await?;
let second_call = start.elapsed();
assert!(first_call > second_call);
Ok(())
}
async fn spawn_http_localhost_worker() -> Result<(Url, SpawnedTask<()>), Box<dyn Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let port = listener
.local_addr()
.expect("Failed to get local address")
.port();
let task = SpawnedTask::spawn(async {
let worker = Worker::default();
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
if let Err(err) = Server::builder()
.add_service(worker.into_worker_server())
.serve_with_incoming(incoming)
.await
{
panic!("{err}")
}
});
Ok((Url::parse(&format!("http://127.0.0.1:{port}"))?, task))
}
}