hirn_exec/operators/
causal_query_read.rs1use std::any::Any;
4use std::fmt;
5use std::sync::Arc;
6
7use arrow_array::{BinaryArray, RecordBatch};
8use arrow_schema::SchemaRef;
9use datafusion_common::{DataFusionError, Result};
10use datafusion_execution::{SendableRecordBatchStream, TaskContext};
11use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
12use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
13use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
14
15use crate::extensions::HirnSessionExt;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum CausalReadKind {
19 ExplainCauses,
20 WhatIf,
21 Counterfactual,
22}
23
24#[derive(Debug, Clone)]
25pub struct CausalQueryReadExec {
26 schema: SchemaRef,
27 properties: PlanProperties,
28 kind: CausalReadKind,
29 primary: String,
30 secondary: Option<String>,
31 depth: u32,
32 namespace: Option<String>,
33}
34
35impl CausalQueryReadExec {
36 pub fn new(
37 schema: SchemaRef,
38 kind: CausalReadKind,
39 primary: String,
40 secondary: Option<String>,
41 depth: u32,
42 namespace: Option<String>,
43 ) -> Self {
44 let properties = PlanProperties::new(
45 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
46 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
47 EmissionType::Final,
48 Boundedness::Bounded,
49 );
50
51 Self {
52 schema,
53 properties,
54 kind,
55 primary,
56 secondary,
57 depth,
58 namespace,
59 }
60 }
61}
62
63impl DisplayAs for CausalQueryReadExec {
64 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 write!(
66 f,
67 "CausalQueryReadExec: kind={:?}, namespace={}",
68 self.kind,
69 self.namespace.as_deref().unwrap_or("*")
70 )
71 }
72}
73
74impl ExecutionPlan for CausalQueryReadExec {
75 fn name(&self) -> &str {
76 match self.kind {
77 CausalReadKind::ExplainCauses => "CausalExplainCausesExec",
78 CausalReadKind::WhatIf => "CausalWhatIfExec",
79 CausalReadKind::Counterfactual => "CausalCounterfactualExec",
80 }
81 }
82
83 fn as_any(&self) -> &dyn Any {
84 self
85 }
86
87 fn schema(&self) -> SchemaRef {
88 self.schema.clone()
89 }
90
91 fn properties(&self) -> &PlanProperties {
92 &self.properties
93 }
94
95 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
96 vec![]
97 }
98
99 fn with_new_children(
100 self: Arc<Self>,
101 children: Vec<Arc<dyn ExecutionPlan>>,
102 ) -> Result<Arc<dyn ExecutionPlan>> {
103 if !children.is_empty() {
104 return Err(DataFusionError::Plan(
105 "CausalQueryReadExec is a leaf node and does not accept children".to_string(),
106 ));
107 }
108 Ok(self)
109 }
110
111 fn execute(
112 &self,
113 _partition: usize,
114 context: Arc<TaskContext>,
115 ) -> Result<SendableRecordBatchStream> {
116 let schema = self.schema.clone();
117 let stream_schema = schema.clone();
118 let kind = self.kind;
119 let primary = self.primary.clone();
120 let secondary = self.secondary.clone();
121 let depth = self.depth;
122 let namespace = self.namespace.clone();
123 let ext = context
124 .session_config()
125 .options()
126 .extensions
127 .get::<HirnSessionExt>()
128 .cloned();
129
130 let fut = async move {
131 let Some(ext) = ext else {
132 return Err(DataFusionError::Execution(
133 "CausalQueryReadExec requires HirnSessionExt".to_string(),
134 ));
135 };
136 let Some(runtime) = ext.query_read_runtime() else {
137 return Err(DataFusionError::Execution(
138 "CausalQueryReadExec requires a query read runtime in HirnSessionExt"
139 .to_string(),
140 ));
141 };
142
143 let payload = match kind {
144 CausalReadKind::ExplainCauses => {
145 runtime
146 .explain_causes_json(
147 &primary,
148 depth,
149 namespace.as_deref(),
150 ext.allowed_namespaces(),
151 )
152 .await
153 }
154 CausalReadKind::WhatIf => {
155 let Some(secondary) = secondary.as_deref() else {
156 return Err(DataFusionError::Execution(
157 "CausalQueryReadExec WHAT_IF requires an outcome".to_string(),
158 ));
159 };
160 runtime
161 .what_if_json(
162 &primary,
163 secondary,
164 namespace.as_deref(),
165 ext.allowed_namespaces(),
166 )
167 .await
168 }
169 CausalReadKind::Counterfactual => {
170 let Some(secondary) = secondary.as_deref() else {
171 return Err(DataFusionError::Execution(
172 "CausalQueryReadExec COUNTERFACTUAL requires a consequent".to_string(),
173 ));
174 };
175 runtime
176 .counterfactual_json(
177 &primary,
178 secondary,
179 namespace.as_deref(),
180 ext.allowed_namespaces(),
181 )
182 .await
183 }
184 }
185 .map_err(|error| DataFusionError::Execution(error.to_string()))?;
186
187 Ok::<_, DataFusionError>(RecordBatch::try_new(
188 stream_schema,
189 vec![Arc::new(BinaryArray::from(vec![payload.as_slice()]))],
190 )?)
191 };
192
193 let stream = futures::stream::once(fut);
194 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
195 }
196}