use std::{any::Any, fmt::Display, hash::Hash, sync::Arc};
use arrow::datatypes::SchemaRef;
use datafusion::{
error::{DataFusionError, Result},
logical_expr::Operator,
physical_plan::{
expressions::{col, lit, BinaryExpr, Column, Literal},
PhysicalExpr,
},
};
use noodles::core::region::Interval;
use crate::error::invalid_interval::InvalidIntervalError;
#[derive(Debug)]
pub struct PosIntervalPhysicalExpr {
start: usize,
end: Option<usize>,
inner: Arc<dyn PhysicalExpr>,
}
impl PosIntervalPhysicalExpr {
pub fn new(start: usize, end: Option<usize>, inner: Arc<dyn PhysicalExpr>) -> Self {
Self { start, end, inner }
}
pub fn interval(&self) -> Result<Interval> {
match self.end {
Some(end) => {
let interval = format!("{}-{}", self.start, end)
.parse::<Interval>()
.map_err(|_| DataFusionError::External(InvalidIntervalError.into()))?;
Ok(interval)
}
None => Err(DataFusionError::External(InvalidIntervalError.into())),
}
}
pub fn start(&self) -> usize {
self.start
}
pub fn end(&self) -> Option<usize> {
self.end
}
pub fn interval_tuple(&self) -> (usize, Option<usize>) {
(self.start, self.end)
}
pub fn inner(&self) -> &Arc<dyn PhysicalExpr> {
&self.inner
}
pub fn from_interval(start: usize, end: Option<usize>, schema: &SchemaRef) -> Result<Self> {
let start_expr = BinaryExpr::new(col("pos", schema)?, Operator::GtEq, lit(start as i64));
match end {
Some(end) => {
let end_expr =
BinaryExpr::new(col("pos", schema)?, Operator::LtEq, lit(end as i64));
let interval_expr =
BinaryExpr::new(Arc::new(start_expr), Operator::And, Arc::new(end_expr));
Ok(Self::new(start, Some(end), Arc::new(interval_expr)))
}
None => {
Ok(Self::new(start, None, Arc::new(start_expr)))
}
}
}
}
impl Display for PosIntervalPhysicalExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"IntervalPhysicalExpr {{ start: {}, end: {:?} }}",
self.start, self.end
)
}
}
impl TryFrom<BinaryExpr> for PosIntervalPhysicalExpr {
type Error = DataFusionError;
fn try_from(expr: BinaryExpr) -> Result<Self, Self::Error> {
let op = expr.op();
let left = expr.left().as_any().downcast_ref::<Column>();
let right = expr.right().as_any().downcast_ref::<Literal>();
if let (Some(col), Some(lit), _) = (left, right, op) {
if col.name() != "pos" {
return Err(DataFusionError::External("Invalid column for pos".into()));
} else {
match op {
Operator::Eq => {
let pos = lit.value().to_string().parse::<usize>().unwrap();
return Ok(Self::new(pos, Some(pos), Arc::new(expr)));
}
Operator::GtEq => {
let pos = lit.value().to_string().parse::<usize>().unwrap();
return Ok(Self::new(pos, None, Arc::new(expr)));
}
Operator::LtEq => {
let pos = lit.value().to_string().parse::<usize>().unwrap();
return Ok(Self::new(1, Some(pos), Arc::new(expr)));
}
_ => return Err(DataFusionError::External("Invalid operator for pos".into())),
}
}
};
Err(DataFusionError::External(
format!("invalid expression for pos: {}", expr).into(),
))
}
}
impl TryFrom<Arc<dyn PhysicalExpr>> for PosIntervalPhysicalExpr {
type Error = DataFusionError;
fn try_from(expr: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
Self::try_from(binary_expr.clone())
} else {
Err(DataFusionError::External(InvalidIntervalError.into()))
}
}
}
impl PartialEq<dyn Any> for PosIntervalPhysicalExpr {
fn eq(&self, other: &dyn Any) -> bool {
if let Some(other) = other.downcast_ref::<Self>() {
self.start == other.start && self.end == other.end
} else {
false
}
}
}
impl PartialEq for PosIntervalPhysicalExpr {
fn eq(&self, other: &Self) -> bool {
self.start == other.start && self.end == other.end
}
}
impl PhysicalExpr for PosIntervalPhysicalExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn data_type(
&self,
_input_schema: &arrow::datatypes::Schema,
) -> datafusion::error::Result<arrow::datatypes::DataType> {
Ok(arrow::datatypes::DataType::Boolean)
}
fn nullable(
&self,
_input_schema: &arrow::datatypes::Schema,
) -> datafusion::error::Result<bool> {
Ok(true)
}
fn evaluate(
&self,
batch: &arrow::record_batch::RecordBatch,
) -> datafusion::error::Result<datafusion::physical_plan::ColumnarValue> {
self.inner.evaluate(batch)
}
fn children(&self) -> Vec<&std::sync::Arc<dyn PhysicalExpr>> {
vec![]
}
fn with_new_children(
self: std::sync::Arc<Self>,
_children: Vec<std::sync::Arc<dyn PhysicalExpr>>,
) -> datafusion::error::Result<std::sync::Arc<dyn PhysicalExpr>> {
Ok(Arc::new(PosIntervalPhysicalExpr::new(
self.start,
self.end,
Arc::clone(&self.inner),
)))
}
fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
let mut s = state;
self.start.hash(&mut s);
self.end.hash(&mut s);
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use arrow::{array::BooleanArray, record_batch::RecordBatch};
use datafusion::{
logical_expr::Operator,
physical_plan::{
expressions::{col, lit, BinaryExpr},
PhysicalExpr,
},
};
use noodles::core::Position;
use crate::{
physical_plan::pos_interval_physical_expr,
tests::{eq, gteq},
};
use super::PosIntervalPhysicalExpr;
#[test]
fn test_call_interval_with_no_upper_bound() {
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
]));
let expr = gteq(col("pos", &schema).unwrap(), lit(4));
let interval_expr =
pos_interval_physical_expr::PosIntervalPhysicalExpr::try_from(expr).unwrap();
assert_eq!(interval_expr.start, 4);
assert_eq!(interval_expr.end, None);
assert!(interval_expr.interval().is_err());
}
#[test]
fn test_from_binary_exprs() {
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
]));
let pos_expr = eq(col("pos", &schema).unwrap(), lit(4));
let interval = super::PosIntervalPhysicalExpr::try_from(pos_expr).unwrap();
assert_eq!(
interval.interval().unwrap(),
noodles::core::region::Interval::from(
Position::new(4).unwrap()..=Position::new(4).unwrap()
)
);
}
#[tokio::test]
async fn test_evaluate() -> Result<(), Box<dyn std::error::Error>> {
let batch = RecordBatch::try_new(
Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
])),
vec![Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3]))],
)
.unwrap();
let binary_expr = eq(col("pos", &batch.schema()).unwrap(), lit(1i64));
let expr = pos_interval_physical_expr::PosIntervalPhysicalExpr::new(
1,
Some(1),
Arc::new(binary_expr),
);
let result = match expr.evaluate(&batch)? {
datafusion::physical_plan::ColumnarValue::Array(array) => array,
_ => panic!("Expected array"),
};
let result = result
.as_any()
.downcast_ref::<arrow::array::BooleanArray>()
.unwrap();
let expected = BooleanArray::from(vec![Some(true), Some(false), Some(false)]);
result
.iter()
.zip(expected.iter())
.for_each(|(result, expected)| {
assert_eq!(result, expected);
});
Ok(())
}
#[test]
fn test_from_interval() {
let schema = Arc::new(arrow::datatypes::Schema::new(vec![
arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
]));
let interval_expr = PosIntervalPhysicalExpr::from_interval(1, Some(10), &schema).unwrap();
let inner_expr = interval_expr
.inner
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
match inner_expr.op() {
Operator::And => {}
_ => panic!("Expected AND operator"),
}
let interval_expr = PosIntervalPhysicalExpr::from_interval(1, None, &schema).unwrap();
let inner_expr = interval_expr
.inner
.as_any()
.downcast_ref::<BinaryExpr>()
.unwrap();
match inner_expr.op() {
Operator::GtEq => {}
_ => panic!("Expected GtEq operator"),
}
}
}