use std::sync::Arc;
use arrow_array::{Array, RecordBatch};
use async_trait::async_trait;
use crate::error::HirnDbError;
use crate::store::{
ColumnTransform, CompactOptions, CompactResult, DatasetInfo, FtsSearchOptions,
HybridSearchOptions, IndexConfig, MultivectorSearchOptions, PhysicalStore, RecordBatchStream,
ScanOptions, VectorSearchOptions, VersionTag,
};
tokio::task_local! {
pub static CURRENT_PRINCIPAL: String;
}
#[async_trait]
pub trait NamespacePolicy: Send + Sync {
async fn allowed_namespaces(&self, principal: &str) -> Option<Vec<String>>;
}
pub struct PolicyEnforcedStore<S: PhysicalStore> {
inner: S,
policy: Arc<dyn NamespacePolicy>,
}
impl<S: PhysicalStore> PolicyEnforcedStore<S> {
pub fn new(inner: S, policy: Arc<dyn NamespacePolicy>) -> Self {
Self { inner, policy }
}
fn current_principal() -> Result<String, HirnDbError> {
CURRENT_PRINCIPAL
.try_with(|p| p.clone())
.map_err(|_| HirnDbError::PolicyViolation("no principal set for current task".into()))
}
fn build_namespace_predicate(allowed: &[String]) -> Option<String> {
if allowed.is_empty() {
return None;
}
let escaped: Vec<String> = allowed
.iter()
.map(|ns| {
let safe = ns.replace('\'', "''");
format!("'{safe}'")
})
.collect();
Some(format!("namespace IN ({})", escaped.join(", ")))
}
fn inject_filter(existing: Option<&str>, ns_pred: &str) -> String {
match existing {
Some(f) if !f.is_empty() => format!("({f}) AND {ns_pred}"),
_ => ns_pred.to_string(),
}
}
fn should_enforce_namespace_filter(dataset: &str) -> bool {
dataset != crate::datasets::resource_blob::DATASET_NAME
}
async fn resolve_allowed(&self) -> Result<Option<Vec<String>>, HirnDbError> {
let principal = Self::current_principal()?;
Ok(self.policy.allowed_namespaces(&principal).await)
}
async fn enforce_scan(
&self,
dataset: &str,
mut opts: ScanOptions,
) -> Result<ScanOptions, HirnDbError> {
if !Self::should_enforce_namespace_filter(dataset) {
return Ok(opts);
}
if let Some(allowed) = self.resolve_allowed().await?
&& let Some(ns_pred) = Self::build_namespace_predicate(&allowed)
{
let new_filter = Self::inject_filter(opts.filter.as_deref(), &ns_pred);
opts.filter = Some(new_filter);
}
Ok(opts)
}
async fn enforce_filter(
&self,
dataset: &str,
filter: Option<String>,
) -> Result<Option<String>, HirnDbError> {
if !Self::should_enforce_namespace_filter(dataset) {
return Ok(filter);
}
if let Some(allowed) = self.resolve_allowed().await?
&& let Some(ns_pred) = Self::build_namespace_predicate(&allowed)
{
let new_filter = Self::inject_filter(filter.as_deref(), &ns_pred);
return Ok(Some(new_filter));
}
Ok(filter)
}
async fn enforce_delete_predicate(
&self,
dataset: &str,
predicate: &str,
) -> Result<String, HirnDbError> {
if !Self::should_enforce_namespace_filter(dataset) {
return Ok(predicate.to_string());
}
if let Some(allowed) = self.resolve_allowed().await?
&& let Some(ns_pred) = Self::build_namespace_predicate(&allowed)
{
return Ok(format!("({predicate}) AND {ns_pred}"));
}
Ok(predicate.to_string())
}
async fn enforce_append(&self, batch: &RecordBatch) -> Result<(), HirnDbError> {
let allowed = match self.resolve_allowed().await? {
Some(a) => a,
None => return Ok(()), };
let schema = batch.schema();
let ns_idx = match schema.index_of("namespace") {
Ok(idx) => idx,
Err(_) => return Ok(()),
};
let col = batch.column(ns_idx);
let ns_array = col
.as_any()
.downcast_ref::<arrow_array::StringArray>()
.ok_or_else(|| HirnDbError::PolicyViolation("namespace column is not Utf8".into()))?;
for i in 0..ns_array.len() {
if ns_array.is_null(i) {
continue;
}
let ns = ns_array.value(i);
if !allowed.iter().any(|a| a == ns) {
return Err(HirnDbError::PolicyViolation(format!(
"write to namespace '{ns}' denied for current principal"
)));
}
}
Ok(())
}
}
#[async_trait]
impl<S: PhysicalStore> PhysicalStore for PolicyEnforcedStore<S> {
async fn append(&self, dataset: &str, batch: RecordBatch) -> Result<(), HirnDbError> {
self.enforce_append(&batch).await?;
self.inner.append(dataset, batch).await
}
async fn append_batches(
&self,
dataset: &str,
batches: Vec<RecordBatch>,
) -> Result<(), HirnDbError> {
for batch in &batches {
self.enforce_append(batch).await?;
}
self.inner.append_batches(dataset, batches).await
}
async fn append_stream(
&self,
dataset: &str,
mut stream: RecordBatchStream,
) -> Result<(), HirnDbError> {
use futures::StreamExt as _;
const MAX_STREAM_BATCH_ROWS: usize = 50_000;
let mut buffer: Vec<RecordBatch> = Vec::new();
let mut buffered_rows: usize = 0;
while let Some(result) = stream.next().await {
let batch = result?;
if batch.num_rows() == 0 {
continue;
}
self.enforce_append(&batch).await?;
buffered_rows += batch.num_rows();
buffer.push(batch);
if buffered_rows >= MAX_STREAM_BATCH_ROWS {
self.inner
.append_batches(dataset, std::mem::take(&mut buffer))
.await?;
buffered_rows = 0;
}
}
if !buffer.is_empty() {
self.inner.append_batches(dataset, buffer).await?;
}
Ok(())
}
async fn scan(
&self,
dataset: &str,
opts: ScanOptions,
) -> Result<Vec<RecordBatch>, HirnDbError> {
let opts = self.enforce_scan(dataset, opts).await?;
self.inner.scan(dataset, opts).await
}
async fn scan_stream(
&self,
dataset: &str,
opts: ScanOptions,
) -> Result<RecordBatchStream, HirnDbError> {
let opts = self.enforce_scan(dataset, opts).await?;
self.inner.scan_stream(dataset, opts).await
}
async fn delete(&self, dataset: &str, predicate: &str) -> Result<u64, HirnDbError> {
let predicate = self.enforce_delete_predicate(dataset, predicate).await?;
self.inner.delete(dataset, &predicate).await
}
async fn merge_insert(
&self,
dataset: &str,
on: &[&str],
batch: RecordBatch,
) -> Result<(), HirnDbError> {
self.enforce_append(&batch).await?;
self.inner.merge_insert(dataset, on, batch).await
}
async fn update_where(
&self,
dataset: &str,
filter: &str,
updates: &[(&str, &str)],
) -> Result<u64, HirnDbError> {
self.inner.update_where(dataset, filter, updates).await
}
async fn count(&self, dataset: &str, filter: Option<&str>) -> Result<u64, HirnDbError> {
let filter_str = self
.enforce_filter(dataset, filter.map(|f| f.to_string()))
.await?;
self.inner.count(dataset, filter_str.as_deref()).await
}
async fn vector_search(
&self,
dataset: &str,
mut opts: VectorSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError> {
opts.filter = self.enforce_filter(dataset, opts.filter).await?;
self.inner.vector_search(dataset, opts).await
}
async fn vector_search_many(
&self,
dataset: &str,
mut queries: Vec<VectorSearchOptions>,
) -> Result<Vec<Vec<RecordBatch>>, HirnDbError> {
for query in &mut queries {
query.filter = self.enforce_filter(dataset, query.filter.take()).await?;
}
self.inner.vector_search_many(dataset, queries).await
}
async fn fts_search(
&self,
dataset: &str,
mut opts: FtsSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError> {
opts.filter = self.enforce_filter(dataset, opts.filter).await?;
self.inner.fts_search(dataset, opts).await
}
async fn hybrid_search(
&self,
dataset: &str,
mut opts: HybridSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError> {
opts.filter = self.enforce_filter(dataset, opts.filter).await?;
self.inner.hybrid_search(dataset, opts).await
}
async fn multivector_search(
&self,
dataset: &str,
mut opts: MultivectorSearchOptions,
) -> Result<Vec<RecordBatch>, HirnDbError> {
opts.filter = self.enforce_filter(dataset, opts.filter).await?;
self.inner.multivector_search(dataset, opts).await
}
async fn create_index(&self, dataset: &str, config: IndexConfig) -> Result<(), HirnDbError> {
self.inner.create_index(dataset, config).await
}
async fn optimize_indices(&self, dataset: &str) -> Result<(), HirnDbError> {
self.inner.optimize_indices(dataset).await
}
async fn compact(
&self,
dataset: &str,
opts: CompactOptions,
) -> Result<CompactResult, HirnDbError> {
self.inner.compact(dataset, opts).await
}
async fn version(&self, dataset: &str) -> Result<u64, HirnDbError> {
self.inner.version(dataset).await
}
async fn tag(&self, dataset: &str, tag: &str) -> Result<(), HirnDbError> {
self.inner.tag(dataset, tag).await
}
async fn checkout(&self, dataset: &str, version: u64) -> Result<(), HirnDbError> {
self.inner.checkout(dataset, version).await
}
async fn list_tags(&self, dataset: &str) -> Result<Vec<VersionTag>, HirnDbError> {
self.inner.list_tags(dataset).await
}
async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, HirnDbError> {
self.inner.list_datasets().await
}
async fn exists(&self, dataset: &str) -> Result<bool, HirnDbError> {
self.inner.exists(dataset).await
}
async fn list_namespaces(&self) -> Result<Vec<String>, HirnDbError> {
self.inner.list_namespaces().await
}
async fn create_namespace(&self, name: &str) -> Result<(), HirnDbError> {
self.inner.create_namespace(name).await
}
async fn drop_namespace(&self, name: &str) -> Result<(), HirnDbError> {
self.inner.drop_namespace(name).await
}
async fn add_columns(
&self,
dataset: &str,
transforms: Vec<ColumnTransform>,
) -> Result<(), HirnDbError> {
self.inner.add_columns(dataset, transforms).await
}
async fn drop_columns(&self, dataset: &str, columns: &[&str]) -> Result<(), HirnDbError> {
self.inner.drop_columns(dataset, columns).await
}
async fn table_provider(
&self,
dataset: &str,
) -> Option<Arc<dyn datafusion::catalog::TableProvider>> {
self.inner.table_provider(dataset).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory_store::MemoryStore;
use arrow_array::{Int64Array, StringArray};
use arrow_schema::{DataType, Field, Schema, SchemaRef};
struct TestPolicy {
allowed: std::collections::HashMap<String, Vec<String>>,
}
impl TestPolicy {
fn new(allowed: Vec<(&str, Vec<&str>)>) -> Self {
Self {
allowed: allowed
.into_iter()
.map(|(k, v)| {
(
k.to_string(),
v.into_iter().map(|s| s.to_string()).collect(),
)
})
.collect(),
}
}
}
#[async_trait]
impl NamespacePolicy for TestPolicy {
async fn allowed_namespaces(&self, principal: &str) -> Option<Vec<String>> {
self.allowed.get(principal).cloned()
}
}
fn test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("namespace", DataType::Utf8, false),
Field::new("value", DataType::Int64, false),
]))
}
fn test_batch(ids: &[&str], namespaces: &[&str], values: &[i64]) -> RecordBatch {
RecordBatch::try_new(
test_schema(),
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(StringArray::from(namespaces.to_vec())),
Arc::new(Int64Array::from(values.to_vec())),
],
)
.unwrap()
}
fn setup_store(allowed: Vec<(&str, Vec<&str>)>) -> PolicyEnforcedStore<MemoryStore> {
let policy = Arc::new(TestPolicy::new(allowed));
PolicyEnforcedStore::new(MemoryStore::new(), policy)
}
#[tokio::test(flavor = "multi_thread")]
async fn scan_injects_namespace_filter() {
let store = setup_store(vec![("agent_a", vec!["ns1", "ns2"])]);
let batch = test_batch(
&["a", "b", "c", "d"],
&["ns1", "ns2", "ns3", "ns1"],
&[1, 2, 3, 4],
);
store.inner.append("test", batch).await.unwrap();
let results = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.scan("test", ScanOptions::default()).await
})
.await
.unwrap();
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 3, "should see 3 rows in ns1+ns2");
}
#[tokio::test(flavor = "multi_thread")]
async fn scan_with_existing_filter_combines() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let batch = test_batch(&["a", "b", "c"], &["ns1", "ns1", "ns2"], &[10, 20, 30]);
store.inner.append("test", batch).await.unwrap();
let results = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store
.scan(
"test",
ScanOptions {
filter: Some("value > 15".to_string()),
..Default::default()
},
)
.await
})
.await
.unwrap();
let total_rows: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(total_rows, 1, "only ns1 row with value 20");
}
#[tokio::test(flavor = "multi_thread")]
async fn append_allowed_namespace_succeeds() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let batch = test_batch(&["x"], &["ns1"], &[42]);
let result = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.append("test", batch).await
})
.await;
assert!(result.is_ok());
}
#[tokio::test(flavor = "multi_thread")]
async fn append_denied_namespace_fails() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let batch = test_batch(&["x"], &["ns2"], &[42]);
let result = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.append("test", batch).await
})
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, HirnDbError::PolicyViolation(_)),
"expected PolicyViolation, got {err:?}"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn no_principal_set_fails_closed() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
store
.inner
.append("test", test_batch(&["a"], &["ns1"], &[1]))
.await
.unwrap();
let result = store.scan("test", ScanOptions::default()).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
HirnDbError::PolicyViolation(_)
));
}
#[tokio::test(flavor = "multi_thread")]
async fn no_policy_restriction_returns_all() {
struct OpenPolicy;
#[async_trait]
impl NamespacePolicy for OpenPolicy {
async fn allowed_namespaces(&self, _principal: &str) -> Option<Vec<String>> {
None
}
}
let store = PolicyEnforcedStore::new(MemoryStore::new(), Arc::new(OpenPolicy));
let batch = test_batch(&["a", "b", "c"], &["ns1", "ns2", "ns3"], &[1, 2, 3]);
store.inner.append("test", batch).await.unwrap();
let results = CURRENT_PRINCIPAL
.scope("anyone".to_string(), async {
store.scan("test", ScanOptions::default()).await
})
.await
.unwrap();
let total: usize = results.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 3, "open policy returns all rows");
}
#[tokio::test(flavor = "multi_thread")]
async fn delete_scoped_to_allowed_namespaces() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let batch = test_batch(&["a", "b", "c"], &["ns1", "ns1", "ns2"], &[1, 2, 3]);
store.inner.append("test", batch).await.unwrap();
let deleted = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.delete("test", "value >= 0").await
})
.await
.unwrap();
assert_eq!(deleted, 2, "only ns1 rows deleted");
let remaining = store
.inner
.scan("test", ScanOptions::default())
.await
.unwrap();
let total: usize = remaining.iter().map(|b| b.num_rows()).sum();
assert_eq!(total, 1, "ns2 row survives");
}
#[tokio::test(flavor = "multi_thread")]
async fn count_respects_policy() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let batch = test_batch(&["a", "b", "c"], &["ns1", "ns2", "ns1"], &[1, 2, 3]);
store.inner.append("test", batch).await.unwrap();
let count = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.count("test", None).await
})
.await
.unwrap();
assert_eq!(count, 2, "only counts ns1 rows");
}
#[test]
fn build_namespace_predicate_escapes_quotes() {
let pred =
PolicyEnforcedStore::<MemoryStore>::build_namespace_predicate(&["it's".to_string()]);
assert_eq!(pred.as_deref(), Some("namespace IN ('it''s')"));
}
#[test]
fn build_namespace_predicate_multiple() {
let pred = PolicyEnforcedStore::<MemoryStore>::build_namespace_predicate(&[
"a".to_string(),
"b".to_string(),
]);
assert_eq!(pred.as_deref(), Some("namespace IN ('a', 'b')"));
}
#[test]
fn build_namespace_predicate_empty() {
let pred = PolicyEnforcedStore::<MemoryStore>::build_namespace_predicate(&[]);
assert!(pred.is_none());
}
#[test]
fn inject_filter_no_existing() {
let result = PolicyEnforcedStore::<MemoryStore>::inject_filter(None, "namespace IN ('a')");
assert_eq!(result, "namespace IN ('a')");
}
#[test]
fn inject_filter_with_existing() {
let result = PolicyEnforcedStore::<MemoryStore>::inject_filter(
Some("value > 5"),
"namespace IN ('a')",
);
assert_eq!(result, "(value > 5) AND namespace IN ('a')");
}
#[tokio::test(flavor = "multi_thread")]
async fn scan_no_namespace_column_passes_through() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
let batch = RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec!["x", "y"]))])
.unwrap();
store.inner.append("no_ns", batch).await.unwrap();
let results = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.scan("no_ns", ScanOptions::default()).await
})
.await;
assert!(results.is_ok() || results.is_err());
}
#[tokio::test(flavor = "multi_thread")]
async fn append_no_namespace_column_allowed() {
let store = setup_store(vec![("agent_a", vec!["ns1"])]);
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
let batch =
RecordBatch::try_new(schema, vec![Arc::new(StringArray::from(vec!["x"]))]).unwrap();
let result = CURRENT_PRINCIPAL
.scope("agent_a".to_string(), async {
store.append("no_ns", batch).await
})
.await;
assert!(result.is_ok());
}
}