1use super::{
2 ast::{CompareOp, ComparePredicate, Predicate},
3 coercion::{CoercionId, CoercionSpec, supports_coercion},
4};
5use crate::{
6 db::identity::{EntityName, EntityNameError, IndexName, IndexNameError},
7 model::{entity::EntityModel, field::EntityFieldKind, index::IndexModel},
8 value::{Value, ValueFamily, ValueFamilyExt},
9};
10use std::{
11 collections::{BTreeMap, BTreeSet},
12 fmt,
13};
14
15#[derive(Clone, Debug, Eq, PartialEq)]
27pub(crate) enum ScalarType {
28 Account,
29 Blob,
30 Bool,
31 Date,
32 Decimal,
33 Duration,
34 Enum,
35 E8s,
36 E18s,
37 Float32,
38 Float64,
39 Int,
40 Int128,
41 IntBig,
42 Principal,
43 Subaccount,
44 Text,
45 Timestamp,
46 Uint,
47 Uint128,
48 UintBig,
49 Ulid,
50 Unit,
51}
52
53macro_rules! scalar_family_from_registry {
55 ( @args $self:expr; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
56 match $self {
57 $( ScalarType::$scalar => $family, )*
58 }
59 };
60}
61
62macro_rules! scalar_matches_value_from_registry {
63 ( @args $self:expr, $value:expr; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
64 matches!(
65 ($self, $value),
66 $( (ScalarType::$scalar, $value_pat) )|*
67 )
68 };
69}
70
71#[cfg(test)]
72macro_rules! scalar_supports_arithmetic_from_registry {
73 ( @args $self:expr; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
74 match $self {
75 $( ScalarType::$scalar => $supports_arithmetic, )*
76 }
77 };
78}
79
80macro_rules! scalar_is_keyable_from_registry {
81 ( @args $self:expr; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
82 match $self {
83 $( ScalarType::$scalar => $is_keyable, )*
84 }
85 };
86}
87
88#[cfg(test)]
89macro_rules! scalar_supports_equality_from_registry {
90 ( @args $self:expr; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
91 match $self {
92 $( ScalarType::$scalar => $supports_equality, )*
93 }
94 };
95}
96
97macro_rules! scalar_supports_ordering_from_registry {
98 ( @args $self:expr; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
99 match $self {
100 $( ScalarType::$scalar => $supports_ordering, )*
101 }
102 };
103}
104
105impl ScalarType {
106 #[must_use]
107 pub const fn family(&self) -> ValueFamily {
108 scalar_registry!(scalar_family_from_registry, self)
109 }
110
111 #[must_use]
112 pub const fn is_orderable(&self) -> bool {
113 self.supports_ordering()
114 }
115
116 #[must_use]
117 pub const fn matches_value(&self, value: &Value) -> bool {
118 scalar_registry!(scalar_matches_value_from_registry, self, value)
119 }
120
121 #[must_use]
122 #[cfg(test)]
123 pub const fn supports_arithmetic(&self) -> bool {
124 scalar_registry!(scalar_supports_arithmetic_from_registry, self)
125 }
126
127 #[must_use]
128 pub const fn is_keyable(&self) -> bool {
129 scalar_registry!(scalar_is_keyable_from_registry, self)
130 }
131
132 #[must_use]
133 #[cfg(test)]
134 pub const fn supports_equality(&self) -> bool {
135 scalar_registry!(scalar_supports_equality_from_registry, self)
136 }
137
138 #[must_use]
139 pub const fn supports_ordering(&self) -> bool {
140 scalar_registry!(scalar_supports_ordering_from_registry, self)
141 }
142}
143
144#[derive(Clone, Debug, Eq, PartialEq)]
155pub(crate) enum FieldType {
156 Scalar(ScalarType),
157 List(Box<Self>),
158 Set(Box<Self>),
159 Map { key: Box<Self>, value: Box<Self> },
160 Unsupported,
161}
162
163impl FieldType {
164 #[must_use]
165 pub const fn family(&self) -> Option<ValueFamily> {
166 match self {
167 Self::Scalar(inner) => Some(inner.family()),
168 Self::List(_) | Self::Set(_) | Self::Map { .. } => Some(ValueFamily::Collection),
169 Self::Unsupported => None,
170 }
171 }
172
173 #[must_use]
174 pub const fn is_text(&self) -> bool {
175 matches!(self, Self::Scalar(ScalarType::Text))
176 }
177
178 #[must_use]
179 pub const fn is_collection(&self) -> bool {
180 matches!(self, Self::List(_) | Self::Set(_) | Self::Map { .. })
181 }
182
183 #[must_use]
184 pub const fn is_list_like(&self) -> bool {
185 matches!(self, Self::List(_) | Self::Set(_))
186 }
187
188 #[must_use]
189 pub const fn is_map(&self) -> bool {
190 matches!(self, Self::Map { .. })
191 }
192
193 #[must_use]
194 pub fn map_types(&self) -> Option<(&Self, &Self)> {
195 match self {
196 Self::Map { key, value } => Some((key.as_ref(), value.as_ref())),
197 _ => {
198 None
200 }
201 }
202 }
203
204 #[must_use]
205 pub const fn is_orderable(&self) -> bool {
206 match self {
207 Self::Scalar(inner) => inner.is_orderable(),
208 _ => false,
209 }
210 }
211
212 #[must_use]
213 pub const fn is_keyable(&self) -> bool {
214 match self {
215 Self::Scalar(inner) => inner.is_keyable(),
216 _ => false,
217 }
218 }
219}
220
221fn validate_index_fields(
222 fields: &BTreeMap<String, FieldType>,
223 indexes: &[&IndexModel],
224) -> Result<(), ValidateError> {
225 let mut seen_names = BTreeSet::new();
226 for index in indexes {
227 if seen_names.contains(index.name) {
228 return Err(ValidateError::DuplicateIndexName {
229 name: index.name.to_string(),
230 });
231 }
232 seen_names.insert(index.name);
233
234 let mut seen = BTreeSet::new();
235 for field in index.fields {
236 if !fields.contains_key(*field) {
237 return Err(ValidateError::IndexFieldUnknown {
238 index: **index,
239 field: (*field).to_string(),
240 });
241 }
242 if seen.contains(*field) {
243 return Err(ValidateError::IndexFieldDuplicate {
244 index: **index,
245 field: (*field).to_string(),
246 });
247 }
248 seen.insert(*field);
249
250 let field_type = fields
251 .get(*field)
252 .expect("index field existence checked above");
253 if matches!(field_type, FieldType::Unsupported) {
256 return Err(ValidateError::IndexFieldUnsupported {
257 index: **index,
258 field: (*field).to_string(),
259 });
260 }
261 }
262 }
263
264 Ok(())
265}
266
267#[derive(Clone, Debug)]
275pub struct SchemaInfo {
276 fields: BTreeMap<String, FieldType>,
277}
278
279impl SchemaInfo {
280 #[must_use]
281 pub(crate) fn field(&self, name: &str) -> Option<&FieldType> {
282 self.fields.get(name)
283 }
284
285 pub fn from_entity_model(model: &EntityModel) -> Result<Self, ValidateError> {
286 let entity_name = EntityName::try_from_str(model.entity_name).map_err(|err| {
288 ValidateError::InvalidEntityName {
289 name: model.entity_name.to_string(),
290 source: err,
291 }
292 })?;
293
294 if !model
295 .fields
296 .iter()
297 .any(|field| std::ptr::eq(field, model.primary_key))
298 {
299 return Err(ValidateError::InvalidPrimaryKey {
300 field: model.primary_key.name.to_string(),
301 });
302 }
303
304 let mut fields = BTreeMap::new();
305 for field in model.fields {
306 if fields.contains_key(field.name) {
307 return Err(ValidateError::DuplicateField {
308 field: field.name.to_string(),
309 });
310 }
311 let ty = field_type_from_model_kind(&field.kind);
312 fields.insert(field.name.to_string(), ty);
313 }
314
315 let pk_field_type = fields
316 .get(model.primary_key.name)
317 .expect("primary key verified above");
318 if !pk_field_type.is_keyable() {
319 return Err(ValidateError::InvalidPrimaryKeyType {
320 field: model.primary_key.name.to_string(),
321 });
322 }
323
324 validate_index_fields(&fields, model.indexes)?;
325 for index in model.indexes {
326 IndexName::try_from_parts(&entity_name, index.fields).map_err(|err| {
327 ValidateError::InvalidIndexName {
328 index: **index,
329 source: err,
330 }
331 })?;
332 }
333
334 Ok(Self { fields })
335 }
336}
337
338#[derive(Debug, thiserror::Error)]
340pub enum ValidateError {
341 #[error("invalid entity name '{name}': {source}")]
342 InvalidEntityName {
343 name: String,
344 #[source]
345 source: EntityNameError,
346 },
347
348 #[error("invalid index name for '{index}': {source}")]
349 InvalidIndexName {
350 index: IndexModel,
351 #[source]
352 source: IndexNameError,
353 },
354
355 #[error("unknown field '{field}'")]
356 UnknownField { field: String },
357
358 #[error("unsupported field type for '{field}'")]
359 UnsupportedFieldType { field: String },
360
361 #[error("duplicate field '{field}'")]
362 DuplicateField { field: String },
363
364 #[error("primary key '{field}' not present in entity fields")]
365 InvalidPrimaryKey { field: String },
366
367 #[error("primary key '{field}' has an unsupported type")]
368 InvalidPrimaryKeyType { field: String },
369
370 #[error("index '{index}' references unknown field '{field}'")]
371 IndexFieldUnknown { index: IndexModel, field: String },
372
373 #[error("index '{index}' references unsupported field '{field}'")]
374 IndexFieldUnsupported { index: IndexModel, field: String },
375
376 #[error("index '{index}' repeats field '{field}'")]
377 IndexFieldDuplicate { index: IndexModel, field: String },
378
379 #[error("duplicate index name '{name}'")]
380 DuplicateIndexName { name: String },
381
382 #[error("operator {op} is not valid for field '{field}'")]
383 InvalidOperator { field: String, op: String },
384
385 #[error("coercion {coercion:?} is not valid for field '{field}'")]
386 InvalidCoercion { field: String, coercion: CoercionId },
387
388 #[error("invalid literal for field '{field}': {message}")]
389 InvalidLiteral { field: String, message: String },
390}
391
392pub fn validate(schema: &SchemaInfo, predicate: &Predicate) -> Result<(), ValidateError> {
393 match predicate {
394 Predicate::True | Predicate::False => Ok(()),
395 Predicate::And(children) | Predicate::Or(children) => {
396 for child in children {
397 validate(schema, child)?;
398 }
399 Ok(())
400 }
401 Predicate::Not(inner) => validate(schema, inner),
402 Predicate::Compare(cmp) => validate_compare(schema, cmp),
403 Predicate::IsNull { field } | Predicate::IsMissing { field } => {
404 ensure_field_exists(schema, field).map(|_| ())
406 }
407 Predicate::IsEmpty { field } => {
408 let field_type = ensure_field(schema, field)?;
409 if field_type.is_text() || field_type.is_collection() {
410 Ok(())
411 } else {
412 Err(invalid_operator(field, "is_empty"))
413 }
414 }
415 Predicate::IsNotEmpty { field } => {
416 let field_type = ensure_field(schema, field)?;
417 if field_type.is_text() || field_type.is_collection() {
418 Ok(())
419 } else {
420 Err(invalid_operator(field, "is_not_empty"))
421 }
422 }
423 Predicate::MapContainsKey {
424 field,
425 key,
426 coercion,
427 } => validate_map_key(schema, field, key, coercion),
428 Predicate::MapContainsValue {
429 field,
430 value,
431 coercion,
432 } => validate_map_value(schema, field, value, coercion),
433 Predicate::MapContainsEntry {
434 field,
435 key,
436 value,
437 coercion,
438 } => validate_map_entry(schema, field, key, value, coercion),
439 Predicate::TextContains { field, value } => {
440 validate_text_contains(schema, field, value, "text_contains")
441 }
442 Predicate::TextContainsCi { field, value } => {
443 validate_text_contains(schema, field, value, "text_contains_ci")
444 }
445 }
446}
447
448pub fn validate_model(model: &EntityModel, predicate: &Predicate) -> Result<(), ValidateError> {
449 let schema = SchemaInfo::from_entity_model(model)?;
450 validate(&schema, predicate)
451}
452
453fn validate_compare(schema: &SchemaInfo, cmp: &ComparePredicate) -> Result<(), ValidateError> {
454 let field_type = ensure_field(schema, &cmp.field)?;
455
456 match cmp.op {
457 CompareOp::Eq | CompareOp::Ne => {
458 validate_eq_ne(&cmp.field, field_type, &cmp.value, &cmp.coercion)
459 }
460 CompareOp::Lt | CompareOp::Lte | CompareOp::Gt | CompareOp::Gte => {
461 validate_ordering(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
462 }
463 CompareOp::In | CompareOp::NotIn => {
464 validate_in(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
465 }
466 CompareOp::Contains => validate_contains(&cmp.field, field_type, &cmp.value, &cmp.coercion),
467 CompareOp::StartsWith | CompareOp::EndsWith => {
468 validate_text_compare(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
469 }
470 }
471}
472
473fn validate_eq_ne(
474 field: &str,
475 field_type: &FieldType,
476 value: &Value,
477 coercion: &CoercionSpec,
478) -> Result<(), ValidateError> {
479 if field_type.is_list_like() {
480 ensure_list_literal(field, value, field_type)?;
481 } else if field_type.is_map() {
482 ensure_map_literal(field, value, field_type)?;
483 } else {
484 ensure_scalar_literal(field, value)?;
485 }
486
487 ensure_coercion(field, field_type, value, coercion)
488}
489
490fn validate_ordering(
491 field: &str,
492 field_type: &FieldType,
493 value: &Value,
494 coercion: &CoercionSpec,
495 op: CompareOp,
496) -> Result<(), ValidateError> {
497 if matches!(coercion.id, CoercionId::CollectionElement) {
498 return Err(ValidateError::InvalidCoercion {
499 field: field.to_string(),
500 coercion: coercion.id,
501 });
502 }
503
504 if !field_type.is_orderable() {
505 return Err(invalid_operator(field, format!("{op:?}")));
506 }
507
508 ensure_scalar_literal(field, value)?;
509
510 ensure_coercion(field, field_type, value, coercion)
511}
512
513fn validate_in(
515 field: &str,
516 field_type: &FieldType,
517 value: &Value,
518 coercion: &CoercionSpec,
519 op: CompareOp,
520) -> Result<(), ValidateError> {
521 if field_type.is_collection() {
522 return Err(invalid_operator(field, format!("{op:?}")));
523 }
524
525 let Value::List(items) = value else {
526 return Err(invalid_literal(field, "expected list literal"));
527 };
528
529 for item in items {
530 ensure_coercion(field, field_type, item, coercion)?;
531 }
532
533 Ok(())
534}
535
536fn validate_contains(
538 field: &str,
539 field_type: &FieldType,
540 value: &Value,
541 coercion: &CoercionSpec,
542) -> Result<(), ValidateError> {
543 if field_type.is_text() {
544 return Err(invalid_operator(
546 field,
547 format!("{:?}", CompareOp::Contains),
548 ));
549 }
550
551 let element_type = match field_type {
552 FieldType::List(inner) | FieldType::Set(inner) => inner.as_ref(),
553 _ => {
554 return Err(invalid_operator(
555 field,
556 format!("{:?}", CompareOp::Contains),
557 ));
558 }
559 };
560
561 if matches!(coercion.id, CoercionId::TextCasefold) {
562 return Err(ValidateError::InvalidCoercion {
564 field: field.to_string(),
565 coercion: coercion.id,
566 });
567 }
568
569 ensure_coercion(field, element_type, value, coercion)
570}
571
572fn validate_text_compare(
574 field: &str,
575 field_type: &FieldType,
576 value: &Value,
577 coercion: &CoercionSpec,
578 op: CompareOp,
579) -> Result<(), ValidateError> {
580 if !field_type.is_text() {
581 return Err(invalid_operator(field, format!("{op:?}")));
582 }
583
584 ensure_text_literal(field, value)?;
585
586 ensure_coercion(field, field_type, value, coercion)
587}
588
589fn ensure_map_types<'a>(
591 schema: &'a SchemaInfo,
592 field: &str,
593 op: &str,
594) -> Result<(&'a FieldType, &'a FieldType), ValidateError> {
595 let field_type = ensure_field(schema, field)?;
596 field_type
597 .map_types()
598 .ok_or_else(|| invalid_operator(field, op))
599}
600
601fn validate_map_key(
602 schema: &SchemaInfo,
603 field: &str,
604 key: &Value,
605 coercion: &CoercionSpec,
606) -> Result<(), ValidateError> {
607 ensure_no_text_casefold(field, coercion)?;
608
609 let (key_type, _) = ensure_map_types(schema, field, "map_contains_key")?;
610
611 ensure_coercion(field, key_type, key, coercion)
612}
613
614fn validate_map_value(
615 schema: &SchemaInfo,
616 field: &str,
617 value: &Value,
618 coercion: &CoercionSpec,
619) -> Result<(), ValidateError> {
620 ensure_no_text_casefold(field, coercion)?;
621
622 let (_, value_type) = ensure_map_types(schema, field, "map_contains_value")?;
623
624 ensure_coercion(field, value_type, value, coercion)
625}
626
627fn validate_map_entry(
628 schema: &SchemaInfo,
629 field: &str,
630 key: &Value,
631 value: &Value,
632 coercion: &CoercionSpec,
633) -> Result<(), ValidateError> {
634 ensure_no_text_casefold(field, coercion)?;
635
636 let (key_type, value_type) = ensure_map_types(schema, field, "map_contains_entry")?;
637
638 ensure_coercion(field, key_type, key, coercion)?;
639 ensure_coercion(field, value_type, value, coercion)?;
640
641 Ok(())
642}
643
644fn validate_text_contains(
646 schema: &SchemaInfo,
647 field: &str,
648 value: &Value,
649 op: &str,
650) -> Result<(), ValidateError> {
651 let field_type = ensure_field(schema, field)?;
652 if !field_type.is_text() {
653 return Err(invalid_operator(field, op));
654 }
655
656 ensure_text_literal(field, value)?;
657
658 Ok(())
659}
660
661fn ensure_field<'a>(schema: &'a SchemaInfo, field: &str) -> Result<&'a FieldType, ValidateError> {
662 let field_type = schema
663 .field(field)
664 .ok_or_else(|| ValidateError::UnknownField {
665 field: field.to_string(),
666 })?;
667
668 if matches!(field_type, FieldType::Unsupported) {
669 return Err(ValidateError::UnsupportedFieldType {
670 field: field.to_string(),
671 });
672 }
673
674 Ok(field_type)
675}
676
677fn ensure_field_exists<'a>(
678 schema: &'a SchemaInfo,
679 field: &str,
680) -> Result<&'a FieldType, ValidateError> {
681 schema
682 .field(field)
683 .ok_or_else(|| ValidateError::UnknownField {
684 field: field.to_string(),
685 })
686}
687
688fn invalid_operator(field: &str, op: impl fmt::Display) -> ValidateError {
689 ValidateError::InvalidOperator {
690 field: field.to_string(),
691 op: op.to_string(),
692 }
693}
694
695fn invalid_literal(field: &str, msg: &str) -> ValidateError {
696 ValidateError::InvalidLiteral {
697 field: field.to_string(),
698 message: msg.to_string(),
699 }
700}
701
702fn ensure_no_text_casefold(field: &str, coercion: &CoercionSpec) -> Result<(), ValidateError> {
704 if matches!(coercion.id, CoercionId::TextCasefold) {
705 return Err(ValidateError::InvalidCoercion {
706 field: field.to_string(),
707 coercion: coercion.id,
708 });
709 }
710
711 Ok(())
712}
713
714fn ensure_text_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
716 if !matches!(value, Value::Text(_)) {
717 return Err(invalid_literal(field, "expected text literal"));
718 }
719
720 Ok(())
721}
722
723fn ensure_scalar_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
725 if matches!(value, Value::List(_)) {
726 return Err(invalid_literal(field, "expected scalar literal"));
727 }
728
729 Ok(())
730}
731
732fn ensure_coercion(
733 field: &str,
734 field_type: &FieldType,
735 literal: &Value,
736 coercion: &CoercionSpec,
737) -> Result<(), ValidateError> {
738 let left_family = field_type
739 .family()
740 .ok_or_else(|| ValidateError::UnsupportedFieldType {
741 field: field.to_string(),
742 })?;
743 let right_family = literal.family();
744
745 if matches!(coercion.id, CoercionId::TextCasefold) && !field_type.is_text() {
746 return Err(ValidateError::InvalidCoercion {
748 field: field.to_string(),
749 coercion: coercion.id,
750 });
751 }
752
753 if !supports_coercion(left_family, right_family, coercion.id) {
754 return Err(ValidateError::InvalidCoercion {
755 field: field.to_string(),
756 coercion: coercion.id,
757 });
758 }
759
760 if matches!(
761 coercion.id,
762 CoercionId::Strict | CoercionId::CollectionElement
763 ) && !literal_matches_type(literal, field_type)
764 {
765 return Err(invalid_literal(
766 field,
767 "literal type does not match field type",
768 ));
769 }
770
771 Ok(())
772}
773
774fn ensure_list_literal(
775 field: &str,
776 literal: &Value,
777 field_type: &FieldType,
778) -> Result<(), ValidateError> {
779 if !literal_matches_type(literal, field_type) {
780 return Err(invalid_literal(
781 field,
782 "list literal does not match field element type",
783 ));
784 }
785
786 Ok(())
787}
788
789fn ensure_map_literal(
790 field: &str,
791 literal: &Value,
792 field_type: &FieldType,
793) -> Result<(), ValidateError> {
794 if !literal_matches_type(literal, field_type) {
795 return Err(invalid_literal(
796 field,
797 "map literal does not match field key/value types",
798 ));
799 }
800
801 Ok(())
802}
803
804pub(crate) fn literal_matches_type(literal: &Value, field_type: &FieldType) -> bool {
805 match field_type {
806 FieldType::Scalar(inner) => inner.matches_value(literal),
807 FieldType::List(element) | FieldType::Set(element) => match literal {
808 Value::List(items) => items.iter().all(|item| literal_matches_type(item, element)),
809 _ => false,
810 },
811 FieldType::Map { key, value } => match literal {
812 Value::List(entries) => entries.iter().all(|entry| match entry {
813 Value::List(pair) if pair.len() == 2 => {
814 literal_matches_type(&pair[0], key) && literal_matches_type(&pair[1], value)
815 }
816 _ => false,
817 }),
818 _ => false,
819 },
820 FieldType::Unsupported => {
821 false
823 }
824 }
825}
826
827fn field_type_from_model_kind(kind: &EntityFieldKind) -> FieldType {
828 match kind {
829 EntityFieldKind::Account => FieldType::Scalar(ScalarType::Account),
830 EntityFieldKind::Blob => FieldType::Scalar(ScalarType::Blob),
831 EntityFieldKind::Bool => FieldType::Scalar(ScalarType::Bool),
832 EntityFieldKind::Date => FieldType::Scalar(ScalarType::Date),
833 EntityFieldKind::Decimal => FieldType::Scalar(ScalarType::Decimal),
834 EntityFieldKind::Duration => FieldType::Scalar(ScalarType::Duration),
835 EntityFieldKind::Enum => FieldType::Scalar(ScalarType::Enum),
836 EntityFieldKind::E8s => FieldType::Scalar(ScalarType::E8s),
837 EntityFieldKind::E18s => FieldType::Scalar(ScalarType::E18s),
838 EntityFieldKind::Float32 => FieldType::Scalar(ScalarType::Float32),
839 EntityFieldKind::Float64 => FieldType::Scalar(ScalarType::Float64),
840 EntityFieldKind::Int => FieldType::Scalar(ScalarType::Int),
841 EntityFieldKind::Int128 => FieldType::Scalar(ScalarType::Int128),
842 EntityFieldKind::IntBig => FieldType::Scalar(ScalarType::IntBig),
843 EntityFieldKind::Principal => FieldType::Scalar(ScalarType::Principal),
844 EntityFieldKind::Subaccount => FieldType::Scalar(ScalarType::Subaccount),
845 EntityFieldKind::Text => FieldType::Scalar(ScalarType::Text),
846 EntityFieldKind::Timestamp => FieldType::Scalar(ScalarType::Timestamp),
847 EntityFieldKind::Uint => FieldType::Scalar(ScalarType::Uint),
848 EntityFieldKind::Uint128 => FieldType::Scalar(ScalarType::Uint128),
849 EntityFieldKind::UintBig => FieldType::Scalar(ScalarType::UintBig),
850 EntityFieldKind::Ulid => FieldType::Scalar(ScalarType::Ulid),
851 EntityFieldKind::Unit => FieldType::Scalar(ScalarType::Unit),
852 EntityFieldKind::Ref { key_kind, .. } => field_type_from_model_kind(key_kind),
853 EntityFieldKind::List(inner) => {
854 FieldType::List(Box::new(field_type_from_model_kind(inner)))
855 }
856 EntityFieldKind::Set(inner) => FieldType::Set(Box::new(field_type_from_model_kind(inner))),
857 EntityFieldKind::Map { key, value } => FieldType::Map {
858 key: Box::new(field_type_from_model_kind(key)),
859 value: Box::new(field_type_from_model_kind(value)),
860 },
861 EntityFieldKind::Unsupported => FieldType::Unsupported,
862 }
863}
864
865impl fmt::Display for FieldType {
866 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
867 match self {
868 Self::Scalar(inner) => write!(f, "{inner:?}"),
869 Self::List(inner) => write!(f, "List<{inner}>"),
870 Self::Set(inner) => write!(f, "Set<{inner}>"),
871 Self::Map { key, value } => write!(f, "Map<{key}, {value}>"),
872 Self::Unsupported => write!(f, "Unsupported"),
873 }
874 }
875}
876
877#[cfg(test)]
882mod tests {
883 use super::{
885 CompareOp, FieldType, ScalarType, ValidateError, validate_model, validate_ordering,
886 };
887 use crate::{
888 db::query::{
889 FieldRef,
890 predicate::{CoercionId, CoercionSpec, Predicate},
891 },
892 db::store::StorageKey,
893 model::field::{EntityFieldKind, EntityFieldModel},
894 test_fixtures::InvalidEntityModelBuilder,
895 traits::EntitySchema,
896 types::{
897 Account, Date, Decimal, Duration, E8s, E18s, Float32, Float64, Int, Int128, Nat,
898 Nat128, Principal, Subaccount, Timestamp, Ulid,
899 },
900 value::{Value, ValueEnum, ValueFamily},
901 };
902 use std::collections::BTreeSet;
903
904 fn registry_scalars() -> Vec<ScalarType> {
906 macro_rules! collect_scalars {
907 ( @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
908 vec![ $( ScalarType::$scalar ),* ]
909 };
910 ( @args $($ignore:tt)*; @entries $( ($scalar:ident, $family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr) ),* $(,)? ) => {
911 vec![ $( ScalarType::$scalar ),* ]
912 };
913 }
914
915 let scalars = scalar_registry!(collect_scalars);
916
917 scalars
918 }
919
920 const SCALAR_TYPE_VARIANT_COUNT: usize = 23;
922
923 fn scalar_index(scalar: ScalarType) -> usize {
925 match scalar {
926 ScalarType::Account => 0,
927 ScalarType::Blob => 1,
928 ScalarType::Bool => 2,
929 ScalarType::Date => 3,
930 ScalarType::Decimal => 4,
931 ScalarType::Duration => 5,
932 ScalarType::Enum => 6,
933 ScalarType::E8s => 7,
934 ScalarType::E18s => 8,
935 ScalarType::Float32 => 9,
936 ScalarType::Float64 => 10,
937 ScalarType::Int => 11,
938 ScalarType::Int128 => 12,
939 ScalarType::IntBig => 13,
940 ScalarType::Principal => 14,
941 ScalarType::Subaccount => 15,
942 ScalarType::Text => 16,
943 ScalarType::Timestamp => 17,
944 ScalarType::Uint => 18,
945 ScalarType::Uint128 => 19,
946 ScalarType::UintBig => 20,
947 ScalarType::Ulid => 21,
948 ScalarType::Unit => 22,
949 }
950 }
951
952 fn scalar_from_index(index: usize) -> Option<ScalarType> {
954 let scalar = match index {
955 0 => ScalarType::Account,
956 1 => ScalarType::Blob,
957 2 => ScalarType::Bool,
958 3 => ScalarType::Date,
959 4 => ScalarType::Decimal,
960 5 => ScalarType::Duration,
961 6 => ScalarType::Enum,
962 7 => ScalarType::E8s,
963 8 => ScalarType::E18s,
964 9 => ScalarType::Float32,
965 10 => ScalarType::Float64,
966 11 => ScalarType::Int,
967 12 => ScalarType::Int128,
968 13 => ScalarType::IntBig,
969 14 => ScalarType::Principal,
970 15 => ScalarType::Subaccount,
971 16 => ScalarType::Text,
972 17 => ScalarType::Timestamp,
973 18 => ScalarType::Uint,
974 19 => ScalarType::Uint128,
975 20 => ScalarType::UintBig,
976 21 => ScalarType::Ulid,
977 22 => ScalarType::Unit,
978 _ => return None,
979 };
980
981 Some(scalar)
982 }
983
984 fn legacy_family(scalar: ScalarType) -> ValueFamily {
986 match scalar {
987 ScalarType::Text => ValueFamily::Textual,
988 ScalarType::Ulid | ScalarType::Principal | ScalarType::Account => {
989 ValueFamily::Identifier
990 }
991 ScalarType::Enum => ValueFamily::Enum,
992 ScalarType::Blob | ScalarType::Subaccount => ValueFamily::Blob,
993 ScalarType::Bool => ValueFamily::Bool,
994 ScalarType::Unit => ValueFamily::Unit,
995 ScalarType::Date
996 | ScalarType::Decimal
997 | ScalarType::Duration
998 | ScalarType::E8s
999 | ScalarType::E18s
1000 | ScalarType::Float32
1001 | ScalarType::Float64
1002 | ScalarType::Int
1003 | ScalarType::Int128
1004 | ScalarType::IntBig
1005 | ScalarType::Timestamp
1006 | ScalarType::Uint
1007 | ScalarType::Uint128
1008 | ScalarType::UintBig => ValueFamily::Numeric,
1009 }
1010 }
1011
1012 fn legacy_matches_value(scalar: ScalarType, value: &Value) -> bool {
1014 let matches_value = matches!(
1015 (scalar, value),
1016 (ScalarType::Account, Value::Account(_))
1017 | (ScalarType::Blob, Value::Blob(_))
1018 | (ScalarType::Bool, Value::Bool(_))
1019 | (ScalarType::Date, Value::Date(_))
1020 | (ScalarType::Decimal, Value::Decimal(_))
1021 | (ScalarType::Duration, Value::Duration(_))
1022 | (ScalarType::Enum, Value::Enum(_))
1023 | (ScalarType::E8s, Value::E8s(_))
1024 | (ScalarType::E18s, Value::E18s(_))
1025 | (ScalarType::Float32, Value::Float32(_))
1026 | (ScalarType::Float64, Value::Float64(_))
1027 | (ScalarType::Int, Value::Int(_))
1028 | (ScalarType::Int128, Value::Int128(_))
1029 | (ScalarType::IntBig, Value::IntBig(_))
1030 | (ScalarType::Principal, Value::Principal(_))
1031 | (ScalarType::Subaccount, Value::Subaccount(_))
1032 | (ScalarType::Text, Value::Text(_))
1033 | (ScalarType::Timestamp, Value::Timestamp(_))
1034 | (ScalarType::Uint, Value::Uint(_))
1035 | (ScalarType::Uint128, Value::Uint128(_))
1036 | (ScalarType::UintBig, Value::UintBig(_))
1037 | (ScalarType::Ulid, Value::Ulid(_))
1038 | (ScalarType::Unit, Value::Unit)
1039 );
1040
1041 matches_value
1042 }
1043
1044 fn legacy_supports_arithmetic(scalar: ScalarType) -> bool {
1046 matches!(
1047 scalar,
1048 ScalarType::Decimal
1049 | ScalarType::E8s
1050 | ScalarType::E18s
1051 | ScalarType::Int
1052 | ScalarType::Int128
1053 | ScalarType::IntBig
1054 | ScalarType::Uint
1055 | ScalarType::Uint128
1056 | ScalarType::UintBig
1057 )
1058 }
1059
1060 fn legacy_supports_ordering(scalar: ScalarType) -> bool {
1062 !matches!(scalar, ScalarType::Blob | ScalarType::Unit)
1063 }
1064
1065 fn legacy_supports_equality(_scalar: ScalarType) -> bool {
1067 true
1068 }
1069
1070 fn legacy_is_keyable(scalar: ScalarType) -> bool {
1072 matches!(
1073 scalar,
1074 ScalarType::Account
1075 | ScalarType::Int
1076 | ScalarType::Principal
1077 | ScalarType::Subaccount
1078 | ScalarType::Timestamp
1079 | ScalarType::Uint
1080 | ScalarType::Ulid
1081 | ScalarType::Unit
1082 )
1083 }
1084
1085 fn sample_value_for_scalar(scalar: ScalarType) -> Value {
1087 match scalar {
1088 ScalarType::Account => Value::Account(Account::dummy(1)),
1089 ScalarType::Blob => Value::Blob(vec![0u8, 1u8]),
1090 ScalarType::Bool => Value::Bool(true),
1091 ScalarType::Date => Value::Date(Date::EPOCH),
1092 ScalarType::Decimal => Value::Decimal(Decimal::ZERO),
1093 ScalarType::Duration => Value::Duration(Duration::ZERO),
1094 ScalarType::Enum => Value::Enum(ValueEnum::loose("example")),
1095 ScalarType::E8s => Value::E8s(E8s::from_atomic(0)),
1096 ScalarType::E18s => Value::E18s(E18s::from_atomic(0)),
1097 ScalarType::Float32 => {
1098 Value::Float32(Float32::try_new(0.0).expect("Float32 sample should be finite"))
1099 }
1100 ScalarType::Float64 => {
1101 Value::Float64(Float64::try_new(0.0).expect("Float64 sample should be finite"))
1102 }
1103 ScalarType::Int => Value::Int(0),
1104 ScalarType::Int128 => Value::Int128(Int128::from(0i128)),
1105 ScalarType::IntBig => Value::IntBig(Int::from(0i32)),
1106 ScalarType::Principal => Value::Principal(Principal::anonymous()),
1107 ScalarType::Subaccount => Value::Subaccount(Subaccount::dummy(2)),
1108 ScalarType::Text => Value::Text("text".to_string()),
1109 ScalarType::Timestamp => Value::Timestamp(Timestamp::EPOCH),
1110 ScalarType::Uint => Value::Uint(0),
1111 ScalarType::Uint128 => Value::Uint128(Nat128::from(0u128)),
1112 ScalarType::UintBig => Value::UintBig(Nat::from(0u64)),
1113 ScalarType::Ulid => Value::Ulid(Ulid::nil()),
1114 ScalarType::Unit => Value::Unit,
1115 }
1116 }
1117
1118 fn mismatching_value_for_scalar(scalar: ScalarType) -> Value {
1120 match scalar {
1121 ScalarType::Unit => Value::Bool(false),
1122 _ => Value::Unit,
1123 }
1124 }
1125
1126 fn field(name: &'static str, kind: EntityFieldKind) -> EntityFieldModel {
1127 EntityFieldModel { name, kind }
1128 }
1129
1130 crate::test_entity_schema! {
1131 ScalarPredicateEntity,
1132 id = Ulid,
1133 path = "predicate_validate::ScalarEntity",
1134 entity_name = "ScalarEntity",
1135 primary_key = "id",
1136 pk_index = 0,
1137 fields = [
1138 ("id", EntityFieldKind::Ulid),
1139 ("email", EntityFieldKind::Text),
1140 ("age", EntityFieldKind::Uint),
1141 ("created_at", EntityFieldKind::Timestamp),
1142 ("active", EntityFieldKind::Bool),
1143 ],
1144 indexes = [],
1145 }
1146
1147 crate::test_entity_schema! {
1148 CollectionPredicateEntity,
1149 id = Ulid,
1150 path = "predicate_validate::CollectionEntity",
1151 entity_name = "CollectionEntity",
1152 primary_key = "id",
1153 pk_index = 0,
1154 fields = [
1155 ("id", EntityFieldKind::Ulid),
1156 ("tags", EntityFieldKind::List(&EntityFieldKind::Text)),
1157 ("principals", EntityFieldKind::Set(&EntityFieldKind::Principal)),
1158 (
1159 "attributes",
1160 EntityFieldKind::Map {
1161 key: &EntityFieldKind::Text,
1162 value: &EntityFieldKind::Uint,
1163 }
1164 ),
1165 ],
1166 indexes = [],
1167 }
1168
1169 #[test]
1170 fn validate_model_accepts_scalars_and_coercions() {
1171 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1172
1173 let predicate = Predicate::And(vec![
1174 FieldRef::new("id").eq(Ulid::nil()),
1175 FieldRef::new("email").text_eq_ci("User@example.com"),
1176 FieldRef::new("age").lt(30u32),
1177 ]);
1178
1179 assert!(validate_model(model, &predicate).is_ok());
1180 }
1181
1182 #[test]
1183 fn validate_model_accepts_collections_and_map_contains() {
1184 let model = <CollectionPredicateEntity as EntitySchema>::MODEL;
1185
1186 let predicate = Predicate::And(vec![
1187 FieldRef::new("tags").is_empty(),
1188 FieldRef::new("principals").is_not_empty(),
1189 FieldRef::new("attributes").map_contains_entry("k", 1u64, CoercionId::Strict),
1190 ]);
1191
1192 assert!(validate_model(model, &predicate).is_ok());
1193
1194 let bad =
1195 FieldRef::new("attributes").map_contains_entry("k", 1u64, CoercionId::TextCasefold);
1196
1197 assert!(matches!(
1198 validate_model(model, &bad),
1199 Err(ValidateError::InvalidCoercion { .. })
1200 ));
1201 }
1202
1203 #[test]
1204 fn validate_model_rejects_unsupported_fields() {
1205 let model = InvalidEntityModelBuilder::from_fields(
1206 vec![
1207 field("id", EntityFieldKind::Ulid),
1208 field("broken", EntityFieldKind::Unsupported),
1209 ],
1210 0,
1211 );
1212
1213 let predicate = FieldRef::new("broken").eq(1u64);
1214
1215 assert!(matches!(
1216 validate_model(&model, &predicate),
1217 Err(ValidateError::UnsupportedFieldType { field }) if field == "broken"
1218 ));
1219 }
1220
1221 #[test]
1222 fn validate_model_accepts_text_contains() {
1223 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1224
1225 let predicate = FieldRef::new("email").text_contains("example");
1226 assert!(validate_model(model, &predicate).is_ok());
1227
1228 let predicate = FieldRef::new("email").text_contains_ci("EXAMPLE");
1229 assert!(validate_model(model, &predicate).is_ok());
1230 }
1231
1232 #[test]
1233 fn validate_model_rejects_text_contains_on_non_text() {
1234 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1235
1236 let predicate = FieldRef::new("age").text_contains("1");
1237 assert!(matches!(
1238 validate_model(model, &predicate),
1239 Err(ValidateError::InvalidOperator { field, op })
1240 if field == "age" && op == "text_contains"
1241 ));
1242 }
1243
1244 #[test]
1245 fn scalar_registry_covers_all_variants_exactly_once() {
1246 let scalars = registry_scalars();
1247 let mut names = BTreeSet::new();
1248 let mut seen = [false; SCALAR_TYPE_VARIANT_COUNT];
1249
1250 for scalar in scalars {
1251 let index = scalar_index(scalar.clone());
1252 assert!(!seen[index], "duplicate scalar entry: {scalar:?}");
1253 seen[index] = true;
1254
1255 let name = format!("{scalar:?}");
1256 assert!(names.insert(name.clone()), "duplicate scalar entry: {name}");
1257 }
1258
1259 let mut missing = Vec::new();
1260 for (index, was_seen) in seen.iter().enumerate() {
1261 if !*was_seen {
1262 let scalar = scalar_from_index(index).expect("index is in range");
1263 missing.push(format!("{scalar:?}"));
1264 }
1265 }
1266
1267 assert!(missing.is_empty(), "missing scalar entries: {missing:?}");
1268 assert_eq!(names.len(), SCALAR_TYPE_VARIANT_COUNT);
1269 }
1270
1271 #[test]
1272 fn scalar_registry_preserves_family_and_matching() {
1273 for scalar in registry_scalars() {
1274 let expected_family = legacy_family(scalar.clone());
1275 let matching = sample_value_for_scalar(scalar.clone());
1276 let mismatching = mismatching_value_for_scalar(scalar.clone());
1277 let expected_match = legacy_matches_value(scalar.clone(), &matching);
1278 let expected_mismatch = legacy_matches_value(scalar.clone(), &mismatching);
1279
1280 assert_eq!(scalar.family(), expected_family);
1281 assert_eq!(scalar.matches_value(&matching), expected_match);
1282 assert_eq!(scalar.matches_value(&mismatching), expected_mismatch);
1283 }
1284 }
1285
1286 #[test]
1287 fn scalar_registry_preserves_arithmetic_support() {
1288 for scalar in registry_scalars() {
1289 let expected = legacy_supports_arithmetic(scalar.clone());
1290
1291 assert_eq!(scalar.supports_arithmetic(), expected);
1292 }
1293 }
1294
1295 #[test]
1296 fn scalar_registry_preserves_keyability() {
1297 for scalar in registry_scalars() {
1298 let expected = legacy_is_keyable(scalar.clone());
1299
1300 assert_eq!(scalar.is_keyable(), expected);
1301 }
1302 }
1303
1304 #[test]
1305 fn scalar_keyability_matches_storage_key_conversion_paths() {
1306 for scalar in registry_scalars() {
1307 let value = sample_value_for_scalar(scalar.clone());
1308 let scalar_keyable = scalar.is_keyable();
1309 let value_keyable = value.as_storage_key().is_some();
1310 let storage_keyable = StorageKey::try_from_value(&value).is_ok();
1311
1312 assert_eq!(
1313 value_keyable, scalar_keyable,
1314 "Value::as_storage_key drift for scalar {scalar:?}"
1315 );
1316 assert_eq!(
1317 storage_keyable, scalar_keyable,
1318 "StorageKey::try_from_value drift for scalar {scalar:?}"
1319 );
1320 assert_eq!(
1321 value.as_storage_key(),
1322 StorageKey::try_from_value(&value).ok()
1323 );
1324 }
1325 }
1326
1327 #[test]
1328 fn scalar_registry_preserves_ordering_support() {
1329 for scalar in registry_scalars() {
1330 let expected = legacy_supports_ordering(scalar.clone());
1331
1332 assert_eq!(scalar.supports_ordering(), expected);
1333 assert_eq!(scalar.is_orderable(), expected);
1334 }
1335 }
1336
1337 #[test]
1338 fn scalar_registry_preserves_equality_support() {
1339 for scalar in registry_scalars() {
1340 let expected = legacy_supports_equality(scalar.clone());
1341
1342 assert_eq!(scalar.supports_equality(), expected);
1343 }
1344 }
1345
1346 #[test]
1347 fn validate_ordering_matches_legacy_support() {
1348 let coercion = CoercionSpec::default();
1349 let op = CompareOp::Lt;
1350
1351 for scalar in registry_scalars() {
1352 let field_type = FieldType::Scalar(scalar.clone());
1353 let value = sample_value_for_scalar(scalar.clone());
1354 let result = validate_ordering("field", &field_type, &value, &coercion, op);
1355 let is_orderable = legacy_supports_ordering(scalar);
1356
1357 if is_orderable {
1358 assert!(result.is_ok(), "scalar should be orderable: {field_type:?}");
1359 } else {
1360 assert!(matches!(result, Err(ValidateError::InvalidOperator { .. })));
1361 }
1362 }
1363 }
1364}