hirn_exec/rules/
policy_pushdown.rs1use std::sync::Arc;
14
15use arrow_schema::DataType;
16use datafusion_common::Result;
17use datafusion_common::config::ConfigOptions;
18use datafusion_common::tree_node::{Transformed, TreeNode};
19use datafusion_physical_optimizer::PhysicalOptimizerRule;
20use datafusion_physical_plan::ExecutionPlan;
21use datafusion_physical_plan::empty::EmptyExec;
22use datafusion_physical_plan::filter::FilterExec;
23
24use crate::extensions::HirnSessionExt;
25
26#[derive(Debug, Default)]
42pub struct PolicyPushdownRule;
43
44impl PolicyPushdownRule {
45 pub fn new() -> Self {
46 Self
47 }
48
49 fn has_namespace_column(plan: &dyn ExecutionPlan) -> bool {
51 plan.schema()
52 .fields()
53 .iter()
54 .any(|f| f.name() == "namespace" && f.data_type() == &DataType::Utf8)
55 }
56
57 fn build_namespace_filter(
59 input: Arc<dyn ExecutionPlan>,
60 namespaces: &[String],
61 ) -> Result<Arc<dyn ExecutionPlan>> {
62 use datafusion_physical_expr::expressions::{self, BinaryExpr, InListExpr};
63
64 let schema = input.schema();
65 let (_idx, _) = schema.column_with_name("namespace").ok_or_else(|| {
66 datafusion_common::DataFusionError::Internal(
67 "PolicyPushdownRule: expected 'namespace' column".into(),
68 )
69 })?;
70
71 let ns_col = expressions::col("namespace", &schema)?;
72
73 let predicate: Arc<dyn datafusion_physical_expr::PhysicalExpr> = if namespaces.len() == 1 {
74 let lit = expressions::lit(datafusion_common::ScalarValue::Utf8(Some(
76 namespaces[0].clone(),
77 )));
78 Arc::new(BinaryExpr::new(ns_col, datafusion_expr::Operator::Eq, lit))
79 } else {
80 let list: Vec<Arc<dyn datafusion_physical_expr::PhysicalExpr>> = namespaces
82 .iter()
83 .map(|ns| {
84 expressions::lit(datafusion_common::ScalarValue::Utf8(Some(ns.clone())))
85 as Arc<dyn datafusion_physical_expr::PhysicalExpr>
86 })
87 .collect();
88
89 Arc::new(InListExpr::try_new(ns_col, list, false, &schema)?)
90 };
91
92 let filter = FilterExec::try_new(predicate, input)?;
93 Ok(Arc::new(filter))
94 }
95}
96
97impl PhysicalOptimizerRule for PolicyPushdownRule {
98 fn optimize(
99 &self,
100 plan: Arc<dyn ExecutionPlan>,
101 config: &ConfigOptions,
102 ) -> Result<Arc<dyn ExecutionPlan>> {
103 let namespaces = config
105 .extensions
106 .get::<HirnSessionExt>()
107 .and_then(|ext| ext.allowed_namespaces().map(|ns| ns.to_vec()));
108
109 let Some(allowed) = namespaces else {
110 return Ok(plan);
112 };
113
114 if allowed.is_empty() {
115 return Ok(Arc::new(EmptyExec::new(plan.schema())));
117 }
118
119 let allowed = Arc::new(allowed);
121 plan.transform_up(|node| {
122 if !node.children().is_empty() {
124 return Ok(Transformed::no(node));
125 }
126
127 if !Self::has_namespace_column(node.as_ref()) {
128 return Ok(Transformed::no(node));
129 }
130
131 let filtered = Self::build_namespace_filter(node, &allowed)?;
132 Ok(Transformed::yes(filtered))
133 })
134 .map(|t| t.data)
135 }
136
137 fn name(&self) -> &str {
138 "PolicyPushdownRule"
139 }
140
141 fn schema_check(&self) -> bool {
142 true
143 }
144}
145
146#[cfg(test)]
149mod tests {
150 use super::*;
151 use arrow_array::{RecordBatch, StringArray};
152 use arrow_schema::{Field, Schema};
153 use datafusion_datasource::memory::MemorySourceConfig;
154
155 fn scan_with_namespace() -> Arc<dyn ExecutionPlan> {
156 let schema = Arc::new(Schema::new(vec![
157 Field::new("id", DataType::Utf8, false),
158 Field::new("namespace", DataType::Utf8, false),
159 Field::new("content", DataType::Utf8, true),
160 ]));
161 let batch = RecordBatch::try_new(
162 schema.clone(),
163 vec![
164 Arc::new(StringArray::from(vec!["m1", "m2", "m3"])),
165 Arc::new(StringArray::from(vec!["ns_a", "ns_b", "ns_a"])),
166 Arc::new(StringArray::from(vec!["hello", "world", "foo"])),
167 ],
168 )
169 .unwrap();
170 MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
171 }
172
173 fn scan_without_namespace() -> Arc<dyn ExecutionPlan> {
174 let schema = Arc::new(Schema::new(vec![
175 Field::new("id", DataType::Utf8, false),
176 Field::new("content", DataType::Utf8, true),
177 ]));
178 let batch = RecordBatch::try_new(
179 schema.clone(),
180 vec![
181 Arc::new(StringArray::from(vec!["m1"])),
182 Arc::new(StringArray::from(vec!["hello"])),
183 ],
184 )
185 .unwrap();
186 MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
187 }
188
189 fn config_with_namespaces(namespaces: Option<Vec<String>>) -> ConfigOptions {
190 let mut config = ConfigOptions::default();
191 let ext = HirnSessionExt::new(
192 Arc::new(()),
193 Arc::new(hirn_core::config::HirnConfig::default()),
194 None,
195 )
196 .with_allowed_namespaces(namespaces);
197 config.extensions.insert(ext);
198 config
199 }
200
201 #[test]
202 fn open_mode_no_filter() {
203 let plan = scan_with_namespace();
204 let rule = PolicyPushdownRule::new();
205 let config = config_with_namespaces(None);
206 let result = rule.optimize(plan.clone(), &config).unwrap();
207 assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
209 }
210
211 #[test]
212 fn deny_all_returns_empty() {
213 let plan = scan_with_namespace();
214 let rule = PolicyPushdownRule::new();
215 let config = config_with_namespaces(Some(vec![]));
216 let result = rule.optimize(plan, &config).unwrap();
217 assert!(result.as_any().downcast_ref::<EmptyExec>().is_some());
218 }
219
220 #[test]
221 fn single_namespace_equality_filter() {
222 let plan = scan_with_namespace();
223 let rule = PolicyPushdownRule::new();
224 let config = config_with_namespaces(Some(vec!["ns_a".to_string()]));
225 let result = rule.optimize(plan, &config).unwrap();
226 let filter = result.as_any().downcast_ref::<FilterExec>();
228 assert!(filter.is_some(), "expected FilterExec");
229 let filter = filter.unwrap();
230 let pred_str = format!("{}", filter.predicate());
231 assert!(
232 pred_str.contains("namespace") && pred_str.contains("ns_a"),
233 "expected namespace = 'ns_a' predicate, got: {pred_str}"
234 );
235 }
236
237 #[test]
238 fn multiple_namespaces_in_list_filter() {
239 let plan = scan_with_namespace();
240 let rule = PolicyPushdownRule::new();
241 let config = config_with_namespaces(Some(vec!["ns_a".to_string(), "ns_b".to_string()]));
242 let result = rule.optimize(plan, &config).unwrap();
243 let filter = result.as_any().downcast_ref::<FilterExec>();
244 assert!(filter.is_some(), "expected FilterExec");
245 let pred_str = format!("{}", filter.unwrap().predicate());
246 assert!(
247 pred_str.contains("namespace") && pred_str.contains("IN"),
248 "expected IN predicate, got: {pred_str}"
249 );
250 }
251
252 #[test]
253 fn no_namespace_column_no_filter() {
254 let plan = scan_without_namespace();
255 let rule = PolicyPushdownRule::new();
256 let config = config_with_namespaces(Some(vec!["ns_a".to_string()]));
257 let result = rule.optimize(plan.clone(), &config).unwrap();
258 assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
260 }
261
262 #[test]
263 fn no_ext_registered_no_filter() {
264 let plan = scan_with_namespace();
265 let rule = PolicyPushdownRule::new();
266 let config = ConfigOptions::default();
267 let result = rule.optimize(plan.clone(), &config).unwrap();
268 assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
270 }
271}