use std::collections::HashSet;
use std::sync::Arc;
use arrow::{
array::{ArrayRef, ListArray, StringArray},
compute::SortOptions,
datatypes::{DataType, Field, Fields},
};
use datafusion::{
common::DataFusionError,
physical_expr::{LexOrdering, PhysicalSortExpr, expressions::col},
};
use itertools::Itertools as _;
use re_arrow_util::{ArrowArrayDowncastRef as _, RecordBatchExt as _};
use re_chunk::ArrowArray as _;
pub trait RecordBatchExt {
fn format_snapshot(&self, transposed: bool) -> String;
fn format_schema_snapshot(&self) -> String;
fn horizontally_sorted(&self) -> Self;
fn sort_property_columns(&self) -> Self;
fn sort_rows_by(&self, columns: &[&str]) -> Result<Self, DataFusionError>
where
Self: Sized;
fn auto_sort_rows(&self) -> Result<Self, DataFusionError>
where
Self: Sized;
fn with_columns(&self, columns: &[&str]) -> Option<Self>
where
Self: Sized;
fn replace_str(&self, column_name: &str, from: &str, to: &str) -> Self;
fn redact(&self, columns: &[&str]) -> Self;
fn project_columns(&self, columns: &[&str]) -> Self;
fn filter_columns_by_prefix(&self, prefix: &str) -> Self;
fn remove_columns(&self, columns: &[&str]) -> Self;
}
impl RecordBatchExt for arrow::array::RecordBatch {
fn format_snapshot(&self, transposed: bool) -> String {
re_arrow_util::format_record_batch_opts(
self,
&re_arrow_util::RecordBatchFormatOpts {
transposed,
width: Some(800),
include_metadata: false,
include_column_metadata: false,
..Default::default()
},
)
.to_string()
}
#[inline]
fn format_schema_snapshot(&self) -> String {
self.schema().format_snapshot()
}
fn horizontally_sorted(&self) -> Self {
self.clone()
.sort_columns_by(|f1, f2| f1.name().cmp(f2.name()))
.expect("should be able to sort")
}
fn sort_property_columns(&self) -> Self {
self.clone()
.sort_columns_by(|f1, f2| {
if f1.name().starts_with("property:") && f2.name().starts_with("property:") {
f1.name().cmp(f2.name())
} else {
std::cmp::Ordering::Equal
}
})
.expect("should be able to sort")
}
fn sort_rows_by(&self, columns: &[&str]) -> Result<Self, DataFusionError> {
let sort_exprs = columns
.iter()
.map(|column| {
Ok(PhysicalSortExpr::new(
col(column, self.schema_ref())?,
SortOptions::default(),
))
})
.collect::<Result<Vec<_>, DataFusionError>>()?;
let Some(ordering) = LexOrdering::new(sort_exprs) else {
return Ok(self.clone());
};
datafusion::physical_plan::sorts::sort::sort_batch(self, &ordering, None)
}
fn auto_sort_rows(&self) -> Result<Self, DataFusionError> {
let sort_exprs = self
.schema()
.fields()
.iter()
.map(|column| {
Ok(PhysicalSortExpr::new(
col(column.name(), self.schema_ref())?,
SortOptions::default(),
))
})
.collect::<Result<Vec<_>, DataFusionError>>()?;
let Some(ordering) = LexOrdering::new(sort_exprs) else {
return Ok(self.clone());
};
datafusion::physical_plan::sorts::sort::sort_batch(self, &ordering, None)
}
fn with_columns(&self, columns: &[&str]) -> Option<Self>
where
Self: Sized,
{
let mut fields = Vec::new();
let mut arrays = Vec::new();
let schema = self.schema();
for column in columns {
let (_, field) = schema.column_with_name(column)?;
fields.push(field.clone());
let array = self.column_by_name(column)?;
arrays.push(array.clone());
}
let schema = arrow::datatypes::Schema::new_with_metadata(fields, schema.metadata().clone());
Some(Self::try_new(Arc::new(schema), arrays).expect("creating record batch"))
}
fn replace_str(&self, column_name: &str, from: &str, to: &str) -> Self {
let schema = self.schema();
schema
.field_with_name(column_name)
.expect("Column not found in schema");
let mut arrays: Vec<ArrayRef> = Vec::new();
for column in schema.fields() {
let array = self.column_by_name(column.name()).expect("no such column");
if column.name() == column_name {
let string_array = array
.try_downcast_array_ref::<StringArray>()
.expect("expected column to be StringArray");
let new_values = string_array
.iter()
.map(|opt| opt.map(|s| s.replace(from, to)))
.collect_vec();
arrays.push(Arc::new(StringArray::from(new_values)) as ArrayRef);
} else {
arrays.push(array.clone());
}
}
if schema.fields().is_empty() {
Self::new_empty(schema)
} else {
Self::try_new(schema, arrays).expect("creation should succeed")
}
}
fn redact(&self, columns: &[&str]) -> Self {
let mut arrays = Vec::new();
let schema = self.schema();
for column in schema.fields() {
let array = self.column_by_name(column.name()).expect("no such column");
if !columns.contains(&column.name().as_str()) {
arrays.push(array.clone());
continue;
}
macro_rules! redact_array {
($array:expr, $array_type:ty, $redact_fn:expr) => {{
let typed_array = $array
.try_downcast_array_ref::<$array_type>()
.expect(concat!("expected column to be ", stringify!($array_type)));
let redacted_values = typed_array.iter().map($redact_fn).collect_vec();
Arc::new(<$array_type>::from(redacted_values)) as ArrayRef
}};
}
match column.data_type() {
arrow::datatypes::DataType::Utf8 => {
arrays.push(redact_array!(array, StringArray, |opt| opt.map(|_| "redacted")));
}
arrow::datatypes::DataType::Int64 => {
arrays
.push(redact_array!(array, arrow::array::Int64Array, |opt| opt.map(|_| 0)));
}
arrow::datatypes::DataType::List(field) => {
let list_array = array
.try_downcast_array_ref::<arrow::array::ListArray>()
.expect("expected column to be ListArray");
let (redacted_values, inner_field) = match field.data_type() {
arrow::datatypes::DataType::Utf8 => {
let redacted = redact_array!(
list_array.values(),
arrow::array::StringArray,
|opt| opt.map(|_| "redacted")
);
let field = Arc::new(Field::new("item", DataType::Utf8, true));
(redacted, field)
}
arrow::datatypes::DataType::Int64 => {
let redacted = redact_array!(
list_array.values(),
arrow::array::Int64Array,
|opt| opt.map(|_| 0)
);
let field = Arc::new(Field::new("item", DataType::Int64, true));
(redacted, field)
}
_ => {
panic!(
"Redaction not implemented for type {} inside a List",
field.data_type()
);
}
};
let offsets = list_array.offsets();
let list_nulls = list_array.nulls().cloned();
let redacted_list = ListArray::try_new(
inner_field,
offsets.clone(),
Arc::new(redacted_values),
list_nulls,
)
.expect("Failed to create ListArray");
arrays.push(Arc::new(redacted_list) as ArrayRef);
}
arrow::datatypes::DataType::Binary => {
arrays.push(redact_array!(array, arrow::array::BinaryArray, |opt| opt
.map(|_| [0u8; 8].as_slice())));
}
_ => {
panic!("Redaction not implemented for type {}", column.data_type());
}
}
}
if schema.fields().is_empty() {
Self::new_empty(schema.clone())
} else {
Self::try_new(schema.clone(), arrays).expect("creation should succeed")
}
}
fn remove_columns(&self, columns: &[&str]) -> Self {
self.clone()
.filter_columns_by(|field| !columns.contains(&field.name().as_str()))
.expect("should be able to filter")
}
fn project_columns(&self, columns: &[&str]) -> Self {
let col_idx = |field: &Field| columns.iter().position(|c| c == field.name());
self.clone()
.filter_columns_by(|field| columns.contains(&field.name().as_str()))
.expect("should be able to filter")
.sort_columns_by(|f1, f2| col_idx(f1).cmp(&col_idx(f2)))
.expect("should be able to sort")
}
fn filter_columns_by_prefix(&self, prefix: &str) -> Self {
self.clone()
.filter_columns_by(|field| field.name().starts_with(prefix))
.expect("should be able to filter")
}
}
pub trait SchemaExt {
fn format_snapshot(&self) -> String;
}
impl SchemaExt for arrow::datatypes::Schema {
fn format_snapshot(&self) -> String {
let metadata = (!self.metadata().is_empty()).then(|| {
format!(
"top-level metadata: [\n {}\n]",
self.metadata()
.iter()
.map(|(k, v)| format!("{k}:{v}"))
.sorted()
.join("\n ")
)
});
let mut fields = self.fields.iter().collect_vec();
fields.sort_by(|a, b| a.name().cmp(b.name()));
let fields = fields.into_iter().map(|field| {
if field.metadata().is_empty() {
format!(
"{}: {}",
field.name(),
re_arrow_util::format_data_type(field.data_type())
)
} else {
format!(
"{}: {} [\n {}\n]",
field.name(),
re_arrow_util::format_data_type(field.data_type()),
field
.metadata()
.iter()
.map(|(k, v)| format!("{k}:{v}"))
.sorted()
.join("\n ")
)
}
});
metadata.into_iter().chain(fields).join("\n")
}
}
pub trait FieldsExt {
fn contains_unordered(
&self,
required_fields: impl IntoIterator<Item = impl AsRef<Field>>,
) -> bool;
}
impl FieldsExt for Fields {
fn contains_unordered(
&self,
required_fields: impl IntoIterator<Item = impl AsRef<Field>>,
) -> bool {
let fields = self.iter().map(|f| f.as_ref()).collect::<HashSet<_>>();
required_fields
.into_iter()
.all(|f| fields.contains(f.as_ref()))
}
}