chroma_types/
spann_configuration.rs

1use crate::{default_space, hnsw_configuration::Space, SpannIndexConfig};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8    64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12    1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16    10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20    32
21}
22
23pub fn default_nreplica_count() -> u32 {
24    8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28    1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32    5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36    50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40    1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44    100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48    64
49}
50
51pub fn default_merge_threshold() -> u32 {
52    25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56    8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60    200
61}
62
63pub fn default_search_ef_spann() -> usize {
64    200
65}
66
67pub fn default_m_spann() -> usize {
68    64
69}
70
71fn default_space_spann() -> Space {
72    Space::L2
73}
74
75#[derive(Debug, Error)]
76pub enum DistributedSpannParametersFromSegmentError {
77    #[error("Invalid metadata: {0}")]
78    InvalidMetadata(#[from] serde_json::Error),
79    #[error("Invalid parameters: {0}")]
80    InvalidParameters(#[from] validator::ValidationErrors),
81}
82
83impl ChromaError for DistributedSpannParametersFromSegmentError {
84    fn code(&self) -> ErrorCodes {
85        match self {
86            DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
87                ErrorCodes::InvalidArgument
88            }
89            DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
90                ErrorCodes::InvalidArgument
91            }
92        }
93    }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98pub struct InternalSpannConfiguration {
99    #[serde(default = "default_search_nprobe")]
100    pub search_nprobe: u32,
101    #[serde(default = "default_search_rng_factor")]
102    pub search_rng_factor: f32,
103    #[serde(default = "default_search_rng_epsilon")]
104    pub search_rng_epsilon: f32,
105    #[serde(default = "default_write_nprobe")]
106    #[validate(range(max = 128))]
107    pub write_nprobe: u32,
108    #[serde(default = "default_nreplica_count")]
109    #[validate(range(max = 8))]
110    pub nreplica_count: u32,
111    #[serde(default = "default_write_rng_factor")]
112    pub write_rng_factor: f32,
113    #[serde(default = "default_write_rng_epsilon")]
114    pub write_rng_epsilon: f32,
115    #[serde(default = "default_split_threshold")]
116    #[validate(range(min = 25, max = 200))]
117    pub split_threshold: u32,
118    #[serde(default = "default_num_samples_kmeans")]
119    pub num_samples_kmeans: usize,
120    #[serde(default = "default_initial_lambda")]
121    pub initial_lambda: f32,
122    #[serde(default = "default_reassign_neighbor_count")]
123    #[validate(range(max = 64))]
124    pub reassign_neighbor_count: u32,
125    #[serde(default = "default_merge_threshold")]
126    #[validate(range(min = 12, max = 100))]
127    pub merge_threshold: u32,
128    #[serde(default = "default_num_centers_to_merge_to")]
129    #[validate(range(max = 8))]
130    pub num_centers_to_merge_to: u32,
131    #[serde(default = "default_space_spann")]
132    pub space: Space,
133    #[serde(default = "default_construction_ef_spann")]
134    #[validate(range(max = 200))]
135    pub ef_construction: usize,
136    #[serde(default = "default_search_ef_spann")]
137    #[validate(range(max = 200))]
138    pub ef_search: usize,
139    #[serde(default = "default_m_spann")]
140    #[validate(range(max = 64))]
141    pub max_neighbors: usize,
142}
143
144impl Default for InternalSpannConfiguration {
145    fn default() -> Self {
146        serde_json::from_str("{}").unwrap()
147    }
148}
149
150impl From<(Option<&Space>, &SpannIndexConfig)> for InternalSpannConfiguration {
151    fn from((space, config): (Option<&Space>, &SpannIndexConfig)) -> Self {
152        InternalSpannConfiguration {
153            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
154            search_rng_factor: config
155                .search_rng_factor
156                .unwrap_or(default_search_rng_factor()),
157            search_rng_epsilon: config
158                .search_rng_epsilon
159                .unwrap_or(default_search_rng_epsilon()),
160            nreplica_count: config.nreplica_count.unwrap_or(default_nreplica_count()),
161            write_rng_factor: config
162                .write_rng_factor
163                .unwrap_or(default_write_rng_factor()),
164            write_rng_epsilon: config
165                .write_rng_epsilon
166                .unwrap_or(default_write_rng_epsilon()),
167            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
168            num_samples_kmeans: config
169                .num_samples_kmeans
170                .unwrap_or(default_num_samples_kmeans()),
171            initial_lambda: config.initial_lambda.unwrap_or(default_initial_lambda()),
172            reassign_neighbor_count: config
173                .reassign_neighbor_count
174                .unwrap_or(default_reassign_neighbor_count()),
175            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
176            num_centers_to_merge_to: config
177                .num_centers_to_merge_to
178                .unwrap_or(default_num_centers_to_merge_to()),
179            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
180            ef_construction: config
181                .ef_construction
182                .unwrap_or(default_construction_ef_spann()),
183            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
184            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
185            space: space.unwrap_or(&default_space()).clone(),
186        }
187    }
188}
189
190#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
191#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
192#[serde(deny_unknown_fields)]
193pub struct SpannConfiguration {
194    pub search_nprobe: Option<u32>,
195    pub write_nprobe: Option<u32>,
196    pub space: Option<Space>,
197    pub ef_construction: Option<usize>,
198    pub ef_search: Option<usize>,
199    pub max_neighbors: Option<usize>,
200    pub reassign_neighbor_count: Option<u32>,
201    pub split_threshold: Option<u32>,
202    pub merge_threshold: Option<u32>,
203}
204
205impl From<InternalSpannConfiguration> for SpannConfiguration {
206    fn from(config: InternalSpannConfiguration) -> Self {
207        Self {
208            search_nprobe: Some(config.search_nprobe),
209            write_nprobe: Some(config.write_nprobe),
210            space: Some(config.space),
211            ef_construction: Some(config.ef_construction),
212            ef_search: Some(config.ef_search),
213            max_neighbors: Some(config.max_neighbors),
214            reassign_neighbor_count: Some(config.reassign_neighbor_count),
215            split_threshold: Some(config.split_threshold),
216            merge_threshold: Some(config.merge_threshold),
217        }
218    }
219}
220
221impl From<SpannConfiguration> for InternalSpannConfiguration {
222    fn from(config: SpannConfiguration) -> Self {
223        Self {
224            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
225            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
226            space: config.space.unwrap_or(default_space_spann()),
227            ef_construction: config
228                .ef_construction
229                .unwrap_or(default_construction_ef_spann()),
230            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
231            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
232            reassign_neighbor_count: config
233                .reassign_neighbor_count
234                .unwrap_or(default_reassign_neighbor_count()),
235            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
236            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
237            ..Default::default()
238        }
239    }
240}
241
242impl Default for SpannConfiguration {
243    fn default() -> Self {
244        InternalSpannConfiguration::default().into()
245    }
246}
247
248#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
249#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
250#[serde(deny_unknown_fields)]
251#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
252pub struct UpdateSpannConfiguration {
253    pub search_nprobe: Option<u32>,
254    pub ef_search: Option<usize>,
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_spann_configuration_to_internal_spann_configuration() {
263        let spann_config = SpannConfiguration {
264            search_nprobe: Some(100),
265            write_nprobe: Some(50),
266            space: Some(Space::Cosine),
267            ef_construction: Some(150),
268            ef_search: Some(180),
269            max_neighbors: Some(32),
270            reassign_neighbor_count: Some(48),
271            split_threshold: Some(75),
272            merge_threshold: Some(50),
273        };
274
275        let internal_config: InternalSpannConfiguration = spann_config.into();
276
277        assert_eq!(internal_config.search_nprobe, 100);
278        assert_eq!(internal_config.write_nprobe, 50);
279        assert_eq!(internal_config.space, Space::Cosine);
280        assert_eq!(internal_config.ef_construction, 150);
281        assert_eq!(internal_config.ef_search, 180);
282        assert_eq!(internal_config.max_neighbors, 32);
283        assert_eq!(internal_config.reassign_neighbor_count, 48);
284        assert_eq!(internal_config.split_threshold, 75);
285        assert_eq!(internal_config.merge_threshold, 50);
286        assert_eq!(
287            internal_config.search_rng_factor,
288            default_search_rng_factor()
289        );
290        assert_eq!(
291            internal_config.search_rng_epsilon,
292            default_search_rng_epsilon()
293        );
294        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
295        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
296        assert_eq!(
297            internal_config.write_rng_epsilon,
298            default_write_rng_epsilon()
299        );
300        assert_eq!(
301            internal_config.num_samples_kmeans,
302            default_num_samples_kmeans()
303        );
304        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
305        assert_eq!(
306            internal_config.num_centers_to_merge_to,
307            default_num_centers_to_merge_to()
308        );
309    }
310
311    #[test]
312    fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
313        let spann_config = SpannConfiguration {
314            search_nprobe: None,
315            write_nprobe: None,
316            space: None,
317            ef_construction: None,
318            ef_search: None,
319            max_neighbors: None,
320            reassign_neighbor_count: None,
321            split_threshold: None,
322            merge_threshold: None,
323        };
324
325        let internal_config: InternalSpannConfiguration = spann_config.into();
326
327        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
328        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
329        assert_eq!(internal_config.space, default_space_spann());
330        assert_eq!(
331            internal_config.ef_construction,
332            default_construction_ef_spann()
333        );
334        assert_eq!(internal_config.ef_search, default_search_ef_spann());
335        assert_eq!(internal_config.max_neighbors, default_m_spann());
336        assert_eq!(
337            internal_config.reassign_neighbor_count,
338            default_reassign_neighbor_count()
339        );
340        assert_eq!(internal_config.split_threshold, default_split_threshold());
341        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
342        assert_eq!(
343            internal_config.search_rng_factor,
344            default_search_rng_factor()
345        );
346        assert_eq!(
347            internal_config.search_rng_epsilon,
348            default_search_rng_epsilon()
349        );
350        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
351        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
352        assert_eq!(
353            internal_config.write_rng_epsilon,
354            default_write_rng_epsilon()
355        );
356        assert_eq!(
357            internal_config.num_samples_kmeans,
358            default_num_samples_kmeans()
359        );
360        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
361        assert_eq!(
362            internal_config.num_centers_to_merge_to,
363            default_num_centers_to_merge_to()
364        );
365    }
366
367    #[test]
368    fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
369        let spann_config = SpannConfiguration {
370            search_nprobe: Some(80),
371            write_nprobe: None,
372            space: Some(Space::Ip),
373            ef_construction: None,
374            ef_search: Some(160),
375            max_neighbors: Some(48),
376            reassign_neighbor_count: None,
377            split_threshold: Some(100),
378            merge_threshold: None,
379        };
380
381        let internal_config: InternalSpannConfiguration = spann_config.into();
382
383        assert_eq!(internal_config.search_nprobe, 80);
384        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
385        assert_eq!(internal_config.space, Space::Ip);
386        assert_eq!(
387            internal_config.ef_construction,
388            default_construction_ef_spann()
389        );
390        assert_eq!(internal_config.ef_search, 160);
391        assert_eq!(internal_config.max_neighbors, 48);
392        assert_eq!(
393            internal_config.reassign_neighbor_count,
394            default_reassign_neighbor_count()
395        );
396        assert_eq!(internal_config.split_threshold, 100);
397        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
398        assert_eq!(
399            internal_config.search_rng_factor,
400            default_search_rng_factor()
401        );
402        assert_eq!(
403            internal_config.search_rng_epsilon,
404            default_search_rng_epsilon()
405        );
406        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
407        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
408        assert_eq!(
409            internal_config.write_rng_epsilon,
410            default_write_rng_epsilon()
411        );
412        assert_eq!(
413            internal_config.num_samples_kmeans,
414            default_num_samples_kmeans()
415        );
416        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
417        assert_eq!(
418            internal_config.num_centers_to_merge_to,
419            default_num_centers_to_merge_to()
420        );
421    }
422
423    #[test]
424    fn test_internal_spann_configuration_default() {
425        let internal_config = InternalSpannConfiguration::default();
426
427        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
428        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
429        assert_eq!(internal_config.space, default_space_spann());
430        assert_eq!(
431            internal_config.ef_construction,
432            default_construction_ef_spann()
433        );
434        assert_eq!(internal_config.ef_search, default_search_ef_spann());
435        assert_eq!(internal_config.max_neighbors, default_m_spann());
436        assert_eq!(
437            internal_config.reassign_neighbor_count,
438            default_reassign_neighbor_count()
439        );
440        assert_eq!(internal_config.split_threshold, default_split_threshold());
441        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
442        assert_eq!(
443            internal_config.search_rng_factor,
444            default_search_rng_factor()
445        );
446        assert_eq!(
447            internal_config.search_rng_epsilon,
448            default_search_rng_epsilon()
449        );
450        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
451        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
452        assert_eq!(
453            internal_config.write_rng_epsilon,
454            default_write_rng_epsilon()
455        );
456        assert_eq!(
457            internal_config.num_samples_kmeans,
458            default_num_samples_kmeans()
459        );
460        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
461        assert_eq!(
462            internal_config.num_centers_to_merge_to,
463            default_num_centers_to_merge_to()
464        );
465    }
466
467    #[test]
468    fn test_spann_configuration_default() {
469        let spann_config = SpannConfiguration::default();
470        let internal_config: InternalSpannConfiguration = spann_config.into();
471
472        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
473        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
474        assert_eq!(internal_config.space, default_space_spann());
475        assert_eq!(
476            internal_config.ef_construction,
477            default_construction_ef_spann()
478        );
479        assert_eq!(internal_config.ef_search, default_search_ef_spann());
480        assert_eq!(internal_config.max_neighbors, default_m_spann());
481        assert_eq!(
482            internal_config.reassign_neighbor_count,
483            default_reassign_neighbor_count()
484        );
485        assert_eq!(internal_config.split_threshold, default_split_threshold());
486        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
487        assert_eq!(
488            internal_config.search_rng_factor,
489            default_search_rng_factor()
490        );
491        assert_eq!(
492            internal_config.search_rng_epsilon,
493            default_search_rng_epsilon()
494        );
495        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
496        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
497        assert_eq!(
498            internal_config.write_rng_epsilon,
499            default_write_rng_epsilon()
500        );
501        assert_eq!(
502            internal_config.num_samples_kmeans,
503            default_num_samples_kmeans()
504        );
505        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
506        assert_eq!(
507            internal_config.num_centers_to_merge_to,
508            default_num_centers_to_merge_to()
509        );
510    }
511
512    #[test]
513    fn test_deserialize_json_without_nreplica_count() {
514        let json_without_nreplica = r#"{
515            "search_nprobe": 120,
516            "search_rng_factor": 2.5,
517            "search_rng_epsilon": 15.0,
518            "write_nprobe": 60,
519            "write_rng_factor": 1.5,
520            "write_rng_epsilon": 8.0,
521            "split_threshold": 80,
522            "num_samples_kmeans": 1500,
523            "initial_lambda": 150.0,
524            "reassign_neighbor_count": 32,
525            "merge_threshold": 30,
526            "num_centers_to_merge_to": 6,
527            "space": "l2",
528            "ef_construction": 180,
529            "ef_search": 200,
530            "max_neighbors": 56
531        }"#;
532
533        let internal_config: InternalSpannConfiguration =
534            serde_json::from_str(json_without_nreplica).unwrap();
535
536        assert_eq!(internal_config.search_nprobe, 120);
537        assert_eq!(internal_config.search_rng_factor, 2.5);
538        assert_eq!(internal_config.search_rng_epsilon, 15.0);
539        assert_eq!(internal_config.write_nprobe, 60);
540        assert_eq!(internal_config.write_rng_factor, 1.5);
541        assert_eq!(internal_config.write_rng_epsilon, 8.0);
542        assert_eq!(internal_config.split_threshold, 80);
543        assert_eq!(internal_config.num_samples_kmeans, 1500);
544        assert_eq!(internal_config.initial_lambda, 150.0);
545        assert_eq!(internal_config.reassign_neighbor_count, 32);
546        assert_eq!(internal_config.merge_threshold, 30);
547        assert_eq!(internal_config.num_centers_to_merge_to, 6);
548        assert_eq!(internal_config.space, Space::L2);
549        assert_eq!(internal_config.ef_construction, 180);
550        assert_eq!(internal_config.ef_search, 200);
551        assert_eq!(internal_config.max_neighbors, 56);
552        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
553    }
554}