1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8 64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12 1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16 10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20 32
21}
22
23pub fn default_nreplica_count() -> u32 {
24 8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28 1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32 5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36 50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40 1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44 100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48 64
49}
50
51pub fn default_merge_threshold() -> u32 {
52 25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56 8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60 200
61}
62
63pub fn default_search_ef_spann() -> usize {
64 200
65}
66
67pub fn default_m_spann() -> usize {
68 64
69}
70
71pub fn default_center_drift_threshold() -> f32 {
72 0.125
73}
74
75fn default_space_spann() -> Space {
76 Space::L2
77}
78
79#[derive(Debug, Error)]
80pub enum DistributedSpannParametersFromSegmentError {
81 #[error("Invalid metadata: {0}")]
82 InvalidMetadata(#[from] serde_json::Error),
83 #[error("Invalid parameters: {0}")]
84 InvalidParameters(#[from] validator::ValidationErrors),
85}
86
87impl ChromaError for DistributedSpannParametersFromSegmentError {
88 fn code(&self) -> ErrorCodes {
89 match self {
90 DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
91 ErrorCodes::InvalidArgument
92 }
93 DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
94 ErrorCodes::InvalidArgument
95 }
96 }
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
101#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
102pub struct InternalSpannConfiguration {
103 #[serde(default = "default_search_nprobe")]
104 pub search_nprobe: u32,
105 #[serde(default = "default_search_rng_factor")]
106 pub search_rng_factor: f32,
107 #[serde(default = "default_search_rng_epsilon")]
108 pub search_rng_epsilon: f32,
109 #[serde(default = "default_write_nprobe")]
110 #[validate(range(max = 128))]
111 pub write_nprobe: u32,
112 #[serde(default = "default_nreplica_count")]
113 #[validate(range(max = 8))]
114 pub nreplica_count: u32,
115 #[serde(default = "default_write_rng_factor")]
116 pub write_rng_factor: f32,
117 #[serde(default = "default_write_rng_epsilon")]
118 pub write_rng_epsilon: f32,
119 #[serde(default = "default_split_threshold")]
120 #[validate(range(min = 25, max = 200))]
121 pub split_threshold: u32,
122 #[serde(default = "default_num_samples_kmeans")]
123 pub num_samples_kmeans: usize,
124 #[serde(default = "default_initial_lambda")]
125 pub initial_lambda: f32,
126 #[serde(default = "default_reassign_neighbor_count")]
127 #[validate(range(max = 64))]
128 pub reassign_neighbor_count: u32,
129 #[serde(default = "default_merge_threshold")]
130 #[validate(range(min = 12, max = 100))]
131 pub merge_threshold: u32,
132 #[serde(default = "default_num_centers_to_merge_to")]
133 #[validate(range(max = 8))]
134 pub num_centers_to_merge_to: u32,
135 #[serde(default = "default_space_spann")]
136 pub space: Space,
137 #[serde(default = "default_construction_ef_spann")]
138 #[validate(range(max = 200))]
139 pub ef_construction: usize,
140 #[serde(default = "default_search_ef_spann")]
141 #[validate(range(max = 200))]
142 pub ef_search: usize,
143 #[serde(default = "default_m_spann")]
144 #[validate(range(max = 64))]
145 pub max_neighbors: usize,
146}
147
148impl Default for InternalSpannConfiguration {
149 fn default() -> Self {
150 serde_json::from_str("{}").unwrap()
151 }
152}
153
154impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
155 fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
156 InternalSpannConfiguration {
157 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
158 search_rng_factor: config
159 .search_rng_factor
160 .unwrap_or(default_search_rng_factor()),
161 search_rng_epsilon: config
162 .search_rng_epsilon
163 .unwrap_or(default_search_rng_epsilon()),
164 nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
165 write_rng_factor: config
166 .write_rng_factor
167 .unwrap_or(default_write_rng_factor()),
168 write_rng_epsilon: config
169 .write_rng_epsilon
170 .unwrap_or(default_write_rng_epsilon()),
171 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
172 num_samples_kmeans: config
173 .num_samples_kmeans
174 .unwrap_or(default_num_samples_kmeans()),
175 initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
176 reassign_neighbor_count: config
177 .reassign_neighbor_count
178 .unwrap_or(default_reassign_neighbor_count()),
179 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
180 num_centers_to_merge_to: config
181 .num_centers_to_merge_to
182 .unwrap_or(default_num_centers_to_merge_to()),
183 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
184 ef_construction: config
185 .ef_construction
186 .unwrap_or(default_construction_ef_spann()),
187 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
188 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
189 space: space.unwrap_or(&default_space()).clone(),
190 }
191 }
192}
193
194#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
195#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
196#[serde(deny_unknown_fields)]
197pub struct SpannConfiguration {
198 pub search_nprobe: Option<u32>,
199 pub write_nprobe: Option<u32>,
200 pub space: Option<Space>,
201 pub ef_construction: Option<usize>,
202 pub ef_search: Option<usize>,
203 pub max_neighbors: Option<usize>,
204 pub reassign_neighbor_count: Option<u32>,
205 pub split_threshold: Option<u32>,
206 pub merge_threshold: Option<u32>,
207}
208
209impl From<InternalSpannConfiguration> for SpannConfiguration {
210 fn from(config: InternalSpannConfiguration) -> Self {
211 Self {
212 search_nprobe: Some(config.search_nprobe),
213 write_nprobe: Some(config.write_nprobe),
214 space: Some(config.space),
215 ef_construction: Some(config.ef_construction),
216 ef_search: Some(config.ef_search),
217 max_neighbors: Some(config.max_neighbors),
218 reassign_neighbor_count: Some(config.reassign_neighbor_count),
219 split_threshold: Some(config.split_threshold),
220 merge_threshold: Some(config.merge_threshold),
221 }
222 }
223}
224
225impl From<SpannConfiguration> for InternalSpannConfiguration {
226 fn from(config: SpannConfiguration) -> Self {
227 Self {
228 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
229 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
230 space: config.space.unwrap_or(default_space_spann()),
231 ef_construction: config
232 .ef_construction
233 .unwrap_or(default_construction_ef_spann()),
234 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
235 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
236 reassign_neighbor_count: config
237 .reassign_neighbor_count
238 .unwrap_or(default_reassign_neighbor_count()),
239 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
240 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
241 ..Default::default()
242 }
243 }
244}
245
246impl Default for SpannConfiguration {
247 fn default() -> Self {
248 InternalSpannConfiguration::default().into()
249 }
250}
251
252#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
253#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
254#[serde(deny_unknown_fields)]
255#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
256pub struct UpdateSpannConfiguration {
257 #[validate(range(max = 128))]
258 pub search_nprobe: Option<u32>,
259 #[validate(range(max = 200))]
260 pub ef_search: Option<usize>,
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_spann_configuration_to_internal_spann_configuration() {
269 let spann_config = SpannConfiguration {
270 search_nprobe: Some(100),
271 write_nprobe: Some(50),
272 space: Some(Space::Cosine),
273 ef_construction: Some(150),
274 ef_search: Some(180),
275 max_neighbors: Some(32),
276 reassign_neighbor_count: Some(48),
277 split_threshold: Some(75),
278 merge_threshold: Some(50),
279 };
280
281 let internal_config: InternalSpannConfiguration = spann_config.into();
282
283 assert_eq!(internal_config.search_nprobe, 100);
284 assert_eq!(internal_config.write_nprobe, 50);
285 assert_eq!(internal_config.space, Space::Cosine);
286 assert_eq!(internal_config.ef_construction, 150);
287 assert_eq!(internal_config.ef_search, 180);
288 assert_eq!(internal_config.max_neighbors, 32);
289 assert_eq!(internal_config.reassign_neighbor_count, 48);
290 assert_eq!(internal_config.split_threshold, 75);
291 assert_eq!(internal_config.merge_threshold, 50);
292 assert_eq!(
293 internal_config.search_rng_factor,
294 default_search_rng_factor()
295 );
296 assert_eq!(
297 internal_config.search_rng_epsilon,
298 default_search_rng_epsilon()
299 );
300 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
301 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
302 assert_eq!(
303 internal_config.write_rng_epsilon,
304 default_write_rng_epsilon()
305 );
306 assert_eq!(
307 internal_config.num_samples_kmeans,
308 default_num_samples_kmeans()
309 );
310 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
311 assert_eq!(
312 internal_config.num_centers_to_merge_to,
313 default_num_centers_to_merge_to()
314 );
315 }
316
317 #[test]
318 fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
319 let spann_config = SpannConfiguration {
320 search_nprobe: None,
321 write_nprobe: None,
322 space: None,
323 ef_construction: None,
324 ef_search: None,
325 max_neighbors: None,
326 reassign_neighbor_count: None,
327 split_threshold: None,
328 merge_threshold: None,
329 };
330
331 let internal_config: InternalSpannConfiguration = spann_config.into();
332
333 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
334 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
335 assert_eq!(internal_config.space, default_space_spann());
336 assert_eq!(
337 internal_config.ef_construction,
338 default_construction_ef_spann()
339 );
340 assert_eq!(internal_config.ef_search, default_search_ef_spann());
341 assert_eq!(internal_config.max_neighbors, default_m_spann());
342 assert_eq!(
343 internal_config.reassign_neighbor_count,
344 default_reassign_neighbor_count()
345 );
346 assert_eq!(internal_config.split_threshold, default_split_threshold());
347 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
348 assert_eq!(
349 internal_config.search_rng_factor,
350 default_search_rng_factor()
351 );
352 assert_eq!(
353 internal_config.search_rng_epsilon,
354 default_search_rng_epsilon()
355 );
356 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
357 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
358 assert_eq!(
359 internal_config.write_rng_epsilon,
360 default_write_rng_epsilon()
361 );
362 assert_eq!(
363 internal_config.num_samples_kmeans,
364 default_num_samples_kmeans()
365 );
366 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
367 assert_eq!(
368 internal_config.num_centers_to_merge_to,
369 default_num_centers_to_merge_to()
370 );
371 }
372
373 #[test]
374 fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
375 let spann_config = SpannConfiguration {
376 search_nprobe: Some(80),
377 write_nprobe: None,
378 space: Some(Space::Ip),
379 ef_construction: None,
380 ef_search: Some(160),
381 max_neighbors: Some(48),
382 reassign_neighbor_count: None,
383 split_threshold: Some(100),
384 merge_threshold: None,
385 };
386
387 let internal_config: InternalSpannConfiguration = spann_config.into();
388
389 assert_eq!(internal_config.search_nprobe, 80);
390 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
391 assert_eq!(internal_config.space, Space::Ip);
392 assert_eq!(
393 internal_config.ef_construction,
394 default_construction_ef_spann()
395 );
396 assert_eq!(internal_config.ef_search, 160);
397 assert_eq!(internal_config.max_neighbors, 48);
398 assert_eq!(
399 internal_config.reassign_neighbor_count,
400 default_reassign_neighbor_count()
401 );
402 assert_eq!(internal_config.split_threshold, 100);
403 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
404 assert_eq!(
405 internal_config.search_rng_factor,
406 default_search_rng_factor()
407 );
408 assert_eq!(
409 internal_config.search_rng_epsilon,
410 default_search_rng_epsilon()
411 );
412 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
413 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
414 assert_eq!(
415 internal_config.write_rng_epsilon,
416 default_write_rng_epsilon()
417 );
418 assert_eq!(
419 internal_config.num_samples_kmeans,
420 default_num_samples_kmeans()
421 );
422 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
423 assert_eq!(
424 internal_config.num_centers_to_merge_to,
425 default_num_centers_to_merge_to()
426 );
427 }
428
429 #[test]
430 fn test_internal_spann_configuration_default() {
431 let internal_config = InternalSpannConfiguration::default();
432
433 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
434 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
435 assert_eq!(internal_config.space, default_space_spann());
436 assert_eq!(
437 internal_config.ef_construction,
438 default_construction_ef_spann()
439 );
440 assert_eq!(internal_config.ef_search, default_search_ef_spann());
441 assert_eq!(internal_config.max_neighbors, default_m_spann());
442 assert_eq!(
443 internal_config.reassign_neighbor_count,
444 default_reassign_neighbor_count()
445 );
446 assert_eq!(internal_config.split_threshold, default_split_threshold());
447 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
448 assert_eq!(
449 internal_config.search_rng_factor,
450 default_search_rng_factor()
451 );
452 assert_eq!(
453 internal_config.search_rng_epsilon,
454 default_search_rng_epsilon()
455 );
456 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
457 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
458 assert_eq!(
459 internal_config.write_rng_epsilon,
460 default_write_rng_epsilon()
461 );
462 assert_eq!(
463 internal_config.num_samples_kmeans,
464 default_num_samples_kmeans()
465 );
466 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
467 assert_eq!(
468 internal_config.num_centers_to_merge_to,
469 default_num_centers_to_merge_to()
470 );
471 }
472
473 #[test]
474 fn test_spann_configuration_default() {
475 let spann_config = SpannConfiguration::default();
476 let internal_config: InternalSpannConfiguration = spann_config.into();
477
478 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
479 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
480 assert_eq!(internal_config.space, default_space_spann());
481 assert_eq!(
482 internal_config.ef_construction,
483 default_construction_ef_spann()
484 );
485 assert_eq!(internal_config.ef_search, default_search_ef_spann());
486 assert_eq!(internal_config.max_neighbors, default_m_spann());
487 assert_eq!(
488 internal_config.reassign_neighbor_count,
489 default_reassign_neighbor_count()
490 );
491 assert_eq!(internal_config.split_threshold, default_split_threshold());
492 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
493 assert_eq!(
494 internal_config.search_rng_factor,
495 default_search_rng_factor()
496 );
497 assert_eq!(
498 internal_config.search_rng_epsilon,
499 default_search_rng_epsilon()
500 );
501 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
502 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
503 assert_eq!(
504 internal_config.write_rng_epsilon,
505 default_write_rng_epsilon()
506 );
507 assert_eq!(
508 internal_config.num_samples_kmeans,
509 default_num_samples_kmeans()
510 );
511 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
512 assert_eq!(
513 internal_config.num_centers_to_merge_to,
514 default_num_centers_to_merge_to()
515 );
516 }
517
518 #[test]
519 fn test_deserialize_json_without_nreplica_count() {
520 let json_without_nreplica = r#"{
521 "search_nprobe": 120,
522 "search_rng_factor": 2.5,
523 "search_rng_epsilon": 15.0,
524 "write_nprobe": 60,
525 "write_rng_factor": 1.5,
526 "write_rng_epsilon": 8.0,
527 "split_threshold": 80,
528 "num_samples_kmeans": 1500,
529 "initial_lambda": 150.0,
530 "reassign_neighbor_count": 32,
531 "merge_threshold": 30,
532 "num_centers_to_merge_to": 6,
533 "space": "l2",
534 "ef_construction": 180,
535 "ef_search": 200,
536 "max_neighbors": 56
537 }"#;
538
539 let internal_config: InternalSpannConfiguration =
540 serde_json::from_str(json_without_nreplica).unwrap();
541
542 assert_eq!(internal_config.search_nprobe, 120);
543 assert_eq!(internal_config.search_rng_factor, 2.5);
544 assert_eq!(internal_config.search_rng_epsilon, 15.0);
545 assert_eq!(internal_config.write_nprobe, 60);
546 assert_eq!(internal_config.write_rng_factor, 1.5);
547 assert_eq!(internal_config.write_rng_epsilon, 8.0);
548 assert_eq!(internal_config.split_threshold, 80);
549 assert_eq!(internal_config.num_samples_kmeans, 1500);
550 assert_eq!(internal_config.initial_lambda, 150.0);
551 assert_eq!(internal_config.reassign_neighbor_count, 32);
552 assert_eq!(internal_config.merge_threshold, 30);
553 assert_eq!(internal_config.num_centers_to_merge_to, 6);
554 assert_eq!(internal_config.space, Space::L2);
555 assert_eq!(internal_config.ef_construction, 180);
556 assert_eq!(internal_config.ef_search, 200);
557 assert_eq!(internal_config.max_neighbors, 56);
558 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
559 }
560}