use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::path::Path;
use std::sync::Arc;
use arc_swap::ArcSwap;
use lru::LruCache;
use serde::Deserialize;
use tokio::sync::Mutex;
use crate::error::MemoryError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Cardinality {
One,
#[default]
Many,
}
#[derive(Debug, Default)]
struct OntologyState {
alias_to_canonical: HashMap<String, String>,
cardinality: HashMap<String, Cardinality>,
}
impl OntologyState {
fn build(predicates: &[PredicateToml]) -> Self {
let mut alias_to_canonical = HashMap::new();
let mut cardinality = HashMap::new();
for entry in predicates {
let canonical = normalize(&entry.canonical);
let card = match entry.cardinality.as_deref() {
Some("1") => Cardinality::One,
_ => Cardinality::Many,
};
alias_to_canonical.insert(canonical.clone(), canonical.clone());
cardinality.insert(canonical.clone(), card);
for alias in &entry.aliases {
alias_to_canonical.insert(normalize(alias), canonical.clone());
}
}
Self {
alias_to_canonical,
cardinality,
}
}
}
pub struct OntologyTable {
state: ArcSwap<OntologyState>,
cache: Mutex<LruCache<String, String>>,
cache_max: usize,
}
impl std::fmt::Debug for OntologyTable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OntologyTable")
.field("state", &"<ArcSwap<OntologyState>>")
.field("cache", &"<Mutex<LruCache>>")
.field("cache_max", &self.cache_max)
.finish()
}
}
impl OntologyTable {
fn new_with_state(state: OntologyState, cache_max: usize) -> Self {
let cap = NonZeroUsize::new(cache_max.max(1)).expect("cache_max >= 1");
Self {
state: ArcSwap::new(Arc::new(state)),
cache: Mutex::new(LruCache::new(cap)),
cache_max,
}
}
#[must_use]
pub fn from_default(cache_max: usize) -> Self {
let state = OntologyState::build(default_predicates());
Self::new_with_state(state, cache_max)
}
pub async fn from_path(path: &Path, cache_max: usize) -> Result<Self, MemoryError> {
let predicates = if path.as_os_str().is_empty() {
default_predicates().to_vec()
} else {
load_toml_file(path).await?
};
let state = OntologyState::build(&predicates);
Ok(Self::new_with_state(state, cache_max))
}
pub async fn reload(&self, path: &Path) -> Result<(), MemoryError> {
let predicates = if path.as_os_str().is_empty() {
default_predicates().to_vec()
} else {
load_toml_file(path).await?
};
let new_state = Arc::new(OntologyState::build(&predicates));
let mut cache = self.cache.lock().await;
cache.clear();
self.state.store(new_state);
Ok(())
}
pub async fn resolve(&self, raw_predicate: &str) -> (String, bool) {
let key = normalize(raw_predicate);
tracing::debug!(target: "memory.graph.apex.ontology_resolve", predicate = raw_predicate);
{
let mut cache = self.cache.lock().await;
if let Some(canonical) = cache.get(&key) {
return (canonical.clone(), false);
}
}
let state = self.state.load();
if let Some(canonical) = state.alias_to_canonical.get(&key) {
let canonical = canonical.clone();
let mut cache = self.cache.lock().await;
cache.put(key, canonical.clone());
return (canonical, false);
}
let canonical = key.clone();
let mut cache = self.cache.lock().await;
cache.put(key, canonical.clone());
(canonical, true)
}
#[must_use]
pub fn cardinality(&self, canonical_predicate: &str) -> Cardinality {
let key = normalize(canonical_predicate);
self.state
.load()
.cardinality
.get(&key)
.copied()
.unwrap_or_default()
}
}
pub(crate) fn normalize(s: &str) -> String {
s.trim()
.chars()
.filter(|c| !c.is_control())
.collect::<String>()
.to_lowercase()
}
#[derive(Debug, Clone, Deserialize)]
struct OntologyToml {
#[serde(rename = "predicate")]
predicates: Vec<PredicateToml>,
}
#[derive(Debug, Clone, Deserialize)]
struct PredicateToml {
canonical: String,
#[serde(default)]
aliases: Vec<String>,
#[serde(default, deserialize_with = "de_cardinality")]
cardinality: Option<String>,
}
fn de_cardinality<'de, D>(deserializer: D) -> Result<Option<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Visitor;
struct CardVisitor;
impl<'de> Visitor<'de> for CardVisitor {
type Value = Option<String>;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, r#"cardinality string "1" or "n", or integer 1"#)
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(Some(v.to_string()))
}
fn visit_u64<E: serde::de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(Some(if v == 1 {
"1".to_string()
} else {
"n".to_string()
}))
}
fn visit_none<E: serde::de::Error>(self) -> Result<Self::Value, E> {
Ok(None)
}
fn visit_some<D2: serde::Deserializer<'de>>(self, d: D2) -> Result<Self::Value, D2::Error> {
d.deserialize_any(self)
}
}
deserializer.deserialize_option(CardVisitor)
}
async fn load_toml_file(path: &Path) -> Result<Vec<PredicateToml>, MemoryError> {
let content = tokio::fs::read_to_string(path)
.await
.map_err(|e| MemoryError::InvalidInput(format!("ontology TOML read error: {e}")))?;
let parsed: OntologyToml = toml::from_str(&content)
.map_err(|e| MemoryError::InvalidInput(format!("ontology TOML parse error: {e}")))?;
Ok(parsed.predicates)
}
fn make(canonical: &str, aliases: &[&str], cardinality: &str) -> PredicateToml {
PredicateToml {
canonical: canonical.to_string(),
aliases: aliases.iter().map(|s| (*s).to_string()).collect(),
cardinality: Some(cardinality.to_string()),
}
}
fn default_predicates() -> &'static [PredicateToml] {
use std::sync::OnceLock;
static DEFAULTS: OnceLock<Vec<PredicateToml>> = OnceLock::new();
DEFAULTS.get_or_init(|| {
vec![
make("works_at", &["employed_by", "job_at", "works_for"], "1"),
make("lives_in", &["resides_in", "based_in"], "1"),
make("born_in", &["birthplace", "born_at"], "1"),
make("manages", &["manages_team", "leads", "supervises"], "1"),
make("owns", &["has", "possesses"], "n"),
make("depends_on", &["requires", "needs"], "n"),
make("knows", &[], "n"),
]
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn resolves_alias_to_canonical() {
let table = OntologyTable::from_default(64);
let (canonical, unmapped) = table.resolve("employed_by").await;
assert_eq!(canonical, "works_at");
assert!(!unmapped);
}
#[tokio::test]
async fn resolves_canonical_to_itself() {
let table = OntologyTable::from_default(64);
let (canonical, unmapped) = table.resolve("works_at").await;
assert_eq!(canonical, "works_at");
assert!(!unmapped);
}
#[tokio::test]
async fn unknown_predicate_returns_raw_and_unmapped() {
let table = OntologyTable::from_default(64);
let (canonical, unmapped) = table.resolve("some_new_predicate").await;
assert_eq!(canonical, "some_new_predicate");
assert!(unmapped);
}
#[tokio::test]
async fn cardinality_one_predicates() {
let table = OntologyTable::from_default(64);
assert_eq!(table.cardinality("works_at"), Cardinality::One);
assert_eq!(table.cardinality("lives_in"), Cardinality::One);
assert_eq!(table.cardinality("born_in"), Cardinality::One);
assert_eq!(table.cardinality("manages"), Cardinality::One);
}
#[tokio::test]
async fn cardinality_many_predicates() {
let table = OntologyTable::from_default(64);
assert_eq!(table.cardinality("owns"), Cardinality::Many);
assert_eq!(table.cardinality("depends_on"), Cardinality::Many);
assert_eq!(table.cardinality("unknown_pred"), Cardinality::Many);
}
#[tokio::test]
async fn normalize_trims_and_lowercases() {
assert_eq!(normalize(" Works_At "), "works_at");
assert_eq!(normalize("EMPLOYED_BY"), "employed_by");
}
#[tokio::test]
async fn cache_hit_on_second_resolve() {
let table = OntologyTable::from_default(64);
let (c1, _) = table.resolve("job_at").await;
let (c2, _) = table.resolve("job_at").await;
assert_eq!(c1, c2);
assert_eq!(c1, "works_at");
}
#[tokio::test]
async fn reload_clears_cache_and_preserves_resolution() {
let table = OntologyTable::from_default(64);
let _ = table.resolve("job_at").await;
table.reload(Path::new("")).await.unwrap();
let (canonical, _) = table.resolve("job_at").await;
assert_eq!(canonical, "works_at");
}
}