1use super::{
2 ast::{CompareOp, ComparePredicate, Predicate, UnsupportedQueryFeature},
3 coercion::{CoercionId, CoercionSpec, supports_coercion},
4};
5use crate::{
6 db::identity::{EntityName, EntityNameError, IndexName, IndexNameError},
7 model::{entity::EntityModel, field::EntityFieldKind, index::IndexModel},
8 traits::FieldValueKind,
9 value::{CoercionFamily, CoercionFamilyExt, Value},
10};
11use std::{
12 collections::{BTreeMap, BTreeSet},
13 fmt,
14};
15
16#[derive(Clone, Debug, Eq, PartialEq)]
28pub(crate) enum ScalarType {
29 Account,
30 Blob,
31 Bool,
32 Date,
33 Decimal,
34 Duration,
35 Enum,
36 E8s,
37 E18s,
38 Float32,
39 Float64,
40 Int,
41 Int128,
42 IntBig,
43 Principal,
44 Subaccount,
45 Text,
46 Timestamp,
47 Uint,
48 Uint128,
49 UintBig,
50 Ulid,
51 Unit,
52}
53
54macro_rules! scalar_coercion_family_from_registry {
56 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
57 match $self {
58 $( ScalarType::$scalar => $coercion_family, )*
59 }
60 };
61}
62
63macro_rules! scalar_matches_value_from_registry {
64 ( @args $self:expr, $value:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
65 matches!(
66 ($self, $value),
67 $( (ScalarType::$scalar, $value_pat) )|*
68 )
69 };
70}
71
72macro_rules! scalar_supports_numeric_coercion_from_registry {
73 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
74 match $self {
75 $( ScalarType::$scalar => $supports_numeric_coercion, )*
76 }
77 };
78}
79
80macro_rules! scalar_is_keyable_from_registry {
81 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
82 match $self {
83 $( ScalarType::$scalar => $is_keyable, )*
84 }
85 };
86}
87
88macro_rules! scalar_supports_ordering_from_registry {
89 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
90 match $self {
91 $( ScalarType::$scalar => $supports_ordering, )*
92 }
93 };
94}
95
96impl ScalarType {
97 #[must_use]
98 pub const fn coercion_family(&self) -> CoercionFamily {
99 scalar_registry!(scalar_coercion_family_from_registry, self)
100 }
101
102 #[must_use]
103 pub const fn is_orderable(&self) -> bool {
104 self.supports_ordering()
107 }
108
109 #[must_use]
110 pub const fn matches_value(&self, value: &Value) -> bool {
111 scalar_registry!(scalar_matches_value_from_registry, self, value)
112 }
113
114 #[must_use]
115 pub const fn supports_numeric_coercion(&self) -> bool {
116 scalar_registry!(scalar_supports_numeric_coercion_from_registry, self)
117 }
118
119 #[must_use]
120 pub const fn is_keyable(&self) -> bool {
121 scalar_registry!(scalar_is_keyable_from_registry, self)
122 }
123
124 #[must_use]
125 pub const fn supports_ordering(&self) -> bool {
126 scalar_registry!(scalar_supports_ordering_from_registry, self)
127 }
128}
129
130#[derive(Clone, Debug, Eq, PartialEq)]
141pub(crate) enum FieldType {
142 Scalar(ScalarType),
143 List(Box<Self>),
144 Set(Box<Self>),
145 Map { key: Box<Self>, value: Box<Self> },
146 Structured { queryable: bool },
147}
148
149impl FieldType {
150 #[must_use]
151 pub const fn value_kind(&self) -> FieldValueKind {
152 match self {
153 Self::Scalar(_) => FieldValueKind::Atomic,
154 Self::List(_) | Self::Set(_) => FieldValueKind::Structured { queryable: true },
155 Self::Map { .. } => FieldValueKind::Structured { queryable: false },
156 Self::Structured { queryable } => FieldValueKind::Structured {
157 queryable: *queryable,
158 },
159 }
160 }
161
162 #[must_use]
163 pub const fn coercion_family(&self) -> Option<CoercionFamily> {
164 match self {
165 Self::Scalar(inner) => Some(inner.coercion_family()),
166 Self::List(_) | Self::Set(_) | Self::Map { .. } => Some(CoercionFamily::Collection),
167 Self::Structured { .. } => None,
168 }
169 }
170
171 #[must_use]
172 pub const fn is_text(&self) -> bool {
173 matches!(self, Self::Scalar(ScalarType::Text))
174 }
175
176 #[must_use]
177 pub const fn is_collection(&self) -> bool {
178 matches!(self, Self::List(_) | Self::Set(_) | Self::Map { .. })
179 }
180
181 #[must_use]
182 pub const fn is_list_like(&self) -> bool {
183 matches!(self, Self::List(_) | Self::Set(_))
184 }
185
186 #[must_use]
187 pub const fn is_map(&self) -> bool {
188 matches!(self, Self::Map { .. })
189 }
190
191 #[must_use]
192 pub fn map_types(&self) -> Option<(&Self, &Self)> {
193 match self {
194 Self::Map { key, value } => Some((key.as_ref(), value.as_ref())),
195 _ => {
196 None
198 }
199 }
200 }
201
202 #[must_use]
203 pub const fn is_orderable(&self) -> bool {
204 match self {
205 Self::Scalar(inner) => inner.is_orderable(),
206 _ => false,
207 }
208 }
209
210 #[must_use]
211 pub const fn is_keyable(&self) -> bool {
212 match self {
213 Self::Scalar(inner) => inner.is_keyable(),
214 _ => false,
215 }
216 }
217
218 #[must_use]
219 pub const fn supports_numeric_coercion(&self) -> bool {
220 match self {
221 Self::Scalar(inner) => inner.supports_numeric_coercion(),
222 _ => false,
223 }
224 }
225}
226
227fn validate_index_fields(
228 fields: &BTreeMap<String, FieldType>,
229 indexes: &[&IndexModel],
230) -> Result<(), ValidateError> {
231 let mut seen_names = BTreeSet::new();
232 for index in indexes {
233 if seen_names.contains(index.name) {
234 return Err(ValidateError::DuplicateIndexName {
235 name: index.name.to_string(),
236 });
237 }
238 seen_names.insert(index.name);
239
240 let mut seen = BTreeSet::new();
241 for field in index.fields {
242 if !fields.contains_key(*field) {
243 return Err(ValidateError::IndexFieldUnknown {
244 index: **index,
245 field: (*field).to_string(),
246 });
247 }
248 if seen.contains(*field) {
249 return Err(ValidateError::IndexFieldDuplicate {
250 index: **index,
251 field: (*field).to_string(),
252 });
253 }
254 seen.insert(*field);
255
256 let field_type = fields
257 .get(*field)
258 .expect("index field existence checked above");
259 if matches!(field_type, FieldType::Map { .. }) {
262 return Err(ValidateError::IndexFieldMapNotQueryable {
263 index: **index,
264 field: (*field).to_string(),
265 });
266 }
267 if !field_type.value_kind().is_queryable() {
268 return Err(ValidateError::IndexFieldNotQueryable {
269 index: **index,
270 field: (*field).to_string(),
271 });
272 }
273 }
274 }
275
276 Ok(())
277}
278
279#[derive(Clone, Debug)]
287pub struct SchemaInfo {
288 fields: BTreeMap<String, FieldType>,
289}
290
291impl SchemaInfo {
292 #[must_use]
293 pub(crate) fn field(&self, name: &str) -> Option<&FieldType> {
294 self.fields.get(name)
295 }
296
297 pub fn from_entity_model(model: &EntityModel) -> Result<Self, ValidateError> {
298 let entity_name = EntityName::try_from_str(model.entity_name).map_err(|err| {
300 ValidateError::InvalidEntityName {
301 name: model.entity_name.to_string(),
302 source: err,
303 }
304 })?;
305
306 if !model
307 .fields
308 .iter()
309 .any(|field| std::ptr::eq(field, model.primary_key))
310 {
311 return Err(ValidateError::InvalidPrimaryKey {
312 field: model.primary_key.name.to_string(),
313 });
314 }
315
316 let mut fields = BTreeMap::new();
317 for field in model.fields {
318 if fields.contains_key(field.name) {
319 return Err(ValidateError::DuplicateField {
320 field: field.name.to_string(),
321 });
322 }
323 let ty = field_type_from_model_kind(&field.kind);
324 fields.insert(field.name.to_string(), ty);
325 }
326
327 let pk_field_type = fields
328 .get(model.primary_key.name)
329 .expect("primary key verified above");
330 if !pk_field_type.is_keyable() {
331 return Err(ValidateError::InvalidPrimaryKeyType {
332 field: model.primary_key.name.to_string(),
333 });
334 }
335
336 validate_index_fields(&fields, model.indexes)?;
337 for index in model.indexes {
338 IndexName::try_from_parts(&entity_name, index.fields).map_err(|err| {
339 ValidateError::InvalidIndexName {
340 index: **index,
341 source: err,
342 }
343 })?;
344 }
345
346 Ok(Self { fields })
347 }
348}
349
350#[derive(Debug, thiserror::Error)]
352pub enum ValidateError {
353 #[error("invalid entity name '{name}': {source}")]
354 InvalidEntityName {
355 name: String,
356 #[source]
357 source: EntityNameError,
358 },
359
360 #[error("invalid index name for '{index}': {source}")]
361 InvalidIndexName {
362 index: IndexModel,
363 #[source]
364 source: IndexNameError,
365 },
366
367 #[error("unknown field '{field}'")]
368 UnknownField { field: String },
369
370 #[error("field '{field}' is not queryable")]
371 NonQueryableFieldType { field: String },
372
373 #[error("duplicate field '{field}'")]
374 DuplicateField { field: String },
375
376 #[error("{0}")]
377 UnsupportedQueryFeature(#[from] UnsupportedQueryFeature),
378
379 #[error("primary key '{field}' not present in entity fields")]
380 InvalidPrimaryKey { field: String },
381
382 #[error("primary key '{field}' has a non-keyable type")]
383 InvalidPrimaryKeyType { field: String },
384
385 #[error("index '{index}' references unknown field '{field}'")]
386 IndexFieldUnknown { index: IndexModel, field: String },
387
388 #[error("index '{index}' references non-queryable field '{field}'")]
389 IndexFieldNotQueryable { index: IndexModel, field: String },
390
391 #[error(
392 "index '{index}' references map field '{field}'; map fields are not queryable in icydb 0.7"
393 )]
394 IndexFieldMapNotQueryable { index: IndexModel, field: String },
395
396 #[error("index '{index}' repeats field '{field}'")]
397 IndexFieldDuplicate { index: IndexModel, field: String },
398
399 #[error("duplicate index name '{name}'")]
400 DuplicateIndexName { name: String },
401
402 #[error("operator {op} is not valid for field '{field}'")]
403 InvalidOperator { field: String, op: String },
404
405 #[error("coercion {coercion:?} is not valid for field '{field}'")]
406 InvalidCoercion { field: String, coercion: CoercionId },
407
408 #[error("invalid literal for field '{field}': {message}")]
409 InvalidLiteral { field: String, message: String },
410}
411
412pub fn reject_unsupported_query_features(
414 predicate: &Predicate,
415) -> Result<(), UnsupportedQueryFeature> {
416 match predicate {
417 Predicate::True
418 | Predicate::False
419 | Predicate::Compare(_)
420 | Predicate::IsNull { .. }
421 | Predicate::IsMissing { .. }
422 | Predicate::IsEmpty { .. }
423 | Predicate::IsNotEmpty { .. }
424 | Predicate::TextContains { .. }
425 | Predicate::TextContainsCi { .. } => Ok(()),
426 Predicate::And(children) | Predicate::Or(children) => {
427 for child in children {
428 reject_unsupported_query_features(child)?;
429 }
430
431 Ok(())
432 }
433 Predicate::Not(inner) => reject_unsupported_query_features(inner),
434 Predicate::MapContainsKey { field, .. }
435 | Predicate::MapContainsValue { field, .. }
436 | Predicate::MapContainsEntry { field, .. } => Err(UnsupportedQueryFeature::MapPredicate {
437 field: field.clone(),
438 }),
439 }
440}
441
442pub fn validate(schema: &SchemaInfo, predicate: &Predicate) -> Result<(), ValidateError> {
443 reject_unsupported_query_features(predicate)?;
444
445 match predicate {
446 Predicate::True | Predicate::False => Ok(()),
447 Predicate::And(children) | Predicate::Or(children) => {
448 for child in children {
449 validate(schema, child)?;
450 }
451 Ok(())
452 }
453 Predicate::Not(inner) => validate(schema, inner),
454 Predicate::Compare(cmp) => validate_compare(schema, cmp),
455 Predicate::IsNull { field } | Predicate::IsMissing { field } => {
456 let _field_type = ensure_field(schema, field)?;
457 Ok(())
458 }
459 Predicate::IsEmpty { field } => {
460 let field_type = ensure_field(schema, field)?;
461 if field_type.is_text() || field_type.is_collection() {
462 Ok(())
463 } else {
464 Err(invalid_operator(field, "is_empty"))
465 }
466 }
467 Predicate::IsNotEmpty { field } => {
468 let field_type = ensure_field(schema, field)?;
469 if field_type.is_text() || field_type.is_collection() {
470 Ok(())
471 } else {
472 Err(invalid_operator(field, "is_not_empty"))
473 }
474 }
475 Predicate::MapContainsKey {
476 field,
477 key,
478 coercion,
479 } => validate_map_key(schema, field, key, coercion),
480 Predicate::MapContainsValue {
481 field,
482 value,
483 coercion,
484 } => validate_map_value(schema, field, value, coercion),
485 Predicate::MapContainsEntry {
486 field,
487 key,
488 value,
489 coercion,
490 } => validate_map_entry(schema, field, key, value, coercion),
491 Predicate::TextContains { field, value } => {
492 validate_text_contains(schema, field, value, "text_contains")
493 }
494 Predicate::TextContainsCi { field, value } => {
495 validate_text_contains(schema, field, value, "text_contains_ci")
496 }
497 }
498}
499
500pub fn validate_model(model: &EntityModel, predicate: &Predicate) -> Result<(), ValidateError> {
501 let schema = SchemaInfo::from_entity_model(model)?;
502 validate(&schema, predicate)
503}
504
505fn validate_compare(schema: &SchemaInfo, cmp: &ComparePredicate) -> Result<(), ValidateError> {
506 let field_type = ensure_field(schema, &cmp.field)?;
507
508 match cmp.op {
509 CompareOp::Eq | CompareOp::Ne => {
510 validate_eq_ne(&cmp.field, field_type, &cmp.value, &cmp.coercion)
511 }
512 CompareOp::Lt | CompareOp::Lte | CompareOp::Gt | CompareOp::Gte => {
513 validate_ordering(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
514 }
515 CompareOp::In | CompareOp::NotIn => {
516 validate_in(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
517 }
518 CompareOp::Contains => validate_contains(&cmp.field, field_type, &cmp.value, &cmp.coercion),
519 CompareOp::StartsWith | CompareOp::EndsWith => {
520 validate_text_compare(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
521 }
522 }
523}
524
525fn validate_eq_ne(
526 field: &str,
527 field_type: &FieldType,
528 value: &Value,
529 coercion: &CoercionSpec,
530) -> Result<(), ValidateError> {
531 if field_type.is_list_like() {
532 ensure_list_literal(field, value, field_type)?;
533 } else if field_type.is_map() {
534 ensure_map_literal(field, value, field_type)?;
535 } else {
536 ensure_scalar_literal(field, value)?;
537 }
538
539 ensure_coercion(field, field_type, value, coercion)
540}
541
542fn validate_ordering(
543 field: &str,
544 field_type: &FieldType,
545 value: &Value,
546 coercion: &CoercionSpec,
547 op: CompareOp,
548) -> Result<(), ValidateError> {
549 if matches!(coercion.id, CoercionId::CollectionElement) {
550 return Err(ValidateError::InvalidCoercion {
551 field: field.to_string(),
552 coercion: coercion.id,
553 });
554 }
555
556 if !field_type.is_orderable() {
557 return Err(invalid_operator(field, format!("{op:?}")));
558 }
559
560 ensure_scalar_literal(field, value)?;
561
562 ensure_coercion(field, field_type, value, coercion)
563}
564
565fn validate_in(
567 field: &str,
568 field_type: &FieldType,
569 value: &Value,
570 coercion: &CoercionSpec,
571 op: CompareOp,
572) -> Result<(), ValidateError> {
573 if field_type.is_collection() {
574 return Err(invalid_operator(field, format!("{op:?}")));
575 }
576
577 let Value::List(items) = value else {
578 return Err(invalid_literal(field, "expected list literal"));
579 };
580
581 for item in items {
582 ensure_coercion(field, field_type, item, coercion)?;
583 }
584
585 Ok(())
586}
587
588fn validate_contains(
590 field: &str,
591 field_type: &FieldType,
592 value: &Value,
593 coercion: &CoercionSpec,
594) -> Result<(), ValidateError> {
595 if field_type.is_text() {
596 return Err(invalid_operator(
598 field,
599 format!("{:?}", CompareOp::Contains),
600 ));
601 }
602
603 let element_type = match field_type {
604 FieldType::List(inner) | FieldType::Set(inner) => inner.as_ref(),
605 _ => {
606 return Err(invalid_operator(
607 field,
608 format!("{:?}", CompareOp::Contains),
609 ));
610 }
611 };
612
613 if matches!(coercion.id, CoercionId::TextCasefold) {
614 return Err(ValidateError::InvalidCoercion {
616 field: field.to_string(),
617 coercion: coercion.id,
618 });
619 }
620
621 ensure_coercion(field, element_type, value, coercion)
622}
623
624fn validate_text_compare(
626 field: &str,
627 field_type: &FieldType,
628 value: &Value,
629 coercion: &CoercionSpec,
630 op: CompareOp,
631) -> Result<(), ValidateError> {
632 if !field_type.is_text() {
633 return Err(invalid_operator(field, format!("{op:?}")));
634 }
635
636 ensure_text_literal(field, value)?;
637
638 ensure_coercion(field, field_type, value, coercion)
639}
640
641fn ensure_map_types<'a>(
643 schema: &'a SchemaInfo,
644 field: &str,
645 op: &str,
646) -> Result<(&'a FieldType, &'a FieldType), ValidateError> {
647 let field_type = ensure_field(schema, field)?;
648 field_type
649 .map_types()
650 .ok_or_else(|| invalid_operator(field, op))
651}
652
653fn validate_map_key(
654 schema: &SchemaInfo,
655 field: &str,
656 key: &Value,
657 coercion: &CoercionSpec,
658) -> Result<(), ValidateError> {
659 ensure_no_text_casefold(field, coercion)?;
660
661 let (key_type, _) = ensure_map_types(schema, field, "map_contains_key")?;
662
663 ensure_coercion(field, key_type, key, coercion)
664}
665
666fn validate_map_value(
667 schema: &SchemaInfo,
668 field: &str,
669 value: &Value,
670 coercion: &CoercionSpec,
671) -> Result<(), ValidateError> {
672 ensure_no_text_casefold(field, coercion)?;
673
674 let (_, value_type) = ensure_map_types(schema, field, "map_contains_value")?;
675
676 ensure_coercion(field, value_type, value, coercion)
677}
678
679fn validate_map_entry(
680 schema: &SchemaInfo,
681 field: &str,
682 key: &Value,
683 value: &Value,
684 coercion: &CoercionSpec,
685) -> Result<(), ValidateError> {
686 ensure_no_text_casefold(field, coercion)?;
687
688 let (key_type, value_type) = ensure_map_types(schema, field, "map_contains_entry")?;
689
690 ensure_coercion(field, key_type, key, coercion)?;
691 ensure_coercion(field, value_type, value, coercion)?;
692
693 Ok(())
694}
695
696fn validate_text_contains(
698 schema: &SchemaInfo,
699 field: &str,
700 value: &Value,
701 op: &str,
702) -> Result<(), ValidateError> {
703 let field_type = ensure_field(schema, field)?;
704 if !field_type.is_text() {
705 return Err(invalid_operator(field, op));
706 }
707
708 ensure_text_literal(field, value)?;
709
710 Ok(())
711}
712
713fn ensure_field<'a>(schema: &'a SchemaInfo, field: &str) -> Result<&'a FieldType, ValidateError> {
714 let field_type = schema
715 .field(field)
716 .ok_or_else(|| ValidateError::UnknownField {
717 field: field.to_string(),
718 })?;
719
720 if matches!(field_type, FieldType::Map { .. }) {
721 return Err(UnsupportedQueryFeature::MapPredicate {
722 field: field.to_string(),
723 }
724 .into());
725 }
726
727 if !field_type.value_kind().is_queryable() {
728 return Err(ValidateError::NonQueryableFieldType {
729 field: field.to_string(),
730 });
731 }
732
733 Ok(field_type)
734}
735
736fn invalid_operator(field: &str, op: impl fmt::Display) -> ValidateError {
737 ValidateError::InvalidOperator {
738 field: field.to_string(),
739 op: op.to_string(),
740 }
741}
742
743fn invalid_literal(field: &str, msg: &str) -> ValidateError {
744 ValidateError::InvalidLiteral {
745 field: field.to_string(),
746 message: msg.to_string(),
747 }
748}
749
750fn ensure_no_text_casefold(field: &str, coercion: &CoercionSpec) -> Result<(), ValidateError> {
752 if matches!(coercion.id, CoercionId::TextCasefold) {
753 return Err(ValidateError::InvalidCoercion {
754 field: field.to_string(),
755 coercion: coercion.id,
756 });
757 }
758
759 Ok(())
760}
761
762fn ensure_text_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
764 if !matches!(value, Value::Text(_)) {
765 return Err(invalid_literal(field, "expected text literal"));
766 }
767
768 Ok(())
769}
770
771fn ensure_scalar_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
773 if matches!(value, Value::List(_)) {
774 return Err(invalid_literal(field, "expected scalar literal"));
775 }
776
777 Ok(())
778}
779
780fn ensure_coercion(
781 field: &str,
782 field_type: &FieldType,
783 literal: &Value,
784 coercion: &CoercionSpec,
785) -> Result<(), ValidateError> {
786 if matches!(coercion.id, CoercionId::TextCasefold) && !field_type.is_text() {
787 return Err(ValidateError::InvalidCoercion {
789 field: field.to_string(),
790 coercion: coercion.id,
791 });
792 }
793
794 if matches!(coercion.id, CoercionId::NumericWiden)
799 && (!field_type.supports_numeric_coercion() || !literal.supports_numeric_coercion())
800 {
801 return Err(ValidateError::InvalidCoercion {
802 field: field.to_string(),
803 coercion: coercion.id,
804 });
805 }
806
807 if !matches!(coercion.id, CoercionId::NumericWiden) {
808 let left_family =
809 field_type
810 .coercion_family()
811 .ok_or_else(|| ValidateError::NonQueryableFieldType {
812 field: field.to_string(),
813 })?;
814 let right_family = literal.coercion_family();
815
816 if !supports_coercion(left_family, right_family, coercion.id) {
817 return Err(ValidateError::InvalidCoercion {
818 field: field.to_string(),
819 coercion: coercion.id,
820 });
821 }
822 }
823
824 if matches!(
825 coercion.id,
826 CoercionId::Strict | CoercionId::CollectionElement
827 ) && !literal_matches_type(literal, field_type)
828 {
829 return Err(invalid_literal(
830 field,
831 "literal type does not match field type",
832 ));
833 }
834
835 Ok(())
836}
837
838fn ensure_list_literal(
839 field: &str,
840 literal: &Value,
841 field_type: &FieldType,
842) -> Result<(), ValidateError> {
843 if !literal_matches_type(literal, field_type) {
844 return Err(invalid_literal(
845 field,
846 "list literal does not match field element type",
847 ));
848 }
849
850 Ok(())
851}
852
853fn ensure_map_literal(
854 field: &str,
855 literal: &Value,
856 field_type: &FieldType,
857) -> Result<(), ValidateError> {
858 if !literal_matches_type(literal, field_type) {
859 return Err(invalid_literal(
860 field,
861 "map literal does not match field key/value types",
862 ));
863 }
864
865 Ok(())
866}
867
868pub(crate) fn literal_matches_type(literal: &Value, field_type: &FieldType) -> bool {
869 match field_type {
870 FieldType::Scalar(inner) => inner.matches_value(literal),
871 FieldType::List(element) | FieldType::Set(element) => match literal {
872 Value::List(items) => items.iter().all(|item| literal_matches_type(item, element)),
873 _ => false,
874 },
875 FieldType::Map { key, value } => match literal {
876 Value::Map(entries) => {
877 if Value::validate_map_entries(entries.as_slice()).is_err() {
878 return false;
879 }
880
881 entries.iter().all(|(entry_key, entry_value)| {
882 literal_matches_type(entry_key, key) && literal_matches_type(entry_value, value)
883 })
884 }
885 _ => false,
886 },
887 FieldType::Structured { .. } => {
888 false
890 }
891 }
892}
893
894fn field_type_from_model_kind(kind: &EntityFieldKind) -> FieldType {
895 match kind {
896 EntityFieldKind::Account => FieldType::Scalar(ScalarType::Account),
897 EntityFieldKind::Blob => FieldType::Scalar(ScalarType::Blob),
898 EntityFieldKind::Bool => FieldType::Scalar(ScalarType::Bool),
899 EntityFieldKind::Date => FieldType::Scalar(ScalarType::Date),
900 EntityFieldKind::Decimal => FieldType::Scalar(ScalarType::Decimal),
901 EntityFieldKind::Duration => FieldType::Scalar(ScalarType::Duration),
902 EntityFieldKind::Enum => FieldType::Scalar(ScalarType::Enum),
903 EntityFieldKind::E8s => FieldType::Scalar(ScalarType::E8s),
904 EntityFieldKind::E18s => FieldType::Scalar(ScalarType::E18s),
905 EntityFieldKind::Float32 => FieldType::Scalar(ScalarType::Float32),
906 EntityFieldKind::Float64 => FieldType::Scalar(ScalarType::Float64),
907 EntityFieldKind::Int => FieldType::Scalar(ScalarType::Int),
908 EntityFieldKind::Int128 => FieldType::Scalar(ScalarType::Int128),
909 EntityFieldKind::IntBig => FieldType::Scalar(ScalarType::IntBig),
910 EntityFieldKind::Principal => FieldType::Scalar(ScalarType::Principal),
911 EntityFieldKind::Subaccount => FieldType::Scalar(ScalarType::Subaccount),
912 EntityFieldKind::Text => FieldType::Scalar(ScalarType::Text),
913 EntityFieldKind::Timestamp => FieldType::Scalar(ScalarType::Timestamp),
914 EntityFieldKind::Uint => FieldType::Scalar(ScalarType::Uint),
915 EntityFieldKind::Uint128 => FieldType::Scalar(ScalarType::Uint128),
916 EntityFieldKind::UintBig => FieldType::Scalar(ScalarType::UintBig),
917 EntityFieldKind::Ulid => FieldType::Scalar(ScalarType::Ulid),
918 EntityFieldKind::Unit => FieldType::Scalar(ScalarType::Unit),
919 EntityFieldKind::Relation { key_kind, .. } => field_type_from_model_kind(key_kind),
920 EntityFieldKind::List(inner) => {
921 FieldType::List(Box::new(field_type_from_model_kind(inner)))
922 }
923 EntityFieldKind::Set(inner) => FieldType::Set(Box::new(field_type_from_model_kind(inner))),
924 EntityFieldKind::Map { key, value } => FieldType::Map {
925 key: Box::new(field_type_from_model_kind(key)),
926 value: Box::new(field_type_from_model_kind(value)),
927 },
928 EntityFieldKind::Structured { queryable } => FieldType::Structured {
929 queryable: *queryable,
930 },
931 }
932}
933
934impl fmt::Display for FieldType {
935 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
936 match self {
937 Self::Scalar(inner) => write!(f, "{inner:?}"),
938 Self::List(inner) => write!(f, "List<{inner}>"),
939 Self::Set(inner) => write!(f, "Set<{inner}>"),
940 Self::Map { key, value } => write!(f, "Map<{key}, {value}>"),
941 Self::Structured { queryable } => {
942 write!(f, "Structured<queryable={queryable}>")
943 }
944 }
945 }
946}
947
948#[cfg(test)]
953mod tests {
954 use super::{FieldType, ScalarType, ValidateError, ensure_coercion, validate_model};
956 use crate::{
957 db::query::{
958 FieldRef,
959 predicate::{CoercionId, CoercionSpec, CompareOp, ComparePredicate, Predicate},
960 },
961 model::field::{EntityFieldKind, EntityFieldModel},
962 test_fixtures::InvalidEntityModelBuilder,
963 traits::{EntitySchema, FieldValue},
964 types::{
965 Account, Date, Decimal, Duration, E8s, E18s, Float32, Float64, Int, Int128, Nat,
966 Nat128, Principal, Subaccount, Timestamp, Ulid,
967 },
968 value::{CoercionFamily, Value, ValueEnum},
969 };
970 use std::collections::BTreeSet;
971
972 fn registry_scalars() -> Vec<ScalarType> {
974 macro_rules! collect_scalars {
975 ( @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
976 vec![ $( ScalarType::$scalar ),* ]
977 };
978 ( @args $($ignore:tt)*; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
979 vec![ $( ScalarType::$scalar ),* ]
980 };
981 }
982
983 let scalars = scalar_registry!(collect_scalars);
984
985 scalars
986 }
987
988 const SCALAR_TYPE_VARIANT_COUNT: usize = 23;
990
991 fn scalar_index(scalar: ScalarType) -> usize {
993 match scalar {
994 ScalarType::Account => 0,
995 ScalarType::Blob => 1,
996 ScalarType::Bool => 2,
997 ScalarType::Date => 3,
998 ScalarType::Decimal => 4,
999 ScalarType::Duration => 5,
1000 ScalarType::Enum => 6,
1001 ScalarType::E8s => 7,
1002 ScalarType::E18s => 8,
1003 ScalarType::Float32 => 9,
1004 ScalarType::Float64 => 10,
1005 ScalarType::Int => 11,
1006 ScalarType::Int128 => 12,
1007 ScalarType::IntBig => 13,
1008 ScalarType::Principal => 14,
1009 ScalarType::Subaccount => 15,
1010 ScalarType::Text => 16,
1011 ScalarType::Timestamp => 17,
1012 ScalarType::Uint => 18,
1013 ScalarType::Uint128 => 19,
1014 ScalarType::UintBig => 20,
1015 ScalarType::Ulid => 21,
1016 ScalarType::Unit => 22,
1017 }
1018 }
1019
1020 fn scalar_from_index(index: usize) -> Option<ScalarType> {
1022 let scalar = match index {
1023 0 => ScalarType::Account,
1024 1 => ScalarType::Blob,
1025 2 => ScalarType::Bool,
1026 3 => ScalarType::Date,
1027 4 => ScalarType::Decimal,
1028 5 => ScalarType::Duration,
1029 6 => ScalarType::Enum,
1030 7 => ScalarType::E8s,
1031 8 => ScalarType::E18s,
1032 9 => ScalarType::Float32,
1033 10 => ScalarType::Float64,
1034 11 => ScalarType::Int,
1035 12 => ScalarType::Int128,
1036 13 => ScalarType::IntBig,
1037 14 => ScalarType::Principal,
1038 15 => ScalarType::Subaccount,
1039 16 => ScalarType::Text,
1040 17 => ScalarType::Timestamp,
1041 18 => ScalarType::Uint,
1042 19 => ScalarType::Uint128,
1043 20 => ScalarType::UintBig,
1044 21 => ScalarType::Ulid,
1045 22 => ScalarType::Unit,
1046 _ => return None,
1047 };
1048
1049 Some(scalar)
1050 }
1051
1052 fn sample_value_for_scalar(scalar: ScalarType) -> Value {
1054 match scalar {
1055 ScalarType::Account => Value::Account(Account::dummy(1)),
1056 ScalarType::Blob => Value::Blob(vec![0u8, 1u8]),
1057 ScalarType::Bool => Value::Bool(true),
1058 ScalarType::Date => Value::Date(Date::EPOCH),
1059 ScalarType::Decimal => Value::Decimal(Decimal::ZERO),
1060 ScalarType::Duration => Value::Duration(Duration::ZERO),
1061 ScalarType::Enum => Value::Enum(ValueEnum::loose("example")),
1062 ScalarType::E8s => Value::E8s(E8s::from_atomic(0)),
1063 ScalarType::E18s => Value::E18s(E18s::from_atomic(0)),
1064 ScalarType::Float32 => {
1065 Value::Float32(Float32::try_new(0.0).expect("Float32 sample should be finite"))
1066 }
1067 ScalarType::Float64 => {
1068 Value::Float64(Float64::try_new(0.0).expect("Float64 sample should be finite"))
1069 }
1070 ScalarType::Int => Value::Int(0),
1071 ScalarType::Int128 => Value::Int128(Int128::from(0i128)),
1072 ScalarType::IntBig => Value::IntBig(Int::from(0i32)),
1073 ScalarType::Principal => Value::Principal(Principal::anonymous()),
1074 ScalarType::Subaccount => Value::Subaccount(Subaccount::dummy(2)),
1075 ScalarType::Text => Value::Text("text".to_string()),
1076 ScalarType::Timestamp => Value::Timestamp(Timestamp::EPOCH),
1077 ScalarType::Uint => Value::Uint(0),
1078 ScalarType::Uint128 => Value::Uint128(Nat128::from(0u128)),
1079 ScalarType::UintBig => Value::UintBig(Nat::from(0u64)),
1080 ScalarType::Ulid => Value::Ulid(Ulid::nil()),
1081 ScalarType::Unit => Value::Unit,
1082 }
1083 }
1084
1085 fn field(name: &'static str, kind: EntityFieldKind) -> EntityFieldModel {
1086 EntityFieldModel { name, kind }
1087 }
1088
1089 crate::test_entity_schema! {
1090 ScalarPredicateEntity,
1091 id = Ulid,
1092 path = "predicate_validate::ScalarEntity",
1093 entity_name = "ScalarEntity",
1094 primary_key = "id",
1095 pk_index = 0,
1096 fields = [
1097 ("id", EntityFieldKind::Ulid),
1098 ("email", EntityFieldKind::Text),
1099 ("age", EntityFieldKind::Uint),
1100 ("created_at", EntityFieldKind::Timestamp),
1101 ("active", EntityFieldKind::Bool),
1102 ],
1103 indexes = [],
1104 }
1105
1106 crate::test_entity_schema! {
1107 CollectionPredicateEntity,
1108 id = Ulid,
1109 path = "predicate_validate::CollectionEntity",
1110 entity_name = "CollectionEntity",
1111 primary_key = "id",
1112 pk_index = 0,
1113 fields = [
1114 ("id", EntityFieldKind::Ulid),
1115 ("tags", EntityFieldKind::List(&EntityFieldKind::Text)),
1116 ("principals", EntityFieldKind::Set(&EntityFieldKind::Principal)),
1117 (
1118 "attributes",
1119 EntityFieldKind::Map {
1120 key: &EntityFieldKind::Text,
1121 value: &EntityFieldKind::Uint,
1122 }
1123 ),
1124 ],
1125 indexes = [],
1126 }
1127
1128 crate::test_entity_schema! {
1129 NumericCoercionPredicateEntity,
1130 id = Ulid,
1131 path = "predicate_validate::NumericCoercionEntity",
1132 entity_name = "NumericCoercionEntity",
1133 primary_key = "id",
1134 pk_index = 0,
1135 fields = [
1136 ("id", EntityFieldKind::Ulid),
1137 ("date", EntityFieldKind::Date),
1138 ("int_big", EntityFieldKind::IntBig),
1139 ("uint_big", EntityFieldKind::UintBig),
1140 ("int_small", EntityFieldKind::Int),
1141 ("uint_small", EntityFieldKind::Uint),
1142 ("decimal", EntityFieldKind::Decimal),
1143 ("e8s", EntityFieldKind::E8s),
1144 ],
1145 indexes = [],
1146 }
1147
1148 #[test]
1149 fn validate_model_accepts_scalars_and_coercions() {
1150 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1151
1152 let predicate = Predicate::And(vec![
1153 FieldRef::new("id").eq(Ulid::nil()),
1154 FieldRef::new("email").text_eq_ci("User@example.com"),
1155 FieldRef::new("age").lt(30u32),
1156 ]);
1157
1158 assert!(validate_model(model, &predicate).is_ok());
1159 }
1160
1161 #[test]
1162 fn validate_model_accepts_deterministic_set_predicates() {
1163 let model = <CollectionPredicateEntity as EntitySchema>::MODEL;
1164
1165 let predicate = Predicate::Compare(ComparePredicate::with_coercion(
1166 "principals",
1167 CompareOp::Contains,
1168 Principal::anonymous().to_value(),
1169 CoercionId::Strict,
1170 ));
1171
1172 assert!(validate_model(model, &predicate).is_ok());
1173 }
1174
1175 #[test]
1176 fn validate_model_rejects_non_queryable_fields() {
1177 let model = InvalidEntityModelBuilder::from_fields(
1178 vec![
1179 field("id", EntityFieldKind::Ulid),
1180 field("broken", EntityFieldKind::Structured { queryable: false }),
1181 ],
1182 0,
1183 );
1184
1185 let predicate = FieldRef::new("broken").eq(1u64);
1186
1187 assert!(matches!(
1188 validate_model(&model, &predicate),
1189 Err(ValidateError::NonQueryableFieldType { field }) if field == "broken"
1190 ));
1191 }
1192
1193 #[test]
1194 fn validate_model_accepts_text_contains() {
1195 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1196
1197 let predicate = FieldRef::new("email").text_contains("example");
1198 assert!(validate_model(model, &predicate).is_ok());
1199
1200 let predicate = FieldRef::new("email").text_contains_ci("EXAMPLE");
1201 assert!(validate_model(model, &predicate).is_ok());
1202 }
1203
1204 #[test]
1205 fn validate_model_rejects_text_contains_on_non_text() {
1206 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1207
1208 let predicate = FieldRef::new("age").text_contains("1");
1209 assert!(matches!(
1210 validate_model(model, &predicate),
1211 Err(ValidateError::InvalidOperator { field, op })
1212 if field == "age" && op == "text_contains"
1213 ));
1214 }
1215
1216 #[test]
1217 fn validate_model_rejects_numeric_widen_for_registry_exclusions() {
1218 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1219
1220 let date_pred = FieldRef::new("date").lt(1i64);
1221 assert!(matches!(
1222 validate_model(model, &date_pred),
1223 Err(ValidateError::InvalidCoercion { field, coercion })
1224 if field == "date" && coercion == CoercionId::NumericWiden
1225 ));
1226
1227 let int_big_pred = FieldRef::new("int_big").lt(Int::from(1i32));
1228 assert!(matches!(
1229 validate_model(model, &int_big_pred),
1230 Err(ValidateError::InvalidCoercion { field, coercion })
1231 if field == "int_big" && coercion == CoercionId::NumericWiden
1232 ));
1233
1234 let uint_big_pred = FieldRef::new("uint_big").lt(Nat::from(1u64));
1235 assert!(matches!(
1236 validate_model(model, &uint_big_pred),
1237 Err(ValidateError::InvalidCoercion { field, coercion })
1238 if field == "uint_big" && coercion == CoercionId::NumericWiden
1239 ));
1240 }
1241
1242 #[test]
1243 fn validate_model_accepts_numeric_widen_for_registry_allowed_scalars() {
1244 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1245 let predicate = Predicate::And(vec![
1246 FieldRef::new("int_small").lt(9u64),
1247 FieldRef::new("uint_small").lt(9i64),
1248 FieldRef::new("decimal").lt(9u64),
1249 FieldRef::new("e8s").lt(9u64),
1250 ]);
1251
1252 assert!(validate_model(model, &predicate).is_ok());
1253 }
1254
1255 #[test]
1256 fn numeric_widen_authority_tracks_registry_flags() {
1257 for scalar in registry_scalars() {
1258 let field_type = FieldType::Scalar(scalar.clone());
1259 let literal = sample_value_for_scalar(scalar.clone());
1260 let expected = scalar.supports_numeric_coercion();
1261 let actual = ensure_coercion(
1262 "value",
1263 &field_type,
1264 &literal,
1265 &CoercionSpec::new(CoercionId::NumericWiden),
1266 )
1267 .is_ok();
1268
1269 assert_eq!(
1270 actual, expected,
1271 "numeric widen drift for scalar {scalar:?}: expected {expected}, got {actual}"
1272 );
1273 }
1274 }
1275
1276 #[test]
1277 fn numeric_widen_is_not_inferred_from_coercion_family() {
1278 let mut numeric_family_with_no_numeric_widen = 0usize;
1279
1280 for scalar in registry_scalars() {
1281 if scalar.coercion_family() != CoercionFamily::Numeric {
1282 continue;
1283 }
1284
1285 let field_type = FieldType::Scalar(scalar.clone());
1286 let literal = sample_value_for_scalar(scalar.clone());
1287 let numeric_widen_allowed = ensure_coercion(
1288 "value",
1289 &field_type,
1290 &literal,
1291 &CoercionSpec::new(CoercionId::NumericWiden),
1292 )
1293 .is_ok();
1294
1295 assert_eq!(
1296 numeric_widen_allowed,
1297 scalar.supports_numeric_coercion(),
1298 "numeric family must not imply numeric widen for scalar {scalar:?}"
1299 );
1300
1301 if !scalar.supports_numeric_coercion() {
1302 numeric_family_with_no_numeric_widen =
1303 numeric_family_with_no_numeric_widen.saturating_add(1);
1304 }
1305 }
1306
1307 assert!(
1308 numeric_family_with_no_numeric_widen > 0,
1309 "expected at least one numeric-family scalar without numeric widen support"
1310 );
1311 }
1312
1313 #[test]
1314 fn scalar_registry_covers_all_variants_exactly_once() {
1315 let scalars = registry_scalars();
1316 let mut names = BTreeSet::new();
1317 let mut seen = [false; SCALAR_TYPE_VARIANT_COUNT];
1318
1319 for scalar in scalars {
1320 let index = scalar_index(scalar.clone());
1321 assert!(!seen[index], "duplicate scalar entry: {scalar:?}");
1322 seen[index] = true;
1323
1324 let name = format!("{scalar:?}");
1325 assert!(names.insert(name.clone()), "duplicate scalar entry: {name}");
1326 }
1327
1328 let mut missing = Vec::new();
1329 for (index, was_seen) in seen.iter().enumerate() {
1330 if !*was_seen {
1331 let scalar = scalar_from_index(index).expect("index is in range");
1332 missing.push(format!("{scalar:?}"));
1333 }
1334 }
1335
1336 assert!(missing.is_empty(), "missing scalar entries: {missing:?}");
1337 assert_eq!(names.len(), SCALAR_TYPE_VARIANT_COUNT);
1338 }
1339
1340 #[test]
1341 fn scalar_keyability_matches_value_storage_key() {
1342 for scalar in registry_scalars() {
1343 let value = sample_value_for_scalar(scalar.clone());
1344 let scalar_keyable = scalar.is_keyable();
1345 let value_keyable = value.as_storage_key().is_some();
1346
1347 assert_eq!(
1348 value_keyable, scalar_keyable,
1349 "Value::as_storage_key drift for scalar {scalar:?}"
1350 );
1351 }
1352 }
1353}