use std::ops::Index;
use arrow::datatypes::Schema;
use crate::Error;
use crate::Result;
use crate::errors::PipelinePlanningError;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum DisplaySlice {
Head(usize),
Tail(usize),
Sample(usize),
}
#[derive(Clone, Debug, PartialEq)]
pub enum ColumnSpec {
Exact(String),
CaseInsensitive(String),
}
#[derive(Clone, Debug, PartialEq)]
pub enum SelectItem {
Column(ColumnSpec),
Sum(ColumnSpec),
Avg(ColumnSpec),
Min(ColumnSpec),
Max(ColumnSpec),
}
#[macro_export]
macro_rules! select_spec {
( $($col:literal),+ $(,)? ) => {
$crate::pipeline::SelectSpec {
columns: vec![
$(
$crate::pipeline::SelectItem::Column(
$crate::pipeline::ColumnSpec::Exact($col.to_string())
)
),+
],
group_by: None,
}
};
( $(:$col:ident),+ $(,)? ) => {
$crate::pipeline::SelectSpec {
columns: vec![
$(
$crate::pipeline::SelectItem::Column(
$crate::pipeline::ColumnSpec::CaseInsensitive(stringify!($col).to_string())
)
),+
],
group_by: None,
}
};
}
impl ColumnSpec {
pub fn resolve(&self, schema: &Schema) -> Result<String> {
match self {
ColumnSpec::Exact(name) => {
schema.index_of(name)?;
Ok(name.clone())
}
ColumnSpec::CaseInsensitive(name) => schema
.fields()
.iter()
.find(|f| f.name().eq_ignore_ascii_case(name))
.map(|f| f.name().clone())
.ok_or_else(|| {
Error::PipelinePlanningError(PipelinePlanningError::ColumnNotFound(
name.clone(),
))
}),
}
}
}
impl SelectItem {
pub fn is_aggregate(&self) -> bool {
matches!(
self,
SelectItem::Sum(_) | SelectItem::Avg(_) | SelectItem::Min(_) | SelectItem::Max(_)
)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct SelectSpec {
pub columns: Vec<SelectItem>,
pub group_by: Option<Vec<ColumnSpec>>,
}
impl SelectSpec {
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
pub fn len(&self) -> usize {
self.columns.len()
}
pub fn has_aggregates(&self) -> bool {
self.columns.iter().any(SelectItem::is_aggregate)
}
pub fn is_aggregate_only(&self) -> bool {
!self.columns.is_empty() && self.columns.iter().all(SelectItem::is_aggregate)
}
pub fn has_group_by(&self) -> bool {
self.group_by.as_ref().is_some_and(|g| !g.is_empty())
}
pub fn from_cli_args(select: &Option<Vec<String>>) -> Option<Self> {
let inner = select.as_ref()?;
let mut columns = Vec::new();
for s in inner {
columns.extend(s.split(',').filter_map(|c| {
let c = c.trim();
if c.is_empty() {
None
} else {
Some(SelectItem::Column(ColumnSpec::Exact(c.to_string())))
}
}));
}
if columns.is_empty() {
None
} else {
Some(Self {
columns,
group_by: None,
})
}
}
pub fn resolve_names(&self, schema: &Schema) -> Result<Vec<String>> {
self.columns
.iter()
.map(|item| match item {
SelectItem::Column(s) => s.resolve(schema),
SelectItem::Sum(_)
| SelectItem::Avg(_)
| SelectItem::Min(_)
| SelectItem::Max(_) => Err(Error::PipelinePlanningError(
PipelinePlanningError::AggregatesInProjectionSelect,
)),
})
.collect()
}
}
impl Index<usize> for SelectSpec {
type Output = SelectItem;
fn index(&self, index: usize) -> &Self::Output {
&self.columns[index]
}
}
#[cfg(test)]
mod tests {
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::Schema;
use super::ColumnSpec;
use super::SelectItem;
#[test]
fn test_select_item_avg_is_aggregate() {
let item = SelectItem::Avg(ColumnSpec::CaseInsensitive("x".into()));
assert!(item.is_aggregate());
}
#[test]
fn test_select_item_min_max_are_aggregate() {
let min_item = SelectItem::Min(ColumnSpec::CaseInsensitive("x".into()));
let max_item = SelectItem::Max(ColumnSpec::CaseInsensitive("y".into()));
assert!(min_item.is_aggregate());
assert!(max_item.is_aggregate());
}
fn schema_with_columns(names: &[&str]) -> Schema {
let fields: Vec<Field> = names
.iter()
.map(|n| Field::new(*n, DataType::Utf8, true))
.collect();
Schema::new(fields)
}
#[test]
fn test_select_spec_resolve_exact_match() {
let schema = schema_with_columns(&["one", "two", "three"]);
let spec = crate::select_spec!("one", "three");
let resolved = spec.resolve_names(&schema).unwrap();
assert_eq!(resolved, vec!["one", "three"]);
}
#[test]
fn test_select_spec_resolve_exact_no_match_wrong_case() {
let schema = schema_with_columns(&["one", "two"]);
let spec = crate::select_spec!("ONE");
let result = spec.resolve_names(&schema);
assert!(result.is_err());
}
#[test]
fn test_select_spec_resolve_case_insensitive_match() {
let schema = schema_with_columns(&["One", "two", "Email"]);
let spec = crate::select_spec!(:ONE, :email);
let resolved = spec.resolve_names(&schema).unwrap();
assert_eq!(resolved, vec!["One", "Email"]);
}
#[test]
fn test_select_spec_resolve_case_insensitive_no_match() {
let schema = schema_with_columns(&["one", "two"]);
let spec = crate::select_spec!(:missing);
let result = spec.resolve_names(&schema);
assert!(result.is_err());
}
#[test]
fn test_select_spec_macro_strings_create_exact_columns() {
let spec = crate::select_spec!("one", "two");
assert_eq!(
spec.columns,
vec![
SelectItem::Column(ColumnSpec::Exact("one".into())),
SelectItem::Column(ColumnSpec::Exact("two".into())),
]
);
assert_eq!(spec.group_by, None);
}
#[test]
fn test_select_spec_macro_symbols_create_case_insensitive_columns() {
let spec = crate::select_spec!(:one, :two);
assert_eq!(
spec.columns,
vec![
SelectItem::Column(ColumnSpec::CaseInsensitive("one".into())),
SelectItem::Column(ColumnSpec::CaseInsensitive("two".into())),
]
);
assert_eq!(spec.group_by, None);
}
}