chroma_types/
spann_configuration.rs

1use crate::hnsw_configuration::Space;
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8    64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12    1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16    10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20    32
21}
22
23pub fn default_nreplica_count() -> u32 {
24    8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28    1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32    5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36    50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40    1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44    100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48    64
49}
50
51pub fn default_merge_threshold() -> u32 {
52    25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56    8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60    200
61}
62
63pub fn default_search_ef_spann() -> usize {
64    200
65}
66
67pub fn default_m_spann() -> usize {
68    64
69}
70
71fn default_space_spann() -> Space {
72    Space::L2
73}
74
75#[derive(Debug, Error)]
76pub enum DistributedSpannParametersFromSegmentError {
77    #[error("Invalid metadata: {0}")]
78    InvalidMetadata(#[from] serde_json::Error),
79    #[error("Invalid parameters: {0}")]
80    InvalidParameters(#[from] validator::ValidationErrors),
81}
82
83impl ChromaError for DistributedSpannParametersFromSegmentError {
84    fn code(&self) -> ErrorCodes {
85        match self {
86            DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
87                ErrorCodes::InvalidArgument
88            }
89            DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
90                ErrorCodes::InvalidArgument
91            }
92        }
93    }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98pub struct InternalSpannConfiguration {
99    #[serde(default = "default_search_nprobe")]
100    pub search_nprobe: u32,
101    #[serde(default = "default_search_rng_factor")]
102    pub search_rng_factor: f32,
103    #[serde(default = "default_search_rng_epsilon")]
104    pub search_rng_epsilon: f32,
105    #[serde(default = "default_write_nprobe")]
106    #[validate(range(max = 128))]
107    pub write_nprobe: u32,
108    #[serde(default = "default_nreplica_count")]
109    #[validate(range(max = 8))]
110    pub nreplica_count: u32,
111    #[serde(default = "default_write_rng_factor")]
112    pub write_rng_factor: f32,
113    #[serde(default = "default_write_rng_epsilon")]
114    pub write_rng_epsilon: f32,
115    #[serde(default = "default_split_threshold")]
116    #[validate(range(min = 25, max = 200))]
117    pub split_threshold: u32,
118    #[serde(default = "default_num_samples_kmeans")]
119    pub num_samples_kmeans: usize,
120    #[serde(default = "default_initial_lambda")]
121    pub initial_lambda: f32,
122    #[serde(default = "default_reassign_neighbor_count")]
123    #[validate(range(max = 64))]
124    pub reassign_neighbor_count: u32,
125    #[serde(default = "default_merge_threshold")]
126    #[validate(range(min = 12, max = 100))]
127    pub merge_threshold: u32,
128    #[serde(default = "default_num_centers_to_merge_to")]
129    #[validate(range(max = 8))]
130    pub num_centers_to_merge_to: u32,
131    #[serde(default = "default_space_spann")]
132    pub space: Space,
133    #[serde(default = "default_construction_ef_spann")]
134    #[validate(range(max = 200))]
135    pub ef_construction: usize,
136    #[serde(default = "default_search_ef_spann")]
137    #[validate(range(max = 200))]
138    pub ef_search: usize,
139    #[serde(default = "default_m_spann")]
140    #[validate(range(max = 64))]
141    pub max_neighbors: usize,
142}
143
144impl Default for InternalSpannConfiguration {
145    fn default() -> Self {
146        serde_json::from_str("{}").unwrap()
147    }
148}
149
150#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
151#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
152#[serde(deny_unknown_fields)]
153pub struct SpannConfiguration {
154    pub search_nprobe: Option<u32>,
155    pub write_nprobe: Option<u32>,
156    pub space: Option<Space>,
157    pub ef_construction: Option<usize>,
158    pub ef_search: Option<usize>,
159    pub max_neighbors: Option<usize>,
160    pub reassign_neighbor_count: Option<u32>,
161    pub split_threshold: Option<u32>,
162    pub merge_threshold: Option<u32>,
163}
164
165impl From<InternalSpannConfiguration> for SpannConfiguration {
166    fn from(config: InternalSpannConfiguration) -> Self {
167        Self {
168            search_nprobe: Some(config.search_nprobe),
169            write_nprobe: Some(config.write_nprobe),
170            space: Some(config.space),
171            ef_construction: Some(config.ef_construction),
172            ef_search: Some(config.ef_search),
173            max_neighbors: Some(config.max_neighbors),
174            reassign_neighbor_count: Some(config.reassign_neighbor_count),
175            split_threshold: Some(config.split_threshold),
176            merge_threshold: Some(config.merge_threshold),
177        }
178    }
179}
180
181impl From<SpannConfiguration> for InternalSpannConfiguration {
182    fn from(config: SpannConfiguration) -> Self {
183        Self {
184            search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
185            write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
186            space: config.space.unwrap_or(default_space_spann()),
187            ef_construction: config
188                .ef_construction
189                .unwrap_or(default_construction_ef_spann()),
190            ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
191            max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
192            reassign_neighbor_count: config
193                .reassign_neighbor_count
194                .unwrap_or(default_reassign_neighbor_count()),
195            split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
196            merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
197            ..Default::default()
198        }
199    }
200}
201
202impl Default for SpannConfiguration {
203    fn default() -> Self {
204        InternalSpannConfiguration::default().into()
205    }
206}
207
208#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
209#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
210#[serde(deny_unknown_fields)]
211#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
212pub struct UpdateSpannConfiguration {
213    pub search_nprobe: Option<u32>,
214    pub ef_search: Option<usize>,
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_spann_configuration_to_internal_spann_configuration() {
223        let spann_config = SpannConfiguration {
224            search_nprobe: Some(100),
225            write_nprobe: Some(50),
226            space: Some(Space::Cosine),
227            ef_construction: Some(150),
228            ef_search: Some(180),
229            max_neighbors: Some(32),
230            reassign_neighbor_count: Some(48),
231            split_threshold: Some(75),
232            merge_threshold: Some(50),
233        };
234
235        let internal_config: InternalSpannConfiguration = spann_config.into();
236
237        assert_eq!(internal_config.search_nprobe, 100);
238        assert_eq!(internal_config.write_nprobe, 50);
239        assert_eq!(internal_config.space, Space::Cosine);
240        assert_eq!(internal_config.ef_construction, 150);
241        assert_eq!(internal_config.ef_search, 180);
242        assert_eq!(internal_config.max_neighbors, 32);
243        assert_eq!(internal_config.reassign_neighbor_count, 48);
244        assert_eq!(internal_config.split_threshold, 75);
245        assert_eq!(internal_config.merge_threshold, 50);
246        assert_eq!(
247            internal_config.search_rng_factor,
248            default_search_rng_factor()
249        );
250        assert_eq!(
251            internal_config.search_rng_epsilon,
252            default_search_rng_epsilon()
253        );
254        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
255        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
256        assert_eq!(
257            internal_config.write_rng_epsilon,
258            default_write_rng_epsilon()
259        );
260        assert_eq!(
261            internal_config.num_samples_kmeans,
262            default_num_samples_kmeans()
263        );
264        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
265        assert_eq!(
266            internal_config.num_centers_to_merge_to,
267            default_num_centers_to_merge_to()
268        );
269    }
270
271    #[test]
272    fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
273        let spann_config = SpannConfiguration {
274            search_nprobe: None,
275            write_nprobe: None,
276            space: None,
277            ef_construction: None,
278            ef_search: None,
279            max_neighbors: None,
280            reassign_neighbor_count: None,
281            split_threshold: None,
282            merge_threshold: None,
283        };
284
285        let internal_config: InternalSpannConfiguration = spann_config.into();
286
287        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
288        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
289        assert_eq!(internal_config.space, default_space_spann());
290        assert_eq!(
291            internal_config.ef_construction,
292            default_construction_ef_spann()
293        );
294        assert_eq!(internal_config.ef_search, default_search_ef_spann());
295        assert_eq!(internal_config.max_neighbors, default_m_spann());
296        assert_eq!(
297            internal_config.reassign_neighbor_count,
298            default_reassign_neighbor_count()
299        );
300        assert_eq!(internal_config.split_threshold, default_split_threshold());
301        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
302        assert_eq!(
303            internal_config.search_rng_factor,
304            default_search_rng_factor()
305        );
306        assert_eq!(
307            internal_config.search_rng_epsilon,
308            default_search_rng_epsilon()
309        );
310        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
311        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
312        assert_eq!(
313            internal_config.write_rng_epsilon,
314            default_write_rng_epsilon()
315        );
316        assert_eq!(
317            internal_config.num_samples_kmeans,
318            default_num_samples_kmeans()
319        );
320        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
321        assert_eq!(
322            internal_config.num_centers_to_merge_to,
323            default_num_centers_to_merge_to()
324        );
325    }
326
327    #[test]
328    fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
329        let spann_config = SpannConfiguration {
330            search_nprobe: Some(80),
331            write_nprobe: None,
332            space: Some(Space::Ip),
333            ef_construction: None,
334            ef_search: Some(160),
335            max_neighbors: Some(48),
336            reassign_neighbor_count: None,
337            split_threshold: Some(100),
338            merge_threshold: None,
339        };
340
341        let internal_config: InternalSpannConfiguration = spann_config.into();
342
343        assert_eq!(internal_config.search_nprobe, 80);
344        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
345        assert_eq!(internal_config.space, Space::Ip);
346        assert_eq!(
347            internal_config.ef_construction,
348            default_construction_ef_spann()
349        );
350        assert_eq!(internal_config.ef_search, 160);
351        assert_eq!(internal_config.max_neighbors, 48);
352        assert_eq!(
353            internal_config.reassign_neighbor_count,
354            default_reassign_neighbor_count()
355        );
356        assert_eq!(internal_config.split_threshold, 100);
357        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
358        assert_eq!(
359            internal_config.search_rng_factor,
360            default_search_rng_factor()
361        );
362        assert_eq!(
363            internal_config.search_rng_epsilon,
364            default_search_rng_epsilon()
365        );
366        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
367        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
368        assert_eq!(
369            internal_config.write_rng_epsilon,
370            default_write_rng_epsilon()
371        );
372        assert_eq!(
373            internal_config.num_samples_kmeans,
374            default_num_samples_kmeans()
375        );
376        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
377        assert_eq!(
378            internal_config.num_centers_to_merge_to,
379            default_num_centers_to_merge_to()
380        );
381    }
382
383    #[test]
384    fn test_internal_spann_configuration_default() {
385        let internal_config = InternalSpannConfiguration::default();
386
387        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
388        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
389        assert_eq!(internal_config.space, default_space_spann());
390        assert_eq!(
391            internal_config.ef_construction,
392            default_construction_ef_spann()
393        );
394        assert_eq!(internal_config.ef_search, default_search_ef_spann());
395        assert_eq!(internal_config.max_neighbors, default_m_spann());
396        assert_eq!(
397            internal_config.reassign_neighbor_count,
398            default_reassign_neighbor_count()
399        );
400        assert_eq!(internal_config.split_threshold, default_split_threshold());
401        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
402        assert_eq!(
403            internal_config.search_rng_factor,
404            default_search_rng_factor()
405        );
406        assert_eq!(
407            internal_config.search_rng_epsilon,
408            default_search_rng_epsilon()
409        );
410        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
411        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
412        assert_eq!(
413            internal_config.write_rng_epsilon,
414            default_write_rng_epsilon()
415        );
416        assert_eq!(
417            internal_config.num_samples_kmeans,
418            default_num_samples_kmeans()
419        );
420        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
421        assert_eq!(
422            internal_config.num_centers_to_merge_to,
423            default_num_centers_to_merge_to()
424        );
425    }
426
427    #[test]
428    fn test_spann_configuration_default() {
429        let spann_config = SpannConfiguration::default();
430        let internal_config: InternalSpannConfiguration = spann_config.into();
431
432        assert_eq!(internal_config.search_nprobe, default_search_nprobe());
433        assert_eq!(internal_config.write_nprobe, default_write_nprobe());
434        assert_eq!(internal_config.space, default_space_spann());
435        assert_eq!(
436            internal_config.ef_construction,
437            default_construction_ef_spann()
438        );
439        assert_eq!(internal_config.ef_search, default_search_ef_spann());
440        assert_eq!(internal_config.max_neighbors, default_m_spann());
441        assert_eq!(
442            internal_config.reassign_neighbor_count,
443            default_reassign_neighbor_count()
444        );
445        assert_eq!(internal_config.split_threshold, default_split_threshold());
446        assert_eq!(internal_config.merge_threshold, default_merge_threshold());
447        assert_eq!(
448            internal_config.search_rng_factor,
449            default_search_rng_factor()
450        );
451        assert_eq!(
452            internal_config.search_rng_epsilon,
453            default_search_rng_epsilon()
454        );
455        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
456        assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
457        assert_eq!(
458            internal_config.write_rng_epsilon,
459            default_write_rng_epsilon()
460        );
461        assert_eq!(
462            internal_config.num_samples_kmeans,
463            default_num_samples_kmeans()
464        );
465        assert_eq!(internal_config.initial_lambda, default_initial_lambda());
466        assert_eq!(
467            internal_config.num_centers_to_merge_to,
468            default_num_centers_to_merge_to()
469        );
470    }
471
472    #[test]
473    fn test_deserialize_json_without_nreplica_count() {
474        let json_without_nreplica = r#"{
475            "search_nprobe": 120,
476            "search_rng_factor": 2.5,
477            "search_rng_epsilon": 15.0,
478            "write_nprobe": 60,
479            "write_rng_factor": 1.5,
480            "write_rng_epsilon": 8.0,
481            "split_threshold": 80,
482            "num_samples_kmeans": 1500,
483            "initial_lambda": 150.0,
484            "reassign_neighbor_count": 32,
485            "merge_threshold": 30,
486            "num_centers_to_merge_to": 6,
487            "space": "l2",
488            "ef_construction": 180,
489            "ef_search": 200,
490            "max_neighbors": 56
491        }"#;
492
493        let internal_config: InternalSpannConfiguration =
494            serde_json::from_str(json_without_nreplica).unwrap();
495
496        assert_eq!(internal_config.search_nprobe, 120);
497        assert_eq!(internal_config.search_rng_factor, 2.5);
498        assert_eq!(internal_config.search_rng_epsilon, 15.0);
499        assert_eq!(internal_config.write_nprobe, 60);
500        assert_eq!(internal_config.write_rng_factor, 1.5);
501        assert_eq!(internal_config.write_rng_epsilon, 8.0);
502        assert_eq!(internal_config.split_threshold, 80);
503        assert_eq!(internal_config.num_samples_kmeans, 1500);
504        assert_eq!(internal_config.initial_lambda, 150.0);
505        assert_eq!(internal_config.reassign_neighbor_count, 32);
506        assert_eq!(internal_config.merge_threshold, 30);
507        assert_eq!(internal_config.num_centers_to_merge_to, 6);
508        assert_eq!(internal_config.space, Space::L2);
509        assert_eq!(internal_config.ef_construction, 180);
510        assert_eq!(internal_config.ef_search, 200);
511        assert_eq!(internal_config.max_neighbors, 56);
512        assert_eq!(internal_config.nreplica_count, default_nreplica_count());
513    }
514}