use arrow_array::cast::AsArray;
use arrow_array::{Array, RecordBatch};
use async_trait::async_trait;
use hirn_core::error::HirnResult;
use hirn_storage::NamespacePolicy;
use std::collections::HashSet;
use std::sync::Arc;
use super::{OpContext, Operator};
pub struct PolicyFilter {
pub policy: Arc<dyn NamespacePolicy>,
}
#[async_trait]
impl Operator for PolicyFilter {
async fn execute(
&self,
input: Vec<RecordBatch>,
ctx: &OpContext,
) -> HirnResult<Vec<RecordBatch>> {
let principal = match &ctx.principal {
Some(p) => p.as_str(),
None => return Ok(input), };
let allowed = match self.policy.allowed_namespaces(principal).await {
Some(ns) => ns.into_iter().collect::<HashSet<String>>(),
None => return Ok(input), };
let mut out = Vec::new();
for batch in &input {
if let Some(filtered) = filter_batch_by_namespace(batch, &allowed)? {
if filtered.num_rows() > 0 {
out.push(filtered);
}
}
}
Ok(out)
}
}
fn filter_batch_by_namespace(
batch: &RecordBatch,
allowed: &HashSet<String>,
) -> HirnResult<Option<RecordBatch>> {
let ns_col = match batch.column_by_name("namespace") {
Some(c) => c,
None => return Ok(Some(batch.clone())), };
let str_arr = ns_col.as_string::<i32>();
let mut keep = Vec::with_capacity(batch.num_rows());
for i in 0..str_arr.len() {
if !str_arr.is_null(i) && allowed.contains(str_arr.value(i)) {
keep.push(i as u32);
}
}
let indices = arrow_array::UInt32Array::from(keep);
let columns: Vec<_> = batch
.columns()
.iter()
.map(|col| arrow_select::take::take(col.as_ref(), &indices, None))
.collect::<Result<_, _>>()
.map_err(|e| hirn_core::error::HirnError::storage(e))?;
let filtered = RecordBatch::try_new(batch.schema(), columns)
.map_err(|e| hirn_core::error::HirnError::storage(e))?;
Ok(Some(filtered))
}