1use crate::{HnswIndexConfig, Metadata};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use std::num::NonZero;
5use thiserror::Error;
6use validator::Validate;
7
8#[derive(Debug, Error)]
9pub enum HnswParametersFromSegmentError {
10 #[error("Invalid metadata: {0}")]
11 InvalidMetadata(#[from] serde_json::Error),
12 #[error("Invalid parameters: {0}")]
13 InvalidParameters(#[from] validator::ValidationErrors),
14 #[error("Incompatible vector index types")]
15 IncompatibleVectorIndexTypes,
16}
17
18impl ChromaError for HnswParametersFromSegmentError {
19 fn code(&self) -> ErrorCodes {
20 match self {
21 HnswParametersFromSegmentError::InvalidMetadata(_) => ErrorCodes::InvalidArgument,
22 HnswParametersFromSegmentError::InvalidParameters(_) => ErrorCodes::InvalidArgument,
23 HnswParametersFromSegmentError::IncompatibleVectorIndexTypes => {
24 ErrorCodes::InvalidArgument
25 }
26 }
27 }
28}
29
30#[derive(Default, Debug, PartialEq, Serialize, Deserialize, Clone)]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32pub enum Space {
33 #[default]
34 #[serde(rename = "l2")]
35 L2,
36 #[serde(rename = "cosine")]
37 Cosine,
38 #[serde(rename = "ip")]
39 Ip,
40}
41
42pub fn default_construction_ef() -> usize {
43 100
44}
45
46pub fn default_search_ef() -> usize {
47 100
48}
49
50pub fn default_m() -> usize {
51 16
52}
53
54pub fn default_num_threads() -> usize {
55 std::thread::available_parallelism()
56 .unwrap_or(NonZero::new(1).unwrap())
57 .get()
58}
59
60pub fn default_resize_factor() -> f64 {
61 1.2
62}
63
64pub fn default_sync_threshold() -> usize {
65 1000
66}
67
68pub fn default_batch_size() -> usize {
69 100
70}
71
72pub fn default_space() -> Space {
73 Space::L2
74}
75
76#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
77#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
78#[serde(deny_unknown_fields)]
79pub struct InternalHnswConfiguration {
80 #[serde(default = "default_space")]
81 pub space: Space,
82 #[serde(default = "default_construction_ef")]
83 pub ef_construction: usize,
84 #[serde(default = "default_search_ef")]
85 pub ef_search: usize,
86 #[serde(default = "default_m")]
87 pub max_neighbors: usize,
88 #[serde(default = "default_num_threads")]
89 #[serde(skip_serializing)]
90 pub num_threads: usize,
91 #[serde(default = "default_resize_factor")]
92 pub resize_factor: f64,
93 #[validate(range(min = 2))]
94 #[serde(default = "default_sync_threshold")]
95 pub sync_threshold: usize,
96 #[validate(range(min = 2))]
97 #[serde(default = "default_batch_size")]
98 #[serde(skip_serializing)]
99 pub batch_size: usize,
100}
101
102impl Default for InternalHnswConfiguration {
103 fn default() -> Self {
104 serde_json::from_str("{}").unwrap()
105 }
106}
107
108impl From<(Option<&Space>, Option<&HnswIndexConfig>)> for InternalHnswConfiguration {
109 fn from((space, config): (Option<&Space>, Option<&HnswIndexConfig>)) -> Self {
110 let mut internal = InternalHnswConfiguration::default();
111
112 if let Some(space) = space {
113 internal.space = space.clone();
114 }
115
116 if let Some(config) = config {
117 if let Some(ef_construction) = config.ef_construction {
118 internal.ef_construction = ef_construction;
119 }
120 if let Some(max_neighbors) = config.max_neighbors {
121 internal.max_neighbors = max_neighbors;
122 }
123 if let Some(ef_search) = config.ef_search {
124 internal.ef_search = ef_search;
125 }
126 if let Some(num_threads) = config.num_threads {
127 internal.num_threads = num_threads;
128 }
129 if let Some(batch_size) = config.batch_size {
130 internal.batch_size = batch_size;
131 }
132 if let Some(sync_threshold) = config.sync_threshold {
133 internal.sync_threshold = sync_threshold;
134 }
135 if let Some(resize_factor) = config.resize_factor {
136 internal.resize_factor = resize_factor;
137 }
138 }
139
140 internal
141 }
142}
143
144#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
145#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
146#[serde(deny_unknown_fields)]
147#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
148pub struct HnswConfiguration {
149 pub space: Option<Space>,
150 pub ef_construction: Option<usize>,
151 pub ef_search: Option<usize>,
152 pub max_neighbors: Option<usize>,
153 #[serde(skip_serializing)]
154 pub num_threads: Option<usize>,
155 pub resize_factor: Option<f64>,
156 #[validate(range(min = 2))]
157 pub sync_threshold: Option<usize>,
158 #[validate(range(min = 2))]
159 #[serde(skip_serializing)]
160 pub batch_size: Option<usize>,
161}
162
163impl From<InternalHnswConfiguration> for HnswConfiguration {
164 fn from(config: InternalHnswConfiguration) -> Self {
165 Self {
166 space: Some(config.space),
167 ef_construction: Some(config.ef_construction),
168 ef_search: Some(config.ef_search),
169 max_neighbors: Some(config.max_neighbors),
170 num_threads: Some(config.num_threads),
171 resize_factor: Some(config.resize_factor),
172 sync_threshold: Some(config.sync_threshold),
173 batch_size: Some(config.batch_size),
174 }
175 }
176}
177
178impl From<HnswConfiguration> for InternalHnswConfiguration {
179 fn from(config: HnswConfiguration) -> Self {
180 Self {
181 space: config.space.unwrap_or(default_space()),
182 ef_construction: config.ef_construction.unwrap_or(default_construction_ef()),
183 ef_search: config.ef_search.unwrap_or(default_search_ef()),
184 max_neighbors: config.max_neighbors.unwrap_or(default_m()),
185 num_threads: config.num_threads.unwrap_or(default_num_threads()),
186 resize_factor: config.resize_factor.unwrap_or(default_resize_factor()),
187 sync_threshold: config.sync_threshold.unwrap_or(default_sync_threshold()),
188 batch_size: config.batch_size.unwrap_or(default_batch_size()),
189 }
190 }
191}
192
193impl Default for HnswConfiguration {
194 fn default() -> Self {
195 serde_json::from_str("{}").unwrap()
196 }
197}
198
199impl InternalHnswConfiguration {
200 pub(crate) fn from_legacy_segment_metadata(
201 segment_metadata: &Option<Metadata>,
202 ) -> Result<Self, HnswParametersFromSegmentError> {
203 if let Some(metadata) = segment_metadata {
204 #[derive(Deserialize)]
205 #[serde(deny_unknown_fields)]
206 struct LegacyMetadataLocalHnswParameters {
207 #[serde(rename = "hnsw:space", default)]
208 pub space: Space,
209 #[serde(rename = "hnsw:construction_ef", default = "default_construction_ef")]
210 pub construction_ef: usize,
211 #[serde(rename = "hnsw:search_ef", default = "default_search_ef")]
212 pub search_ef: usize,
213 #[serde(rename = "hnsw:M", default = "default_m")]
214 pub m: usize,
215 #[serde(rename = "hnsw:num_threads", default = "default_num_threads")]
216 pub num_threads: usize,
217 #[serde(rename = "hnsw:resize_factor", default = "default_resize_factor")]
218 pub resize_factor: f64,
219 #[serde(rename = "hnsw:sync_threshold", default = "default_sync_threshold")]
220 pub sync_threshold: usize,
221 #[serde(rename = "hnsw:batch_size", default = "default_batch_size")]
222 pub batch_size: usize,
223 }
224
225 let filtered_metadata = metadata
226 .clone()
227 .into_iter()
228 .filter(|(k, _)| k.starts_with("hnsw:"))
229 .collect::<Metadata>();
230
231 let metadata_str = serde_json::to_string(&filtered_metadata)?;
232 let parsed = serde_json::from_str::<LegacyMetadataLocalHnswParameters>(&metadata_str)?;
233 let params = InternalHnswConfiguration {
234 space: parsed.space,
235 ef_construction: parsed.construction_ef,
236 ef_search: parsed.search_ef,
237 max_neighbors: parsed.m,
238 num_threads: parsed.num_threads,
239 resize_factor: parsed.resize_factor,
240 sync_threshold: parsed.sync_threshold,
241 batch_size: parsed.batch_size,
242 };
243 params.validate()?;
244 Ok(params)
245 } else {
246 Ok(InternalHnswConfiguration::default())
247 }
248 }
249}
250
251#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, Validate)]
252#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
253#[serde(deny_unknown_fields)]
254#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
255pub struct UpdateHnswConfiguration {
256 pub ef_search: Option<usize>,
257 pub max_neighbors: Option<usize>,
258 pub num_threads: Option<usize>,
259 pub resize_factor: Option<f64>,
260 #[validate(range(min = 2))]
261 pub sync_threshold: Option<usize>,
262 #[validate(range(min = 2))]
263 pub batch_size: Option<usize>,
264}