Skip to main content

chroma_types/
collection.rs

1use std::str::FromStr;
2use std::time::{Duration, SystemTime};
3
4use super::{Metadata, MetadataValueConversionError};
5use crate::{
6    chroma_proto, test_segment, CollectionConfiguration, InternalCollectionConfiguration, Schema,
7    SchemaError, Segment, SegmentScope, UpdateCollectionConfiguration, UpdateMetadata,
8};
9use chroma_error::{ChromaError, ErrorCodes};
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12use uuid::Uuid;
13
14#[cfg(feature = "pyo3")]
15use pyo3::{exceptions::PyValueError, types::PyAnyMethods};
16
17// CollectionUuid is a wrapper around Uuid to provide a type for the collection id.
18#[derive(
19    Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
20)]
21#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
22pub struct CollectionUuid(pub Uuid);
23
24/// DatabaseUuid is a wrapper around Uuid to provide a type for the database id.
25#[derive(
26    Copy, Clone, Debug, Default, Deserialize, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize,
27)]
28#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
29pub struct DatabaseUuid(pub Uuid);
30
31impl DatabaseUuid {
32    pub fn new() -> Self {
33        DatabaseUuid(Uuid::new_v4())
34    }
35}
36
37impl CollectionUuid {
38    pub fn new() -> Self {
39        CollectionUuid(Uuid::new_v4())
40    }
41
42    pub fn storage_prefix_for_log(&self) -> String {
43        format!("logs/{}", self)
44    }
45}
46
47impl std::str::FromStr for CollectionUuid {
48    type Err = uuid::Error;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match Uuid::parse_str(s) {
52            Ok(uuid) => Ok(CollectionUuid(uuid)),
53            Err(err) => Err(err),
54        }
55    }
56}
57
58impl std::str::FromStr for DatabaseUuid {
59    type Err = uuid::Error;
60
61    fn from_str(s: &str) -> Result<Self, Self::Err> {
62        match Uuid::parse_str(s) {
63            Ok(uuid) => Ok(DatabaseUuid(uuid)),
64            Err(err) => Err(err),
65        }
66    }
67}
68
69impl std::fmt::Display for DatabaseUuid {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(f, "{}", self.0)
72    }
73}
74
75impl std::fmt::Display for CollectionUuid {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        write!(f, "{}", self.0)
78    }
79}
80
81fn serialize_internal_collection_configuration<S: serde::Serializer>(
82    config: &InternalCollectionConfiguration,
83    serializer: S,
84) -> Result<S::Ok, S::Error> {
85    let collection_config: CollectionConfiguration = config.clone().into();
86    collection_config.serialize(serializer)
87}
88
89fn deserialize_internal_collection_configuration<'de, D: serde::Deserializer<'de>>(
90    deserializer: D,
91) -> Result<InternalCollectionConfiguration, D::Error> {
92    let collection_config = CollectionConfiguration::deserialize(deserializer)?;
93    collection_config
94        .try_into()
95        .map_err(serde::de::Error::custom)
96}
97
98#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
99#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
100#[cfg_attr(feature = "pyo3", pyo3::pyclass)]
101pub struct Collection {
102    #[serde(rename = "id")]
103    pub collection_id: CollectionUuid,
104    pub name: String,
105    #[serde(
106        serialize_with = "serialize_internal_collection_configuration",
107        deserialize_with = "deserialize_internal_collection_configuration",
108        rename = "configuration_json"
109    )]
110    #[cfg_attr(feature = "utoipa", schema(value_type = CollectionConfiguration))]
111    pub config: InternalCollectionConfiguration,
112    pub schema: Option<Schema>,
113    pub metadata: Option<Metadata>,
114    pub dimension: Option<i32>,
115    pub tenant: String,
116    pub database: String,
117    pub log_position: i64,
118    pub version: i32,
119    #[serde(skip)]
120    pub total_records_post_compaction: u64,
121    #[serde(skip)]
122    pub size_bytes_post_compaction: u64,
123    #[serde(skip)]
124    pub last_compaction_time_secs: u64,
125    #[serde(skip)]
126    pub version_file_path: Option<String>,
127    #[serde(skip)]
128    pub root_collection_id: Option<CollectionUuid>,
129    #[serde(skip)]
130    pub lineage_file_path: Option<String>,
131    #[serde(skip, default = "SystemTime::now")]
132    pub updated_at: SystemTime,
133    #[serde(skip)]
134    pub database_id: DatabaseUuid,
135    /// Number of consecutive compaction failures for this collection.
136    /// Used by the scheduler to track and skip collections that repeatedly fail compaction.
137    #[serde(skip)]
138    pub compaction_failure_count: i32,
139}
140
141impl Default for Collection {
142    fn default() -> Self {
143        Self {
144            collection_id: CollectionUuid::new(),
145            name: "".to_string(),
146            config: InternalCollectionConfiguration::default_hnsw(),
147            schema: None,
148            metadata: None,
149            dimension: None,
150            tenant: "".to_string(),
151            database: "".to_string(),
152            log_position: 0,
153            version: 0,
154            total_records_post_compaction: 0,
155            size_bytes_post_compaction: 0,
156            last_compaction_time_secs: 0,
157            version_file_path: None,
158            root_collection_id: None,
159            lineage_file_path: None,
160            updated_at: SystemTime::now(),
161            database_id: DatabaseUuid::new(),
162            compaction_failure_count: 0,
163        }
164    }
165}
166
167#[cfg(feature = "pyo3")]
168#[pyo3::pymethods]
169impl Collection {
170    #[getter]
171    fn id<'py>(&self, py: pyo3::Python<'py>) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
172        let res = pyo3::prelude::PyModule::import(py, "uuid")?
173            .getattr("UUID")?
174            .call1((self.collection_id.to_string(),))?;
175        Ok(res)
176    }
177
178    #[getter]
179    fn configuration<'py>(
180        &self,
181        py: pyo3::Python<'py>,
182    ) -> pyo3::PyResult<pyo3::Bound<'py, pyo3::PyAny>> {
183        let config: crate::CollectionConfiguration = self.config.clone().into();
184        let config_json_str = serde_json::to_string(&config).unwrap();
185        let res = pyo3::prelude::PyModule::import(py, "json")?
186            .getattr("loads")?
187            .call1((config_json_str,))?;
188        Ok(res)
189    }
190
191    #[getter]
192    fn schema<'py>(
193        &self,
194        py: pyo3::Python<'py>,
195    ) -> pyo3::PyResult<Option<pyo3::Bound<'py, pyo3::PyAny>>> {
196        match self.schema.as_ref() {
197            Some(schema) => {
198                let schema_json = serde_json::to_string(schema)
199                    .map_err(|err| PyValueError::new_err(err.to_string()))?;
200                let res = pyo3::prelude::PyModule::import(py, "json")?
201                    .getattr("loads")?
202                    .call1((schema_json,))?;
203                Ok(Some(res))
204            }
205            None => Ok(None),
206        }
207    }
208
209    #[getter]
210    pub fn name(&self) -> &str {
211        &self.name
212    }
213
214    #[getter]
215    pub fn metadata(&self) -> Option<Metadata> {
216        self.metadata.clone()
217    }
218
219    #[getter]
220    pub fn dimension(&self) -> Option<i32> {
221        self.dimension
222    }
223
224    #[getter]
225    pub fn tenant(&self) -> &str {
226        &self.tenant
227    }
228
229    #[getter]
230    pub fn database(&self) -> &str {
231        &self.database
232    }
233}
234
235impl Collection {
236    /// Reconcile the collection schema and configuration when serving read requests.
237    ///
238    /// The read path needs to tolerate collections that only have a configuration persisted.
239    /// This helper hydrates `schema` from the stored configuration when needed, or regenerates
240    /// the configuration from the existing schema to keep both representations consistent.
241    pub fn reconcile_schema_for_read(&mut self) -> Result<(), SchemaError> {
242        if let Some(schema) = self.schema.as_ref() {
243            self.config = InternalCollectionConfiguration::try_from(schema)
244                .map_err(|reason| SchemaError::InvalidSchema { reason })?;
245        } else {
246            self.schema = Some(Schema::try_from(&self.config)?);
247        }
248
249        Ok(())
250    }
251
252    pub fn test_collection(dim: i32) -> Self {
253        Collection {
254            name: "test_collection".to_string(),
255            dimension: Some(dim),
256            tenant: "default_tenant".to_string(),
257            database: "default_database".to_string(),
258            database_id: DatabaseUuid::new(),
259            ..Default::default()
260        }
261    }
262}
263
264#[derive(Error, Debug)]
265pub enum CollectionConversionError {
266    #[error("Invalid config: {0}")]
267    InvalidConfig(#[from] serde_json::Error),
268    #[error("Invalid UUID")]
269    InvalidUuid,
270    #[error(transparent)]
271    MetadataValueConversionError(#[from] MetadataValueConversionError),
272    #[error("Missing Database Id")]
273    MissingDatabaseId,
274}
275
276impl ChromaError for CollectionConversionError {
277    fn code(&self) -> ErrorCodes {
278        match self {
279            CollectionConversionError::InvalidConfig(_) => ErrorCodes::InvalidArgument,
280            CollectionConversionError::InvalidUuid => ErrorCodes::InvalidArgument,
281            CollectionConversionError::MetadataValueConversionError(e) => e.code(),
282            CollectionConversionError::MissingDatabaseId => ErrorCodes::Internal,
283        }
284    }
285}
286
287impl TryFrom<chroma_proto::Collection> for Collection {
288    type Error = CollectionConversionError;
289
290    fn try_from(proto_collection: chroma_proto::Collection) -> Result<Self, Self::Error> {
291        let collection_id = CollectionUuid::from_str(&proto_collection.id)
292            .map_err(|_| CollectionConversionError::InvalidUuid)?;
293        let collection_metadata: Option<Metadata> = match proto_collection.metadata {
294            Some(proto_metadata) => match proto_metadata.try_into() {
295                Ok(metadata) => Some(metadata),
296                Err(e) => return Err(CollectionConversionError::MetadataValueConversionError(e)),
297            },
298            None => None,
299        };
300        // TODO(@codetheweb): this be updated to error with "missing field" once all SysDb deployments are up-to-date
301        let updated_at = match proto_collection.updated_at {
302            Some(updated_at) => {
303                SystemTime::UNIX_EPOCH
304                    + Duration::new(updated_at.seconds as u64, updated_at.nanos as u32)
305            }
306            None => SystemTime::now(),
307        };
308        let database_id = match proto_collection.database_id {
309            Some(db_id) => DatabaseUuid::from_str(&db_id)
310                .map_err(|_| CollectionConversionError::InvalidUuid)?,
311            None => {
312                return Err(CollectionConversionError::MissingDatabaseId);
313            }
314        };
315        let schema = match proto_collection.schema_str {
316            Some(schema_str) if !schema_str.is_empty() => Some(serde_json::from_str(&schema_str)?),
317            _ => None,
318        };
319
320        Ok(Collection {
321            collection_id,
322            name: proto_collection.name,
323            config: serde_json::from_str(&proto_collection.configuration_json_str)?,
324            schema,
325            metadata: collection_metadata,
326            dimension: proto_collection.dimension,
327            tenant: proto_collection.tenant,
328            database: proto_collection.database,
329            log_position: proto_collection.log_position,
330            version: proto_collection.version,
331            total_records_post_compaction: proto_collection.total_records_post_compaction,
332            size_bytes_post_compaction: proto_collection.size_bytes_post_compaction,
333            last_compaction_time_secs: proto_collection.last_compaction_time_secs,
334            version_file_path: proto_collection.version_file_path,
335            root_collection_id: proto_collection
336                .root_collection_id
337                .map(|uuid| CollectionUuid(Uuid::try_parse(&uuid).unwrap())),
338            lineage_file_path: proto_collection.lineage_file_path,
339            updated_at,
340            database_id,
341            compaction_failure_count: proto_collection.compaction_failure_count,
342        })
343    }
344}
345
346#[derive(Error, Debug)]
347pub enum CollectionToProtoError {
348    #[error("Could not serialize config: {0}")]
349    ConfigSerialization(#[from] serde_json::Error),
350}
351
352impl ChromaError for CollectionToProtoError {
353    fn code(&self) -> ErrorCodes {
354        match self {
355            CollectionToProtoError::ConfigSerialization(_) => ErrorCodes::Internal,
356        }
357    }
358}
359
360impl TryFrom<Collection> for chroma_proto::Collection {
361    type Error = CollectionToProtoError;
362
363    fn try_from(value: Collection) -> Result<Self, Self::Error> {
364        Ok(Self {
365            id: value.collection_id.0.to_string(),
366            name: value.name,
367            configuration_json_str: serde_json::to_string(&value.config)?,
368            schema_str: value
369                .schema
370                .map(|s| serde_json::to_string(&s))
371                .transpose()?,
372            metadata: value.metadata.map(Into::into),
373            dimension: value.dimension,
374            tenant: value.tenant,
375            database: value.database,
376            log_position: value.log_position,
377            version: value.version,
378            total_records_post_compaction: value.total_records_post_compaction,
379            size_bytes_post_compaction: value.size_bytes_post_compaction,
380            last_compaction_time_secs: value.last_compaction_time_secs,
381            version_file_path: value.version_file_path,
382            root_collection_id: value.root_collection_id.map(|uuid| uuid.0.to_string()),
383            lineage_file_path: value.lineage_file_path,
384            updated_at: Some(value.updated_at.into()),
385            database_id: Some(value.database_id.0.to_string()),
386            compaction_failure_count: value.compaction_failure_count,
387        })
388    }
389}
390
391#[derive(Clone, Debug)]
392pub struct CollectionAndSegments {
393    pub collection: Collection,
394    pub metadata_segment: Segment,
395    pub record_segment: Segment,
396    pub vector_segment: Segment,
397}
398
399impl CollectionAndSegments {
400    // If dimension is not set and vector segment has no files,
401    // we assume this is an uninitialized collection
402    pub fn is_uninitialized(&self) -> bool {
403        self.collection.dimension.is_none() && self.vector_segment.file_path.is_empty()
404    }
405
406    pub fn test(dim: i32) -> Self {
407        let collection = Collection::test_collection(dim);
408        let collection_uuid = collection.collection_id;
409        Self {
410            collection,
411            metadata_segment: test_segment(collection_uuid, SegmentScope::METADATA),
412            record_segment: test_segment(collection_uuid, SegmentScope::RECORD),
413            vector_segment: test_segment(collection_uuid, SegmentScope::VECTOR),
414        }
415    }
416}
417
418#[derive(Deserialize, Serialize, Debug, Clone)]
419#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
420pub struct CreateCollectionPayload {
421    pub name: String,
422    pub schema: Option<Schema>,
423    pub configuration: Option<CollectionConfiguration>,
424    pub metadata: Option<Metadata>,
425    #[serde(default)]
426    pub get_or_create: bool,
427}
428
429#[derive(Deserialize, Serialize, Debug, Clone)]
430#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
431pub struct UpdateCollectionPayload {
432    pub new_name: Option<String>,
433    pub new_metadata: Option<UpdateMetadata>,
434    pub new_configuration: Option<UpdateCollectionConfiguration>,
435}
436
437#[cfg(test)]
438mod test {
439    use super::*;
440
441    #[test]
442    fn test_collection_try_from() {
443        // Create a valid Schema and serialize it
444        let schema = Schema::new_default(crate::KnnIndex::Spann);
445        let schema_str = serde_json::to_string(&schema).unwrap();
446
447        let proto_collection = chroma_proto::Collection {
448            id: "00000000-0000-0000-0000-000000000000".to_string(),
449            name: "foo".to_string(),
450            configuration_json_str: "{\"a\": \"param\", \"b\": \"param2\", \"3\": true}"
451                .to_string(),
452            schema_str: Some(schema_str),
453            metadata: None,
454            dimension: None,
455            tenant: "baz".to_string(),
456            database: "qux".to_string(),
457            log_position: 0,
458            version: 0,
459            total_records_post_compaction: 0,
460            size_bytes_post_compaction: 0,
461            last_compaction_time_secs: 0,
462            version_file_path: Some("version_file_path".to_string()),
463            root_collection_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
464            lineage_file_path: Some("lineage_file_path".to_string()),
465            updated_at: Some(prost_types::Timestamp {
466                seconds: 1,
467                nanos: 1,
468            }),
469            database_id: Some("00000000-0000-0000-0000-000000000000".to_string()),
470            compaction_failure_count: 0,
471        };
472        let converted_collection: Collection = proto_collection.try_into().unwrap();
473        assert_eq!(
474            converted_collection.collection_id,
475            CollectionUuid(Uuid::nil())
476        );
477        assert_eq!(converted_collection.name, "foo".to_string());
478        assert_eq!(converted_collection.metadata, None);
479        assert_eq!(converted_collection.dimension, None);
480        assert_eq!(converted_collection.tenant, "baz".to_string());
481        assert_eq!(converted_collection.database, "qux".to_string());
482        assert_eq!(converted_collection.total_records_post_compaction, 0);
483        assert_eq!(converted_collection.size_bytes_post_compaction, 0);
484        assert_eq!(converted_collection.last_compaction_time_secs, 0);
485        assert_eq!(
486            converted_collection.version_file_path,
487            Some("version_file_path".to_string())
488        );
489        assert_eq!(
490            converted_collection.root_collection_id,
491            Some(CollectionUuid(Uuid::nil()))
492        );
493        assert_eq!(
494            converted_collection.lineage_file_path,
495            Some("lineage_file_path".to_string())
496        );
497        assert_eq!(
498            converted_collection.updated_at,
499            SystemTime::UNIX_EPOCH + Duration::new(1, 1)
500        );
501        assert_eq!(converted_collection.database_id, DatabaseUuid(Uuid::nil()));
502    }
503
504    #[test]
505    fn storage_prefix_for_log_format() {
506        let collection_id = Uuid::parse_str("34e72052-5e60-47cb-be88-19a9715b7026")
507            .map(CollectionUuid)
508            .unwrap();
509        let prefix = collection_id.storage_prefix_for_log();
510        assert_eq!("logs/34e72052-5e60-47cb-be88-19a9715b7026", prefix);
511    }
512}