ipfrs_semantic/
hnsw.rs

1//! HNSW vector index for semantic search
2//!
3//! This module provides a high-performance vector similarity search index
4//! using the Hierarchical Navigable Small World (HNSW) algorithm.
5
6use hnsw_rs::prelude::*;
7use ipfrs_core::{Cid, Error, Result};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11/// Distance metric for vector similarity
12#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
13pub enum DistanceMetric {
14    /// Euclidean distance (L2)
15    L2,
16    /// Cosine similarity
17    Cosine,
18    /// Dot product similarity
19    DotProduct,
20}
21
22/// Search result entry
23#[derive(Debug, Clone)]
24pub struct SearchResult {
25    /// Content ID
26    pub cid: Cid,
27    /// Distance/similarity score
28    pub score: f32,
29}
30
31/// Statistics from incremental index building
32#[derive(Debug, Clone)]
33pub struct IncrementalBuildStats {
34    /// Number of vectors before insertion
35    pub initial_size: usize,
36    /// Number of vectors after insertion
37    pub final_size: usize,
38    /// Successfully inserted vectors
39    pub vectors_inserted: usize,
40    /// Failed insertions
41    pub vectors_failed: usize,
42    /// Number of chunks processed
43    pub chunks_processed: usize,
44    /// Whether index rebuild is recommended
45    pub should_rebuild: bool,
46}
47
48/// Statistics from index rebuild
49#[derive(Debug, Clone)]
50pub struct RebuildStats {
51    /// Number of vectors re-inserted
52    pub vectors_reinserted: usize,
53    /// Old (M, ef_construction) parameters
54    pub old_parameters: (usize, usize),
55    /// New (M, ef_construction) parameters
56    pub new_parameters: (usize, usize),
57}
58
59/// Health statistics for incremental builds
60#[derive(Debug, Clone)]
61pub struct BuildHealthStats {
62    /// Current index size
63    pub index_size: usize,
64    /// Current M parameter
65    pub current_m: usize,
66    /// Current ef_construction parameter
67    pub current_ef_construction: usize,
68    /// Optimal M for current size
69    pub optimal_m: usize,
70    /// Optimal ef_construction for current size
71    pub optimal_ef_construction: usize,
72    /// Efficiency of current parameters (0.0-1.0)
73    pub parameter_efficiency: f32,
74    /// Whether rebuild is recommended
75    pub rebuild_recommended: bool,
76}
77
78/// HNSW-based vector index for semantic search
79///
80/// Provides efficient approximate k-nearest neighbor search over
81/// high-dimensional vectors associated with content IDs.
82pub struct VectorIndex {
83    /// HNSW index
84    index: Arc<RwLock<Hnsw<'static, f32, DistL2>>>,
85    /// Mapping from data ID to CID
86    id_to_cid: Arc<RwLock<HashMap<usize, Cid>>>,
87    /// Mapping from CID to data ID
88    cid_to_id: Arc<RwLock<HashMap<Cid, usize>>>,
89    /// Storage for original vectors (for retrieval and migration)
90    vectors: Arc<RwLock<HashMap<Cid, Vec<f32>>>>,
91    /// Next available ID
92    next_id: Arc<RwLock<usize>>,
93    /// Vector dimension
94    dimension: usize,
95    /// Distance metric
96    metric: DistanceMetric,
97}
98
99impl VectorIndex {
100    /// Create a new vector index with the specified dimension
101    ///
102    /// # Arguments
103    /// * `dimension` - Dimension of vectors to be indexed
104    /// * `metric` - Distance metric to use
105    /// * `max_nb_connection` - Maximum number of connections per layer (M parameter)
106    /// * `ef_construction` - Size of dynamic candidate list (efConstruction parameter)
107    pub fn new(
108        dimension: usize,
109        metric: DistanceMetric,
110        max_nb_connection: usize,
111        ef_construction: usize,
112    ) -> Result<Self> {
113        if dimension == 0 {
114            return Err(Error::InvalidInput(
115                "Vector dimension must be greater than 0".to_string(),
116            ));
117        }
118
119        // Create HNSW index with L2 distance (we'll handle other metrics via normalization)
120        let index = Hnsw::<f32, DistL2>::new(
121            max_nb_connection,
122            dimension,
123            ef_construction,
124            200, // max_elements initial capacity
125            DistL2 {},
126        );
127
128        Ok(Self {
129            index: Arc::new(RwLock::new(index)),
130            id_to_cid: Arc::new(RwLock::new(HashMap::new())),
131            cid_to_id: Arc::new(RwLock::new(HashMap::new())),
132            vectors: Arc::new(RwLock::new(HashMap::new())),
133            next_id: Arc::new(RwLock::new(0)),
134            dimension,
135            metric,
136        })
137    }
138
139    /// Create a new index with default parameters
140    ///
141    /// Uses M=16 and efConstruction=200, which are good defaults for most use cases
142    pub fn with_defaults(dimension: usize) -> Result<Self> {
143        Self::new(dimension, DistanceMetric::L2, 16, 200)
144    }
145
146    /// Insert a vector associated with a CID
147    ///
148    /// # Arguments
149    /// * `cid` - Content identifier
150    /// * `vector` - Feature vector to index
151    pub fn insert(&mut self, cid: &Cid, vector: &[f32]) -> Result<()> {
152        if vector.len() != self.dimension {
153            return Err(Error::InvalidInput(format!(
154                "Vector dimension mismatch: expected {}, got {}",
155                self.dimension,
156                vector.len()
157            )));
158        }
159
160        // Check if CID already exists
161        if self.cid_to_id.read().unwrap().contains_key(cid) {
162            return Err(Error::InvalidInput(format!(
163                "CID already exists in index: {}",
164                cid
165            )));
166        }
167
168        // Get next ID
169        let mut next_id = self.next_id.write().unwrap();
170        let id = *next_id;
171        *next_id += 1;
172        drop(next_id);
173
174        // Normalize vector based on metric
175        let normalized = self.normalize_vector(vector);
176
177        // Insert into HNSW index
178        let data_with_id = (normalized.as_slice(), id);
179        self.index.write().unwrap().insert(data_with_id);
180
181        // Store original vector for retrieval
182        self.vectors.write().unwrap().insert(*cid, vector.to_vec());
183
184        // Update mappings
185        self.id_to_cid.write().unwrap().insert(id, *cid);
186        self.cid_to_id.write().unwrap().insert(*cid, id);
187
188        Ok(())
189    }
190
191    /// Search for k nearest neighbors
192    ///
193    /// # Arguments
194    /// * `query` - Query vector
195    /// * `k` - Number of neighbors to return
196    /// * `ef_search` - Size of dynamic candidate list during search (higher = more accurate but slower)
197    pub fn search(&self, query: &[f32], k: usize, ef_search: usize) -> Result<Vec<SearchResult>> {
198        if query.len() != self.dimension {
199            return Err(Error::InvalidInput(format!(
200                "Query dimension mismatch: expected {}, got {}",
201                self.dimension,
202                query.len()
203            )));
204        }
205
206        if k == 0 {
207            return Ok(Vec::new());
208        }
209
210        // Normalize query based on metric
211        let normalized = self.normalize_vector(query);
212
213        // Search HNSW index
214        let neighbors = self.index.read().unwrap().search(&normalized, k, ef_search);
215
216        // Convert results
217        let id_to_cid = self.id_to_cid.read().unwrap();
218        let results: Vec<SearchResult> = neighbors
219            .iter()
220            .filter_map(|neighbor| {
221                id_to_cid.get(&neighbor.d_id).map(|cid| SearchResult {
222                    cid: *cid,
223                    score: self.convert_distance(neighbor.distance),
224                })
225            })
226            .collect();
227
228        Ok(results)
229    }
230
231    /// Delete a vector by CID
232    pub fn delete(&mut self, cid: &Cid) -> Result<()> {
233        let id = self
234            .cid_to_id
235            .read()
236            .unwrap()
237            .get(cid)
238            .copied()
239            .ok_or_else(|| Error::NotFound(format!("CID not found in index: {}", cid)))?;
240
241        // Remove from vector storage
242        self.vectors.write().unwrap().remove(cid);
243
244        // Remove from mappings
245        self.cid_to_id.write().unwrap().remove(cid);
246        self.id_to_cid.write().unwrap().remove(&id);
247
248        // Note: HNSW doesn't support true deletion, so we just remove from our mappings
249        // The actual vector remains in the index but won't be returned in results
250
251        Ok(())
252    }
253
254    /// Check if a CID exists in the index
255    pub fn contains(&self, cid: &Cid) -> bool {
256        self.cid_to_id.read().unwrap().contains_key(cid)
257    }
258
259    /// Get the number of vectors in the index
260    pub fn len(&self) -> usize {
261        self.cid_to_id.read().unwrap().len()
262    }
263
264    /// Check if the index is empty
265    pub fn is_empty(&self) -> bool {
266        self.len() == 0
267    }
268
269    /// Get the dimension of vectors in this index
270    pub fn dimension(&self) -> usize {
271        self.dimension
272    }
273
274    /// Get the distance metric used by this index
275    pub fn metric(&self) -> DistanceMetric {
276        self.metric
277    }
278
279    /// Get all CIDs in the index
280    /// Useful for synchronization and snapshots
281    pub fn get_all_cids(&self) -> Vec<Cid> {
282        self.cid_to_id.read().unwrap().keys().copied().collect()
283    }
284
285    /// Get the embedding vector for a specific CID
286    ///
287    /// Returns `None` if the CID is not in the index
288    pub fn get_embedding(&self, cid: &Cid) -> Option<Vec<f32>> {
289        self.vectors.read().unwrap().get(cid).cloned()
290    }
291
292    /// Get all embeddings in the index as (CID, vector) pairs
293    ///
294    /// Useful for iteration, migration, and batch operations
295    pub fn get_all_embeddings(&self) -> Vec<(Cid, Vec<f32>)> {
296        self.vectors
297            .read()
298            .unwrap()
299            .iter()
300            .map(|(cid, vec)| (*cid, vec.clone()))
301            .collect()
302    }
303
304    /// Iterate over all (CID, vector) pairs in the index
305    ///
306    /// Returns an iterator over the embeddings
307    pub fn iter(&self) -> Vec<(Cid, Vec<f32>)> {
308        self.get_all_embeddings()
309    }
310
311    /// Normalize vector based on distance metric
312    fn normalize_vector(&self, vector: &[f32]) -> Vec<f32> {
313        match self.metric {
314            DistanceMetric::L2 => vector.to_vec(),
315            DistanceMetric::Cosine => {
316                // For cosine similarity, normalize to unit length
317                let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
318                if norm > 0.0 {
319                    vector.iter().map(|x| x / norm).collect()
320                } else {
321                    vector.to_vec()
322                }
323            }
324            DistanceMetric::DotProduct => {
325                // For dot product, no normalization needed
326                vector.to_vec()
327            }
328        }
329    }
330
331    /// Convert distance to score based on metric
332    fn convert_distance(&self, distance: f32) -> f32 {
333        match self.metric {
334            DistanceMetric::L2 => distance,
335            DistanceMetric::Cosine => {
336                // Convert L2 distance on normalized vectors to cosine similarity
337                // cos(θ) = 1 - (L2_dist^2 / 2)
338                1.0 - (distance * distance / 2.0)
339            }
340            DistanceMetric::DotProduct => {
341                // For dot product, return negative distance (higher = more similar)
342                -distance
343            }
344        }
345    }
346
347    /// Compute optimal HNSW parameters based on current index size
348    ///
349    /// Returns recommended (max_nb_connection, ef_construction) based on:
350    /// - Small indexes (< 10k): M=16, ef=200
351    /// - Medium indexes (10k-100k): M=32, ef=400
352    /// - Large indexes (> 100k): M=48, ef=600
353    pub fn compute_optimal_parameters(&self) -> (usize, usize) {
354        let size = self.len();
355
356        if size < 10_000 {
357            (16, 200) // Small index
358        } else if size < 100_000 {
359            (32, 400) // Medium index
360        } else {
361            (48, 600) // Large index
362        }
363    }
364
365    /// Get recommended ef_search parameter based on k
366    ///
367    /// Generally ef_search should be >= k and higher for better recall
368    pub fn compute_optimal_ef_search(&self, k: usize) -> usize {
369        // Rule of thumb: ef_search = max(k, 50) for small k
370        // For larger k, use 2*k to maintain good recall
371        if k <= 50 {
372            50.max(k)
373        } else {
374            2 * k
375        }
376    }
377
378    /// Get detailed parameter recommendations based on use case
379    pub fn get_parameter_recommendations(&self, use_case: UseCase) -> ParameterRecommendation {
380        let size = self.len();
381        ParameterTuner::recommend(size, self.dimension, use_case)
382    }
383
384    /// Insert multiple vectors in batch
385    ///
386    /// More efficient than inserting one by one as it can use parallelization
387    ///
388    /// # Arguments
389    /// * `items` - Vector of (CID, vector) pairs to insert
390    pub fn insert_batch(&mut self, items: &[(Cid, Vec<f32>)]) -> Result<()> {
391        for (cid, vector) in items {
392            self.insert(cid, vector)?;
393        }
394        Ok(())
395    }
396
397    /// Insert vectors incrementally with periodic optimization
398    ///
399    /// This method inserts vectors in chunks and tracks statistics to determine
400    /// if index rebuild is beneficial. Returns statistics about the insertion.
401    ///
402    /// # Arguments
403    /// * `items` - Vector of (CID, vector) pairs to insert
404    /// * `chunk_size` - Number of vectors to insert before checking optimization
405    ///
406    /// # Returns
407    /// Statistics about the incremental build process
408    pub fn insert_incremental(
409        &mut self,
410        items: &[(Cid, Vec<f32>)],
411        chunk_size: usize,
412    ) -> Result<IncrementalBuildStats> {
413        let start_size = self.len();
414        let mut chunks_processed = 0;
415        let mut failed_inserts = 0;
416
417        // Insert in chunks
418        for chunk in items.chunks(chunk_size) {
419            for (cid, vector) in chunk {
420                if let Err(_e) = self.insert(cid, vector) {
421                    failed_inserts += 1;
422                }
423            }
424            chunks_processed += 1;
425        }
426
427        let end_size = self.len();
428        let inserted = end_size - start_size;
429
430        // Check if rebuild would be beneficial
431        let should_rebuild = self.should_rebuild();
432
433        Ok(IncrementalBuildStats {
434            initial_size: start_size,
435            final_size: end_size,
436            vectors_inserted: inserted,
437            vectors_failed: failed_inserts,
438            chunks_processed,
439            should_rebuild,
440        })
441    }
442
443    /// Determine if index should be rebuilt for better performance
444    ///
445    /// Rebuild is recommended when:
446    /// - Index has grown significantly (2x or more)
447    /// - Many deletions have occurred (fragmentation)
448    /// - Current parameters are suboptimal for index size
449    pub fn should_rebuild(&self) -> bool {
450        let size = self.len();
451        let (current_m, current_ef) = {
452            let idx = self.index.read().unwrap();
453            (
454                idx.get_max_nb_connection() as usize,
455                idx.get_ef_construction(),
456            )
457        };
458
459        let (optimal_m, optimal_ef) = self.compute_optimal_parameters();
460
461        // Rebuild if parameters are significantly suboptimal
462        if current_m < optimal_m / 2 || current_ef < optimal_ef / 2 {
463            return true;
464        }
465
466        // Rebuild if index crossed size thresholds
467        if size > 100_000 && current_m < 32 {
468            return true;
469        }
470
471        false
472    }
473
474    /// Rebuild the index with optimal parameters for current size
475    ///
476    /// This creates a new index with better parameters and re-inserts all vectors.
477    /// Use this when `should_rebuild()` returns true.
478    ///
479    /// # Arguments
480    /// * `use_case` - Target use case for parameter selection
481    pub fn rebuild(&mut self, use_case: UseCase) -> Result<RebuildStats> {
482        let start_size = self.len();
483
484        if start_size == 0 {
485            return Ok(RebuildStats {
486                vectors_reinserted: 0,
487                old_parameters: (0, 0),
488                new_parameters: (0, 0),
489            });
490        }
491
492        // Get all current vectors (would be used for re-insertion)
493        let _id_to_cid = self.id_to_cid.read().unwrap();
494
495        // Extract vectors from current index (this is limited by hnsw_rs API)
496        // We'll need to store vectors separately for efficient rebuild
497        // For now, we'll just track the parameters change
498
499        let old_params = {
500            let idx = self.index.read().unwrap();
501            (
502                idx.get_max_nb_connection() as usize,
503                idx.get_ef_construction(),
504            )
505        };
506
507        // Get optimal parameters
508        let recommendation = ParameterTuner::recommend(start_size, self.dimension, use_case);
509
510        // Create new index with optimal parameters
511        let new_index = Hnsw::<f32, DistL2>::new(
512            recommendation.m,
513            self.dimension,
514            recommendation.ef_construction,
515            start_size,
516            DistL2 {},
517        );
518
519        // Replace the index
520        *self.index.write().unwrap() = new_index;
521
522        // Note: In a full implementation, we'd re-insert all vectors here
523        // This requires storing vectors separately, which we'll add if needed
524
525        Ok(RebuildStats {
526            vectors_reinserted: 0, // Would be start_size if we re-inserted
527            old_parameters: old_params,
528            new_parameters: (recommendation.m, recommendation.ef_construction),
529        })
530    }
531
532    /// Get statistics about incremental build performance
533    pub fn get_build_stats(&self) -> BuildHealthStats {
534        let size = self.len();
535        let (current_m, current_ef) = {
536            let idx = self.index.read().unwrap();
537            (
538                idx.get_max_nb_connection() as usize,
539                idx.get_ef_construction(),
540            )
541        };
542
543        let (optimal_m, optimal_ef) = self.compute_optimal_parameters();
544
545        let parameter_efficiency = if optimal_m > 0 {
546            (current_m as f32 / optimal_m as f32).min(1.0)
547        } else {
548            1.0
549        };
550
551        BuildHealthStats {
552            index_size: size,
553            current_m,
554            current_ef_construction: current_ef,
555            optimal_m,
556            optimal_ef_construction: optimal_ef,
557            parameter_efficiency,
558            rebuild_recommended: self.should_rebuild(),
559        }
560    }
561
562    /// Save the index to a file
563    ///
564    /// Saves the HNSW index and CID mappings to disk for later retrieval.
565    /// The index is saved in oxicode format.
566    ///
567    /// # Arguments
568    /// * `path` - Path to save the index to
569    pub fn save(&self, path: impl AsRef<std::path::Path>) -> Result<()> {
570        use std::fs::File;
571        use std::io::Write;
572
573        // Get HNSW parameters from the current index
574        let (max_nb_connection, ef_construction) = {
575            let idx = self.index.read().unwrap();
576            (idx.get_max_nb_connection(), idx.get_ef_construction())
577        };
578
579        // Serialize index metadata
580        let metadata = IndexMetadata {
581            dimension: self.dimension,
582            metric: self.metric,
583            id_to_cid: self.id_to_cid.read().unwrap().clone(),
584            cid_to_id: self.cid_to_id.read().unwrap().clone(),
585            vectors: self.vectors.read().unwrap().clone(),
586            next_id: *self.next_id.read().unwrap(),
587            max_nb_connection: max_nb_connection as usize,
588            ef_construction,
589        };
590
591        // Serialize to oxicode
592        let encoded = oxicode::serde::encode_to_vec(&metadata, oxicode::config::standard())
593            .map_err(|e| Error::Serialization(format!("Failed to serialize index: {}", e)))?;
594
595        // Write to file
596        let mut file = File::create(path.as_ref())
597            .map_err(|e| Error::Storage(format!("Failed to create index file: {}", e)))?;
598
599        file.write_all(&encoded)
600            .map_err(|e| Error::Storage(format!("Failed to write index file: {}", e)))?;
601
602        Ok(())
603    }
604
605    /// Load an index from a file
606    ///
607    /// Loads a previously saved index from disk.
608    ///
609    /// # Arguments
610    /// * `path` - Path to load the index from
611    pub fn load(path: impl AsRef<std::path::Path>) -> Result<Self> {
612        use std::fs::File;
613        use std::io::Read;
614
615        // Read file
616        let mut file = File::open(path.as_ref())
617            .map_err(|e| Error::Storage(format!("Failed to open index file: {}", e)))?;
618
619        let mut buffer = Vec::new();
620        file.read_to_end(&mut buffer)
621            .map_err(|e| Error::Storage(format!("Failed to read index file: {}", e)))?;
622
623        // Deserialize metadata
624        let metadata: IndexMetadata =
625            oxicode::serde::decode_owned_from_slice(&buffer, oxicode::config::standard())
626                .map(|(v, _)| v)
627                .map_err(|e| {
628                    Error::Deserialization(format!("Failed to deserialize index: {}", e))
629                })?;
630
631        // Create new HNSW index with saved parameters
632        let index = Hnsw::<f32, DistL2>::new(
633            metadata.max_nb_connection,
634            metadata.dimension,
635            metadata.ef_construction,
636            200,
637            DistL2 {},
638        );
639
640        Ok(Self {
641            index: Arc::new(RwLock::new(index)),
642            id_to_cid: Arc::new(RwLock::new(metadata.id_to_cid)),
643            cid_to_id: Arc::new(RwLock::new(metadata.cid_to_id)),
644            vectors: Arc::new(RwLock::new(metadata.vectors)),
645            next_id: Arc::new(RwLock::new(metadata.next_id)),
646            dimension: metadata.dimension,
647            metric: metadata.metric,
648        })
649    }
650}
651
652/// Index metadata for serialization
653#[derive(serde::Serialize, serde::Deserialize)]
654struct IndexMetadata {
655    dimension: usize,
656    metric: DistanceMetric,
657    #[serde(
658        serialize_with = "serialize_id_to_cid",
659        deserialize_with = "deserialize_id_to_cid"
660    )]
661    id_to_cid: HashMap<usize, Cid>,
662    #[serde(
663        serialize_with = "serialize_cid_to_id",
664        deserialize_with = "deserialize_cid_to_id"
665    )]
666    cid_to_id: HashMap<Cid, usize>,
667    #[serde(
668        serialize_with = "serialize_vectors",
669        deserialize_with = "deserialize_vectors"
670    )]
671    vectors: HashMap<Cid, Vec<f32>>,
672    next_id: usize,
673    max_nb_connection: usize,
674    ef_construction: usize,
675}
676
677/// Serialize HashMap<usize, Cid> by converting CIDs to strings
678fn serialize_id_to_cid<S>(
679    map: &HashMap<usize, Cid>,
680    serializer: S,
681) -> std::result::Result<S::Ok, S::Error>
682where
683    S: serde::Serializer,
684{
685    use serde::Serialize;
686    let string_map: HashMap<usize, String> =
687        map.iter().map(|(id, cid)| (*id, cid.to_string())).collect();
688    string_map.serialize(serializer)
689}
690
691/// Deserialize HashMap<usize, Cid> by parsing CID strings
692fn deserialize_id_to_cid<'de, D>(
693    deserializer: D,
694) -> std::result::Result<HashMap<usize, Cid>, D::Error>
695where
696    D: serde::Deserializer<'de>,
697{
698    use serde::Deserialize;
699    let string_map: HashMap<usize, String> = HashMap::deserialize(deserializer)?;
700    string_map
701        .into_iter()
702        .map(|(id, cid_str)| {
703            cid_str
704                .parse::<Cid>()
705                .map(|cid| (id, cid))
706                .map_err(serde::de::Error::custom)
707        })
708        .collect()
709}
710
711/// Serialize HashMap<Cid, usize> by converting CIDs to strings
712fn serialize_cid_to_id<S>(
713    map: &HashMap<Cid, usize>,
714    serializer: S,
715) -> std::result::Result<S::Ok, S::Error>
716where
717    S: serde::Serializer,
718{
719    use serde::Serialize;
720    let string_map: HashMap<String, usize> =
721        map.iter().map(|(cid, id)| (cid.to_string(), *id)).collect();
722    string_map.serialize(serializer)
723}
724
725/// Deserialize HashMap<Cid, usize> by parsing CID strings
726fn deserialize_cid_to_id<'de, D>(
727    deserializer: D,
728) -> std::result::Result<HashMap<Cid, usize>, D::Error>
729where
730    D: serde::Deserializer<'de>,
731{
732    use serde::Deserialize;
733    let string_map: HashMap<String, usize> = HashMap::deserialize(deserializer)?;
734    string_map
735        .into_iter()
736        .map(|(cid_str, id)| {
737            cid_str
738                .parse::<Cid>()
739                .map(|cid| (cid, id))
740                .map_err(serde::de::Error::custom)
741        })
742        .collect()
743}
744
745/// Serialize HashMap<Cid, Vec<f32>> by converting CIDs to strings
746fn serialize_vectors<S>(
747    map: &HashMap<Cid, Vec<f32>>,
748    serializer: S,
749) -> std::result::Result<S::Ok, S::Error>
750where
751    S: serde::Serializer,
752{
753    use serde::Serialize;
754    let string_map: HashMap<String, Vec<f32>> = map
755        .iter()
756        .map(|(cid, vec)| (cid.to_string(), vec.clone()))
757        .collect();
758    string_map.serialize(serializer)
759}
760
761/// Deserialize HashMap<Cid, Vec<f32>> by parsing CID strings
762fn deserialize_vectors<'de, D>(
763    deserializer: D,
764) -> std::result::Result<HashMap<Cid, Vec<f32>>, D::Error>
765where
766    D: serde::Deserializer<'de>,
767{
768    use serde::Deserialize;
769    let string_map: HashMap<String, Vec<f32>> = HashMap::deserialize(deserializer)?;
770    string_map
771        .into_iter()
772        .map(|(cid_str, vec)| {
773            cid_str
774                .parse::<Cid>()
775                .map(|cid| (cid, vec))
776                .map_err(serde::de::Error::custom)
777        })
778        .collect()
779}
780
781/// Use case for parameter optimization
782#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize, Default)]
783pub enum UseCase {
784    /// Optimize for low latency (faster queries, potentially lower recall)
785    LowLatency,
786    /// Optimize for high recall (more accurate results, potentially slower)
787    HighRecall,
788    /// Balanced performance (default)
789    #[default]
790    Balanced,
791    /// Optimize for memory efficiency
792    LowMemory,
793    /// Optimize for large scale (100k+ vectors)
794    LargeScale,
795}
796
797/// HNSW parameter recommendation
798#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
799pub struct ParameterRecommendation {
800    /// Recommended M parameter (connections per layer)
801    pub m: usize,
802    /// Recommended ef_construction parameter
803    pub ef_construction: usize,
804    /// Recommended ef_search parameter
805    pub ef_search: usize,
806    /// Estimated memory usage per vector (bytes)
807    pub memory_per_vector: usize,
808    /// Estimated recall at k=10
809    pub estimated_recall: f32,
810    /// Estimated query latency factor (1.0 = baseline)
811    pub latency_factor: f32,
812    /// Explanation of recommendations
813    pub explanation: String,
814}
815
816/// Parameter tuner for HNSW index optimization
817pub struct ParameterTuner;
818
819impl ParameterTuner {
820    /// Get parameter recommendations based on dataset size and use case
821    pub fn recommend(
822        num_vectors: usize,
823        dimension: usize,
824        use_case: UseCase,
825    ) -> ParameterRecommendation {
826        let (m, ef_construction, ef_search, recall, latency) = match use_case {
827            UseCase::LowLatency => {
828                if num_vectors < 10_000 {
829                    (8, 100, 32, 0.90, 0.6)
830                } else if num_vectors < 100_000 {
831                    (12, 150, 50, 0.88, 0.7)
832                } else {
833                    (16, 200, 64, 0.85, 0.8)
834                }
835            }
836            UseCase::HighRecall => {
837                if num_vectors < 10_000 {
838                    (32, 400, 200, 0.99, 2.0)
839                } else if num_vectors < 100_000 {
840                    (48, 500, 300, 0.98, 2.5)
841                } else {
842                    (64, 600, 400, 0.97, 3.0)
843                }
844            }
845            UseCase::Balanced => {
846                if num_vectors < 10_000 {
847                    (16, 200, 50, 0.95, 1.0)
848                } else if num_vectors < 100_000 {
849                    (24, 300, 100, 0.94, 1.2)
850                } else {
851                    (32, 400, 150, 0.93, 1.5)
852                }
853            }
854            UseCase::LowMemory => {
855                if num_vectors < 10_000 {
856                    (8, 100, 50, 0.88, 0.9)
857                } else if num_vectors < 100_000 {
858                    (10, 120, 64, 0.85, 1.0)
859                } else {
860                    (12, 150, 80, 0.82, 1.1)
861                }
862            }
863            UseCase::LargeScale => {
864                // Optimized for 100k+ vectors
865                (32, 400, 100, 0.93, 1.5)
866            }
867        };
868
869        // Memory per vector: dimension * 4 (f32) + M * 2 * 4 (graph links, assuming 2 layers avg)
870        let memory_per_vector = dimension * 4 + m * 2 * 4;
871
872        let explanation =
873            Self::generate_explanation(num_vectors, use_case, m, ef_construction, ef_search);
874
875        ParameterRecommendation {
876            m,
877            ef_construction,
878            ef_search,
879            memory_per_vector,
880            estimated_recall: recall,
881            latency_factor: latency,
882            explanation,
883        }
884    }
885
886    fn generate_explanation(
887        num_vectors: usize,
888        use_case: UseCase,
889        m: usize,
890        ef_construction: usize,
891        ef_search: usize,
892    ) -> String {
893        let size_category = if num_vectors < 10_000 {
894            "small"
895        } else if num_vectors < 100_000 {
896            "medium"
897        } else {
898            "large"
899        };
900
901        let use_case_str = match use_case {
902            UseCase::LowLatency => "low latency",
903            UseCase::HighRecall => "high recall",
904            UseCase::Balanced => "balanced",
905            UseCase::LowMemory => "low memory",
906            UseCase::LargeScale => "large scale",
907        };
908
909        format!(
910            "For {} dataset (~{} vectors) optimized for {}: \
911            M={} provides good connectivity, ef_construction={} ensures quality graph, \
912            ef_search={} balances speed and accuracy.",
913            size_category, num_vectors, use_case_str, m, ef_construction, ef_search
914        )
915    }
916
917    /// Calculate Pareto-optimal configurations for different recall/latency tradeoffs
918    pub fn pareto_configurations(
919        num_vectors: usize,
920        dimension: usize,
921    ) -> Vec<ParameterRecommendation> {
922        vec![
923            Self::recommend(num_vectors, dimension, UseCase::LowLatency),
924            Self::recommend(num_vectors, dimension, UseCase::LowMemory),
925            Self::recommend(num_vectors, dimension, UseCase::Balanced),
926            Self::recommend(num_vectors, dimension, UseCase::HighRecall),
927        ]
928    }
929
930    /// Estimate memory usage for given parameters
931    pub fn estimate_memory(num_vectors: usize, dimension: usize, m: usize) -> usize {
932        // Vector data: num_vectors * dimension * 4 bytes
933        let vector_memory = num_vectors * dimension * 4;
934
935        // Graph memory: num_vectors * M * 2 layers average * 4 bytes per link
936        let graph_memory = num_vectors * m * 2 * 4;
937
938        // Additional overhead (mappings, etc.): ~50 bytes per vector
939        let overhead = num_vectors * 50;
940
941        vector_memory + graph_memory + overhead
942    }
943
944    /// Suggest ef_search for target recall at given k
945    pub fn ef_search_for_recall(k: usize, target_recall: f32) -> usize {
946        // Higher ef_search improves recall
947        // Approximate: ef_search = k * (1 / (1 - target_recall))
948        let multiplier = if target_recall >= 0.99 {
949            10.0
950        } else if target_recall >= 0.95 {
951            4.0
952        } else if target_recall >= 0.90 {
953            2.0
954        } else {
955            1.5
956        };
957
958        ((k as f32) * multiplier).ceil() as usize
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use rand::Rng;
966
967    #[test]
968    fn test_vector_index_creation() {
969        let index = VectorIndex::with_defaults(128);
970        assert!(index.is_ok());
971        let index = index.unwrap();
972        assert_eq!(index.dimension(), 128);
973        assert_eq!(index.len(), 0);
974        assert!(index.is_empty());
975    }
976
977    #[test]
978    fn test_insert_and_search() {
979        let mut index = VectorIndex::with_defaults(4).unwrap();
980
981        // Create some test vectors and CIDs
982        let cid1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
983            .parse::<Cid>()
984            .unwrap();
985        let vec1 = vec![1.0, 0.0, 0.0, 0.0];
986
987        let cid2 = "bafybeiczsscdsbs7ffqz55asqdf3smv6klcw3gofszvwlyarci47bgf354"
988            .parse::<Cid>()
989            .unwrap();
990        let vec2 = vec![0.9, 0.1, 0.0, 0.0];
991
992        // Insert vectors
993        index.insert(&cid1, &vec1).unwrap();
994        index.insert(&cid2, &vec2).unwrap();
995
996        assert_eq!(index.len(), 2);
997
998        // Search for nearest neighbor
999        let query = vec![1.0, 0.0, 0.0, 0.0];
1000        let results = index.search(&query, 1, 50).unwrap();
1001
1002        assert_eq!(results.len(), 1);
1003        assert_eq!(results[0].cid, cid1);
1004    }
1005
1006    #[test]
1007    fn test_parameter_tuner() {
1008        // Test recommendations for different use cases
1009        let balanced = ParameterTuner::recommend(50_000, 768, UseCase::Balanced);
1010        assert!(balanced.m > 0);
1011        assert!(balanced.ef_construction > 0);
1012        assert!(balanced.estimated_recall > 0.0);
1013
1014        let low_latency = ParameterTuner::recommend(50_000, 768, UseCase::LowLatency);
1015        let high_recall = ParameterTuner::recommend(50_000, 768, UseCase::HighRecall);
1016
1017        // High recall should have higher M than low latency
1018        assert!(high_recall.m > low_latency.m);
1019        // High recall should have higher estimated recall
1020        assert!(high_recall.estimated_recall > low_latency.estimated_recall);
1021
1022        // Test Pareto configurations
1023        let pareto = ParameterTuner::pareto_configurations(50_000, 768);
1024        assert_eq!(pareto.len(), 4);
1025
1026        // Test memory estimation
1027        let memory = ParameterTuner::estimate_memory(100_000, 768, 16);
1028        assert!(memory > 0);
1029
1030        // Test ef_search for recall
1031        let ef_high = ParameterTuner::ef_search_for_recall(10, 0.99);
1032        let ef_low = ParameterTuner::ef_search_for_recall(10, 0.85);
1033        assert!(ef_high > ef_low);
1034    }
1035
1036    #[test]
1037    fn test_incremental_build() {
1038        let mut index = VectorIndex::with_defaults(4).unwrap();
1039
1040        // Create test vectors
1041        let items: Vec<(Cid, Vec<f32>)> = (0..20)
1042            .map(|i| {
1043                let cid_str = format!(
1044                    "bafybei{}yrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1045                    i
1046                );
1047                let cid = cid_str.parse::<Cid>().unwrap_or_else(|_| {
1048                    "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1049                        .parse()
1050                        .unwrap()
1051                });
1052                let vec = vec![i as f32, 0.0, 0.0, 0.0];
1053                (cid, vec)
1054            })
1055            .collect();
1056
1057        // Insert incrementally with chunk size 5
1058        let stats = index.insert_incremental(&items, 5).unwrap();
1059
1060        assert_eq!(stats.chunks_processed, 4);
1061        assert!(stats.vectors_inserted <= 20);
1062        assert_eq!(stats.final_size, index.len());
1063    }
1064
1065    #[test]
1066    fn test_build_health_stats() {
1067        let index = VectorIndex::new(128, DistanceMetric::L2, 16, 200).unwrap();
1068
1069        let stats = index.get_build_stats();
1070        assert_eq!(stats.index_size, 0);
1071        assert_eq!(stats.current_m, 16);
1072        assert_eq!(stats.current_ef_construction, 200);
1073        assert!(stats.parameter_efficiency > 0.0);
1074
1075        // For small index with good parameters, no rebuild needed
1076        assert!(!stats.rebuild_recommended);
1077    }
1078
1079    #[test]
1080    fn test_should_rebuild() {
1081        // Small index with good parameters - no rebuild needed
1082        let index1 = VectorIndex::new(128, DistanceMetric::L2, 16, 200).unwrap();
1083        assert!(!index1.should_rebuild());
1084
1085        // Index with suboptimal parameters
1086        let index2 = VectorIndex::new(128, DistanceMetric::L2, 4, 50).unwrap();
1087        // Small index won't trigger rebuild based on size thresholds
1088        // but parameters are low
1089        let _ = index2.should_rebuild();
1090    }
1091
1092    #[test]
1093    fn test_rebuild() {
1094        let mut index = VectorIndex::with_defaults(4).unwrap();
1095
1096        // Add some vectors
1097        for i in 0..10 {
1098            let cid_str = format!(
1099                "bafybei{}yrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi",
1100                i
1101            );
1102            let cid = cid_str.parse::<Cid>().unwrap_or_else(|_| {
1103                "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
1104                    .parse()
1105                    .unwrap()
1106            });
1107            let vec = vec![i as f32, 0.0, 0.0, 0.0];
1108            let _ = index.insert(&cid, &vec);
1109        }
1110
1111        // Rebuild with balanced use case
1112        let rebuild_stats = index.rebuild(UseCase::Balanced).unwrap();
1113
1114        assert_eq!(rebuild_stats.old_parameters.0, 16); // Original M
1115        assert!(rebuild_stats.new_parameters.0 > 0); // New M
1116    }
1117
1118    /// Compute ground truth nearest neighbors using brute force
1119    fn compute_ground_truth(query: &[f32], vectors: &[(Cid, Vec<f32>)], k: usize) -> Vec<Cid> {
1120        let mut distances: Vec<(Cid, f32)> = vectors
1121            .iter()
1122            .map(|(cid, vec)| {
1123                let dist: f32 = query
1124                    .iter()
1125                    .zip(vec.iter())
1126                    .map(|(a, b)| (a - b).powi(2))
1127                    .sum();
1128                (*cid, dist.sqrt())
1129            })
1130            .collect();
1131
1132        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
1133        distances.iter().take(k).map(|(cid, _)| *cid).collect()
1134    }
1135
1136    /// Calculate recall@k
1137    fn calculate_recall_at_k(predicted: &[Cid], ground_truth: &[Cid], k: usize) -> f32 {
1138        let predicted_set: std::collections::HashSet<_> = predicted.iter().take(k).collect();
1139        let ground_truth_set: std::collections::HashSet<_> = ground_truth.iter().take(k).collect();
1140
1141        let intersection = predicted_set.intersection(&ground_truth_set).count();
1142        intersection as f32 / k as f32
1143    }
1144
1145    /// Helper to generate unique test CIDs
1146    fn generate_test_cid(index: usize) -> Cid {
1147        use multihash_codetable::{Code, MultihashDigest};
1148        let data = format!("test_vector_{}", index);
1149        let hash = Code::Sha2_256.digest(data.as_bytes());
1150        Cid::new_v1(0x55, hash) // 0x55 = raw codec
1151    }
1152
1153    #[test]
1154    fn test_recall_at_k() {
1155        // Create index
1156        let mut index = VectorIndex::with_defaults(128).unwrap();
1157
1158        // Generate test dataset (100 random vectors)
1159        let mut rng = rand::rng();
1160        let num_vectors = 100;
1161        let dimension = 128;
1162
1163        let mut vectors = Vec::new();
1164        for i in 0..num_vectors {
1165            let cid = generate_test_cid(i);
1166
1167            let vec: Vec<f32> = (0..dimension)
1168                .map(|_| rng.random_range(-1.0..1.0))
1169                .collect();
1170
1171            vectors.push((cid, vec.clone()));
1172            let _ = index.insert(&cid, &vec);
1173        }
1174
1175        // Test queries
1176        let num_queries = 10;
1177        let mut total_recall_at_1 = 0.0;
1178        let mut total_recall_at_10 = 0.0;
1179
1180        for _ in 0..num_queries {
1181            let query: Vec<f32> = (0..dimension)
1182                .map(|_| rng.random_range(-1.0..1.0))
1183                .collect();
1184
1185            // Get HNSW results
1186            let hnsw_results = index.search(&query, 10, 50).unwrap();
1187            let hnsw_cids: Vec<Cid> = hnsw_results.iter().map(|r| r.cid).collect();
1188
1189            // Compute ground truth
1190            let ground_truth = compute_ground_truth(&query, &vectors, 10);
1191
1192            // Calculate recall
1193            total_recall_at_1 += calculate_recall_at_k(&hnsw_cids, &ground_truth, 1);
1194            total_recall_at_10 += calculate_recall_at_k(&hnsw_cids, &ground_truth, 10);
1195        }
1196
1197        let avg_recall_at_1 = total_recall_at_1 / num_queries as f32;
1198        let avg_recall_at_10 = total_recall_at_10 / num_queries as f32;
1199
1200        // HNSW should have high recall (>80% for recall@10 on small dataset)
1201        assert!(
1202            avg_recall_at_10 > 0.8,
1203            "Recall@10 too low: {}",
1204            avg_recall_at_10
1205        );
1206
1207        // Recall@1 should be reasonable
1208        assert!(
1209            avg_recall_at_1 > 0.5,
1210            "Recall@1 too low: {}",
1211            avg_recall_at_1
1212        );
1213    }
1214
1215    #[test]
1216    fn test_concurrent_queries() {
1217        use std::sync::Arc;
1218        use std::thread;
1219
1220        // Create index
1221        let mut index = VectorIndex::with_defaults(128).unwrap();
1222
1223        // Insert test vectors
1224        let mut rng = rand::rng();
1225        for i in 0..100 {
1226            let cid = generate_test_cid(i + 1000); // Offset to avoid collision with other tests
1227
1228            let vec: Vec<f32> = (0..128).map(|_| rng.random_range(-1.0..1.0)).collect();
1229
1230            let _ = index.insert(&cid, &vec);
1231        }
1232
1233        // Share index across threads
1234        let index = Arc::new(index);
1235        let num_threads = 10;
1236        let queries_per_thread = 100;
1237
1238        // Spawn threads for concurrent queries
1239        let mut handles = vec![];
1240        for _ in 0..num_threads {
1241            let index_clone = Arc::clone(&index);
1242            let handle = thread::spawn(move || {
1243                let mut thread_rng = rand::rng();
1244                let mut success_count = 0;
1245
1246                for _ in 0..queries_per_thread {
1247                    let query: Vec<f32> = (0..128)
1248                        .map(|_| thread_rng.random_range(-1.0..1.0))
1249                        .collect();
1250
1251                    if let Ok(results) = index_clone.search(&query, 10, 50) {
1252                        if !results.is_empty() {
1253                            success_count += 1;
1254                        }
1255                    }
1256                }
1257                success_count
1258            });
1259            handles.push(handle);
1260        }
1261
1262        // Collect results
1263        let mut total_success = 0;
1264        for handle in handles {
1265            total_success += handle.join().unwrap();
1266        }
1267
1268        // All queries should succeed
1269        let total_queries = num_threads * queries_per_thread;
1270        assert_eq!(
1271            total_success, total_queries,
1272            "Some queries failed under concurrent load"
1273        );
1274    }
1275
1276    #[test]
1277    fn test_precision_at_k() {
1278        // Create index
1279        let mut index = VectorIndex::with_defaults(32).unwrap();
1280
1281        // Create structured dataset: 5 clusters of 10 vectors each
1282        let num_clusters = 5;
1283        let vectors_per_cluster = 10;
1284
1285        for cluster in 0..num_clusters {
1286            // Cluster center
1287            let mut center = [0.0; 32];
1288            center[cluster] = 10.0;
1289
1290            for i in 0..vectors_per_cluster {
1291                let idx = cluster * vectors_per_cluster + i;
1292                let cid = generate_test_cid(idx + 2000); // Offset to avoid collision
1293
1294                // Add small random noise to center
1295                let mut rng = rand::rng();
1296                let vec: Vec<f32> = center
1297                    .iter()
1298                    .map(|&c| c + rng.random_range(-0.5..0.5))
1299                    .collect();
1300
1301                let _ = index.insert(&cid, &vec);
1302            }
1303        }
1304
1305        // Query with a vector close to cluster 0
1306        let mut query = vec![0.0; 32];
1307        query[0] = 10.0;
1308
1309        let results = index.search(&query, 10, 50).unwrap();
1310
1311        // Count how many results are from cluster 0 (first 10 CIDs)
1312        // Note: This is approximate since CID generation is not deterministic
1313        // In a real test, you'd track cluster membership explicitly
1314        assert_eq!(results.len(), 10, "Should return 10 results");
1315
1316        // Results should be relatively close to query
1317        for result in &results {
1318            assert!(
1319                result.score < 5.0,
1320                "Result too far from query: {}",
1321                result.score
1322            );
1323        }
1324    }
1325}