use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use serde::{Deserialize, Serialize};
use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
const OVERFETCH_FACTOR: usize = 5;
const OVERFETCH_FLOOR: usize = 20;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct DocAcl {
#[serde(default)]
pub public: bool,
#[serde(default)]
pub users: Vec<String>,
#[serde(default)]
pub groups: Vec<String>,
}
impl DocAcl {
pub const ACL_METADATA_KEY: &'static str = "acl_v2";
#[must_use]
pub fn public() -> Self {
Self {
public: true,
..Self::default()
}
}
#[must_use]
pub fn for_users<I, S>(users: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
public: false,
users: users.into_iter().map(Into::into).collect(),
groups: Vec::new(),
}
}
#[must_use]
pub fn for_groups<I, S>(groups: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
public: false,
users: Vec::new(),
groups: groups.into_iter().map(Into::into).collect(),
}
}
#[must_use]
pub fn with_users<I, S>(mut self, users: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.users.extend(users.into_iter().map(Into::into));
self
}
#[must_use]
pub fn with_groups<I, S>(mut self, groups: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.groups.extend(groups.into_iter().map(Into::into));
self
}
#[must_use]
pub fn attach_to(&self, doc: Document) -> Document {
let json = serde_json::to_string(self).expect("DocAcl always serializes");
doc.with_metadata(Self::ACL_METADATA_KEY, json)
}
#[must_use]
pub fn from_metadata(metadata: &HashMap<String, String>) -> Option<Self> {
let raw = metadata.get(Self::ACL_METADATA_KEY)?;
serde_json::from_str(raw).ok()
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct AccessContext {
pub user_id: Option<String>,
pub groups: Vec<String>,
}
impl AccessContext {
#[must_use]
pub fn new(user_id: Option<String>, groups: Vec<String>) -> Self {
Self { user_id, groups }
}
#[must_use]
pub fn for_user(user_id: impl Into<String>) -> Self {
Self {
user_id: Some(user_id.into()),
groups: Vec::new(),
}
}
#[must_use]
pub fn anonymous() -> Self {
Self::default()
}
#[must_use]
pub fn with_group(mut self, group: impl Into<String>) -> Self {
self.groups.push(group.into());
self
}
#[must_use]
pub fn can_access(&self, acl: &DocAcl) -> bool {
if acl.public {
return true;
}
if let Some(uid) = &self.user_id {
if acl.users.iter().any(|u| u == uid) {
return true;
}
}
self.groups.iter().any(|g| acl.groups.contains(g))
}
}
type AclTable = Arc<RwLock<HashMap<String, DocAcl>>>;
#[derive(Clone)]
pub struct AclKnowledgeStore {
inner: Arc<dyn KnowledgeBase>,
acls: AclTable,
}
impl AclKnowledgeStore {
#[must_use]
pub fn new(inner: Arc<dyn KnowledgeBase>) -> Self {
Self {
inner,
acls: Arc::new(RwLock::new(HashMap::new())),
}
}
#[must_use]
pub fn ingest_handle(&self) -> Arc<dyn KnowledgeBase> {
Arc::new(AclIngestHandle {
inner: Arc::clone(&self.inner),
acls: Arc::clone(&self.acls),
})
}
#[must_use]
pub fn reader(&self, ctx: AccessContext) -> Arc<dyn KnowledgeBase> {
Arc::new(AclReader {
inner: Arc::clone(&self.inner),
acls: Arc::clone(&self.acls),
ctx,
})
}
pub fn record_acl(&self, document_id: impl Into<String>, acl: DocAcl) -> anyhow::Result<()> {
let mut table = self
.acls
.write()
.map_err(|e| anyhow::anyhow!("acl table lock poisoned: {e}"))?;
table.insert(document_id.into(), acl);
Ok(())
}
}
struct AclIngestHandle {
inner: Arc<dyn KnowledgeBase>,
acls: AclTable,
}
impl KnowledgeBase for AclIngestHandle {
fn ingest(&self, doc: Document) -> anyhow::Result<()> {
if let Some(acl) = DocAcl::from_metadata(&doc.metadata) {
let mut table = self
.acls
.write()
.map_err(|e| anyhow::anyhow!("acl table lock poisoned: {e}"))?;
table.insert(doc.id.clone(), acl);
}
self.inner.ingest(doc)
}
fn query(&self, query: &str, limit: usize) -> anyhow::Result<Vec<KnowledgeResult>> {
self.inner.query(query, limit)
}
}
struct AclReader {
inner: Arc<dyn KnowledgeBase>,
acls: AclTable,
ctx: AccessContext,
}
impl KnowledgeBase for AclReader {
fn ingest(&self, doc: Document) -> anyhow::Result<()> {
if let Some(acl) = DocAcl::from_metadata(&doc.metadata) {
let mut table = self
.acls
.write()
.map_err(|e| anyhow::anyhow!("acl table lock poisoned: {e}"))?;
table.insert(doc.id.clone(), acl);
}
self.inner.ingest(doc)
}
fn query(&self, query: &str, limit: usize) -> anyhow::Result<Vec<KnowledgeResult>> {
let candidate_n = limit.saturating_mul(OVERFETCH_FACTOR).max(OVERFETCH_FLOOR);
let candidates = self.inner.query(query, candidate_n)?;
let table = self
.acls
.read()
.map_err(|e| anyhow::anyhow!("acl table lock poisoned: {e}"))?;
let mut out = Vec::with_capacity(limit.min(candidates.len()));
for result in candidates {
let allowed = match table.get(&result.document_id) {
Some(acl) => self.ctx.can_access(acl),
None => true,
};
if allowed {
out.push(result);
if out.len() == limit {
break;
}
}
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_access_public_allows_anyone() {
let acl = DocAcl::public();
assert!(AccessContext::anonymous().can_access(&acl));
assert!(AccessContext::for_user("anyone").can_access(&acl));
}
#[test]
fn can_access_user_match() {
let acl = DocAcl::for_users(["alice"]);
assert!(AccessContext::for_user("alice").can_access(&acl));
}
#[test]
fn can_access_user_no_match_is_denied() {
let acl = DocAcl::for_users(["alice"]);
assert!(!AccessContext::for_user("bob").can_access(&acl));
assert!(!AccessContext::anonymous().can_access(&acl));
}
#[test]
fn can_access_group_match() {
let acl = DocAcl::for_groups(["support"]);
let ctx = AccessContext::new(Some("carol".into()), vec!["support".into()]);
assert!(ctx.can_access(&acl));
}
#[test]
fn can_access_group_no_match_is_denied() {
let acl = DocAcl::for_groups(["support"]);
let ctx = AccessContext::new(Some("dave".into()), vec!["billing".into()]);
assert!(!ctx.can_access(&acl));
}
#[test]
fn can_access_empty_acl_is_fully_locked() {
let acl = DocAcl::default();
assert!(!AccessContext::for_user("alice").can_access(&acl));
assert!(!AccessContext::anonymous().can_access(&acl));
let grouped = AccessContext::new(Some("x".into()), vec!["g".into()]);
assert!(!grouped.can_access(&acl));
}
#[test]
fn can_access_mixed_user_or_group() {
let acl = DocAcl::for_users(["alice"]).with_groups(["support"]);
assert!(AccessContext::for_user("alice").can_access(&acl));
let grp = AccessContext::new(Some("zed".into()), vec!["support".into()]);
assert!(grp.can_access(&acl));
let neither = AccessContext::new(Some("zed".into()), vec!["billing".into()]);
assert!(!neither.can_access(&acl));
}
#[test]
fn docacl_round_trips_through_metadata() {
let acl = DocAcl::for_users(["alice", "bob"]).with_groups(["support"]);
let doc = acl.attach_to(Document::new(
"c",
"s",
smooth_operator_core::DocumentType::Documentation,
));
let parsed = DocAcl::from_metadata(&doc.metadata).expect("acl present");
assert_eq!(parsed, acl);
}
#[test]
fn from_metadata_absent_is_none() {
let doc = Document::new("c", "s", smooth_operator_core::DocumentType::Documentation);
assert!(DocAcl::from_metadata(&doc.metadata).is_none());
}
#[test]
fn from_metadata_malformed_is_none() {
let mut metadata = HashMap::new();
metadata.insert(
DocAcl::ACL_METADATA_KEY.to_string(),
"{not json".to_string(),
);
assert!(DocAcl::from_metadata(&metadata).is_none());
}
}