chroma_types/
collection.rs

1use std::str::FromStr;
2
3use super::{Metadata, MetadataValueConversionError};
4use crate::{
5    chroma_proto, test_segment, CollectionConfiguration, InternalCollectionConfiguration, Schema,
6    SchemaError, Segment, SegmentScope, UpdateCollectionConfiguration, UpdateMetadata,
7};
8use chroma_error::{ChromaError, ErrorCodes};
9use serde::{Deserialize, Serialize};
10use std::time::{Duration, SystemTime};
11use thiserror::Error;
12use uuid::Uuid;
13
14#[cfg(feature = "pyo3")]
15use pyo3::types::PyAnyMethods;
16
17/// CollectionUuid is a wrapper around Uuid to provide a type for the collection id.
18#[derive(
19    Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
20)]
21#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
22pub struct CollectionUuid(pub Uuid);
23
24/// DatabaseUuid is a wrapper around Uuid to provide a type for the database id.
25#[derive(
26    Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
27)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub struct DatabaseUuid(pub Uuid);
30
31impl DatabaseUuid {
32    pub fn new() -> Self {
33        DatabaseUuid(Uuid::new_v4())
34    }
35}
36
37impl CollectionUuid {
38    pub fn new() -> Self {
39        CollectionUuid(Uuid::new_v4())
40    }
41
42    pub fn storage_prefix_for_log(&self) -> String {
43        format!("logs/{}", self)
44    }
45}
46
47impl std::str::FromStr for CollectionUuid {
48    type Err = uuid::Error;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match Uuid::parse_str(s) {
52            Ok(uuid) => Ok(CollectionUuid(uuid)),
53            Err(err) => Err(err),
54        }
55    }
56}
57
58impl std::str::FromStr for DatabaseUuid {
59    type Err = uuid::Error;
60
61    fn from_str(s: &str) -> Result<Self, Self::Err> {
62        match Uuid::parse_str(s) {
63            Ok(uuid) => Ok(DatabaseUuid(uuid)),
64            Err(err) => Err(err),
65        }
66    }
67}
68
69impl std::fmt::Display for DatabaseUuid {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "{}", self.0)
72    }
73}
74
75impl std::fmt::Display for CollectionUuid {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", self.0)
78    }
79}
80
81fn serialize_internal_collection_configuration<S: serde::Serializer>(
82    config: &InternalCollectionConfiguration,
83    serializer: S,
84) -> Result<S::Ok, S::Error> {
85    let collection_config: CollectionConfiguration = config.clone().into();
86    collection_config.serialize(serializer)
87}
88
89fn deserialize_internal_collection_configuration<'de, D: serde::Deserializer<'de>>(
90    deserializer: D,
91) -> Result<InternalCollectionConfiguration, D::Error> {
92    let collection_config = CollectionConfiguration::deserialize(deserializer)?;
93    collection_config
94        .try_into()
95        .map_err(serde::de::Error::custom)
96}
97
98#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
100#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
101pub struct Collection {
102    #[serde(rename = "id")]
103    pub collection_id: CollectionUuid,
104    pub name: String,
105    #[serde(
106        serialize_with = "serialize_internal_collection_configuration",
107        deserialize_with = "deserialize_internal_collection_configuration",
108        rename = "configuration_json"
109    )]
110    #[cfg_attr(feature = "utoipa", schema(value_type = CollectionConfiguration))]
111    pub config: InternalCollectionConfiguration,
112    pub schema: Option<Schema>,
113    pub metadata: Option<Metadata>,
114    pub dimension: Option<i32>,
115    pub tenant: String,
116    pub database: String,
117    pub log_position: i64,
118    pub version: i32,
119    #[serde(skip)]
120    pub total_records_post_compaction: u64,
121    #[serde(skip)]
122    pub size_bytes_post_compaction: u64,
123    #[serde(skip)]
124    pub last_compaction_time_secs: u64,
125    #[serde(skip)]
126    pub version_file_path: Option<String>,
127    #[serde(skip)]
128    pub root_collection_id: Option<CollectionUuid>,
129    #[serde(skip)]
130    pub lineage_file_path: Option<String>,
131    #[serde(skip, default = "SystemTime::now")]
132    pub updated_at: SystemTime,
133    #[serde(skip)]
134    pub database_id: DatabaseUuid,
135}
136
137impl Default for Collection {
138    fn default() -> Self {
139        Self {
140            collection_id: CollectionUuid::new(),
141            name: "".to_string(),
142            config: InternalCollectionConfiguration::default_hnsw(),
143            schema: None,
144            metadata: None,
145            dimension: None,
146            tenant: "".to_string(),
147            database: "".to_string(),
148            log_position: 0,
149            version: 0,
150            total_records_post_compaction: 0,
151            size_bytes_post_compaction: 0,
152            last_compaction_time_secs: 0,
153            version_file_path: None,
154            root_collection_id: None,
155            lineage_file_path: None,
156            updated_at: SystemTime::now(),
157            database_id: DatabaseUuid::new(),
158        }
159    }
160}
161
162#[cfg(feature = "pyo3")]
163#[pyo3::pymethods]
164impl Collection {
165    #[getter]
166    fn id<'py>(&self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
167        let res = pyo3::prelude::PyModule::import(py, "uuid")?
168            .getattr("UUID")?
169            .call1((self.collection_id.to_string(),))?;
170        Ok(res)
171    }
172
173    #[getter]
174    fn configuration<'py>(
175        &self,
176        py: pyo3::Python<'py>,
177    ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
178        let config: crate::CollectionConfiguration = self.config.clone().into();
179        let config_json_str = serde_json::to_string(&config).unwrap();
180        let res = pyo3::prelude::PyModule::import(py, "json")?
181            .getattr("loads")?
182            .call1((config_json_str,))?;
183        Ok(res)
184    }
185
186    #[getter]
187    pub fn name(&self) -> &str {
188        &self.name
189    }
190
191    #[getter]
192    pub fn metadata(&self) -> Option<Metadata> {
193        self.metadata.clone()
194    }
195
196    #[getter]
197    pub fn dimension(&self) -> Option<i32> {
198        self.dimension
199    }
200
201    #[getter]
202    pub fn tenant(&self) -> &str {
203        &self.tenant
204    }
205
206    #[getter]
207    pub fn database(&self) -> &str {
208        &self.database
209    }
210}
211
212impl Collection {
213    /// Reconcile the collection schema and configuration, ensuring both are consistent.
214    pub fn reconcile_schema_with_config(&mut self) -> Result<(), SchemaError> {
215        let reconciled_schema =
216            Schema::reconcile_schema_and_config(self.schema.clone(), Some(self.config.clone()))?;
217
218        self.config = InternalCollectionConfiguration::try_from(&reconciled_schema)
219            .map_err(|reason| SchemaError::InvalidSchema { reason })?;
220        self.schema = Some(reconciled_schema);
221
222        Ok(())
223    }
224
225    pub fn test_collection(dim: i32) -> Self {
226        Collection {
227            name: "test_collection".to_string(),
228            dimension: Some(dim),
229            tenant: "default_tenant".to_string(),
230            database: "default_database".to_string(),
231            database_id: DatabaseUuid::new(),
232            ..Default::default()
233        }
234    }
235}
236
237#[derive(Error, Debug)]
238pub enum CollectionConversionError {
239    #[error("Invalid config: {0}")]
240    InvalidConfig(#[from] serde_json::Error),
241    #[error("Invalid UUID")]
242    InvalidUuid,
243    #[error(transparent)]
244    MetadataValueConversionError(#[from] MetadataValueConversionError),
245    #[error("Missing Database Id")]
246    MissingDatabaseId,
247}
248
249impl ChromaError for CollectionConversionError {
250    fn code(&self) -> ErrorCodes {
251        match self {
252            CollectionConversionError::InvalidConfig(_) => ErrorCodes::InvalidArgument,
253            CollectionConversionError::InvalidUuid => ErrorCodes::InvalidArgument,
254            CollectionConversionError::MetadataValueConversionError(e) => e.code(),
255            CollectionConversionError::MissingDatabaseId => ErrorCodes::Internal,
256        }
257    }
258}
259
260impl TryFrom<chroma_proto::Collection> for Collection {
261    type Error = CollectionConversionError;
262
263    fn try_from(proto_collection: chroma_proto::Collection) -> Result<Self, Self::Error> {
264        let collection_id = CollectionUuid::from_str(&proto_collection.id)
265            .map_err(|_| CollectionConversionError::InvalidUuid)?;
266        let collection_metadata: Option<Metadata> = match proto_collection.metadata {
267            Some(proto_metadata) => match proto_metadata.try_into() {
268                Ok(metadata) => Some(metadata),
269                Err(e) => return Err(CollectionConversionError::MetadataValueConversionError(e)),
270            },
271            None => None,
272        };
273        // TODO(@codetheweb): this be updated to error with "missing field" once all SysDb deployments are up-to-date
274        let updated_at = match proto_collection.updated_at {
275            Some(updated_at) => {
276                SystemTime::UNIX_EPOCH
277                    + Duration::new(updated_at.seconds as u64, updated_at.nanos as u32)
278            }
279            None => SystemTime::now(),
280        };
281        let database_id = match proto_collection.database_id {
282            Some(db_id) => DatabaseUuid::from_str(&db_id)
283                .map_err(|_| CollectionConversionError::InvalidUuid)?,
284            None => {
285                return Err(CollectionConversionError::MissingDatabaseId);
286            }
287        };
288        let schema = match proto_collection.schema_str {
289            Some(schema_str) if !schema_str.is_empty() => Some(serde_json::from_str(&schema_str)?),
290            _ => None,
291        };
292
293        Ok(Collection {
294            collection_id,
295            name: proto_collection.name,
296            config: serde_json::from_str(&proto_collection.configuration_json_str)?,
297            schema,
298            metadata: collection_metadata,
299            dimension: proto_collection.dimension,
300            tenant: proto_collection.tenant,
301            database: proto_collection.database,
302            log_position: proto_collection.log_position,
303            version: proto_collection.version,
304            total_records_post_compaction: proto_collection.total_records_post_compaction,
305            size_bytes_post_compaction: proto_collection.size_bytes_post_compaction,
306            last_compaction_time_secs: proto_collection.last_compaction_time_secs,
307            version_file_path: proto_collection.version_file_path,
308            root_collection_id: proto_collection
309                .root_collection_id
310                .map(|uuid| CollectionUuid(Uuid::try_parse(&uuid).unwrap())),
311            lineage_file_path: proto_collection.lineage_file_path,
312            updated_at,
313            database_id,
314        })
315    }
316}
317
318#[derive(Error, Debug)]
319pub enum CollectionToProtoError {
320    #[error("Could not serialize config: {0}")]
321    ConfigSerialization(#[from] serde_json::Error),
322}
323
324impl ChromaError for CollectionToProtoError {
325    fn code(&self) -> ErrorCodes {
326        match self {
327            CollectionToProtoError::ConfigSerialization(_) => ErrorCodes::Internal,
328        }
329    }
330}
331
332impl TryFrom<Collection> for chroma_proto::Collection {
333    type Error = CollectionToProtoError;
334
335    fn try_from(value: Collection) -> Result<Self, Self::Error> {
336        Ok(Self {
337            id: value.collection_id.0.to_string(),
338            name: value.name,
339            configuration_json_str: serde_json::to_string(&value.config)?,
340            schema_str: value
341                .schema
342                .map(|s| serde_json::to_string(&s))
343                .transpose()?,
344            metadata: value.metadata.map(Into::into),
345            dimension: value.dimension,
346            tenant: value.tenant,
347            database: value.database,
348            log_position: value.log_position,
349            version: value.version,
350            total_records_post_compaction: value.total_records_post_compaction,
351            size_bytes_post_compaction: value.size_bytes_post_compaction,
352            last_compaction_time_secs: value.last_compaction_time_secs,
353            version_file_path: value.version_file_path,
354            root_collection_id: value.root_collection_id.map(|uuid| uuid.0.to_string()),
355            lineage_file_path: value.lineage_file_path,
356            updated_at: Some(value.updated_at.into()),
357            database_id: Some(value.database_id.0.to_string()),
358        })
359    }
360}
361
362#[derive(Clone, Debug)]
363pub struct CollectionAndSegments {
364    pub collection: Collection,
365    pub metadata_segment: Segment,
366    pub record_segment: Segment,
367    pub vector_segment: Segment,
368}
369
370impl CollectionAndSegments {
371    // If dimension is not set and vector segment has no files,
372    // we assume this is an uninitialized collection
373    pub fn is_uninitialized(&self) -> bool {
374        self.collection.dimension.is_none() && self.vector_segment.file_path.is_empty()
375    }
376
377    pub fn test(dim: i32) -> Self {
378        let collection = Collection::test_collection(dim);
379        let collection_uuid = collection.collection_id;
380        Self {
381            collection,
382            metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA),
383            record_segment: test_segment(collection_uuid, SegmentScope::RECORD),
384            vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR),
385        }
386    }
387}
388
389#[derive(Deserialize, Serialize, Debug, Clone)]
390#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
391pub struct CreateCollectionPayload {
392    pub name: String,
393    pub schema: Option<Schema>,
394    pub configuration: Option<CollectionConfiguration>,
395    pub metadata: Option<Metadata>,
396    #[serde(default)]
397    pub get_or_create: bool,
398}
399
400#[derive(Deserialize, Serialize, Debug, Clone)]
401#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
402pub struct UpdateCollectionPayload {
403    pub new_name: Option<String>,
404    pub new_metadata: Option<UpdateMetadata>,
405    pub new_configuration: Option<UpdateCollectionConfiguration>,
406}
407
408#[cfg(test)]
409mod test {
410    use super::*;
411
412    #[test]
413    fn test_collection_try_from() {
414        // Create a valid Schema and serialize it
415        let schema = Schema::new_default(crate::KnnIndex::Spann);
416        let schema_str = serde_json::to_string(&schema).unwrap();
417
418        let proto_collection = chroma_proto::Collection {
419            id: "00000000-0000-0000-0000-000000000000".to_string(),
420            name: "foo".to_string(),
421            configuration_json_str: "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}"
422                .to_string(),
423            schema_str: Some(schema_str),
424            metadata: None,
425            dimension: None,
426            tenant: "baz".to_string(),
427            database: "qux".to_string(),
428            log_position: 0,
429            version: 0,
430            total_records_post_compaction: 0,
431            size_bytes_post_compaction: 0,
432            last_compaction_time_secs: 0,
433            version_file_path: Some("version_file_path".to_string()),
434            root_collection_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
435            lineage_file_path: Some("lineage_file_path".to_string()),
436            updated_at: Some(prost_types::Timestamp {
437                seconds: 1,
438                nanos: 1,
439            }),
440            database_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
441        };
442        let converted_collection: Collection = proto_collection.try_into().unwrap();
443        assert_eq!(
444            converted_collection.collection_id,
445            CollectionUuid(Uuid::nil())
446        );
447        assert_eq!(converted_collection.name, "foo".to_string());
448        assert_eq!(converted_collection.metadata, None);
449        assert_eq!(converted_collection.dimension, None);
450        assert_eq!(converted_collection.tenant, "baz".to_string());
451        assert_eq!(converted_collection.database, "qux".to_string());
452        assert_eq!(converted_collection.total_records_post_compaction, 0);
453        assert_eq!(converted_collection.size_bytes_post_compaction, 0);
454        assert_eq!(converted_collection.last_compaction_time_secs, 0);
455        assert_eq!(
456            converted_collection.version_file_path,
457            Some("version_file_path".to_string())
458        );
459        assert_eq!(
460            converted_collection.root_collection_id,
461            Some(CollectionUuid(Uuid::nil()))
462        );
463        assert_eq!(
464            converted_collection.lineage_file_path,
465            Some("lineage_file_path".to_string())
466        );
467        assert_eq!(
468            converted_collection.updated_at,
469            SystemTime::UNIX_EPOCH + Duration::new(1, 1)
470        );
471        assert_eq!(converted_collection.database_id, DatabaseUuid(Uuid::nil()));
472    }
473
474    #[test]
475    fn storage_prefix_for_log_format() {
476        let collection_id = Uuid::parse_str("34e72052-5e60-47cb-be88-19a9715b7026")
477            .map(CollectionUuid)
478            .unwrap();
479        let prefix = collection_id.storage_prefix_for_log();
480        assert_eq!("logs/34e72052-5e60-47cb-be88-19a9715b7026", prefix);
481    }
482}