Skip to main content

hirn_engine/operators/
policy.rs

1//! Policy filtering operator.
2//!
3//! Filters input batches to include only rows whose `namespace` column
4//! is in the set of namespaces allowed for the current principal.
5
6use arrow_array::cast::AsArray;
7use arrow_array::{Array, RecordBatch};
8use async_trait::async_trait;
9
10use hirn_core::error::HirnResult;
11use hirn_storage::NamespacePolicy;
12
13use std::collections::HashSet;
14use std::sync::Arc;
15
16use super::{OpContext, Operator};
17
18/// Operator that filters rows by namespace based on Cedar policy.
19///
20/// If the `OpContext` has no principal, all rows pass through (permissive).
21/// If the principal has no policy restrictions, all rows pass through.
22pub struct PolicyFilter {
23    pub policy: Arc<dyn NamespacePolicy>,
24}
25
26#[async_trait]
27impl Operator for PolicyFilter {
28    async fn execute(
29        &self,
30        input: Vec<RecordBatch>,
31        ctx: &OpContext,
32    ) -> HirnResult<Vec<RecordBatch>> {
33        let principal = match &ctx.principal {
34            Some(p) => p.as_str(),
35            None => return Ok(input), // No principal → permissive.
36        };
37
38        let allowed = match self.policy.allowed_namespaces(principal).await {
39            Some(ns) => ns.into_iter().collect::<HashSet<String>>(),
40            None => return Ok(input), // No restrictions → pass all.
41        };
42
43        let mut out = Vec::new();
44        for batch in &input {
45            if let Some(filtered) = filter_batch_by_namespace(batch, &allowed)? {
46                if filtered.num_rows() > 0 {
47                    out.push(filtered);
48                }
49            }
50        }
51        Ok(out)
52    }
53}
54
55/// Filter a single batch: keep only rows where `namespace` ∈ `allowed`.
56/// Returns `None` if the batch has no `namespace` column (passes through).
57fn filter_batch_by_namespace(
58    batch: &RecordBatch,
59    allowed: &HashSet<String>,
60) -> HirnResult<Option<RecordBatch>> {
61    let ns_col = match batch.column_by_name("namespace") {
62        Some(c) => c,
63        None => return Ok(Some(batch.clone())), // No namespace column → pass.
64    };
65
66    let str_arr = ns_col.as_string::<i32>();
67    let mut keep = Vec::with_capacity(batch.num_rows());
68    for i in 0..str_arr.len() {
69        if !str_arr.is_null(i) && allowed.contains(str_arr.value(i)) {
70            keep.push(i as u32);
71        }
72    }
73
74    let indices = arrow_array::UInt32Array::from(keep);
75    let columns: Vec<_> = batch
76        .columns()
77        .iter()
78        .map(|col| arrow_select::take::take(col.as_ref(), &indices, None))
79        .collect::<Result<_, _>>()
80        .map_err(|e| hirn_core::error::HirnError::storage(e))?;
81
82    let filtered = RecordBatch::try_new(batch.schema(), columns)
83        .map_err(|e| hirn_core::error::HirnError::storage(e))?;
84    Ok(Some(filtered))
85}