use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::RwLock;
use cognis_core::documents::Document;
use cognis_core::error::Result;
use cognis_core::vectorstores::base::VectorStore;
use crate::text_splitter::TextSplitter;
#[async_trait]
pub trait RecordManager: Send + Sync {
async fn exists(&self, keys: &[String]) -> Result<Vec<bool>>;
async fn update(&self, keys: &[String], group_ids: Option<&[String]>) -> Result<()>;
async fn delete_keys(&self, keys: &[String]) -> Result<()>;
async fn list_keys(&self, group_id: Option<&str>) -> Result<Vec<String>>;
}
pub struct InMemoryRecordManager {
records: RwLock<HashMap<String, Option<String>>>,
}
impl InMemoryRecordManager {
pub fn new() -> Self {
Self {
records: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemoryRecordManager {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RecordManager for InMemoryRecordManager {
async fn exists(&self, keys: &[String]) -> Result<Vec<bool>> {
let records = self.records.read().await;
Ok(keys.iter().map(|k| records.contains_key(k)).collect())
}
async fn update(&self, keys: &[String], group_ids: Option<&[String]>) -> Result<()> {
let mut records = self.records.write().await;
for (i, key) in keys.iter().enumerate() {
let group_id = group_ids.and_then(|g| g.get(i).cloned());
records.insert(key.clone(), group_id);
}
Ok(())
}
async fn delete_keys(&self, keys: &[String]) -> Result<()> {
let mut records = self.records.write().await;
for key in keys {
records.remove(key);
}
Ok(())
}
async fn list_keys(&self, group_id: Option<&str>) -> Result<Vec<String>> {
let records = self.records.read().await;
let keys = match group_id {
Some(gid) => records
.iter()
.filter(|(_, v)| v.as_deref() == Some(gid))
.map(|(k, _)| k.clone())
.collect(),
None => records.keys().cloned().collect(),
};
Ok(keys)
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum CleanupMode {
#[default]
None,
Incremental,
Full,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct IndexingResult {
pub num_added: usize,
pub num_skipped: usize,
pub num_deleted: usize,
}
pub fn content_hash(doc: &Document) -> String {
let mut hasher = DefaultHasher::new();
doc.page_content.hash(&mut hasher);
let mut keys: Vec<&String> = doc.metadata.keys().collect();
keys.sort();
for key in keys {
key.hash(&mut hasher);
let val = &doc.metadata[key];
val.to_string().hash(&mut hasher);
}
format!("{:016x}", hasher.finish())
}
pub struct IndexingPipeline {
pub text_splitter: Option<Box<dyn TextSplitter>>,
pub vectorstore: Arc<dyn VectorStore>,
pub record_manager: Option<Box<dyn RecordManager>>,
pub cleanup_mode: CleanupMode,
}
impl IndexingPipeline {
pub fn new(vectorstore: Arc<dyn VectorStore>) -> Self {
Self {
text_splitter: None,
vectorstore,
record_manager: None,
cleanup_mode: CleanupMode::None,
}
}
pub fn with_text_splitter(mut self, splitter: Box<dyn TextSplitter>) -> Self {
self.text_splitter = Some(splitter);
self
}
pub fn with_record_manager(mut self, manager: Box<dyn RecordManager>) -> Self {
self.record_manager = Some(manager);
self
}
pub fn with_cleanup_mode(mut self, mode: CleanupMode) -> Self {
self.cleanup_mode = mode;
self
}
pub async fn index(&self, documents: Vec<Document>) -> Result<IndexingResult> {
let docs = match &self.text_splitter {
Some(splitter) => splitter.split_documents(&documents),
None => documents,
};
let hashes: Vec<String> = docs.iter().map(content_hash).collect();
let (docs_to_add, num_skipped) = match &self.record_manager {
Some(rm) => {
if self.cleanup_mode == CleanupMode::Full {
let all_keys = rm.list_keys(None).await?;
if !all_keys.is_empty() {
let key_refs: Vec<String> = all_keys.clone();
self.vectorstore.delete(Some(&key_refs)).await?;
rm.delete_keys(&all_keys).await?;
}
(docs.clone(), 0usize)
} else {
let existence = rm.exists(&hashes).await?;
let mut new_docs = Vec::new();
let mut skipped = 0usize;
for (i, exists) in existence.iter().enumerate() {
if *exists {
skipped += 1;
} else {
new_docs.push(docs[i].clone());
}
}
(new_docs, skipped)
}
}
None => (docs.clone(), 0usize),
};
let num_added = docs_to_add.len();
if !docs_to_add.is_empty() {
self.vectorstore.add_documents(docs_to_add, None).await?;
}
let mut num_deleted = 0usize;
if let Some(rm) = &self.record_manager {
if self.cleanup_mode == CleanupMode::Full {
rm.update(&hashes, None).await?;
} else {
rm.update(&hashes, None).await?;
if self.cleanup_mode == CleanupMode::Incremental {
let all_keys = rm.list_keys(None).await?;
let current_set: std::collections::HashSet<&str> =
hashes.iter().map(|h| h.as_str()).collect();
let stale_keys: Vec<String> = all_keys
.iter()
.filter(|k| !current_set.contains(k.as_str()))
.cloned()
.collect();
if !stale_keys.is_empty() {
num_deleted = stale_keys.len();
self.vectorstore.delete(Some(&stale_keys)).await?;
rm.delete_keys(&stale_keys).await?;
}
}
}
}
Ok(IndexingResult {
num_added,
num_skipped,
num_deleted,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vectorstores::in_memory::InMemoryVectorStore;
use cognis_core::embeddings::Embeddings;
use cognis_core::embeddings_fake::DeterministicFakeEmbedding;
use std::sync::Arc;
fn make_embeddings() -> Arc<dyn Embeddings> {
Arc::new(DeterministicFakeEmbedding::new(16))
}
fn make_store() -> Arc<InMemoryVectorStore> {
Arc::new(InMemoryVectorStore::new(make_embeddings()))
}
fn make_docs(contents: &[&str]) -> Vec<Document> {
contents.iter().map(|c| Document::new(*c)).collect()
}
#[tokio::test]
async fn test_basic_indexing_adds_all_documents() {
let store = make_store();
let pipeline = IndexingPipeline::new(store.clone());
let docs = make_docs(&["hello", "world", "foo"]);
let result = pipeline.index(docs).await.unwrap();
assert_eq!(result.num_added, 3);
assert_eq!(result.num_skipped, 0);
assert_eq!(result.num_deleted, 0);
let found = store.similarity_search("hello", 10).await.unwrap();
assert_eq!(found.len(), 3);
}
#[tokio::test]
async fn test_incremental_skips_unchanged() {
let store = make_store();
let rm = Box::new(InMemoryRecordManager::new());
let pipeline = IndexingPipeline::new(store.clone())
.with_record_manager(rm)
.with_cleanup_mode(CleanupMode::None);
let docs = make_docs(&["alpha", "beta"]);
let r1 = pipeline.index(docs.clone()).await.unwrap();
assert_eq!(r1.num_added, 2);
assert_eq!(r1.num_skipped, 0);
let r2 = pipeline.index(docs).await.unwrap();
assert_eq!(r2.num_added, 0);
assert_eq!(r2.num_skipped, 2);
}
#[tokio::test]
async fn test_content_change_triggers_reindex() {
let store = make_store();
let rm = Box::new(InMemoryRecordManager::new());
let pipeline = IndexingPipeline::new(store.clone())
.with_record_manager(rm)
.with_cleanup_mode(CleanupMode::None);
let docs = make_docs(&["version1"]);
let r1 = pipeline.index(docs).await.unwrap();
assert_eq!(r1.num_added, 1);
let docs2 = make_docs(&["version2"]);
let r2 = pipeline.index(docs2).await.unwrap();
assert_eq!(r2.num_added, 1);
assert_eq!(r2.num_skipped, 0);
}
#[tokio::test]
async fn test_full_cleanup_reindexes_all() {
let store = make_store();
let rm = Box::new(InMemoryRecordManager::new());
let pipeline = IndexingPipeline::new(store.clone())
.with_record_manager(rm)
.with_cleanup_mode(CleanupMode::Full);
let docs = make_docs(&["a", "b"]);
let r1 = pipeline.index(docs.clone()).await.unwrap();
assert_eq!(r1.num_added, 2);
let r2 = pipeline.index(docs).await.unwrap();
assert_eq!(r2.num_added, 2);
assert_eq!(r2.num_skipped, 0);
}
#[tokio::test]
async fn test_indexing_with_text_splitter() {
use crate::text_splitter::CharacterTextSplitter;
let store = make_store();
let splitter = CharacterTextSplitter::new()
.with_chunk_size(10)
.with_chunk_overlap(0)
.with_separator("\n");
let pipeline = IndexingPipeline::new(store.clone()).with_text_splitter(Box::new(splitter));
let docs = vec![Document::new("line one\nline two\nline three")];
let result = pipeline.index(docs).await.unwrap();
assert!(
result.num_added >= 2,
"Expected at least 2 chunks, got {}",
result.num_added
);
}
#[tokio::test]
async fn test_indexing_result_counts() {
let store = make_store();
let rm = Box::new(InMemoryRecordManager::new());
let pipeline = IndexingPipeline::new(store.clone())
.with_record_manager(rm)
.with_cleanup_mode(CleanupMode::Incremental);
let docs = make_docs(&["x", "y", "z"]);
let r1 = pipeline.index(docs).await.unwrap();
assert_eq!(r1.num_added, 3);
assert_eq!(r1.num_skipped, 0);
assert_eq!(r1.num_deleted, 0);
let docs2 = make_docs(&["x", "y"]);
let r2 = pipeline.index(docs2).await.unwrap();
assert_eq!(r2.num_skipped, 2);
assert_eq!(r2.num_added, 0);
assert_eq!(r2.num_deleted, 1); }
#[tokio::test]
async fn test_empty_document_list() {
let store = make_store();
let pipeline = IndexingPipeline::new(store.clone());
let result = pipeline.index(vec![]).await.unwrap();
assert_eq!(result.num_added, 0);
assert_eq!(result.num_skipped, 0);
assert_eq!(result.num_deleted, 0);
}
#[tokio::test]
async fn test_record_manager_tracks_keys() {
let rm = InMemoryRecordManager::new();
let keys = rm.list_keys(None).await.unwrap();
assert!(keys.is_empty());
rm.update(&["k1".into(), "k2".into()], None).await.unwrap();
let keys = rm.list_keys(None).await.unwrap();
assert_eq!(keys.len(), 2);
let exists = rm.exists(&["k1".into(), "k3".into()]).await.unwrap();
assert_eq!(exists, vec![true, false]);
rm.delete_keys(&["k1".into()]).await.unwrap();
let keys = rm.list_keys(None).await.unwrap();
assert_eq!(keys.len(), 1);
assert_eq!(keys[0], "k2");
}
#[tokio::test]
async fn test_multiple_indexing_passes() {
let store = make_store();
let rm = Box::new(InMemoryRecordManager::new());
let pipeline = IndexingPipeline::new(store.clone())
.with_record_manager(rm)
.with_cleanup_mode(CleanupMode::None);
let r1 = pipeline.index(make_docs(&["a", "b"])).await.unwrap();
assert_eq!(r1.num_added, 2);
let r2 = pipeline.index(make_docs(&["a", "b"])).await.unwrap();
assert_eq!(r2.num_added, 0);
assert_eq!(r2.num_skipped, 2);
let r3 = pipeline.index(make_docs(&["a", "b", "c"])).await.unwrap();
assert_eq!(r3.num_added, 1);
assert_eq!(r3.num_skipped, 2);
}
#[tokio::test]
async fn test_content_hash_determinism() {
let doc = Document::new("hello world");
let h1 = content_hash(&doc);
let h2 = content_hash(&doc);
assert_eq!(h1, h2, "Same document must produce the same hash");
let doc2 = Document::new("goodbye world");
let h3 = content_hash(&doc2);
assert_ne!(h1, h3, "Different content must produce different hashes");
}
#[tokio::test]
async fn test_content_hash_includes_metadata() {
let doc1 = Document::new("same content");
let mut meta = HashMap::new();
meta.insert("key".to_string(), serde_json::json!("value"));
let doc2 = Document::new("same content").with_metadata(meta);
let h1 = content_hash(&doc1);
let h2 = content_hash(&doc2);
assert_ne!(h1, h2, "Metadata differences should change the hash");
}
#[tokio::test]
async fn test_record_manager_group_ids() {
let rm = InMemoryRecordManager::new();
rm.update(
&["k1".into(), "k2".into(), "k3".into()],
Some(&["g1".into(), "g1".into(), "g2".into()]),
)
.await
.unwrap();
let g1_keys = rm.list_keys(Some("g1")).await.unwrap();
assert_eq!(g1_keys.len(), 2);
let g2_keys = rm.list_keys(Some("g2")).await.unwrap();
assert_eq!(g2_keys.len(), 1);
let all_keys = rm.list_keys(None).await.unwrap();
assert_eq!(all_keys.len(), 3);
}
}