1use serde::{Deserialize, Serialize};
4
5#[derive(
7 Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
8)]
9pub enum PruningStrategy {
10 Alpha,
12 Robust,
14 Hybrid,
16}
17
18#[derive(
20 Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
21)]
22pub enum SearchMode {
23 InMemory,
25 Streaming,
27 Cached,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
33pub struct DiskAnnConfig {
34 pub dimension: usize,
36
37 pub max_degree: usize,
39
40 pub build_beam_width: usize,
42
43 pub search_beam_width: usize,
45
46 pub alpha: f32,
48
49 pub pruning_strategy: PruningStrategy,
51
52 pub search_mode: SearchMode,
54
55 pub max_vectors_in_memory: Option<usize>,
57
58 pub use_pq_compression: bool,
60
61 pub pq_subvectors: Option<usize>,
63
64 pub pq_bits: Option<u8>,
66
67 pub enable_incremental_updates: bool,
69
70 pub num_entry_points: usize,
72
73 pub io_buffer_size: usize,
75}
76
77impl DiskAnnConfig {
78 pub fn default_config(dimension: usize) -> Self {
80 Self {
81 dimension,
82 max_degree: 64,
83 build_beam_width: 100,
84 search_beam_width: 75,
85 alpha: 1.2,
86 pruning_strategy: PruningStrategy::Robust,
87 search_mode: SearchMode::Cached,
88 max_vectors_in_memory: Some(100_000),
89 use_pq_compression: false,
90 pq_subvectors: None,
91 pq_bits: None,
92 enable_incremental_updates: false,
93 num_entry_points: 1,
94 io_buffer_size: 1 << 20, }
96 }
97
98 pub fn memory_optimized(dimension: usize) -> Self {
100 Self {
101 dimension,
102 max_degree: 32,
103 build_beam_width: 75,
104 search_beam_width: 50,
105 alpha: 1.2,
106 pruning_strategy: PruningStrategy::Robust,
107 search_mode: SearchMode::Streaming,
108 max_vectors_in_memory: Some(10_000),
109 use_pq_compression: true,
110 pq_subvectors: Some(dimension / 16),
111 pq_bits: Some(8),
112 enable_incremental_updates: false,
113 num_entry_points: 1,
114 io_buffer_size: 512 * 1024, }
116 }
117
118 pub fn speed_optimized(dimension: usize) -> Self {
120 Self {
121 dimension,
122 max_degree: 96,
123 build_beam_width: 150,
124 search_beam_width: 100,
125 alpha: 1.2,
126 pruning_strategy: PruningStrategy::Alpha,
127 search_mode: SearchMode::InMemory,
128 max_vectors_in_memory: Some(1_000_000),
129 use_pq_compression: false,
130 pq_subvectors: None,
131 pq_bits: None,
132 enable_incremental_updates: true,
133 num_entry_points: 4,
134 io_buffer_size: 4 << 20, }
136 }
137
138 pub fn billion_scale(dimension: usize) -> Self {
140 Self {
141 dimension,
142 max_degree: 64,
143 build_beam_width: 100,
144 search_beam_width: 64,
145 alpha: 1.2,
146 pruning_strategy: PruningStrategy::Robust,
147 search_mode: SearchMode::Streaming,
148 max_vectors_in_memory: Some(50_000),
149 use_pq_compression: true,
150 pq_subvectors: Some(dimension / 8),
151 pq_bits: Some(8),
152 enable_incremental_updates: false,
153 num_entry_points: 8,
154 io_buffer_size: 2 << 20, }
156 }
157
158 pub fn validate(&self) -> Result<(), String> {
160 if self.dimension == 0 {
161 return Err("Dimension must be greater than 0".to_string());
162 }
163
164 if self.max_degree == 0 {
165 return Err("Max degree must be greater than 0".to_string());
166 }
167
168 if self.build_beam_width == 0 {
169 return Err("Build beam width must be greater than 0".to_string());
170 }
171
172 if self.search_beam_width == 0 {
173 return Err("Search beam width must be greater than 0".to_string());
174 }
175
176 if self.alpha <= 0.0 {
177 return Err("Alpha must be positive".to_string());
178 }
179
180 if self.use_pq_compression {
181 if self.pq_subvectors.is_none() {
182 return Err(
183 "PQ subvectors must be specified when compression is enabled".to_string(),
184 );
185 }
186 if self.pq_bits.is_none() {
187 return Err("PQ bits must be specified when compression is enabled".to_string());
188 }
189
190 let pq_subvectors = self.pq_subvectors.unwrap();
191 if self.dimension % pq_subvectors != 0 {
192 return Err(format!(
193 "Dimension {} must be divisible by PQ subvectors {}",
194 self.dimension, pq_subvectors
195 ));
196 }
197
198 let pq_bits = self.pq_bits.unwrap();
199 if pq_bits == 0 || pq_bits > 16 {
200 return Err("PQ bits must be between 1 and 16".to_string());
201 }
202 }
203
204 if self.num_entry_points == 0 {
205 return Err("Number of entry points must be greater than 0".to_string());
206 }
207
208 if self.io_buffer_size == 0 {
209 return Err("IO buffer size must be greater than 0".to_string());
210 }
211
212 Ok(())
213 }
214
215 pub fn estimate_memory_usage(&self, num_vectors: usize) -> usize {
217 let graph_memory = num_vectors * (4 + self.max_degree * 4);
219
220 let vector_memory = if self.use_pq_compression {
222 let pq_subvectors = self.pq_subvectors.unwrap_or(self.dimension / 8);
223 let pq_bits = self.pq_bits.unwrap_or(8);
224 let bytes_per_code = (pq_bits as usize + 7) / 8;
225 num_vectors * pq_subvectors * bytes_per_code
226 } else {
227 num_vectors * self.dimension * 4 };
229
230 let inmem_vectors = self
232 .max_vectors_in_memory
233 .unwrap_or(num_vectors)
234 .min(num_vectors);
235 let inmem_memory = inmem_vectors * self.dimension * 4;
236
237 graph_memory + vector_memory + inmem_memory + self.io_buffer_size
238 }
239}
240
241impl Default for DiskAnnConfig {
242 fn default() -> Self {
243 Self::default_config(128)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_default_config() {
253 let config = DiskAnnConfig::default_config(128);
254 assert_eq!(config.dimension, 128);
255 assert_eq!(config.max_degree, 64);
256 assert!(config.validate().is_ok());
257 }
258
259 #[test]
260 fn test_memory_optimized() {
261 let config = DiskAnnConfig::memory_optimized(256);
262 assert_eq!(config.dimension, 256);
263 assert!(config.use_pq_compression);
264 assert_eq!(config.search_mode, SearchMode::Streaming);
265 assert!(config.validate().is_ok());
266 }
267
268 #[test]
269 fn test_speed_optimized() {
270 let config = DiskAnnConfig::speed_optimized(512);
271 assert_eq!(config.dimension, 512);
272 assert!(!config.use_pq_compression);
273 assert_eq!(config.search_mode, SearchMode::InMemory);
274 assert!(config.validate().is_ok());
275 }
276
277 #[test]
278 fn test_billion_scale() {
279 let config = DiskAnnConfig::billion_scale(768);
280 assert_eq!(config.dimension, 768);
281 assert!(config.use_pq_compression);
282 assert_eq!(config.search_mode, SearchMode::Streaming);
283 assert!(config.validate().is_ok());
284 }
285
286 #[test]
287 fn test_validation() {
288 let mut config = DiskAnnConfig::default_config(128);
289 assert!(config.validate().is_ok());
290
291 config.dimension = 0;
292 assert!(config.validate().is_err());
293
294 config = DiskAnnConfig::default_config(128);
295 config.max_degree = 0;
296 assert!(config.validate().is_err());
297
298 config = DiskAnnConfig::default_config(128);
299 config.use_pq_compression = true;
300 assert!(config.validate().is_err()); }
302
303 #[test]
304 fn test_memory_estimation() {
305 let config = DiskAnnConfig::default_config(128);
306 let memory = config.estimate_memory_usage(1_000_000);
307 assert!(memory > 0);
308
309 let pq_config = DiskAnnConfig::memory_optimized(128);
310 let pq_memory = pq_config.estimate_memory_usage(1_000_000);
311 assert!(pq_memory < memory); }
313
314 #[test]
315 fn test_pq_validation() {
316 let mut config = DiskAnnConfig::default_config(128);
317 config.use_pq_compression = true;
318 config.pq_subvectors = Some(16);
319 config.pq_bits = Some(8);
320 assert!(config.validate().is_ok());
321
322 config.pq_subvectors = Some(15); assert!(config.validate().is_err());
324
325 config.pq_subvectors = Some(16);
326 config.pq_bits = Some(0);
327 assert!(config.validate().is_err());
328
329 config.pq_bits = Some(20); assert!(config.validate().is_err());
331 }
332}