Skip to main content

hirn_exec/operators/
causal_query_read.rs

1//! `CausalQueryReadExec` — query-scoped terminal reads for causal HirnQL statements.
2
3use std::any::Any;
4use std::fmt;
5use std::sync::Arc;
6
7use arrow_array::{BinaryArray, RecordBatch};
8use arrow_schema::SchemaRef;
9use datafusion_common::{DataFusionError, Result};
10use datafusion_execution::{SendableRecordBatchStream, TaskContext};
11use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
12use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
13use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
14
15use crate::extensions::HirnSessionExt;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum CausalReadKind {
19    ExplainCauses,
20    WhatIf,
21    Counterfactual,
22}
23
24#[derive(Debug, Clone)]
25pub struct CausalQueryReadExec {
26    schema: SchemaRef,
27    properties: PlanProperties,
28    kind: CausalReadKind,
29    primary: String,
30    secondary: Option<String>,
31    depth: u32,
32    namespace: Option<String>,
33}
34
35impl CausalQueryReadExec {
36    pub fn new(
37        schema: SchemaRef,
38        kind: CausalReadKind,
39        primary: String,
40        secondary: Option<String>,
41        depth: u32,
42        namespace: Option<String>,
43    ) -> Self {
44        let properties = PlanProperties::new(
45            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
46            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
47            EmissionType::Final,
48            Boundedness::Bounded,
49        );
50
51        Self {
52            schema,
53            properties,
54            kind,
55            primary,
56            secondary,
57            depth,
58            namespace,
59        }
60    }
61}
62
63impl DisplayAs for CausalQueryReadExec {
64    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(
66            f,
67            "CausalQueryReadExec: kind={:?}, namespace={}",
68            self.kind,
69            self.namespace.as_deref().unwrap_or("*")
70        )
71    }
72}
73
74impl ExecutionPlan for CausalQueryReadExec {
75    fn name(&self) -> &str {
76        match self.kind {
77            CausalReadKind::ExplainCauses => "CausalExplainCausesExec",
78            CausalReadKind::WhatIf => "CausalWhatIfExec",
79            CausalReadKind::Counterfactual => "CausalCounterfactualExec",
80        }
81    }
82
83    fn as_any(&self) -> &dyn Any {
84        self
85    }
86
87    fn schema(&self) -> SchemaRef {
88        self.schema.clone()
89    }
90
91    fn properties(&self) -> &PlanProperties {
92        &self.properties
93    }
94
95    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
96        vec![]
97    }
98
99    fn with_new_children(
100        self: Arc<Self>,
101        children: Vec<Arc<dyn ExecutionPlan>>,
102    ) -> Result<Arc<dyn ExecutionPlan>> {
103        if !children.is_empty() {
104            return Err(DataFusionError::Plan(
105                "CausalQueryReadExec is a leaf node and does not accept children".to_string(),
106            ));
107        }
108        Ok(self)
109    }
110
111    fn execute(
112        &self,
113        _partition: usize,
114        context: Arc<TaskContext>,
115    ) -> Result<SendableRecordBatchStream> {
116        let schema = self.schema.clone();
117        let stream_schema = schema.clone();
118        let kind = self.kind;
119        let primary = self.primary.clone();
120        let secondary = self.secondary.clone();
121        let depth = self.depth;
122        let namespace = self.namespace.clone();
123        let ext = context
124            .session_config()
125            .options()
126            .extensions
127            .get::<HirnSessionExt>()
128            .cloned();
129
130        let fut = async move {
131            let Some(ext) = ext else {
132                return Err(DataFusionError::Execution(
133                    "CausalQueryReadExec requires HirnSessionExt".to_string(),
134                ));
135            };
136            let Some(runtime) = ext.query_read_runtime() else {
137                return Err(DataFusionError::Execution(
138                    "CausalQueryReadExec requires a query read runtime in HirnSessionExt"
139                        .to_string(),
140                ));
141            };
142
143            let payload = match kind {
144                CausalReadKind::ExplainCauses => {
145                    runtime
146                        .explain_causes_json(
147                            &primary,
148                            depth,
149                            namespace.as_deref(),
150                            ext.allowed_namespaces(),
151                        )
152                        .await
153                }
154                CausalReadKind::WhatIf => {
155                    let Some(secondary) = secondary.as_deref() else {
156                        return Err(DataFusionError::Execution(
157                            "CausalQueryReadExec WHAT_IF requires an outcome".to_string(),
158                        ));
159                    };
160                    runtime
161                        .what_if_json(
162                            &primary,
163                            secondary,
164                            namespace.as_deref(),
165                            ext.allowed_namespaces(),
166                        )
167                        .await
168                }
169                CausalReadKind::Counterfactual => {
170                    let Some(secondary) = secondary.as_deref() else {
171                        return Err(DataFusionError::Execution(
172                            "CausalQueryReadExec COUNTERFACTUAL requires a consequent".to_string(),
173                        ));
174                    };
175                    runtime
176                        .counterfactual_json(
177                            &primary,
178                            secondary,
179                            namespace.as_deref(),
180                            ext.allowed_namespaces(),
181                        )
182                        .await
183                }
184            }
185            .map_err(|error| DataFusionError::Execution(error.to_string()))?;
186
187            Ok::<_, DataFusionError>(RecordBatch::try_new(
188                stream_schema,
189                vec![Arc::new(BinaryArray::from(vec![payload.as_slice()]))],
190            )?)
191        };
192
193        let stream = futures::stream::once(fut);
194        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
195    }
196}