Skip to main content

hermes_core/structures/vector/ivf/
soar.rs

1//! SOAR: Spilling with Orthogonality-Amplified Residuals
2//!
3//! Implementation of Google's SOAR algorithm for improved IVF recall:
4//! - Assigns vectors to multiple clusters (primary + secondary)
5//! - Secondary clusters chosen to have orthogonal residuals
6//! - When query is parallel to primary residual (high error), secondary has low error
7//!
8//! Reference: "SOAR: New algorithms for even faster vector search with ScaNN"
9//! https://research.google/blog/soar-new-algorithms-for-even-faster-vector-search-with-scann/
10
11use serde::{Deserialize, Serialize};
12
13/// Configuration for SOAR (Spilling with Orthogonality-Amplified Residuals)
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct SoarConfig {
16    /// Number of secondary cluster assignments (typically 1-2)
17    pub num_secondary: usize,
18    /// Use selective spilling (only spill vectors near cluster boundaries)
19    pub selective: bool,
20    /// Threshold for selective spilling (residual norm must exceed this)
21    pub spill_threshold: f32,
22}
23
24impl Default for SoarConfig {
25    fn default() -> Self {
26        Self {
27            num_secondary: 1,
28            selective: true,
29            spill_threshold: 0.5,
30        }
31    }
32}
33
34impl SoarConfig {
35    /// Create SOAR config with 1 secondary assignment
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Create SOAR config with specified number of secondary assignments
41    pub fn with_secondary(num_secondary: usize) -> Self {
42        Self {
43            num_secondary,
44            ..Default::default()
45        }
46    }
47
48    /// Enable/disable selective spilling
49    pub fn selective(mut self, enabled: bool) -> Self {
50        self.selective = enabled;
51        self
52    }
53
54    /// Set spill threshold for selective spilling
55    pub fn threshold(mut self, threshold: f32) -> Self {
56        self.spill_threshold = threshold;
57        self
58    }
59
60    /// Full spilling (no selectivity) - assigns all vectors to secondary clusters
61    pub fn full() -> Self {
62        Self {
63            num_secondary: 1,
64            selective: false,
65            spill_threshold: 0.0,
66        }
67    }
68
69    /// Aggressive spilling with 2 secondary clusters
70    pub fn aggressive() -> Self {
71        Self {
72            num_secondary: 2,
73            selective: false,
74            spill_threshold: 0.0,
75        }
76    }
77}
78
79/// Multi-cluster assignment result from SOAR
80#[derive(Debug, Clone)]
81pub struct MultiAssignment {
82    /// Primary cluster (nearest centroid)
83    pub primary_cluster: u32,
84    /// Secondary clusters (orthogonal residuals)
85    pub secondary_clusters: Vec<u32>,
86}
87
88impl MultiAssignment {
89    /// Create assignment with only primary cluster
90    pub fn primary_only(cluster: u32) -> Self {
91        Self {
92            primary_cluster: cluster,
93            secondary_clusters: Vec::new(),
94        }
95    }
96
97    /// Get all clusters (primary + secondary)
98    pub fn all_clusters(&self) -> impl Iterator<Item = u32> + '_ {
99        std::iter::once(self.primary_cluster).chain(self.secondary_clusters.iter().copied())
100    }
101
102    /// Total number of cluster assignments
103    pub fn num_assignments(&self) -> usize {
104        1 + self.secondary_clusters.len()
105    }
106
107    /// Check if this is a spilled assignment (has secondary clusters)
108    pub fn is_spilled(&self) -> bool {
109        !self.secondary_clusters.is_empty()
110    }
111}
112
113/// Statistics for SOAR assignments
114#[allow(dead_code)]
115#[derive(Debug, Clone, Default)]
116pub struct SoarStats {
117    /// Total vectors assigned
118    pub total_vectors: usize,
119    /// Vectors with secondary assignments (spilled)
120    pub spilled_vectors: usize,
121    /// Total cluster assignments (including secondary)
122    pub total_assignments: usize,
123}
124
125#[allow(dead_code)]
126impl SoarStats {
127    pub fn new() -> Self {
128        Self::default()
129    }
130
131    /// Record an assignment
132    pub fn record(&mut self, assignment: &MultiAssignment) {
133        self.total_vectors += 1;
134        self.total_assignments += assignment.num_assignments();
135        if assignment.is_spilled() {
136            self.spilled_vectors += 1;
137        }
138    }
139
140    /// Spill ratio (fraction of vectors with secondary assignments)
141    pub fn spill_ratio(&self) -> f32 {
142        if self.total_vectors == 0 {
143            0.0
144        } else {
145            self.spilled_vectors as f32 / self.total_vectors as f32
146        }
147    }
148
149    /// Average assignments per vector
150    pub fn avg_assignments(&self) -> f32 {
151        if self.total_vectors == 0 {
152            0.0
153        } else {
154            self.total_assignments as f32 / self.total_vectors as f32
155        }
156    }
157
158    /// Storage overhead factor (1.0 = no overhead, 2.0 = 2x storage)
159    pub fn storage_factor(&self) -> f32 {
160        self.avg_assignments()
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_soar_config_default() {
170        let config = SoarConfig::default();
171        assert_eq!(config.num_secondary, 1);
172        assert!(config.selective);
173    }
174
175    #[test]
176    fn test_multi_assignment() {
177        let assignment = MultiAssignment {
178            primary_cluster: 5,
179            secondary_clusters: vec![2, 7],
180        };
181
182        assert_eq!(assignment.num_assignments(), 3);
183        assert!(assignment.is_spilled());
184
185        let all: Vec<u32> = assignment.all_clusters().collect();
186        assert_eq!(all, vec![5, 2, 7]);
187    }
188
189    #[test]
190    fn test_soar_stats() {
191        let mut stats = SoarStats::new();
192
193        // Primary only assignment
194        stats.record(&MultiAssignment::primary_only(0));
195
196        // Spilled assignment
197        stats.record(&MultiAssignment {
198            primary_cluster: 1,
199            secondary_clusters: vec![2],
200        });
201
202        assert_eq!(stats.total_vectors, 2);
203        assert_eq!(stats.spilled_vectors, 1);
204        assert_eq!(stats.total_assignments, 3);
205        assert_eq!(stats.spill_ratio(), 0.5);
206        assert_eq!(stats.avg_assignments(), 1.5);
207    }
208}