use std::collections::HashMap;
use super::config::BqConfig;
use super::utils::{hamming_distance, quantize_to_binary};
use crate::vectordb::{
SearchResult, VectorDbError, VectorPoint, WriteConsistency, cosine_similarity,
};
#[derive(Default, Clone)]
pub struct MockBqClient {
collections: std::sync::Arc<std::sync::RwLock<HashMap<String, MockBqCollection>>>,
config: BqConfig,
}
#[derive(Default, Clone)]
struct MockBqCollection {
vector_size: u64,
points: HashMap<u64, MockBqPoint>,
}
#[derive(Clone)]
struct MockBqPoint {
vector: Vec<f32>,
binary: Vec<u8>,
tenant_id: u64,
context_hash: u64,
timestamp: i64,
storage_key: Option<String>,
}
impl MockBqClient {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: BqConfig) -> Self {
Self {
collections: std::sync::Arc::new(std::sync::RwLock::new(HashMap::new())),
config,
}
}
pub fn point_count(&self, collection: &str) -> Option<usize> {
self.collections
.read()
.ok()?
.get(collection)
.map(|c| c.points.len())
}
#[cfg(any(test, feature = "mock"))]
pub fn poison_lock(&self) {
use std::thread;
let collections_clone = self.collections.clone();
let handle = thread::spawn(move || {
let _guard = collections_clone.write().unwrap();
panic!("Intentional panic to poison lock for testing");
});
let _ = handle.join();
}
pub async fn ensure_bq_collection(
&self,
name: &str,
vector_size: u64,
) -> Result<(), VectorDbError> {
let mut collections =
self.collections
.write()
.map_err(|_| VectorDbError::CreateCollectionFailed {
collection: name.to_string(),
message: "lock poisoned".to_string(),
})?;
collections
.entry(name.to_string())
.or_insert(MockBqCollection {
vector_size,
points: HashMap::new(),
});
Ok(())
}
pub async fn upsert_points(
&self,
collection: &str,
points: Vec<VectorPoint>,
_consistency: WriteConsistency,
) -> Result<(), VectorDbError> {
let mut collections =
self.collections
.write()
.map_err(|_| VectorDbError::UpsertFailed {
collection: collection.to_string(),
message: "lock poisoned".to_string(),
})?;
let coll =
collections
.get_mut(collection)
.ok_or_else(|| VectorDbError::CollectionNotFound {
collection: collection.to_string(),
})?;
for point in points {
if point.vector.len() as u64 != coll.vector_size {
return Err(VectorDbError::InvalidDimension {
expected: coll.vector_size as usize,
actual: point.vector.len(),
});
}
let binary = quantize_to_binary(&point.vector);
coll.points.insert(
point.id,
MockBqPoint {
vector: point.vector,
binary,
tenant_id: point.tenant_id,
context_hash: point.context_hash,
timestamp: point.timestamp,
storage_key: point.storage_key,
},
);
}
Ok(())
}
pub async fn search_bq(
&self,
collection: &str,
query: Vec<f32>,
limit: u64,
tenant_filter: Option<u64>,
) -> Result<Vec<SearchResult>, VectorDbError> {
let collections = self
.collections
.read()
.map_err(|_| VectorDbError::SearchFailed {
collection: collection.to_string(),
message: "lock poisoned".to_string(),
})?;
let coll =
collections
.get(collection)
.ok_or_else(|| VectorDbError::CollectionNotFound {
collection: collection.to_string(),
})?;
let query_binary = quantize_to_binary(&query);
let mut candidates: Vec<(u64, &MockBqPoint, u32)> = coll
.points
.iter()
.filter(|(_, p)| tenant_filter.is_none() || tenant_filter == Some(p.tenant_id))
.map(|(&id, p)| {
let hamming = hamming_distance(&query_binary, &p.binary);
(id, p, hamming)
})
.collect();
candidates.sort_by_key(|(_, _, h)| *h);
let rescore_limit = if self.config.rescore {
self.config.rescore_limit as usize
} else {
limit as usize
};
candidates.truncate(rescore_limit);
let mut results: Vec<SearchResult> = if self.config.rescore {
candidates
.into_iter()
.map(|(id, p, _)| {
let score = cosine_similarity(&query, &p.vector);
SearchResult {
id,
score,
tenant_id: p.tenant_id,
context_hash: p.context_hash,
timestamp: p.timestamp,
storage_key: p.storage_key.clone(),
}
})
.collect()
} else {
let max_hamming = (query_binary.len() as u32) * 8;
candidates
.into_iter()
.map(|(id, p, hamming)| {
let score = 1.0 - (hamming as f32 / max_hamming as f32);
SearchResult {
id,
score,
tenant_id: p.tenant_id,
context_hash: p.context_hash,
timestamp: p.timestamp,
storage_key: p.storage_key.clone(),
}
})
.collect()
};
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit as usize);
Ok(results)
}
pub async fn delete_points(
&self,
collection: &str,
ids: Vec<u64>,
) -> Result<(), VectorDbError> {
let mut collections =
self.collections
.write()
.map_err(|_| VectorDbError::DeleteFailed {
collection: collection.to_string(),
message: "lock poisoned".to_string(),
})?;
let coll =
collections
.get_mut(collection)
.ok_or_else(|| VectorDbError::CollectionNotFound {
collection: collection.to_string(),
})?;
for id in ids {
coll.points.remove(&id);
}
Ok(())
}
}