use std::collections::HashMap;
use std::future::Future;
use std::sync::Mutex;
use crate::embedding::{EmbeddingError, EmbeddingModel};
use crate::memory::Scope;
use super::cosine::cosine_similarity;
pub const MIN_ENTITY_SIMILARITY: f32 = 0.75;
#[derive(Debug, Clone, PartialEq)]
pub enum Resolution {
Existing {
key: String,
name: String,
},
New {
name: String,
},
}
#[derive(Debug, Clone, PartialEq)]
pub struct EntityVector {
pub key: String,
pub name: String,
pub embedding: Vec<f32>,
}
pub trait EntityCatalog: Send + Sync + 'static {
fn candidates_in_scope(
&self,
scope: &Scope,
) -> impl Future<Output = Result<Vec<EntityVector>, ResolveError>> + Send;
}
pub trait EntityResolver: Send + Sync + 'static {
fn resolve(&self, scope: &Scope, entity: &str) -> impl Future<Output = Result<Resolution, ResolveError>> + Send;
}
#[derive(Debug, thiserror::Error)]
pub enum ResolveError {
#[error("entity catalog read failed: {0}")]
Catalog(String),
#[error("entity embedding failed: {0}")]
Embed(#[from] EmbeddingError),
}
pub struct ExactStringResolver<C> {
catalog: C,
}
impl<C: EntityCatalog> ExactStringResolver<C> {
pub fn new(catalog: C) -> Self {
Self { catalog }
}
}
impl<C: EntityCatalog> EntityResolver for ExactStringResolver<C> {
async fn resolve(&self, scope: &Scope, entity: &str) -> Result<Resolution, ResolveError> {
let candidates = self.catalog.candidates_in_scope(scope).await?;
match candidates.into_iter().find(|candidate| candidate.name == entity) {
Some(matched) => Ok(Resolution::Existing {
key: matched.key,
name: matched.name,
}),
None => Ok(Resolution::New {
name: entity.to_string(),
}),
}
}
}
pub struct EmbeddingEntityResolver<E, C> {
embedder: E,
catalog: C,
min_similarity: f32,
}
impl<E: EmbeddingModel, C: EntityCatalog> EmbeddingEntityResolver<E, C> {
pub fn new(embedder: E, catalog: C) -> Self {
Self {
embedder,
catalog,
min_similarity: MIN_ENTITY_SIMILARITY,
}
}
#[must_use]
pub fn with_min_similarity(mut self, min_similarity: f32) -> Self {
self.min_similarity = min_similarity;
self
}
}
impl<E: EmbeddingModel, C: EntityCatalog> EntityResolver for EmbeddingEntityResolver<E, C> {
async fn resolve(&self, scope: &Scope, entity: &str) -> Result<Resolution, ResolveError> {
let query = self.embedder.embed(entity).await?;
let candidates = self.catalog.candidates_in_scope(scope).await?;
let best = candidates
.into_iter()
.filter_map(|candidate| cosine_similarity(&query, &candidate.embedding).map(|score| (score, candidate)))
.filter(|(score, _)| *score >= self.min_similarity)
.max_by(|(a, _), (b, _)| a.total_cmp(b));
match best {
Some((_, matched)) => Ok(Resolution::Existing {
key: matched.key,
name: matched.name,
}),
None => Ok(Resolution::New {
name: entity.to_string(),
}),
}
}
}
#[derive(Default)]
pub struct InMemoryEntityCatalog {
by_scope: Mutex<HashMap<Scope, Vec<EntityVector>>>,
}
impl InMemoryEntityCatalog {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&self, scope: &Scope, entity: EntityVector) {
self.by_scope
.lock()
.expect("entity catalog mutex poisoned")
.entry(scope.clone())
.or_default()
.push(entity);
}
}
impl EntityCatalog for InMemoryEntityCatalog {
async fn candidates_in_scope(&self, scope: &Scope) -> Result<Vec<EntityVector>, ResolveError> {
Ok(self
.by_scope
.lock()
.expect("entity catalog mutex poisoned")
.get(scope)
.cloned()
.unwrap_or_default())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn scope(user: &str) -> Scope {
Scope {
agent_id: "agent".to_string(),
org_id: "org".to_string(),
user_id: user.to_string(),
}
}
struct FakeEmbedding;
impl EmbeddingModel for FakeEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
let vector = if text.starts_with("Alice") {
vec![1.0, 0.0, 0.0]
} else if text == "Bob" {
vec![0.0, 1.0, 0.0]
} else {
vec![0.0, 0.0, 1.0]
};
Ok(vector)
}
fn dimensions(&self) -> usize {
3
}
}
async fn catalog_with_alice() -> InMemoryEntityCatalog {
let catalog = InMemoryEntityCatalog::new();
catalog.insert(
&scope("u"),
EntityVector {
key: "alice-node".to_string(),
name: "Alice".to_string(),
embedding: vec![1.0, 0.0, 0.0],
},
);
catalog
}
#[tokio::test(flavor = "current_thread")]
async fn should_match_exact_name_with_exact_resolver() {
let resolver = ExactStringResolver::new(catalog_with_alice().await);
let resolution = resolver.resolve(&scope("u"), "Alice").await.unwrap();
assert_eq!(
resolution,
Resolution::Existing {
key: "alice-node".to_string(),
name: "Alice".to_string(),
}
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_create_new_for_surface_variant_with_exact_resolver() {
let resolver = ExactStringResolver::new(catalog_with_alice().await);
let resolution = resolver.resolve(&scope("u"), "Alice Smith").await.unwrap();
assert_eq!(
resolution,
Resolution::New {
name: "Alice Smith".to_string(),
}
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_merge_surface_variant_with_embedding_resolver() {
let resolver = EmbeddingEntityResolver::new(FakeEmbedding, catalog_with_alice().await);
let resolution = resolver.resolve(&scope("u"), "Alice Smith").await.unwrap();
assert_eq!(
resolution,
Resolution::Existing {
key: "alice-node".to_string(),
name: "Alice".to_string(),
}
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_create_new_for_dissimilar_entity_with_embedding_resolver() {
let resolver = EmbeddingEntityResolver::new(FakeEmbedding, catalog_with_alice().await);
let resolution = resolver.resolve(&scope("u"), "Bob").await.unwrap();
assert_eq!(
resolution,
Resolution::New {
name: "Bob".to_string()
}
);
}
#[tokio::test(flavor = "current_thread")]
async fn should_not_merge_across_scopes() {
let resolver = EmbeddingEntityResolver::new(FakeEmbedding, catalog_with_alice().await);
let resolution = resolver.resolve(&scope("other"), "Alice").await.unwrap();
assert_eq!(
resolution,
Resolution::New {
name: "Alice".to_string()
}
);
}
}