Skip to main content

chroma_types/
validators.rs

1use crate::{
2    execution::plan::SearchPayload,
3    operator::{Aggregate, GroupBy, QueryVector, Rank},
4    CollectionMetadataUpdate, Metadata, MetadataValue, Schema, UpdateMetadata, UpdateMetadataValue,
5    DOCUMENT_KEY, EMBEDDING_KEY,
6};
7use regex::Regex;
8use std::collections::HashMap;
9use std::str::FromStr;
10use std::{net::IpAddr, sync::LazyLock};
11use validator::ValidationError;
12
13static ALNUM_RE: LazyLock<Regex> = LazyLock::new(|| {
14    Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{1, 510}[a-zA-Z0-9]$")
15        .expect("The alphanumeric regex should be valid")
16});
17
18static DP_RE: LazyLock<Regex> =
19    LazyLock::new(|| Regex::new(r"\.\.").expect("The double period regex should be valid"));
20
21pub(crate) fn validate_non_empty_collection_update_metadata(
22    update: &CollectionMetadataUpdate,
23) -> Result<(), ValidationError> {
24    match update {
25        CollectionMetadataUpdate::UpdateMetadata(metadata) => {
26            validate_non_empty_metadata(metadata)?;
27            validate_update_metadata(metadata)
28        }
29        CollectionMetadataUpdate::ResetMetadata => Ok(()),
30    }
31}
32
33pub(crate) fn validate_non_empty_metadata<V>(
34    metadata: &HashMap<String, V>,
35) -> Result<(), ValidationError> {
36    if metadata.is_empty() {
37        Err(ValidationError::new("metadata").with_message("Metadata cannot be empty".into()))
38    } else {
39        Ok(())
40    }
41}
42
43pub fn validate_name(name: impl AsRef<str>) -> Result<(), ValidationError> {
44    let name_str = name.as_ref();
45
46    // A topology is a valid name.  A database name prefixed with a topology is a valid name.  The
47    // conjunction must be separated by a single `+` and not exceed the database name limits.
48    // Thus, we recurse after validating no more plusses.
49    if let Some((topo, name)) = name_str.split_once('+') {
50        if name_str.len() > 512 {
51            return Err(ValidationError::new("name").with_message(
52                format!(
53                    "Expected a name containing 3-512 characters. Got: {}",
54                    name_str.len()
55                )
56                .into(),
57            ));
58        }
59        if name.chars().any(|c| c == '+') {
60            return Err(ValidationError::new("name").with_message(
61                "Expected a name to contain at most one topology:  Got two `+` characters.".into(),
62            ));
63        }
64        assert!(
65            !topo.chars().any(|c| c == '+'),
66            "split once should not bypass the split character"
67        );
68        validate_name(topo)?;
69        validate_name(name)?;
70        return Ok(());
71    }
72
73    if !ALNUM_RE.is_match(name_str) {
74        return Err(ValidationError::new("name").with_message(format!("Expected a name containing 3-512 characters from [a-zA-Z0-9._-], starting and ending with a character in [a-zA-Z0-9]. Got: {name_str}").into()));
75    }
76
77    if DP_RE.is_match(name_str) {
78        return Err(ValidationError::new("name").with_message(
79            format!(
80            "Expected a name that does not contains two consecutive periods (..). Got {name_str}"
81        )
82            .into(),
83        ));
84    }
85    if IpAddr::from_str(name_str).is_ok() {
86        return Err(ValidationError::new("name").with_message(
87            format!("Expected a name that is not a valid ip address. Got {name_str}").into(),
88        ));
89    }
90    Ok(())
91}
92
93/// Validate a single metadata key
94fn validate_metadata_key(key: &str) -> Result<(), ValidationError> {
95    if key.is_empty() {
96        Err(ValidationError::new("metadata_key")
97            .with_message("Metadata key cannot be empty".into()))
98    } else if key.starts_with('#') || key.starts_with('$') {
99        Err(ValidationError::new("metadata_key")
100            .with_message(format!("Metadata key cannot start with '#' or '$': {}", key).into()))
101    } else {
102        Ok(())
103    }
104}
105
106/// Validate metadata
107pub fn validate_metadata(metadata: &Metadata) -> Result<(), ValidationError> {
108    for (key, value) in metadata {
109        validate_metadata_key(key)?;
110
111        if let MetadataValue::SparseVector(sv) = value {
112            sv.validate().map_err(|e| {
113                ValidationError::new("sparse_vector")
114                    .with_message(format!("Invalid sparse vector: {}", e).into())
115            })?;
116        }
117    }
118    Ok(())
119}
120
121/// Validate update metadata
122pub fn validate_update_metadata(metadata: &UpdateMetadata) -> Result<(), ValidationError> {
123    for (key, value) in metadata {
124        validate_metadata_key(key)?;
125
126        if let UpdateMetadataValue::SparseVector(sv) = value {
127            sv.validate().map_err(|e| {
128                ValidationError::new("sparse_vector")
129                    .with_message(format!("Invalid sparse vector: {}", e).into())
130            })?;
131        }
132    }
133    Ok(())
134}
135
136/// Validate optional vector of optional metadata
137pub fn validate_metadata_vec(metadatas: &[Option<Metadata>]) -> Result<(), ValidationError> {
138    for (i, metadata_opt) in metadatas.iter().enumerate() {
139        if let Some(metadata) = metadata_opt {
140            validate_metadata(metadata).map_err(|_| {
141                ValidationError::new("metadata")
142                    .with_message(format!("Invalid metadata at index {}", i).into())
143            })?;
144        }
145    }
146    Ok(())
147}
148
149/// Validate optional vector of optional update metadata
150pub fn validate_update_metadata_vec(
151    metadatas: &[Option<UpdateMetadata>],
152) -> Result<(), ValidationError> {
153    for (i, metadata_opt) in metadatas.iter().enumerate() {
154        if let Some(metadata) = metadata_opt {
155            validate_update_metadata(metadata).map_err(|_| {
156                ValidationError::new("metadata")
157                    .with_message(format!("Invalid metadata at index {}", i).into())
158            })?;
159        }
160    }
161    Ok(())
162}
163
164/// Validate optional metadata (for CreateCollectionRequest)
165pub fn validate_optional_metadata(metadata: &Metadata) -> Result<(), ValidationError> {
166    // First check it's not empty
167    validate_non_empty_metadata(metadata)?;
168    // Then validate keys and sparse vectors
169    validate_metadata(metadata)?;
170    Ok(())
171}
172
173/// Validate rank operator for sparse vectors
174pub fn validate_rank(rank: &Rank) -> Result<(), ValidationError> {
175    for knn in rank.knn_queries() {
176        if let QueryVector::Sparse(sv) = &knn.query {
177            sv.validate().map_err(|e| {
178                ValidationError::new("sparse_vector")
179                    .with_message(format!("Invalid sparse vector in KNN query: {}", e).into())
180            })?;
181        }
182    }
183    Ok(())
184}
185
186/// Validate group_by operator
187pub fn validate_group_by(group_by: &GroupBy) -> Result<(), ValidationError> {
188    let has_keys = !group_by.keys.is_empty();
189    let has_aggregate = group_by.aggregate.is_some();
190
191    if has_keys != has_aggregate {
192        return Err(ValidationError::new("group_by").with_message(
193            "group_by keys and aggregate must both be specified or both be omitted".into(),
194        ));
195    }
196
197    // Validate group_by keys: only metadata fields are allowed
198    for key in &group_by.keys {
199        match key {
200            crate::operator::Key::MetadataField(_) => {}
201            _ => {
202                return Err(ValidationError::new("group_by").with_message(
203                    "group_by keys must be metadata fields (cannot use #score, #document, #embedding, or #metadata)".into(),
204                ));
205            }
206        }
207    }
208
209    match &group_by.aggregate {
210        Some(Aggregate::MinK { keys, k }) | Some(Aggregate::MaxK { keys, k }) => {
211            if keys.is_empty() {
212                return Err(ValidationError::new("group_by")
213                    .with_message("aggregate keys must not be empty".into()));
214            }
215            if *k == 0 {
216                return Err(ValidationError::new("group_by")
217                    .with_message("aggregate k must be greater than 0".into()));
218            }
219            // Validate aggregate keys: only metadata fields and score are allowed
220            for key in keys {
221                match key {
222                    crate::operator::Key::MetadataField(_) | crate::operator::Key::Score => {}
223                    _ => {
224                        return Err(ValidationError::new("group_by").with_message(
225                            "aggregate keys must be metadata fields or #score (cannot use #document, #embedding, or #metadata)".into(),
226                        ));
227                    }
228                }
229            }
230        }
231        None => {}
232    }
233
234    Ok(())
235}
236
237/// Validate SearchPayload
238pub fn validate_search_payload(payload: &SearchPayload) -> Result<(), ValidationError> {
239    if !payload.group_by.keys.is_empty() && payload.rank.expr.is_none() {
240        return Err(ValidationError::new("group_by")
241            .with_message("group_by requires rank expression to be specified".into()));
242    }
243    Ok(())
244}
245
246/// Validate schema
247pub fn validate_schema(schema: &Schema) -> Result<(), ValidationError> {
248    // Prevent users from setting source_attached_function_id - only the system can set this
249    if schema.source_attached_function_id.is_some() {
250        return Err(ValidationError::new("schema").with_message(
251            "Cannot set source_attached_function_id. This field is reserved for system use.".into(),
252        ));
253    }
254
255    let mut sparse_index_keys = Vec::new();
256    if schema
257        .defaults
258        .float_list
259        .as_ref()
260        .is_some_and(|vt| vt.vector_index.as_ref().is_some_and(|it| it.enabled))
261    {
262        return Err(ValidationError::new("schema").with_message("Vector index cannot be enabled by default. It can only be enabled on #embedding field.".into()));
263    }
264    if schema.defaults.float_list.as_ref().is_some_and(|vt| {
265        vt.vector_index
266            .as_ref()
267            .is_some_and(|it| it.config.hnsw.is_some() && it.config.spann.is_some())
268    }) {
269        return Err(ValidationError::new("schema").with_message(
270            "Both spann and hnsw config cannot be present at the same time.".into(),
271        ));
272    }
273    if schema
274        .defaults
275        .sparse_vector
276        .as_ref()
277        .is_some_and(|vt| vt.sparse_vector_index.as_ref().is_some_and(|it| it.enabled))
278    {
279        return Err(ValidationError::new("schema").with_message("Sparse vector index cannot be enabled by default. Please enable sparse vector index on specific keys. At most one sparse vector index is allowed for the collection.".into()));
280    }
281    if schema
282        .defaults
283        .string
284        .as_ref()
285        .is_some_and(|vt| vt.fts_index.as_ref().is_some_and(|it| it.enabled))
286    {
287        return Err(ValidationError::new("schema").with_message("Full text search / regular expression index cannot be enabled by default. It can only be enabled on #document field.".into()));
288    }
289    for (key, config) in &schema.keys {
290        // Validate that keys cannot start with # (except system keys)
291        if key.starts_with('#') && key != DOCUMENT_KEY && key != EMBEDDING_KEY {
292            return Err(ValidationError::new("schema").with_message(
293                format!("key cannot begin with '#'. Keys starting with '#' are reserved for system use: {key}")
294                    .into(),
295            ));
296        }
297
298        if key == DOCUMENT_KEY
299            && (config.boolean.is_some()
300                || config.float.is_some()
301                || config.int.is_some()
302                || config.float_list.is_some()
303                || config.sparse_vector.is_some())
304        {
305            return Err(ValidationError::new("schema").with_message(
306                format!("Document field cannot have any value types other than string: {key}")
307                    .into(),
308            ));
309        }
310        if key == EMBEDDING_KEY
311            && (config.boolean.is_some()
312                || config.float.is_some()
313                || config.int.is_some()
314                || config.string.is_some()
315                || config.sparse_vector.is_some())
316        {
317            return Err(ValidationError::new("schema").with_message(
318                format!("Embedding field cannot have any value types other than float_list: {key}")
319                    .into(),
320            ));
321        }
322        if let Some(vit) = config
323            .float_list
324            .as_ref()
325            .and_then(|vt| vt.vector_index.as_ref())
326        {
327            if vit.enabled && key != EMBEDDING_KEY {
328                return Err(ValidationError::new("schema").with_message(
329                    format!("Vector index can only be enabled on #embedding field: {key}").into(),
330                ));
331            }
332            if vit
333                .config
334                .source_key
335                .as_ref()
336                .is_some_and(|key| key != DOCUMENT_KEY)
337            {
338                return Err(ValidationError::new("schema")
339                    .with_message("Vector index can only source from #document".into()));
340            }
341        }
342        if let Some(svit) = config
343            .sparse_vector
344            .as_ref()
345            .and_then(|vt| vt.sparse_vector_index.as_ref())
346        {
347            if svit.enabled {
348                sparse_index_keys.push(key);
349                if sparse_index_keys.len() > 1 {
350                    return Err(ValidationError::new("schema").with_message(
351                        format!("At most one sparse vector index is allowed for the collection: {sparse_index_keys:?}")
352                            .into(),
353                    ));
354                }
355                if svit.config.source_key.is_some() && svit.config.embedding_function.is_none() {
356                    return Err(ValidationError::new("schema").with_message(
357                        "If source_key is provided then embedding_function must also be provided since there is no default embedding function.".into(),
358                    ));
359                }
360            }
361            // Validate source_key for sparse vector index
362            if let Some(source_key) = &svit.config.source_key {
363                if source_key.starts_with('#') && source_key != DOCUMENT_KEY {
364                    return Err(ValidationError::new("schema").with_message(
365                        "source_key cannot begin with '#'. The only valid key starting with '#' is Key.DOCUMENT or '#document'.".into(),
366                    ));
367                }
368            }
369        }
370        if config
371            .string
372            .as_ref()
373            .is_some_and(|vt| vt.fts_index.as_ref().is_some_and(|it| it.enabled))
374            && key != DOCUMENT_KEY
375        {
376            return Err(ValidationError::new("schema").with_message(format!("Full text search / regular expression index can only be enabled on #document field: {key}").into()));
377        }
378        if config.string.as_ref().is_some_and(|vt| {
379            vt.string_inverted_index
380                .as_ref()
381                .is_some_and(|it| it.enabled)
382        }) && key == DOCUMENT_KEY
383        {
384            return Err(ValidationError::new("schema").with_message(
385                format!("String inverted index can not be enabled on #document key: {key}").into(),
386            ));
387        }
388    }
389    if let Some(cmek) = &schema.cmek {
390        if !cmek.validate_pattern() {
391            return Err(ValidationError::new("schema")
392                .with_message(format!("CMEK does not match expected pattern: {cmek:?}").into()));
393        }
394    }
395    Ok(())
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use crate::operator::Key;
402    use crate::{MetadataValue, SparseVector};
403
404    #[test]
405    fn valid_simple_name() {
406        assert!(validate_name("abc").is_ok());
407        assert!(validate_name("my_collection").is_ok());
408        assert!(validate_name("my-collection").is_ok());
409        assert!(validate_name("my.collection").is_ok());
410        assert!(validate_name("MyCollection123").is_ok());
411    }
412
413    #[test]
414    fn invalid_simple_name_too_short() {
415        assert!(validate_name("ab").is_err());
416        assert!(validate_name("a").is_err());
417        assert!(validate_name("").is_err());
418    }
419
420    #[test]
421    fn invalid_simple_name_bad_start_or_end() {
422        assert!(validate_name("_abc").is_err());
423        assert!(validate_name("-abc").is_err());
424        assert!(validate_name(".abc").is_err());
425        assert!(validate_name("abc_").is_err());
426        assert!(validate_name("abc-").is_err());
427        assert!(validate_name("abc.").is_err());
428    }
429
430    #[test]
431    fn invalid_simple_name_double_period() {
432        assert!(validate_name("abc..def").is_err());
433        assert!(validate_name("my..collection").is_err());
434    }
435
436    #[test]
437    fn invalid_simple_name_ip_address() {
438        assert!(validate_name("192.168.0.1").is_err());
439        assert!(validate_name("127.0.0.1").is_err());
440    }
441
442    #[test]
443    fn valid_topology_prefixed_name() {
444        assert!(validate_name("topo+name").is_ok());
445        assert!(validate_name("my_topology+my_collection").is_ok());
446        assert!(validate_name("region1+database").is_ok());
447        assert!(validate_name("abc+def").is_ok());
448    }
449
450    #[test]
451    fn invalid_topology_prefixed_name_multiple_plus() {
452        assert!(validate_name("a+b+c").is_err());
453        assert!(validate_name("topo+name+extra").is_err());
454        assert!(validate_name("one+two+three").is_err());
455    }
456
457    #[test]
458    fn invalid_topology_prefixed_name_invalid_topology() {
459        // Topology part must be valid (3+ chars, start/end with alnum)
460        assert!(validate_name("ab+valid").is_err());
461        assert!(validate_name("_bad+valid").is_err());
462        assert!(validate_name("bad_+valid").is_err());
463    }
464
465    #[test]
466    fn invalid_topology_prefixed_name_invalid_name() {
467        // Name part must be valid (3+ chars, start/end with alnum)
468        assert!(validate_name("valid+ab").is_err());
469        assert!(validate_name("valid+_bad").is_err());
470        assert!(validate_name("valid+bad_").is_err());
471    }
472
473    #[test]
474    fn invalid_topology_prefixed_name_too_long() {
475        // Total length exceeds 512
476        let long_topo = "a".repeat(256);
477        let long_name = "b".repeat(258);
478        let too_long = format!("{}+{}", long_topo, long_name);
479        assert!(too_long.len() > 512);
480        assert!(validate_name(&too_long).is_err());
481    }
482
483    #[test]
484    fn valid_topology_prefixed_name_at_limit() {
485        // Total length exactly 512
486        let topo = "a".repeat(255);
487        let name = "b".repeat(256);
488        let at_limit = format!("{}+{}", topo, name);
489        assert_eq!(at_limit.len(), 512);
490        assert!(validate_name(&at_limit).is_ok());
491    }
492
493    #[test]
494    fn test_metadata_validation() {
495        // Valid metadata
496        let mut metadata = Metadata::new();
497        metadata.insert("valid_key".to_string(), MetadataValue::Int(42));
498        let sparse = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
499        metadata.insert("embedding".to_string(), MetadataValue::SparseVector(sparse));
500        assert!(validate_metadata(&metadata).is_ok());
501
502        // Invalid key starting with #
503        let mut metadata = Metadata::new();
504        metadata.insert("#embedding".to_string(), MetadataValue::Int(42));
505        assert!(validate_metadata(&metadata).is_err());
506
507        // Invalid key starting with #
508        let mut metadata = Metadata::new();
509        metadata.insert("#invalid".to_string(), MetadataValue::Int(42));
510        assert!(validate_metadata(&metadata).is_err());
511
512        // Invalid key starting with $
513        let mut metadata = Metadata::new();
514        metadata.insert("$invalid".to_string(), MetadataValue::Int(42));
515        assert!(validate_metadata(&metadata).is_err());
516
517        // Invalid empty key
518        let mut metadata = Metadata::new();
519        metadata.insert("".to_string(), MetadataValue::Int(42));
520        assert!(validate_metadata(&metadata).is_err());
521
522        // Invalid sparse vector (length mismatch)
523        let mut metadata = Metadata::new();
524        let invalid_sparse = SparseVector {
525            indices: vec![1, 2],
526            values: vec![0.1, 0.2, 0.3],
527            tokens: None,
528        };
529        metadata.insert(
530            "embedding".to_string(),
531            MetadataValue::SparseVector(invalid_sparse),
532        );
533        assert!(validate_metadata(&metadata).is_err());
534    }
535
536    #[test]
537    fn test_validate_group_by() {
538        // Valid: both keys and aggregate present
539        let group_by = GroupBy {
540            keys: vec![Key::field("category")],
541            aggregate: Some(Aggregate::MinK {
542                keys: vec![Key::Score],
543                k: 3,
544            }),
545        };
546        assert!(validate_group_by(&group_by).is_ok());
547
548        // Valid: both empty
549        let group_by = GroupBy {
550            keys: vec![],
551            aggregate: None,
552        };
553        assert!(validate_group_by(&group_by).is_ok());
554
555        // Invalid: keys present, aggregate missing
556        let group_by = GroupBy {
557            keys: vec![Key::field("category")],
558            aggregate: None,
559        };
560        assert!(validate_group_by(&group_by).is_err());
561
562        // Invalid: aggregate present, keys missing
563        let group_by = GroupBy {
564            keys: vec![],
565            aggregate: Some(Aggregate::MinK {
566                keys: vec![Key::Score],
567                k: 3,
568            }),
569        };
570        assert!(validate_group_by(&group_by).is_err());
571
572        // Invalid: aggregate k = 0
573        let group_by = GroupBy {
574            keys: vec![Key::field("category")],
575            aggregate: Some(Aggregate::MinK {
576                keys: vec![Key::Score],
577                k: 0,
578            }),
579        };
580        assert!(validate_group_by(&group_by).is_err());
581
582        // Invalid: aggregate keys empty
583        let group_by = GroupBy {
584            keys: vec![Key::field("category")],
585            aggregate: Some(Aggregate::MaxK { keys: vec![], k: 5 }),
586        };
587        assert!(validate_group_by(&group_by).is_err());
588
589        // Invalid: group_by key must be metadata field (not #score)
590        let group_by = GroupBy {
591            keys: vec![Key::Score],
592            aggregate: Some(Aggregate::MinK {
593                keys: vec![Key::Score],
594                k: 3,
595            }),
596        };
597        assert!(validate_group_by(&group_by).is_err());
598
599        // Invalid: aggregate key cannot be #document
600        let group_by = GroupBy {
601            keys: vec![Key::field("category")],
602            aggregate: Some(Aggregate::MinK {
603                keys: vec![Key::Document],
604                k: 3,
605            }),
606        };
607        assert!(validate_group_by(&group_by).is_err());
608
609        // Valid: aggregate key can be metadata field
610        let group_by = GroupBy {
611            keys: vec![Key::field("category")],
612            aggregate: Some(Aggregate::MinK {
613                keys: vec![Key::field("date"), Key::Score],
614                k: 3,
615            }),
616        };
617        assert!(validate_group_by(&group_by).is_ok());
618    }
619
620    #[test]
621    fn test_validate_search_payload() {
622        use crate::operator::{QueryVector, RankExpr};
623
624        // Valid: group_by with rank expression
625        let payload = SearchPayload {
626            rank: Rank {
627                expr: Some(RankExpr::Knn {
628                    query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
629                    key: Key::Embedding,
630                    limit: 100,
631                    default: None,
632                    return_rank: false,
633                }),
634            },
635            group_by: GroupBy {
636                keys: vec![Key::field("category")],
637                aggregate: Some(Aggregate::MinK {
638                    keys: vec![Key::Score],
639                    k: 3,
640                }),
641            },
642            ..Default::default()
643        };
644        assert!(validate_search_payload(&payload).is_ok());
645
646        // Valid: no group_by, no rank
647        let payload = SearchPayload::default();
648        assert!(validate_search_payload(&payload).is_ok());
649
650        // Valid: rank without group_by
651        let payload = SearchPayload {
652            rank: Rank {
653                expr: Some(RankExpr::Knn {
654                    query: QueryVector::Dense(vec![0.1, 0.2, 0.3]),
655                    key: Key::Embedding,
656                    limit: 100,
657                    default: None,
658                    return_rank: false,
659                }),
660            },
661            ..Default::default()
662        };
663        assert!(validate_search_payload(&payload).is_ok());
664
665        // Invalid: group_by without rank expression
666        let payload = SearchPayload {
667            rank: Rank { expr: None },
668            group_by: GroupBy {
669                keys: vec![Key::field("category")],
670                aggregate: Some(Aggregate::MinK {
671                    keys: vec![Key::Score],
672                    k: 3,
673                }),
674            },
675            ..Default::default()
676        };
677        assert!(validate_search_payload(&payload).is_err());
678    }
679}