1use crate::hnsw_configuration::Space;
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use validator::Validate;
6
7pub fn default_search_nprobe() -> u32 {
8 64
9}
10
11pub fn default_search_rng_factor() -> f32 {
12 1.0
13}
14
15pub fn default_search_rng_epsilon() -> f32 {
16 10.0
17}
18
19pub fn default_write_nprobe() -> u32 {
20 32
21}
22
23pub fn default_nreplica_count() -> u32 {
24 8
25}
26
27pub fn default_write_rng_factor() -> f32 {
28 1.0
29}
30
31pub fn default_write_rng_epsilon() -> f32 {
32 5.0
33}
34
35pub fn default_split_threshold() -> u32 {
36 50
37}
38
39pub fn default_num_samples_kmeans() -> usize {
40 1000
41}
42
43pub fn default_initial_lambda() -> f32 {
44 100.0
45}
46
47pub fn default_reassign_neighbor_count() -> u32 {
48 64
49}
50
51pub fn default_merge_threshold() -> u32 {
52 25
53}
54
55pub fn default_num_centers_to_merge_to() -> u32 {
56 8
57}
58
59pub fn default_construction_ef_spann() -> usize {
60 200
61}
62
63pub fn default_search_ef_spann() -> usize {
64 200
65}
66
67pub fn default_m_spann() -> usize {
68 64
69}
70
71fn default_space_spann() -> Space {
72 Space::L2
73}
74
75#[derive(Debug, Error)]
76pub enum DistributedSpannParametersFromSegmentError {
77 #[error("Invalid metadata: {0}")]
78 InvalidMetadata(#[from] serde_json::Error),
79 #[error("Invalid parameters: {0}")]
80 InvalidParameters(#[from] validator::ValidationErrors),
81}
82
83impl ChromaError for DistributedSpannParametersFromSegmentError {
84 fn code(&self) -> ErrorCodes {
85 match self {
86 DistributedSpannParametersFromSegmentError::InvalidMetadata(_) => {
87 ErrorCodes::InvalidArgument
88 }
89 DistributedSpannParametersFromSegmentError::InvalidParameters(_) => {
90 ErrorCodes::InvalidArgument
91 }
92 }
93 }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
97#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
98pub struct InternalSpannConfiguration {
99 #[serde(default = "default_search_nprobe")]
100 pub search_nprobe: u32,
101 #[serde(default = "default_search_rng_factor")]
102 pub search_rng_factor: f32,
103 #[serde(default = "default_search_rng_epsilon")]
104 pub search_rng_epsilon: f32,
105 #[serde(default = "default_write_nprobe")]
106 #[validate(range(max = 128))]
107 pub write_nprobe: u32,
108 #[serde(default = "default_nreplica_count")]
109 #[validate(range(max = 8))]
110 pub nreplica_count: u32,
111 #[serde(default = "default_write_rng_factor")]
112 pub write_rng_factor: f32,
113 #[serde(default = "default_write_rng_epsilon")]
114 pub write_rng_epsilon: f32,
115 #[serde(default = "default_split_threshold")]
116 #[validate(range(min = 25, max = 200))]
117 pub split_threshold: u32,
118 #[serde(default = "default_num_samples_kmeans")]
119 pub num_samples_kmeans: usize,
120 #[serde(default = "default_initial_lambda")]
121 pub initial_lambda: f32,
122 #[serde(default = "default_reassign_neighbor_count")]
123 #[validate(range(max = 64))]
124 pub reassign_neighbor_count: u32,
125 #[serde(default = "default_merge_threshold")]
126 #[validate(range(min = 12, max = 100))]
127 pub merge_threshold: u32,
128 #[serde(default = "default_num_centers_to_merge_to")]
129 #[validate(range(max = 8))]
130 pub num_centers_to_merge_to: u32,
131 #[serde(default = "default_space_spann")]
132 pub space: Space,
133 #[serde(default = "default_construction_ef_spann")]
134 #[validate(range(max = 200))]
135 pub ef_construction: usize,
136 #[serde(default = "default_search_ef_spann")]
137 #[validate(range(max = 200))]
138 pub ef_search: usize,
139 #[serde(default = "default_m_spann")]
140 #[validate(range(max = 64))]
141 pub max_neighbors: usize,
142}
143
144impl Default for InternalSpannConfiguration {
145 fn default() -> Self {
146 serde_json::from_str("{}").unwrap()
147 }
148}
149
150#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
151#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
152#[serde(deny_unknown_fields)]
153pub struct SpannConfiguration {
154 pub search_nprobe: Option<u32>,
155 pub write_nprobe: Option<u32>,
156 pub space: Option<Space>,
157 pub ef_construction: Option<usize>,
158 pub ef_search: Option<usize>,
159 pub max_neighbors: Option<usize>,
160 pub reassign_neighbor_count: Option<u32>,
161 pub split_threshold: Option<u32>,
162 pub merge_threshold: Option<u32>,
163}
164
165impl From<InternalSpannConfiguration> for SpannConfiguration {
166 fn from(config: InternalSpannConfiguration) -> Self {
167 Self {
168 search_nprobe: Some(config.search_nprobe),
169 write_nprobe: Some(config.write_nprobe),
170 space: Some(config.space),
171 ef_construction: Some(config.ef_construction),
172 ef_search: Some(config.ef_search),
173 max_neighbors: Some(config.max_neighbors),
174 reassign_neighbor_count: Some(config.reassign_neighbor_count),
175 split_threshold: Some(config.split_threshold),
176 merge_threshold: Some(config.merge_threshold),
177 }
178 }
179}
180
181impl From<SpannConfiguration> for InternalSpannConfiguration {
182 fn from(config: SpannConfiguration) -> Self {
183 Self {
184 search_nprobe: config.search_nprobe.unwrap_or(default_search_nprobe()),
185 write_nprobe: config.write_nprobe.unwrap_or(default_write_nprobe()),
186 space: config.space.unwrap_or(default_space_spann()),
187 ef_construction: config
188 .ef_construction
189 .unwrap_or(default_construction_ef_spann()),
190 ef_search: config.ef_search.unwrap_or(default_search_ef_spann()),
191 max_neighbors: config.max_neighbors.unwrap_or(default_m_spann()),
192 reassign_neighbor_count: config
193 .reassign_neighbor_count
194 .unwrap_or(default_reassign_neighbor_count()),
195 split_threshold: config.split_threshold.unwrap_or(default_split_threshold()),
196 merge_threshold: config.merge_threshold.unwrap_or(default_merge_threshold()),
197 ..Default::default()
198 }
199 }
200}
201
202impl Default for SpannConfiguration {
203 fn default() -> Self {
204 InternalSpannConfiguration::default().into()
205 }
206}
207
208#[derive(Clone, Default, Debug, Serialize, Deserialize, Validate, PartialEq)]
209#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
210#[serde(deny_unknown_fields)]
211#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
212pub struct UpdateSpannConfiguration {
213 pub search_nprobe: Option<u32>,
214 pub ef_search: Option<usize>,
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_spann_configuration_to_internal_spann_configuration() {
223 let spann_config = SpannConfiguration {
224 search_nprobe: Some(100),
225 write_nprobe: Some(50),
226 space: Some(Space::Cosine),
227 ef_construction: Some(150),
228 ef_search: Some(180),
229 max_neighbors: Some(32),
230 reassign_neighbor_count: Some(48),
231 split_threshold: Some(75),
232 merge_threshold: Some(50),
233 };
234
235 let internal_config: InternalSpannConfiguration = spann_config.into();
236
237 assert_eq!(internal_config.search_nprobe, 100);
238 assert_eq!(internal_config.write_nprobe, 50);
239 assert_eq!(internal_config.space, Space::Cosine);
240 assert_eq!(internal_config.ef_construction, 150);
241 assert_eq!(internal_config.ef_search, 180);
242 assert_eq!(internal_config.max_neighbors, 32);
243 assert_eq!(internal_config.reassign_neighbor_count, 48);
244 assert_eq!(internal_config.split_threshold, 75);
245 assert_eq!(internal_config.merge_threshold, 50);
246 assert_eq!(
247 internal_config.search_rng_factor,
248 default_search_rng_factor()
249 );
250 assert_eq!(
251 internal_config.search_rng_epsilon,
252 default_search_rng_epsilon()
253 );
254 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
255 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
256 assert_eq!(
257 internal_config.write_rng_epsilon,
258 default_write_rng_epsilon()
259 );
260 assert_eq!(
261 internal_config.num_samples_kmeans,
262 default_num_samples_kmeans()
263 );
264 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
265 assert_eq!(
266 internal_config.num_centers_to_merge_to,
267 default_num_centers_to_merge_to()
268 );
269 }
270
271 #[test]
272 fn test_spann_configuration_to_internal_spann_configuration_with_none_values() {
273 let spann_config = SpannConfiguration {
274 search_nprobe: None,
275 write_nprobe: None,
276 space: None,
277 ef_construction: None,
278 ef_search: None,
279 max_neighbors: None,
280 reassign_neighbor_count: None,
281 split_threshold: None,
282 merge_threshold: None,
283 };
284
285 let internal_config: InternalSpannConfiguration = spann_config.into();
286
287 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
288 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
289 assert_eq!(internal_config.space, default_space_spann());
290 assert_eq!(
291 internal_config.ef_construction,
292 default_construction_ef_spann()
293 );
294 assert_eq!(internal_config.ef_search, default_search_ef_spann());
295 assert_eq!(internal_config.max_neighbors, default_m_spann());
296 assert_eq!(
297 internal_config.reassign_neighbor_count,
298 default_reassign_neighbor_count()
299 );
300 assert_eq!(internal_config.split_threshold, default_split_threshold());
301 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
302 assert_eq!(
303 internal_config.search_rng_factor,
304 default_search_rng_factor()
305 );
306 assert_eq!(
307 internal_config.search_rng_epsilon,
308 default_search_rng_epsilon()
309 );
310 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
311 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
312 assert_eq!(
313 internal_config.write_rng_epsilon,
314 default_write_rng_epsilon()
315 );
316 assert_eq!(
317 internal_config.num_samples_kmeans,
318 default_num_samples_kmeans()
319 );
320 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
321 assert_eq!(
322 internal_config.num_centers_to_merge_to,
323 default_num_centers_to_merge_to()
324 );
325 }
326
327 #[test]
328 fn test_spann_configuration_to_internal_spann_configuration_mixed_values() {
329 let spann_config = SpannConfiguration {
330 search_nprobe: Some(80),
331 write_nprobe: None,
332 space: Some(Space::Ip),
333 ef_construction: None,
334 ef_search: Some(160),
335 max_neighbors: Some(48),
336 reassign_neighbor_count: None,
337 split_threshold: Some(100),
338 merge_threshold: None,
339 };
340
341 let internal_config: InternalSpannConfiguration = spann_config.into();
342
343 assert_eq!(internal_config.search_nprobe, 80);
344 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
345 assert_eq!(internal_config.space, Space::Ip);
346 assert_eq!(
347 internal_config.ef_construction,
348 default_construction_ef_spann()
349 );
350 assert_eq!(internal_config.ef_search, 160);
351 assert_eq!(internal_config.max_neighbors, 48);
352 assert_eq!(
353 internal_config.reassign_neighbor_count,
354 default_reassign_neighbor_count()
355 );
356 assert_eq!(internal_config.split_threshold, 100);
357 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
358 assert_eq!(
359 internal_config.search_rng_factor,
360 default_search_rng_factor()
361 );
362 assert_eq!(
363 internal_config.search_rng_epsilon,
364 default_search_rng_epsilon()
365 );
366 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
367 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
368 assert_eq!(
369 internal_config.write_rng_epsilon,
370 default_write_rng_epsilon()
371 );
372 assert_eq!(
373 internal_config.num_samples_kmeans,
374 default_num_samples_kmeans()
375 );
376 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
377 assert_eq!(
378 internal_config.num_centers_to_merge_to,
379 default_num_centers_to_merge_to()
380 );
381 }
382
383 #[test]
384 fn test_internal_spann_configuration_default() {
385 let internal_config = InternalSpannConfiguration::default();
386
387 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
388 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
389 assert_eq!(internal_config.space, default_space_spann());
390 assert_eq!(
391 internal_config.ef_construction,
392 default_construction_ef_spann()
393 );
394 assert_eq!(internal_config.ef_search, default_search_ef_spann());
395 assert_eq!(internal_config.max_neighbors, default_m_spann());
396 assert_eq!(
397 internal_config.reassign_neighbor_count,
398 default_reassign_neighbor_count()
399 );
400 assert_eq!(internal_config.split_threshold, default_split_threshold());
401 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
402 assert_eq!(
403 internal_config.search_rng_factor,
404 default_search_rng_factor()
405 );
406 assert_eq!(
407 internal_config.search_rng_epsilon,
408 default_search_rng_epsilon()
409 );
410 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
411 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
412 assert_eq!(
413 internal_config.write_rng_epsilon,
414 default_write_rng_epsilon()
415 );
416 assert_eq!(
417 internal_config.num_samples_kmeans,
418 default_num_samples_kmeans()
419 );
420 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
421 assert_eq!(
422 internal_config.num_centers_to_merge_to,
423 default_num_centers_to_merge_to()
424 );
425 }
426
427 #[test]
428 fn test_spann_configuration_default() {
429 let spann_config = SpannConfiguration::default();
430 let internal_config: InternalSpannConfiguration = spann_config.into();
431
432 assert_eq!(internal_config.search_nprobe, default_search_nprobe());
433 assert_eq!(internal_config.write_nprobe, default_write_nprobe());
434 assert_eq!(internal_config.space, default_space_spann());
435 assert_eq!(
436 internal_config.ef_construction,
437 default_construction_ef_spann()
438 );
439 assert_eq!(internal_config.ef_search, default_search_ef_spann());
440 assert_eq!(internal_config.max_neighbors, default_m_spann());
441 assert_eq!(
442 internal_config.reassign_neighbor_count,
443 default_reassign_neighbor_count()
444 );
445 assert_eq!(internal_config.split_threshold, default_split_threshold());
446 assert_eq!(internal_config.merge_threshold, default_merge_threshold());
447 assert_eq!(
448 internal_config.search_rng_factor,
449 default_search_rng_factor()
450 );
451 assert_eq!(
452 internal_config.search_rng_epsilon,
453 default_search_rng_epsilon()
454 );
455 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
456 assert_eq!(internal_config.write_rng_factor, default_write_rng_factor());
457 assert_eq!(
458 internal_config.write_rng_epsilon,
459 default_write_rng_epsilon()
460 );
461 assert_eq!(
462 internal_config.num_samples_kmeans,
463 default_num_samples_kmeans()
464 );
465 assert_eq!(internal_config.initial_lambda, default_initial_lambda());
466 assert_eq!(
467 internal_config.num_centers_to_merge_to,
468 default_num_centers_to_merge_to()
469 );
470 }
471
472 #[test]
473 fn test_deserialize_json_without_nreplica_count() {
474 let json_without_nreplica = r#"{
475 "search_nprobe": 120,
476 "search_rng_factor": 2.5,
477 "search_rng_epsilon": 15.0,
478 "write_nprobe": 60,
479 "write_rng_factor": 1.5,
480 "write_rng_epsilon": 8.0,
481 "split_threshold": 80,
482 "num_samples_kmeans": 1500,
483 "initial_lambda": 150.0,
484 "reassign_neighbor_count": 32,
485 "merge_threshold": 30,
486 "num_centers_to_merge_to": 6,
487 "space": "l2",
488 "ef_construction": 180,
489 "ef_search": 200,
490 "max_neighbors": 56
491 }"#;
492
493 let internal_config: InternalSpannConfiguration =
494 serde_json::from_str(json_without_nreplica).unwrap();
495
496 assert_eq!(internal_config.search_nprobe, 120);
497 assert_eq!(internal_config.search_rng_factor, 2.5);
498 assert_eq!(internal_config.search_rng_epsilon, 15.0);
499 assert_eq!(internal_config.write_nprobe, 60);
500 assert_eq!(internal_config.write_rng_factor, 1.5);
501 assert_eq!(internal_config.write_rng_epsilon, 8.0);
502 assert_eq!(internal_config.split_threshold, 80);
503 assert_eq!(internal_config.num_samples_kmeans, 1500);
504 assert_eq!(internal_config.initial_lambda, 150.0);
505 assert_eq!(internal_config.reassign_neighbor_count, 32);
506 assert_eq!(internal_config.merge_threshold, 30);
507 assert_eq!(internal_config.num_centers_to_merge_to, 6);
508 assert_eq!(internal_config.space, Space::L2);
509 assert_eq!(internal_config.ef_construction, 180);
510 assert_eq!(internal_config.ef_search, 200);
511 assert_eq!(internal_config.max_neighbors, 56);
512 assert_eq!(internal_config.nreplica_count, default_nreplica_count());
513 }
514}