use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use serde::{Deserialize, Serialize};
use smooth_operator_core::{Document, KnowledgeBase, KnowledgeResult};
use crate::access_control::{AccessContext, DocAcl};
const OVERFETCH_FACTOR: usize = 5;
const OVERFETCH_FLOOR: usize = 20;
pub const DEFAULT_BOOST: f32 = 1.0;
#[derive(Debug, Clone, PartialEq)]
pub struct DocMeta {
pub document_sets: Vec<String>,
pub boost: f32,
pub metadata: HashMap<String, String>,
}
impl DocMeta {
pub const DOCUMENT_SET_KEY: &'static str = "document_set";
pub const BOOST_KEY: &'static str = "boost";
#[must_use]
pub fn parse_sets(metadata: &HashMap<String, String>) -> Vec<String> {
metadata
.get(Self::DOCUMENT_SET_KEY)
.map(|raw| {
raw.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(ToString::to_string)
.collect()
})
.unwrap_or_default()
}
#[must_use]
pub fn parse_boost(metadata: &HashMap<String, String>) -> f32 {
metadata
.get(Self::BOOST_KEY)
.and_then(|raw| raw.trim().parse::<f32>().ok())
.filter(|b| b.is_finite())
.map(|b| b.max(0.0))
.unwrap_or(DEFAULT_BOOST)
}
#[must_use]
pub fn from_metadata(metadata: &HashMap<String, String>) -> Self {
Self {
document_sets: Self::parse_sets(metadata),
boost: Self::parse_boost(metadata),
metadata: metadata.clone(),
}
}
#[must_use]
pub fn from_document(doc: &Document) -> Self {
Self::from_metadata(&doc.metadata)
}
#[must_use]
pub fn in_set(&self, set: &str) -> bool {
self.document_sets.iter().any(|s| s == set)
}
}
#[must_use]
pub fn with_document_set<I, S>(doc: Document, sets: I) -> Document
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
let joined = sets
.into_iter()
.map(Into::into)
.filter(|s| !s.trim().is_empty())
.collect::<Vec<_>>()
.join(",");
if joined.is_empty() {
doc
} else {
doc.with_metadata(DocMeta::DOCUMENT_SET_KEY, joined)
}
}
#[must_use]
pub fn with_boost(doc: Document, boost: f32) -> Document {
let boost = if boost.is_finite() {
boost.max(0.0)
} else {
DEFAULT_BOOST
};
doc.with_metadata(DocMeta::BOOST_KEY, format!("{boost}"))
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RetrievalFilter {
#[serde(default)]
pub document_sets: Option<Vec<String>>,
#[serde(default)]
pub metadata_eq: HashMap<String, String>,
}
impl RetrievalFilter {
#[must_use]
pub fn none() -> Self {
Self::default()
}
#[must_use]
pub fn in_sets<I, S>(sets: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
document_sets: Some(sets.into_iter().map(Into::into).collect()),
metadata_eq: HashMap::new(),
}
}
#[must_use]
pub fn with_metadata_eq(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata_eq.insert(key.into(), value.into());
self
}
#[must_use]
pub fn is_unconstrained(&self) -> bool {
self.document_sets.is_none() && self.metadata_eq.is_empty()
}
#[must_use]
pub fn matches(&self, meta: &DocMeta) -> bool {
if let Some(sets) = &self.document_sets {
if !sets.iter().any(|s| meta.in_set(s)) {
return false;
}
}
self.metadata_eq
.iter()
.all(|(k, v)| meta.metadata.get(k).is_some_and(|mv| mv == v))
}
}
type MetaTable = Arc<RwLock<HashMap<String, DocMeta>>>;
type AclTable = Arc<RwLock<HashMap<String, DocAcl>>>;
#[derive(Clone)]
pub struct CuratedKnowledgeStore {
inner: Arc<dyn KnowledgeBase>,
meta: MetaTable,
acls: AclTable,
}
impl CuratedKnowledgeStore {
#[must_use]
pub fn new(inner: Arc<dyn KnowledgeBase>) -> Self {
Self {
inner,
meta: Arc::new(RwLock::new(HashMap::new())),
acls: Arc::new(RwLock::new(HashMap::new())),
}
}
#[must_use]
pub fn ingest_handle(&self) -> Arc<dyn KnowledgeBase> {
Arc::new(CuratedIngestHandle {
inner: Arc::clone(&self.inner),
meta: Arc::clone(&self.meta),
acls: Arc::clone(&self.acls),
})
}
#[must_use]
pub fn reader(&self, filter: RetrievalFilter, access: AccessContext) -> Arc<dyn KnowledgeBase> {
Arc::new(CuratedReader {
inner: Arc::clone(&self.inner),
meta: Arc::clone(&self.meta),
acls: Arc::clone(&self.acls),
filter,
access,
})
}
pub fn record_meta(&self, document_id: impl Into<String>, meta: DocMeta) -> anyhow::Result<()> {
let mut table = self
.meta
.write()
.map_err(|e| anyhow::anyhow!("curation meta table lock poisoned: {e}"))?;
table.insert(document_id.into(), meta);
Ok(())
}
}
struct CuratedIngestHandle {
inner: Arc<dyn KnowledgeBase>,
meta: MetaTable,
acls: AclTable,
}
fn record_ingest_metadata(meta: &MetaTable, acls: &AclTable, doc: &Document) -> anyhow::Result<()> {
{
let mut table = meta
.write()
.map_err(|e| anyhow::anyhow!("curation meta table lock poisoned: {e}"))?;
table.insert(doc.id.clone(), DocMeta::from_document(doc));
}
if let Some(acl) = DocAcl::from_metadata(&doc.metadata) {
let mut table = acls
.write()
.map_err(|e| anyhow::anyhow!("acl table lock poisoned: {e}"))?;
table.insert(doc.id.clone(), acl);
}
Ok(())
}
impl KnowledgeBase for CuratedIngestHandle {
fn ingest(&self, doc: Document) -> anyhow::Result<()> {
record_ingest_metadata(&self.meta, &self.acls, &doc)?;
self.inner.ingest(doc)
}
fn query(&self, query: &str, limit: usize) -> anyhow::Result<Vec<KnowledgeResult>> {
self.inner.query(query, limit)
}
}
struct CuratedReader {
inner: Arc<dyn KnowledgeBase>,
meta: MetaTable,
acls: AclTable,
filter: RetrievalFilter,
access: AccessContext,
}
impl KnowledgeBase for CuratedReader {
fn ingest(&self, doc: Document) -> anyhow::Result<()> {
record_ingest_metadata(&self.meta, &self.acls, &doc)?;
self.inner.ingest(doc)
}
fn query(&self, query: &str, limit: usize) -> anyhow::Result<Vec<KnowledgeResult>> {
if limit == 0 {
return Ok(Vec::new());
}
let candidate_n = limit.saturating_mul(OVERFETCH_FACTOR).max(OVERFETCH_FLOOR);
let candidates = self.inner.query(query, candidate_n)?;
let meta_table = self
.meta
.read()
.map_err(|e| anyhow::anyhow!("curation meta table lock poisoned: {e}"))?;
let acl_table = self
.acls
.read()
.map_err(|e| anyhow::anyhow!("acl table lock poisoned: {e}"))?;
let mut kept: Vec<KnowledgeResult> = Vec::with_capacity(candidates.len());
for mut result in candidates {
let acl_ok = match acl_table.get(&result.document_id) {
Some(acl) => self.access.can_access(acl),
None => true,
};
if !acl_ok {
continue;
}
let doc_meta = meta_table.get(&result.document_id).cloned();
let meta_for_match = doc_meta.clone().unwrap_or_else(|| DocMeta {
document_sets: Vec::new(),
boost: DEFAULT_BOOST,
metadata: HashMap::new(),
});
if !self.filter.matches(&meta_for_match) {
continue;
}
result.score *= meta_for_match.boost;
kept.push(result);
}
kept.sort_by(|a, b| b.score.total_cmp(&a.score));
kept.truncate(limit);
Ok(kept)
}
}
#[cfg(test)]
mod tests {
use super::*;
use smooth_operator_core::DocumentType;
fn doc(id: &str, content: &str) -> Document {
let mut d = Document::new(content, "s", DocumentType::Documentation);
d.id = id.to_string();
d
}
#[test]
fn parse_sets_single_and_multi() {
let d = with_document_set(doc("a", "x"), ["alpha"]);
assert_eq!(DocMeta::parse_sets(&d.metadata), vec!["alpha".to_string()]);
let d = with_document_set(doc("b", "x"), ["alpha", "beta"]);
assert_eq!(
DocMeta::parse_sets(&d.metadata),
vec!["alpha".to_string(), "beta".to_string()]
);
}
#[test]
fn parse_sets_trims_and_drops_empties() {
let d = doc("c", "x").with_metadata(DocMeta::DOCUMENT_SET_KEY, " alpha , , beta ,");
assert_eq!(
DocMeta::parse_sets(&d.metadata),
vec!["alpha".to_string(), "beta".to_string()]
);
}
#[test]
fn parse_sets_absent_is_empty() {
let d = doc("d", "x");
assert!(DocMeta::parse_sets(&d.metadata).is_empty());
}
#[test]
fn parse_boost_default_when_absent() {
let d = doc("e", "x");
assert_eq!(DocMeta::parse_boost(&d.metadata), DEFAULT_BOOST);
}
#[test]
fn parse_boost_parses_valid() {
let d = with_boost(doc("f", "x"), 3.0);
assert!((DocMeta::parse_boost(&d.metadata) - 3.0).abs() < f32::EPSILON);
}
#[test]
fn parse_boost_malformed_falls_back_to_default() {
for bad in ["abc", "", " ", "NaN", "inf", "1.2.3"] {
let d = doc("g", "x").with_metadata(DocMeta::BOOST_KEY, bad);
assert_eq!(
DocMeta::parse_boost(&d.metadata),
DEFAULT_BOOST,
"malformed boost {bad:?} must fall back to default"
);
}
}
#[test]
fn parse_boost_negative_is_clamped_to_zero() {
let d = doc("h", "x").with_metadata(DocMeta::BOOST_KEY, "-2.0");
assert_eq!(DocMeta::parse_boost(&d.metadata), 0.0);
}
#[test]
fn with_boost_normalizes_non_finite() {
let d = with_boost(doc("i", "x"), f32::NAN);
assert_eq!(DocMeta::parse_boost(&d.metadata), DEFAULT_BOOST);
let d = with_boost(doc("j", "x"), f32::INFINITY);
assert_eq!(DocMeta::parse_boost(&d.metadata), DEFAULT_BOOST);
}
fn meta(sets: &[&str], boost: f32, kv: &[(&str, &str)]) -> DocMeta {
let mut metadata = HashMap::new();
for (k, v) in kv {
metadata.insert((*k).to_string(), (*v).to_string());
}
DocMeta {
document_sets: sets.iter().map(ToString::to_string).collect(),
boost,
metadata,
}
}
#[test]
fn unconstrained_filter_matches_everything() {
let f = RetrievalFilter::none();
assert!(f.is_unconstrained());
assert!(f.matches(&meta(&[], 1.0, &[])));
assert!(f.matches(&meta(&["alpha"], 1.0, &[("kind", "code")])));
}
#[test]
fn set_scope_matches_only_members() {
let f = RetrievalFilter::in_sets(["alpha"]);
assert!(!f.is_unconstrained());
assert!(f.matches(&meta(&["alpha"], 1.0, &[])));
assert!(f.matches(&meta(&["alpha", "beta"], 1.0, &[]))); assert!(!f.matches(&meta(&["beta"], 1.0, &[])));
assert!(!f.matches(&meta(&[], 1.0, &[]))); }
#[test]
fn set_scope_union_across_listed_sets() {
let f = RetrievalFilter::in_sets(["alpha", "gamma"]);
assert!(f.matches(&meta(&["gamma"], 1.0, &[])));
assert!(f.matches(&meta(&["alpha"], 1.0, &[])));
assert!(!f.matches(&meta(&["beta"], 1.0, &[])));
}
#[test]
fn empty_set_list_matches_nothing() {
let f = RetrievalFilter {
document_sets: Some(vec![]),
metadata_eq: HashMap::new(),
};
assert!(!f.matches(&meta(&["alpha"], 1.0, &[])));
assert!(!f.matches(&meta(&[], 1.0, &[])));
}
#[test]
fn metadata_eq_requires_all_equalities() {
let f = RetrievalFilter::none()
.with_metadata_eq("kind", "prose")
.with_metadata_eq("lang", "en");
assert!(f.matches(&meta(&[], 1.0, &[("kind", "prose"), ("lang", "en")])));
assert!(!f.matches(&meta(&[], 1.0, &[("kind", "prose")])));
assert!(!f.matches(&meta(&[], 1.0, &[("kind", "code"), ("lang", "en")])));
}
#[test]
fn set_and_metadata_compose_with_and() {
let f = RetrievalFilter::in_sets(["alpha"]).with_metadata_eq("kind", "prose");
assert!(f.matches(&meta(&["alpha"], 1.0, &[("kind", "prose")])));
assert!(!f.matches(&meta(&["alpha"], 1.0, &[("kind", "code")]))); assert!(!f.matches(&meta(&["beta"], 1.0, &[("kind", "prose")]))); }
#[test]
fn filter_round_trips_through_json() {
let f = RetrievalFilter::in_sets(["alpha", "beta"]).with_metadata_eq("kind", "prose");
let json = serde_json::to_string(&f).expect("serialize");
let parsed: RetrievalFilter = serde_json::from_str(&json).expect("deserialize");
assert_eq!(parsed, f);
}
fn curated_store() -> CuratedKnowledgeStore {
CuratedKnowledgeStore::new(Arc::new(smooth_operator_core::InMemoryKnowledge::new()))
}
#[test]
fn reader_with_no_filter_returns_all_with_boost_applied() {
let store = curated_store();
let h = store.ingest_handle();
h.ingest(with_document_set(
doc("a", "clearance alpha fact"),
["alpha"],
))
.unwrap();
h.ingest(doc("plain", "clearance plain fact")).unwrap();
let r = store.reader(RetrievalFilter::none(), AccessContext::anonymous());
let ids: Vec<String> = r
.query("clearance", 10)
.unwrap()
.into_iter()
.map(|x| x.document_id)
.collect();
assert!(ids.contains(&"a".to_string()));
assert!(ids.contains(&"plain".to_string()));
}
#[test]
fn malformed_boost_metadata_yields_default_boost_at_read() {
let store = curated_store();
let h = store.ingest_handle();
h.ingest(doc("bad", "clearance fact").with_metadata(DocMeta::BOOST_KEY, "not-a-number"))
.unwrap();
let r = store.reader(RetrievalFilter::none(), AccessContext::anonymous());
let hits = r.query("clearance", 10).unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].document_id, "bad");
}
}