use std::{any::Any, fmt::Display, str::FromStr, sync::Arc};
use datafusion::{
error::{DataFusionError, Result},
physical_plan::{expressions::BinaryExpr, PhysicalExpr},
};
use noodles::core::{region::Interval, Region};
use crate::error::{invalid_chrom::InvalidRegionNameError, invalid_region::InvalidRegionError};
use super::{
region_name_physical_expr::RegionNamePhysicalExpr,
start_end_interval_physical_expr::StartEndIntervalPhysicalExpr,
};
#[derive(Debug)]
pub struct StartEndRegionPhysicalExpr {
region_name_expr: Arc<dyn PhysicalExpr>,
interval_expr: Option<Arc<dyn PhysicalExpr>>,
}
impl StartEndRegionPhysicalExpr {
pub fn new(
region_name_expr: Arc<dyn PhysicalExpr>,
interval_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Self {
Self {
region_name_expr,
interval_expr,
}
}
pub fn region(&self) -> Result<Region> {
let internal_region_name_expr = self.region_name_expr().ok_or(InvalidRegionNameError)?;
let field_value = internal_region_name_expr.field_value();
match self.interval_expr() {
Some(interval_expr) => {
let start = interval_expr.start();
let end = interval_expr.end().ok_or(InvalidRegionError)?;
let interval_str = format!("{}-{}", start, end);
let interval = Interval::from_str(&interval_str).map_err(|_| InvalidRegionError)?;
let region = Region::new(field_value, interval);
Ok(region)
}
None => {
let region = field_value.parse().map_err(|_| InvalidRegionNameError)?;
Ok(region)
}
}
}
pub fn interval_expr(&self) -> Option<&StartEndIntervalPhysicalExpr> {
self.interval_expr
.as_ref()
.and_then(|expr| expr.as_any().downcast_ref::<StartEndIntervalPhysicalExpr>())
}
pub fn region_name_expr(&self) -> Option<&RegionNamePhysicalExpr> {
self.region_name_expr
.as_any()
.downcast_ref::<RegionNamePhysicalExpr>()
}
}
impl Display for StartEndRegionPhysicalExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"StartEndRegionPhysicalExpr {{ region_name: {}, interval: {:?} }}",
self.region_name_expr, self.interval_expr,
)
}
}
impl From<RegionNamePhysicalExpr> for StartEndRegionPhysicalExpr {
fn from(value: RegionNamePhysicalExpr) -> Self {
let chrom_expr = Arc::new(value);
Self::new(chrom_expr, None)
}
}
impl TryFrom<BinaryExpr> for StartEndRegionPhysicalExpr {
type Error = DataFusionError;
fn try_from(expr: BinaryExpr) -> Result<Self, Self::Error> {
if let Ok(chrom) = RegionNamePhysicalExpr::try_from(expr.clone()) {
let new_region = Self::from(chrom);
return Ok(new_region);
}
let chrom_op = expr
.left()
.as_any()
.downcast_ref::<BinaryExpr>()
.map(|e| RegionNamePhysicalExpr::try_from(e.clone()))
.transpose()?;
let pos_op = expr
.right()
.as_any()
.downcast_ref::<BinaryExpr>()
.map(|binary_expr| StartEndRegionPhysicalExpr::try_from(binary_expr.clone()))
.transpose()?;
match (chrom_op, pos_op) {
(Some(chrom), Some(pos)) => Ok(Self::new(Arc::new(chrom), Some(Arc::new(pos)))),
(_, _) => Err(DataFusionError::External(InvalidRegionError.into())),
}
}
}
impl TryFrom<Arc<dyn PhysicalExpr>> for StartEndRegionPhysicalExpr {
type Error = DataFusionError;
fn try_from(expr: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
Self::try_from(binary_expr.clone())
} else {
Err(DataFusionError::External(InvalidRegionError.into()))
}
}
}
impl PartialEq<dyn Any> for StartEndRegionPhysicalExpr {
fn eq(&self, other: &dyn Any) -> bool {
if let Some(other) = other.downcast_ref::<StartEndRegionPhysicalExpr>() {
self == other
} else {
false
}
}
}
impl PartialEq for StartEndRegionPhysicalExpr {
fn eq(&self, other: &Self) -> bool {
let left_interval = match self.interval_expr() {
Some(interval_expr) => interval_expr,
None => return false,
};
let right_interval = match other.interval_expr() {
Some(interval_expr) => interval_expr,
None => return false,
};
if left_interval != right_interval {
return false;
}
let left_chrom = match self.region_name_expr() {
Some(chrom_expr) => chrom_expr,
None => return false,
};
let right_chrom = match other.region_name_expr() {
Some(chrom_expr) => chrom_expr,
None => return false,
};
left_chrom == right_chrom
}
}
impl PhysicalExpr for StartEndRegionPhysicalExpr {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn data_type(
&self,
_input_schema: &arrow::datatypes::Schema,
) -> datafusion::error::Result<arrow::datatypes::DataType> {
Ok(arrow::datatypes::DataType::Boolean)
}
fn nullable(
&self,
_input_schema: &arrow::datatypes::Schema,
) -> datafusion::error::Result<bool> {
Ok(true)
}
fn evaluate(
&self,
batch: &arrow::record_batch::RecordBatch,
) -> datafusion::error::Result<datafusion::physical_plan::ColumnarValue> {
let eval = match self.interval_expr {
Some(ref interval_expr) => {
let binary_expr = BinaryExpr::new(
Arc::clone(&self.region_name_expr),
datafusion::logical_expr::Operator::And,
Arc::clone(interval_expr),
);
binary_expr.evaluate(batch)
}
None => self.region_name_expr.evaluate(batch),
};
tracing::trace!("Got eval: {:?}", eval);
eval
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr + 'static>> {
vec![]
}
fn with_new_children(
self: std::sync::Arc<Self>,
_children: Vec<std::sync::Arc<dyn PhysicalExpr>>,
) -> datafusion::error::Result<std::sync::Arc<dyn PhysicalExpr>> {
Ok(Arc::new(StartEndRegionPhysicalExpr::new(
Arc::clone(&self.region_name_expr),
self.interval_expr.clone(),
)))
}
fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
let mut s = state;
self.region_name_expr.dyn_hash(&mut s);
if let Some(ref interval_expr) = self.interval_expr {
interval_expr.dyn_hash(&mut s);
}
}
}