chroma_types/
hnsw_configuration.rs

1use crate::Metadata;
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use std::num::NonZero;
5use thiserror::Error;
6use validator::Validate;
7
8#[derive(Debug, Error)]
9pub enum HnswParametersFromSegmentError {
10    #[error("Invalid metadata: {0}")]
11    InvalidMetadata(#[from] serde_json::Error),
12    #[error("Invalid parameters: {0}")]
13    InvalidParameters(#[from] validator::ValidationErrors),
14    #[error("Incompatible vector index types")]
15    IncompatibleVectorIndexTypes,
16}
17
18impl ChromaError for HnswParametersFromSegmentError {
19    fn code(&self) -> ErrorCodes {
20        match self {
21            HnswParametersFromSegmentError::InvalidMetadata(_) => ErrorCodes::InvalidArgument,
22            HnswParametersFromSegmentError::InvalidParameters(_) => ErrorCodes::InvalidArgument,
23            HnswParametersFromSegmentError::IncompatibleVectorIndexTypes => {
24                ErrorCodes::InvalidArgument
25            }
26        }
27    }
28}
29
30#[derive(Default, Debug, PartialEq, Serialize, Deserialize, Clone)]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32pub enum Space {
33    #[default]
34    #[serde(rename = "l2")]
35    L2,
36    #[serde(rename = "cosine")]
37    Cosine,
38    #[serde(rename = "ip")]
39    Ip,
40}
41
42pub fn default_construction_ef() -> usize {
43    100
44}
45
46pub fn default_search_ef() -> usize {
47    100
48}
49
50pub fn default_m() -> usize {
51    16
52}
53
54pub fn default_num_threads() -> usize {
55    std::thread::available_parallelism()
56        .unwrap_or(NonZero::new(1).unwrap())
57        .get()
58}
59
60pub fn default_resize_factor() -> f64 {
61    1.2
62}
63
64pub fn default_sync_threshold() -> usize {
65    1000
66}
67
68pub fn default_batch_size() -> usize {
69    100
70}
71
72pub fn default_space() -> Space {
73    Space::L2
74}
75
76#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
77#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
78#[serde(deny_unknown_fields)]
79pub struct InternalHnswConfiguration {
80    #[serde(default = "default_space")]
81    pub space: Space,
82    #[serde(default = "default_construction_ef")]
83    pub ef_construction: usize,
84    #[serde(default = "default_search_ef")]
85    pub ef_search: usize,
86    #[serde(default = "default_m")]
87    pub max_neighbors: usize,
88    #[serde(default = "default_num_threads")]
89    #[serde(skip_serializing)]
90    pub num_threads: usize,
91    #[serde(default = "default_resize_factor")]
92    pub resize_factor: f64,
93    #[validate(range(min = 2))]
94    #[serde(default = "default_sync_threshold")]
95    pub sync_threshold: usize,
96    #[validate(range(min = 2))]
97    #[serde(default = "default_batch_size")]
98    #[serde(skip_serializing)]
99    pub batch_size: usize,
100}
101
102impl Default for InternalHnswConfiguration {
103    fn default() -> Self {
104        serde_json::from_str("{}").unwrap()
105    }
106}
107
108#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
109#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
110#[serde(deny_unknown_fields)]
111#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
112pub struct HnswConfiguration {
113    pub space: Option<Space>,
114    pub ef_construction: Option<usize>,
115    pub ef_search: Option<usize>,
116    pub max_neighbors: Option<usize>,
117    #[serde(skip_serializing)]
118    pub num_threads: Option<usize>,
119    pub resize_factor: Option<f64>,
120    #[validate(range(min = 2))]
121    pub sync_threshold: Option<usize>,
122    #[validate(range(min = 2))]
123    #[serde(skip_serializing)]
124    pub batch_size: Option<usize>,
125}
126
127impl From<InternalHnswConfiguration> for HnswConfiguration {
128    fn from(config: InternalHnswConfiguration) -> Self {
129        Self {
130            space: Some(config.space),
131            ef_construction: Some(config.ef_construction),
132            ef_search: Some(config.ef_search),
133            max_neighbors: Some(config.max_neighbors),
134            num_threads: Some(config.num_threads),
135            resize_factor: Some(config.resize_factor),
136            sync_threshold: Some(config.sync_threshold),
137            batch_size: Some(config.batch_size),
138        }
139    }
140}
141
142impl From<HnswConfiguration> for InternalHnswConfiguration {
143    fn from(config: HnswConfiguration) -> Self {
144        Self {
145            space: config.space.unwrap_or(default_space()),
146            ef_construction: config.ef_construction.unwrap_or(default_construction_ef()),
147            ef_search: config.ef_search.unwrap_or(default_search_ef()),
148            max_neighbors: config.max_neighbors.unwrap_or(default_m()),
149            num_threads: config.num_threads.unwrap_or(default_num_threads()),
150            resize_factor: config.resize_factor.unwrap_or(default_resize_factor()),
151            sync_threshold: config.sync_threshold.unwrap_or(default_sync_threshold()),
152            batch_size: config.batch_size.unwrap_or(default_batch_size()),
153        }
154    }
155}
156
157impl Default for HnswConfiguration {
158    fn default() -> Self {
159        serde_json::from_str("{}").unwrap()
160    }
161}
162
163impl InternalHnswConfiguration {
164    pub(crate) fn from_legacy_segment_metadata(
165        segment_metadata: &Option<Metadata>,
166    ) -> Result<Self, HnswParametersFromSegmentError> {
167        if let Some(metadata) = segment_metadata {
168            #[derive(Deserialize)]
169            #[serde(deny_unknown_fields)]
170            struct LegacyMetadataLocalHnswParameters {
171                #[serde(rename = "hnsw:space", default)]
172                pub space: Space,
173                #[serde(rename = "hnsw:construction_ef", default = "default_construction_ef")]
174                pub construction_ef: usize,
175                #[serde(rename = "hnsw:search_ef", default = "default_search_ef")]
176                pub search_ef: usize,
177                #[serde(rename = "hnsw:M", default = "default_m")]
178                pub m: usize,
179                #[serde(rename = "hnsw:num_threads", default = "default_num_threads")]
180                pub num_threads: usize,
181                #[serde(rename = "hnsw:resize_factor", default = "default_resize_factor")]
182                pub resize_factor: f64,
183                #[serde(rename = "hnsw:sync_threshold", default = "default_sync_threshold")]
184                pub sync_threshold: usize,
185                #[serde(rename = "hnsw:batch_size", default = "default_batch_size")]
186                pub batch_size: usize,
187            }
188
189            let filtered_metadata = metadata
190                .clone()
191                .into_iter()
192                .filter(|(k, _)| k.starts_with("hnsw:"))
193                .collect::<Metadata>();
194
195            let metadata_str = serde_json::to_string(&filtered_metadata)?;
196            let parsed = serde_json::from_str::<LegacyMetadataLocalHnswParameters>(&metadata_str)?;
197            let params = InternalHnswConfiguration {
198                space: parsed.space,
199                ef_construction: parsed.construction_ef,
200                ef_search: parsed.search_ef,
201                max_neighbors: parsed.m,
202                num_threads: parsed.num_threads,
203                resize_factor: parsed.resize_factor,
204                sync_threshold: parsed.sync_threshold,
205                batch_size: parsed.batch_size,
206            };
207            params.validate()?;
208            Ok(params)
209        } else {
210            Ok(InternalHnswConfiguration::default())
211        }
212    }
213}
214
215#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, Validate)]
216#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
217#[serde(deny_unknown_fields)]
218#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
219pub struct UpdateHnswConfiguration {
220    pub ef_search: Option<usize>,
221    pub max_neighbors: Option<usize>,
222    pub num_threads: Option<usize>,
223    pub resize_factor: Option<f64>,
224    #[validate(range(min = 2))]
225    pub sync_threshold: Option<usize>,
226    #[validate(range(min = 2))]
227    pub batch_size: Option<usize>,
228}