exon/physical_plan/
region_physical_expr.rs

1// Copyright 2023 WHERE TRUE Technologies.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{any::Any, fmt::Display, sync::Arc};
16
17use crate::error::Result;
18use arrow::datatypes::SchemaRef;
19use datafusion::{
20    error::DataFusionError,
21    physical_plan::{expressions::BinaryExpr, PhysicalExpr},
22};
23use noodles::core::Region;
24
25use crate::error::{invalid_chrom::InvalidRegionNameError, invalid_region::InvalidRegionError};
26
27use super::{
28    pos_interval_physical_expr::PosIntervalPhysicalExpr,
29    region_name_physical_expr::RegionNamePhysicalExpr,
30};
31
32/// A physical expression that represents a region, e.g. chr1:100-200.
33#[derive(Debug)]
34pub struct RegionPhysicalExpr {
35    region_name_expr: Arc<dyn PhysicalExpr>,
36    interval_expr: Option<Arc<dyn PhysicalExpr>>,
37}
38
39impl RegionPhysicalExpr {
40    /// Create a new `RegionPhysicalExpr` from a region and two inner expressions.
41    pub fn new(
42        region_name_expr: Arc<dyn PhysicalExpr>,
43        interval_expr: Option<Arc<dyn PhysicalExpr>>,
44    ) -> Self {
45        Self {
46            region_name_expr,
47            interval_expr,
48        }
49    }
50
51    /// Get the region.
52    pub fn region(&self) -> Result<Region> {
53        let internal_region_name_expr = self.region_name_expr().ok_or(InvalidRegionNameError)?;
54        let field_value = internal_region_name_expr.field_value();
55
56        match self.interval_expr() {
57            Some(interval_expr) => {
58                let interval = interval_expr.interval()?;
59                let region = Region::new(field_value, interval);
60                Ok(region)
61            }
62            None => {
63                let region = field_value.parse().map_err(|_| InvalidRegionNameError)?;
64                Ok(region)
65            }
66        }
67    }
68
69    /// Get the interval expression.
70    pub fn interval_expr(&self) -> Option<&PosIntervalPhysicalExpr> {
71        self.interval_expr
72            .as_ref()
73            .and_then(|expr| expr.as_any().downcast_ref::<PosIntervalPhysicalExpr>())
74    }
75
76    /// Get the chromosome expression.
77    pub fn region_name_expr(&self) -> Option<&RegionNamePhysicalExpr> {
78        self.region_name_expr
79            .as_any()
80            .downcast_ref::<RegionNamePhysicalExpr>()
81    }
82
83    /// Create a new `RegionPhysicalExpr` from a region and a schema.
84    pub fn from_region(region: Region, schema: SchemaRef) -> Result<Self> {
85        let start = region.interval().start().map(usize::from).unwrap_or(1);
86
87        let end = region.interval().end().map(usize::from);
88
89        let interval_expr = PosIntervalPhysicalExpr::from_interval(start, end, &schema)?;
90
91        let region_name = std::str::from_utf8(region.name())?;
92        let chrom_expr = RegionNamePhysicalExpr::from_chrom(region_name, &schema)?;
93
94        let region_expr = Self::new(Arc::new(chrom_expr), Some(Arc::new(interval_expr)));
95
96        Ok(region_expr)
97    }
98}
99
100impl Display for RegionPhysicalExpr {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(
103            f,
104            "RegionPhysicalExpr {{ region_name: {}, interval: {:?} }}",
105            self.region_name_expr, self.interval_expr,
106        )
107    }
108}
109
110impl From<RegionNamePhysicalExpr> for RegionPhysicalExpr {
111    fn from(value: RegionNamePhysicalExpr) -> Self {
112        let chrom_expr = Arc::new(value);
113
114        Self::new(chrom_expr, None)
115    }
116}
117
118impl TryFrom<BinaryExpr> for RegionPhysicalExpr {
119    type Error = DataFusionError;
120
121    fn try_from(expr: BinaryExpr) -> Result<Self, Self::Error> {
122        if let Ok(chrom) = RegionNamePhysicalExpr::try_from(expr.clone()) {
123            let new_region = Self::from(chrom);
124            return Ok(new_region);
125        }
126
127        let chrom_op = expr
128            .left()
129            .as_any()
130            .downcast_ref::<BinaryExpr>()
131            .map(|e| RegionNamePhysicalExpr::try_from(e.clone()))
132            .transpose()?;
133
134        let pos_op = expr
135            .right()
136            .as_any()
137            .downcast_ref::<BinaryExpr>()
138            .map(|binary_expr| PosIntervalPhysicalExpr::try_from(binary_expr.clone()))
139            .transpose()?;
140
141        match (chrom_op, pos_op) {
142            (Some(chrom), Some(pos)) => Ok(Self::new(Arc::new(chrom), Some(Arc::new(pos)))),
143            (_, _) => Err(DataFusionError::External(InvalidRegionError.into())),
144        }
145    }
146}
147
148impl TryFrom<Arc<dyn PhysicalExpr>> for RegionPhysicalExpr {
149    type Error = DataFusionError;
150
151    fn try_from(expr: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
152        if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
153            Self::try_from(binary_expr.clone())
154        } else {
155            Err(DataFusionError::External(InvalidRegionError.into()))
156        }
157    }
158}
159
160impl PartialEq<dyn Any> for RegionPhysicalExpr {
161    fn eq(&self, other: &dyn Any) -> bool {
162        if let Some(other) = other.downcast_ref::<RegionPhysicalExpr>() {
163            let left_interval = match self.interval_expr() {
164                Some(interval_expr) => interval_expr,
165                None => return false,
166            };
167
168            let right_interval = match other.interval_expr() {
169                Some(interval_expr) => interval_expr,
170                None => return false,
171            };
172
173            if left_interval != right_interval {
174                return false;
175            }
176
177            let left_chrom = match self.region_name_expr() {
178                Some(chrom_expr) => chrom_expr,
179                None => return false,
180            };
181
182            let right_chrom = match other.region_name_expr() {
183                Some(chrom_expr) => chrom_expr,
184                None => return false,
185            };
186
187            left_chrom == right_chrom
188        } else {
189            false
190        }
191    }
192}
193
194impl PhysicalExpr for RegionPhysicalExpr {
195    fn as_any(&self) -> &dyn std::any::Any {
196        self
197    }
198
199    fn data_type(
200        &self,
201        _input_schema: &arrow::datatypes::Schema,
202    ) -> datafusion::error::Result<arrow::datatypes::DataType> {
203        Ok(arrow::datatypes::DataType::Boolean)
204    }
205
206    fn nullable(
207        &self,
208        _input_schema: &arrow::datatypes::Schema,
209    ) -> datafusion::error::Result<bool> {
210        Ok(true)
211    }
212
213    fn evaluate(
214        &self,
215        batch: &arrow::record_batch::RecordBatch,
216    ) -> datafusion::error::Result<datafusion::physical_plan::ColumnarValue> {
217        let eval = match self.interval_expr {
218            Some(ref interval_expr) => {
219                let binary_expr = BinaryExpr::new(
220                    Arc::clone(&self.region_name_expr),
221                    datafusion::logical_expr::Operator::And,
222                    Arc::clone(interval_expr),
223                );
224
225                binary_expr.evaluate(batch)
226            }
227            None => self.region_name_expr.evaluate(batch),
228        };
229
230        tracing::trace!("Got eval: {:?}", eval);
231
232        eval
233    }
234
235    fn children(&self) -> Vec<&std::sync::Arc<dyn PhysicalExpr>> {
236        vec![]
237    }
238
239    fn with_new_children(
240        self: std::sync::Arc<Self>,
241        _children: Vec<std::sync::Arc<dyn PhysicalExpr>>,
242    ) -> datafusion::error::Result<std::sync::Arc<dyn PhysicalExpr>> {
243        Ok(Arc::new(RegionPhysicalExpr::new(
244            Arc::clone(&self.region_name_expr),
245            self.interval_expr.clone(),
246        )))
247    }
248
249    fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
250        let mut s = state;
251
252        self.region_name_expr.dyn_hash(&mut s);
253
254        if let Some(ref interval_expr) = self.interval_expr {
255            interval_expr.dyn_hash(&mut s);
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use std::sync::Arc;
263
264    use arrow::{array::BooleanArray, record_batch::RecordBatch};
265    use datafusion::{
266        logical_expr::Operator,
267        physical_plan::{
268            expressions::{col, lit, BinaryExpr},
269            PhysicalExpr,
270        },
271        scalar::ScalarValue,
272    };
273    use noodles::core::{Position, Region};
274
275    #[test]
276    fn test_from_binary_exprs() {
277        let schema = Arc::new(arrow::datatypes::Schema::new(vec![
278            arrow::datatypes::Field::new("chrom", arrow::datatypes::DataType::Utf8, false),
279            arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
280        ]));
281
282        let expr = BinaryExpr::new(
283            Arc::new(BinaryExpr::new(
284                col("chrom", &schema).unwrap(),
285                Operator::Eq,
286                lit(ScalarValue::from("1")),
287            )),
288            Operator::And,
289            Arc::new(BinaryExpr::new(
290                col("pos", &schema).unwrap(),
291                Operator::Eq,
292                lit(ScalarValue::from(4)),
293            )),
294        );
295
296        let region = super::RegionPhysicalExpr::try_from(expr).unwrap();
297
298        assert_eq!(
299            region.region().unwrap(),
300            Region::new(
301                "1",
302                noodles::core::region::Interval::from(
303                    Position::new(4).unwrap()..=Position::new(4).unwrap()
304                )
305            )
306        );
307    }
308
309    #[tokio::test]
310    async fn test_evaluate() {
311        let batch = RecordBatch::try_new(
312            Arc::new(arrow::datatypes::Schema::new(vec![
313                arrow::datatypes::Field::new("chrom", arrow::datatypes::DataType::Utf8, false),
314                arrow::datatypes::Field::new("pos", arrow::datatypes::DataType::Int64, false),
315            ])),
316            vec![
317                Arc::new(arrow::array::StringArray::from(vec![
318                    "chr1", "chr1", "chr2",
319                ])),
320                Arc::new(arrow::array::Int64Array::from(vec![1, 2, 3])),
321            ],
322        )
323        .unwrap();
324
325        let region = "chr1:1-1".parse::<Region>().unwrap();
326
327        let expr = super::RegionPhysicalExpr::from_region(region, batch.schema()).unwrap();
328
329        let result = match expr.evaluate(&batch).unwrap() {
330            datafusion::physical_plan::ColumnarValue::Array(array) => array,
331            _ => panic!("Expected array"),
332        };
333
334        // Convert the result to a boolean array
335        let result = result
336            .as_any()
337            .downcast_ref::<arrow::array::BooleanArray>()
338            .unwrap();
339
340        let expected = BooleanArray::from(vec![Some(true), Some(false), Some(false)]);
341
342        result
343            .iter()
344            .zip(expected.iter())
345            .for_each(|(result, expected)| {
346                assert_eq!(result, expected);
347            });
348    }
349}