vicinity 0.3.1

Approximate Nearest Neighbor Search: HNSW, DiskANN, IVF-PQ, ScaNN, quantization
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
//! FusedANN: Attribute-vector fusion for filtered search.
//!
//! Instead of treating filters as hard constraints evaluated during traversal,
//! FusedANN fuses attribute predicates into the vector space via a Lagrangian-like
//! relaxation. Hard filters become soft penalties, enabling efficient approximate
//! search while preserving top-k semantics.
//!
//! # Algorithm
//!
//! 1. **Attribute Embedding**: Encode categorical/numeric attributes into vectors
//! 2. **Fusion**: Combine attribute vectors with content vectors:
//!    `fused = alpha * content_vec + (1-alpha) * attribute_vec`
//! 3. **Search**: Standard ANN search in fused space
//! 4. **Refinement**: Post-filter exact matches if needed
//!
//! # Key Insight
//!
//! When filter selectivity is very high (few matches), traditional pre-filtering
//! degrades recall. FusedANN's soft penalties naturally bias search toward
//! matching regions while still exploring nearby non-matches for navigability.
//!
//! # References
//!
//! - Heidari et al. (2025): "FusedANN: Convexified Hybrid ANN via Attribute-Vector
//!   Fusion" - <https://arxiv.org/abs/2509.19767>

use crate::RetrieveError;
use std::collections::HashMap;

/// Attribute value that can be embedded.
#[derive(Clone, Debug, PartialEq)]
pub enum AttributeValue {
    /// A single categorical label.
    Categorical(String),
    /// A scalar numeric value.
    Numeric(f32),
    /// A numeric range with bounds.
    NumericRange {
        /// Lower bound of the range.
        min: f32,
        /// Upper bound of the range.
        max: f32,
    },
    /// A boolean flag.
    Boolean(bool),
    /// Multiple categorical labels.
    MultiCategory(Vec<String>),
}

/// Attribute schema defining how to embed attributes.
#[derive(Clone, Debug)]
pub struct AttributeSchema {
    /// Dimension of attribute embedding
    pub dimension: usize,
    /// Attribute definitions
    pub attributes: Vec<AttributeDefinition>,
}

/// Single attribute definition.
#[derive(Clone, Debug)]
pub struct AttributeDefinition {
    /// Attribute name.
    pub name: String,
    /// Type and encoding parameters.
    pub attr_type: AttributeType,
    /// Weight when fusing (higher = more important for filtering)
    pub weight: f32,
}

/// Attribute type for embedding strategy.
#[derive(Clone, Debug)]
pub enum AttributeType {
    /// One-hot encoding for categories.
    Categorical {
        /// Known category labels for one-hot encoding.
        categories: Vec<String>,
    },
    /// Normalized numeric value.
    Numeric {
        /// Minimum value for normalization.
        min: f32,
        /// Maximum value for normalization.
        max: f32,
    },
    /// Boolean flag
    Boolean,
}

impl AttributeSchema {
    /// Create new schema.
    pub fn new(dimension: usize) -> Self {
        Self {
            dimension,
            attributes: Vec::new(),
        }
    }

    /// Add categorical attribute.
    pub fn add_categorical(&mut self, name: &str, categories: Vec<String>, weight: f32) {
        self.attributes.push(AttributeDefinition {
            name: name.to_string(),
            attr_type: AttributeType::Categorical { categories },
            weight,
        });
    }

    /// Add numeric attribute.
    pub fn add_numeric(&mut self, name: &str, min: f32, max: f32, weight: f32) {
        self.attributes.push(AttributeDefinition {
            name: name.to_string(),
            attr_type: AttributeType::Numeric { min, max },
            weight,
        });
    }

    /// Add boolean attribute.
    pub fn add_boolean(&mut self, name: &str, weight: f32) {
        self.attributes.push(AttributeDefinition {
            name: name.to_string(),
            attr_type: AttributeType::Boolean,
            weight,
        });
    }

    /// Compute total embedding dimension needed for all attributes.
    pub fn attribute_embedding_dim(&self) -> usize {
        let mut dim = 0;
        for attr in &self.attributes {
            dim += match &attr.attr_type {
                AttributeType::Categorical { categories } => categories.len(),
                AttributeType::Numeric { .. } => 1,
                AttributeType::Boolean => 1,
            };
        }
        dim
    }
}

/// Attribute embedder.
pub struct AttributeEmbedder {
    schema: AttributeSchema,
    embedding_dim: usize,
}

impl AttributeEmbedder {
    /// Create embedder from schema.
    pub fn new(schema: AttributeSchema) -> Self {
        let embedding_dim = schema.attribute_embedding_dim();
        Self {
            schema,
            embedding_dim,
        }
    }

    /// Embed attribute values into vector.
    pub fn embed(&self, attributes: &HashMap<String, AttributeValue>) -> Vec<f32> {
        let mut embedding = vec![0.0f32; self.embedding_dim];
        let mut offset = 0;

        for attr_def in &self.schema.attributes {
            let weight = attr_def.weight;

            match &attr_def.attr_type {
                AttributeType::Categorical { categories } => {
                    if let Some(AttributeValue::Categorical(cat)) = attributes.get(&attr_def.name) {
                        if let Some(idx) = categories.iter().position(|c| c == cat) {
                            embedding[offset + idx] = weight;
                        }
                    } else if let Some(AttributeValue::MultiCategory(cats)) =
                        attributes.get(&attr_def.name)
                    {
                        for cat in cats {
                            if let Some(idx) = categories.iter().position(|c| c == cat) {
                                embedding[offset + idx] = weight / cats.len() as f32;
                            }
                        }
                    }
                    offset += categories.len();
                }
                AttributeType::Numeric { min, max } => {
                    if let Some(AttributeValue::Numeric(val)) = attributes.get(&attr_def.name) {
                        // Normalize to [0, 1]
                        let normalized = (val - min) / (max - min);
                        embedding[offset] = normalized.clamp(0.0, 1.0) * weight;
                    }
                    offset += 1;
                }
                AttributeType::Boolean => {
                    if let Some(AttributeValue::Boolean(b)) = attributes.get(&attr_def.name) {
                        embedding[offset] = if *b { weight } else { 0.0 };
                    }
                    offset += 1;
                }
            }
        }

        embedding
    }

    /// Embed a filter query into vector.
    ///
    /// For filter queries, we want to maximize similarity to matching vectors.
    pub fn embed_filter(&self, filters: &HashMap<String, AttributeValue>) -> Vec<f32> {
        // Filter embedding is the same as data embedding
        // Similarity will be high when attributes match
        self.embed(filters)
    }

    /// Get embedding dimension.
    pub fn embedding_dim(&self) -> usize {
        self.embedding_dim
    }
}

/// FusedANN configuration.
#[derive(Clone, Debug)]
pub struct FusedConfig {
    /// Fusion weight: alpha for content, (1-alpha) for attributes
    pub alpha: f32,
    /// Penalty for attribute mismatch (Lagrangian multiplier)
    pub lambda: f32,
    /// Whether to post-filter exact matches
    pub exact_filter: bool,
    /// Expansion factor for search (to account for soft filtering)
    pub expansion_factor: f32,
}

impl Default for FusedConfig {
    fn default() -> Self {
        Self {
            alpha: 0.7,
            lambda: 1.0,
            exact_filter: true,
            expansion_factor: 2.0,
        }
    }
}

/// Fused vector combining content and attributes.
#[derive(Clone, Debug)]
pub struct FusedVector {
    /// Content vector
    pub content: Vec<f32>,
    /// Attribute embedding
    pub attributes: Vec<f32>,
    /// Precomputed fused vector
    pub fused: Vec<f32>,
}

impl FusedVector {
    /// Create fused vector.
    pub fn new(content: Vec<f32>, attributes: Vec<f32>, alpha: f32) -> Self {
        let fused = Self::compute_fusion(&content, &attributes, alpha);
        Self {
            content,
            attributes,
            fused,
        }
    }

    /// Compute fusion of content and attribute vectors.
    fn compute_fusion(content: &[f32], attributes: &[f32], alpha: f32) -> Vec<f32> {
        let mut fused = Vec::with_capacity(content.len() + attributes.len());

        // Scale content by alpha
        for &c in content {
            fused.push(c * alpha);
        }

        // Scale attributes by (1-alpha)
        for &a in attributes {
            fused.push(a * (1.0 - alpha));
        }

        // Normalize
        let norm: f32 = fused.iter().map(|x| x * x).sum::<f32>().sqrt();
        if norm > 1e-10 {
            for f in &mut fused {
                *f /= norm;
            }
        }

        fused
    }

    /// Dimension of fused vector.
    pub fn fused_dim(&self) -> usize {
        self.fused.len()
    }
}

/// FusedANN index.
pub struct FusedIndex {
    config: FusedConfig,
    embedder: AttributeEmbedder,
    content_dim: usize,
    /// Stored fused vectors (for simple brute-force search)
    vectors: Vec<FusedVector>,
    /// Original attributes (for exact filtering)
    original_attributes: Vec<HashMap<String, AttributeValue>>,
}

impl FusedIndex {
    /// Create new FusedANN index.
    pub fn new(embedder: AttributeEmbedder, content_dim: usize, config: FusedConfig) -> Self {
        Self {
            config,
            embedder,
            content_dim,
            vectors: Vec::new(),
            original_attributes: Vec::new(),
        }
    }

    /// Add vector with attributes.
    pub fn add(
        &mut self,
        content: Vec<f32>,
        attributes: HashMap<String, AttributeValue>,
    ) -> Result<u32, RetrieveError> {
        if content.len() != self.content_dim {
            return Err(RetrieveError::DimensionMismatch {
                query_dim: content.len(),
                doc_dim: self.content_dim,
            });
        }

        let attr_embedding = self.embedder.embed(&attributes);
        let fused = FusedVector::new(content, attr_embedding, self.config.alpha);

        let id = self.vectors.len() as u32;
        self.vectors.push(fused);
        self.original_attributes.push(attributes);

        Ok(id)
    }

    /// Search with optional attribute filter.
    pub fn search(
        &self,
        query_content: &[f32],
        query_filter: Option<&HashMap<String, AttributeValue>>,
        k: usize,
    ) -> Result<Vec<(u32, f32)>, RetrieveError> {
        if query_content.len() != self.content_dim {
            return Err(RetrieveError::DimensionMismatch {
                query_dim: query_content.len(),
                doc_dim: self.content_dim,
            });
        }

        // Create query fused vector
        let query_attrs = if let Some(filter) = query_filter {
            self.embedder.embed_filter(filter)
        } else {
            vec![0.0; self.embedder.embedding_dim()]
        };

        let query_fused = FusedVector::new(query_content.to_vec(), query_attrs, self.config.alpha);

        // Search limit (expanded to account for filtering)
        let search_k = if query_filter.is_some() {
            (k as f32 * self.config.expansion_factor) as usize
        } else {
            k
        };

        // Compute distances to all vectors
        let mut candidates: Vec<(u32, f32)> = self
            .vectors
            .iter()
            .enumerate()
            .map(|(idx, vec)| {
                let dist = fused_distance(&query_fused.fused, &vec.fused, self.config.lambda);
                (idx as u32, dist)
            })
            .collect();

        // Sort by distance
        candidates.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
        candidates.truncate(search_k);

        // Post-filter if exact filtering is enabled
        if self.config.exact_filter {
            if let Some(filter) = query_filter {
                candidates.retain(|(id, _)| {
                    let attrs = &self.original_attributes[*id as usize];
                    check_filter(attrs, filter)
                });
            }
        }

        candidates.truncate(k);
        Ok(candidates)
    }

    /// Get number of vectors in index.
    pub fn len(&self) -> usize {
        self.vectors.len()
    }

    /// Check if index is empty.
    pub fn is_empty(&self) -> bool {
        self.vectors.is_empty()
    }
}

/// Compute distance in fused space.
fn fused_distance(query: &[f32], candidate: &[f32], lambda: f32) -> f32 {
    if query.len() != candidate.len() {
        return f32::INFINITY;
    }

    // Euclidean distance in fused space
    let mut sum = 0.0f32;
    for (q, c) in query.iter().zip(candidate.iter()) {
        let diff = q - c;
        sum += diff * diff;
    }

    // Scale by lambda for attribute-heavy queries
    sum.sqrt() * lambda
}

/// Check if attributes match filter exactly.
fn check_filter(
    attributes: &HashMap<String, AttributeValue>,
    filter: &HashMap<String, AttributeValue>,
) -> bool {
    for (key, filter_val) in filter {
        if let Some(attr_val) = attributes.get(key) {
            let matches = match (filter_val, attr_val) {
                (AttributeValue::Categorical(f), AttributeValue::Categorical(a)) => f == a,
                (AttributeValue::Numeric(f), AttributeValue::Numeric(a)) => (f - a).abs() < 1e-6,
                (AttributeValue::Boolean(f), AttributeValue::Boolean(a)) => f == a,
                (AttributeValue::NumericRange { min, max }, AttributeValue::Numeric(a)) => {
                    a >= min && a <= max
                }
                (AttributeValue::Categorical(f), AttributeValue::MultiCategory(cats)) => {
                    cats.contains(f)
                }
                _ => false,
            };
            if !matches {
                return false;
            }
        } else {
            return false; // Missing required attribute
        }
    }
    true
}

/// Estimate optimal alpha based on filter selectivity.
///
/// When filter is very selective, use lower alpha (more weight on attributes).
/// When filter is loose, use higher alpha (more weight on content).
pub fn recommend_alpha(estimated_selectivity: f32, k: usize, total_docs: usize) -> f32 {
    let expected_matches = estimated_selectivity * total_docs as f32;

    // If expected matches >> k, we can focus on content
    if expected_matches > k as f32 * 10.0 {
        return 0.9; // High alpha = content-focused
    }

    // If expected matches ~= k, balance both
    if expected_matches > k as f32 {
        return 0.7;
    }

    // If expected matches < k, heavily weight attributes
    0.5
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
    use super::*;

    fn create_test_schema() -> AttributeSchema {
        let mut schema = AttributeSchema::new(64);
        schema.add_categorical("category", vec!["A".into(), "B".into(), "C".into()], 1.0);
        schema.add_numeric("year", 2000.0, 2025.0, 0.5);
        schema.add_boolean("premium", 0.3);
        schema
    }

    #[test]
    fn test_attribute_embedding() {
        let schema = create_test_schema();
        let embedder = AttributeEmbedder::new(schema);

        let mut attrs = HashMap::new();
        attrs.insert(
            "category".to_string(),
            AttributeValue::Categorical("B".to_string()),
        );
        attrs.insert("year".to_string(), AttributeValue::Numeric(2020.0));
        attrs.insert("premium".to_string(), AttributeValue::Boolean(true));

        let embedding = embedder.embed(&attrs);

        // Should have dimensions for: 3 categories + 1 numeric + 1 boolean = 5
        assert_eq!(embedder.embedding_dim(), 5);
        assert_eq!(embedding.len(), 5);

        // Category B (index 1) should be 1.0
        assert_eq!(embedding[1], 1.0);

        // Year normalized: (2020 - 2000) / (2025 - 2000) = 0.8, weighted by 0.5
        assert!((embedding[3] - 0.4).abs() < 0.01);

        // Premium true, weighted by 0.3
        assert_eq!(embedding[4], 0.3);
    }

    #[test]
    fn test_fused_vector() {
        let content = vec![1.0, 0.0, 0.0];
        let attrs = vec![0.0, 1.0];

        let fused = FusedVector::new(content, attrs, 0.7);

        // Fused dimension = content (3) + attrs (2) = 5
        assert_eq!(fused.fused_dim(), 5);

        // Should be normalized
        let norm: f32 = fused.fused.iter().map(|x| x * x).sum::<f32>().sqrt();
        assert!((norm - 1.0).abs() < 0.01);
    }

    #[test]
    fn test_fused_index_basic() {
        let schema = create_test_schema();
        let embedder = AttributeEmbedder::new(schema);
        let config = FusedConfig::default();

        let mut index = FusedIndex::new(embedder, 4, config);

        // Add vectors
        let mut attrs1 = HashMap::new();
        attrs1.insert(
            "category".to_string(),
            AttributeValue::Categorical("A".to_string()),
        );
        index.add(vec![1.0, 0.0, 0.0, 0.0], attrs1).unwrap();

        let mut attrs2 = HashMap::new();
        attrs2.insert(
            "category".to_string(),
            AttributeValue::Categorical("B".to_string()),
        );
        index.add(vec![0.0, 1.0, 0.0, 0.0], attrs2).unwrap();

        assert_eq!(index.len(), 2);
    }

    #[test]
    fn test_fused_search_no_filter() {
        let schema = create_test_schema();
        let embedder = AttributeEmbedder::new(schema);
        let mut index = FusedIndex::new(embedder, 4, FusedConfig::default());

        // Add vectors
        for i in 0..10 {
            let mut attrs = HashMap::new();
            attrs.insert(
                "category".to_string(),
                AttributeValue::Categorical(if i % 2 == 0 { "A" } else { "B" }.to_string()),
            );
            let mut content = vec![0.0; 4];
            content[i % 4] = 1.0;
            index.add(content, attrs).unwrap();
        }

        // Search without filter
        let results = index.search(&[1.0, 0.0, 0.0, 0.0], None, 3).unwrap();

        assert!(!results.is_empty());
        assert!(results.len() <= 3);
    }

    #[test]
    fn test_fused_search_with_filter() {
        let schema = create_test_schema();
        let embedder = AttributeEmbedder::new(schema);
        let mut index = FusedIndex::new(embedder, 4, FusedConfig::default());

        // Add vectors - half A, half B
        for i in 0..10 {
            let mut attrs = HashMap::new();
            let cat = if i < 5 { "A" } else { "B" };
            attrs.insert(
                "category".to_string(),
                AttributeValue::Categorical(cat.to_string()),
            );
            let mut content = vec![0.0; 4];
            content[i % 4] = 1.0;
            index.add(content, attrs).unwrap();
        }

        // Search with filter for category A
        let mut filter = HashMap::new();
        filter.insert(
            "category".to_string(),
            AttributeValue::Categorical("A".to_string()),
        );

        let results = index
            .search(&[1.0, 0.0, 0.0, 0.0], Some(&filter), 3)
            .unwrap();

        // All results should be category A
        for (id, _) in &results {
            assert!(*id < 5, "ID {} should be in category A", id);
        }
    }

    #[test]
    fn test_check_filter() {
        let mut attrs = HashMap::new();
        attrs.insert(
            "cat".to_string(),
            AttributeValue::Categorical("A".to_string()),
        );
        attrs.insert("year".to_string(), AttributeValue::Numeric(2020.0));

        // Matching filter
        let mut filter = HashMap::new();
        filter.insert(
            "cat".to_string(),
            AttributeValue::Categorical("A".to_string()),
        );
        assert!(check_filter(&attrs, &filter));

        // Non-matching filter
        let mut filter2 = HashMap::new();
        filter2.insert(
            "cat".to_string(),
            AttributeValue::Categorical("B".to_string()),
        );
        assert!(!check_filter(&attrs, &filter2));

        // Range filter
        let mut filter3 = HashMap::new();
        filter3.insert(
            "year".to_string(),
            AttributeValue::NumericRange {
                min: 2015.0,
                max: 2025.0,
            },
        );
        assert!(check_filter(&attrs, &filter3));
    }

    #[test]
    fn test_recommend_alpha() {
        // Very selective filter (few matches) - expected_matches = 0.01 * 10000 = 100
        // 100 > 10 * 10 = 100, so goes to second branch
        let alpha1 = recommend_alpha(0.01, 10, 10000);
        // expected_matches (100) == k * 10 (100), so it returns 0.7
        assert!(alpha1 <= 0.7);

        // Loose filter (many matches)
        let alpha2 = recommend_alpha(0.9, 10, 10000);
        assert!(alpha2 >= 0.7);

        // Very few matches - force the 0.5 case
        let alpha3 = recommend_alpha(0.001, 10, 10000); // expected = 10, k = 10
        assert!(alpha3 <= 0.7);
    }

    #[test]
    fn test_multi_category_filter() {
        let mut attrs = HashMap::new();
        attrs.insert(
            "tags".to_string(),
            AttributeValue::MultiCategory(vec!["rust".to_string(), "python".to_string()]),
        );

        let mut filter = HashMap::new();
        filter.insert(
            "tags".to_string(),
            AttributeValue::Categorical("rust".to_string()),
        );

        assert!(check_filter(&attrs, &filter));
    }
}