1use std::str::FromStr;
2use std::time::{Duration, SystemTime};
3
4use super::{Metadata, MetadataValueConversionError};
5use crate::{
6 chroma_proto, test_segment, CollectionConfiguration, InternalCollectionConfiguration, Schema,
7 SchemaError, Segment, SegmentScope, UpdateCollectionConfiguration, UpdateMetadata,
8};
9use chroma_error::{ChromaError, ErrorCodes};
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12use uuid::Uuid;
13
14#[cfg(feature = "pyo3")]
15use pyo3::{exceptions::PyValueError, types::PyAnyMethods};
16
17#[derive(
19 Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
20)]
21#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
22pub struct CollectionUuid(pub Uuid);
23
24#[derive(
26 Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
27)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub struct DatabaseUuid(pub Uuid);
30
31impl DatabaseUuid {
32 pub fn new() -> Self {
33 DatabaseUuid(Uuid::new_v4())
34 }
35}
36
37impl CollectionUuid {
38 pub fn new() -> Self {
39 CollectionUuid(Uuid::new_v4())
40 }
41
42 pub fn storage_prefix_for_log(&self) -> String {
43 format!("logs/{}", self)
44 }
45}
46
47impl std::str::FromStr for CollectionUuid {
48 type Err = uuid::Error;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match Uuid::parse_str(s) {
52 Ok(uuid) => Ok(CollectionUuid(uuid)),
53 Err(err) => Err(err),
54 }
55 }
56}
57
58impl std::str::FromStr for DatabaseUuid {
59 type Err = uuid::Error;
60
61 fn from_str(s: &str) -> Result<Self, Self::Err> {
62 match Uuid::parse_str(s) {
63 Ok(uuid) => Ok(DatabaseUuid(uuid)),
64 Err(err) => Err(err),
65 }
66 }
67}
68
69impl std::fmt::Display for DatabaseUuid {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{}", self.0)
72 }
73}
74
75impl std::fmt::Display for CollectionUuid {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 write!(f, "{}", self.0)
78 }
79}
80
81fn serialize_internal_collection_configuration<S: serde::Serializer>(
82 config: &InternalCollectionConfiguration,
83 serializer: S,
84) -> Result<S::Ok, S::Error> {
85 let collection_config: CollectionConfiguration = config.clone().into();
86 collection_config.serialize(serializer)
87}
88
89fn deserialize_internal_collection_configuration<'de, D: serde::Deserializer<'de>>(
90 deserializer: D,
91) -> Result<InternalCollectionConfiguration, D::Error> {
92 let collection_config = CollectionConfiguration::deserialize(deserializer)?;
93 collection_config
94 .try_into()
95 .map_err(serde::de::Error::custom)
96}
97
98#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
100#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
101pub struct Collection {
102 #[serde(rename = "id")]
103 pub collection_id: CollectionUuid,
104 pub name: String,
105 #[serde(
106 serialize_with = "serialize_internal_collection_configuration",
107 deserialize_with = "deserialize_internal_collection_configuration",
108 rename = "configuration_json"
109 )]
110 #[cfg_attr(feature = "utoipa", schema(value_type = CollectionConfiguration))]
111 pub config: InternalCollectionConfiguration,
112 pub schema: Option<Schema>,
113 pub metadata: Option<Metadata>,
114 pub dimension: Option<i32>,
115 pub tenant: String,
116 pub database: String,
117 pub log_position: i64,
118 pub version: i32,
119 #[serde(skip)]
120 pub total_records_post_compaction: u64,
121 #[serde(skip)]
122 pub size_bytes_post_compaction: u64,
123 #[serde(skip)]
124 pub last_compaction_time_secs: u64,
125 #[serde(skip)]
126 pub version_file_path: Option<String>,
127 #[serde(skip)]
128 pub root_collection_id: Option<CollectionUuid>,
129 #[serde(skip)]
130 pub lineage_file_path: Option<String>,
131 #[serde(skip, default = "SystemTime::now")]
132 pub updated_at: SystemTime,
133 #[serde(skip)]
134 pub database_id: DatabaseUuid,
135 #[serde(skip)]
138 pub compaction_failure_count: i32,
139}
140
141impl Default for Collection {
142 fn default() -> Self {
143 Self {
144 collection_id: CollectionUuid::new(),
145 name: "".to_string(),
146 config: InternalCollectionConfiguration::default_hnsw(),
147 schema: None,
148 metadata: None,
149 dimension: None,
150 tenant: "".to_string(),
151 database: "".to_string(),
152 log_position: 0,
153 version: 0,
154 total_records_post_compaction: 0,
155 size_bytes_post_compaction: 0,
156 last_compaction_time_secs: 0,
157 version_file_path: None,
158 root_collection_id: None,
159 lineage_file_path: None,
160 updated_at: SystemTime::now(),
161 database_id: DatabaseUuid::new(),
162 compaction_failure_count: 0,
163 }
164 }
165}
166
167#[cfg(feature = "pyo3")]
168#[pyo3::pymethods]
169impl Collection {
170 #[getter]
171 fn id<'py>(&self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
172 let res = pyo3::prelude::PyModule::import(py, "uuid")?
173 .getattr("UUID")?
174 .call1((self.collection_id.to_string(),))?;
175 Ok(res)
176 }
177
178 #[getter]
179 fn configuration<'py>(
180 &self,
181 py: pyo3::Python<'py>,
182 ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
183 let config: crate::CollectionConfiguration = self.config.clone().into();
184 let config_json_str = serde_json::to_string(&config).unwrap();
185 let res = pyo3::prelude::PyModule::import(py, "json")?
186 .getattr("loads")?
187 .call1((config_json_str,))?;
188 Ok(res)
189 }
190
191 #[getter]
192 fn schema<'py>(
193 &self,
194 py: pyo3::Python<'py>,
195 ) -> pyo3::PyResult<Option<pyo3::Bound<'py, pyo3::PyAny>>> {
196 match self.schema.as_ref() {
197 Some(schema) => {
198 let schema_json = serde_json::to_string(schema)
199 .map_err(|err| PyValueError::new_err(err.to_string()))?;
200 let res = pyo3::prelude::PyModule::import(py, "json")?
201 .getattr("loads")?
202 .call1((schema_json,))?;
203 Ok(Some(res))
204 }
205 None => Ok(None),
206 }
207 }
208
209 #[getter]
210 pub fn name(&self) -> &str {
211 &self.name
212 }
213
214 #[getter]
215 pub fn metadata(&self) -> Option<Metadata> {
216 self.metadata.clone()
217 }
218
219 #[getter]
220 pub fn dimension(&self) -> Option<i32> {
221 self.dimension
222 }
223
224 #[getter]
225 pub fn tenant(&self) -> &str {
226 &self.tenant
227 }
228
229 #[getter]
230 pub fn database(&self) -> &str {
231 &self.database
232 }
233}
234
235impl Collection {
236 pub fn reconcile_schema_for_read(&mut self) -> Result<(), SchemaError> {
242 if let Some(schema) = self.schema.as_ref() {
243 self.config = InternalCollectionConfiguration::try_from(schema)
244 .map_err(|reason| SchemaError::InvalidSchema { reason })?;
245 } else {
246 self.schema = Some(Schema::try_from(&self.config)?);
247 }
248
249 Ok(())
250 }
251
252 pub fn test_collection(dim: i32) -> Self {
253 Collection {
254 name: "test_collection".to_string(),
255 dimension: Some(dim),
256 tenant: "default_tenant".to_string(),
257 database: "default_database".to_string(),
258 database_id: DatabaseUuid::new(),
259 ..Default::default()
260 }
261 }
262}
263
264#[derive(Error, Debug)]
265pub enum CollectionConversionError {
266 #[error("Invalid config: {0}")]
267 InvalidConfig(#[from] serde_json::Error),
268 #[error("Invalid UUID")]
269 InvalidUuid,
270 #[error(transparent)]
271 MetadataValueConversionError(#[from] MetadataValueConversionError),
272 #[error("Missing Database Id")]
273 MissingDatabaseId,
274}
275
276impl ChromaError for CollectionConversionError {
277 fn code(&self) -> ErrorCodes {
278 match self {
279 CollectionConversionError::InvalidConfig(_) => ErrorCodes::InvalidArgument,
280 CollectionConversionError::InvalidUuid => ErrorCodes::InvalidArgument,
281 CollectionConversionError::MetadataValueConversionError(e) => e.code(),
282 CollectionConversionError::MissingDatabaseId => ErrorCodes::Internal,
283 }
284 }
285}
286
287impl TryFrom<chroma_proto::Collection> for Collection {
288 type Error = CollectionConversionError;
289
290 fn try_from(proto_collection: chroma_proto::Collection) -> Result<Self, Self::Error> {
291 let collection_id = CollectionUuid::from_str(&proto_collection.id)
292 .map_err(|_| CollectionConversionError::InvalidUuid)?;
293 let collection_metadata: Option<Metadata> = match proto_collection.metadata {
294 Some(proto_metadata) => match proto_metadata.try_into() {
295 Ok(metadata) => Some(metadata),
296 Err(e) => return Err(CollectionConversionError::MetadataValueConversionError(e)),
297 },
298 None => None,
299 };
300 let updated_at = match proto_collection.updated_at {
302 Some(updated_at) => {
303 SystemTime::UNIX_EPOCH
304 + Duration::new(updated_at.seconds as u64, updated_at.nanos as u32)
305 }
306 None => SystemTime::now(),
307 };
308 let database_id = match proto_collection.database_id {
309 Some(db_id) => DatabaseUuid::from_str(&db_id)
310 .map_err(|_| CollectionConversionError::InvalidUuid)?,
311 None => {
312 return Err(CollectionConversionError::MissingDatabaseId);
313 }
314 };
315 let schema = match proto_collection.schema_str {
316 Some(schema_str) if !schema_str.is_empty() => Some(serde_json::from_str(&schema_str)?),
317 _ => None,
318 };
319
320 Ok(Collection {
321 collection_id,
322 name: proto_collection.name,
323 config: serde_json::from_str(&proto_collection.configuration_json_str)?,
324 schema,
325 metadata: collection_metadata,
326 dimension: proto_collection.dimension,
327 tenant: proto_collection.tenant,
328 database: proto_collection.database,
329 log_position: proto_collection.log_position,
330 version: proto_collection.version,
331 total_records_post_compaction: proto_collection.total_records_post_compaction,
332 size_bytes_post_compaction: proto_collection.size_bytes_post_compaction,
333 last_compaction_time_secs: proto_collection.last_compaction_time_secs,
334 version_file_path: proto_collection.version_file_path,
335 root_collection_id: proto_collection
336 .root_collection_id
337 .map(|uuid| CollectionUuid(Uuid::try_parse(&uuid).unwrap())),
338 lineage_file_path: proto_collection.lineage_file_path,
339 updated_at,
340 database_id,
341 compaction_failure_count: proto_collection.compaction_failure_count,
342 })
343 }
344}
345
346#[derive(Error, Debug)]
347pub enum CollectionToProtoError {
348 #[error("Could not serialize config: {0}")]
349 ConfigSerialization(#[from] serde_json::Error),
350}
351
352impl ChromaError for CollectionToProtoError {
353 fn code(&self) -> ErrorCodes {
354 match self {
355 CollectionToProtoError::ConfigSerialization(_) => ErrorCodes::Internal,
356 }
357 }
358}
359
360impl TryFrom<Collection> for chroma_proto::Collection {
361 type Error = CollectionToProtoError;
362
363 fn try_from(value: Collection) -> Result<Self, Self::Error> {
364 Ok(Self {
365 id: value.collection_id.0.to_string(),
366 name: value.name,
367 configuration_json_str: serde_json::to_string(&value.config)?,
368 schema_str: value
369 .schema
370 .map(|s| serde_json::to_string(&s))
371 .transpose()?,
372 metadata: value.metadata.map(Into::into),
373 dimension: value.dimension,
374 tenant: value.tenant,
375 database: value.database,
376 log_position: value.log_position,
377 version: value.version,
378 total_records_post_compaction: value.total_records_post_compaction,
379 size_bytes_post_compaction: value.size_bytes_post_compaction,
380 last_compaction_time_secs: value.last_compaction_time_secs,
381 version_file_path: value.version_file_path,
382 root_collection_id: value.root_collection_id.map(|uuid| uuid.0.to_string()),
383 lineage_file_path: value.lineage_file_path,
384 updated_at: Some(value.updated_at.into()),
385 database_id: Some(value.database_id.0.to_string()),
386 compaction_failure_count: value.compaction_failure_count,
387 })
388 }
389}
390
391#[derive(Clone, Debug)]
392pub struct CollectionAndSegments {
393 pub collection: Collection,
394 pub metadata_segment: Segment,
395 pub record_segment: Segment,
396 pub vector_segment: Segment,
397}
398
399impl CollectionAndSegments {
400 pub fn is_uninitialized(&self) -> bool {
403 self.collection.dimension.is_none() && self.vector_segment.file_path.is_empty()
404 }
405
406 pub fn test(dim: i32) -> Self {
407 let collection = Collection::test_collection(dim);
408 let collection_uuid = collection.collection_id;
409 Self {
410 collection,
411 metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA),
412 record_segment: test_segment(collection_uuid, SegmentScope::RECORD),
413 vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR),
414 }
415 }
416}
417
418#[derive(Deserialize, Serialize, Debug, Clone)]
419#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
420pub struct CreateCollectionPayload {
421 pub name: String,
422 pub schema: Option<Schema>,
423 pub configuration: Option<CollectionConfiguration>,
424 pub metadata: Option<Metadata>,
425 #[serde(default)]
426 pub get_or_create: bool,
427}
428
429#[derive(Deserialize, Serialize, Debug, Clone)]
430#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
431pub struct UpdateCollectionPayload {
432 pub new_name: Option<String>,
433 pub new_metadata: Option<UpdateMetadata>,
434 pub new_configuration: Option<UpdateCollectionConfiguration>,
435}
436
437#[cfg(test)]
438mod test {
439 use super::*;
440
441 #[test]
442 fn test_collection_try_from() {
443 let schema = Schema::new_default(crate::KnnIndex::Spann);
445 let schema_str = serde_json::to_string(&schema).unwrap();
446
447 let proto_collection = chroma_proto::Collection {
448 id: "00000000-0000-0000-0000-000000000000".to_string(),
449 name: "foo".to_string(),
450 configuration_json_str: "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}"
451 .to_string(),
452 schema_str: Some(schema_str),
453 metadata: None,
454 dimension: None,
455 tenant: "baz".to_string(),
456 database: "qux".to_string(),
457 log_position: 0,
458 version: 0,
459 total_records_post_compaction: 0,
460 size_bytes_post_compaction: 0,
461 last_compaction_time_secs: 0,
462 version_file_path: Some("version_file_path".to_string()),
463 root_collection_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
464 lineage_file_path: Some("lineage_file_path".to_string()),
465 updated_at: Some(prost_types::Timestamp {
466 seconds: 1,
467 nanos: 1,
468 }),
469 database_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
470 compaction_failure_count: 0,
471 };
472 let converted_collection: Collection = proto_collection.try_into().unwrap();
473 assert_eq!(
474 converted_collection.collection_id,
475 CollectionUuid(Uuid::nil())
476 );
477 assert_eq!(converted_collection.name, "foo".to_string());
478 assert_eq!(converted_collection.metadata, None);
479 assert_eq!(converted_collection.dimension, None);
480 assert_eq!(converted_collection.tenant, "baz".to_string());
481 assert_eq!(converted_collection.database, "qux".to_string());
482 assert_eq!(converted_collection.total_records_post_compaction, 0);
483 assert_eq!(converted_collection.size_bytes_post_compaction, 0);
484 assert_eq!(converted_collection.last_compaction_time_secs, 0);
485 assert_eq!(
486 converted_collection.version_file_path,
487 Some("version_file_path".to_string())
488 );
489 assert_eq!(
490 converted_collection.root_collection_id,
491 Some(CollectionUuid(Uuid::nil()))
492 );
493 assert_eq!(
494 converted_collection.lineage_file_path,
495 Some("lineage_file_path".to_string())
496 );
497 assert_eq!(
498 converted_collection.updated_at,
499 SystemTime::UNIX_EPOCH + Duration::new(1, 1)
500 );
501 assert_eq!(converted_collection.database_id, DatabaseUuid(Uuid::nil()));
502 }
503
504 #[test]
505 fn storage_prefix_for_log_format() {
506 let collection_id = Uuid::parse_str("34e72052-5e60-47cb-be88-19a9715b7026")
507 .map(CollectionUuid)
508 .unwrap();
509 let prefix = collection_id.storage_prefix_for_log();
510 assert_eq!("logs/34e72052-5e60-47cb-be88-19a9715b7026", prefix);
511 }
512}