use crate::input::proto::substrait;
use crate::input::traits::ProtoEnum;
use crate::output::comment;
use crate::output::diagnostic;
use crate::output::type_system::data;
use crate::parse::context;
use crate::parse::expressions;
use crate::parse::extensions;
fn parse_sort_direction(x: &i32, y: &mut context::Context) -> diagnostic::Result<&'static str> {
use substrait::sort_field::SortDirection;
match SortDirection::proto_enum_from_i32(*x) {
None => {
diagnostic!(
y,
Error,
IllegalValue,
"unknown value {x} for {}",
SortDirection::proto_enum_type()
);
Ok("Invalid sort by")
}
Some(SortDirection::Unspecified) => {
diagnostic!(y, Error, ProtoMissingField, "direction");
Ok("Invalid sort by")
}
Some(SortDirection::AscNullsFirst) => {
describe!(y, Misc, "Sort ascending, nulls first");
Ok("Ascending sort by")
}
Some(SortDirection::AscNullsLast) => {
describe!(y, Misc, "Sort ascending, nulls last");
Ok("Ascending sort by")
}
Some(SortDirection::DescNullsFirst) => {
describe!(y, Misc, "Sort descending, nulls first");
Ok("Descending sort by")
}
Some(SortDirection::DescNullsLast) => {
describe!(y, Misc, "Sort descending, nulls last");
Ok("Descending sort by")
}
Some(SortDirection::Clustered) => {
describe!(y, Misc, "Coalesce equal values");
summary!(
y,
"Equal values are grouped together, but no ordering is defined between clusters."
);
Ok("Coalesce")
}
}
}
fn parse_comparison_function_reference(
x: &u32,
y: &mut context::Context,
data_type: &data::Type,
) -> diagnostic::Result<&'static str> {
let functions = extensions::simple::parse_function_reference(x, y)?;
let argument =
expressions::functions::FunctionArgument::Value(data_type.clone(), Default::default());
let context = expressions::functions::FunctionContext {
function_type: expressions::functions::FunctionType::Scalar,
arguments: vec![argument.clone(), argument],
options: vec![],
return_type: data::new_unresolved_type(),
};
let binding = expressions::functions::FunctionBinding::new(Some(&functions), &context, y);
let comment = comment::Comment::new()
.plain("Comparison function for sorting. Taking two elements as input,")
.plain("it must determine the correct sort order. The return value is");
let comment = match binding.return_type.class() {
data::Class::Simple(data::class::Simple::Boolean) => {
let comment = comment
.plain("interpreted as the result of a < b, so:")
.lo()
.plain("f(a, b) => true: a sorts before b.")
.li()
.plain("f(a, b) => false: b sorts before a.");
if binding.return_type.nullable() {
comment
.li()
.plain("f(a, b) => null: a and b have no defined sort order.")
} else {
comment
}
}
data::Class::Simple(data::class::Simple::I8)
| data::Class::Simple(data::class::Simple::I16)
| data::Class::Simple(data::class::Simple::I32)
| data::Class::Simple(data::class::Simple::I64) => {
let comment = comment
.plain("interpreted as follows:")
.lo()
.plain("f(a, b) => negative: a sorts before b.")
.li()
.plain("f(a, b) => positive: b sorts before a.");
if binding.return_type.nullable() {
comment
.li()
.plain("f(a, b) => zero or null: a and b have no defined sort order.")
} else {
comment
.li()
.plain("f(a, b) => null: a and b have no defined sort order.")
}
}
_ => {
if !binding.return_type.is_unresolved() {
diagnostic!(
y,
Error,
TypeMismatch,
"comparison functions must yield booleans (a < b) or integers (a ?= b), but found {}",
binding.return_type
);
}
comment
.plain("interpreted as follows:")
.lo()
.plain("f(a, b) => true or negative: a sorts before b;")
.li()
.plain("f(a, b) => false or positive: b sorts before a;")
.li()
.plain("f(a, b) => 0 or null: a and b have no defined sort order.")
.lc()
}
};
y.push_summary(comment);
Ok("Custom sort")
}
fn parse_sort_kind(
x: &substrait::sort_field::SortKind,
y: &mut context::Context,
data_type: &data::Type,
) -> diagnostic::Result<&'static str> {
match x {
substrait::sort_field::SortKind::Direction(x) => parse_sort_direction(x, y),
substrait::sort_field::SortKind::ComparisonFunctionReference(x) => {
parse_comparison_function_reference(x, y, data_type)
}
}
}
pub fn parse_sort_field(
x: &substrait::SortField,
y: &mut context::Context,
) -> diagnostic::Result<expressions::Expression> {
let (n, e) = proto_required_field!(x, y, expr, expressions::parse_expression);
let expression = e.unwrap_or_default();
let method = proto_required_field!(x, y, sort_kind, parse_sort_kind, &n.data_type())
.1
.unwrap_or("Invalid sort by");
describe!(y, Misc, "{method} {expression}");
summary!(y, "{method} {expression:#}.");
Ok(expression)
}