use crate::storage::{
AccessContext, AccessLevel, QdrantConnectionConfig, EmbeddingCacheConfig,
AccessControlConfig, StorageStats,
};
use crate::{Document, Error, Result};
use async_trait::async_trait;
use qdrant_client::qdrant::{
CreateCollection, DeletePoints, Distance, PointId, PointStruct, QuantizationConfig,
ScalarQuantization, SearchPoints, UpsertPoints, VectorParams, VectorsConfig,
Filter as QdrantFilter, Condition, FieldCondition, Match,
};
use qdrant_client::Qdrant;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub max_batch_size: usize,
pub batch_timeout_ms: u64,
pub parallel_batching: bool,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 100,
batch_timeout_ms: 1000,
parallel_batching: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QueryCacheConfig {
pub max_cache_entries: usize,
pub ttl_secs: u64,
pub enabled: bool,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self {
max_cache_entries: 1000,
ttl_secs: 300, enabled: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QueryCacheKey {
pub collection_name: String,
pub query_vector_hash: u64,
pub top_k: usize,
pub filter_hash: Option<u64>,
}
impl QueryCacheKey {
pub fn new(
collection_name: String,
query_vector: &[f32],
top_k: usize,
filter: Option<&QdrantFilter>,
) -> Self {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
for val in query_vector {
hasher.write_u32(val.to_bits());
}
let query_vector_hash = hasher.finish();
let filter_hash = filter.map(|f| {
let mut hasher = DefaultHasher::new();
format!("{:?}", f).hash(&mut hasher);
hasher.finish()
});
Self {
collection_name,
query_vector_hash,
top_k,
filter_hash,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedQueryEntry {
pub results: Vec<(Uuid, f32)>,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub total_queries: u64,
pub hit_rate: f64,
}
pub struct QueryCache {
cache: HashMap<QueryCacheKey, CachedQueryEntry>,
access_order: Vec<QueryCacheKey>,
config: QueryCacheConfig,
stats: CacheStats,
}
impl QueryCache {
pub fn new(config: QueryCacheConfig) -> Self {
Self {
cache: HashMap::new(),
access_order: Vec::new(),
config,
stats: CacheStats::default(),
}
}
pub fn insert(&mut self, key: QueryCacheKey, results: Vec<(Uuid, f32)>) {
if !self.config.enabled {
return;
}
let entry = CachedQueryEntry {
results,
created_at: Instant::now(),
last_accessed: Instant::now(),
access_count: 0,
};
self.cache.insert(key.clone(), entry);
self.access_order.push(key);
while self.cache.len() > self.config.max_cache_entries {
self.evict_lru();
}
}
pub fn get(&mut self, key: &QueryCacheKey) -> Option<Vec<(Uuid, f32)>> {
self.stats.total_queries += 1;
let result = if let Some(entry) = self.cache.get_mut(key) {
if entry.created_at.elapsed().as_secs() <= self.config.ttl_secs {
entry.access_count += 1;
entry.last_accessed = Instant::now();
self.access_order.retain(|k| k != key);
self.access_order.push(key.clone());
self.stats.hits += 1;
Some(entry.results.clone())
} else {
None
}
} else {
None
};
if result.is_none() && self.cache.contains_key(key) {
self.cache.remove(key);
self.access_order.retain(|k| k != key);
self.stats.misses += 1;
}
self.update_hit_rate();
result
}
fn evict_lru(&mut self) {
if let Some(oldest_key) = self.access_order.first() {
let key = oldest_key.clone();
self.cache.remove(&key);
self.access_order.remove(0);
}
}
fn update_hit_rate(&mut self) {
if self.stats.total_queries > 0 {
self.stats.hit_rate = self.stats.hits as f64 / self.stats.total_queries as f64;
}
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn clear(&mut self) {
self.cache.clear();
self.access_order.clear();
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
}
pub struct OptimizedQdrantStorage {
client: Arc<Qdrant>,
query_cache: Arc<RwLock<QueryCache>>,
batch_config: BatchConfig,
}
impl OptimizedQdrantStorage {
pub async fn new(
url: &str,
cache_config: QueryCacheConfig,
batch_config: BatchConfig,
) -> Result<Self> {
let client = Qdrant::from_url(url).build().map_err(|e| {
Error::Storage(format!("Failed to connect to Qdrant: {}", e))
})?;
Ok(Self {
client: Arc::new(client),
query_cache: Arc::new(RwLock::new(QueryCache::new(cache_config))),
batch_config,
})
}
pub async fn search_cached(
&self,
collection_name: &str,
query_vector: Vec<f32>,
top_k: usize,
filter: Option<QdrantFilter>,
) -> Result<Vec<(Uuid, f32)>> {
let cache_key = QueryCacheKey::new(
collection_name.to_string(),
&query_vector,
top_k,
filter.as_ref(),
);
{
let mut cache = self.query_cache.write().await;
if let Some(cached_results) = cache.get(&cache_key) {
return Ok(cached_results);
}
}
let search_result = self
.client
.search_points(SearchPoints {
collection_name: collection_name.to_string(),
vector: query_vector.clone(),
limit: top_k as u64,
filter: filter.clone(),
..Default::default()
})
.await
.map_err(|e| Error::Storage(format!("Search failed: {}", e)))?;
let results: Vec<(Uuid, f32)> = search_result
.result
.iter()
.filter_map(|point| {
if let Some(id) = &point.id {
if let Some(qdrant_client::qdrant::point_id::PointIdOptions::Uuid(uuid_str)) =
&id.point_id_options
{
if let Ok(uuid) = Uuid::parse_str(uuid_str) {
return Some((uuid, point.score));
}
}
}
None
})
.collect();
{
let mut cache = self.query_cache.write().await;
cache.insert(cache_key, results.clone());
}
Ok(results)
}
pub async fn cache_stats(&self) -> CacheStats {
let cache = self.query_cache.read().await;
cache.stats().clone()
}
pub async fn clear_cache(&self) {
let mut cache = self.query_cache.write().await;
cache.clear();
}
}