chroma_types/
spann_configuration.rs

1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8    64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12    1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16    10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20    32
21}
22
23pub fn default_nreplica_count() -> u32 {
24    8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28    1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32    5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36    50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40    1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44    100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48    64
49}
50
51pub fn default_merge_threshold() -> u32 {
52    25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56    8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60    200
61}
62
63pub fn default_search_ef_spann() -> usize {
64    200
65}
66
67pub fn default_m_spann() -> usize {
68    64
69}
70
71fn default_space_spann() -> Space {
72    Space::L2
73}
74
75#[derive(Debug, Error)]
76pub enum DistributedSpannParametersFromSegmentError {
77    #[error("Invalid metadata: {0}")]
78    InvalidMetadata(#[from] serde_json::Error),
79    #[error("Invalid parameters: {0}")]
80    InvalidParameters(#[from] validator::ValidationErrors),
81}
82
83impl ChromaError for DistributedSpannParametersFromSegmentError {
84    fn code(&self) -> ErrorCodes {
85        match self {
86            DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
87                ErrorCodes::InvalidArgument
88            }
89            DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
90                ErrorCodes::InvalidArgument
91            }
92        }
93    }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98pub struct InternalSpannConfiguration {
99    #[serde(default = "default_search_nprobe")]
100    pub search_nprobe: u32,
101    #[serde(default = "default_search_rng_factor")]
102    pub search_rng_factor: f32,
103    #[serde(default = "default_search_rng_epsilon")]
104    pub search_rng_epsilon: f32,
105    #[serde(default = "default_write_nprobe")]
106    #[validate(range(max = 128))]
107    pub write_nprobe: u32,
108    #[serde(default = "default_nreplica_count")]
109    #[validate(range(max = 8))]
110    pub nreplica_count: u32,
111    #[serde(default = "default_write_rng_factor")]
112    pub write_rng_factor: f32,
113    #[serde(default = "default_write_rng_epsilon")]
114    pub write_rng_epsilon: f32,
115    #[serde(default = "default_split_threshold")]
116    #[validate(range(min = 25, max = 200))]
117    pub split_threshold: u32,
118    #[serde(default = "default_num_samples_kmeans")]
119    pub num_samples_kmeans: usize,
120    #[serde(default = "default_initial_lambda")]
121    pub initial_lambda: f32,
122    #[serde(default = "default_reassign_neighbor_count")]
123    #[validate(range(max = 64))]
124    pub reassign_neighbor_count: u32,
125    #[serde(default = "default_merge_threshold")]
126    #[validate(range(min = 12, max = 100))]
127    pub merge_threshold: u32,
128    #[serde(default = "default_num_centers_to_merge_to")]
129    #[validate(range(max = 8))]
130    pub num_centers_to_merge_to: u32,
131    #[serde(default = "default_space_spann")]
132    pub space: Space,
133    #[serde(default = "default_construction_ef_spann")]
134    #[validate(range(max = 200))]
135    pub ef_construction: usize,
136    #[serde(default = "default_search_ef_spann")]
137    #[validate(range(max = 200))]
138    pub ef_search: usize,
139    #[serde(default = "default_m_spann")]
140    #[validate(range(max = 64))]
141    pub max_neighbors: usize,
142}
143
144impl Default for InternalSpannConfiguration {
145    fn default() -> Self {
146        serde_json::from_str("{}").unwrap()
147    }
148}
149
150impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
151    fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
152        InternalSpannConfiguration {
153            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
154            search_rng_factor: config
155                .search_rng_factor
156                .unwrap_or(default_search_rng_factor()),
157            search_rng_epsilon: config
158                .search_rng_epsilon
159                .unwrap_or(default_search_rng_epsilon()),
160            nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
161            write_rng_factor: config
162                .write_rng_factor
163                .unwrap_or(default_write_rng_factor()),
164            write_rng_epsilon: config
165                .write_rng_epsilon
166                .unwrap_or(default_write_rng_epsilon()),
167            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
168            num_samples_kmeans: config
169                .num_samples_kmeans
170                .unwrap_or(default_num_samples_kmeans()),
171            initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
172            reassign_neighbor_count: config
173                .reassign_neighbor_count
174                .unwrap_or(default_reassign_neighbor_count()),
175            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
176            num_centers_to_merge_to: config
177                .num_centers_to_merge_to
178                .unwrap_or(default_num_centers_to_merge_to()),
179            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
180            ef_construction: config
181                .ef_construction
182                .unwrap_or(default_construction_ef_spann()),
183            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
184            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
185            space: space.unwrap_or(&default_space()).clone(),
186        }
187    }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
191#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
192#[serde(deny_unknown_fields)]
193pub struct SpannConfiguration {
194    pub search_nprobe: Option<u32>,
195    pub write_nprobe: Option<u32>,
196    pub space: Option<Space>,
197    pub ef_construction: Option<usize>,
198    pub ef_search: Option<usize>,
199    pub max_neighbors: Option<usize>,
200    pub reassign_neighbor_count: Option<u32>,
201    pub split_threshold: Option<u32>,
202    pub merge_threshold: Option<u32>,
203}
204
205impl From<InternalSpannConfiguration> for SpannConfiguration {
206    fn from(config: InternalSpannConfiguration) -> Self {
207        Self {
208            search_nprobe: Some(config.search_nprobe),
209            write_nprobe: Some(config.write_nprobe),
210            space: Some(config.space),
211            ef_construction: Some(config.ef_construction),
212            ef_search: Some(config.ef_search),
213            max_neighbors: Some(config.max_neighbors),
214            reassign_neighbor_count: Some(config.reassign_neighbor_count),
215            split_threshold: Some(config.split_threshold),
216            merge_threshold: Some(config.merge_threshold),
217        }
218    }
219}
220
221impl From<SpannConfiguration> for InternalSpannConfiguration {
222    fn from(config: SpannConfiguration) -> Self {
223        Self {
224            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
225            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
226            space: config.space.unwrap_or(default_space_spann()),
227            ef_construction: config
228                .ef_construction
229                .unwrap_or(default_construction_ef_spann()),
230            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
231            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
232            reassign_neighbor_count: config
233                .reassign_neighbor_count
234                .unwrap_or(default_reassign_neighbor_count()),
235            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
236            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
237            ..Default::default()
238        }
239    }
240}
241
242impl Default for SpannConfiguration {
243    fn default() -> Self {
244        InternalSpannConfiguration::default().into()
245    }
246}
247
248#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
249#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
250#[serde(deny_unknown_fields)]
251#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
252pub struct UpdateSpannConfiguration {
253    #[validate(range(max = 128))]
254    pub search_nprobe: Option<u32>,
255    #[validate(range(max = 200))]
256    pub ef_search: Option<usize>,
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_spann_configuration_to_internal_spann_configuration() {
265        let spann_config = SpannConfiguration {
266            search_nprobe: Some(100),
267            write_nprobe: Some(50),
268            space: Some(Space::Cosine),
269            ef_construction: Some(150),
270            ef_search: Some(180),
271            max_neighbors: Some(32),
272            reassign_neighbor_count: Some(48),
273            split_threshold: Some(75),
274            merge_threshold: Some(50),
275        };
276
277        let internal_config: InternalSpannConfiguration = spann_config.into();
278
279        assert_eq!(internal_config.search_nprobe, 100);
280        assert_eq!(internal_config.write_nprobe, 50);
281        assert_eq!(internal_config.space, Space::Cosine);
282        assert_eq!(internal_config.ef_construction, 150);
283        assert_eq!(internal_config.ef_search, 180);
284        assert_eq!(internal_config.max_neighbors, 32);
285        assert_eq!(internal_config.reassign_neighbor_count, 48);
286        assert_eq!(internal_config.split_threshold, 75);
287        assert_eq!(internal_config.merge_threshold, 50);
288        assert_eq!(
289            internal_config.search_rng_factor,
290            default_search_rng_factor()
291        );
292        assert_eq!(
293            internal_config.search_rng_epsilon,
294            default_search_rng_epsilon()
295        );
296        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
297        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
298        assert_eq!(
299            internal_config.write_rng_epsilon,
300            default_write_rng_epsilon()
301        );
302        assert_eq!(
303            internal_config.num_samples_kmeans,
304            default_num_samples_kmeans()
305        );
306        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
307        assert_eq!(
308            internal_config.num_centers_to_merge_to,
309            default_num_centers_to_merge_to()
310        );
311    }
312
313    #[test]
314    fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
315        let spann_config = SpannConfiguration {
316            search_nprobe: None,
317            write_nprobe: None,
318            space: None,
319            ef_construction: None,
320            ef_search: None,
321            max_neighbors: None,
322            reassign_neighbor_count: None,
323            split_threshold: None,
324            merge_threshold: None,
325        };
326
327        let internal_config: InternalSpannConfiguration = spann_config.into();
328
329        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
330        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
331        assert_eq!(internal_config.space, default_space_spann());
332        assert_eq!(
333            internal_config.ef_construction,
334            default_construction_ef_spann()
335        );
336        assert_eq!(internal_config.ef_search, default_search_ef_spann());
337        assert_eq!(internal_config.max_neighbors, default_m_spann());
338        assert_eq!(
339            internal_config.reassign_neighbor_count,
340            default_reassign_neighbor_count()
341        );
342        assert_eq!(internal_config.split_threshold, default_split_threshold());
343        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
344        assert_eq!(
345            internal_config.search_rng_factor,
346            default_search_rng_factor()
347        );
348        assert_eq!(
349            internal_config.search_rng_epsilon,
350            default_search_rng_epsilon()
351        );
352        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
353        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
354        assert_eq!(
355            internal_config.write_rng_epsilon,
356            default_write_rng_epsilon()
357        );
358        assert_eq!(
359            internal_config.num_samples_kmeans,
360            default_num_samples_kmeans()
361        );
362        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
363        assert_eq!(
364            internal_config.num_centers_to_merge_to,
365            default_num_centers_to_merge_to()
366        );
367    }
368
369    #[test]
370    fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
371        let spann_config = SpannConfiguration {
372            search_nprobe: Some(80),
373            write_nprobe: None,
374            space: Some(Space::Ip),
375            ef_construction: None,
376            ef_search: Some(160),
377            max_neighbors: Some(48),
378            reassign_neighbor_count: None,
379            split_threshold: Some(100),
380            merge_threshold: None,
381        };
382
383        let internal_config: InternalSpannConfiguration = spann_config.into();
384
385        assert_eq!(internal_config.search_nprobe, 80);
386        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
387        assert_eq!(internal_config.space, Space::Ip);
388        assert_eq!(
389            internal_config.ef_construction,
390            default_construction_ef_spann()
391        );
392        assert_eq!(internal_config.ef_search, 160);
393        assert_eq!(internal_config.max_neighbors, 48);
394        assert_eq!(
395            internal_config.reassign_neighbor_count,
396            default_reassign_neighbor_count()
397        );
398        assert_eq!(internal_config.split_threshold, 100);
399        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
400        assert_eq!(
401            internal_config.search_rng_factor,
402            default_search_rng_factor()
403        );
404        assert_eq!(
405            internal_config.search_rng_epsilon,
406            default_search_rng_epsilon()
407        );
408        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
409        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
410        assert_eq!(
411            internal_config.write_rng_epsilon,
412            default_write_rng_epsilon()
413        );
414        assert_eq!(
415            internal_config.num_samples_kmeans,
416            default_num_samples_kmeans()
417        );
418        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
419        assert_eq!(
420            internal_config.num_centers_to_merge_to,
421            default_num_centers_to_merge_to()
422        );
423    }
424
425    #[test]
426    fn test_internal_spann_configuration_default() {
427        let internal_config = InternalSpannConfiguration::default();
428
429        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
430        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
431        assert_eq!(internal_config.space, default_space_spann());
432        assert_eq!(
433            internal_config.ef_construction,
434            default_construction_ef_spann()
435        );
436        assert_eq!(internal_config.ef_search, default_search_ef_spann());
437        assert_eq!(internal_config.max_neighbors, default_m_spann());
438        assert_eq!(
439            internal_config.reassign_neighbor_count,
440            default_reassign_neighbor_count()
441        );
442        assert_eq!(internal_config.split_threshold, default_split_threshold());
443        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
444        assert_eq!(
445            internal_config.search_rng_factor,
446            default_search_rng_factor()
447        );
448        assert_eq!(
449            internal_config.search_rng_epsilon,
450            default_search_rng_epsilon()
451        );
452        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
453        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
454        assert_eq!(
455            internal_config.write_rng_epsilon,
456            default_write_rng_epsilon()
457        );
458        assert_eq!(
459            internal_config.num_samples_kmeans,
460            default_num_samples_kmeans()
461        );
462        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
463        assert_eq!(
464            internal_config.num_centers_to_merge_to,
465            default_num_centers_to_merge_to()
466        );
467    }
468
469    #[test]
470    fn test_spann_configuration_default() {
471        let spann_config = SpannConfiguration::default();
472        let internal_config: InternalSpannConfiguration = spann_config.into();
473
474        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
475        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
476        assert_eq!(internal_config.space, default_space_spann());
477        assert_eq!(
478            internal_config.ef_construction,
479            default_construction_ef_spann()
480        );
481        assert_eq!(internal_config.ef_search, default_search_ef_spann());
482        assert_eq!(internal_config.max_neighbors, default_m_spann());
483        assert_eq!(
484            internal_config.reassign_neighbor_count,
485            default_reassign_neighbor_count()
486        );
487        assert_eq!(internal_config.split_threshold, default_split_threshold());
488        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
489        assert_eq!(
490            internal_config.search_rng_factor,
491            default_search_rng_factor()
492        );
493        assert_eq!(
494            internal_config.search_rng_epsilon,
495            default_search_rng_epsilon()
496        );
497        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
498        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
499        assert_eq!(
500            internal_config.write_rng_epsilon,
501            default_write_rng_epsilon()
502        );
503        assert_eq!(
504            internal_config.num_samples_kmeans,
505            default_num_samples_kmeans()
506        );
507        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
508        assert_eq!(
509            internal_config.num_centers_to_merge_to,
510            default_num_centers_to_merge_to()
511        );
512    }
513
514    #[test]
515    fn test_deserialize_json_without_nreplica_count() {
516        let json_without_nreplica = r#"{
517            "search_nprobe": 120,
518            "search_rng_factor": 2.5,
519            "search_rng_epsilon": 15.0,
520            "write_nprobe": 60,
521            "write_rng_factor": 1.5,
522            "write_rng_epsilon": 8.0,
523            "split_threshold": 80,
524            "num_samples_kmeans": 1500,
525            "initial_lambda": 150.0,
526            "reassign_neighbor_count": 32,
527            "merge_threshold": 30,
528            "num_centers_to_merge_to": 6,
529            "space": "l2",
530            "ef_construction": 180,
531            "ef_search": 200,
532            "max_neighbors": 56
533        }"#;
534
535        let internal_config: InternalSpannConfiguration =
536            serde_json::from_str(json_without_nreplica).unwrap();
537
538        assert_eq!(internal_config.search_nprobe, 120);
539        assert_eq!(internal_config.search_rng_factor, 2.5);
540        assert_eq!(internal_config.search_rng_epsilon, 15.0);
541        assert_eq!(internal_config.write_nprobe, 60);
542        assert_eq!(internal_config.write_rng_factor, 1.5);
543        assert_eq!(internal_config.write_rng_epsilon, 8.0);
544        assert_eq!(internal_config.split_threshold, 80);
545        assert_eq!(internal_config.num_samples_kmeans, 1500);
546        assert_eq!(internal_config.initial_lambda, 150.0);
547        assert_eq!(internal_config.reassign_neighbor_count, 32);
548        assert_eq!(internal_config.merge_threshold, 30);
549        assert_eq!(internal_config.num_centers_to_merge_to, 6);
550        assert_eq!(internal_config.space, Space::L2);
551        assert_eq!(internal_config.ef_construction, 180);
552        assert_eq!(internal_config.ef_search, 200);
553        assert_eq!(internal_config.max_neighbors, 56);
554        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
555    }
556}