1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8 64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12 1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16 10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20 32
21}
22
23pub fn default_nreplica_count() -> u32 {
24 8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28 1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32 5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36 50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40 1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44 100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48 64
49}
50
51pub fn default_merge_threshold() -> u32 {
52 25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56 8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60 200
61}
62
63pub fn default_search_ef_spann() -> usize {
64 200
65}
66
67pub fn default_m_spann() -> usize {
68 64
69}
70
71fn default_space_spann() -> Space {
72 Space::L2
73}
74
75#[derive(Debug, Error)]
76pub enum DistributedSpannParametersFromSegmentError {
77 #[error("Invalid metadata: {0}")]
78 InvalidMetadata(#[from] serde_json::Error),
79 #[error("Invalid parameters: {0}")]
80 InvalidParameters(#[from] validator::ValidationErrors),
81}
82
83impl ChromaError for DistributedSpannParametersFromSegmentError {
84 fn code(&self) -> ErrorCodes {
85 match self {
86 DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
87 ErrorCodes::InvalidArgument
88 }
89 DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
90 ErrorCodes::InvalidArgument
91 }
92 }
93 }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98pub struct InternalSpannConfiguration {
99 #[serde(default = "default_search_nprobe")]
100 pub search_nprobe: u32,
101 #[serde(default = "default_search_rng_factor")]
102 pub search_rng_factor: f32,
103 #[serde(default = "default_search_rng_epsilon")]
104 pub search_rng_epsilon: f32,
105 #[serde(default = "default_write_nprobe")]
106 #[validate(range(max = 128))]
107 pub write_nprobe: u32,
108 #[serde(default = "default_nreplica_count")]
109 #[validate(range(max = 8))]
110 pub nreplica_count: u32,
111 #[serde(default = "default_write_rng_factor")]
112 pub write_rng_factor: f32,
113 #[serde(default = "default_write_rng_epsilon")]
114 pub write_rng_epsilon: f32,
115 #[serde(default = "default_split_threshold")]
116 #[validate(range(min = 25, max = 200))]
117 pub split_threshold: u32,
118 #[serde(default = "default_num_samples_kmeans")]
119 pub num_samples_kmeans: usize,
120 #[serde(default = "default_initial_lambda")]
121 pub initial_lambda: f32,
122 #[serde(default = "default_reassign_neighbor_count")]
123 #[validate(range(max = 64))]
124 pub reassign_neighbor_count: u32,
125 #[serde(default = "default_merge_threshold")]
126 #[validate(range(min = 12, max = 100))]
127 pub merge_threshold: u32,
128 #[serde(default = "default_num_centers_to_merge_to")]
129 #[validate(range(max = 8))]
130 pub num_centers_to_merge_to: u32,
131 #[serde(default = "default_space_spann")]
132 pub space: Space,
133 #[serde(default = "default_construction_ef_spann")]
134 #[validate(range(max = 200))]
135 pub ef_construction: usize,
136 #[serde(default = "default_search_ef_spann")]
137 #[validate(range(max = 200))]
138 pub ef_search: usize,
139 #[serde(default = "default_m_spann")]
140 #[validate(range(max = 64))]
141 pub max_neighbors: usize,
142}
143
144impl Default for InternalSpannConfiguration {
145 fn default() -> Self {
146 serde_json::from_str("{}").unwrap()
147 }
148}
149
150impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
151 fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
152 InternalSpannConfiguration {
153 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
154 search_rng_factor: config
155 .search_rng_factor
156 .unwrap_or(default_search_rng_factor()),
157 search_rng_epsilon: config
158 .search_rng_epsilon
159 .unwrap_or(default_search_rng_epsilon()),
160 nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
161 write_rng_factor: config
162 .write_rng_factor
163 .unwrap_or(default_write_rng_factor()),
164 write_rng_epsilon: config
165 .write_rng_epsilon
166 .unwrap_or(default_write_rng_epsilon()),
167 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
168 num_samples_kmeans: config
169 .num_samples_kmeans
170 .unwrap_or(default_num_samples_kmeans()),
171 initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
172 reassign_neighbor_count: config
173 .reassign_neighbor_count
174 .unwrap_or(default_reassign_neighbor_count()),
175 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
176 num_centers_to_merge_to: config
177 .num_centers_to_merge_to
178 .unwrap_or(default_num_centers_to_merge_to()),
179 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
180 ef_construction: config
181 .ef_construction
182 .unwrap_or(default_construction_ef_spann()),
183 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
184 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
185 space: space.unwrap_or(&default_space()).clone(),
186 }
187 }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
191#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
192#[serde(deny_unknown_fields)]
193pub struct SpannConfiguration {
194 pub search_nprobe: Option<u32>,
195 pub write_nprobe: Option<u32>,
196 pub space: Option<Space>,
197 pub ef_construction: Option<usize>,
198 pub ef_search: Option<usize>,
199 pub max_neighbors: Option<usize>,
200 pub reassign_neighbor_count: Option<u32>,
201 pub split_threshold: Option<u32>,
202 pub merge_threshold: Option<u32>,
203}
204
205impl From<InternalSpannConfiguration> for SpannConfiguration {
206 fn from(config: InternalSpannConfiguration) -> Self {
207 Self {
208 search_nprobe: Some(config.search_nprobe),
209 write_nprobe: Some(config.write_nprobe),
210 space: Some(config.space),
211 ef_construction: Some(config.ef_construction),
212 ef_search: Some(config.ef_search),
213 max_neighbors: Some(config.max_neighbors),
214 reassign_neighbor_count: Some(config.reassign_neighbor_count),
215 split_threshold: Some(config.split_threshold),
216 merge_threshold: Some(config.merge_threshold),
217 }
218 }
219}
220
221impl From<SpannConfiguration> for InternalSpannConfiguration {
222 fn from(config: SpannConfiguration) -> Self {
223 Self {
224 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
225 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
226 space: config.space.unwrap_or(default_space_spann()),
227 ef_construction: config
228 .ef_construction
229 .unwrap_or(default_construction_ef_spann()),
230 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
231 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
232 reassign_neighbor_count: config
233 .reassign_neighbor_count
234 .unwrap_or(default_reassign_neighbor_count()),
235 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
236 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
237 ..Default::default()
238 }
239 }
240}
241
242impl Default for SpannConfiguration {
243 fn default() -> Self {
244 InternalSpannConfiguration::default().into()
245 }
246}
247
248#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
249#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
250#[serde(deny_unknown_fields)]
251#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
252pub struct UpdateSpannConfiguration {
253 #[validate(range(max = 128))]
254 pub search_nprobe: Option<u32>,
255 #[validate(range(max = 200))]
256 pub ef_search: Option<usize>,
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_spann_configuration_to_internal_spann_configuration() {
265 let spann_config = SpannConfiguration {
266 search_nprobe: Some(100),
267 write_nprobe: Some(50),
268 space: Some(Space::Cosine),
269 ef_construction: Some(150),
270 ef_search: Some(180),
271 max_neighbors: Some(32),
272 reassign_neighbor_count: Some(48),
273 split_threshold: Some(75),
274 merge_threshold: Some(50),
275 };
276
277 let internal_config: InternalSpannConfiguration = spann_config.into();
278
279 assert_eq!(internal_config.search_nprobe, 100);
280 assert_eq!(internal_config.write_nprobe, 50);
281 assert_eq!(internal_config.space, Space::Cosine);
282 assert_eq!(internal_config.ef_construction, 150);
283 assert_eq!(internal_config.ef_search, 180);
284 assert_eq!(internal_config.max_neighbors, 32);
285 assert_eq!(internal_config.reassign_neighbor_count, 48);
286 assert_eq!(internal_config.split_threshold, 75);
287 assert_eq!(internal_config.merge_threshold, 50);
288 assert_eq!(
289 internal_config.search_rng_factor,
290 default_search_rng_factor()
291 );
292 assert_eq!(
293 internal_config.search_rng_epsilon,
294 default_search_rng_epsilon()
295 );
296 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
297 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
298 assert_eq!(
299 internal_config.write_rng_epsilon,
300 default_write_rng_epsilon()
301 );
302 assert_eq!(
303 internal_config.num_samples_kmeans,
304 default_num_samples_kmeans()
305 );
306 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
307 assert_eq!(
308 internal_config.num_centers_to_merge_to,
309 default_num_centers_to_merge_to()
310 );
311 }
312
313 #[test]
314 fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
315 let spann_config = SpannConfiguration {
316 search_nprobe: None,
317 write_nprobe: None,
318 space: None,
319 ef_construction: None,
320 ef_search: None,
321 max_neighbors: None,
322 reassign_neighbor_count: None,
323 split_threshold: None,
324 merge_threshold: None,
325 };
326
327 let internal_config: InternalSpannConfiguration = spann_config.into();
328
329 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
330 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
331 assert_eq!(internal_config.space, default_space_spann());
332 assert_eq!(
333 internal_config.ef_construction,
334 default_construction_ef_spann()
335 );
336 assert_eq!(internal_config.ef_search, default_search_ef_spann());
337 assert_eq!(internal_config.max_neighbors, default_m_spann());
338 assert_eq!(
339 internal_config.reassign_neighbor_count,
340 default_reassign_neighbor_count()
341 );
342 assert_eq!(internal_config.split_threshold, default_split_threshold());
343 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
344 assert_eq!(
345 internal_config.search_rng_factor,
346 default_search_rng_factor()
347 );
348 assert_eq!(
349 internal_config.search_rng_epsilon,
350 default_search_rng_epsilon()
351 );
352 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
353 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
354 assert_eq!(
355 internal_config.write_rng_epsilon,
356 default_write_rng_epsilon()
357 );
358 assert_eq!(
359 internal_config.num_samples_kmeans,
360 default_num_samples_kmeans()
361 );
362 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
363 assert_eq!(
364 internal_config.num_centers_to_merge_to,
365 default_num_centers_to_merge_to()
366 );
367 }
368
369 #[test]
370 fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
371 let spann_config = SpannConfiguration {
372 search_nprobe: Some(80),
373 write_nprobe: None,
374 space: Some(Space::Ip),
375 ef_construction: None,
376 ef_search: Some(160),
377 max_neighbors: Some(48),
378 reassign_neighbor_count: None,
379 split_threshold: Some(100),
380 merge_threshold: None,
381 };
382
383 let internal_config: InternalSpannConfiguration = spann_config.into();
384
385 assert_eq!(internal_config.search_nprobe, 80);
386 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
387 assert_eq!(internal_config.space, Space::Ip);
388 assert_eq!(
389 internal_config.ef_construction,
390 default_construction_ef_spann()
391 );
392 assert_eq!(internal_config.ef_search, 160);
393 assert_eq!(internal_config.max_neighbors, 48);
394 assert_eq!(
395 internal_config.reassign_neighbor_count,
396 default_reassign_neighbor_count()
397 );
398 assert_eq!(internal_config.split_threshold, 100);
399 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
400 assert_eq!(
401 internal_config.search_rng_factor,
402 default_search_rng_factor()
403 );
404 assert_eq!(
405 internal_config.search_rng_epsilon,
406 default_search_rng_epsilon()
407 );
408 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
409 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
410 assert_eq!(
411 internal_config.write_rng_epsilon,
412 default_write_rng_epsilon()
413 );
414 assert_eq!(
415 internal_config.num_samples_kmeans,
416 default_num_samples_kmeans()
417 );
418 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
419 assert_eq!(
420 internal_config.num_centers_to_merge_to,
421 default_num_centers_to_merge_to()
422 );
423 }
424
425 #[test]
426 fn test_internal_spann_configuration_default() {
427 let internal_config = InternalSpannConfiguration::default();
428
429 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
430 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
431 assert_eq!(internal_config.space, default_space_spann());
432 assert_eq!(
433 internal_config.ef_construction,
434 default_construction_ef_spann()
435 );
436 assert_eq!(internal_config.ef_search, default_search_ef_spann());
437 assert_eq!(internal_config.max_neighbors, default_m_spann());
438 assert_eq!(
439 internal_config.reassign_neighbor_count,
440 default_reassign_neighbor_count()
441 );
442 assert_eq!(internal_config.split_threshold, default_split_threshold());
443 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
444 assert_eq!(
445 internal_config.search_rng_factor,
446 default_search_rng_factor()
447 );
448 assert_eq!(
449 internal_config.search_rng_epsilon,
450 default_search_rng_epsilon()
451 );
452 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
453 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
454 assert_eq!(
455 internal_config.write_rng_epsilon,
456 default_write_rng_epsilon()
457 );
458 assert_eq!(
459 internal_config.num_samples_kmeans,
460 default_num_samples_kmeans()
461 );
462 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
463 assert_eq!(
464 internal_config.num_centers_to_merge_to,
465 default_num_centers_to_merge_to()
466 );
467 }
468
469 #[test]
470 fn test_spann_configuration_default() {
471 let spann_config = SpannConfiguration::default();
472 let internal_config: InternalSpannConfiguration = spann_config.into();
473
474 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
475 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
476 assert_eq!(internal_config.space, default_space_spann());
477 assert_eq!(
478 internal_config.ef_construction,
479 default_construction_ef_spann()
480 );
481 assert_eq!(internal_config.ef_search, default_search_ef_spann());
482 assert_eq!(internal_config.max_neighbors, default_m_spann());
483 assert_eq!(
484 internal_config.reassign_neighbor_count,
485 default_reassign_neighbor_count()
486 );
487 assert_eq!(internal_config.split_threshold, default_split_threshold());
488 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
489 assert_eq!(
490 internal_config.search_rng_factor,
491 default_search_rng_factor()
492 );
493 assert_eq!(
494 internal_config.search_rng_epsilon,
495 default_search_rng_epsilon()
496 );
497 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
498 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
499 assert_eq!(
500 internal_config.write_rng_epsilon,
501 default_write_rng_epsilon()
502 );
503 assert_eq!(
504 internal_config.num_samples_kmeans,
505 default_num_samples_kmeans()
506 );
507 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
508 assert_eq!(
509 internal_config.num_centers_to_merge_to,
510 default_num_centers_to_merge_to()
511 );
512 }
513
514 #[test]
515 fn test_deserialize_json_without_nreplica_count() {
516 let json_without_nreplica = r#"{
517 "search_nprobe": 120,
518 "search_rng_factor": 2.5,
519 "search_rng_epsilon": 15.0,
520 "write_nprobe": 60,
521 "write_rng_factor": 1.5,
522 "write_rng_epsilon": 8.0,
523 "split_threshold": 80,
524 "num_samples_kmeans": 1500,
525 "initial_lambda": 150.0,
526 "reassign_neighbor_count": 32,
527 "merge_threshold": 30,
528 "num_centers_to_merge_to": 6,
529 "space": "l2",
530 "ef_construction": 180,
531 "ef_search": 200,
532 "max_neighbors": 56
533 }"#;
534
535 let internal_config: InternalSpannConfiguration =
536 serde_json::from_str(json_without_nreplica).unwrap();
537
538 assert_eq!(internal_config.search_nprobe, 120);
539 assert_eq!(internal_config.search_rng_factor, 2.5);
540 assert_eq!(internal_config.search_rng_epsilon, 15.0);
541 assert_eq!(internal_config.write_nprobe, 60);
542 assert_eq!(internal_config.write_rng_factor, 1.5);
543 assert_eq!(internal_config.write_rng_epsilon, 8.0);
544 assert_eq!(internal_config.split_threshold, 80);
545 assert_eq!(internal_config.num_samples_kmeans, 1500);
546 assert_eq!(internal_config.initial_lambda, 150.0);
547 assert_eq!(internal_config.reassign_neighbor_count, 32);
548 assert_eq!(internal_config.merge_threshold, 30);
549 assert_eq!(internal_config.num_centers_to_merge_to, 6);
550 assert_eq!(internal_config.space, Space::L2);
551 assert_eq!(internal_config.ef_construction, 180);
552 assert_eq!(internal_config.ef_search, 200);
553 assert_eq!(internal_config.max_neighbors, 56);
554 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
555 }
556}