Skip to main content

chroma_types/
spann_configuration.rs

1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8    64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12    1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16    10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20    32
21}
22
23pub fn default_nreplica_count() -> u32 {
24    8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28    1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32    5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36    50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40    1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44    100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48    64
49}
50
51pub fn default_merge_threshold() -> u32 {
52    25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56    8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60    200
61}
62
63pub fn default_search_ef_spann() -> usize {
64    200
65}
66
67pub fn default_m_spann() -> usize {
68    64
69}
70
71pub fn default_center_drift_threshold() -> f32 {
72    0.125
73}
74
75fn default_space_spann() -> Space {
76    Space::L2
77}
78
79#[derive(Debug, Error)]
80pub enum DistributedSpannParametersFromSegmentError {
81    #[error("Invalid metadata: {0}")]
82    InvalidMetadata(#[from] serde_json::Error),
83    #[error("Invalid parameters: {0}")]
84    InvalidParameters(#[from] validator::ValidationErrors),
85}
86
87impl ChromaError for DistributedSpannParametersFromSegmentError {
88    fn code(&self) -> ErrorCodes {
89        match self {
90            DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
91                ErrorCodes::InvalidArgument
92            }
93            DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
94                ErrorCodes::InvalidArgument
95            }
96        }
97    }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
101#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
102pub struct InternalSpannConfiguration {
103    #[serde(default = "default_search_nprobe")]
104    pub search_nprobe: u32,
105    #[serde(default = "default_search_rng_factor")]
106    pub search_rng_factor: f32,
107    #[serde(default = "default_search_rng_epsilon")]
108    pub search_rng_epsilon: f32,
109    #[serde(default = "default_write_nprobe")]
110    #[validate(range(max = 128))]
111    pub write_nprobe: u32,
112    #[serde(default = "default_nreplica_count")]
113    #[validate(range(max = 8))]
114    pub nreplica_count: u32,
115    #[serde(default = "default_write_rng_factor")]
116    pub write_rng_factor: f32,
117    #[serde(default = "default_write_rng_epsilon")]
118    pub write_rng_epsilon: f32,
119    #[serde(default = "default_split_threshold")]
120    #[validate(range(min = 25, max = 200))]
121    pub split_threshold: u32,
122    #[serde(default = "default_num_samples_kmeans")]
123    pub num_samples_kmeans: usize,
124    #[serde(default = "default_initial_lambda")]
125    pub initial_lambda: f32,
126    #[serde(default = "default_reassign_neighbor_count")]
127    #[validate(range(max = 64))]
128    pub reassign_neighbor_count: u32,
129    #[serde(default = "default_merge_threshold")]
130    #[validate(range(min = 12, max = 100))]
131    pub merge_threshold: u32,
132    #[serde(default = "default_num_centers_to_merge_to")]
133    #[validate(range(max = 8))]
134    pub num_centers_to_merge_to: u32,
135    #[serde(default = "default_space_spann")]
136    pub space: Space,
137    #[serde(default = "default_construction_ef_spann")]
138    #[validate(range(max = 200))]
139    pub ef_construction: usize,
140    #[serde(default = "default_search_ef_spann")]
141    #[validate(range(max = 200))]
142    pub ef_search: usize,
143    #[serde(default = "default_m_spann")]
144    #[validate(range(max = 64))]
145    pub max_neighbors: usize,
146}
147
148impl Default for InternalSpannConfiguration {
149    fn default() -> Self {
150        serde_json::from_str("{}").unwrap()
151    }
152}
153
154impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
155    fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
156        InternalSpannConfiguration {
157            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
158            search_rng_factor: config
159                .search_rng_factor
160                .unwrap_or(default_search_rng_factor()),
161            search_rng_epsilon: config
162                .search_rng_epsilon
163                .unwrap_or(default_search_rng_epsilon()),
164            nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
165            write_rng_factor: config
166                .write_rng_factor
167                .unwrap_or(default_write_rng_factor()),
168            write_rng_epsilon: config
169                .write_rng_epsilon
170                .unwrap_or(default_write_rng_epsilon()),
171            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
172            num_samples_kmeans: config
173                .num_samples_kmeans
174                .unwrap_or(default_num_samples_kmeans()),
175            initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
176            reassign_neighbor_count: config
177                .reassign_neighbor_count
178                .unwrap_or(default_reassign_neighbor_count()),
179            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
180            num_centers_to_merge_to: config
181                .num_centers_to_merge_to
182                .unwrap_or(default_num_centers_to_merge_to()),
183            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
184            ef_construction: config
185                .ef_construction
186                .unwrap_or(default_construction_ef_spann()),
187            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
188            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
189            space: space.unwrap_or(&default_space()).clone(),
190        }
191    }
192}
193
194#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
195#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
196#[serde(deny_unknown_fields)]
197pub struct SpannConfiguration {
198    pub search_nprobe: Option<u32>,
199    pub write_nprobe: Option<u32>,
200    pub space: Option<Space>,
201    pub ef_construction: Option<usize>,
202    pub ef_search: Option<usize>,
203    pub max_neighbors: Option<usize>,
204    pub reassign_neighbor_count: Option<u32>,
205    pub split_threshold: Option<u32>,
206    pub merge_threshold: Option<u32>,
207}
208
209impl From<InternalSpannConfiguration> for SpannConfiguration {
210    fn from(config: InternalSpannConfiguration) -> Self {
211        Self {
212            search_nprobe: Some(config.search_nprobe),
213            write_nprobe: Some(config.write_nprobe),
214            space: Some(config.space),
215            ef_construction: Some(config.ef_construction),
216            ef_search: Some(config.ef_search),
217            max_neighbors: Some(config.max_neighbors),
218            reassign_neighbor_count: Some(config.reassign_neighbor_count),
219            split_threshold: Some(config.split_threshold),
220            merge_threshold: Some(config.merge_threshold),
221        }
222    }
223}
224
225impl From<SpannConfiguration> for InternalSpannConfiguration {
226    fn from(config: SpannConfiguration) -> Self {
227        Self {
228            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
229            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
230            space: config.space.unwrap_or(default_space_spann()),
231            ef_construction: config
232                .ef_construction
233                .unwrap_or(default_construction_ef_spann()),
234            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
235            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
236            reassign_neighbor_count: config
237                .reassign_neighbor_count
238                .unwrap_or(default_reassign_neighbor_count()),
239            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
240            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
241            ..Default::default()
242        }
243    }
244}
245
246impl Default for SpannConfiguration {
247    fn default() -> Self {
248        InternalSpannConfiguration::default().into()
249    }
250}
251
252#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
253#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
254#[serde(deny_unknown_fields)]
255#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
256pub struct UpdateSpannConfiguration {
257    #[validate(range(max = 128))]
258    pub search_nprobe: Option<u32>,
259    #[validate(range(max = 200))]
260    pub ef_search: Option<usize>,
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_spann_configuration_to_internal_spann_configuration() {
269        let spann_config = SpannConfiguration {
270            search_nprobe: Some(100),
271            write_nprobe: Some(50),
272            space: Some(Space::Cosine),
273            ef_construction: Some(150),
274            ef_search: Some(180),
275            max_neighbors: Some(32),
276            reassign_neighbor_count: Some(48),
277            split_threshold: Some(75),
278            merge_threshold: Some(50),
279        };
280
281        let internal_config: InternalSpannConfiguration = spann_config.into();
282
283        assert_eq!(internal_config.search_nprobe, 100);
284        assert_eq!(internal_config.write_nprobe, 50);
285        assert_eq!(internal_config.space, Space::Cosine);
286        assert_eq!(internal_config.ef_construction, 150);
287        assert_eq!(internal_config.ef_search, 180);
288        assert_eq!(internal_config.max_neighbors, 32);
289        assert_eq!(internal_config.reassign_neighbor_count, 48);
290        assert_eq!(internal_config.split_threshold, 75);
291        assert_eq!(internal_config.merge_threshold, 50);
292        assert_eq!(
293            internal_config.search_rng_factor,
294            default_search_rng_factor()
295        );
296        assert_eq!(
297            internal_config.search_rng_epsilon,
298            default_search_rng_epsilon()
299        );
300        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
301        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
302        assert_eq!(
303            internal_config.write_rng_epsilon,
304            default_write_rng_epsilon()
305        );
306        assert_eq!(
307            internal_config.num_samples_kmeans,
308            default_num_samples_kmeans()
309        );
310        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
311        assert_eq!(
312            internal_config.num_centers_to_merge_to,
313            default_num_centers_to_merge_to()
314        );
315    }
316
317    #[test]
318    fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
319        let spann_config = SpannConfiguration {
320            search_nprobe: None,
321            write_nprobe: None,
322            space: None,
323            ef_construction: None,
324            ef_search: None,
325            max_neighbors: None,
326            reassign_neighbor_count: None,
327            split_threshold: None,
328            merge_threshold: None,
329        };
330
331        let internal_config: InternalSpannConfiguration = spann_config.into();
332
333        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
334        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
335        assert_eq!(internal_config.space, default_space_spann());
336        assert_eq!(
337            internal_config.ef_construction,
338            default_construction_ef_spann()
339        );
340        assert_eq!(internal_config.ef_search, default_search_ef_spann());
341        assert_eq!(internal_config.max_neighbors, default_m_spann());
342        assert_eq!(
343            internal_config.reassign_neighbor_count,
344            default_reassign_neighbor_count()
345        );
346        assert_eq!(internal_config.split_threshold, default_split_threshold());
347        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
348        assert_eq!(
349            internal_config.search_rng_factor,
350            default_search_rng_factor()
351        );
352        assert_eq!(
353            internal_config.search_rng_epsilon,
354            default_search_rng_epsilon()
355        );
356        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
357        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
358        assert_eq!(
359            internal_config.write_rng_epsilon,
360            default_write_rng_epsilon()
361        );
362        assert_eq!(
363            internal_config.num_samples_kmeans,
364            default_num_samples_kmeans()
365        );
366        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
367        assert_eq!(
368            internal_config.num_centers_to_merge_to,
369            default_num_centers_to_merge_to()
370        );
371    }
372
373    #[test]
374    fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
375        let spann_config = SpannConfiguration {
376            search_nprobe: Some(80),
377            write_nprobe: None,
378            space: Some(Space::Ip),
379            ef_construction: None,
380            ef_search: Some(160),
381            max_neighbors: Some(48),
382            reassign_neighbor_count: None,
383            split_threshold: Some(100),
384            merge_threshold: None,
385        };
386
387        let internal_config: InternalSpannConfiguration = spann_config.into();
388
389        assert_eq!(internal_config.search_nprobe, 80);
390        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
391        assert_eq!(internal_config.space, Space::Ip);
392        assert_eq!(
393            internal_config.ef_construction,
394            default_construction_ef_spann()
395        );
396        assert_eq!(internal_config.ef_search, 160);
397        assert_eq!(internal_config.max_neighbors, 48);
398        assert_eq!(
399            internal_config.reassign_neighbor_count,
400            default_reassign_neighbor_count()
401        );
402        assert_eq!(internal_config.split_threshold, 100);
403        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
404        assert_eq!(
405            internal_config.search_rng_factor,
406            default_search_rng_factor()
407        );
408        assert_eq!(
409            internal_config.search_rng_epsilon,
410            default_search_rng_epsilon()
411        );
412        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
413        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
414        assert_eq!(
415            internal_config.write_rng_epsilon,
416            default_write_rng_epsilon()
417        );
418        assert_eq!(
419            internal_config.num_samples_kmeans,
420            default_num_samples_kmeans()
421        );
422        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
423        assert_eq!(
424            internal_config.num_centers_to_merge_to,
425            default_num_centers_to_merge_to()
426        );
427    }
428
429    #[test]
430    fn test_internal_spann_configuration_default() {
431        let internal_config = InternalSpannConfiguration::default();
432
433        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
434        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
435        assert_eq!(internal_config.space, default_space_spann());
436        assert_eq!(
437            internal_config.ef_construction,
438            default_construction_ef_spann()
439        );
440        assert_eq!(internal_config.ef_search, default_search_ef_spann());
441        assert_eq!(internal_config.max_neighbors, default_m_spann());
442        assert_eq!(
443            internal_config.reassign_neighbor_count,
444            default_reassign_neighbor_count()
445        );
446        assert_eq!(internal_config.split_threshold, default_split_threshold());
447        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
448        assert_eq!(
449            internal_config.search_rng_factor,
450            default_search_rng_factor()
451        );
452        assert_eq!(
453            internal_config.search_rng_epsilon,
454            default_search_rng_epsilon()
455        );
456        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
457        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
458        assert_eq!(
459            internal_config.write_rng_epsilon,
460            default_write_rng_epsilon()
461        );
462        assert_eq!(
463            internal_config.num_samples_kmeans,
464            default_num_samples_kmeans()
465        );
466        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
467        assert_eq!(
468            internal_config.num_centers_to_merge_to,
469            default_num_centers_to_merge_to()
470        );
471    }
472
473    #[test]
474    fn test_spann_configuration_default() {
475        let spann_config = SpannConfiguration::default();
476        let internal_config: InternalSpannConfiguration = spann_config.into();
477
478        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
479        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
480        assert_eq!(internal_config.space, default_space_spann());
481        assert_eq!(
482            internal_config.ef_construction,
483            default_construction_ef_spann()
484        );
485        assert_eq!(internal_config.ef_search, default_search_ef_spann());
486        assert_eq!(internal_config.max_neighbors, default_m_spann());
487        assert_eq!(
488            internal_config.reassign_neighbor_count,
489            default_reassign_neighbor_count()
490        );
491        assert_eq!(internal_config.split_threshold, default_split_threshold());
492        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
493        assert_eq!(
494            internal_config.search_rng_factor,
495            default_search_rng_factor()
496        );
497        assert_eq!(
498            internal_config.search_rng_epsilon,
499            default_search_rng_epsilon()
500        );
501        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
502        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
503        assert_eq!(
504            internal_config.write_rng_epsilon,
505            default_write_rng_epsilon()
506        );
507        assert_eq!(
508            internal_config.num_samples_kmeans,
509            default_num_samples_kmeans()
510        );
511        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
512        assert_eq!(
513            internal_config.num_centers_to_merge_to,
514            default_num_centers_to_merge_to()
515        );
516    }
517
518    #[test]
519    fn test_deserialize_json_without_nreplica_count() {
520        let json_without_nreplica = r#"{
521            "search_nprobe": 120,
522            "search_rng_factor": 2.5,
523            "search_rng_epsilon": 15.0,
524            "write_nprobe": 60,
525            "write_rng_factor": 1.5,
526            "write_rng_epsilon": 8.0,
527            "split_threshold": 80,
528            "num_samples_kmeans": 1500,
529            "initial_lambda": 150.0,
530            "reassign_neighbor_count": 32,
531            "merge_threshold": 30,
532            "num_centers_to_merge_to": 6,
533            "space": "l2",
534            "ef_construction": 180,
535            "ef_search": 200,
536            "max_neighbors": 56
537        }"#;
538
539        let internal_config: InternalSpannConfiguration =
540            serde_json::from_str(json_without_nreplica).unwrap();
541
542        assert_eq!(internal_config.search_nprobe, 120);
543        assert_eq!(internal_config.search_rng_factor, 2.5);
544        assert_eq!(internal_config.search_rng_epsilon, 15.0);
545        assert_eq!(internal_config.write_nprobe, 60);
546        assert_eq!(internal_config.write_rng_factor, 1.5);
547        assert_eq!(internal_config.write_rng_epsilon, 8.0);
548        assert_eq!(internal_config.split_threshold, 80);
549        assert_eq!(internal_config.num_samples_kmeans, 1500);
550        assert_eq!(internal_config.initial_lambda, 150.0);
551        assert_eq!(internal_config.reassign_neighbor_count, 32);
552        assert_eq!(internal_config.merge_threshold, 30);
553        assert_eq!(internal_config.num_centers_to_merge_to, 6);
554        assert_eq!(internal_config.space, Space::L2);
555        assert_eq!(internal_config.ef_construction, 180);
556        assert_eq!(internal_config.ef_search, 200);
557        assert_eq!(internal_config.max_neighbors, 56);
558        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
559    }
560}