use std::net::IpAddr;
use std::sync::Arc;
use std::{any::Any, net::Ipv4Addr, net::Ipv6Addr};
use datafusion::arrow::array::{Array, ArrayRef, AsArray, BooleanArray};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
Volatility,
};
use super::string_utils::{scalar_to_str, STRING_TYPES};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct IsIpv4Udf {
signature: Signature,
}
impl Default for IsIpv4Udf {
fn default() -> Self {
Self::new()
}
}
impl IsIpv4Udf {
pub fn new() -> Self {
let sigs: Vec<TypeSignature> = STRING_TYPES
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone()]))
.collect();
Self {
signature: Signature::new(TypeSignature::OneOf(sigs), Volatility::Immutable),
}
}
}
fn map_string_to_bool<T, F>(array: &T, f: F) -> ArrayRef
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
F: Fn(&str) -> Option<bool>,
{
let result: BooleanArray = array.into_iter().map(|opt| opt.and_then(&f)).collect();
Arc::new(result)
}
impl ScalarUDFImpl for IsIpv4Udf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_is_ipv4"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 1 {
return exec_err!("is_ipv4 expects exactly 1 argument, got {}", args.len());
}
match &args[0] {
ColumnarValue::Scalar(scalar) => {
let result = scalar_to_str(scalar)?.and_then(check_is_ipv4);
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
}
ColumnarValue::Array(array) => {
let result = match array.data_type() {
DataType::Utf8 => map_string_to_bool(array.as_string::<i32>(), check_is_ipv4),
DataType::LargeUtf8 => {
map_string_to_bool(array.as_string::<i64>(), check_is_ipv4)
}
DataType::Utf8View => map_string_to_bool(array.as_string_view(), check_is_ipv4),
other => return exec_err!("is_ipv4 expects string array, got {}", other),
};
Ok(ColumnarValue::Array(result))
}
}
}
}
fn check_is_ipv4(s: &str) -> Option<bool> {
match s.parse::<IpAddr>() {
Ok(IpAddr::V4(_)) => Some(true),
Ok(IpAddr::V6(v6)) => {
Some(v6.to_ipv4_mapped().is_some())
}
Err(_) => None,
}
}
pub fn is_ipv4_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(IsIpv4Udf::new())
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct IsIpv6Udf {
signature: Signature,
}
impl Default for IsIpv6Udf {
fn default() -> Self {
Self::new()
}
}
impl IsIpv6Udf {
pub fn new() -> Self {
let sigs: Vec<TypeSignature> = STRING_TYPES
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone()]))
.collect();
Self {
signature: Signature::new(TypeSignature::OneOf(sigs), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for IsIpv6Udf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_is_ipv6"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 1 {
return exec_err!("is_ipv6 expects exactly 1 argument, got {}", args.len());
}
match &args[0] {
ColumnarValue::Scalar(scalar) => {
let result = scalar_to_str(scalar)?.and_then(check_is_ipv6);
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
}
ColumnarValue::Array(array) => {
let result = match array.data_type() {
DataType::Utf8 => map_string_to_bool(array.as_string::<i32>(), check_is_ipv6),
DataType::LargeUtf8 => {
map_string_to_bool(array.as_string::<i64>(), check_is_ipv6)
}
DataType::Utf8View => map_string_to_bool(array.as_string_view(), check_is_ipv6),
other => return exec_err!("is_ipv6 expects string array, got {}", other),
};
Ok(ColumnarValue::Array(result))
}
}
}
}
fn check_is_ipv6(s: &str) -> Option<bool> {
match s.parse::<IpAddr>() {
Ok(IpAddr::V6(v6)) => {
Some(v6.to_ipv4_mapped().is_none())
}
Ok(IpAddr::V4(_)) => Some(false),
Err(_) => None,
}
}
pub fn is_ipv6_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(IsIpv6Udf::new())
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct CidrContainsUdf {
signature: Signature,
}
impl Default for CidrContainsUdf {
fn default() -> Self {
Self::new()
}
}
impl CidrContainsUdf {
pub fn new() -> Self {
let mut sigs = Vec::new();
for s1 in &STRING_TYPES {
for s2 in &STRING_TYPES {
sigs.push(TypeSignature::Exact(vec![s1.clone(), s2.clone()]));
}
}
Self {
signature: Signature::new(TypeSignature::OneOf(sigs), Volatility::Immutable),
}
}
}
fn cidr_contains_both_arrays<C, I>(cidrs: &C, ips: &I) -> ArrayRef
where
C: Array + 'static,
I: Array + 'static,
for<'a> &'a C: IntoIterator<Item = Option<&'a str>>,
for<'a> &'a I: IntoIterator<Item = Option<&'a str>>,
{
let result: BooleanArray = cidrs
.into_iter()
.zip(ips.into_iter())
.map(|(cidr_opt, ip_opt)| match (cidr_opt, ip_opt) {
(Some(cidr), Some(ip)) => cidr_contains_impl(cidr, ip),
_ => None,
})
.collect();
Arc::new(result)
}
fn cidr_scalar_ip_array<T>(parsed_cidr: &Option<ParsedCidr>, ips: &T) -> ArrayRef
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let result: BooleanArray = match parsed_cidr {
Some(cidr) => ips
.into_iter()
.map(|ip_opt| ip_opt.and_then(|ip| cidr.contains(ip)))
.collect(),
None => ips.into_iter().map(|_| None).collect(),
};
Arc::new(result)
}
fn cidr_array_ip_scalar<T>(cidrs: &T, ip: Option<&str>) -> ArrayRef
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let result: BooleanArray = match ip {
Some(ip) => cidrs
.into_iter()
.map(|cidr_opt| cidr_opt.and_then(|cidr| cidr_contains_impl(cidr, ip)))
.collect(),
None => cidrs.into_iter().map(|_| None).collect(),
};
Arc::new(result)
}
macro_rules! dispatch_string_array {
($array:expr, $func:expr) => {
match $array.data_type() {
DataType::Utf8 => $func($array.as_string::<i32>()),
DataType::LargeUtf8 => $func($array.as_string::<i64>()),
DataType::Utf8View => $func($array.as_string_view()),
other => {
return exec_err!("cidr_contains expects string array, got {}", other);
}
}
};
}
impl ScalarUDFImpl for CidrContainsUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_cidr_contains"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;
if args.len() != 2 {
return exec_err!(
"cidr_contains expects exactly 2 arguments, got {}",
args.len()
);
}
match (&args[0], &args[1]) {
(ColumnarValue::Scalar(cidr_scalar), ColumnarValue::Scalar(ip_scalar)) => {
let cidr_str = scalar_to_str(cidr_scalar)?;
let ip_str = scalar_to_str(ip_scalar)?;
let result = match (cidr_str, ip_str) {
(Some(cidr), Some(ip)) => cidr_contains_impl(cidr, ip),
_ => None,
};
Ok(ColumnarValue::Scalar(ScalarValue::Boolean(result)))
}
(ColumnarValue::Array(cidr_array), ColumnarValue::Array(ip_array)) => {
let result = match cidr_array.data_type() {
DataType::Utf8 => {
let cidrs = cidr_array.as_string::<i32>();
dispatch_string_array!(ip_array, |ips| cidr_contains_both_arrays(
cidrs, ips
))
}
DataType::LargeUtf8 => {
let cidrs = cidr_array.as_string::<i64>();
dispatch_string_array!(ip_array, |ips| cidr_contains_both_arrays(
cidrs, ips
))
}
DataType::Utf8View => {
let cidrs = cidr_array.as_string_view();
dispatch_string_array!(ip_array, |ips| cidr_contains_both_arrays(
cidrs, ips
))
}
other => {
return exec_err!(
"cidr_contains expects string array for cidr, got {}",
other
)
}
};
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(cidr_scalar), ColumnarValue::Array(ip_array)) => {
let parsed_cidr = scalar_to_str(cidr_scalar)?.and_then(ParsedCidr::parse);
let result =
dispatch_string_array!(ip_array, |ips| cidr_scalar_ip_array(&parsed_cidr, ips));
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(cidr_array), ColumnarValue::Scalar(ip_scalar)) => {
let ip_str = scalar_to_str(ip_scalar)?;
let result =
dispatch_string_array!(cidr_array, |cidrs| cidr_array_ip_scalar(cidrs, ip_str));
Ok(ColumnarValue::Array(result))
}
}
}
}
enum ParsedCidr {
V4 { masked_network: u32, mask: u32 },
V6 { masked_network: u128, mask: u128 },
}
impl ParsedCidr {
fn parse(cidr: &str) -> Option<Self> {
let (network_str, prefix_str) = cidr.split_once('/')?;
let prefix_len: u8 = prefix_str.parse().ok()?;
if let Ok(network) = network_str.parse::<Ipv4Addr>() {
if prefix_len > 32 {
return None;
}
let mask = if prefix_len == 0 {
0
} else {
!0u32 << (32 - prefix_len)
};
return Some(ParsedCidr::V4 {
masked_network: u32::from(network) & mask,
mask,
});
}
if let Ok(network) = network_str.parse::<Ipv6Addr>() {
if prefix_len > 128 {
return None;
}
let mask = if prefix_len == 0 {
0
} else {
!0u128 << (128 - prefix_len)
};
return Some(ParsedCidr::V6 {
masked_network: u128::from(network) & mask,
mask,
});
}
None
}
fn contains(&self, ip: &str) -> Option<bool> {
let addr: IpAddr = ip.parse().ok()?;
match (self, addr) {
(
ParsedCidr::V4 {
masked_network,
mask,
},
IpAddr::V4(v4),
) => Some((u32::from(v4) & mask) == *masked_network),
(
ParsedCidr::V6 {
masked_network,
mask,
},
IpAddr::V6(v6),
) => Some((u128::from(v6) & mask) == *masked_network),
_ => Some(false),
}
}
}
fn cidr_contains_impl(cidr: &str, ip: &str) -> Option<bool> {
ParsedCidr::parse(cidr)?.contains(ip)
}
pub fn cidr_contains_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CidrContainsUdf::new())
}