1use std::str::FromStr;
2
3use super::{Metadata, MetadataValueConversionError};
4use crate::{
5 chroma_proto, test_segment, CollectionConfiguration, InternalCollectionConfiguration,
6 InternalSchema, Segment, SegmentScope, UpdateCollectionConfiguration, UpdateMetadata,
7};
8use chroma_error::{ChromaError, ErrorCodes};
9use serde::{Deserialize, Serialize};
10use std::time::{Duration, SystemTime};
11use thiserror::Error;
12use uuid::Uuid;
13
14#[cfg(feature = "pyo3")]
15use pyo3::types::PyAnyMethods;
16
17#[derive(
19 Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
20)]
21#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
22pub struct CollectionUuid(pub Uuid);
23
24#[derive(
26 Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
27)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub struct DatabaseUuid(pub Uuid);
30
31impl DatabaseUuid {
32 pub fn new() -> Self {
33 DatabaseUuid(Uuid::new_v4())
34 }
35}
36
37impl CollectionUuid {
38 pub fn new() -> Self {
39 CollectionUuid(Uuid::new_v4())
40 }
41
42 pub fn storage_prefix_for_log(&self) -> String {
43 format!("logs/{}", self)
44 }
45}
46
47impl std::str::FromStr for CollectionUuid {
48 type Err = uuid::Error;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match Uuid::parse_str(s) {
52 Ok(uuid) => Ok(CollectionUuid(uuid)),
53 Err(err) => Err(err),
54 }
55 }
56}
57
58impl std::str::FromStr for DatabaseUuid {
59 type Err = uuid::Error;
60
61 fn from_str(s: &str) -> Result<Self, Self::Err> {
62 match Uuid::parse_str(s) {
63 Ok(uuid) => Ok(DatabaseUuid(uuid)),
64 Err(err) => Err(err),
65 }
66 }
67}
68
69impl std::fmt::Display for DatabaseUuid {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{}", self.0)
72 }
73}
74
75impl std::fmt::Display for CollectionUuid {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", self.0)
78 }
79}
80
81fn serialize_internal_collection_configuration<S: serde::Serializer>(
82 config: &InternalCollectionConfiguration,
83 serializer: S,
84) -> Result<S::Ok, S::Error> {
85 let collection_config: CollectionConfiguration = config.clone().into();
86 collection_config.serialize(serializer)
87}
88
89fn deserialize_internal_collection_configuration<'de, D: serde::Deserializer<'de>>(
90 deserializer: D,
91) -> Result<InternalCollectionConfiguration, D::Error> {
92 let collection_config = CollectionConfiguration::deserialize(deserializer)?;
93 collection_config
94 .try_into()
95 .map_err(serde::de::Error::custom)
96}
97
98#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
100#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
101pub struct Collection {
102 #[serde(rename = "id")]
103 pub collection_id: CollectionUuid,
104 pub name: String,
105 #[serde(
106 serialize_with = "serialize_internal_collection_configuration",
107 deserialize_with = "deserialize_internal_collection_configuration",
108 rename = "configuration_json"
109 )]
110 #[cfg_attr(feature = "utoipa", schema(value_type = CollectionConfiguration))]
111 pub config: InternalCollectionConfiguration,
112 pub schema: Option<InternalSchema>,
113 pub metadata: Option<Metadata>,
114 pub dimension: Option<i32>,
115 pub tenant: String,
116 pub database: String,
117 pub log_position: i64,
118 pub version: i32,
119 #[serde(skip)]
120 pub total_records_post_compaction: u64,
121 #[serde(skip)]
122 pub size_bytes_post_compaction: u64,
123 #[serde(skip)]
124 pub last_compaction_time_secs: u64,
125 #[serde(skip)]
126 pub version_file_path: Option<String>,
127 #[serde(skip)]
128 pub root_collection_id: Option<CollectionUuid>,
129 #[serde(skip)]
130 pub lineage_file_path: Option<String>,
131 #[serde(skip, default = "SystemTime::now")]
132 pub updated_at: SystemTime,
133 #[serde(skip)]
134 pub database_id: DatabaseUuid,
135}
136
137impl Default for Collection {
138 fn default() -> Self {
139 Self {
140 collection_id: CollectionUuid::new(),
141 name: "".to_string(),
142 config: InternalCollectionConfiguration::default_hnsw(),
143 schema: None,
144 metadata: None,
145 dimension: None,
146 tenant: "".to_string(),
147 database: "".to_string(),
148 log_position: 0,
149 version: 0,
150 total_records_post_compaction: 0,
151 size_bytes_post_compaction: 0,
152 last_compaction_time_secs: 0,
153 version_file_path: None,
154 root_collection_id: None,
155 lineage_file_path: None,
156 updated_at: SystemTime::now(),
157 database_id: DatabaseUuid::new(),
158 }
159 }
160}
161
162#[cfg(feature = "pyo3")]
163#[pyo3::pymethods]
164impl Collection {
165 #[getter]
166 fn id<'py>(&self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
167 let res = pyo3::prelude::PyModule::import(py, "uuid")?
168 .getattr("UUID")?
169 .call1((self.collection_id.to_string(),))?;
170 Ok(res)
171 }
172
173 #[getter]
174 fn configuration<'py>(
175 &self,
176 py: pyo3::Python<'py>,
177 ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
178 let config: crate::CollectionConfiguration = self.config.clone().into();
179 let config_json_str = serde_json::to_string(&config).unwrap();
180 let res = pyo3::prelude::PyModule::import(py, "json")?
181 .getattr("loads")?
182 .call1((config_json_str,))?;
183 Ok(res)
184 }
185
186 #[getter]
187 pub fn name(&self) -> &str {
188 &self.name
189 }
190
191 #[getter]
192 pub fn metadata(&self) -> Option<Metadata> {
193 self.metadata.clone()
194 }
195
196 #[getter]
197 pub fn dimension(&self) -> Option<i32> {
198 self.dimension
199 }
200
201 #[getter]
202 pub fn tenant(&self) -> &str {
203 &self.tenant
204 }
205
206 #[getter]
207 pub fn database(&self) -> &str {
208 &self.database
209 }
210}
211
212impl Collection {
213 pub fn test_collection(dim: i32) -> Self {
214 Collection {
215 name: "test_collection".to_string(),
216 dimension: Some(dim),
217 tenant: "default_tenant".to_string(),
218 database: "default_database".to_string(),
219 database_id: DatabaseUuid::new(),
220 ..Default::default()
221 }
222 }
223}
224
225#[derive(Error, Debug)]
226pub enum CollectionConversionError {
227 #[error("Invalid config: {0}")]
228 InvalidConfig(#[from] serde_json::Error),
229 #[error("Invalid UUID")]
230 InvalidUuid,
231 #[error(transparent)]
232 MetadataValueConversionError(#[from] MetadataValueConversionError),
233 #[error("Missing Database Id")]
234 MissingDatabaseId,
235}
236
237impl ChromaError for CollectionConversionError {
238 fn code(&self) -> ErrorCodes {
239 match self {
240 CollectionConversionError::InvalidConfig(_) => ErrorCodes::InvalidArgument,
241 CollectionConversionError::InvalidUuid => ErrorCodes::InvalidArgument,
242 CollectionConversionError::MetadataValueConversionError(e) => e.code(),
243 CollectionConversionError::MissingDatabaseId => ErrorCodes::Internal,
244 }
245 }
246}
247
248impl TryFrom<chroma_proto::Collection> for Collection {
249 type Error = CollectionConversionError;
250
251 fn try_from(proto_collection: chroma_proto::Collection) -> Result<Self, Self::Error> {
252 let collection_id = CollectionUuid::from_str(&proto_collection.id)
253 .map_err(|_| CollectionConversionError::InvalidUuid)?;
254 let collection_metadata: Option<Metadata> = match proto_collection.metadata {
255 Some(proto_metadata) => match proto_metadata.try_into() {
256 Ok(metadata) => Some(metadata),
257 Err(e) => return Err(CollectionConversionError::MetadataValueConversionError(e)),
258 },
259 None => None,
260 };
261 let updated_at = match proto_collection.updated_at {
263 Some(updated_at) => {
264 SystemTime::UNIX_EPOCH
265 + Duration::new(updated_at.seconds as u64, updated_at.nanos as u32)
266 }
267 None => SystemTime::now(),
268 };
269 let database_id = match proto_collection.database_id {
270 Some(db_id) => DatabaseUuid::from_str(&db_id)
271 .map_err(|_| CollectionConversionError::InvalidUuid)?,
272 None => {
273 return Err(CollectionConversionError::MissingDatabaseId);
274 }
275 };
276 let schema = match proto_collection.schema_str {
277 Some(schema_str) if !schema_str.is_empty() => Some(serde_json::from_str(&schema_str)?),
278 _ => None,
279 };
280
281 Ok(Collection {
282 collection_id,
283 name: proto_collection.name,
284 config: serde_json::from_str(&proto_collection.configuration_json_str)?,
285 schema,
286 metadata: collection_metadata,
287 dimension: proto_collection.dimension,
288 tenant: proto_collection.tenant,
289 database: proto_collection.database,
290 log_position: proto_collection.log_position,
291 version: proto_collection.version,
292 total_records_post_compaction: proto_collection.total_records_post_compaction,
293 size_bytes_post_compaction: proto_collection.size_bytes_post_compaction,
294 last_compaction_time_secs: proto_collection.last_compaction_time_secs,
295 version_file_path: proto_collection.version_file_path,
296 root_collection_id: proto_collection
297 .root_collection_id
298 .map(|uuid| CollectionUuid(Uuid::try_parse(&uuid).unwrap())),
299 lineage_file_path: proto_collection.lineage_file_path,
300 updated_at,
301 database_id,
302 })
303 }
304}
305
306#[derive(Error, Debug)]
307pub enum CollectionToProtoError {
308 #[error("Could not serialize config: {0}")]
309 ConfigSerialization(#[from] serde_json::Error),
310}
311
312impl ChromaError for CollectionToProtoError {
313 fn code(&self) -> ErrorCodes {
314 match self {
315 CollectionToProtoError::ConfigSerialization(_) => ErrorCodes::Internal,
316 }
317 }
318}
319
320impl TryFrom<Collection> for chroma_proto::Collection {
321 type Error = CollectionToProtoError;
322
323 fn try_from(value: Collection) -> Result<Self, Self::Error> {
324 Ok(Self {
325 id: value.collection_id.0.to_string(),
326 name: value.name,
327 configuration_json_str: serde_json::to_string(&value.config)?,
328 schema_str: value
329 .schema
330 .map(|s| serde_json::to_string(&s))
331 .transpose()?,
332 metadata: value.metadata.map(Into::into),
333 dimension: value.dimension,
334 tenant: value.tenant,
335 database: value.database,
336 log_position: value.log_position,
337 version: value.version,
338 total_records_post_compaction: value.total_records_post_compaction,
339 size_bytes_post_compaction: value.size_bytes_post_compaction,
340 last_compaction_time_secs: value.last_compaction_time_secs,
341 version_file_path: value.version_file_path,
342 root_collection_id: value.root_collection_id.map(|uuid| uuid.0.to_string()),
343 lineage_file_path: value.lineage_file_path,
344 updated_at: Some(value.updated_at.into()),
345 database_id: Some(value.database_id.0.to_string()),
346 })
347 }
348}
349
350#[derive(Clone, Debug)]
351pub struct CollectionAndSegments {
352 pub collection: Collection,
353 pub metadata_segment: Segment,
354 pub record_segment: Segment,
355 pub vector_segment: Segment,
356}
357
358impl CollectionAndSegments {
359 pub fn is_uninitialized(&self) -> bool {
362 self.collection.dimension.is_none() && self.vector_segment.file_path.is_empty()
363 }
364
365 pub fn test(dim: i32) -> Self {
366 let collection = Collection::test_collection(dim);
367 let collection_uuid = collection.collection_id;
368 Self {
369 collection,
370 metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA),
371 record_segment: test_segment(collection_uuid, SegmentScope::RECORD),
372 vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR),
373 }
374 }
375}
376
377#[derive(Deserialize, Serialize, Debug, Clone)]
378#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
379pub struct CreateCollectionPayload {
380 pub name: String,
381 pub schema: Option<InternalSchema>,
382 pub configuration: Option<CollectionConfiguration>,
383 pub metadata: Option<Metadata>,
384 #[serde(default)]
385 pub get_or_create: bool,
386}
387
388#[derive(Deserialize, Serialize, Debug, Clone)]
389#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
390pub struct UpdateCollectionPayload {
391 pub new_name: Option<String>,
392 pub new_metadata: Option<UpdateMetadata>,
393 pub new_configuration: Option<UpdateCollectionConfiguration>,
394}
395
396#[cfg(test)]
397mod test {
398 use super::*;
399
400 #[test]
401 fn test_collection_try_from() {
402 let schema = InternalSchema::new_default(crate::KnnIndex::Spann);
404 let schema_str = serde_json::to_string(&schema).unwrap();
405
406 let proto_collection = chroma_proto::Collection {
407 id: "00000000-0000-0000-0000-000000000000".to_string(),
408 name: "foo".to_string(),
409 configuration_json_str: "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}"
410 .to_string(),
411 schema_str: Some(schema_str),
412 metadata: None,
413 dimension: None,
414 tenant: "baz".to_string(),
415 database: "qux".to_string(),
416 log_position: 0,
417 version: 0,
418 total_records_post_compaction: 0,
419 size_bytes_post_compaction: 0,
420 last_compaction_time_secs: 0,
421 version_file_path: Some("version_file_path".to_string()),
422 root_collection_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
423 lineage_file_path: Some("lineage_file_path".to_string()),
424 updated_at: Some(prost_types::Timestamp {
425 seconds: 1,
426 nanos: 1,
427 }),
428 database_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
429 };
430 let converted_collection: Collection = proto_collection.try_into().unwrap();
431 assert_eq!(
432 converted_collection.collection_id,
433 CollectionUuid(Uuid::nil())
434 );
435 assert_eq!(converted_collection.name, "foo".to_string());
436 assert_eq!(converted_collection.metadata, None);
437 assert_eq!(converted_collection.dimension, None);
438 assert_eq!(converted_collection.tenant, "baz".to_string());
439 assert_eq!(converted_collection.database, "qux".to_string());
440 assert_eq!(converted_collection.total_records_post_compaction, 0);
441 assert_eq!(converted_collection.size_bytes_post_compaction, 0);
442 assert_eq!(converted_collection.last_compaction_time_secs, 0);
443 assert_eq!(
444 converted_collection.version_file_path,
445 Some("version_file_path".to_string())
446 );
447 assert_eq!(
448 converted_collection.root_collection_id,
449 Some(CollectionUuid(Uuid::nil()))
450 );
451 assert_eq!(
452 converted_collection.lineage_file_path,
453 Some("lineage_file_path".to_string())
454 );
455 assert_eq!(
456 converted_collection.updated_at,
457 SystemTime::UNIX_EPOCH + Duration::new(1, 1)
458 );
459 assert_eq!(converted_collection.database_id, DatabaseUuid(Uuid::nil()));
460 }
461
462 #[test]
463 fn storage_prefix_for_log_format() {
464 let collection_id = Uuid::parse_str("34e72052-5e60-47cb-be88-19a9715b7026")
465 .map(CollectionUuid)
466 .unwrap();
467 let prefix = collection_id.storage_prefix_for_log();
468 assert_eq!("logs/34e72052-5e60-47cb-be88-19a9715b7026", prefix);
469 }
470}