use std::sync::Arc;
use hirn_core::HirnResult;
use hirn_core::id::MemoryId;
use hirn_storage::PhysicalStore;
use hirn_storage::store::VectorSearchOptions;
use crate::admission::{AdmissionController, AdmissionDecision, MemoryCandidate};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DuplicateAction {
Reject,
Merge,
}
pub struct DuplicateDetector {
storage: Arc<dyn PhysicalStore>,
dataset: String,
threshold: f32,
action: DuplicateAction,
}
impl DuplicateDetector {
pub fn new(
storage: Arc<dyn PhysicalStore>,
dataset: impl Into<String>,
threshold: f32,
action: DuplicateAction,
) -> Self {
Self {
storage,
dataset: dataset.into(),
threshold,
action,
}
}
pub fn with_defaults(storage: Arc<dyn PhysicalStore>, dataset: impl Into<String>) -> Self {
Self::new(storage, dataset, 0.05, DuplicateAction::Reject)
}
}
#[async_trait::async_trait]
impl AdmissionController for DuplicateDetector {
fn name(&self) -> &str {
"duplicate_detector"
}
async fn evaluate(&self, candidate: &MemoryCandidate) -> HirnResult<AdmissionDecision> {
let embedding = match &candidate.embedding {
Some(emb) => emb,
None => {
return Ok(AdmissionDecision::Accept {
importance_override: None,
});
}
};
let exists = self
.storage
.exists(&self.dataset)
.await
.map_err(hirn_core::HirnError::storage)?;
if !exists {
return Ok(AdmissionDecision::Accept {
importance_override: None,
});
}
let options = VectorSearchOptions {
query: embedding.clone(),
column: "embedding".into(),
limit: 1,
..Default::default()
};
let batches = self
.storage
.vector_search(&self.dataset, options)
.await
.map_err(hirn_core::HirnError::storage)?;
match extract_nearest(&batches) {
None => Ok(AdmissionDecision::Accept {
importance_override: None,
}),
Some((distance, target_id)) => {
if distance <= self.threshold {
match self.action {
DuplicateAction::Reject => Ok(AdmissionDecision::Reject {
reason: format!("duplicate of {target_id} (distance {distance:.4})"),
}),
DuplicateAction::Merge => {
Ok(AdmissionDecision::Merge { target: target_id })
}
}
} else {
Ok(AdmissionDecision::Accept {
importance_override: None,
})
}
}
}
}
}
fn extract_nearest(batches: &[arrow_array::RecordBatch]) -> Option<(f32, MemoryId)> {
for batch in batches {
let distance = extract_distance(batch)?;
let id = extract_id(batch)?;
return Some((distance, id));
}
None
}
fn extract_distance(batch: &arrow_array::RecordBatch) -> Option<f32> {
use arrow_array::Array;
let col = batch.column_by_name("_distance")?;
if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float32Array>() {
if arr.len() > 0 {
return Some(arr.value(0));
}
}
if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
if arr.len() > 0 {
return Some(arr.value(0) as f32);
}
}
None
}
fn extract_id(batch: &arrow_array::RecordBatch) -> Option<MemoryId> {
use arrow_array::Array;
let col = batch.column_by_name("id")?;
if let Some(arr) = col.as_any().downcast_ref::<arrow_array::StringArray>() {
if arr.len() > 0 {
return MemoryId::parse(arr.value(0)).ok();
}
}
if let Some(arr) = col.as_any().downcast_ref::<arrow_array::LargeStringArray>() {
if arr.len() > 0 {
return MemoryId::parse(arr.value(0)).ok();
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use hirn_core::metadata::Metadata;
use hirn_core::types::{AgentId, Namespace};
use hirn_storage::{HirnDb, HirnDbConfig};
fn candidate(embedding: Option<Vec<f32>>) -> MemoryCandidate {
MemoryCandidate {
id: MemoryId::new(),
content: "test content".into(),
entities: vec![],
embedding,
agent_id: AgentId::new("test").unwrap(),
namespace: Namespace::shared(),
importance: 0.5,
surprise: 0.5,
metadata: Metadata::default(),
}
}
fn rand_vec(seed: u128) -> Vec<f32> {
(0..32)
.map(|i| (seed as f64 * 0.618_033 + i as f64 * 0.414_213).sin() as f32)
.collect()
}
async fn temp_storage() -> (Arc<dyn PhysicalStore>, tempfile::TempDir) {
let dir = tempfile::tempdir().unwrap();
let lance_path = dir.path().join("lance");
let config = HirnDbConfig::local(lance_path.to_str().unwrap());
let backend = HirnDb::open(config.clone()).await.unwrap();
(backend.store_arc(), dir)
}
async fn insert_record(storage: &Arc<dyn PhysicalStore>, emb: Vec<f32>) -> MemoryId {
let rec = hirn_core::episodic::EpisodicRecord::builder()
.content("existing memory")
.embedding(emb)
.agent_id(AgentId::new("test").unwrap())
.build()
.unwrap();
let id = rec.id;
let batch =
hirn_storage::datasets::episodic::to_batch(std::slice::from_ref(&rec), 32).unwrap();
storage.append("episodic", batch).await.unwrap();
id
}
#[tokio::test]
async fn no_embedding_accepts() {
let (storage, _dir) = temp_storage().await;
let det = DuplicateDetector::with_defaults(storage, "episodic");
let result = det.evaluate(&candidate(None)).await.unwrap();
assert!(result.is_accept());
}
#[tokio::test]
async fn empty_database_accepts() {
let (storage, _dir) = temp_storage().await;
let det = DuplicateDetector::with_defaults(storage, "episodic");
let result = det.evaluate(&candidate(Some(rand_vec(1)))).await.unwrap();
assert!(result.is_accept());
}
#[tokio::test]
async fn exact_duplicate_rejected() {
let (storage, _dir) = temp_storage().await;
let emb = rand_vec(42);
insert_record(&storage, emb.clone()).await;
let det = DuplicateDetector::new(storage, "episodic", 0.05, DuplicateAction::Reject);
let result = det.evaluate(&candidate(Some(emb))).await.unwrap();
assert!(result.is_reject());
}
#[tokio::test]
async fn exact_duplicate_merged() {
let (storage, _dir) = temp_storage().await;
let emb = rand_vec(42);
let target_id = insert_record(&storage, emb.clone()).await;
let det = DuplicateDetector::new(storage, "episodic", 0.05, DuplicateAction::Merge);
let result = det.evaluate(&candidate(Some(emb))).await.unwrap();
if let AdmissionDecision::Merge { target } = result {
assert_eq!(target, target_id);
} else {
panic!("expected Merge decision, got: {:?}", result);
}
}
#[tokio::test]
async fn distinct_memory_accepted() {
let (storage, _dir) = temp_storage().await;
let emb1 = rand_vec(1);
insert_record(&storage, emb1).await;
let emb2 = rand_vec(999);
let det = DuplicateDetector::with_defaults(storage, "episodic");
let result = det.evaluate(&candidate(Some(emb2))).await.unwrap();
assert!(result.is_accept());
}
#[tokio::test]
async fn configurable_threshold() {
let (storage, _dir) = temp_storage().await;
let emb = rand_vec(42);
insert_record(&storage, emb.clone()).await;
let det = DuplicateDetector::new(storage, "episodic", 100.0, DuplicateAction::Reject);
let result = det.evaluate(&candidate(Some(emb))).await.unwrap();
assert!(result.is_reject());
}
#[tokio::test]
async fn near_duplicate_detected() {
let (storage, _dir) = temp_storage().await;
let emb = rand_vec(42);
insert_record(&storage, emb.clone()).await;
let near_dup: Vec<f32> = emb
.iter()
.enumerate()
.map(|(i, &x)| x + (i as f32 * 0.0001))
.collect();
let det = DuplicateDetector::with_defaults(storage, "episodic");
let result = det.evaluate(&candidate(Some(near_dup))).await.unwrap();
assert!(
result.is_reject(),
"near-duplicate should be detected as duplicate"
);
}
}