hirn_engine/operators/
policy.rs1use arrow_array::cast::AsArray;
7use arrow_array::{Array, RecordBatch};
8use async_trait::async_trait;
9
10use hirn_core::error::HirnResult;
11use hirn_storage::NamespacePolicy;
12
13use std::collections::HashSet;
14use std::sync::Arc;
15
16use super::{OpContext, Operator};
17
18pub struct PolicyFilter {
23 pub policy: Arc<dyn NamespacePolicy>,
24}
25
26#[async_trait]
27impl Operator for PolicyFilter {
28 async fn execute(
29 &self,
30 input: Vec<RecordBatch>,
31 ctx: &OpContext,
32 ) -> HirnResult<Vec<RecordBatch>> {
33 let principal = match &ctx.principal {
34 Some(p) => p.as_str(),
35 None => return Ok(input), };
37
38 let allowed = match self.policy.allowed_namespaces(principal).await {
39 Some(ns) => ns.into_iter().collect::<HashSet<String>>(),
40 None => return Ok(input), };
42
43 let mut out = Vec::new();
44 for batch in &input {
45 if let Some(filtered) = filter_batch_by_namespace(batch, &allowed)? {
46 if filtered.num_rows() > 0 {
47 out.push(filtered);
48 }
49 }
50 }
51 Ok(out)
52 }
53}
54
55fn filter_batch_by_namespace(
58 batch: &RecordBatch,
59 allowed: &HashSet<String>,
60) -> HirnResult<Option<RecordBatch>> {
61 let ns_col = match batch.column_by_name("namespace") {
62 Some(c) => c,
63 None => return Ok(Some(batch.clone())), };
65
66 let str_arr = ns_col.as_string::<i32>();
67 let mut keep = Vec::with_capacity(batch.num_rows());
68 for i in 0..str_arr.len() {
69 if !str_arr.is_null(i) && allowed.contains(str_arr.value(i)) {
70 keep.push(i as u32);
71 }
72 }
73
74 let indices = arrow_array::UInt32Array::from(keep);
75 let columns: Vec<_> = batch
76 .columns()
77 .iter()
78 .map(|col| arrow_select::take::take(col.as_ref(), &indices, None))
79 .collect::<Result<_, _>>()
80 .map_err(|e| hirn_core::error::HirnError::storage(e))?;
81
82 let filtered = RecordBatch::try_new(batch.schema(), columns)
83 .map_err(|e| hirn_core::error::HirnError::storage(e))?;
84 Ok(Some(filtered))
85}