use arrow::util::pretty::pretty_format_batches;
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::execution::SessionStateBuilder;
use datafusion::prelude::{ParquetReadOptions, SessionContext};
use datafusion_distributed::{
DefaultChannelResolver, DistributedExt, GetWorkerInfoRequest, SessionStateBuilderExt,
WorkerResolver, create_worker_client, display_plan_ascii,
};
use futures::TryStreamExt;
use std::error::Error;
use structopt::StructOpt;
use url::Url;
#[derive(StructOpt)]
#[structopt(
name = "versioned_run",
about = "A localhost Distributed DataFusion runner with worker version filtering"
)]
struct Args {
#[structopt()]
query: String,
#[structopt(long = "cluster-ports", use_delimiter = true)]
cluster_ports: Vec<u16>,
#[structopt(long)]
version: Option<String>,
#[structopt(long)]
show_distributed_plan: bool,
}
async fn worker_has_version(
channel_resolver: &DefaultChannelResolver,
url: &Url,
expected_version: &str,
) -> bool {
let Ok(channel) = channel_resolver.get_channel(url).await else {
return false;
};
let mut client = create_worker_client(channel);
let Ok(response) = client.get_worker_info(GetWorkerInfoRequest {}).await else {
return false;
};
response.into_inner().version == expected_version
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let args = Args::from_args();
let ports = if let Some(target_version) = &args.version {
let channel_resolver = DefaultChannelResolver::default();
let mut compatible = Vec::new();
for &port in &args.cluster_ports {
let url = Url::parse(&format!("http://localhost:{port}"))?;
if worker_has_version(&channel_resolver, &url, target_version).await {
compatible.push(port);
} else {
println!("Excluding worker on port {port} (version mismatch)");
}
}
if compatible.is_empty() {
return Err(format!("No workers matched version '{target_version}'").into());
}
println!(
"Using {}/{} workers matching version '{target_version}'\n",
compatible.len(),
args.cluster_ports.len()
);
compatible
} else {
args.cluster_ports
};
let localhost_resolver = LocalhostWorkerResolver { ports };
let state = SessionStateBuilder::new()
.with_default_features()
.with_distributed_worker_resolver(localhost_resolver)
.with_distributed_planner()
.with_distributed_file_scan_config_bytes_per_partition(1)?
.build();
let ctx = SessionContext::from(state);
ctx.register_parquet("weather", "testdata/weather", ParquetReadOptions::default())
.await?;
let df = ctx.sql(&args.query).await?;
if args.show_distributed_plan {
let plan = df.create_physical_plan().await?;
println!("{}", display_plan_ascii(plan.as_ref(), false));
} else {
let stream = df.execute_stream().await?;
let batches = stream.try_collect::<Vec<_>>().await?;
let formatted = pretty_format_batches(&batches)?;
println!("{formatted}");
}
Ok(())
}
#[derive(Clone)]
struct LocalhostWorkerResolver {
ports: Vec<u16>,
}
#[async_trait]
impl WorkerResolver for LocalhostWorkerResolver {
fn get_urls(&self) -> Result<Vec<Url>, DataFusionError> {
Ok(self
.ports
.iter()
.map(|port| Url::parse(&format!("http://localhost:{port}")).unwrap())
.collect())
}
}