use crate::Confidence;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct SemanticRegistry {
pub embeddings: Vec<f32>,
pub hidden_dim: usize,
pub labels: Vec<LabelDefinition>,
pub label_index: HashMap<String, usize>,
}
#[derive(Debug, Clone)]
pub struct LabelDefinition {
pub slug: String,
pub description: String,
pub category: LabelCategory,
pub modality: ModalityHint,
pub threshold: Confidence,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum LabelCategory {
Entity,
Relation,
Attribute,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum ModalityHint {
#[default]
TextOnly,
Any,
}
impl SemanticRegistry {
pub fn builder() -> SemanticRegistryBuilder {
SemanticRegistryBuilder::new()
}
pub fn len(&self) -> usize {
self.labels.len()
}
pub fn is_empty(&self) -> bool {
self.labels.is_empty()
}
pub fn get_embedding(&self, slug: &str) -> Option<&[f32]> {
let idx = self.label_index.get(slug)?;
let start = idx * self.hidden_dim;
let end = start + self.hidden_dim;
if end <= self.embeddings.len() {
Some(&self.embeddings[start..end])
} else {
None
}
}
pub fn entity_labels(&self) -> impl Iterator<Item = &LabelDefinition> {
self.labels
.iter()
.filter(|l| l.category == LabelCategory::Entity)
}
pub fn relation_labels(&self) -> impl Iterator<Item = &LabelDefinition> {
self.labels
.iter()
.filter(|l| l.category == LabelCategory::Relation)
}
pub fn standard_ner(hidden_dim: usize) -> Self {
let labels = vec![
LabelDefinition {
slug: "person".into(),
description: "A named individual human being".into(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
},
LabelDefinition {
slug: "organization".into(),
description: "A company, institution, agency, or other group".into(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
},
LabelDefinition {
slug: "location".into(),
description: "A geographical place, city, country, or region".into(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
},
LabelDefinition {
slug: "date".into(),
description: "A calendar date or time expression".into(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
},
LabelDefinition {
slug: "money".into(),
description: "A monetary amount with currency".into(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
},
];
let num_labels = labels.len();
let label_index: HashMap<String, usize> = labels
.iter()
.enumerate()
.map(|(i, l)| (l.slug.clone(), i))
.collect();
let embeddings = vec![0.0f32; num_labels * hidden_dim];
Self {
embeddings,
hidden_dim,
labels,
label_index,
}
}
}
#[derive(Debug, Default)]
pub struct SemanticRegistryBuilder {
labels: Vec<LabelDefinition>,
}
impl SemanticRegistryBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_entity(mut self, slug: &str, description: &str) -> Self {
self.labels.push(LabelDefinition {
slug: slug.into(),
description: description.into(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
});
self
}
pub fn add_relation(mut self, slug: &str, description: &str) -> Self {
self.labels.push(LabelDefinition {
slug: slug.into(),
description: description.into(),
category: LabelCategory::Relation,
modality: ModalityHint::TextOnly,
threshold: Confidence::new(0.5),
});
self
}
pub fn add_label(mut self, label: LabelDefinition) -> Self {
self.labels.push(label);
self
}
pub fn build_zero(self, hidden_dim: usize) -> SemanticRegistry {
let num_labels = self.labels.len();
let label_index: HashMap<String, usize> = self
.labels
.iter()
.enumerate()
.map(|(i, l)| (l.slug.clone(), i))
.collect();
SemanticRegistry {
embeddings: vec![0.0f32; num_labels * hidden_dim],
hidden_dim,
labels: self.labels,
label_index,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn standard_ner_has_five_labels() {
let reg = SemanticRegistry::standard_ner(4);
assert_eq!(reg.len(), 5);
assert!(!reg.is_empty());
}
#[test]
fn standard_ner_label_index_consistent() {
let reg = SemanticRegistry::standard_ner(8);
for (i, label) in reg.labels.iter().enumerate() {
assert_eq!(
reg.label_index.get(&label.slug),
Some(&i),
"label_index[{slug}] should map to {i}",
slug = label.slug,
);
}
}
#[test]
fn standard_ner_embedding_dimensions() {
let dim = 16;
let reg = SemanticRegistry::standard_ner(dim);
assert_eq!(reg.hidden_dim, dim);
assert_eq!(reg.embeddings.len(), reg.len() * dim);
}
#[test]
fn standard_ner_all_entity_category() {
let reg = SemanticRegistry::standard_ner(4);
for label in ®.labels {
assert_eq!(label.category, LabelCategory::Entity);
}
}
#[test]
fn get_embedding_returns_correct_slice() {
let dim = 3;
let mut reg = SemanticRegistry::standard_ner(dim);
let idx = reg.label_index["organization"];
let start = idx * dim;
reg.embeddings[start] = 1.0;
reg.embeddings[start + 1] = 2.0;
reg.embeddings[start + 2] = 3.0;
let emb = reg.get_embedding("organization").unwrap();
assert_eq!(emb, &[1.0, 2.0, 3.0]);
}
#[test]
fn get_embedding_returns_none_for_unknown() {
let reg = SemanticRegistry::standard_ner(4);
assert!(reg.get_embedding("nonexistent").is_none());
}
#[test]
fn entity_and_relation_iterators() {
let reg = SemanticRegistry::builder()
.add_entity("person", "a human")
.add_relation("CEO_OF", "chief executive of")
.add_entity("org", "an organization")
.build_zero(4);
let entities: Vec<_> = reg.entity_labels().collect();
let relations: Vec<_> = reg.relation_labels().collect();
assert_eq!(entities.len(), 2);
assert_eq!(relations.len(), 1);
assert_eq!(relations[0].slug, "CEO_OF");
}
#[test]
fn builder_empty_produces_empty_registry() {
let reg = SemanticRegistryBuilder::new().build_zero(8);
assert!(reg.is_empty());
assert_eq!(reg.len(), 0);
assert_eq!(reg.embeddings.len(), 0);
}
#[test]
fn builder_add_label_custom() {
let label = LabelDefinition {
slug: "drug".into(),
description: "a pharmaceutical compound".into(),
category: LabelCategory::Entity,
modality: ModalityHint::Any,
threshold: Confidence::new(0.3),
};
let reg = SemanticRegistry::builder().add_label(label).build_zero(2);
assert_eq!(reg.len(), 1);
assert_eq!(reg.labels[0].modality, ModalityHint::Any);
assert!((reg.labels[0].threshold.value() - 0.3).abs() < f64::EPSILON);
}
#[test]
fn builder_embeddings_are_zeroed() {
let reg = SemanticRegistry::builder()
.add_entity("a", "desc a")
.add_entity("b", "desc b")
.build_zero(4);
assert!(reg.embeddings.iter().all(|&v| v == 0.0));
}
#[test]
fn builder_preserves_insertion_order() {
let reg = SemanticRegistry::builder()
.add_entity("alpha", "first")
.add_relation("BETA", "second")
.add_entity("gamma", "third")
.build_zero(2);
let slugs: Vec<&str> = reg.labels.iter().map(|l| l.slug.as_str()).collect();
assert_eq!(slugs, vec!["alpha", "BETA", "gamma"]);
}
#[test]
fn modality_hint_default_is_text_only() {
assert_eq!(ModalityHint::default(), ModalityHint::TextOnly);
}
#[test]
fn label_category_equality() {
assert_eq!(LabelCategory::Entity, LabelCategory::Entity);
assert_ne!(LabelCategory::Entity, LabelCategory::Relation);
assert_ne!(LabelCategory::Relation, LabelCategory::Attribute);
}
}