use std::{fmt::Debug, iter::zip, sync::Arc};
use arrow_schema::DataType;
use datafusion_common::{plan_err, Result};
use sedona_common::sedona_internal_err;
use crate::datatypes::{Edges, SedonaType, RASTER, WKB_GEOGRAPHY, WKB_GEOMETRY};
#[derive(Debug)]
pub struct ArgMatcher {
matchers: Vec<Arc<dyn TypeMatcher + Send + Sync>>,
out_type: SedonaType,
}
impl ArgMatcher {
pub fn new(matchers: Vec<Arc<dyn TypeMatcher + Send + Sync>>, out_type: SedonaType) -> Self {
Self { matchers, out_type }
}
pub fn match_args(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
if !self.matches(args) {
return Ok(None);
}
let geometry_arg_crses = args
.iter()
.filter(|arg_type| IsGeometryOrGeography {}.match_type(arg_type))
.map(|arg_type| match arg_type {
SedonaType::Wkb(_, crs) | SedonaType::WkbView(_, crs) => crs.clone(),
_ => None,
})
.collect::<Vec<_>>();
if geometry_arg_crses.is_empty() {
return Ok(Some(self.out_type.clone()));
}
let out_crs = geometry_arg_crses[0].clone();
for this_crs in geometry_arg_crses.into_iter().skip(1) {
if out_crs != this_crs {
let hint = "Use ST_Transform() or ST_SetSRID() to ensure arguments are compatible.";
return match (out_crs, this_crs) {
(None, Some(rhs_crs)) => {
plan_err!("Mismatched CRS arguments: None vs {rhs_crs}\n{hint}")
}
(Some(lhs_crs), None) => {
plan_err!("Mismatched CRS arguments: {lhs_crs} vs None\n{hint}")
}
(Some(lhs_crs), Some(rhs_crs)) => {
plan_err!("Mismatched CRS arguments: {lhs_crs} vs {rhs_crs}\n{hint}")
}
_ => sedona_internal_err!("None vs. None should be considered equal"),
};
}
}
match &self.out_type {
SedonaType::Wkb(edges, _) => Ok(Some(SedonaType::Wkb(*edges, out_crs))),
SedonaType::WkbView(edges, _) => Ok(Some(SedonaType::WkbView(*edges, out_crs))),
_ => Ok(Some(self.out_type.clone())),
}
}
pub fn matches(&self, args: &[SedonaType]) -> bool {
if args.len() > self.matchers.len() {
return false;
}
let matcher_iter = self.matchers.iter();
let mut arg_iter = args.iter().peekable();
for matcher in matcher_iter {
if let Some(arg) = arg_iter.peek() {
if arg == &&SedonaType::Arrow(DataType::Null) || matcher.match_type(arg) {
arg_iter.next(); continue; } else if matcher.optional() {
continue; } else {
return false; }
} else if matcher.optional() {
continue; } else {
return false; }
}
arg_iter.next().is_none()
}
pub fn types_if_null(&self, args: &[SedonaType]) -> Result<Vec<SedonaType>> {
let mut out = Vec::new();
for (arg, matcher) in zip(args, &self.matchers) {
if let SedonaType::Arrow(DataType::Null) = arg {
if let Some(type_if_null) = matcher.type_if_null() {
out.push(type_if_null);
} else {
return sedona_internal_err!(
"Matcher {matcher:?} does not provide type_if_null()"
);
}
} else {
out.push(arg.clone());
}
}
Ok(out)
}
pub fn is_any() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsAny {})
}
pub fn is_arrow(data_type: DataType) -> Arc<dyn TypeMatcher + Send + Sync> {
Self::is_exact(SedonaType::Arrow(data_type))
}
pub fn is_exact(exact_type: SedonaType) -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsExact { exact_type })
}
pub fn is_geometry_or_geography() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsGeometryOrGeography {})
}
pub fn is_geometry() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsGeometry {})
}
pub fn is_geography() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsGeography {})
}
pub fn is_item_crs() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsItemCrs {})
}
pub fn is_raster() -> Arc<dyn TypeMatcher + Send + Sync> {
Self::is_exact(RASTER)
}
pub fn is_null() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsNull {})
}
pub fn is_numeric() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsNumeric {})
}
pub fn is_integer() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsInteger {})
}
pub fn is_string() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsString {})
}
pub fn is_binary() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsBinary {})
}
pub fn is_boolean() -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(IsBoolean {})
}
pub fn optional(
matcher: Arc<dyn TypeMatcher + Send + Sync>,
) -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(OptionalMatcher { inner: matcher })
}
pub fn or(
matchers: Vec<Arc<dyn TypeMatcher + Send + Sync>>,
) -> Arc<dyn TypeMatcher + Send + Sync> {
Arc::new(OrMatcher { matchers })
}
}
pub trait TypeMatcher: Debug {
fn match_type(&self, arg: &SedonaType) -> bool;
fn optional(&self) -> bool {
false
}
fn type_if_null(&self) -> Option<SedonaType> {
None
}
}
#[derive(Debug)]
struct IsAny;
impl TypeMatcher for IsAny {
fn match_type(&self, _arg: &SedonaType) -> bool {
true
}
}
#[derive(Debug)]
struct IsExact {
exact_type: SedonaType,
}
impl TypeMatcher for IsExact {
fn match_type(&self, arg: &SedonaType) -> bool {
self.exact_type.match_signature(arg)
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(self.exact_type.clone())
}
}
#[derive(Debug)]
struct OptionalMatcher {
inner: Arc<dyn TypeMatcher + Send + Sync>,
}
impl TypeMatcher for OptionalMatcher {
fn match_type(&self, arg: &SedonaType) -> bool {
self.inner.match_type(arg)
}
fn optional(&self) -> bool {
true
}
fn type_if_null(&self) -> Option<SedonaType> {
self.inner.type_if_null()
}
}
#[derive(Debug)]
struct OrMatcher {
matchers: Vec<Arc<dyn TypeMatcher + Send + Sync>>,
}
impl TypeMatcher for OrMatcher {
fn match_type(&self, arg: &SedonaType) -> bool {
self.matchers.iter().any(|m| m.match_type(arg))
}
fn type_if_null(&self) -> Option<SedonaType> {
None
}
}
#[derive(Debug)]
struct IsGeometryOrGeography {}
impl TypeMatcher for IsGeometryOrGeography {
fn match_type(&self, arg: &SedonaType) -> bool {
matches!(arg, SedonaType::Wkb(_, _) | SedonaType::WkbView(_, _))
}
}
#[derive(Debug)]
struct IsGeometry {}
impl TypeMatcher for IsGeometry {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Wkb(edges, _) | SedonaType::WkbView(edges, _) => {
matches!(edges, Edges::Planar)
}
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(WKB_GEOMETRY)
}
}
#[derive(Debug)]
struct IsGeography {}
impl TypeMatcher for IsGeography {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Wkb(edges, _) | SedonaType::WkbView(edges, _) => {
matches!(edges, Edges::Spherical)
}
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(WKB_GEOGRAPHY)
}
}
#[derive(Debug)]
struct IsItemCrs {}
impl TypeMatcher for IsItemCrs {
fn match_type(&self, arg: &SedonaType) -> bool {
if let SedonaType::Arrow(DataType::Struct(fields)) = arg {
let field_names = fields.iter().map(|f| f.name()).collect::<Vec<_>>();
if field_names != ["item", "crs"] {
return false;
}
return true;
}
false
}
}
#[derive(Debug)]
struct IsNumeric {}
impl TypeMatcher for IsNumeric {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Arrow(data_type) => data_type.is_numeric(),
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(SedonaType::Arrow(DataType::Float64))
}
}
#[derive(Debug)]
struct IsInteger {}
impl TypeMatcher for IsInteger {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Arrow(data_type) => data_type.is_integer(),
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(SedonaType::Arrow(DataType::Int64))
}
}
#[derive(Debug)]
struct IsString {}
impl TypeMatcher for IsString {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Arrow(data_type) => {
matches!(
data_type,
DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8
)
}
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(SedonaType::Arrow(DataType::Utf8))
}
}
#[derive(Debug)]
struct IsBinary {}
impl TypeMatcher for IsBinary {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Arrow(data_type) => {
matches!(data_type, DataType::Binary | DataType::BinaryView)
}
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(SedonaType::Arrow(DataType::Binary))
}
}
#[derive(Debug)]
struct IsBoolean {}
impl TypeMatcher for IsBoolean {
fn match_type(&self, arg: &SedonaType) -> bool {
match arg {
SedonaType::Arrow(data_type) => {
matches!(data_type, DataType::Boolean)
}
_ => false,
}
}
fn type_if_null(&self) -> Option<SedonaType> {
Some(SedonaType::Arrow(DataType::Boolean))
}
}
#[derive(Debug)]
struct IsNull {}
impl TypeMatcher for IsNull {
fn match_type(&self, arg: &SedonaType) -> bool {
matches!(arg, SedonaType::Arrow(DataType::Null))
}
}
#[cfg(test)]
mod tests {
use crate::datatypes::{WKB_GEOGRAPHY, WKB_GEOMETRY};
use super::*;
#[test]
fn matchers() {
assert!(ArgMatcher::is_arrow(DataType::Null).match_type(&SedonaType::Arrow(DataType::Null)));
assert!(ArgMatcher::is_geometry_or_geography().match_type(&WKB_GEOMETRY));
assert!(ArgMatcher::is_geometry_or_geography().match_type(&WKB_GEOGRAPHY));
assert!(!ArgMatcher::is_geometry_or_geography()
.match_type(&SedonaType::Arrow(DataType::Binary)));
assert_eq!(ArgMatcher::is_geometry_or_geography().type_if_null(), None);
assert!(ArgMatcher::is_geometry().match_type(&WKB_GEOMETRY));
assert!(!ArgMatcher::is_geometry().match_type(&WKB_GEOGRAPHY));
assert_eq!(ArgMatcher::is_geometry().type_if_null(), Some(WKB_GEOMETRY));
assert!(ArgMatcher::is_geography().match_type(&WKB_GEOGRAPHY));
assert!(!ArgMatcher::is_geography().match_type(&WKB_GEOMETRY));
assert_eq!(
ArgMatcher::is_geography().type_if_null(),
Some(WKB_GEOGRAPHY)
);
assert!(ArgMatcher::is_numeric().match_type(&SedonaType::Arrow(DataType::Int32)));
assert!(ArgMatcher::is_numeric().match_type(&SedonaType::Arrow(DataType::Float64)));
assert_eq!(
ArgMatcher::is_numeric().type_if_null(),
Some(SedonaType::Arrow(DataType::Float64))
);
assert!(ArgMatcher::is_integer().match_type(&SedonaType::Arrow(DataType::UInt32)));
assert!(ArgMatcher::is_integer().match_type(&SedonaType::Arrow(DataType::Int32)));
assert!(!ArgMatcher::is_integer().match_type(&SedonaType::Arrow(DataType::Float64)));
assert!(ArgMatcher::is_string().match_type(&SedonaType::Arrow(DataType::Utf8)));
assert!(ArgMatcher::is_string().match_type(&SedonaType::Arrow(DataType::Utf8View)));
assert!(ArgMatcher::is_string().match_type(&SedonaType::Arrow(DataType::LargeUtf8)));
assert!(!ArgMatcher::is_string().match_type(&SedonaType::Arrow(DataType::Binary)));
assert_eq!(
ArgMatcher::is_string().type_if_null(),
Some(SedonaType::Arrow(DataType::Utf8))
);
assert!(ArgMatcher::is_binary().match_type(&SedonaType::Arrow(DataType::Binary)));
assert!(ArgMatcher::is_binary().match_type(&SedonaType::Arrow(DataType::BinaryView)));
assert!(!ArgMatcher::is_binary().match_type(&SedonaType::Arrow(DataType::Utf8)));
assert_eq!(
ArgMatcher::is_binary().type_if_null(),
Some(SedonaType::Arrow(DataType::Binary))
);
assert!(ArgMatcher::is_boolean().match_type(&SedonaType::Arrow(DataType::Boolean)));
assert!(!ArgMatcher::is_boolean().match_type(&SedonaType::Arrow(DataType::Int32)));
assert!(ArgMatcher::is_null().match_type(&SedonaType::Arrow(DataType::Null)));
assert!(!ArgMatcher::is_null().match_type(&SedonaType::Arrow(DataType::Int32)));
assert_eq!(
ArgMatcher::is_boolean().type_if_null(),
Some(SedonaType::Arrow(DataType::Boolean))
);
assert!(ArgMatcher::is_raster().match_type(&RASTER));
assert!(!ArgMatcher::is_raster().match_type(&SedonaType::Arrow(DataType::Int32)));
assert!(!ArgMatcher::is_raster().match_type(&WKB_GEOMETRY));
}
#[test]
fn optional_matcher() {
let matcher = ArgMatcher::new(
vec![
ArgMatcher::is_geometry(),
ArgMatcher::optional(ArgMatcher::is_boolean()),
ArgMatcher::optional(ArgMatcher::is_numeric()),
],
SedonaType::Arrow(DataType::Null),
);
assert!(matcher.matches(&[
WKB_GEOMETRY,
SedonaType::Arrow(DataType::Boolean),
SedonaType::Arrow(DataType::Int32)
]));
assert!(matcher.matches(&[WKB_GEOMETRY]));
assert!(matcher.matches(&[WKB_GEOMETRY, SedonaType::Arrow(DataType::Int32)]));
assert!(!matcher.matches(&[SedonaType::Arrow(DataType::Boolean)]));
assert!(!matcher.matches(&[WKB_GEOMETRY, WKB_GEOMETRY]));
assert!(!matcher.matches(&[
SedonaType::Arrow(DataType::Boolean),
SedonaType::Arrow(DataType::Boolean)
]));
assert!(!matcher.matches(&[
WKB_GEOGRAPHY,
SedonaType::Arrow(DataType::Boolean),
SedonaType::Arrow(DataType::Int32),
SedonaType::Arrow(DataType::Int32)
]));
}
#[test]
fn or_matcher() {
let matcher = ArgMatcher::new(
vec![
ArgMatcher::is_geometry(),
ArgMatcher::or(vec![ArgMatcher::is_boolean(), ArgMatcher::is_numeric()]),
],
SedonaType::Arrow(DataType::Null),
);
assert!(matcher.matches(&[WKB_GEOMETRY, SedonaType::Arrow(DataType::Boolean),]));
assert!(matcher.matches(&[WKB_GEOMETRY, SedonaType::Arrow(DataType::Int32)]));
assert!(!matcher.matches(&[WKB_GEOMETRY, WKB_GEOMETRY]));
assert!(!matcher.matches(&[
SedonaType::Arrow(DataType::Boolean),
SedonaType::Arrow(DataType::Boolean)
]));
assert_eq!(
ArgMatcher::or(vec![ArgMatcher::is_boolean(), ArgMatcher::is_numeric()]).type_if_null(),
None
);
}
#[test]
fn arg_matcher_matches_null() {
for type_matcher in [
ArgMatcher::is_arrow(DataType::Null),
ArgMatcher::is_arrow(DataType::Float32),
ArgMatcher::is_geometry_or_geography(),
ArgMatcher::is_geometry(),
ArgMatcher::is_geography(),
ArgMatcher::is_numeric(),
ArgMatcher::is_string(),
ArgMatcher::is_binary(),
ArgMatcher::is_boolean(),
ArgMatcher::optional(ArgMatcher::is_numeric()),
] {
let matcher = ArgMatcher::new(vec![type_matcher], SedonaType::Arrow(DataType::Null));
assert!(matcher.matches(&[SedonaType::Arrow(DataType::Null)]));
}
}
}