1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8 64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12 1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16 10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20 32
21}
22
23pub fn default_nreplica_count() -> u32 {
24 8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28 1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32 5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36 50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40 1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44 100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48 64
49}
50
51pub fn default_merge_threshold() -> u32 {
52 25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56 8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60 200
61}
62
63pub fn default_search_ef_spann() -> usize {
64 200
65}
66
67pub fn default_m_spann() -> usize {
68 64
69}
70
71pub fn default_quantize() -> bool {
72 false
73}
74
75pub fn default_center_drift_threshold() -> f32 {
76 0.125
77}
78
79fn default_space_spann() -> Space {
80 Space::L2
81}
82
83#[derive(Debug, Error)]
84pub enum DistributedSpannParametersFromSegmentError {
85 #[error("Invalid metadata: {0}")]
86 InvalidMetadata(#[from] serde_json::Error),
87 #[error("Invalid parameters: {0}")]
88 InvalidParameters(#[from] validator::ValidationErrors),
89}
90
91impl ChromaError for DistributedSpannParametersFromSegmentError {
92 fn code(&self) -> ErrorCodes {
93 match self {
94 DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
95 ErrorCodes::InvalidArgument
96 }
97 DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
98 ErrorCodes::InvalidArgument
99 }
100 }
101 }
102}
103
104#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
105#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
106pub struct InternalSpannConfiguration {
107 #[serde(default = "default_search_nprobe")]
108 pub search_nprobe: u32,
109 #[serde(default = "default_search_rng_factor")]
110 pub search_rng_factor: f32,
111 #[serde(default = "default_search_rng_epsilon")]
112 pub search_rng_epsilon: f32,
113 #[serde(default = "default_write_nprobe")]
114 #[validate(range(max = 128))]
115 pub write_nprobe: u32,
116 #[serde(default = "default_nreplica_count")]
117 #[validate(range(max = 8))]
118 pub nreplica_count: u32,
119 #[serde(default = "default_write_rng_factor")]
120 pub write_rng_factor: f32,
121 #[serde(default = "default_write_rng_epsilon")]
122 pub write_rng_epsilon: f32,
123 #[serde(default = "default_split_threshold")]
124 #[validate(range(min = 25, max = 200))]
125 pub split_threshold: u32,
126 #[serde(default = "default_num_samples_kmeans")]
127 pub num_samples_kmeans: usize,
128 #[serde(default = "default_initial_lambda")]
129 pub initial_lambda: f32,
130 #[serde(default = "default_reassign_neighbor_count")]
131 #[validate(range(max = 64))]
132 pub reassign_neighbor_count: u32,
133 #[serde(default = "default_merge_threshold")]
134 #[validate(range(min = 12, max = 100))]
135 pub merge_threshold: u32,
136 #[serde(default = "default_num_centers_to_merge_to")]
137 #[validate(range(max = 8))]
138 pub num_centers_to_merge_to: u32,
139 #[serde(default = "default_space_spann")]
140 pub space: Space,
141 #[serde(default = "default_construction_ef_spann")]
142 #[validate(range(max = 200))]
143 pub ef_construction: usize,
144 #[serde(default = "default_search_ef_spann")]
145 #[validate(range(max = 200))]
146 pub ef_search: usize,
147 #[serde(default = "default_m_spann")]
148 #[validate(range(max = 64))]
149 pub max_neighbors: usize,
150}
151
152impl Default for InternalSpannConfiguration {
153 fn default() -> Self {
154 serde_json::from_str("{}").unwrap()
155 }
156}
157
158impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
159 fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
160 InternalSpannConfiguration {
161 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
162 search_rng_factor: config
163 .search_rng_factor
164 .unwrap_or(default_search_rng_factor()),
165 search_rng_epsilon: config
166 .search_rng_epsilon
167 .unwrap_or(default_search_rng_epsilon()),
168 nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
169 write_rng_factor: config
170 .write_rng_factor
171 .unwrap_or(default_write_rng_factor()),
172 write_rng_epsilon: config
173 .write_rng_epsilon
174 .unwrap_or(default_write_rng_epsilon()),
175 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
176 num_samples_kmeans: config
177 .num_samples_kmeans
178 .unwrap_or(default_num_samples_kmeans()),
179 initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
180 reassign_neighbor_count: config
181 .reassign_neighbor_count
182 .unwrap_or(default_reassign_neighbor_count()),
183 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
184 num_centers_to_merge_to: config
185 .num_centers_to_merge_to
186 .unwrap_or(default_num_centers_to_merge_to()),
187 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
188 ef_construction: config
189 .ef_construction
190 .unwrap_or(default_construction_ef_spann()),
191 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
192 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
193 space: space.unwrap_or(&default_space()).clone(),
194 }
195 }
196}
197
198#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
199#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
200#[serde(deny_unknown_fields)]
201pub struct SpannConfiguration {
202 pub search_nprobe: Option<u32>,
203 pub write_nprobe: Option<u32>,
204 pub space: Option<Space>,
205 pub ef_construction: Option<usize>,
206 pub ef_search: Option<usize>,
207 pub max_neighbors: Option<usize>,
208 pub reassign_neighbor_count: Option<u32>,
209 pub split_threshold: Option<u32>,
210 pub merge_threshold: Option<u32>,
211}
212
213impl From<InternalSpannConfiguration> for SpannConfiguration {
214 fn from(config: InternalSpannConfiguration) -> Self {
215 Self {
216 search_nprobe: Some(config.search_nprobe),
217 write_nprobe: Some(config.write_nprobe),
218 space: Some(config.space),
219 ef_construction: Some(config.ef_construction),
220 ef_search: Some(config.ef_search),
221 max_neighbors: Some(config.max_neighbors),
222 reassign_neighbor_count: Some(config.reassign_neighbor_count),
223 split_threshold: Some(config.split_threshold),
224 merge_threshold: Some(config.merge_threshold),
225 }
226 }
227}
228
229impl From<SpannConfiguration> for InternalSpannConfiguration {
230 fn from(config: SpannConfiguration) -> Self {
231 Self {
232 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
233 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
234 space: config.space.unwrap_or(default_space_spann()),
235 ef_construction: config
236 .ef_construction
237 .unwrap_or(default_construction_ef_spann()),
238 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
239 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
240 reassign_neighbor_count: config
241 .reassign_neighbor_count
242 .unwrap_or(default_reassign_neighbor_count()),
243 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
244 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
245 ..Default::default()
246 }
247 }
248}
249
250impl Default for SpannConfiguration {
251 fn default() -> Self {
252 InternalSpannConfiguration::default().into()
253 }
254}
255
256#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
257#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
258#[serde(deny_unknown_fields)]
259#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
260pub struct UpdateSpannConfiguration {
261 #[validate(range(max = 128))]
262 pub search_nprobe: Option<u32>,
263 #[validate(range(max = 200))]
264 pub ef_search: Option<usize>,
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_spann_configuration_to_internal_spann_configuration() {
273 let spann_config = SpannConfiguration {
274 search_nprobe: Some(100),
275 write_nprobe: Some(50),
276 space: Some(Space::Cosine),
277 ef_construction: Some(150),
278 ef_search: Some(180),
279 max_neighbors: Some(32),
280 reassign_neighbor_count: Some(48),
281 split_threshold: Some(75),
282 merge_threshold: Some(50),
283 };
284
285 let internal_config: InternalSpannConfiguration = spann_config.into();
286
287 assert_eq!(internal_config.search_nprobe, 100);
288 assert_eq!(internal_config.write_nprobe, 50);
289 assert_eq!(internal_config.space, Space::Cosine);
290 assert_eq!(internal_config.ef_construction, 150);
291 assert_eq!(internal_config.ef_search, 180);
292 assert_eq!(internal_config.max_neighbors, 32);
293 assert_eq!(internal_config.reassign_neighbor_count, 48);
294 assert_eq!(internal_config.split_threshold, 75);
295 assert_eq!(internal_config.merge_threshold, 50);
296 assert_eq!(
297 internal_config.search_rng_factor,
298 default_search_rng_factor()
299 );
300 assert_eq!(
301 internal_config.search_rng_epsilon,
302 default_search_rng_epsilon()
303 );
304 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
305 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
306 assert_eq!(
307 internal_config.write_rng_epsilon,
308 default_write_rng_epsilon()
309 );
310 assert_eq!(
311 internal_config.num_samples_kmeans,
312 default_num_samples_kmeans()
313 );
314 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
315 assert_eq!(
316 internal_config.num_centers_to_merge_to,
317 default_num_centers_to_merge_to()
318 );
319 }
320
321 #[test]
322 fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
323 let spann_config = SpannConfiguration {
324 search_nprobe: None,
325 write_nprobe: None,
326 space: None,
327 ef_construction: None,
328 ef_search: None,
329 max_neighbors: None,
330 reassign_neighbor_count: None,
331 split_threshold: None,
332 merge_threshold: None,
333 };
334
335 let internal_config: InternalSpannConfiguration = spann_config.into();
336
337 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
338 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
339 assert_eq!(internal_config.space, default_space_spann());
340 assert_eq!(
341 internal_config.ef_construction,
342 default_construction_ef_spann()
343 );
344 assert_eq!(internal_config.ef_search, default_search_ef_spann());
345 assert_eq!(internal_config.max_neighbors, default_m_spann());
346 assert_eq!(
347 internal_config.reassign_neighbor_count,
348 default_reassign_neighbor_count()
349 );
350 assert_eq!(internal_config.split_threshold, default_split_threshold());
351 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
352 assert_eq!(
353 internal_config.search_rng_factor,
354 default_search_rng_factor()
355 );
356 assert_eq!(
357 internal_config.search_rng_epsilon,
358 default_search_rng_epsilon()
359 );
360 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
361 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
362 assert_eq!(
363 internal_config.write_rng_epsilon,
364 default_write_rng_epsilon()
365 );
366 assert_eq!(
367 internal_config.num_samples_kmeans,
368 default_num_samples_kmeans()
369 );
370 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
371 assert_eq!(
372 internal_config.num_centers_to_merge_to,
373 default_num_centers_to_merge_to()
374 );
375 }
376
377 #[test]
378 fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
379 let spann_config = SpannConfiguration {
380 search_nprobe: Some(80),
381 write_nprobe: None,
382 space: Some(Space::Ip),
383 ef_construction: None,
384 ef_search: Some(160),
385 max_neighbors: Some(48),
386 reassign_neighbor_count: None,
387 split_threshold: Some(100),
388 merge_threshold: None,
389 };
390
391 let internal_config: InternalSpannConfiguration = spann_config.into();
392
393 assert_eq!(internal_config.search_nprobe, 80);
394 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
395 assert_eq!(internal_config.space, Space::Ip);
396 assert_eq!(
397 internal_config.ef_construction,
398 default_construction_ef_spann()
399 );
400 assert_eq!(internal_config.ef_search, 160);
401 assert_eq!(internal_config.max_neighbors, 48);
402 assert_eq!(
403 internal_config.reassign_neighbor_count,
404 default_reassign_neighbor_count()
405 );
406 assert_eq!(internal_config.split_threshold, 100);
407 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
408 assert_eq!(
409 internal_config.search_rng_factor,
410 default_search_rng_factor()
411 );
412 assert_eq!(
413 internal_config.search_rng_epsilon,
414 default_search_rng_epsilon()
415 );
416 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
417 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
418 assert_eq!(
419 internal_config.write_rng_epsilon,
420 default_write_rng_epsilon()
421 );
422 assert_eq!(
423 internal_config.num_samples_kmeans,
424 default_num_samples_kmeans()
425 );
426 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
427 assert_eq!(
428 internal_config.num_centers_to_merge_to,
429 default_num_centers_to_merge_to()
430 );
431 }
432
433 #[test]
434 fn test_internal_spann_configuration_default() {
435 let internal_config = InternalSpannConfiguration::default();
436
437 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
438 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
439 assert_eq!(internal_config.space, default_space_spann());
440 assert_eq!(
441 internal_config.ef_construction,
442 default_construction_ef_spann()
443 );
444 assert_eq!(internal_config.ef_search, default_search_ef_spann());
445 assert_eq!(internal_config.max_neighbors, default_m_spann());
446 assert_eq!(
447 internal_config.reassign_neighbor_count,
448 default_reassign_neighbor_count()
449 );
450 assert_eq!(internal_config.split_threshold, default_split_threshold());
451 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
452 assert_eq!(
453 internal_config.search_rng_factor,
454 default_search_rng_factor()
455 );
456 assert_eq!(
457 internal_config.search_rng_epsilon,
458 default_search_rng_epsilon()
459 );
460 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
461 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
462 assert_eq!(
463 internal_config.write_rng_epsilon,
464 default_write_rng_epsilon()
465 );
466 assert_eq!(
467 internal_config.num_samples_kmeans,
468 default_num_samples_kmeans()
469 );
470 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
471 assert_eq!(
472 internal_config.num_centers_to_merge_to,
473 default_num_centers_to_merge_to()
474 );
475 }
476
477 #[test]
478 fn test_spann_configuration_default() {
479 let spann_config = SpannConfiguration::default();
480 let internal_config: InternalSpannConfiguration = spann_config.into();
481
482 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
483 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
484 assert_eq!(internal_config.space, default_space_spann());
485 assert_eq!(
486 internal_config.ef_construction,
487 default_construction_ef_spann()
488 );
489 assert_eq!(internal_config.ef_search, default_search_ef_spann());
490 assert_eq!(internal_config.max_neighbors, default_m_spann());
491 assert_eq!(
492 internal_config.reassign_neighbor_count,
493 default_reassign_neighbor_count()
494 );
495 assert_eq!(internal_config.split_threshold, default_split_threshold());
496 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
497 assert_eq!(
498 internal_config.search_rng_factor,
499 default_search_rng_factor()
500 );
501 assert_eq!(
502 internal_config.search_rng_epsilon,
503 default_search_rng_epsilon()
504 );
505 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
506 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
507 assert_eq!(
508 internal_config.write_rng_epsilon,
509 default_write_rng_epsilon()
510 );
511 assert_eq!(
512 internal_config.num_samples_kmeans,
513 default_num_samples_kmeans()
514 );
515 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
516 assert_eq!(
517 internal_config.num_centers_to_merge_to,
518 default_num_centers_to_merge_to()
519 );
520 }
521
522 #[test]
523 fn test_deserialize_json_without_nreplica_count() {
524 let json_without_nreplica = r#"{
525 "search_nprobe": 120,
526 "search_rng_factor": 2.5,
527 "search_rng_epsilon": 15.0,
528 "write_nprobe": 60,
529 "write_rng_factor": 1.5,
530 "write_rng_epsilon": 8.0,
531 "split_threshold": 80,
532 "num_samples_kmeans": 1500,
533 "initial_lambda": 150.0,
534 "reassign_neighbor_count": 32,
535 "merge_threshold": 30,
536 "num_centers_to_merge_to": 6,
537 "space": "l2",
538 "ef_construction": 180,
539 "ef_search": 200,
540 "max_neighbors": 56
541 }"#;
542
543 let internal_config: InternalSpannConfiguration =
544 serde_json::from_str(json_without_nreplica).unwrap();
545
546 assert_eq!(internal_config.search_nprobe, 120);
547 assert_eq!(internal_config.search_rng_factor, 2.5);
548 assert_eq!(internal_config.search_rng_epsilon, 15.0);
549 assert_eq!(internal_config.write_nprobe, 60);
550 assert_eq!(internal_config.write_rng_factor, 1.5);
551 assert_eq!(internal_config.write_rng_epsilon, 8.0);
552 assert_eq!(internal_config.split_threshold, 80);
553 assert_eq!(internal_config.num_samples_kmeans, 1500);
554 assert_eq!(internal_config.initial_lambda, 150.0);
555 assert_eq!(internal_config.reassign_neighbor_count, 32);
556 assert_eq!(internal_config.merge_threshold, 30);
557 assert_eq!(internal_config.num_centers_to_merge_to, 6);
558 assert_eq!(internal_config.space, Space::L2);
559 assert_eq!(internal_config.ef_construction, 180);
560 assert_eq!(internal_config.ef_search, 200);
561 assert_eq!(internal_config.max_neighbors, 56);
562 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
563 }
564}