use std::{
collections::HashMap,
env, fs,
io::Write,
path::{Path, PathBuf},
sync::RwLock,
};
use async_trait::async_trait;
use rucora_core::{
error::ProviderError,
retrieval::{SearchResult, VectorQuery, VectorRecord, VectorStore},
};
use serde::{Deserialize, Serialize};
pub struct ChromaPersistentStore {
persist_dir: PathBuf,
collection: String,
cache: RwLock<HashMap<String, PersistedRecord>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersistedRecord {
id: String,
vector: Vec<f32>,
metadata: Option<serde_json::Value>,
}
impl ChromaPersistentStore {
pub fn from_env() -> Result<Self, ProviderError> {
let persist_dir =
env::var("CHROMA_PERSIST_DIR").unwrap_or_else(|_| "./chroma_db".to_string());
let collection = env::var("CHROMA_COLLECTION").unwrap_or_else(|_| "default".to_string());
Self::new(persist_dir, collection)
}
pub fn new(
persist_dir: impl AsRef<Path>,
collection: impl Into<String>,
) -> Result<Self, ProviderError> {
let persist_dir = persist_dir.as_ref().to_path_buf();
let collection = collection.into();
let collection_dir = persist_dir.join(&collection);
fs::create_dir_all(&collection_dir)
.map_err(|e| ProviderError::Message(format!("无法创建目录: {e}")))?;
let cache = Self::load_from_disk(&collection_dir)?;
Ok(Self {
persist_dir,
collection,
cache: RwLock::new(cache),
})
}
fn collection_dir(&self) -> PathBuf {
self.persist_dir.join(&self.collection)
}
fn data_file(&self) -> PathBuf {
self.collection_dir().join("data.json")
}
fn load_from_disk(
collection_dir: &Path,
) -> Result<HashMap<String, PersistedRecord>, ProviderError> {
let data_file = collection_dir.join("data.json");
if !data_file.exists() {
return Ok(HashMap::new());
}
let content = fs::read_to_string(&data_file)
.map_err(|e| ProviderError::Message(format!("读取数据文件失败: {e}")))?;
if content.trim().is_empty() {
return Ok(HashMap::new());
}
let records: Vec<PersistedRecord> = serde_json::from_str(&content)
.map_err(|e| ProviderError::Message(format!("解析数据文件失败: {e}")))?;
let cache: HashMap<String, PersistedRecord> =
records.into_iter().map(|r| (r.id.clone(), r)).collect();
Ok(cache)
}
fn save_to_disk(&self) -> Result<(), ProviderError> {
let cache = self
.cache
.read()
.map_err(|_| ProviderError::Message("无法获取读锁".to_string()))?;
let records: Vec<&PersistedRecord> = cache.values().collect();
let json = serde_json::to_string_pretty(&records)
.map_err(|e| ProviderError::Message(format!("序列化数据失败: {e}")))?;
drop(cache);
let data_file = self.data_file();
let mut file = fs::File::create(&data_file)
.map_err(|e| ProviderError::Message(format!("创建数据文件失败: {e}")))?;
file.write_all(json.as_bytes())
.map_err(|e| ProviderError::Message(format!("写入数据文件失败: {e}")))?;
Ok(())
}
pub fn with_collection(mut self, collection: impl Into<String>) -> Result<Self, ProviderError> {
let collection = collection.into();
let collection_dir = self.persist_dir.join(&collection);
fs::create_dir_all(&collection_dir)
.map_err(|e| ProviderError::Message(format!("无法创建目录: {e}")))?;
let cache = Self::load_from_disk(&collection_dir)?;
self.collection = collection;
self.cache = RwLock::new(cache);
Ok(self)
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
pub fn collection_exists(&self) -> bool {
self.collection_dir().exists()
}
pub fn delete_collection(&self) -> Result<(), ProviderError> {
let collection_dir = self.collection_dir();
if collection_dir.exists() {
fs::remove_dir_all(&collection_dir)
.map_err(|e| ProviderError::Message(format!("删除集合失败: {e}")))?;
}
Ok(())
}
}
#[async_trait]
impl VectorStore for ChromaPersistentStore {
async fn upsert(&self, records: Vec<VectorRecord>) -> Result<(), ProviderError> {
if records.is_empty() {
return Ok(());
}
let mut cache = self
.cache
.write()
.map_err(|_| ProviderError::Message("无法获取写锁".to_string()))?;
for record in records {
let persisted = PersistedRecord {
id: record.id,
vector: record.vector,
metadata: record.metadata,
};
cache.insert(persisted.id.clone(), persisted);
}
drop(cache);
self.save_to_disk()?;
Ok(())
}
async fn delete(&self, ids: Vec<String>) -> Result<(), ProviderError> {
if ids.is_empty() {
return Ok(());
}
let mut cache = self
.cache
.write()
.map_err(|_| ProviderError::Message("无法获取写锁".to_string()))?;
for id in ids {
cache.remove(&id);
}
drop(cache);
self.save_to_disk()?;
Ok(())
}
async fn get(&self, ids: Vec<String>) -> Result<Vec<VectorRecord>, ProviderError> {
let cache = self
.cache
.read()
.map_err(|_| ProviderError::Message("无法获取读锁".to_string()))?;
let mut results = Vec::new();
for id in ids {
if let Some(record) = cache.get(&id) {
let text = record
.metadata
.as_ref()
.and_then(|m| m.get("text"))
.and_then(|t| t.as_str())
.map(|s| s.to_string());
results.push(VectorRecord {
id: record.id.clone(),
vector: record.vector.clone(),
text,
metadata: record.metadata.clone(),
});
}
}
Ok(results)
}
async fn search(&self, query: VectorQuery) -> Result<Vec<SearchResult>, ProviderError> {
let cache = self
.cache
.read()
.map_err(|_| ProviderError::Message("无法获取读锁".to_string()))?;
let query_vector = &query.vector;
let mut scores: Vec<(String, f32)> = cache
.values()
.map(|record| {
let score = Self::cosine_similarity(query_vector, &record.vector);
(record.id.clone(), score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let threshold = query.score_threshold.unwrap_or(0.0);
let top_k = query.top_k;
let mut results = Vec::new();
for (id, score) in scores.iter().take(top_k) {
if *score < threshold {
break;
}
if let Some(record) = cache.get(id) {
let text = record
.metadata
.as_ref()
.and_then(|m| m.get("text"))
.and_then(|t| t.as_str())
.map(|s| s.to_string());
results.push(SearchResult {
id: id.clone(),
score: *score,
vector: Some(record.vector.clone()),
text,
metadata: record.metadata.clone(),
});
}
}
Ok(results)
}
async fn clear(&self) -> Result<(), ProviderError> {
let mut cache = self
.cache
.write()
.map_err(|_| ProviderError::Message("无法获取写锁".to_string()))?;
cache.clear();
drop(cache);
self.save_to_disk()?;
Ok(())
}
async fn count(&self) -> Result<usize, ProviderError> {
let cache = self
.cache
.read()
.map_err(|_| ProviderError::Message("无法获取读锁".to_string()))?;
Ok(cache.len())
}
}