use super::{Document, VectorStoreError};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkDocument {
pub chunk_id: String,
pub parent_id: String,
pub content: String,
pub segment: usize,
pub metadata: HashMap<String, String>,
}
impl ChunkDocument {
pub fn new(
chunk_id: String,
parent_id: String,
content: String,
segment: usize,
) -> Self {
Self {
chunk_id,
parent_id,
content,
segment,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn to_document(&self) -> Document {
Document {
content: self.content.clone(),
metadata: self.metadata.clone(),
id: Some(self.chunk_id.clone()),
}
}
}
#[async_trait]
pub trait DocumentStore: Send + Sync {
async fn add_document(&self, document: Document) -> Result<String, VectorStoreError>;
async fn add_documents(&self, documents: Vec<Document>) -> Result<Vec<String>, VectorStoreError>;
async fn get_document(&self, id: &str) -> Result<Option<Document>, VectorStoreError>;
async fn delete_document(&self, id: &str) -> Result<(), VectorStoreError>;
async fn count(&self) -> usize;
async fn clear(&self) -> Result<(), VectorStoreError>;
}
#[async_trait]
pub trait ChunkedDocumentStoreTrait: Send + Sync {
async fn add_parent_document(
&self,
document: Document,
chunk_size: usize,
) -> Result<(String, Vec<String>), VectorStoreError>;
async fn add_parent_documents(
&self,
documents: Vec<Document>,
chunk_size: usize,
) -> Result<Vec<(String, Vec<String>)>, VectorStoreError>;
async fn get_parent_document(&self, parent_id: &str) -> Result<Option<Document>, VectorStoreError>;
async fn get_chunk(&self, chunk_id: &str) -> Result<Option<ChunkDocument>, VectorStoreError>;
async fn get_chunk_document(&self, chunk_id: &str) -> Result<Option<Document>, VectorStoreError>;
async fn get_chunks_for_parent(&self, parent_id: &str) -> Result<Vec<ChunkDocument>, VectorStoreError>;
async fn get_chunk_documents_for_parent(&self, parent_id: &str) -> Result<Vec<Document>, VectorStoreError>;
async fn delete_parent_document(&self, parent_id: &str) -> Result<(), VectorStoreError>;
async fn parent_count(&self) -> usize;
async fn chunk_count(&self) -> usize;
async fn get_all_chunks(&self) -> Result<Vec<ChunkDocument>, VectorStoreError>;
async fn clear(&self) -> Result<(), VectorStoreError>;
async fn save(&self, path: impl AsRef<Path> + Send) -> Result<(), VectorStoreError> {
Err(VectorStoreError::StorageError("save not implemented for this store".to_string()))
}
async fn load(path: impl AsRef<Path> + Send) -> Result<Self, VectorStoreError> where Self: Sized {
Err(VectorStoreError::StorageError("load not implemented for this store".to_string()))
}
}
pub struct InMemoryDocumentStore {
documents: Arc<RwLock<HashMap<String, Document>>>,
}
impl InMemoryDocumentStore {
pub fn new() -> Self {
Self {
documents: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryDocumentStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl DocumentStore for InMemoryDocumentStore {
async fn add_document(&self, document: Document) -> Result<String, VectorStoreError> {
let id = document.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
let mut store = self.documents.write().await;
store.insert(id.clone(), document);
Ok(id)
}
async fn add_documents(&self, documents: Vec<Document>) -> Result<Vec<String>, VectorStoreError> {
let mut store = self.documents.write().await;
let mut ids = Vec::new();
for doc in documents {
let id = doc.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
store.insert(id.clone(), doc);
ids.push(id);
}
Ok(ids)
}
async fn get_document(&self, id: &str) -> Result<Option<Document>, VectorStoreError> {
let store = self.documents.read().await;
Ok(store.get(id).cloned())
}
async fn delete_document(&self, id: &str) -> Result<(), VectorStoreError> {
let mut store = self.documents.write().await;
store.remove(id);
Ok(())
}
async fn count(&self) -> usize {
let store = self.documents.read().await;
store.len()
}
async fn clear(&self) -> Result<(), VectorStoreError> {
let mut store = self.documents.write().await;
store.clear();
Ok(())
}
}
pub struct InMemoryChunkedDocumentStore {
parent_docs: Arc<RwLock<HashMap<String, Document>>>,
chunks: Arc<RwLock<HashMap<String, ChunkDocument>>>,
parent_to_chunks: Arc<RwLock<HashMap<String, Vec<String>>>>,
}
impl InMemoryChunkedDocumentStore {
pub fn new() -> Self {
Self {
parent_docs: Arc::new(RwLock::new(HashMap::new())),
chunks: Arc::new(RwLock::new(HashMap::new())),
parent_to_chunks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add_parent_document_blocking(
&self,
document: Document,
chunk_size: usize,
) -> Result<(String, Vec<String>), VectorStoreError> {
let parent_id = document.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
{
let mut parents = self.parent_docs.blocking_write();
parents.insert(parent_id.clone(), document.clone());
}
let chunk_ids = self.split_and_store_chunks_blocking(&parent_id, &document.content, chunk_size)?;
Ok((parent_id, chunk_ids))
}
pub fn get_parent_document_blocking(&self, parent_id: &str) -> Result<Option<Document>, VectorStoreError> {
let parents = self.parent_docs.blocking_read();
Ok(parents.get(parent_id).cloned())
}
pub fn get_chunk_blocking(&self, chunk_id: &str) -> Result<Option<ChunkDocument>, VectorStoreError> {
let chunks = self.chunks.blocking_read();
Ok(chunks.get(chunk_id).cloned())
}
pub fn get_chunk_document_blocking(&self, chunk_id: &str) -> Result<Option<Document>, VectorStoreError> {
let chunks = self.chunks.blocking_read();
Ok(chunks.get(chunk_id).map(|c| c.to_document()))
}
pub fn blocking_get_chunks_for_parent(&self, parent_id: &str) -> Result<Vec<ChunkDocument>, VectorStoreError> {
let mapping = self.parent_to_chunks.blocking_read();
let chunks = self.chunks.blocking_read();
let chunk_ids = mapping.get(parent_id).cloned().unwrap_or_default();
let result = chunk_ids
.iter()
.filter_map(|id| chunks.get(id).cloned())
.collect();
Ok(result)
}
fn split_and_store_chunks_blocking(
&self,
parent_id: &str,
content: &str,
chunk_size: usize,
) -> Result<Vec<String>, VectorStoreError> {
let chars: Vec<char> = content.chars().collect();
let total_len = chars.len();
let mut chunk_ids = Vec::new();
let mut segment = 0;
let mut start = 0;
while start < total_len {
let end = std::cmp::min(start + chunk_size, total_len);
let chunk_content: String = chars[start..end].iter().collect();
let chunk_id = format!("{}_{}", parent_id, segment);
let chunk = ChunkDocument::new(
chunk_id.clone(),
parent_id.to_string(),
chunk_content,
segment,
);
{
let mut chunks = self.chunks.blocking_write();
chunks.insert(chunk_id.clone(), chunk);
}
{
let mut mapping = self.parent_to_chunks.blocking_write();
mapping
.entry(parent_id.to_string())
.or_insert_with(Vec::new)
.push(chunk_id.clone());
}
chunk_ids.push(chunk_id);
segment += 1;
start = end;
}
Ok(chunk_ids)
}
async fn split_and_store_chunks_async(
&self,
parent_id: &str,
content: &str,
chunk_size: usize,
) -> Result<Vec<String>, VectorStoreError> {
let chars: Vec<char> = content.chars().collect();
let total_len = chars.len();
let mut chunk_ids = Vec::new();
let mut segment = 0;
let mut start = 0;
while start < total_len {
let end = std::cmp::min(start + chunk_size, total_len);
let chunk_content: String = chars[start..end].iter().collect();
let chunk_id = format!("{}_{}", parent_id, segment);
let chunk = ChunkDocument::new(
chunk_id.clone(),
parent_id.to_string(),
chunk_content,
segment,
);
{
let mut chunks = self.chunks.write().await;
chunks.insert(chunk_id.clone(), chunk);
}
{
let mut mapping = self.parent_to_chunks.write().await;
mapping
.entry(parent_id.to_string())
.or_insert_with(Vec::new)
.push(chunk_id.clone());
}
chunk_ids.push(chunk_id);
segment += 1;
start = end;
}
Ok(chunk_ids)
}
}
impl Default for InMemoryChunkedDocumentStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ChunkedDocumentStoreTrait for InMemoryChunkedDocumentStore {
async fn add_parent_document(
&self,
document: Document,
chunk_size: usize,
) -> Result<(String, Vec<String>), VectorStoreError> {
let parent_id = document.id.clone().unwrap_or_else(|| Uuid::new_v4().to_string());
{
let mut parents = self.parent_docs.write().await;
parents.insert(parent_id.clone(), document.clone());
}
let chunk_ids = self.split_and_store_chunks_async(&parent_id, &document.content, chunk_size).await?;
Ok((parent_id, chunk_ids))
}
async fn add_parent_documents(
&self,
documents: Vec<Document>,
chunk_size: usize,
) -> Result<Vec<(String, Vec<String>)>, VectorStoreError> {
let mut results = Vec::new();
for doc in documents {
let result = self.add_parent_document(doc, chunk_size).await?;
results.push(result);
}
Ok(results)
}
async fn get_parent_document(&self, parent_id: &str) -> Result<Option<Document>, VectorStoreError> {
let parents = self.parent_docs.read().await;
Ok(parents.get(parent_id).cloned())
}
async fn get_chunk(&self, chunk_id: &str) -> Result<Option<ChunkDocument>, VectorStoreError> {
let chunks = self.chunks.read().await;
Ok(chunks.get(chunk_id).cloned())
}
async fn get_chunk_document(&self, chunk_id: &str) -> Result<Option<Document>, VectorStoreError> {
let chunks = self.chunks.read().await;
Ok(chunks.get(chunk_id).map(|c| c.to_document()))
}
async fn get_chunks_for_parent(&self, parent_id: &str) -> Result<Vec<ChunkDocument>, VectorStoreError> {
let mapping = self.parent_to_chunks.read().await;
let chunks = self.chunks.read().await;
let chunk_ids = mapping.get(parent_id).cloned().unwrap_or_default();
let result = chunk_ids
.iter()
.filter_map(|id| chunks.get(id).cloned())
.collect();
Ok(result)
}
async fn get_chunk_documents_for_parent(&self, parent_id: &str) -> Result<Vec<Document>, VectorStoreError> {
let chunks = self.get_chunks_for_parent(parent_id).await?;
Ok(chunks.iter().map(|c| c.to_document()).collect())
}
async fn delete_parent_document(&self, parent_id: &str) -> Result<(), VectorStoreError> {
let chunk_ids = {
let mapping = self.parent_to_chunks.read().await;
mapping.get(parent_id).cloned().unwrap_or_default()
};
{
let mut chunks = self.chunks.write().await;
for chunk_id in &chunk_ids {
chunks.remove(chunk_id);
}
}
{
let mut mapping = self.parent_to_chunks.write().await;
mapping.remove(parent_id);
}
{
let mut parents = self.parent_docs.write().await;
parents.remove(parent_id);
}
Ok(())
}
async fn parent_count(&self) -> usize {
let parents = self.parent_docs.read().await;
parents.len()
}
async fn chunk_count(&self) -> usize {
let chunks = self.chunks.read().await;
chunks.len()
}
async fn get_all_chunks(&self) -> Result<Vec<ChunkDocument>, VectorStoreError> {
let chunks = self.chunks.read().await;
Ok(chunks.values().cloned().collect())
}
async fn clear(&self) -> Result<(), VectorStoreError> {
let mut parents = self.parent_docs.write().await;
let mut chunks = self.chunks.write().await;
let mut mapping = self.parent_to_chunks.write().await;
parents.clear();
chunks.clear();
mapping.clear();
Ok(())
}
async fn save(&self, path: impl AsRef<Path> + Send) -> Result<(), VectorStoreError> {
let parents = self.parent_docs.read().await;
let chunks = self.chunks.read().await;
let mapping = self.parent_to_chunks.read().await;
let data = ChunkedStoreData {
parent_docs: parents.clone(),
chunks: chunks.clone(),
parent_to_chunks: mapping.clone(),
};
let encoded = bincode::serialize(&data)
.map_err(|e| VectorStoreError::StorageError(e.to_string()))?;
std::fs::write(path.as_ref(), encoded)
.map_err(|e| VectorStoreError::StorageError(e.to_string()))?;
Ok(())
}
async fn load(path: impl AsRef<Path> + Send) -> Result<Self, VectorStoreError> {
let bytes = std::fs::read(path.as_ref())
.map_err(|e| VectorStoreError::StorageError(e.to_string()))?;
let data: ChunkedStoreData = bincode::deserialize(&bytes)
.map_err(|e| VectorStoreError::StorageError(e.to_string()))?;
Ok(Self {
parent_docs: Arc::new(RwLock::new(data.parent_docs)),
chunks: Arc::new(RwLock::new(data.chunks)),
parent_to_chunks: Arc::new(RwLock::new(data.parent_to_chunks)),
})
}
}
#[async_trait]
impl DocumentStore for InMemoryChunkedDocumentStore {
async fn add_document(&self, document: Document) -> Result<String, VectorStoreError> {
let id = document.id.clone().unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let mut chunks = self.chunks.write().await;
let chunk = ChunkDocument::new(
id.clone(),
id.clone(),
document.content.clone(),
0,
);
chunks.insert(id.clone(), chunk);
Ok(id)
}
async fn add_documents(&self, documents: Vec<Document>) -> Result<Vec<String>, VectorStoreError> {
let mut ids = Vec::new();
for doc in documents {
let id = self.add_document(doc).await?;
ids.push(id);
}
Ok(ids)
}
async fn get_document(&self, id: &str) -> Result<Option<Document>, VectorStoreError> {
self.get_chunk_document(id).await
}
async fn delete_document(&self, id: &str) -> Result<(), VectorStoreError> {
let mut chunks = self.chunks.write().await;
chunks.remove(id);
Ok(())
}
async fn count(&self) -> usize {
self.chunk_count().await
}
async fn clear(&self) -> Result<(), VectorStoreError> {
ChunkedDocumentStoreTrait::clear(self).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ChunkedStoreData {
parent_docs: HashMap<String, Document>,
chunks: HashMap<String, ChunkDocument>,
parent_to_chunks: HashMap<String, Vec<String>>,
}
pub type ChunkedDocumentStore = InMemoryChunkedDocumentStore;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_document_store() {
let store = InMemoryDocumentStore::new();
let doc = Document::new("测试内容").with_id("doc_001");
let id = store.add_document(doc).await.unwrap();
assert_eq!(id, "doc_001");
assert_eq!(store.count().await, 1);
let retrieved = store.get_document("doc_001").await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "测试内容");
store.delete_document("doc_001").await.unwrap();
assert_eq!(store.count().await, 0);
}
#[tokio::test]
async fn test_chunked_document_store() {
let store = ChunkedDocumentStore::new();
let doc = Document::new("这是一段很长的测试文本,用于验证文档分割功能。").with_id("parent_001");
let (parent_id, chunk_ids) = store.add_parent_document(doc, 20).await.unwrap();
assert_eq!(parent_id, "parent_001");
assert!(chunk_ids.len() > 1);
let parent = store.get_parent_document("parent_001").await.unwrap();
assert!(parent.is_some());
let chunks = store.get_chunks_for_parent("parent_001").await.unwrap();
assert_eq!(chunks.len(), chunk_ids.len());
let chunk = store.get_chunk(&chunk_ids[0]).await.unwrap();
assert!(chunk.is_some());
assert_eq!(chunk.unwrap().parent_id, "parent_001");
store.delete_parent_document("parent_001").await.unwrap();
assert_eq!(store.parent_count().await, 0);
assert_eq!(store.chunk_count().await, 0);
}
#[tokio::test]
async fn test_chunk_to_document() {
let chunk = ChunkDocument::new(
"chunk_001".to_string(),
"parent_001".to_string(),
"Chunk内容".to_string(),
0,
).with_metadata("source", "test");
let doc = chunk.to_document();
assert_eq!(doc.id, Some("chunk_001".to_string()));
assert_eq!(doc.content, "Chunk内容");
assert_eq!(doc.metadata.get("source"), Some(&"test".to_string()));
}
#[tokio::test]
async fn test_persistence() {
let store = ChunkedDocumentStore::new();
let doc = Document::new("测试持久化功能的内容").with_id("parent_001");
store.add_parent_document(doc, 10).await.unwrap();
let temp_path = tempfile::NamedTempFile::new().unwrap();
store.save(temp_path.path()).await.unwrap();
let loaded = ChunkedDocumentStore::load(temp_path.path()).await.unwrap();
assert_eq!(loaded.parent_count().await, store.parent_count().await);
assert_eq!(loaded.chunk_count().await, store.chunk_count().await);
let parent = loaded.get_parent_document("parent_001").await.unwrap();
assert!(parent.is_some());
}
}