use crate::{LogicalPlan, PlanVisitor};
use arrow::datatypes::Schema;
use std::fmt;
pub struct IndentVisitor<'a, 'b> {
f: &'a mut fmt::Formatter<'b>,
with_schema: bool,
indent: usize,
}
impl<'a, 'b> IndentVisitor<'a, 'b> {
pub fn new(f: &'a mut fmt::Formatter<'b>, with_schema: bool) -> Self {
Self {
f,
with_schema,
indent: 0,
}
}
}
impl<'a, 'b> PlanVisitor for IndentVisitor<'a, 'b> {
type Error = fmt::Error;
fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<bool, fmt::Error> {
if self.indent > 0 {
writeln!(self.f)?;
}
write!(self.f, "{:indent$}", "", indent = self.indent * 2)?;
write!(self.f, "{}", plan.display())?;
if self.with_schema {
write!(
self.f,
" {}",
display_schema(&plan.schema().as_ref().to_owned().into())
)?;
}
self.indent += 1;
Ok(true)
}
fn post_visit(&mut self, _plan: &LogicalPlan) -> Result<bool, fmt::Error> {
self.indent -= 1;
Ok(true)
}
}
pub fn display_schema(schema: &Schema) -> impl fmt::Display + '_ {
struct Wrapper<'a>(&'a Schema);
impl<'a> fmt::Display for Wrapper<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "[")?;
for (idx, field) in self.0.fields().iter().enumerate() {
if idx > 0 {
write!(f, ", ")?;
}
let nullable_str = if field.is_nullable() { ";N" } else { "" };
write!(
f,
"{}:{:?}{}",
field.name(),
field.data_type(),
nullable_str
)?;
}
write!(f, "]")
}
}
Wrapper(schema)
}
#[derive(Default)]
struct GraphvizBuilder {
id_gen: usize,
}
impl GraphvizBuilder {
fn next_id(&mut self) -> usize {
self.id_gen += 1;
self.id_gen
}
fn start_cluster(&mut self, f: &mut fmt::Formatter, title: &str) -> fmt::Result {
writeln!(f, " subgraph cluster_{}", self.next_id())?;
writeln!(f, " {{")?;
writeln!(f, " graph[label={}]", Self::quoted(title))
}
fn end_cluster(&mut self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(f, " }}")
}
fn quoted(label: &str) -> String {
let label = label.replace('"', "_");
format!("\"{label}\"")
}
}
pub struct GraphvizVisitor<'a, 'b> {
f: &'a mut fmt::Formatter<'b>,
graphviz_builder: GraphvizBuilder,
with_schema: bool,
parent_ids: Vec<usize>,
}
impl<'a, 'b> GraphvizVisitor<'a, 'b> {
pub fn new(f: &'a mut fmt::Formatter<'b>) -> Self {
Self {
f,
graphviz_builder: GraphvizBuilder::default(),
with_schema: false,
parent_ids: Vec::new(),
}
}
pub fn set_with_schema(&mut self, with_schema: bool) {
self.with_schema = with_schema;
}
pub fn pre_visit_plan(&mut self, label: &str) -> fmt::Result {
self.graphviz_builder.start_cluster(self.f, label)
}
pub fn post_visit_plan(&mut self) -> fmt::Result {
self.graphviz_builder.end_cluster(self.f)
}
}
impl<'a, 'b> PlanVisitor for GraphvizVisitor<'a, 'b> {
type Error = fmt::Error;
fn pre_visit(&mut self, plan: &LogicalPlan) -> Result<bool, fmt::Error> {
let id = self.graphviz_builder.next_id();
let label = if self.with_schema {
format!(
r"{}\nSchema: {}",
plan.display(),
display_schema(&plan.schema().as_ref().to_owned().into())
)
} else {
format!("{}", plan.display())
};
writeln!(
self.f,
" {}[shape=box label={}]",
id,
GraphvizBuilder::quoted(&label)
)?;
if let Some(parent_id) = self.parent_ids.last() {
writeln!(
self.f,
" {parent_id} -> {id} [arrowhead=none, arrowtail=normal, dir=back]"
)?;
}
self.parent_ids.push(id);
Ok(true)
}
fn post_visit(&mut self, _plan: &LogicalPlan) -> Result<bool, fmt::Error> {
let res = self.parent_ids.pop();
match res {
Some(_) => Ok(true),
None => Err(fmt::Error),
}
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
use super::*;
#[test]
fn test_display_empty_schema() {
let schema = Schema::new(vec![]);
assert_eq!("[]", format!("{}", display_schema(&schema)));
}
#[test]
fn test_display_schema() {
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, true),
]);
assert_eq!(
"[id:Int32, first_name:Utf8;N]",
format!("{}", display_schema(&schema))
);
}
}