chroma_types/
hnsw_configuration.rs

1use crate::{HnswIndexConfig, Metadata};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use std::num::NonZero;
5use thiserror::Error;
6use validator::Validate;
7
8#[derive(Debug, Error)]
9pub enum HnswParametersFromSegmentError {
10    #[error("Invalid metadata: {0}")]
11    InvalidMetadata(#[from] serde_json::Error),
12    #[error("Invalid parameters: {0}")]
13    InvalidParameters(#[from] validator::ValidationErrors),
14    #[error("Incompatible vector index types")]
15    IncompatibleVectorIndexTypes,
16}
17
18impl ChromaError for HnswParametersFromSegmentError {
19    fn code(&self) -> ErrorCodes {
20        match self {
21            HnswParametersFromSegmentError::InvalidMetadata(_) => ErrorCodes::InvalidArgument,
22            HnswParametersFromSegmentError::InvalidParameters(_) => ErrorCodes::InvalidArgument,
23            HnswParametersFromSegmentError::IncompatibleVectorIndexTypes => {
24                ErrorCodes::InvalidArgument
25            }
26        }
27    }
28}
29
30#[derive(Default, Debug, PartialEq, Serialize, Deserialize, Clone)]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32pub enum Space {
33    #[default]
34    #[serde(rename = "l2")]
35    L2,
36    #[serde(rename = "cosine")]
37    Cosine,
38    #[serde(rename = "ip")]
39    Ip,
40}
41
42pub fn default_construction_ef() -> usize {
43    100
44}
45
46pub fn default_search_ef() -> usize {
47    100
48}
49
50pub fn default_m() -> usize {
51    16
52}
53
54pub fn default_num_threads() -> usize {
55    std::thread::available_parallelism()
56        .unwrap_or(NonZero::new(1).unwrap())
57        .get()
58}
59
60pub fn default_resize_factor() -> f64 {
61    1.2
62}
63
64pub fn default_sync_threshold() -> usize {
65    1000
66}
67
68pub fn default_batch_size() -> usize {
69    100
70}
71
72pub fn default_space() -> Space {
73    Space::L2
74}
75
76#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
77#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
78#[serde(deny_unknown_fields)]
79pub struct InternalHnswConfiguration {
80    #[serde(default = "default_space")]
81    pub space: Space,
82    #[serde(default = "default_construction_ef")]
83    pub ef_construction: usize,
84    #[serde(default = "default_search_ef")]
85    pub ef_search: usize,
86    #[serde(default = "default_m")]
87    pub max_neighbors: usize,
88    #[serde(default = "default_num_threads")]
89    #[serde(skip_serializing)]
90    pub num_threads: usize,
91    #[serde(default = "default_resize_factor")]
92    pub resize_factor: f64,
93    #[validate(range(min = 2))]
94    #[serde(default = "default_sync_threshold")]
95    pub sync_threshold: usize,
96    #[validate(range(min = 2))]
97    #[serde(default = "default_batch_size")]
98    #[serde(skip_serializing)]
99    pub batch_size: usize,
100}
101
102impl Default for InternalHnswConfiguration {
103    fn default() -> Self {
104        serde_json::from_str("{}").unwrap()
105    }
106}
107
108impl From<(Option<&Space>, Option<&HnswIndexConfig>)> for InternalHnswConfiguration {
109    fn from((space, config): (Option<&Space>, Option<&HnswIndexConfig>)) -> Self {
110        let mut internal = InternalHnswConfiguration::default();
111
112        if let Some(space) = space {
113            internal.space = space.clone();
114        }
115
116        if let Some(config) = config {
117            if let Some(ef_construction) = config.ef_construction {
118                internal.ef_construction = ef_construction;
119            }
120            if let Some(max_neighbors) = config.max_neighbors {
121                internal.max_neighbors = max_neighbors;
122            }
123            if let Some(ef_search) = config.ef_search {
124                internal.ef_search = ef_search;
125            }
126            if let Some(num_threads) = config.num_threads {
127                internal.num_threads = num_threads;
128            }
129            if let Some(batch_size) = config.batch_size {
130                internal.batch_size = batch_size;
131            }
132            if let Some(sync_threshold) = config.sync_threshold {
133                internal.sync_threshold = sync_threshold;
134            }
135            if let Some(resize_factor) = config.resize_factor {
136                internal.resize_factor = resize_factor;
137            }
138        }
139
140        internal
141    }
142}
143
144#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
145#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
146#[serde(deny_unknown_fields)]
147#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
148pub struct HnswConfiguration {
149    pub space: Option<Space>,
150    pub ef_construction: Option<usize>,
151    pub ef_search: Option<usize>,
152    pub max_neighbors: Option<usize>,
153    #[serde(skip_serializing)]
154    pub num_threads: Option<usize>,
155    pub resize_factor: Option<f64>,
156    #[validate(range(min = 2))]
157    pub sync_threshold: Option<usize>,
158    #[validate(range(min = 2))]
159    #[serde(skip_serializing)]
160    pub batch_size: Option<usize>,
161}
162
163impl From<InternalHnswConfiguration> for HnswConfiguration {
164    fn from(config: InternalHnswConfiguration) -> Self {
165        Self {
166            space: Some(config.space),
167            ef_construction: Some(config.ef_construction),
168            ef_search: Some(config.ef_search),
169            max_neighbors: Some(config.max_neighbors),
170            num_threads: Some(config.num_threads),
171            resize_factor: Some(config.resize_factor),
172            sync_threshold: Some(config.sync_threshold),
173            batch_size: Some(config.batch_size),
174        }
175    }
176}
177
178impl From<HnswConfiguration> for InternalHnswConfiguration {
179    fn from(config: HnswConfiguration) -> Self {
180        Self {
181            space: config.space.unwrap_or(default_space()),
182            ef_construction: config.ef_construction.unwrap_or(default_construction_ef()),
183            ef_search: config.ef_search.unwrap_or(default_search_ef()),
184            max_neighbors: config.max_neighbors.unwrap_or(default_m()),
185            num_threads: config.num_threads.unwrap_or(default_num_threads()),
186            resize_factor: config.resize_factor.unwrap_or(default_resize_factor()),
187            sync_threshold: config.sync_threshold.unwrap_or(default_sync_threshold()),
188            batch_size: config.batch_size.unwrap_or(default_batch_size()),
189        }
190    }
191}
192
193impl Default for HnswConfiguration {
194    fn default() -> Self {
195        serde_json::from_str("{}").unwrap()
196    }
197}
198
199impl InternalHnswConfiguration {
200    pub(crate) fn from_legacy_segment_metadata(
201        segment_metadata: &Option<Metadata>,
202    ) -> Result<Self, HnswParametersFromSegmentError> {
203        if let Some(metadata) = segment_metadata {
204            #[derive(Deserialize)]
205            #[serde(deny_unknown_fields)]
206            struct LegacyMetadataLocalHnswParameters {
207                #[serde(rename = "hnsw:space", default)]
208                pub space: Space,
209                #[serde(rename = "hnsw:construction_ef", default = "default_construction_ef")]
210                pub construction_ef: usize,
211                #[serde(rename = "hnsw:search_ef", default = "default_search_ef")]
212                pub search_ef: usize,
213                #[serde(rename = "hnsw:M", default = "default_m")]
214                pub m: usize,
215                #[serde(rename = "hnsw:num_threads", default = "default_num_threads")]
216                pub num_threads: usize,
217                #[serde(rename = "hnsw:resize_factor", default = "default_resize_factor")]
218                pub resize_factor: f64,
219                #[serde(rename = "hnsw:sync_threshold", default = "default_sync_threshold")]
220                pub sync_threshold: usize,
221                #[serde(rename = "hnsw:batch_size", default = "default_batch_size")]
222                pub batch_size: usize,
223            }
224
225            let filtered_metadata = metadata
226                .clone()
227                .into_iter()
228                .filter(|(k, _)| k.starts_with("hnsw:"))
229                .collect::<Metadata>();
230
231            let metadata_str = serde_json::to_string(&filtered_metadata)?;
232            let parsed = serde_json::from_str::<LegacyMetadataLocalHnswParameters>(&metadata_str)?;
233            let params = InternalHnswConfiguration {
234                space: parsed.space,
235                ef_construction: parsed.construction_ef,
236                ef_search: parsed.search_ef,
237                max_neighbors: parsed.m,
238                num_threads: parsed.num_threads,
239                resize_factor: parsed.resize_factor,
240                sync_threshold: parsed.sync_threshold,
241                batch_size: parsed.batch_size,
242            };
243            params.validate()?;
244            Ok(params)
245        } else {
246            Ok(InternalHnswConfiguration::default())
247        }
248    }
249}
250
251#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, Validate)]
252#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
253#[serde(deny_unknown_fields)]
254#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
255pub struct UpdateHnswConfiguration {
256    pub ef_search: Option<usize>,
257    pub max_neighbors: Option<usize>,
258    pub num_threads: Option<usize>,
259    pub resize_factor: Option<f64>,
260    #[validate(range(min = 2))]
261    pub sync_threshold: Option<usize>,
262    #[validate(range(min = 2))]
263    pub batch_size: Option<usize>,
264}