Skip to main content

chroma_types/
spann_configuration.rs

1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8    64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12    1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16    10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20    32
21}
22
23pub fn default_nreplica_count() -> u32 {
24    8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28    1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32    5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36    50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40    1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44    100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48    64
49}
50
51pub fn default_merge_threshold() -> u32 {
52    25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56    8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60    200
61}
62
63pub fn default_search_ef_spann() -> usize {
64    200
65}
66
67pub fn default_m_spann() -> usize {
68    64
69}
70
71pub fn default_quantize() -> bool {
72    false
73}
74
75pub fn default_center_drift_threshold() -> f32 {
76    0.125
77}
78
79fn default_space_spann() -> Space {
80    Space::L2
81}
82
83#[derive(Debug, Error)]
84pub enum DistributedSpannParametersFromSegmentError {
85    #[error("Invalid metadata: {0}")]
86    InvalidMetadata(#[from] serde_json::Error),
87    #[error("Invalid parameters: {0}")]
88    InvalidParameters(#[from] validator::ValidationErrors),
89}
90
91impl ChromaError for DistributedSpannParametersFromSegmentError {
92    fn code(&self) -> ErrorCodes {
93        match self {
94            DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
95                ErrorCodes::InvalidArgument
96            }
97            DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
98                ErrorCodes::InvalidArgument
99            }
100        }
101    }
102}
103
104#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
105#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
106pub struct InternalSpannConfiguration {
107    #[serde(default = "default_search_nprobe")]
108    pub search_nprobe: u32,
109    #[serde(default = "default_search_rng_factor")]
110    pub search_rng_factor: f32,
111    #[serde(default = "default_search_rng_epsilon")]
112    pub search_rng_epsilon: f32,
113    #[serde(default = "default_write_nprobe")]
114    #[validate(range(max = 128))]
115    pub write_nprobe: u32,
116    #[serde(default = "default_nreplica_count")]
117    #[validate(range(max = 8))]
118    pub nreplica_count: u32,
119    #[serde(default = "default_write_rng_factor")]
120    pub write_rng_factor: f32,
121    #[serde(default = "default_write_rng_epsilon")]
122    pub write_rng_epsilon: f32,
123    #[serde(default = "default_split_threshold")]
124    #[validate(range(min = 25, max = 200))]
125    pub split_threshold: u32,
126    #[serde(default = "default_num_samples_kmeans")]
127    pub num_samples_kmeans: usize,
128    #[serde(default = "default_initial_lambda")]
129    pub initial_lambda: f32,
130    #[serde(default = "default_reassign_neighbor_count")]
131    #[validate(range(max = 64))]
132    pub reassign_neighbor_count: u32,
133    #[serde(default = "default_merge_threshold")]
134    #[validate(range(min = 12, max = 100))]
135    pub merge_threshold: u32,
136    #[serde(default = "default_num_centers_to_merge_to")]
137    #[validate(range(max = 8))]
138    pub num_centers_to_merge_to: u32,
139    #[serde(default = "default_space_spann")]
140    pub space: Space,
141    #[serde(default = "default_construction_ef_spann")]
142    #[validate(range(max = 200))]
143    pub ef_construction: usize,
144    #[serde(default = "default_search_ef_spann")]
145    #[validate(range(max = 200))]
146    pub ef_search: usize,
147    #[serde(default = "default_m_spann")]
148    #[validate(range(max = 64))]
149    pub max_neighbors: usize,
150}
151
152impl Default for InternalSpannConfiguration {
153    fn default() -> Self {
154        serde_json::from_str("{}").unwrap()
155    }
156}
157
158impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
159    fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
160        InternalSpannConfiguration {
161            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
162            search_rng_factor: config
163                .search_rng_factor
164                .unwrap_or(default_search_rng_factor()),
165            search_rng_epsilon: config
166                .search_rng_epsilon
167                .unwrap_or(default_search_rng_epsilon()),
168            nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
169            write_rng_factor: config
170                .write_rng_factor
171                .unwrap_or(default_write_rng_factor()),
172            write_rng_epsilon: config
173                .write_rng_epsilon
174                .unwrap_or(default_write_rng_epsilon()),
175            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
176            num_samples_kmeans: config
177                .num_samples_kmeans
178                .unwrap_or(default_num_samples_kmeans()),
179            initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
180            reassign_neighbor_count: config
181                .reassign_neighbor_count
182                .unwrap_or(default_reassign_neighbor_count()),
183            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
184            num_centers_to_merge_to: config
185                .num_centers_to_merge_to
186                .unwrap_or(default_num_centers_to_merge_to()),
187            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
188            ef_construction: config
189                .ef_construction
190                .unwrap_or(default_construction_ef_spann()),
191            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
192            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
193            space: space.unwrap_or(&default_space()).clone(),
194        }
195    }
196}
197
198#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
199#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
200#[serde(deny_unknown_fields)]
201pub struct SpannConfiguration {
202    pub search_nprobe: Option<u32>,
203    pub write_nprobe: Option<u32>,
204    pub space: Option<Space>,
205    pub ef_construction: Option<usize>,
206    pub ef_search: Option<usize>,
207    pub max_neighbors: Option<usize>,
208    pub reassign_neighbor_count: Option<u32>,
209    pub split_threshold: Option<u32>,
210    pub merge_threshold: Option<u32>,
211}
212
213impl From<InternalSpannConfiguration> for SpannConfiguration {
214    fn from(config: InternalSpannConfiguration) -> Self {
215        Self {
216            search_nprobe: Some(config.search_nprobe),
217            write_nprobe: Some(config.write_nprobe),
218            space: Some(config.space),
219            ef_construction: Some(config.ef_construction),
220            ef_search: Some(config.ef_search),
221            max_neighbors: Some(config.max_neighbors),
222            reassign_neighbor_count: Some(config.reassign_neighbor_count),
223            split_threshold: Some(config.split_threshold),
224            merge_threshold: Some(config.merge_threshold),
225        }
226    }
227}
228
229impl From<SpannConfiguration> for InternalSpannConfiguration {
230    fn from(config: SpannConfiguration) -> Self {
231        Self {
232            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
233            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
234            space: config.space.unwrap_or(default_space_spann()),
235            ef_construction: config
236                .ef_construction
237                .unwrap_or(default_construction_ef_spann()),
238            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
239            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
240            reassign_neighbor_count: config
241                .reassign_neighbor_count
242                .unwrap_or(default_reassign_neighbor_count()),
243            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
244            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
245            ..Default::default()
246        }
247    }
248}
249
250impl Default for SpannConfiguration {
251    fn default() -> Self {
252        InternalSpannConfiguration::default().into()
253    }
254}
255
256#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
257#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
258#[serde(deny_unknown_fields)]
259#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
260pub struct UpdateSpannConfiguration {
261    #[validate(range(max = 128))]
262    pub search_nprobe: Option<u32>,
263    #[validate(range(max = 200))]
264    pub ef_search: Option<usize>,
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_spann_configuration_to_internal_spann_configuration() {
273        let spann_config = SpannConfiguration {
274            search_nprobe: Some(100),
275            write_nprobe: Some(50),
276            space: Some(Space::Cosine),
277            ef_construction: Some(150),
278            ef_search: Some(180),
279            max_neighbors: Some(32),
280            reassign_neighbor_count: Some(48),
281            split_threshold: Some(75),
282            merge_threshold: Some(50),
283        };
284
285        let internal_config: InternalSpannConfiguration = spann_config.into();
286
287        assert_eq!(internal_config.search_nprobe, 100);
288        assert_eq!(internal_config.write_nprobe, 50);
289        assert_eq!(internal_config.space, Space::Cosine);
290        assert_eq!(internal_config.ef_construction, 150);
291        assert_eq!(internal_config.ef_search, 180);
292        assert_eq!(internal_config.max_neighbors, 32);
293        assert_eq!(internal_config.reassign_neighbor_count, 48);
294        assert_eq!(internal_config.split_threshold, 75);
295        assert_eq!(internal_config.merge_threshold, 50);
296        assert_eq!(
297            internal_config.search_rng_factor,
298            default_search_rng_factor()
299        );
300        assert_eq!(
301            internal_config.search_rng_epsilon,
302            default_search_rng_epsilon()
303        );
304        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
305        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
306        assert_eq!(
307            internal_config.write_rng_epsilon,
308            default_write_rng_epsilon()
309        );
310        assert_eq!(
311            internal_config.num_samples_kmeans,
312            default_num_samples_kmeans()
313        );
314        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
315        assert_eq!(
316            internal_config.num_centers_to_merge_to,
317            default_num_centers_to_merge_to()
318        );
319    }
320
321    #[test]
322    fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
323        let spann_config = SpannConfiguration {
324            search_nprobe: None,
325            write_nprobe: None,
326            space: None,
327            ef_construction: None,
328            ef_search: None,
329            max_neighbors: None,
330            reassign_neighbor_count: None,
331            split_threshold: None,
332            merge_threshold: None,
333        };
334
335        let internal_config: InternalSpannConfiguration = spann_config.into();
336
337        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
338        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
339        assert_eq!(internal_config.space, default_space_spann());
340        assert_eq!(
341            internal_config.ef_construction,
342            default_construction_ef_spann()
343        );
344        assert_eq!(internal_config.ef_search, default_search_ef_spann());
345        assert_eq!(internal_config.max_neighbors, default_m_spann());
346        assert_eq!(
347            internal_config.reassign_neighbor_count,
348            default_reassign_neighbor_count()
349        );
350        assert_eq!(internal_config.split_threshold, default_split_threshold());
351        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
352        assert_eq!(
353            internal_config.search_rng_factor,
354            default_search_rng_factor()
355        );
356        assert_eq!(
357            internal_config.search_rng_epsilon,
358            default_search_rng_epsilon()
359        );
360        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
361        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
362        assert_eq!(
363            internal_config.write_rng_epsilon,
364            default_write_rng_epsilon()
365        );
366        assert_eq!(
367            internal_config.num_samples_kmeans,
368            default_num_samples_kmeans()
369        );
370        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
371        assert_eq!(
372            internal_config.num_centers_to_merge_to,
373            default_num_centers_to_merge_to()
374        );
375    }
376
377    #[test]
378    fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
379        let spann_config = SpannConfiguration {
380            search_nprobe: Some(80),
381            write_nprobe: None,
382            space: Some(Space::Ip),
383            ef_construction: None,
384            ef_search: Some(160),
385            max_neighbors: Some(48),
386            reassign_neighbor_count: None,
387            split_threshold: Some(100),
388            merge_threshold: None,
389        };
390
391        let internal_config: InternalSpannConfiguration = spann_config.into();
392
393        assert_eq!(internal_config.search_nprobe, 80);
394        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
395        assert_eq!(internal_config.space, Space::Ip);
396        assert_eq!(
397            internal_config.ef_construction,
398            default_construction_ef_spann()
399        );
400        assert_eq!(internal_config.ef_search, 160);
401        assert_eq!(internal_config.max_neighbors, 48);
402        assert_eq!(
403            internal_config.reassign_neighbor_count,
404            default_reassign_neighbor_count()
405        );
406        assert_eq!(internal_config.split_threshold, 100);
407        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
408        assert_eq!(
409            internal_config.search_rng_factor,
410            default_search_rng_factor()
411        );
412        assert_eq!(
413            internal_config.search_rng_epsilon,
414            default_search_rng_epsilon()
415        );
416        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
417        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
418        assert_eq!(
419            internal_config.write_rng_epsilon,
420            default_write_rng_epsilon()
421        );
422        assert_eq!(
423            internal_config.num_samples_kmeans,
424            default_num_samples_kmeans()
425        );
426        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
427        assert_eq!(
428            internal_config.num_centers_to_merge_to,
429            default_num_centers_to_merge_to()
430        );
431    }
432
433    #[test]
434    fn test_internal_spann_configuration_default() {
435        let internal_config = InternalSpannConfiguration::default();
436
437        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
438        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
439        assert_eq!(internal_config.space, default_space_spann());
440        assert_eq!(
441            internal_config.ef_construction,
442            default_construction_ef_spann()
443        );
444        assert_eq!(internal_config.ef_search, default_search_ef_spann());
445        assert_eq!(internal_config.max_neighbors, default_m_spann());
446        assert_eq!(
447            internal_config.reassign_neighbor_count,
448            default_reassign_neighbor_count()
449        );
450        assert_eq!(internal_config.split_threshold, default_split_threshold());
451        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
452        assert_eq!(
453            internal_config.search_rng_factor,
454            default_search_rng_factor()
455        );
456        assert_eq!(
457            internal_config.search_rng_epsilon,
458            default_search_rng_epsilon()
459        );
460        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
461        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
462        assert_eq!(
463            internal_config.write_rng_epsilon,
464            default_write_rng_epsilon()
465        );
466        assert_eq!(
467            internal_config.num_samples_kmeans,
468            default_num_samples_kmeans()
469        );
470        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
471        assert_eq!(
472            internal_config.num_centers_to_merge_to,
473            default_num_centers_to_merge_to()
474        );
475    }
476
477    #[test]
478    fn test_spann_configuration_default() {
479        let spann_config = SpannConfiguration::default();
480        let internal_config: InternalSpannConfiguration = spann_config.into();
481
482        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
483        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
484        assert_eq!(internal_config.space, default_space_spann());
485        assert_eq!(
486            internal_config.ef_construction,
487            default_construction_ef_spann()
488        );
489        assert_eq!(internal_config.ef_search, default_search_ef_spann());
490        assert_eq!(internal_config.max_neighbors, default_m_spann());
491        assert_eq!(
492            internal_config.reassign_neighbor_count,
493            default_reassign_neighbor_count()
494        );
495        assert_eq!(internal_config.split_threshold, default_split_threshold());
496        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
497        assert_eq!(
498            internal_config.search_rng_factor,
499            default_search_rng_factor()
500        );
501        assert_eq!(
502            internal_config.search_rng_epsilon,
503            default_search_rng_epsilon()
504        );
505        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
506        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
507        assert_eq!(
508            internal_config.write_rng_epsilon,
509            default_write_rng_epsilon()
510        );
511        assert_eq!(
512            internal_config.num_samples_kmeans,
513            default_num_samples_kmeans()
514        );
515        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
516        assert_eq!(
517            internal_config.num_centers_to_merge_to,
518            default_num_centers_to_merge_to()
519        );
520    }
521
522    #[test]
523    fn test_deserialize_json_without_nreplica_count() {
524        let json_without_nreplica = r#"{
525            "search_nprobe": 120,
526            "search_rng_factor": 2.5,
527            "search_rng_epsilon": 15.0,
528            "write_nprobe": 60,
529            "write_rng_factor": 1.5,
530            "write_rng_epsilon": 8.0,
531            "split_threshold": 80,
532            "num_samples_kmeans": 1500,
533            "initial_lambda": 150.0,
534            "reassign_neighbor_count": 32,
535            "merge_threshold": 30,
536            "num_centers_to_merge_to": 6,
537            "space": "l2",
538            "ef_construction": 180,
539            "ef_search": 200,
540            "max_neighbors": 56
541        }"#;
542
543        let internal_config: InternalSpannConfiguration =
544            serde_json::from_str(json_without_nreplica).unwrap();
545
546        assert_eq!(internal_config.search_nprobe, 120);
547        assert_eq!(internal_config.search_rng_factor, 2.5);
548        assert_eq!(internal_config.search_rng_epsilon, 15.0);
549        assert_eq!(internal_config.write_nprobe, 60);
550        assert_eq!(internal_config.write_rng_factor, 1.5);
551        assert_eq!(internal_config.write_rng_epsilon, 8.0);
552        assert_eq!(internal_config.split_threshold, 80);
553        assert_eq!(internal_config.num_samples_kmeans, 1500);
554        assert_eq!(internal_config.initial_lambda, 150.0);
555        assert_eq!(internal_config.reassign_neighbor_count, 32);
556        assert_eq!(internal_config.merge_threshold, 30);
557        assert_eq!(internal_config.num_centers_to_merge_to, 6);
558        assert_eq!(internal_config.space, Space::L2);
559        assert_eq!(internal_config.ef_construction, 180);
560        assert_eq!(internal_config.ef_search, 200);
561        assert_eq!(internal_config.max_neighbors, 56);
562        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
563    }
564}