Skip to main content

hirn_exec/rules/
policy_pushdown.rs

1//! `PolicyPushdownRule` — injects namespace filter predicates into physical plans.
2//!
3//! At plan optimization time, reads the pre-resolved [`allowed_namespaces`]
4//! from [`HirnSessionExt`] and injects `namespace IN (...)` or
5//! `namespace = '...'` filter predicates above table scan operators.
6//!
7//! When no namespaces are allowed (empty list), the plan subtree is replaced
8//! with an empty result. When in open mode (`None`), no filters are injected.
9//!
10//! [`allowed_namespaces`]: crate::extensions::HirnSessionExt::allowed_namespaces
11//! [`HirnSessionExt`]: crate::extensions::HirnSessionExt
12
13use std::sync::Arc;
14
15use arrow_schema::DataType;
16use datafusion_common::Result;
17use datafusion_common::config::ConfigOptions;
18use datafusion_common::tree_node::{Transformed, TreeNode};
19use datafusion_physical_optimizer::PhysicalOptimizerRule;
20use datafusion_physical_plan::ExecutionPlan;
21use datafusion_physical_plan::empty::EmptyExec;
22use datafusion_physical_plan::filter::FilterExec;
23
24use crate::extensions::HirnSessionExt;
25
26/// Injects Cedar-derived namespace filter predicates into physical plans.
27///
28/// This rule reads the agent's allowed namespace list from [`HirnSessionExt`]
29/// (pre-resolved at session setup time) and injects `namespace = '...'` or
30/// `namespace IN (...)` filters above any scan operator whose schema contains
31/// a `namespace` column.
32///
33/// # Behavior
34///
35/// | `allowed_namespaces` | Action |
36/// |---------------------|--------|
37/// | `None` | No filter injected (open mode) |
38/// | `Some([])` | Replace subtree with `EmptyExec` (deny all) |
39/// | `Some(["ns_a"])` | Inject `namespace = 'ns_a'` equality filter |
40/// | `Some(["ns_a", "ns_b"])` | Inject `namespace IN ('ns_a', 'ns_b')` filter |
41#[derive(Debug, Default)]
42pub struct PolicyPushdownRule;
43
44impl PolicyPushdownRule {
45    pub fn new() -> Self {
46        Self
47    }
48
49    /// Check if a plan node's output schema contains a `namespace` column.
50    fn has_namespace_column(plan: &dyn ExecutionPlan) -> bool {
51        plan.schema()
52            .fields()
53            .iter()
54            .any(|f| f.name() == "namespace" && f.data_type() == &DataType::Utf8)
55    }
56
57    /// Build a physical filter expression for namespace restriction.
58    fn build_namespace_filter(
59        input: Arc<dyn ExecutionPlan>,
60        namespaces: &[String],
61    ) -> Result<Arc<dyn ExecutionPlan>> {
62        use datafusion_physical_expr::expressions::{self, BinaryExpr, InListExpr};
63
64        let schema = input.schema();
65        let (_idx, _) = schema.column_with_name("namespace").ok_or_else(|| {
66            datafusion_common::DataFusionError::Internal(
67                "PolicyPushdownRule: expected 'namespace' column".into(),
68            )
69        })?;
70
71        let ns_col = expressions::col("namespace", &schema)?;
72
73        let predicate: Arc<dyn datafusion_physical_expr::PhysicalExpr> = if namespaces.len() == 1 {
74            // namespace = 'ns_a'
75            let lit = expressions::lit(datafusion_common::ScalarValue::Utf8(Some(
76                namespaces[0].clone(),
77            )));
78            Arc::new(BinaryExpr::new(ns_col, datafusion_expr::Operator::Eq, lit))
79        } else {
80            // namespace IN ('ns_a', 'ns_b', ...)
81            let list: Vec<Arc<dyn datafusion_physical_expr::PhysicalExpr>> = namespaces
82                .iter()
83                .map(|ns| {
84                    expressions::lit(datafusion_common::ScalarValue::Utf8(Some(ns.clone())))
85                        as Arc<dyn datafusion_physical_expr::PhysicalExpr>
86                })
87                .collect();
88
89            Arc::new(InListExpr::try_new(ns_col, list, false, &schema)?)
90        };
91
92        let filter = FilterExec::try_new(predicate, input)?;
93        Ok(Arc::new(filter))
94    }
95}
96
97impl PhysicalOptimizerRule for PolicyPushdownRule {
98    fn optimize(
99        &self,
100        plan: Arc<dyn ExecutionPlan>,
101        config: &ConfigOptions,
102    ) -> Result<Arc<dyn ExecutionPlan>> {
103        // Read allowed namespaces from session extensions.
104        let namespaces = config
105            .extensions
106            .get::<HirnSessionExt>()
107            .and_then(|ext| ext.allowed_namespaces().map(|ns| ns.to_vec()));
108
109        let Some(allowed) = namespaces else {
110            // Open mode — no policy filtering.
111            return Ok(plan);
112        };
113
114        if allowed.is_empty() {
115            // Deny all — replace entire plan with empty result.
116            return Ok(Arc::new(EmptyExec::new(plan.schema())));
117        }
118
119        // Walk the plan tree and inject filters above scan nodes.
120        let allowed = Arc::new(allowed);
121        plan.transform_up(|node| {
122            // Only inject on leaf nodes (scans) that have a namespace column.
123            if !node.children().is_empty() {
124                return Ok(Transformed::no(node));
125            }
126
127            if !Self::has_namespace_column(node.as_ref()) {
128                return Ok(Transformed::no(node));
129            }
130
131            let filtered = Self::build_namespace_filter(node, &allowed)?;
132            Ok(Transformed::yes(filtered))
133        })
134        .map(|t| t.data)
135    }
136
137    fn name(&self) -> &str {
138        "PolicyPushdownRule"
139    }
140
141    fn schema_check(&self) -> bool {
142        true
143    }
144}
145
146// ── Tests ───────────────────────────────────────────────────────────────
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use arrow_array::{RecordBatch, StringArray};
152    use arrow_schema::{Field, Schema};
153    use datafusion_datasource::memory::MemorySourceConfig;
154
155    fn scan_with_namespace() -> Arc<dyn ExecutionPlan> {
156        let schema = Arc::new(Schema::new(vec![
157            Field::new("id", DataType::Utf8, false),
158            Field::new("namespace", DataType::Utf8, false),
159            Field::new("content", DataType::Utf8, true),
160        ]));
161        let batch = RecordBatch::try_new(
162            schema.clone(),
163            vec![
164                Arc::new(StringArray::from(vec!["m1", "m2", "m3"])),
165                Arc::new(StringArray::from(vec!["ns_a", "ns_b", "ns_a"])),
166                Arc::new(StringArray::from(vec!["hello", "world", "foo"])),
167            ],
168        )
169        .unwrap();
170        MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
171    }
172
173    fn scan_without_namespace() -> Arc<dyn ExecutionPlan> {
174        let schema = Arc::new(Schema::new(vec![
175            Field::new("id", DataType::Utf8, false),
176            Field::new("content", DataType::Utf8, true),
177        ]));
178        let batch = RecordBatch::try_new(
179            schema.clone(),
180            vec![
181                Arc::new(StringArray::from(vec!["m1"])),
182                Arc::new(StringArray::from(vec!["hello"])),
183            ],
184        )
185        .unwrap();
186        MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap()
187    }
188
189    fn config_with_namespaces(namespaces: Option<Vec<String>>) -> ConfigOptions {
190        let mut config = ConfigOptions::default();
191        let ext = HirnSessionExt::new(
192            Arc::new(()),
193            Arc::new(hirn_core::config::HirnConfig::default()),
194            None,
195        )
196        .with_allowed_namespaces(namespaces);
197        config.extensions.insert(ext);
198        config
199    }
200
201    #[test]
202    fn open_mode_no_filter() {
203        let plan = scan_with_namespace();
204        let rule = PolicyPushdownRule::new();
205        let config = config_with_namespaces(None);
206        let result = rule.optimize(plan.clone(), &config).unwrap();
207        // No filter injected — plan should be unchanged.
208        assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
209    }
210
211    #[test]
212    fn deny_all_returns_empty() {
213        let plan = scan_with_namespace();
214        let rule = PolicyPushdownRule::new();
215        let config = config_with_namespaces(Some(vec![]));
216        let result = rule.optimize(plan, &config).unwrap();
217        assert!(result.as_any().downcast_ref::<EmptyExec>().is_some());
218    }
219
220    #[test]
221    fn single_namespace_equality_filter() {
222        let plan = scan_with_namespace();
223        let rule = PolicyPushdownRule::new();
224        let config = config_with_namespaces(Some(vec!["ns_a".to_string()]));
225        let result = rule.optimize(plan, &config).unwrap();
226        // Should be a FilterExec wrapping the scan.
227        let filter = result.as_any().downcast_ref::<FilterExec>();
228        assert!(filter.is_some(), "expected FilterExec");
229        let filter = filter.unwrap();
230        let pred_str = format!("{}", filter.predicate());
231        assert!(
232            pred_str.contains("namespace") && pred_str.contains("ns_a"),
233            "expected namespace = 'ns_a' predicate, got: {pred_str}"
234        );
235    }
236
237    #[test]
238    fn multiple_namespaces_in_list_filter() {
239        let plan = scan_with_namespace();
240        let rule = PolicyPushdownRule::new();
241        let config = config_with_namespaces(Some(vec!["ns_a".to_string(), "ns_b".to_string()]));
242        let result = rule.optimize(plan, &config).unwrap();
243        let filter = result.as_any().downcast_ref::<FilterExec>();
244        assert!(filter.is_some(), "expected FilterExec");
245        let pred_str = format!("{}", filter.unwrap().predicate());
246        assert!(
247            pred_str.contains("namespace") && pred_str.contains("IN"),
248            "expected IN predicate, got: {pred_str}"
249        );
250    }
251
252    #[test]
253    fn no_namespace_column_no_filter() {
254        let plan = scan_without_namespace();
255        let rule = PolicyPushdownRule::new();
256        let config = config_with_namespaces(Some(vec!["ns_a".to_string()]));
257        let result = rule.optimize(plan.clone(), &config).unwrap();
258        // No namespace column → no filter injected.
259        assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
260    }
261
262    #[test]
263    fn no_ext_registered_no_filter() {
264        let plan = scan_with_namespace();
265        let rule = PolicyPushdownRule::new();
266        let config = ConfigOptions::default();
267        let result = rule.optimize(plan.clone(), &config).unwrap();
268        // No HirnSessionExt → open mode → no filter.
269        assert!(result.as_any().downcast_ref::<FilterExec>().is_none());
270    }
271}