use crate::cascade::{CascadeConfig, CascadeIndex, SearchResult as CascadeSearchResult};
use crate::error::SynaError;
use crate::gwi::{GravityWellIndex, GwiConfig, GwiSearchResult};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct HybridSearchResult {
pub key: String,
pub score: f32,
pub source: ResultSource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResultSource {
Hot,
Cold,
}
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub hot: GwiConfig,
pub cold: CascadeConfig,
}
pub struct HybridVectorStore {
hot_index: GravityWellIndex,
cold_index: CascadeIndex,
cold_path: String,
}
impl HybridVectorStore {
pub fn new<P: AsRef<Path>>(
hot_path: P,
cold_path: P,
config: HybridConfig,
) -> Result<Self, SynaError> {
let cold_path_str = cold_path.as_ref().to_string_lossy().to_string();
let hot = GravityWellIndex::new(&hot_path, config.hot)?;
let cold = CascadeIndex::new(&cold_path, config.cold)?;
Ok(Self {
hot_index: hot,
cold_index: cold,
cold_path: cold_path_str,
})
}
pub fn open<P: AsRef<Path>>(
hot_path: P,
cold_path: P,
cold_config: CascadeConfig,
) -> Result<Self, SynaError> {
let cold_path_str = cold_path.as_ref().to_string_lossy().to_string();
let hot = GravityWellIndex::open(&hot_path)?;
let cold = CascadeIndex::new(&cold_path, cold_config)?;
Ok(Self {
hot_index: hot,
cold_index: cold,
cold_path: cold_path_str,
})
}
pub fn initialize_hot(&mut self, sample_vectors: &[&[f32]]) -> Result<(), SynaError> {
self.hot_index.initialize_attractors(sample_vectors)
}
pub fn ingest(&mut self, key: &str, vector: &[f32]) -> Result<(), SynaError> {
self.hot_index.insert(key, vector)
}
pub fn ingest_batch(&mut self, keys: &[&str], vectors: &[&[f32]]) -> Result<usize, SynaError> {
self.hot_index.insert_batch(keys, vectors)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<HybridSearchResult>, SynaError> {
let hot_results = self.hot_index.search(query, k).unwrap_or_default();
let cold_results = self.cold_index.search(query, k).unwrap_or_default();
let mut combined: HashMap<String, (f32, ResultSource)> = HashMap::new();
for res in hot_results {
combined.insert(res.key, (res.score, ResultSource::Hot));
}
for res in cold_results {
combined
.entry(res.key)
.and_modify(|(score, source)| {
if res.score < *score {
*score = res.score;
*source = ResultSource::Cold;
}
})
.or_insert((res.score, ResultSource::Cold));
}
let mut results: Vec<HybridSearchResult> = combined
.into_iter()
.map(|(key, (score, source))| HybridSearchResult { key, score, source })
.collect();
results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(Ordering::Equal));
results.truncate(k);
Ok(results)
}
pub fn search_hot(&self, query: &[f32], k: usize) -> Result<Vec<GwiSearchResult>, SynaError> {
self.hot_index.search(query, k)
}
pub fn search_cold(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<CascadeSearchResult>, SynaError> {
self.cold_index.search(query, k)
}
pub fn hot_count(&self) -> usize {
self.hot_index.len()
}
pub fn cold_count(&self) -> usize {
self.cold_index.len()
}
pub fn len(&self) -> usize {
self.hot_count() + self.cold_count()
}
pub fn is_empty(&self) -> bool {
self.hot_index.is_empty() && self.cold_index.is_empty()
}
pub fn flush_hot(&self) -> Result<(), SynaError> {
self.hot_index.flush()
}
pub fn save_cold(&self) -> Result<(), SynaError> {
self.cold_index.save(&self.cold_path)
}
pub fn promote_to_cold(&mut self) -> Result<usize, SynaError> {
let keys = self.hot_index.keys();
let mut promoted = 0;
for key in keys {
if let Some(vector) = self.hot_index.get(&key)? {
self.cold_index.insert(&key, &vector)?;
promoted += 1;
}
}
Ok(promoted)
}
}