1use std::{any::Any, fmt::Display, sync::Arc};
16
17use crate::error::Result;
18use arrow::datatypes::SchemaRef;
19use datafusion::{
20 error::DataFusionError,
21 physical_plan::{expressions::BinaryExpr, PhysicalExpr},
22};
23use noodles::core::Region;
24
25use crate::error::{invalid_chrom::InvalidRegionNameError, invalid_region::InvalidRegionError};
26
27use super::{
28 pos_interval_physical_expr::PosIntervalPhysicalExpr,
29 region_name_physical_expr::RegionNamePhysicalExpr,
30};
31
32#[derive(Debug)]
34pub struct RegionPhysicalExpr {
35 region_name_expr: Arc<dyn PhysicalExpr>,
36 interval_expr: Option<Arc<dyn PhysicalExpr>>,
37}
38
39impl RegionPhysicalExpr {
40 pub fn new(
42 region_name_expr: Arc<dyn PhysicalExpr>,
43 interval_expr: Option<Arc<dyn PhysicalExpr>>,
44 ) -> Self {
45 Self {
46 region_name_expr,
47 interval_expr,
48 }
49 }
50
51 pub fn region(&self) -> Result<Region> {
53 let internal_region_name_expr = self.region_name_expr().ok_or(InvalidRegionNameError)?;
54 let field_value = internal_region_name_expr.field_value();
55
56 match self.interval_expr() {
57 Some(interval_expr) => {
58 let interval = interval_expr.interval()?;
59 let region = Region::new(field_value, interval);
60 Ok(region)
61 }
62 None => {
63 let region = field_value.parse().map_err(|_| InvalidRegionNameError)?;
64 Ok(region)
65 }
66 }
67 }
68
69 pub fn interval_expr(&self) -> Option<&PosIntervalPhysicalExpr> {
71 self.interval_expr
72 .as_ref()
73 .and_then(|expr| expr.as_any().downcast_ref::<PosIntervalPhysicalExpr>())
74 }
75
76 pub fn region_name_expr(&self) -> Option<&RegionNamePhysicalExpr> {
78 self.region_name_expr
79 .as_any()
80 .downcast_ref::<RegionNamePhysicalExpr>()
81 }
82
83 pub fn from_region(region: Region, schema: SchemaRef) -> Result<Self> {
85 let start = region.interval().start().map(usize::from).unwrap_or(1);
86
87 let end = region.interval().end().map(usize::from);
88
89 let interval_expr = PosIntervalPhysicalExpr::from_interval(start, end, &schema)?;
90
91 let region_name = std::str::from_utf8(region.name())?;
92 let chrom_expr = RegionNamePhysicalExpr::from_chrom(region_name, &schema)?;
93
94 let region_expr = Self::new(Arc::new(chrom_expr), Some(Arc::new(interval_expr)));
95
96 Ok(region_expr)
97 }
98}
99
100impl Display for RegionPhysicalExpr {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 write!(
103 f,
104 "RegionPhysicalExpr {{ region_name: {}, interval: {:?} }}",
105 self.region_name_expr, self.interval_expr,
106 )
107 }
108}
109
110impl From<RegionNamePhysicalExpr> for RegionPhysicalExpr {
111 fn from(value: RegionNamePhysicalExpr) -> Self {
112 let chrom_expr = Arc::new(value);
113
114 Self::new(chrom_expr, None)
115 }
116}
117
118impl TryFrom<BinaryExpr> for RegionPhysicalExpr {
119 type Error = DataFusionError;
120
121 fn try_from(expr: BinaryExpr) -> Result<Self, Self::Error> {
122 if let Ok(chrom) = RegionNamePhysicalExpr::try_from(expr.clone()) {
123 let new_region = Self::from(chrom);
124 return Ok(new_region);
125 }
126
127 let chrom_op = expr
128 .left()
129 .as_any()
130 .downcast_ref::<BinaryExpr>()
131 .map(|e| RegionNamePhysicalExpr::try_from(e.clone()))
132 .transpose()?;
133
134 let pos_op = expr
135 .right()
136 .as_any()
137 .downcast_ref::<BinaryExpr>()
138 .map(|binary_expr| PosIntervalPhysicalExpr::try_from(binary_expr.clone()))
139 .transpose()?;
140
141 match (chrom_op, pos_op) {
142 (Some(chrom), Some(pos)) => Ok(Self::new(Arc::new(chrom), Some(Arc::new(pos)))),
143 (_, _) => Err(DataFusionError::External(InvalidRegionError.into())),
144 }
145 }
146}
147
148impl TryFrom<Arc<dyn PhysicalExpr>> for RegionPhysicalExpr {
149 type Error = DataFusionError;
150
151 fn try_from(expr: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
152 if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
153 Self::try_from(binary_expr.clone())
154 } else {
155 Err(DataFusionError::External(InvalidRegionError.into()))
156 }
157 }
158}
159
160impl PartialEq<dyn Any> for RegionPhysicalExpr {
161 fn eq(&self, other: &dyn Any) -> bool {
162 if let Some(other) = other.downcast_ref::<RegionPhysicalExpr>() {
163 let left_interval = match self.interval_expr() {
164 Some(interval_expr) => interval_expr,
165 None => return false,
166 };
167
168 let right_interval = match other.interval_expr() {
169 Some(interval_expr) => interval_expr,
170 None => return false,
171 };
172
173 if left_interval != right_interval {
174 return false;
175 }
176
177 let left_chrom = match self.region_name_expr() {
178 Some(chrom_expr) => chrom_expr,
179 None => return false,
180 };
181
182 let right_chrom = match other.region_name_expr() {
183 Some(chrom_expr) => chrom_expr,
184 None => return false,
185 };
186
187 left_chrom == right_chrom
188 } else {
189 false
190 }
191 }
192}
193
194impl PhysicalExpr for RegionPhysicalExpr {
195 fn as_any(&self) -> &dyn std::any::Any {
196 self
197 }
198
199 fn data_type(
200 &self,
201 _input_schema: &arrow::datatypes::Schema,
202 ) -> datafusion::error::Result<arrow::datatypes::DataType> {
203 Ok(arrow::datatypes::DataType::Boolean)
204 }
205
206 fn nullable(
207 &self,
208 _input_schema: &arrow::datatypes::Schema,
209 ) -> datafusion::error::Result<bool> {
210 Ok(true)
211 }
212
213 fn evaluate(
214 &self,
215 batch: &arrow::record_batch::RecordBatch,
216 ) -> datafusion::error::Result<datafusion::physical_plan::ColumnarValue> {
217 let eval = match self.interval_expr {
218 Some(ref interval_expr) => {
219 let binary_expr = BinaryExpr::new(
220 Arc::clone(&self.region_name_expr),
221 datafusion::logical_expr::Operator::And,
222 Arc::clone(interval_expr),
223 );
224
225 binary_expr.evaluate(batch)
226 }
227 None => self.region_name_expr.evaluate(batch),
228 };
229
230 tracing::trace!("Got eval: {:?}", eval);
231
232 eval
233 }
234
235 fn children(&self) -> Vec<&std::sync::Arc<dyn PhysicalExpr>> {
236 vec![]
237 }
238
239 fn with_new_children(
240 self: std::sync::Arc<Self>,
241 _children: Vec<std::sync::Arc<dyn PhysicalExpr>>,
242 ) -> datafusion::error::Result<std::sync::Arc<dyn PhysicalExpr>> {
243 Ok(Arc::new(RegionPhysicalExpr::new(
244 Arc::clone(&self.region_name_expr),
245 self.interval_expr.clone(),
246 )))
247 }
248
249 fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
250 let mut s = state;
251
252 self.region_name_expr.dyn_hash(&mut s);
253
254 if let Some(ref interval_expr) = self.interval_expr {
255 interval_expr.dyn_hash(&mut s);
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use std::sync::Arc;
263
264 use arrow::{array::BooleanArray, record_batch::RecordBatch};
265 use datafusion::{
266 logical_expr::Operator,
267 physical_plan::{
268 expressions::{col, lit, BinaryExpr},
269 PhysicalExpr,
270 },
271 scalar::ScalarValue,
272 };
273 use noodles::core::{Position, Region};
274
275 #[test]
276 fn test_from_binary_exprs() {
277 let schema = Arc::new(arrow::datatypes::Schema::new(vec![
278 arrow::datatypes::Field::new("chrom", arrow::datatypes::DataType::Utf8, false),
279 arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
280 ]));
281
282 let expr = BinaryExpr::new(
283 Arc::new(BinaryExpr::new(
284 col("chrom", &schema).unwrap(),
285 Operator::Eq,
286 lit(ScalarValue::from("1")),
287 )),
288 Operator::And,
289 Arc::new(BinaryExpr::new(
290 col("pos", &schema).unwrap(),
291 Operator::Eq,
292 lit(ScalarValue::from(4)),
293 )),
294 );
295
296 let region = super::RegionPhysicalExpr::try_from(expr).unwrap();
297
298 assert_eq!(
299 region.region().unwrap(),
300 Region::new(
301 "1",
302 noodles::core::region::Interval::from(
303 Position::new(4).unwrap()..=Position::new(4).unwrap()
304 )
305 )
306 );
307 }
308
309 #[tokio::test]
310 async fn test_evaluate() {
311 let batch = RecordBatch::try_new(
312 Arc::new(arrow::datatypes::Schema::new(vec![
313 arrow::datatypes::Field::new("chrom", arrow::datatypes::DataType::Utf8, false),
314 arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
315 ])),
316 vec![
317 Arc::new(arrow::array::StringArray::from(vec![
318 "chr1", "chr1", "chr2",
319 ])),
320 Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
321 ],
322 )
323 .unwrap();
324
325 let region = "chr1:1-1".parse::<Region>().unwrap();
326
327 let expr = super::RegionPhysicalExpr::from_region(region, batch.schema()).unwrap();
328
329 let result = match expr.evaluate(&batch).unwrap() {
330 datafusion::physical_plan::ColumnarValue::Array(array) => array,
331 _ => panic!("Expected array"),
332 };
333
334 let result = result
336 .as_any()
337 .downcast_ref::<arrow::array::BooleanArray>()
338 .unwrap();
339
340 let expected = BooleanArray::from(vec![Some(true), Some(false), Some(false)]);
341
342 result
343 .iter()
344 .zip(expected.iter())
345 .for_each(|(result, expected)| {
346 assert_eq!(result, expected);
347 });
348 }
349}