1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8 64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12 1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16 10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20 32
21}
22
23pub fn default_nreplica_count() -> u32 {
24 8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28 1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32 5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36 50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40 1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44 100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48 64
49}
50
51pub fn default_merge_threshold() -> u32 {
52 25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56 8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60 200
61}
62
63pub fn default_search_ef_spann() -> usize {
64 200
65}
66
67pub fn default_m_spann() -> usize {
68 64
69}
70
71fn default_space_spann() -> Space {
72 Space::L2
73}
74
75#[derive(Debug, Error)]
76pub enum DistributedSpannParametersFromSegmentError {
77 #[error("Invalid metadata: {0}")]
78 InvalidMetadata(#[from] serde_json::Error),
79 #[error("Invalid parameters: {0}")]
80 InvalidParameters(#[from] validator::ValidationErrors),
81}
82
83impl ChromaError for DistributedSpannParametersFromSegmentError {
84 fn code(&self) -> ErrorCodes {
85 match self {
86 DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
87 ErrorCodes::InvalidArgument
88 }
89 DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
90 ErrorCodes::InvalidArgument
91 }
92 }
93 }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98pub struct InternalSpannConfiguration {
99 #[serde(default = "default_search_nprobe")]
100 pub search_nprobe: u32,
101 #[serde(default = "default_search_rng_factor")]
102 pub search_rng_factor: f32,
103 #[serde(default = "default_search_rng_epsilon")]
104 pub search_rng_epsilon: f32,
105 #[serde(default = "default_write_nprobe")]
106 #[validate(range(max = 128))]
107 pub write_nprobe: u32,
108 #[serde(default = "default_nreplica_count")]
109 #[validate(range(max = 8))]
110 pub nreplica_count: u32,
111 #[serde(default = "default_write_rng_factor")]
112 pub write_rng_factor: f32,
113 #[serde(default = "default_write_rng_epsilon")]
114 pub write_rng_epsilon: f32,
115 #[serde(default = "default_split_threshold")]
116 #[validate(range(min = 25, max = 200))]
117 pub split_threshold: u32,
118 #[serde(default = "default_num_samples_kmeans")]
119 pub num_samples_kmeans: usize,
120 #[serde(default = "default_initial_lambda")]
121 pub initial_lambda: f32,
122 #[serde(default = "default_reassign_neighbor_count")]
123 #[validate(range(max = 64))]
124 pub reassign_neighbor_count: u32,
125 #[serde(default = "default_merge_threshold")]
126 #[validate(range(min = 12, max = 100))]
127 pub merge_threshold: u32,
128 #[serde(default = "default_num_centers_to_merge_to")]
129 #[validate(range(max = 8))]
130 pub num_centers_to_merge_to: u32,
131 #[serde(default = "default_space_spann")]
132 pub space: Space,
133 #[serde(default = "default_construction_ef_spann")]
134 #[validate(range(max = 200))]
135 pub ef_construction: usize,
136 #[serde(default = "default_search_ef_spann")]
137 #[validate(range(max = 200))]
138 pub ef_search: usize,
139 #[serde(default = "default_m_spann")]
140 #[validate(range(max = 64))]
141 pub max_neighbors: usize,
142}
143
144impl Default for InternalSpannConfiguration {
145 fn default() -> Self {
146 serde_json::from_str("{}").unwrap()
147 }
148}
149
150impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
151 fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
152 InternalSpannConfiguration {
153 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
154 search_rng_factor: config
155 .search_rng_factor
156 .unwrap_or(default_search_rng_factor()),
157 search_rng_epsilon: config
158 .search_rng_epsilon
159 .unwrap_or(default_search_rng_epsilon()),
160 nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
161 write_rng_factor: config
162 .write_rng_factor
163 .unwrap_or(default_write_rng_factor()),
164 write_rng_epsilon: config
165 .write_rng_epsilon
166 .unwrap_or(default_write_rng_epsilon()),
167 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
168 num_samples_kmeans: config
169 .num_samples_kmeans
170 .unwrap_or(default_num_samples_kmeans()),
171 initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
172 reassign_neighbor_count: config
173 .reassign_neighbor_count
174 .unwrap_or(default_reassign_neighbor_count()),
175 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
176 num_centers_to_merge_to: config
177 .num_centers_to_merge_to
178 .unwrap_or(default_num_centers_to_merge_to()),
179 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
180 ef_construction: config
181 .ef_construction
182 .unwrap_or(default_construction_ef_spann()),
183 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
184 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
185 space: space.unwrap_or(&default_space()).clone(),
186 }
187 }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
191#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
192#[serde(deny_unknown_fields)]
193pub struct SpannConfiguration {
194 pub search_nprobe: Option<u32>,
195 pub write_nprobe: Option<u32>,
196 pub space: Option<Space>,
197 pub ef_construction: Option<usize>,
198 pub ef_search: Option<usize>,
199 pub max_neighbors: Option<usize>,
200 pub reassign_neighbor_count: Option<u32>,
201 pub split_threshold: Option<u32>,
202 pub merge_threshold: Option<u32>,
203}
204
205impl From<InternalSpannConfiguration> for SpannConfiguration {
206 fn from(config: InternalSpannConfiguration) -> Self {
207 Self {
208 search_nprobe: Some(config.search_nprobe),
209 write_nprobe: Some(config.write_nprobe),
210 space: Some(config.space),
211 ef_construction: Some(config.ef_construction),
212 ef_search: Some(config.ef_search),
213 max_neighbors: Some(config.max_neighbors),
214 reassign_neighbor_count: Some(config.reassign_neighbor_count),
215 split_threshold: Some(config.split_threshold),
216 merge_threshold: Some(config.merge_threshold),
217 }
218 }
219}
220
221impl From<SpannConfiguration> for InternalSpannConfiguration {
222 fn from(config: SpannConfiguration) -> Self {
223 Self {
224 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
225 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
226 space: config.space.unwrap_or(default_space_spann()),
227 ef_construction: config
228 .ef_construction
229 .unwrap_or(default_construction_ef_spann()),
230 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
231 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
232 reassign_neighbor_count: config
233 .reassign_neighbor_count
234 .unwrap_or(default_reassign_neighbor_count()),
235 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
236 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
237 ..Default::default()
238 }
239 }
240}
241
242impl Default for SpannConfiguration {
243 fn default() -> Self {
244 InternalSpannConfiguration::default().into()
245 }
246}
247
248#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
249#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
250#[serde(deny_unknown_fields)]
251#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
252pub struct UpdateSpannConfiguration {
253 pub search_nprobe: Option<u32>,
254 pub ef_search: Option<usize>,
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_spann_configuration_to_internal_spann_configuration() {
263 let spann_config = SpannConfiguration {
264 search_nprobe: Some(100),
265 write_nprobe: Some(50),
266 space: Some(Space::Cosine),
267 ef_construction: Some(150),
268 ef_search: Some(180),
269 max_neighbors: Some(32),
270 reassign_neighbor_count: Some(48),
271 split_threshold: Some(75),
272 merge_threshold: Some(50),
273 };
274
275 let internal_config: InternalSpannConfiguration = spann_config.into();
276
277 assert_eq!(internal_config.search_nprobe, 100);
278 assert_eq!(internal_config.write_nprobe, 50);
279 assert_eq!(internal_config.space, Space::Cosine);
280 assert_eq!(internal_config.ef_construction, 150);
281 assert_eq!(internal_config.ef_search, 180);
282 assert_eq!(internal_config.max_neighbors, 32);
283 assert_eq!(internal_config.reassign_neighbor_count, 48);
284 assert_eq!(internal_config.split_threshold, 75);
285 assert_eq!(internal_config.merge_threshold, 50);
286 assert_eq!(
287 internal_config.search_rng_factor,
288 default_search_rng_factor()
289 );
290 assert_eq!(
291 internal_config.search_rng_epsilon,
292 default_search_rng_epsilon()
293 );
294 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
295 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
296 assert_eq!(
297 internal_config.write_rng_epsilon,
298 default_write_rng_epsilon()
299 );
300 assert_eq!(
301 internal_config.num_samples_kmeans,
302 default_num_samples_kmeans()
303 );
304 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
305 assert_eq!(
306 internal_config.num_centers_to_merge_to,
307 default_num_centers_to_merge_to()
308 );
309 }
310
311 #[test]
312 fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
313 let spann_config = SpannConfiguration {
314 search_nprobe: None,
315 write_nprobe: None,
316 space: None,
317 ef_construction: None,
318 ef_search: None,
319 max_neighbors: None,
320 reassign_neighbor_count: None,
321 split_threshold: None,
322 merge_threshold: None,
323 };
324
325 let internal_config: InternalSpannConfiguration = spann_config.into();
326
327 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
328 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
329 assert_eq!(internal_config.space, default_space_spann());
330 assert_eq!(
331 internal_config.ef_construction,
332 default_construction_ef_spann()
333 );
334 assert_eq!(internal_config.ef_search, default_search_ef_spann());
335 assert_eq!(internal_config.max_neighbors, default_m_spann());
336 assert_eq!(
337 internal_config.reassign_neighbor_count,
338 default_reassign_neighbor_count()
339 );
340 assert_eq!(internal_config.split_threshold, default_split_threshold());
341 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
342 assert_eq!(
343 internal_config.search_rng_factor,
344 default_search_rng_factor()
345 );
346 assert_eq!(
347 internal_config.search_rng_epsilon,
348 default_search_rng_epsilon()
349 );
350 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
351 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
352 assert_eq!(
353 internal_config.write_rng_epsilon,
354 default_write_rng_epsilon()
355 );
356 assert_eq!(
357 internal_config.num_samples_kmeans,
358 default_num_samples_kmeans()
359 );
360 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
361 assert_eq!(
362 internal_config.num_centers_to_merge_to,
363 default_num_centers_to_merge_to()
364 );
365 }
366
367 #[test]
368 fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
369 let spann_config = SpannConfiguration {
370 search_nprobe: Some(80),
371 write_nprobe: None,
372 space: Some(Space::Ip),
373 ef_construction: None,
374 ef_search: Some(160),
375 max_neighbors: Some(48),
376 reassign_neighbor_count: None,
377 split_threshold: Some(100),
378 merge_threshold: None,
379 };
380
381 let internal_config: InternalSpannConfiguration = spann_config.into();
382
383 assert_eq!(internal_config.search_nprobe, 80);
384 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
385 assert_eq!(internal_config.space, Space::Ip);
386 assert_eq!(
387 internal_config.ef_construction,
388 default_construction_ef_spann()
389 );
390 assert_eq!(internal_config.ef_search, 160);
391 assert_eq!(internal_config.max_neighbors, 48);
392 assert_eq!(
393 internal_config.reassign_neighbor_count,
394 default_reassign_neighbor_count()
395 );
396 assert_eq!(internal_config.split_threshold, 100);
397 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
398 assert_eq!(
399 internal_config.search_rng_factor,
400 default_search_rng_factor()
401 );
402 assert_eq!(
403 internal_config.search_rng_epsilon,
404 default_search_rng_epsilon()
405 );
406 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
407 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
408 assert_eq!(
409 internal_config.write_rng_epsilon,
410 default_write_rng_epsilon()
411 );
412 assert_eq!(
413 internal_config.num_samples_kmeans,
414 default_num_samples_kmeans()
415 );
416 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
417 assert_eq!(
418 internal_config.num_centers_to_merge_to,
419 default_num_centers_to_merge_to()
420 );
421 }
422
423 #[test]
424 fn test_internal_spann_configuration_default() {
425 let internal_config = InternalSpannConfiguration::default();
426
427 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
428 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
429 assert_eq!(internal_config.space, default_space_spann());
430 assert_eq!(
431 internal_config.ef_construction,
432 default_construction_ef_spann()
433 );
434 assert_eq!(internal_config.ef_search, default_search_ef_spann());
435 assert_eq!(internal_config.max_neighbors, default_m_spann());
436 assert_eq!(
437 internal_config.reassign_neighbor_count,
438 default_reassign_neighbor_count()
439 );
440 assert_eq!(internal_config.split_threshold, default_split_threshold());
441 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
442 assert_eq!(
443 internal_config.search_rng_factor,
444 default_search_rng_factor()
445 );
446 assert_eq!(
447 internal_config.search_rng_epsilon,
448 default_search_rng_epsilon()
449 );
450 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
451 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
452 assert_eq!(
453 internal_config.write_rng_epsilon,
454 default_write_rng_epsilon()
455 );
456 assert_eq!(
457 internal_config.num_samples_kmeans,
458 default_num_samples_kmeans()
459 );
460 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
461 assert_eq!(
462 internal_config.num_centers_to_merge_to,
463 default_num_centers_to_merge_to()
464 );
465 }
466
467 #[test]
468 fn test_spann_configuration_default() {
469 let spann_config = SpannConfiguration::default();
470 let internal_config: InternalSpannConfiguration = spann_config.into();
471
472 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
473 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
474 assert_eq!(internal_config.space, default_space_spann());
475 assert_eq!(
476 internal_config.ef_construction,
477 default_construction_ef_spann()
478 );
479 assert_eq!(internal_config.ef_search, default_search_ef_spann());
480 assert_eq!(internal_config.max_neighbors, default_m_spann());
481 assert_eq!(
482 internal_config.reassign_neighbor_count,
483 default_reassign_neighbor_count()
484 );
485 assert_eq!(internal_config.split_threshold, default_split_threshold());
486 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
487 assert_eq!(
488 internal_config.search_rng_factor,
489 default_search_rng_factor()
490 );
491 assert_eq!(
492 internal_config.search_rng_epsilon,
493 default_search_rng_epsilon()
494 );
495 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
496 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
497 assert_eq!(
498 internal_config.write_rng_epsilon,
499 default_write_rng_epsilon()
500 );
501 assert_eq!(
502 internal_config.num_samples_kmeans,
503 default_num_samples_kmeans()
504 );
505 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
506 assert_eq!(
507 internal_config.num_centers_to_merge_to,
508 default_num_centers_to_merge_to()
509 );
510 }
511
512 #[test]
513 fn test_deserialize_json_without_nreplica_count() {
514 let json_without_nreplica = r#"{
515 "search_nprobe": 120,
516 "search_rng_factor": 2.5,
517 "search_rng_epsilon": 15.0,
518 "write_nprobe": 60,
519 "write_rng_factor": 1.5,
520 "write_rng_epsilon": 8.0,
521 "split_threshold": 80,
522 "num_samples_kmeans": 1500,
523 "initial_lambda": 150.0,
524 "reassign_neighbor_count": 32,
525 "merge_threshold": 30,
526 "num_centers_to_merge_to": 6,
527 "space": "l2",
528 "ef_construction": 180,
529 "ef_search": 200,
530 "max_neighbors": 56
531 }"#;
532
533 let internal_config: InternalSpannConfiguration =
534 serde_json::from_str(json_without_nreplica).unwrap();
535
536 assert_eq!(internal_config.search_nprobe, 120);
537 assert_eq!(internal_config.search_rng_factor, 2.5);
538 assert_eq!(internal_config.search_rng_epsilon, 15.0);
539 assert_eq!(internal_config.write_nprobe, 60);
540 assert_eq!(internal_config.write_rng_factor, 1.5);
541 assert_eq!(internal_config.write_rng_epsilon, 8.0);
542 assert_eq!(internal_config.split_threshold, 80);
543 assert_eq!(internal_config.num_samples_kmeans, 1500);
544 assert_eq!(internal_config.initial_lambda, 150.0);
545 assert_eq!(internal_config.reassign_neighbor_count, 32);
546 assert_eq!(internal_config.merge_threshold, 30);
547 assert_eq!(internal_config.num_centers_to_merge_to, 6);
548 assert_eq!(internal_config.space, Space::L2);
549 assert_eq!(internal_config.ef_construction, 180);
550 assert_eq!(internal_config.ef_search, 200);
551 assert_eq!(internal_config.max_neighbors, 56);
552 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
553 }
554}