1use crate::Metadata;
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use std::num::NonZero;
5use thiserror::Error;
6use validator::Validate;
7
8#[derive(Debug, Error)]
9pub enum HnswParametersFromSegmentError {
10 #[error("Invalid metadata: {0}")]
11 InvalidMetadata(#[from] serde_json::Error),
12 #[error("Invalid parameters: {0}")]
13 InvalidParameters(#[from] validator::ValidationErrors),
14 #[error("Incompatible vector index types")]
15 IncompatibleVectorIndexTypes,
16}
17
18impl ChromaError for HnswParametersFromSegmentError {
19 fn code(&self) -> ErrorCodes {
20 match self {
21 HnswParametersFromSegmentError::InvalidMetadata(_) => ErrorCodes::InvalidArgument,
22 HnswParametersFromSegmentError::InvalidParameters(_) => ErrorCodes::InvalidArgument,
23 HnswParametersFromSegmentError::IncompatibleVectorIndexTypes => {
24 ErrorCodes::InvalidArgument
25 }
26 }
27 }
28}
29
30#[derive(Default, Debug, PartialEq, Serialize, Deserialize, Clone)]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32pub enum Space {
33 #[default]
34 #[serde(rename = "l2")]
35 L2,
36 #[serde(rename = "cosine")]
37 Cosine,
38 #[serde(rename = "ip")]
39 Ip,
40}
41
42pub fn default_construction_ef() -> usize {
43 100
44}
45
46pub fn default_search_ef() -> usize {
47 100
48}
49
50pub fn default_m() -> usize {
51 16
52}
53
54pub fn default_num_threads() -> usize {
55 std::thread::available_parallelism()
56 .unwrap_or(NonZero::new(1).unwrap())
57 .get()
58}
59
60pub fn default_resize_factor() -> f64 {
61 1.2
62}
63
64pub fn default_sync_threshold() -> usize {
65 1000
66}
67
68pub fn default_batch_size() -> usize {
69 100
70}
71
72pub fn default_space() -> Space {
73 Space::L2
74}
75
76#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
77#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
78#[serde(deny_unknown_fields)]
79pub struct InternalHnswConfiguration {
80 #[serde(default = "default_space")]
81 pub space: Space,
82 #[serde(default = "default_construction_ef")]
83 pub ef_construction: usize,
84 #[serde(default = "default_search_ef")]
85 pub ef_search: usize,
86 #[serde(default = "default_m")]
87 pub max_neighbors: usize,
88 #[serde(default = "default_num_threads")]
89 #[serde(skip_serializing)]
90 pub num_threads: usize,
91 #[serde(default = "default_resize_factor")]
92 pub resize_factor: f64,
93 #[validate(range(min = 2))]
94 #[serde(default = "default_sync_threshold")]
95 pub sync_threshold: usize,
96 #[validate(range(min = 2))]
97 #[serde(default = "default_batch_size")]
98 #[serde(skip_serializing)]
99 pub batch_size: usize,
100}
101
102impl Default for InternalHnswConfiguration {
103 fn default() -> Self {
104 serde_json::from_str("{}").unwrap()
105 }
106}
107
108#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Validate)]
109#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
110#[serde(deny_unknown_fields)]
111#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
112pub struct HnswConfiguration {
113 pub space: Option<Space>,
114 pub ef_construction: Option<usize>,
115 pub ef_search: Option<usize>,
116 pub max_neighbors: Option<usize>,
117 #[serde(skip_serializing)]
118 pub num_threads: Option<usize>,
119 pub resize_factor: Option<f64>,
120 #[validate(range(min = 2))]
121 pub sync_threshold: Option<usize>,
122 #[validate(range(min = 2))]
123 #[serde(skip_serializing)]
124 pub batch_size: Option<usize>,
125}
126
127impl From<InternalHnswConfiguration> for HnswConfiguration {
128 fn from(config: InternalHnswConfiguration) -> Self {
129 Self {
130 space: Some(config.space),
131 ef_construction: Some(config.ef_construction),
132 ef_search: Some(config.ef_search),
133 max_neighbors: Some(config.max_neighbors),
134 num_threads: Some(config.num_threads),
135 resize_factor: Some(config.resize_factor),
136 sync_threshold: Some(config.sync_threshold),
137 batch_size: Some(config.batch_size),
138 }
139 }
140}
141
142impl From<HnswConfiguration> for InternalHnswConfiguration {
143 fn from(config: HnswConfiguration) -> Self {
144 Self {
145 space: config.space.unwrap_or(default_space()),
146 ef_construction: config.ef_construction.unwrap_or(default_construction_ef()),
147 ef_search: config.ef_search.unwrap_or(default_search_ef()),
148 max_neighbors: config.max_neighbors.unwrap_or(default_m()),
149 num_threads: config.num_threads.unwrap_or(default_num_threads()),
150 resize_factor: config.resize_factor.unwrap_or(default_resize_factor()),
151 sync_threshold: config.sync_threshold.unwrap_or(default_sync_threshold()),
152 batch_size: config.batch_size.unwrap_or(default_batch_size()),
153 }
154 }
155}
156
157impl Default for HnswConfiguration {
158 fn default() -> Self {
159 serde_json::from_str("{}").unwrap()
160 }
161}
162
163impl InternalHnswConfiguration {
164 pub(crate) fn from_legacy_segment_metadata(
165 segment_metadata: &Option<Metadata>,
166 ) -> Result<Self, HnswParametersFromSegmentError> {
167 if let Some(metadata) = segment_metadata {
168 #[derive(Deserialize)]
169 #[serde(deny_unknown_fields)]
170 struct LegacyMetadataLocalHnswParameters {
171 #[serde(rename = "hnsw:space", default)]
172 pub space: Space,
173 #[serde(rename = "hnsw:construction_ef", default = "default_construction_ef")]
174 pub construction_ef: usize,
175 #[serde(rename = "hnsw:search_ef", default = "default_search_ef")]
176 pub search_ef: usize,
177 #[serde(rename = "hnsw:M", default = "default_m")]
178 pub m: usize,
179 #[serde(rename = "hnsw:num_threads", default = "default_num_threads")]
180 pub num_threads: usize,
181 #[serde(rename = "hnsw:resize_factor", default = "default_resize_factor")]
182 pub resize_factor: f64,
183 #[serde(rename = "hnsw:sync_threshold", default = "default_sync_threshold")]
184 pub sync_threshold: usize,
185 #[serde(rename = "hnsw:batch_size", default = "default_batch_size")]
186 pub batch_size: usize,
187 }
188
189 let filtered_metadata = metadata
190 .clone()
191 .into_iter()
192 .filter(|(k, _)| k.starts_with("hnsw:"))
193 .collect::<Metadata>();
194
195 let metadata_str = serde_json::to_string(&filtered_metadata)?;
196 let parsed = serde_json::from_str::<LegacyMetadataLocalHnswParameters>(&metadata_str)?;
197 let params = InternalHnswConfiguration {
198 space: parsed.space,
199 ef_construction: parsed.construction_ef,
200 ef_search: parsed.search_ef,
201 max_neighbors: parsed.m,
202 num_threads: parsed.num_threads,
203 resize_factor: parsed.resize_factor,
204 sync_threshold: parsed.sync_threshold,
205 batch_size: parsed.batch_size,
206 };
207 params.validate()?;
208 Ok(params)
209 } else {
210 Ok(InternalHnswConfiguration::default())
211 }
212 }
213}
214
215#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, Validate)]
216#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
217#[serde(deny_unknown_fields)]
218#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
219pub struct UpdateHnswConfiguration {
220 pub ef_search: Option<usize>,
221 pub max_neighbors: Option<usize>,
222 pub num_threads: Option<usize>,
223 pub resize_factor: Option<f64>,
224 #[validate(range(min = 2))]
225 pub sync_threshold: Option<usize>,
226 #[validate(range(min = 2))]
227 pub batch_size: Option<usize>,
228}