1use serde::{Deserialize, Serialize};
4
5#[derive(
7 Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
8)]
9pub enum PruningStrategy {
10 Alpha,
12 Robust,
14 Hybrid,
16}
17
18#[derive(
20 Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
21)]
22pub enum SearchMode {
23 InMemory,
25 Streaming,
27 Cached,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize, oxicode::Encode, oxicode::Decode)]
33pub struct DiskAnnConfig {
34 pub dimension: usize,
36
37 pub max_degree: usize,
39
40 pub build_beam_width: usize,
42
43 pub search_beam_width: usize,
45
46 pub alpha: f32,
48
49 pub pruning_strategy: PruningStrategy,
51
52 pub search_mode: SearchMode,
54
55 pub max_vectors_in_memory: Option<usize>,
57
58 pub use_pq_compression: bool,
60
61 pub pq_subvectors: Option<usize>,
63
64 pub pq_bits: Option<u8>,
66
67 pub enable_incremental_updates: bool,
69
70 pub num_entry_points: usize,
72
73 pub io_buffer_size: usize,
75}
76
77impl DiskAnnConfig {
78 pub fn default_config(dimension: usize) -> Self {
80 Self {
81 dimension,
82 max_degree: 64,
83 build_beam_width: 100,
84 search_beam_width: 75,
85 alpha: 1.2,
86 pruning_strategy: PruningStrategy::Robust,
87 search_mode: SearchMode::Cached,
88 max_vectors_in_memory: Some(100_000),
89 use_pq_compression: false,
90 pq_subvectors: None,
91 pq_bits: None,
92 enable_incremental_updates: false,
93 num_entry_points: 1,
94 io_buffer_size: 1 << 20, }
96 }
97
98 pub fn memory_optimized(dimension: usize) -> Self {
100 Self {
101 dimension,
102 max_degree: 32,
103 build_beam_width: 75,
104 search_beam_width: 50,
105 alpha: 1.2,
106 pruning_strategy: PruningStrategy::Robust,
107 search_mode: SearchMode::Streaming,
108 max_vectors_in_memory: Some(10_000),
109 use_pq_compression: true,
110 pq_subvectors: Some(dimension / 16),
111 pq_bits: Some(8),
112 enable_incremental_updates: false,
113 num_entry_points: 1,
114 io_buffer_size: 512 * 1024, }
116 }
117
118 pub fn speed_optimized(dimension: usize) -> Self {
120 Self {
121 dimension,
122 max_degree: 96,
123 build_beam_width: 150,
124 search_beam_width: 100,
125 alpha: 1.2,
126 pruning_strategy: PruningStrategy::Alpha,
127 search_mode: SearchMode::InMemory,
128 max_vectors_in_memory: Some(1_000_000),
129 use_pq_compression: false,
130 pq_subvectors: None,
131 pq_bits: None,
132 enable_incremental_updates: true,
133 num_entry_points: 4,
134 io_buffer_size: 4 << 20, }
136 }
137
138 pub fn billion_scale(dimension: usize) -> Self {
140 Self {
141 dimension,
142 max_degree: 64,
143 build_beam_width: 100,
144 search_beam_width: 64,
145 alpha: 1.2,
146 pruning_strategy: PruningStrategy::Robust,
147 search_mode: SearchMode::Streaming,
148 max_vectors_in_memory: Some(50_000),
149 use_pq_compression: true,
150 pq_subvectors: Some(dimension / 8),
151 pq_bits: Some(8),
152 enable_incremental_updates: false,
153 num_entry_points: 8,
154 io_buffer_size: 2 << 20, }
156 }
157
158 pub fn validate(&self) -> Result<(), String> {
160 if self.dimension == 0 {
161 return Err("Dimension must be greater than 0".to_string());
162 }
163
164 if self.max_degree == 0 {
165 return Err("Max degree must be greater than 0".to_string());
166 }
167
168 if self.build_beam_width == 0 {
169 return Err("Build beam width must be greater than 0".to_string());
170 }
171
172 if self.search_beam_width == 0 {
173 return Err("Search beam width must be greater than 0".to_string());
174 }
175
176 if self.alpha <= 0.0 {
177 return Err("Alpha must be positive".to_string());
178 }
179
180 if self.use_pq_compression {
181 if self.pq_subvectors.is_none() {
182 return Err(
183 "PQ subvectors must be specified when compression is enabled".to_string(),
184 );
185 }
186 if self.pq_bits.is_none() {
187 return Err("PQ bits must be specified when compression is enabled".to_string());
188 }
189
190 let pq_subvectors = self
191 .pq_subvectors
192 .expect("pq_subvectors validated as Some above");
193 if self.dimension % pq_subvectors != 0 {
194 return Err(format!(
195 "Dimension {} must be divisible by PQ subvectors {}",
196 self.dimension, pq_subvectors
197 ));
198 }
199
200 let pq_bits = self.pq_bits.expect("pq_bits validated as Some above");
201 if pq_bits == 0 || pq_bits > 16 {
202 return Err("PQ bits must be between 1 and 16".to_string());
203 }
204 }
205
206 if self.num_entry_points == 0 {
207 return Err("Number of entry points must be greater than 0".to_string());
208 }
209
210 if self.io_buffer_size == 0 {
211 return Err("IO buffer size must be greater than 0".to_string());
212 }
213
214 Ok(())
215 }
216
217 pub fn estimate_memory_usage(&self, num_vectors: usize) -> usize {
219 let graph_memory = num_vectors * (4 + self.max_degree * 4);
221
222 let vector_memory = if self.use_pq_compression {
224 let pq_subvectors = self.pq_subvectors.unwrap_or(self.dimension / 8);
225 let pq_bits = self.pq_bits.unwrap_or(8);
226 let bytes_per_code = (pq_bits as usize + 7) / 8;
227 num_vectors * pq_subvectors * bytes_per_code
228 } else {
229 num_vectors * self.dimension * 4 };
231
232 let inmem_vectors = self
234 .max_vectors_in_memory
235 .unwrap_or(num_vectors)
236 .min(num_vectors);
237 let inmem_memory = inmem_vectors * self.dimension * 4;
238
239 graph_memory + vector_memory + inmem_memory + self.io_buffer_size
240 }
241}
242
243impl Default for DiskAnnConfig {
244 fn default() -> Self {
245 Self::default_config(128)
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_default_config() {
255 let config = DiskAnnConfig::default_config(128);
256 assert_eq!(config.dimension, 128);
257 assert_eq!(config.max_degree, 64);
258 assert!(config.validate().is_ok());
259 }
260
261 #[test]
262 fn test_memory_optimized() {
263 let config = DiskAnnConfig::memory_optimized(256);
264 assert_eq!(config.dimension, 256);
265 assert!(config.use_pq_compression);
266 assert_eq!(config.search_mode, SearchMode::Streaming);
267 assert!(config.validate().is_ok());
268 }
269
270 #[test]
271 fn test_speed_optimized() {
272 let config = DiskAnnConfig::speed_optimized(512);
273 assert_eq!(config.dimension, 512);
274 assert!(!config.use_pq_compression);
275 assert_eq!(config.search_mode, SearchMode::InMemory);
276 assert!(config.validate().is_ok());
277 }
278
279 #[test]
280 fn test_billion_scale() {
281 let config = DiskAnnConfig::billion_scale(768);
282 assert_eq!(config.dimension, 768);
283 assert!(config.use_pq_compression);
284 assert_eq!(config.search_mode, SearchMode::Streaming);
285 assert!(config.validate().is_ok());
286 }
287
288 #[test]
289 fn test_validation() {
290 let mut config = DiskAnnConfig::default_config(128);
291 assert!(config.validate().is_ok());
292
293 config.dimension = 0;
294 assert!(config.validate().is_err());
295
296 config = DiskAnnConfig::default_config(128);
297 config.max_degree = 0;
298 assert!(config.validate().is_err());
299
300 config = DiskAnnConfig::default_config(128);
301 config.use_pq_compression = true;
302 assert!(config.validate().is_err()); }
304
305 #[test]
306 fn test_memory_estimation() {
307 let config = DiskAnnConfig::default_config(128);
308 let memory = config.estimate_memory_usage(1_000_000);
309 assert!(memory > 0);
310
311 let pq_config = DiskAnnConfig::memory_optimized(128);
312 let pq_memory = pq_config.estimate_memory_usage(1_000_000);
313 assert!(pq_memory < memory); }
315
316 #[test]
317 fn test_pq_validation() {
318 let mut config = DiskAnnConfig::default_config(128);
319 config.use_pq_compression = true;
320 config.pq_subvectors = Some(16);
321 config.pq_bits = Some(8);
322 assert!(config.validate().is_ok());
323
324 config.pq_subvectors = Some(15); assert!(config.validate().is_err());
326
327 config.pq_subvectors = Some(16);
328 config.pq_bits = Some(0);
329 assert!(config.validate().is_err());
330
331 config.pq_bits = Some(20); assert!(config.validate().is_err());
333 }
334}