1use super::{
2 ast::{CompareOp, ComparePredicate, Predicate, UnsupportedQueryFeature},
3 coercion::{CoercionId, CoercionSpec, supports_coercion},
4};
5use crate::{
6 db::identity::{EntityName, EntityNameError, IndexName, IndexNameError},
7 model::{entity::EntityModel, field::EntityFieldKind, index::IndexModel},
8 traits::FieldValueKind,
9 value::{CoercionFamily, CoercionFamilyExt, Value},
10};
11use std::{
12 collections::{BTreeMap, BTreeSet},
13 fmt,
14};
15
16#[derive(Clone, Debug, Eq, PartialEq)]
28pub(crate) enum ScalarType {
29 Account,
30 Blob,
31 Bool,
32 Date,
33 Decimal,
34 Duration,
35 Enum,
36 E8s,
37 E18s,
38 Float32,
39 Float64,
40 Int,
41 Int128,
42 IntBig,
43 Principal,
44 Subaccount,
45 Text,
46 Timestamp,
47 Uint,
48 Uint128,
49 UintBig,
50 Ulid,
51 Unit,
52}
53
54macro_rules! scalar_coercion_family_from_registry {
56 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
57 match $self {
58 $( ScalarType::$scalar => $coercion_family, )*
59 }
60 };
61}
62
63macro_rules! scalar_matches_value_from_registry {
64 ( @args $self:expr, $value:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
65 matches!(
66 ($self, $value),
67 $( (ScalarType::$scalar, $value_pat) )|*
68 )
69 };
70}
71
72macro_rules! scalar_supports_numeric_coercion_from_registry {
73 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
74 match $self {
75 $( ScalarType::$scalar => $supports_numeric_coercion, )*
76 }
77 };
78}
79
80macro_rules! scalar_is_keyable_from_registry {
81 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
82 match $self {
83 $( ScalarType::$scalar => $is_keyable, )*
84 }
85 };
86}
87
88macro_rules! scalar_supports_ordering_from_registry {
89 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
90 match $self {
91 $( ScalarType::$scalar => $supports_ordering, )*
92 }
93 };
94}
95
96impl ScalarType {
97 #[must_use]
98 pub const fn coercion_family(&self) -> CoercionFamily {
99 scalar_registry!(scalar_coercion_family_from_registry, self)
100 }
101
102 #[must_use]
103 pub const fn is_orderable(&self) -> bool {
104 self.supports_ordering()
107 }
108
109 #[must_use]
110 pub const fn matches_value(&self, value: &Value) -> bool {
111 scalar_registry!(scalar_matches_value_from_registry, self, value)
112 }
113
114 #[must_use]
115 pub const fn supports_numeric_coercion(&self) -> bool {
116 scalar_registry!(scalar_supports_numeric_coercion_from_registry, self)
117 }
118
119 #[must_use]
120 pub const fn is_keyable(&self) -> bool {
121 scalar_registry!(scalar_is_keyable_from_registry, self)
122 }
123
124 #[must_use]
125 pub const fn supports_ordering(&self) -> bool {
126 scalar_registry!(scalar_supports_ordering_from_registry, self)
127 }
128}
129
130#[derive(Clone, Debug, Eq, PartialEq)]
141pub(crate) enum FieldType {
142 Scalar(ScalarType),
143 List(Box<Self>),
144 Set(Box<Self>),
145 Map { key: Box<Self>, value: Box<Self> },
146 Structured { queryable: bool },
147}
148
149impl FieldType {
150 #[must_use]
151 pub const fn value_kind(&self) -> FieldValueKind {
152 match self {
153 Self::Scalar(_) => FieldValueKind::Atomic,
154 Self::List(_) | Self::Set(_) => FieldValueKind::Structured { queryable: true },
155 Self::Map { .. } => FieldValueKind::Structured { queryable: false },
156 Self::Structured { queryable } => FieldValueKind::Structured {
157 queryable: *queryable,
158 },
159 }
160 }
161
162 #[must_use]
163 pub const fn coercion_family(&self) -> Option<CoercionFamily> {
164 match self {
165 Self::Scalar(inner) => Some(inner.coercion_family()),
166 Self::List(_) | Self::Set(_) | Self::Map { .. } => Some(CoercionFamily::Collection),
167 Self::Structured { .. } => None,
168 }
169 }
170
171 #[must_use]
172 pub const fn is_text(&self) -> bool {
173 matches!(self, Self::Scalar(ScalarType::Text))
174 }
175
176 #[must_use]
177 pub const fn is_collection(&self) -> bool {
178 matches!(self, Self::List(_) | Self::Set(_) | Self::Map { .. })
179 }
180
181 #[must_use]
182 pub const fn is_list_like(&self) -> bool {
183 matches!(self, Self::List(_) | Self::Set(_))
184 }
185
186 #[must_use]
187 pub const fn is_map(&self) -> bool {
188 matches!(self, Self::Map { .. })
189 }
190
191 #[must_use]
192 pub fn map_types(&self) -> Option<(&Self, &Self)> {
193 match self {
194 Self::Map { key, value } => Some((key.as_ref(), value.as_ref())),
195 _ => {
196 None
198 }
199 }
200 }
201
202 #[must_use]
203 pub const fn is_orderable(&self) -> bool {
204 match self {
205 Self::Scalar(inner) => inner.is_orderable(),
206 _ => false,
207 }
208 }
209
210 #[must_use]
211 pub const fn is_keyable(&self) -> bool {
212 match self {
213 Self::Scalar(inner) => inner.is_keyable(),
214 _ => false,
215 }
216 }
217
218 #[must_use]
219 pub const fn supports_numeric_coercion(&self) -> bool {
220 match self {
221 Self::Scalar(inner) => inner.supports_numeric_coercion(),
222 _ => false,
223 }
224 }
225}
226
227fn validate_index_fields(
228 fields: &BTreeMap<String, FieldType>,
229 indexes: &[&IndexModel],
230) -> Result<(), ValidateError> {
231 let mut seen_names = BTreeSet::new();
232 for index in indexes {
233 if seen_names.contains(index.name) {
234 return Err(ValidateError::DuplicateIndexName {
235 name: index.name.to_string(),
236 });
237 }
238 seen_names.insert(index.name);
239
240 let mut seen = BTreeSet::new();
241 for field in index.fields {
242 if !fields.contains_key(*field) {
243 return Err(ValidateError::IndexFieldUnknown {
244 index: **index,
245 field: (*field).to_string(),
246 });
247 }
248 if seen.contains(*field) {
249 return Err(ValidateError::IndexFieldDuplicate {
250 index: **index,
251 field: (*field).to_string(),
252 });
253 }
254 seen.insert(*field);
255
256 let field_type = fields
257 .get(*field)
258 .expect("index field existence checked above");
259 if matches!(field_type, FieldType::Map { .. }) {
262 return Err(ValidateError::IndexFieldMapNotQueryable {
263 index: **index,
264 field: (*field).to_string(),
265 });
266 }
267 if !field_type.value_kind().is_queryable() {
268 return Err(ValidateError::IndexFieldNotQueryable {
269 index: **index,
270 field: (*field).to_string(),
271 });
272 }
273 }
274 }
275
276 Ok(())
277}
278
279#[derive(Clone, Debug)]
287pub struct SchemaInfo {
288 fields: BTreeMap<String, FieldType>,
289}
290
291impl SchemaInfo {
292 #[must_use]
293 pub(crate) fn field(&self, name: &str) -> Option<&FieldType> {
294 self.fields.get(name)
295 }
296
297 pub fn from_entity_model(model: &EntityModel) -> Result<Self, ValidateError> {
298 let entity_name = EntityName::try_from_str(model.entity_name).map_err(|err| {
300 ValidateError::InvalidEntityName {
301 name: model.entity_name.to_string(),
302 source: err,
303 }
304 })?;
305
306 if !model
307 .fields
308 .iter()
309 .any(|field| std::ptr::eq(field, model.primary_key))
310 {
311 return Err(ValidateError::InvalidPrimaryKey {
312 field: model.primary_key.name.to_string(),
313 });
314 }
315
316 let mut fields = BTreeMap::new();
317 for field in model.fields {
318 if fields.contains_key(field.name) {
319 return Err(ValidateError::DuplicateField {
320 field: field.name.to_string(),
321 });
322 }
323 let ty = field_type_from_model_kind(&field.kind);
324 fields.insert(field.name.to_string(), ty);
325 }
326
327 let pk_field_type = fields
328 .get(model.primary_key.name)
329 .expect("primary key verified above");
330 if !pk_field_type.is_keyable() {
331 return Err(ValidateError::InvalidPrimaryKeyType {
332 field: model.primary_key.name.to_string(),
333 });
334 }
335
336 validate_index_fields(&fields, model.indexes)?;
337 for index in model.indexes {
338 IndexName::try_from_parts(&entity_name, index.fields).map_err(|err| {
339 ValidateError::InvalidIndexName {
340 index: **index,
341 source: err,
342 }
343 })?;
344 }
345
346 Ok(Self { fields })
347 }
348}
349
350#[derive(Debug, thiserror::Error)]
352pub enum ValidateError {
353 #[error("invalid entity name '{name}': {source}")]
354 InvalidEntityName {
355 name: String,
356 #[source]
357 source: EntityNameError,
358 },
359
360 #[error("invalid index name for '{index}': {source}")]
361 InvalidIndexName {
362 index: IndexModel,
363 #[source]
364 source: IndexNameError,
365 },
366
367 #[error("unknown field '{field}'")]
368 UnknownField { field: String },
369
370 #[error("field '{field}' is not queryable")]
371 NonQueryableFieldType { field: String },
372
373 #[error("duplicate field '{field}'")]
374 DuplicateField { field: String },
375
376 #[error("{0}")]
377 UnsupportedQueryFeature(#[from] UnsupportedQueryFeature),
378
379 #[error("primary key '{field}' not present in entity fields")]
380 InvalidPrimaryKey { field: String },
381
382 #[error("primary key '{field}' has a non-keyable type")]
383 InvalidPrimaryKeyType { field: String },
384
385 #[error("index '{index}' references unknown field '{field}'")]
386 IndexFieldUnknown { index: IndexModel, field: String },
387
388 #[error("index '{index}' references non-queryable field '{field}'")]
389 IndexFieldNotQueryable { index: IndexModel, field: String },
390
391 #[error(
392 "index '{index}' references map field '{field}'; map fields are not queryable in icydb 0.7"
393 )]
394 IndexFieldMapNotQueryable { index: IndexModel, field: String },
395
396 #[error("index '{index}' repeats field '{field}'")]
397 IndexFieldDuplicate { index: IndexModel, field: String },
398
399 #[error("duplicate index name '{name}'")]
400 DuplicateIndexName { name: String },
401
402 #[error("operator {op} is not valid for field '{field}'")]
403 InvalidOperator { field: String, op: String },
404
405 #[error("coercion {coercion:?} is not valid for field '{field}'")]
406 InvalidCoercion { field: String, coercion: CoercionId },
407
408 #[error("invalid literal for field '{field}': {message}")]
409 InvalidLiteral { field: String, message: String },
410}
411
412pub fn reject_unsupported_query_features(
414 predicate: &Predicate,
415) -> Result<(), UnsupportedQueryFeature> {
416 match predicate {
417 Predicate::True
418 | Predicate::False
419 | Predicate::Compare(_)
420 | Predicate::IsNull { .. }
421 | Predicate::IsMissing { .. }
422 | Predicate::IsEmpty { .. }
423 | Predicate::IsNotEmpty { .. }
424 | Predicate::TextContains { .. }
425 | Predicate::TextContainsCi { .. } => Ok(()),
426 Predicate::And(children) | Predicate::Or(children) => {
427 for child in children {
428 reject_unsupported_query_features(child)?;
429 }
430
431 Ok(())
432 }
433 Predicate::Not(inner) => reject_unsupported_query_features(inner),
434 Predicate::MapContainsKey { field, .. }
435 | Predicate::MapContainsValue { field, .. }
436 | Predicate::MapContainsEntry { field, .. } => Err(UnsupportedQueryFeature::MapPredicate {
437 field: field.clone(),
438 }),
439 }
440}
441
442pub fn validate(schema: &SchemaInfo, predicate: &Predicate) -> Result<(), ValidateError> {
443 reject_unsupported_query_features(predicate)?;
444
445 match predicate {
446 Predicate::True | Predicate::False => Ok(()),
447 Predicate::And(children) | Predicate::Or(children) => {
448 for child in children {
449 validate(schema, child)?;
450 }
451 Ok(())
452 }
453 Predicate::Not(inner) => validate(schema, inner),
454 Predicate::Compare(cmp) => validate_compare(schema, cmp),
455 Predicate::IsNull { field } | Predicate::IsMissing { field } => {
456 let _field_type = ensure_field(schema, field)?;
457 Ok(())
458 }
459 Predicate::IsEmpty { field } => {
460 let field_type = ensure_field(schema, field)?;
461 if field_type.is_text() || field_type.is_collection() {
462 Ok(())
463 } else {
464 Err(invalid_operator(field, "is_empty"))
465 }
466 }
467 Predicate::IsNotEmpty { field } => {
468 let field_type = ensure_field(schema, field)?;
469 if field_type.is_text() || field_type.is_collection() {
470 Ok(())
471 } else {
472 Err(invalid_operator(field, "is_not_empty"))
473 }
474 }
475 Predicate::MapContainsKey {
476 field,
477 key,
478 coercion,
479 } => validate_map_key(schema, field, key, coercion),
480 Predicate::MapContainsValue {
481 field,
482 value,
483 coercion,
484 } => validate_map_value(schema, field, value, coercion),
485 Predicate::MapContainsEntry {
486 field,
487 key,
488 value,
489 coercion,
490 } => validate_map_entry(schema, field, key, value, coercion),
491 Predicate::TextContains { field, value } => {
492 validate_text_contains(schema, field, value, "text_contains")
493 }
494 Predicate::TextContainsCi { field, value } => {
495 validate_text_contains(schema, field, value, "text_contains_ci")
496 }
497 }
498}
499
500pub fn validate_model(model: &EntityModel, predicate: &Predicate) -> Result<(), ValidateError> {
501 let schema = SchemaInfo::from_entity_model(model)?;
502 validate(&schema, predicate)
503}
504
505fn validate_compare(schema: &SchemaInfo, cmp: &ComparePredicate) -> Result<(), ValidateError> {
506 let field_type = ensure_field(schema, &cmp.field)?;
507
508 match cmp.op {
509 CompareOp::Eq | CompareOp::Ne => {
510 validate_eq_ne(&cmp.field, field_type, &cmp.value, &cmp.coercion)
511 }
512 CompareOp::Lt | CompareOp::Lte | CompareOp::Gt | CompareOp::Gte => {
513 validate_ordering(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
514 }
515 CompareOp::In | CompareOp::NotIn => {
516 validate_in(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
517 }
518 CompareOp::Contains => validate_contains(&cmp.field, field_type, &cmp.value, &cmp.coercion),
519 CompareOp::StartsWith | CompareOp::EndsWith => {
520 validate_text_compare(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
521 }
522 }
523}
524
525fn validate_eq_ne(
526 field: &str,
527 field_type: &FieldType,
528 value: &Value,
529 coercion: &CoercionSpec,
530) -> Result<(), ValidateError> {
531 if field_type.is_list_like() {
532 ensure_list_literal(field, value, field_type)?;
533 } else if field_type.is_map() {
534 ensure_map_literal(field, value, field_type)?;
535 } else {
536 ensure_scalar_literal(field, value)?;
537 }
538
539 ensure_coercion(field, field_type, value, coercion)
540}
541
542fn validate_ordering(
543 field: &str,
544 field_type: &FieldType,
545 value: &Value,
546 coercion: &CoercionSpec,
547 op: CompareOp,
548) -> Result<(), ValidateError> {
549 if matches!(coercion.id, CoercionId::CollectionElement) {
550 return Err(ValidateError::InvalidCoercion {
551 field: field.to_string(),
552 coercion: coercion.id,
553 });
554 }
555
556 if !field_type.is_orderable() {
557 return Err(invalid_operator(field, format!("{op:?}")));
558 }
559
560 ensure_scalar_literal(field, value)?;
561
562 ensure_coercion(field, field_type, value, coercion)
563}
564
565fn validate_in(
567 field: &str,
568 field_type: &FieldType,
569 value: &Value,
570 coercion: &CoercionSpec,
571 op: CompareOp,
572) -> Result<(), ValidateError> {
573 if field_type.is_collection() {
574 return Err(invalid_operator(field, format!("{op:?}")));
575 }
576
577 let Value::List(items) = value else {
578 return Err(invalid_literal(field, "expected list literal"));
579 };
580
581 for item in items {
582 ensure_coercion(field, field_type, item, coercion)?;
583 }
584
585 Ok(())
586}
587
588fn validate_contains(
590 field: &str,
591 field_type: &FieldType,
592 value: &Value,
593 coercion: &CoercionSpec,
594) -> Result<(), ValidateError> {
595 if field_type.is_text() {
596 return Err(invalid_operator(
598 field,
599 format!("{:?}", CompareOp::Contains),
600 ));
601 }
602
603 let element_type = match field_type {
604 FieldType::List(inner) | FieldType::Set(inner) => inner.as_ref(),
605 _ => {
606 return Err(invalid_operator(
607 field,
608 format!("{:?}", CompareOp::Contains),
609 ));
610 }
611 };
612
613 if matches!(coercion.id, CoercionId::TextCasefold) {
614 return Err(ValidateError::InvalidCoercion {
616 field: field.to_string(),
617 coercion: coercion.id,
618 });
619 }
620
621 ensure_coercion(field, element_type, value, coercion)
622}
623
624fn validate_text_compare(
626 field: &str,
627 field_type: &FieldType,
628 value: &Value,
629 coercion: &CoercionSpec,
630 op: CompareOp,
631) -> Result<(), ValidateError> {
632 if !field_type.is_text() {
633 return Err(invalid_operator(field, format!("{op:?}")));
634 }
635
636 ensure_text_literal(field, value)?;
637
638 ensure_coercion(field, field_type, value, coercion)
639}
640
641fn ensure_map_types<'a>(
643 schema: &'a SchemaInfo,
644 field: &str,
645 op: &str,
646) -> Result<(&'a FieldType, &'a FieldType), ValidateError> {
647 let field_type = ensure_field(schema, field)?;
648 field_type
649 .map_types()
650 .ok_or_else(|| invalid_operator(field, op))
651}
652
653fn validate_map_key(
654 schema: &SchemaInfo,
655 field: &str,
656 key: &Value,
657 coercion: &CoercionSpec,
658) -> Result<(), ValidateError> {
659 ensure_no_text_casefold(field, coercion)?;
660
661 let (key_type, _) = ensure_map_types(schema, field, "map_contains_key")?;
662
663 ensure_coercion(field, key_type, key, coercion)
664}
665
666fn validate_map_value(
667 schema: &SchemaInfo,
668 field: &str,
669 value: &Value,
670 coercion: &CoercionSpec,
671) -> Result<(), ValidateError> {
672 ensure_no_text_casefold(field, coercion)?;
673
674 let (_, value_type) = ensure_map_types(schema, field, "map_contains_value")?;
675
676 ensure_coercion(field, value_type, value, coercion)
677}
678
679fn validate_map_entry(
680 schema: &SchemaInfo,
681 field: &str,
682 key: &Value,
683 value: &Value,
684 coercion: &CoercionSpec,
685) -> Result<(), ValidateError> {
686 ensure_no_text_casefold(field, coercion)?;
687
688 let (key_type, value_type) = ensure_map_types(schema, field, "map_contains_entry")?;
689
690 ensure_coercion(field, key_type, key, coercion)?;
691 ensure_coercion(field, value_type, value, coercion)?;
692
693 Ok(())
694}
695
696fn validate_text_contains(
698 schema: &SchemaInfo,
699 field: &str,
700 value: &Value,
701 op: &str,
702) -> Result<(), ValidateError> {
703 let field_type = ensure_field(schema, field)?;
704 if !field_type.is_text() {
705 return Err(invalid_operator(field, op));
706 }
707
708 ensure_text_literal(field, value)?;
709
710 Ok(())
711}
712
713fn ensure_field<'a>(schema: &'a SchemaInfo, field: &str) -> Result<&'a FieldType, ValidateError> {
714 let field_type = schema
715 .field(field)
716 .ok_or_else(|| ValidateError::UnknownField {
717 field: field.to_string(),
718 })?;
719
720 if matches!(field_type, FieldType::Map { .. }) {
721 return Err(UnsupportedQueryFeature::MapPredicate {
722 field: field.to_string(),
723 }
724 .into());
725 }
726
727 if !field_type.value_kind().is_queryable() {
728 return Err(ValidateError::NonQueryableFieldType {
729 field: field.to_string(),
730 });
731 }
732
733 Ok(field_type)
734}
735
736fn invalid_operator(field: &str, op: impl fmt::Display) -> ValidateError {
737 ValidateError::InvalidOperator {
738 field: field.to_string(),
739 op: op.to_string(),
740 }
741}
742
743fn invalid_literal(field: &str, msg: &str) -> ValidateError {
744 ValidateError::InvalidLiteral {
745 field: field.to_string(),
746 message: msg.to_string(),
747 }
748}
749
750fn ensure_no_text_casefold(field: &str, coercion: &CoercionSpec) -> Result<(), ValidateError> {
752 if matches!(coercion.id, CoercionId::TextCasefold) {
753 return Err(ValidateError::InvalidCoercion {
754 field: field.to_string(),
755 coercion: coercion.id,
756 });
757 }
758
759 Ok(())
760}
761
762fn ensure_text_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
764 if !matches!(value, Value::Text(_)) {
765 return Err(invalid_literal(field, "expected text literal"));
766 }
767
768 Ok(())
769}
770
771fn ensure_scalar_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
773 if matches!(value, Value::List(_)) {
774 return Err(invalid_literal(field, "expected scalar literal"));
775 }
776
777 Ok(())
778}
779
780fn ensure_coercion(
781 field: &str,
782 field_type: &FieldType,
783 literal: &Value,
784 coercion: &CoercionSpec,
785) -> Result<(), ValidateError> {
786 if matches!(coercion.id, CoercionId::TextCasefold) && !field_type.is_text() {
787 return Err(ValidateError::InvalidCoercion {
789 field: field.to_string(),
790 coercion: coercion.id,
791 });
792 }
793
794 if matches!(coercion.id, CoercionId::NumericWiden)
799 && (!field_type.supports_numeric_coercion() || !literal.supports_numeric_coercion())
800 {
801 return Err(ValidateError::InvalidCoercion {
802 field: field.to_string(),
803 coercion: coercion.id,
804 });
805 }
806
807 if !matches!(coercion.id, CoercionId::NumericWiden) {
808 let left_family =
809 field_type
810 .coercion_family()
811 .ok_or_else(|| ValidateError::NonQueryableFieldType {
812 field: field.to_string(),
813 })?;
814 let right_family = literal.coercion_family();
815
816 if !supports_coercion(left_family, right_family, coercion.id) {
817 return Err(ValidateError::InvalidCoercion {
818 field: field.to_string(),
819 coercion: coercion.id,
820 });
821 }
822 }
823
824 if matches!(
825 coercion.id,
826 CoercionId::Strict | CoercionId::CollectionElement
827 ) && !literal_matches_type(literal, field_type)
828 {
829 return Err(invalid_literal(
830 field,
831 "literal type does not match field type",
832 ));
833 }
834
835 Ok(())
836}
837
838fn ensure_list_literal(
839 field: &str,
840 literal: &Value,
841 field_type: &FieldType,
842) -> Result<(), ValidateError> {
843 if !literal_matches_type(literal, field_type) {
844 return Err(invalid_literal(
845 field,
846 "list literal does not match field element type",
847 ));
848 }
849
850 Ok(())
851}
852
853fn ensure_map_literal(
854 field: &str,
855 literal: &Value,
856 field_type: &FieldType,
857) -> Result<(), ValidateError> {
858 if !literal_matches_type(literal, field_type) {
859 return Err(invalid_literal(
860 field,
861 "map literal does not match field key/value types",
862 ));
863 }
864
865 Ok(())
866}
867
868pub(crate) fn literal_matches_type(literal: &Value, field_type: &FieldType) -> bool {
869 match field_type {
870 FieldType::Scalar(inner) => inner.matches_value(literal),
871 FieldType::List(element) | FieldType::Set(element) => match literal {
872 Value::List(items) => items.iter().all(|item| literal_matches_type(item, element)),
873 _ => false,
874 },
875 FieldType::Map { key, value } => match literal {
876 Value::Map(entries) => {
877 if Value::validate_map_entries(entries.as_slice()).is_err() {
878 return false;
879 }
880
881 entries.iter().all(|(entry_key, entry_value)| {
882 literal_matches_type(entry_key, key) && literal_matches_type(entry_value, value)
883 })
884 }
885 _ => false,
886 },
887 FieldType::Structured { .. } => {
888 false
890 }
891 }
892}
893
894fn field_type_from_model_kind(kind: &EntityFieldKind) -> FieldType {
895 match kind {
896 EntityFieldKind::Account => FieldType::Scalar(ScalarType::Account),
897 EntityFieldKind::Blob => FieldType::Scalar(ScalarType::Blob),
898 EntityFieldKind::Bool => FieldType::Scalar(ScalarType::Bool),
899 EntityFieldKind::Date => FieldType::Scalar(ScalarType::Date),
900 EntityFieldKind::Decimal => FieldType::Scalar(ScalarType::Decimal),
901 EntityFieldKind::Duration => FieldType::Scalar(ScalarType::Duration),
902 EntityFieldKind::Enum => FieldType::Scalar(ScalarType::Enum),
903 EntityFieldKind::E8s => FieldType::Scalar(ScalarType::E8s),
904 EntityFieldKind::E18s => FieldType::Scalar(ScalarType::E18s),
905 EntityFieldKind::Float32 => FieldType::Scalar(ScalarType::Float32),
906 EntityFieldKind::Float64 => FieldType::Scalar(ScalarType::Float64),
907 EntityFieldKind::Int => FieldType::Scalar(ScalarType::Int),
908 EntityFieldKind::Int128 => FieldType::Scalar(ScalarType::Int128),
909 EntityFieldKind::IntBig => FieldType::Scalar(ScalarType::IntBig),
910 EntityFieldKind::Principal => FieldType::Scalar(ScalarType::Principal),
911 EntityFieldKind::Subaccount => FieldType::Scalar(ScalarType::Subaccount),
912 EntityFieldKind::Text => FieldType::Scalar(ScalarType::Text),
913 EntityFieldKind::Timestamp => FieldType::Scalar(ScalarType::Timestamp),
914 EntityFieldKind::Uint => FieldType::Scalar(ScalarType::Uint),
915 EntityFieldKind::Uint128 => FieldType::Scalar(ScalarType::Uint128),
916 EntityFieldKind::UintBig => FieldType::Scalar(ScalarType::UintBig),
917 EntityFieldKind::Ulid => FieldType::Scalar(ScalarType::Ulid),
918 EntityFieldKind::Unit => FieldType::Scalar(ScalarType::Unit),
919 EntityFieldKind::Ref { key_kind, .. } => field_type_from_model_kind(key_kind),
920 EntityFieldKind::List(inner) => {
921 FieldType::List(Box::new(field_type_from_model_kind(inner)))
922 }
923 EntityFieldKind::Set(inner) => FieldType::Set(Box::new(field_type_from_model_kind(inner))),
924 EntityFieldKind::Map { key, value } => FieldType::Map {
925 key: Box::new(field_type_from_model_kind(key)),
926 value: Box::new(field_type_from_model_kind(value)),
927 },
928 EntityFieldKind::Structured { queryable } => FieldType::Structured {
929 queryable: *queryable,
930 },
931 }
932}
933
934impl fmt::Display for FieldType {
935 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
936 match self {
937 Self::Scalar(inner) => write!(f, "{inner:?}"),
938 Self::List(inner) => write!(f, "List<{inner}>"),
939 Self::Set(inner) => write!(f, "Set<{inner}>"),
940 Self::Map { key, value } => write!(f, "Map<{key}, {value}>"),
941 Self::Structured { queryable } => {
942 write!(f, "Structured<queryable={queryable}>")
943 }
944 }
945 }
946}
947
948#[cfg(test)]
953mod tests {
954 use super::{FieldType, ScalarType, ValidateError, ensure_coercion, validate_model};
956 use crate::{
957 db::query::{
958 FieldRef,
959 predicate::{
960 CoercionId, CoercionSpec, CompareOp, ComparePredicate, Predicate,
961 UnsupportedQueryFeature,
962 },
963 },
964 model::field::{EntityFieldKind, EntityFieldModel},
965 test_fixtures::InvalidEntityModelBuilder,
966 traits::{EntitySchema, FieldValue},
967 types::{
968 Account, Date, Decimal, Duration, E8s, E18s, Float32, Float64, Int, Int128, Nat,
969 Nat128, Principal, Subaccount, Timestamp, Ulid,
970 },
971 value::{CoercionFamily, Value, ValueEnum},
972 };
973 use std::collections::BTreeSet;
974
975 fn registry_scalars() -> Vec<ScalarType> {
977 macro_rules! collect_scalars {
978 ( @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
979 vec![ $( ScalarType::$scalar ),* ]
980 };
981 ( @args $($ignore:tt)*; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
982 vec![ $( ScalarType::$scalar ),* ]
983 };
984 }
985
986 let scalars = scalar_registry!(collect_scalars);
987
988 scalars
989 }
990
991 const SCALAR_TYPE_VARIANT_COUNT: usize = 23;
993
994 fn scalar_index(scalar: ScalarType) -> usize {
996 match scalar {
997 ScalarType::Account => 0,
998 ScalarType::Blob => 1,
999 ScalarType::Bool => 2,
1000 ScalarType::Date => 3,
1001 ScalarType::Decimal => 4,
1002 ScalarType::Duration => 5,
1003 ScalarType::Enum => 6,
1004 ScalarType::E8s => 7,
1005 ScalarType::E18s => 8,
1006 ScalarType::Float32 => 9,
1007 ScalarType::Float64 => 10,
1008 ScalarType::Int => 11,
1009 ScalarType::Int128 => 12,
1010 ScalarType::IntBig => 13,
1011 ScalarType::Principal => 14,
1012 ScalarType::Subaccount => 15,
1013 ScalarType::Text => 16,
1014 ScalarType::Timestamp => 17,
1015 ScalarType::Uint => 18,
1016 ScalarType::Uint128 => 19,
1017 ScalarType::UintBig => 20,
1018 ScalarType::Ulid => 21,
1019 ScalarType::Unit => 22,
1020 }
1021 }
1022
1023 fn scalar_from_index(index: usize) -> Option<ScalarType> {
1025 let scalar = match index {
1026 0 => ScalarType::Account,
1027 1 => ScalarType::Blob,
1028 2 => ScalarType::Bool,
1029 3 => ScalarType::Date,
1030 4 => ScalarType::Decimal,
1031 5 => ScalarType::Duration,
1032 6 => ScalarType::Enum,
1033 7 => ScalarType::E8s,
1034 8 => ScalarType::E18s,
1035 9 => ScalarType::Float32,
1036 10 => ScalarType::Float64,
1037 11 => ScalarType::Int,
1038 12 => ScalarType::Int128,
1039 13 => ScalarType::IntBig,
1040 14 => ScalarType::Principal,
1041 15 => ScalarType::Subaccount,
1042 16 => ScalarType::Text,
1043 17 => ScalarType::Timestamp,
1044 18 => ScalarType::Uint,
1045 19 => ScalarType::Uint128,
1046 20 => ScalarType::UintBig,
1047 21 => ScalarType::Ulid,
1048 22 => ScalarType::Unit,
1049 _ => return None,
1050 };
1051
1052 Some(scalar)
1053 }
1054
1055 fn sample_value_for_scalar(scalar: ScalarType) -> Value {
1057 match scalar {
1058 ScalarType::Account => Value::Account(Account::dummy(1)),
1059 ScalarType::Blob => Value::Blob(vec![0u8, 1u8]),
1060 ScalarType::Bool => Value::Bool(true),
1061 ScalarType::Date => Value::Date(Date::EPOCH),
1062 ScalarType::Decimal => Value::Decimal(Decimal::ZERO),
1063 ScalarType::Duration => Value::Duration(Duration::ZERO),
1064 ScalarType::Enum => Value::Enum(ValueEnum::loose("example")),
1065 ScalarType::E8s => Value::E8s(E8s::from_atomic(0)),
1066 ScalarType::E18s => Value::E18s(E18s::from_atomic(0)),
1067 ScalarType::Float32 => {
1068 Value::Float32(Float32::try_new(0.0).expect("Float32 sample should be finite"))
1069 }
1070 ScalarType::Float64 => {
1071 Value::Float64(Float64::try_new(0.0).expect("Float64 sample should be finite"))
1072 }
1073 ScalarType::Int => Value::Int(0),
1074 ScalarType::Int128 => Value::Int128(Int128::from(0i128)),
1075 ScalarType::IntBig => Value::IntBig(Int::from(0i32)),
1076 ScalarType::Principal => Value::Principal(Principal::anonymous()),
1077 ScalarType::Subaccount => Value::Subaccount(Subaccount::dummy(2)),
1078 ScalarType::Text => Value::Text("text".to_string()),
1079 ScalarType::Timestamp => Value::Timestamp(Timestamp::EPOCH),
1080 ScalarType::Uint => Value::Uint(0),
1081 ScalarType::Uint128 => Value::Uint128(Nat128::from(0u128)),
1082 ScalarType::UintBig => Value::UintBig(Nat::from(0u64)),
1083 ScalarType::Ulid => Value::Ulid(Ulid::nil()),
1084 ScalarType::Unit => Value::Unit,
1085 }
1086 }
1087
1088 fn field(name: &'static str, kind: EntityFieldKind) -> EntityFieldModel {
1089 EntityFieldModel { name, kind }
1090 }
1091
1092 crate::test_entity_schema! {
1093 ScalarPredicateEntity,
1094 id = Ulid,
1095 path = "predicate_validate::ScalarEntity",
1096 entity_name = "ScalarEntity",
1097 primary_key = "id",
1098 pk_index = 0,
1099 fields = [
1100 ("id", EntityFieldKind::Ulid),
1101 ("email", EntityFieldKind::Text),
1102 ("age", EntityFieldKind::Uint),
1103 ("created_at", EntityFieldKind::Timestamp),
1104 ("active", EntityFieldKind::Bool),
1105 ],
1106 indexes = [],
1107 }
1108
1109 crate::test_entity_schema! {
1110 CollectionPredicateEntity,
1111 id = Ulid,
1112 path = "predicate_validate::CollectionEntity",
1113 entity_name = "CollectionEntity",
1114 primary_key = "id",
1115 pk_index = 0,
1116 fields = [
1117 ("id", EntityFieldKind::Ulid),
1118 ("tags", EntityFieldKind::List(&EntityFieldKind::Text)),
1119 ("principals", EntityFieldKind::Set(&EntityFieldKind::Principal)),
1120 (
1121 "attributes",
1122 EntityFieldKind::Map {
1123 key: &EntityFieldKind::Text,
1124 value: &EntityFieldKind::Uint,
1125 }
1126 ),
1127 ],
1128 indexes = [],
1129 }
1130
1131 crate::test_entity_schema! {
1132 NumericCoercionPredicateEntity,
1133 id = Ulid,
1134 path = "predicate_validate::NumericCoercionEntity",
1135 entity_name = "NumericCoercionEntity",
1136 primary_key = "id",
1137 pk_index = 0,
1138 fields = [
1139 ("id", EntityFieldKind::Ulid),
1140 ("date", EntityFieldKind::Date),
1141 ("int_big", EntityFieldKind::IntBig),
1142 ("uint_big", EntityFieldKind::UintBig),
1143 ("int_small", EntityFieldKind::Int),
1144 ("uint_small", EntityFieldKind::Uint),
1145 ("decimal", EntityFieldKind::Decimal),
1146 ("e8s", EntityFieldKind::E8s),
1147 ],
1148 indexes = [],
1149 }
1150
1151 #[test]
1152 fn validate_model_accepts_scalars_and_coercions() {
1153 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1154
1155 let predicate = Predicate::And(vec![
1156 FieldRef::new("id").eq(Ulid::nil()),
1157 FieldRef::new("email").text_eq_ci("User@example.com"),
1158 FieldRef::new("age").lt(30u32),
1159 ]);
1160
1161 assert!(validate_model(model, &predicate).is_ok());
1162 }
1163
1164 #[test]
1165 fn validate_model_rejects_map_predicates() {
1166 let model = <CollectionPredicateEntity as EntitySchema>::MODEL;
1167
1168 let map_contains_builder =
1169 FieldRef::new("attributes").map_contains_entry("k", 1u64, CoercionId::Strict);
1170 assert!(matches!(
1171 map_contains_builder,
1172 Err(UnsupportedQueryFeature::MapPredicate { field }) if field == "attributes"
1173 ));
1174
1175 let map_contains_predicate = Predicate::MapContainsEntry {
1176 field: "attributes".to_string(),
1177 key: Value::Text("k".to_string()),
1178 value: Value::Uint(1),
1179 coercion: CoercionSpec::new(CoercionId::Strict),
1180 };
1181 assert!(matches!(
1182 validate_model(model, &map_contains_predicate),
1183 Err(ValidateError::UnsupportedQueryFeature(UnsupportedQueryFeature::MapPredicate { field }))
1184 if field == "attributes"
1185 ));
1186
1187 let map_presence = Predicate::IsMissing {
1188 field: "attributes".to_string(),
1189 };
1190 assert!(matches!(
1191 validate_model(model, &map_presence),
1192 Err(ValidateError::UnsupportedQueryFeature(UnsupportedQueryFeature::MapPredicate { field }))
1193 if field == "attributes"
1194 ));
1195 }
1196
1197 #[test]
1198 fn validate_model_accepts_deterministic_set_predicates() {
1199 let model = <CollectionPredicateEntity as EntitySchema>::MODEL;
1200
1201 let predicate = Predicate::Compare(ComparePredicate::with_coercion(
1202 "principals",
1203 CompareOp::Contains,
1204 Principal::anonymous().to_value(),
1205 CoercionId::Strict,
1206 ));
1207
1208 assert!(validate_model(model, &predicate).is_ok());
1209 }
1210
1211 #[test]
1212 fn validate_model_rejects_non_queryable_fields() {
1213 let model = InvalidEntityModelBuilder::from_fields(
1214 vec![
1215 field("id", EntityFieldKind::Ulid),
1216 field("broken", EntityFieldKind::Structured { queryable: false }),
1217 ],
1218 0,
1219 );
1220
1221 let predicate = FieldRef::new("broken").eq(1u64);
1222
1223 assert!(matches!(
1224 validate_model(&model, &predicate),
1225 Err(ValidateError::NonQueryableFieldType { field }) if field == "broken"
1226 ));
1227 }
1228
1229 #[test]
1230 fn validate_model_accepts_text_contains() {
1231 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1232
1233 let predicate = FieldRef::new("email").text_contains("example");
1234 assert!(validate_model(model, &predicate).is_ok());
1235
1236 let predicate = FieldRef::new("email").text_contains_ci("EXAMPLE");
1237 assert!(validate_model(model, &predicate).is_ok());
1238 }
1239
1240 #[test]
1241 fn validate_model_rejects_text_contains_on_non_text() {
1242 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1243
1244 let predicate = FieldRef::new("age").text_contains("1");
1245 assert!(matches!(
1246 validate_model(model, &predicate),
1247 Err(ValidateError::InvalidOperator { field, op })
1248 if field == "age" && op == "text_contains"
1249 ));
1250 }
1251
1252 #[test]
1253 fn validate_model_rejects_numeric_widen_for_registry_exclusions() {
1254 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1255
1256 let date_pred = FieldRef::new("date").lt(1i64);
1257 assert!(matches!(
1258 validate_model(model, &date_pred),
1259 Err(ValidateError::InvalidCoercion { field, coercion })
1260 if field == "date" && coercion == CoercionId::NumericWiden
1261 ));
1262
1263 let int_big_pred = FieldRef::new("int_big").lt(Int::from(1i32));
1264 assert!(matches!(
1265 validate_model(model, &int_big_pred),
1266 Err(ValidateError::InvalidCoercion { field, coercion })
1267 if field == "int_big" && coercion == CoercionId::NumericWiden
1268 ));
1269
1270 let uint_big_pred = FieldRef::new("uint_big").lt(Nat::from(1u64));
1271 assert!(matches!(
1272 validate_model(model, &uint_big_pred),
1273 Err(ValidateError::InvalidCoercion { field, coercion })
1274 if field == "uint_big" && coercion == CoercionId::NumericWiden
1275 ));
1276 }
1277
1278 #[test]
1279 fn validate_model_accepts_numeric_widen_for_registry_allowed_scalars() {
1280 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1281 let predicate = Predicate::And(vec![
1282 FieldRef::new("int_small").lt(9u64),
1283 FieldRef::new("uint_small").lt(9i64),
1284 FieldRef::new("decimal").lt(9u64),
1285 FieldRef::new("e8s").lt(9u64),
1286 ]);
1287
1288 assert!(validate_model(model, &predicate).is_ok());
1289 }
1290
1291 #[test]
1292 fn numeric_widen_authority_tracks_registry_flags() {
1293 for scalar in registry_scalars() {
1294 let field_type = FieldType::Scalar(scalar.clone());
1295 let literal = sample_value_for_scalar(scalar.clone());
1296 let expected = scalar.supports_numeric_coercion();
1297 let actual = ensure_coercion(
1298 "value",
1299 &field_type,
1300 &literal,
1301 &CoercionSpec::new(CoercionId::NumericWiden),
1302 )
1303 .is_ok();
1304
1305 assert_eq!(
1306 actual, expected,
1307 "numeric widen drift for scalar {scalar:?}: expected {expected}, got {actual}"
1308 );
1309 }
1310 }
1311
1312 #[test]
1313 fn numeric_widen_is_not_inferred_from_coercion_family() {
1314 let mut numeric_family_with_no_numeric_widen = 0usize;
1315
1316 for scalar in registry_scalars() {
1317 if scalar.coercion_family() != CoercionFamily::Numeric {
1318 continue;
1319 }
1320
1321 let field_type = FieldType::Scalar(scalar.clone());
1322 let literal = sample_value_for_scalar(scalar.clone());
1323 let numeric_widen_allowed = ensure_coercion(
1324 "value",
1325 &field_type,
1326 &literal,
1327 &CoercionSpec::new(CoercionId::NumericWiden),
1328 )
1329 .is_ok();
1330
1331 assert_eq!(
1332 numeric_widen_allowed,
1333 scalar.supports_numeric_coercion(),
1334 "numeric family must not imply numeric widen for scalar {scalar:?}"
1335 );
1336
1337 if !scalar.supports_numeric_coercion() {
1338 numeric_family_with_no_numeric_widen =
1339 numeric_family_with_no_numeric_widen.saturating_add(1);
1340 }
1341 }
1342
1343 assert!(
1344 numeric_family_with_no_numeric_widen > 0,
1345 "expected at least one numeric-family scalar without numeric widen support"
1346 );
1347 }
1348
1349 #[test]
1350 fn scalar_registry_covers_all_variants_exactly_once() {
1351 let scalars = registry_scalars();
1352 let mut names = BTreeSet::new();
1353 let mut seen = [false; SCALAR_TYPE_VARIANT_COUNT];
1354
1355 for scalar in scalars {
1356 let index = scalar_index(scalar.clone());
1357 assert!(!seen[index], "duplicate scalar entry: {scalar:?}");
1358 seen[index] = true;
1359
1360 let name = format!("{scalar:?}");
1361 assert!(names.insert(name.clone()), "duplicate scalar entry: {name}");
1362 }
1363
1364 let mut missing = Vec::new();
1365 for (index, was_seen) in seen.iter().enumerate() {
1366 if !*was_seen {
1367 let scalar = scalar_from_index(index).expect("index is in range");
1368 missing.push(format!("{scalar:?}"));
1369 }
1370 }
1371
1372 assert!(missing.is_empty(), "missing scalar entries: {missing:?}");
1373 assert_eq!(names.len(), SCALAR_TYPE_VARIANT_COUNT);
1374 }
1375
1376 #[test]
1377 fn scalar_keyability_matches_value_storage_key() {
1378 for scalar in registry_scalars() {
1379 let value = sample_value_for_scalar(scalar.clone());
1380 let scalar_keyable = scalar.is_keyable();
1381 let value_keyable = value.as_storage_key().is_some();
1382
1383 assert_eq!(
1384 value_keyable, scalar_keyable,
1385 "Value::as_storage_key drift for scalar {scalar:?}"
1386 );
1387 }
1388 }
1389}