1use crate::hnsw::{HnswConfig, HnswIndex};
43use crate::optimizer::{OptimizerConfig, QueryOptimizer, SearchStrategy};
44use crate::search::VectorSearchIndex;
45use crate::types::{DistanceMetric, SearchConfig, SearchResult};
46use anyhow::{anyhow, Result};
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::time::{Duration, Instant};
50use tracing::{debug, info, warn};
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct AdaptiveConfig {
55 pub metric: DistanceMetric,
57 pub normalize: bool,
59 pub min_recall: f32,
61 pub auto_upgrade: bool,
63 pub latency_threshold_ms: u64,
65 pub stats_window: usize,
67}
68
69impl Default for AdaptiveConfig {
70 fn default() -> Self {
71 Self {
72 metric: DistanceMetric::Cosine,
73 normalize: true,
74 min_recall: 0.95,
75 auto_upgrade: true,
76 latency_threshold_ms: 10, stats_window: 100,
78 }
79 }
80}
81
82impl AdaptiveConfig {
83 pub fn high_accuracy() -> Self {
85 Self {
86 min_recall: 0.99,
87 latency_threshold_ms: 50, ..Default::default()
89 }
90 }
91
92 pub fn low_latency() -> Self {
94 Self {
95 min_recall: 0.90,
96 latency_threshold_ms: 5, auto_upgrade: true,
98 ..Default::default()
99 }
100 }
101}
102
103enum IndexImpl {
105 BruteForce(VectorSearchIndex),
106 Hnsw(HnswIndex),
107}
108
109pub struct AdaptiveIndex {
111 config: AdaptiveConfig,
112 optimizer: QueryOptimizer,
113 index: Option<IndexImpl>,
114 num_vectors: usize,
115 dimensions: usize,
116 recent_latencies: Vec<Duration>,
118 total_searches: usize,
119 embeddings_cache: HashMap<String, Vec<f32>>,
120}
121
122impl AdaptiveIndex {
123 pub fn new(config: AdaptiveConfig) -> Self {
125 let optimizer_config = OptimizerConfig {
126 min_recall: config.min_recall,
127 ..OptimizerConfig::default()
128 };
129
130 Self {
131 config,
132 optimizer: QueryOptimizer::new(optimizer_config),
133 index: None,
134 num_vectors: 0,
135 dimensions: 0,
136 recent_latencies: Vec::new(),
137 total_searches: 0,
138 embeddings_cache: HashMap::new(),
139 }
140 }
141
142 pub fn build(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
144 if embeddings.is_empty() {
145 return Err(anyhow!("Cannot build index from empty embeddings"));
146 }
147
148 self.num_vectors = embeddings.len();
149 self.dimensions = embeddings.values().next().unwrap().len();
150 self.embeddings_cache = embeddings.clone();
151
152 info!(
153 "Building adaptive index with {} vectors, {} dimensions",
154 self.num_vectors, self.dimensions
155 );
156
157 let strategy = self
159 .optimizer
160 .recommend_strategy(self.num_vectors, self.config.min_recall);
161
162 self.build_with_strategy(embeddings, strategy)?;
163
164 Ok(())
165 }
166
167 fn build_with_strategy(
169 &mut self,
170 embeddings: &HashMap<String, Vec<f32>>,
171 strategy: SearchStrategy,
172 ) -> Result<()> {
173 info!("Building index with strategy: {:?}", strategy);
174
175 match strategy {
176 SearchStrategy::BruteForce => {
177 let mut index = VectorSearchIndex::new(SearchConfig {
178 metric: self.config.metric,
179 normalize: self.config.normalize,
180 parallel: true,
181 });
182 index.build(embeddings)?;
183 self.index = Some(IndexImpl::BruteForce(index));
184 }
185 SearchStrategy::Hnsw => {
186 let mut index = HnswIndex::new(HnswConfig::default());
187 index.build(embeddings)?;
188 self.index = Some(IndexImpl::Hnsw(index));
189 }
190 _ => {
191 warn!("Strategy {:?} not yet implemented, using HNSW", strategy);
193 let mut index = HnswIndex::new(HnswConfig::default());
194 index.build(embeddings)?;
195 self.index = Some(IndexImpl::Hnsw(index));
196 }
197 }
198
199 Ok(())
200 }
201
202 pub fn search(&mut self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
204 let index = self
205 .index
206 .as_ref()
207 .ok_or_else(|| anyhow!("Index not built"))?;
208
209 let start = Instant::now();
210
211 let results = match index {
212 IndexImpl::BruteForce(idx) => idx.search(query, k)?,
213 IndexImpl::Hnsw(idx) => idx.search(query, k)?,
214 };
215
216 let elapsed = start.elapsed();
217
218 self.track_search_latency(elapsed);
220
221 if self.config.auto_upgrade {
223 self.check_and_upgrade()?;
224 }
225
226 Ok(results)
227 }
228
229 pub fn add_vector(&mut self, entity_id: String, embedding: Vec<f32>) -> Result<()> {
231 self.embeddings_cache
233 .insert(entity_id.clone(), embedding.clone());
234 self.num_vectors += 1;
235
236 if let Some(index) = &mut self.index {
238 match index {
239 IndexImpl::BruteForce(idx) => {
240 idx.add_vector(entity_id, embedding)?;
241 }
242 IndexImpl::Hnsw(_) => {
243 if self.config.auto_upgrade {
246 debug!("HNSW doesn't support incremental updates, checking for rebuild");
247 }
248 }
249 }
250 }
251
252 if self.config.auto_upgrade {
254 self.check_and_upgrade()?;
255 }
256
257 Ok(())
258 }
259
260 pub fn add_vectors(&mut self, embeddings: &HashMap<String, Vec<f32>>) -> Result<()> {
262 for (id, emb) in embeddings {
263 self.embeddings_cache.insert(id.clone(), emb.clone());
264 }
265 self.num_vectors += embeddings.len();
266
267 if let Some(index) = &mut self.index {
268 match index {
269 IndexImpl::BruteForce(idx) => {
270 idx.add_vectors(embeddings)?;
271 }
272 IndexImpl::Hnsw(_) => {
273 if self.config.auto_upgrade {
275 debug!("HNSW batch insert requires rebuild");
276 }
277 }
278 }
279 }
280
281 if self.config.auto_upgrade {
282 self.check_and_upgrade()?;
283 }
284
285 Ok(())
286 }
287
288 fn track_search_latency(&mut self, duration: Duration) {
290 self.total_searches += 1;
291 self.recent_latencies.push(duration);
292
293 if self.recent_latencies.len() > self.config.stats_window {
295 self.recent_latencies.remove(0);
296 }
297 }
298
299 fn check_and_upgrade(&mut self) -> Result<()> {
301 let current_strategy = self.current_strategy();
303 let recommended_strategy = self
304 .optimizer
305 .recommend_strategy(self.num_vectors, self.config.min_recall);
306
307 if current_strategy != recommended_strategy {
309 info!(
310 "Dataset size changed, upgrading from {:?} to {:?}",
311 current_strategy, recommended_strategy
312 );
313 self.build_with_strategy(&self.embeddings_cache.clone(), recommended_strategy)?;
314 return Ok(());
315 }
316
317 if !self.recent_latencies.is_empty() {
319 let avg_latency =
320 self.recent_latencies.iter().sum::<Duration>() / self.recent_latencies.len() as u32;
321
322 if avg_latency.as_millis() as u64 > self.config.latency_threshold_ms {
323 warn!(
324 "Average latency {}ms exceeds threshold {}ms",
325 avg_latency.as_millis(),
326 self.config.latency_threshold_ms
327 );
328
329 if current_strategy == SearchStrategy::BruteForce && self.num_vectors > 1000 {
331 info!("Upgrading to HNSW due to high latency");
332 self.build_with_strategy(&self.embeddings_cache.clone(), SearchStrategy::Hnsw)?;
333 }
334 }
335 }
336
337 Ok(())
338 }
339
340 pub fn current_strategy(&self) -> SearchStrategy {
342 match &self.index {
343 Some(IndexImpl::BruteForce(_)) => SearchStrategy::BruteForce,
344 Some(IndexImpl::Hnsw(_)) => SearchStrategy::Hnsw,
345 None => SearchStrategy::BruteForce, }
347 }
348
349 pub fn stats(&self) -> AdaptiveStats {
351 let avg_latency = if !self.recent_latencies.is_empty() {
352 self.recent_latencies.iter().sum::<Duration>() / self.recent_latencies.len() as u32
353 } else {
354 Duration::ZERO
355 };
356
357 let p95_latency = if !self.recent_latencies.is_empty() {
358 let mut sorted = self.recent_latencies.clone();
359 sorted.sort();
360 let p95_idx = (sorted.len() as f32 * 0.95) as usize;
361 sorted.get(p95_idx).copied().unwrap_or(Duration::ZERO)
362 } else {
363 Duration::ZERO
364 };
365
366 AdaptiveStats {
367 num_vectors: self.num_vectors,
368 dimensions: self.dimensions,
369 current_strategy: self.current_strategy(),
370 total_searches: self.total_searches,
371 avg_latency_ms: avg_latency.as_secs_f64() * 1000.0,
372 p95_latency_ms: p95_latency.as_secs_f64() * 1000.0,
373 }
374 }
375
376 #[inline]
378 pub fn len(&self) -> usize {
379 self.num_vectors
380 }
381
382 #[inline]
384 pub fn is_empty(&self) -> bool {
385 self.num_vectors == 0
386 }
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize)]
391pub struct AdaptiveStats {
392 pub num_vectors: usize,
393 pub dimensions: usize,
394 pub current_strategy: SearchStrategy,
395 pub total_searches: usize,
396 pub avg_latency_ms: f64,
397 pub p95_latency_ms: f64,
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 fn create_test_embeddings(count: usize, dim: usize) -> HashMap<String, Vec<f32>> {
405 let mut embeddings = HashMap::new();
406 for i in 0..count {
407 let vec: Vec<f32> = (0..dim).map(|j| (i + j) as f32 * 0.1).collect();
408 embeddings.insert(format!("doc_{}", i), vec);
409 }
410 embeddings
411 }
412
413 #[test]
414 fn test_adaptive_index_small_dataset() {
415 let embeddings = create_test_embeddings(100, 3);
416 let mut index = AdaptiveIndex::new(AdaptiveConfig::default());
417
418 index.build(&embeddings).unwrap();
419
420 assert_eq!(index.current_strategy(), SearchStrategy::BruteForce);
422
423 let query = vec![0.1, 0.2, 0.3];
424 let results = index.search(&query, 5).unwrap();
425
426 assert!(results.len() <= 5);
427 assert!(!results.is_empty());
428 }
429
430 #[test]
431 #[ignore = "slow HNSW construction benchmark - run with --ignored"]
432 fn test_adaptive_index_medium_dataset() {
433 let embeddings = create_test_embeddings(11000, 3);
435 let mut index = AdaptiveIndex::new(AdaptiveConfig::default());
436
437 index.build(&embeddings).unwrap();
438
439 assert_eq!(index.current_strategy(), SearchStrategy::Hnsw);
441
442 let query = vec![0.1, 0.2, 0.3];
443 let results = index.search(&query, 10).unwrap();
444
445 assert!(results.len() <= 10);
446 }
447
448 #[test]
449 fn test_adaptive_index_incremental_add() {
450 let embeddings = create_test_embeddings(50, 3);
451 let mut index = AdaptiveIndex::new(AdaptiveConfig::default());
452
453 index.build(&embeddings).unwrap();
454
455 assert_eq!(index.len(), 50);
456
457 index
459 .add_vector("new_doc".to_string(), vec![0.9, 0.9, 0.9])
460 .unwrap();
461
462 assert_eq!(index.len(), 51);
463 }
464
465 #[test]
466 fn test_adaptive_stats() {
467 let embeddings = create_test_embeddings(100, 3);
468 let mut index = AdaptiveIndex::new(AdaptiveConfig::default());
469
470 index.build(&embeddings).unwrap();
471
472 let query = vec![0.1, 0.2, 0.3];
474 for _ in 0..10 {
475 let _ = index.search(&query, 5);
476 }
477
478 let stats = index.stats();
479 assert_eq!(stats.num_vectors, 100);
480 assert_eq!(stats.dimensions, 3);
481 assert_eq!(stats.total_searches, 10);
482 assert!(stats.avg_latency_ms >= 0.0);
483 }
484
485 #[test]
486 fn test_adaptive_config_presets() {
487 let high_acc = AdaptiveConfig::high_accuracy();
488 assert_eq!(high_acc.min_recall, 0.99);
489
490 let low_lat = AdaptiveConfig::low_latency();
491 assert_eq!(low_lat.latency_threshold_ms, 5);
492 }
493
494 #[test]
495 fn test_adaptive_index_empty() {
496 let index = AdaptiveIndex::new(AdaptiveConfig::default());
497 assert!(index.is_empty());
498 assert_eq!(index.len(), 0);
499 }
500}