1use super::{
2 ast::{CompareOp, ComparePredicate, Predicate},
3 coercion::{CoercionId, CoercionSpec, supports_coercion},
4};
5use crate::{
6 db::identity::{EntityName, EntityNameError, IndexName, IndexNameError},
7 model::{entity::EntityModel, field::EntityFieldKind, index::IndexModel},
8 value::{CoercionFamily, CoercionFamilyExt, Value},
9};
10use std::{
11 collections::{BTreeMap, BTreeSet},
12 fmt,
13};
14
15#[derive(Clone, Debug, Eq, PartialEq)]
27pub(crate) enum ScalarType {
28 Account,
29 Blob,
30 Bool,
31 Date,
32 Decimal,
33 Duration,
34 Enum,
35 E8s,
36 E18s,
37 Float32,
38 Float64,
39 Int,
40 Int128,
41 IntBig,
42 Principal,
43 Subaccount,
44 Text,
45 Timestamp,
46 Uint,
47 Uint128,
48 UintBig,
49 Ulid,
50 Unit,
51}
52
53macro_rules! scalar_coercion_family_from_registry {
55 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
56 match $self {
57 $( ScalarType::$scalar => $coercion_family, )*
58 }
59 };
60}
61
62macro_rules! scalar_matches_value_from_registry {
63 ( @args $self:expr, $value:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
64 matches!(
65 ($self, $value),
66 $( (ScalarType::$scalar, $value_pat) )|*
67 )
68 };
69}
70
71macro_rules! scalar_supports_numeric_coercion_from_registry {
72 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
73 match $self {
74 $( ScalarType::$scalar => $supports_numeric_coercion, )*
75 }
76 };
77}
78
79#[cfg(test)]
80macro_rules! scalar_supports_arithmetic_from_registry {
81 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
82 match $self {
83 $( ScalarType::$scalar => $supports_arithmetic, )*
84 }
85 };
86}
87
88macro_rules! scalar_is_keyable_from_registry {
89 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
90 match $self {
91 $( ScalarType::$scalar => $is_keyable, )*
92 }
93 };
94}
95
96#[cfg(test)]
97macro_rules! scalar_supports_equality_from_registry {
98 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
99 match $self {
100 $( ScalarType::$scalar => $supports_equality, )*
101 }
102 };
103}
104
105macro_rules! scalar_supports_ordering_from_registry {
106 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
107 match $self {
108 $( ScalarType::$scalar => $supports_ordering, )*
109 }
110 };
111}
112
113impl ScalarType {
114 #[must_use]
115 pub const fn coercion_family(&self) -> CoercionFamily {
116 scalar_registry!(scalar_coercion_family_from_registry, self)
117 }
118
119 #[must_use]
120 pub const fn is_orderable(&self) -> bool {
121 self.supports_ordering()
124 }
125
126 #[must_use]
127 pub const fn matches_value(&self, value: &Value) -> bool {
128 scalar_registry!(scalar_matches_value_from_registry, self, value)
129 }
130
131 #[must_use]
132 pub const fn supports_numeric_coercion(&self) -> bool {
133 scalar_registry!(scalar_supports_numeric_coercion_from_registry, self)
134 }
135
136 #[must_use]
137 #[cfg(test)]
138 #[expect(dead_code)]
139 pub const fn supports_arithmetic(&self) -> bool {
140 scalar_registry!(scalar_supports_arithmetic_from_registry, self)
141 }
142
143 #[must_use]
144 pub const fn is_keyable(&self) -> bool {
145 scalar_registry!(scalar_is_keyable_from_registry, self)
146 }
147
148 #[must_use]
149 #[cfg(test)]
150 #[expect(dead_code)]
151 pub const fn supports_equality(&self) -> bool {
152 scalar_registry!(scalar_supports_equality_from_registry, self)
153 }
154
155 #[must_use]
156 pub const fn supports_ordering(&self) -> bool {
157 scalar_registry!(scalar_supports_ordering_from_registry, self)
158 }
159}
160
161#[derive(Clone, Debug, Eq, PartialEq)]
172pub(crate) enum FieldType {
173 Scalar(ScalarType),
174 List(Box<Self>),
175 Set(Box<Self>),
176 Map { key: Box<Self>, value: Box<Self> },
177 Unsupported,
178}
179
180impl FieldType {
181 #[must_use]
182 pub const fn coercion_family(&self) -> Option<CoercionFamily> {
183 match self {
184 Self::Scalar(inner) => Some(inner.coercion_family()),
185 Self::List(_) | Self::Set(_) | Self::Map { .. } => Some(CoercionFamily::Collection),
186 Self::Unsupported => None,
187 }
188 }
189
190 #[must_use]
191 pub const fn is_text(&self) -> bool {
192 matches!(self, Self::Scalar(ScalarType::Text))
193 }
194
195 #[must_use]
196 pub const fn is_collection(&self) -> bool {
197 matches!(self, Self::List(_) | Self::Set(_) | Self::Map { .. })
198 }
199
200 #[must_use]
201 pub const fn is_list_like(&self) -> bool {
202 matches!(self, Self::List(_) | Self::Set(_))
203 }
204
205 #[must_use]
206 pub const fn is_map(&self) -> bool {
207 matches!(self, Self::Map { .. })
208 }
209
210 #[must_use]
211 pub fn map_types(&self) -> Option<(&Self, &Self)> {
212 match self {
213 Self::Map { key, value } => Some((key.as_ref(), value.as_ref())),
214 _ => {
215 None
217 }
218 }
219 }
220
221 #[must_use]
222 pub const fn is_orderable(&self) -> bool {
223 match self {
224 Self::Scalar(inner) => inner.is_orderable(),
225 _ => false,
226 }
227 }
228
229 #[must_use]
230 pub const fn is_keyable(&self) -> bool {
231 match self {
232 Self::Scalar(inner) => inner.is_keyable(),
233 _ => false,
234 }
235 }
236
237 #[must_use]
238 pub const fn supports_numeric_coercion(&self) -> bool {
239 match self {
240 Self::Scalar(inner) => inner.supports_numeric_coercion(),
241 _ => false,
242 }
243 }
244}
245
246fn validate_index_fields(
247 fields: &BTreeMap<String, FieldType>,
248 indexes: &[&IndexModel],
249) -> Result<(), ValidateError> {
250 let mut seen_names = BTreeSet::new();
251 for index in indexes {
252 if seen_names.contains(index.name) {
253 return Err(ValidateError::DuplicateIndexName {
254 name: index.name.to_string(),
255 });
256 }
257 seen_names.insert(index.name);
258
259 let mut seen = BTreeSet::new();
260 for field in index.fields {
261 if !fields.contains_key(*field) {
262 return Err(ValidateError::IndexFieldUnknown {
263 index: **index,
264 field: (*field).to_string(),
265 });
266 }
267 if seen.contains(*field) {
268 return Err(ValidateError::IndexFieldDuplicate {
269 index: **index,
270 field: (*field).to_string(),
271 });
272 }
273 seen.insert(*field);
274
275 let field_type = fields
276 .get(*field)
277 .expect("index field existence checked above");
278 if matches!(field_type, FieldType::Unsupported) {
281 return Err(ValidateError::IndexFieldUnsupported {
282 index: **index,
283 field: (*field).to_string(),
284 });
285 }
286 }
287 }
288
289 Ok(())
290}
291
292#[derive(Clone, Debug)]
300pub struct SchemaInfo {
301 fields: BTreeMap<String, FieldType>,
302}
303
304impl SchemaInfo {
305 #[must_use]
306 pub(crate) fn field(&self, name: &str) -> Option<&FieldType> {
307 self.fields.get(name)
308 }
309
310 pub fn from_entity_model(model: &EntityModel) -> Result<Self, ValidateError> {
311 let entity_name = EntityName::try_from_str(model.entity_name).map_err(|err| {
313 ValidateError::InvalidEntityName {
314 name: model.entity_name.to_string(),
315 source: err,
316 }
317 })?;
318
319 if !model
320 .fields
321 .iter()
322 .any(|field| std::ptr::eq(field, model.primary_key))
323 {
324 return Err(ValidateError::InvalidPrimaryKey {
325 field: model.primary_key.name.to_string(),
326 });
327 }
328
329 let mut fields = BTreeMap::new();
330 for field in model.fields {
331 if fields.contains_key(field.name) {
332 return Err(ValidateError::DuplicateField {
333 field: field.name.to_string(),
334 });
335 }
336 let ty = field_type_from_model_kind(&field.kind);
337 fields.insert(field.name.to_string(), ty);
338 }
339
340 let pk_field_type = fields
341 .get(model.primary_key.name)
342 .expect("primary key verified above");
343 if !pk_field_type.is_keyable() {
344 return Err(ValidateError::InvalidPrimaryKeyType {
345 field: model.primary_key.name.to_string(),
346 });
347 }
348
349 validate_index_fields(&fields, model.indexes)?;
350 for index in model.indexes {
351 IndexName::try_from_parts(&entity_name, index.fields).map_err(|err| {
352 ValidateError::InvalidIndexName {
353 index: **index,
354 source: err,
355 }
356 })?;
357 }
358
359 Ok(Self { fields })
360 }
361}
362
363#[derive(Debug, thiserror::Error)]
365pub enum ValidateError {
366 #[error("invalid entity name '{name}': {source}")]
367 InvalidEntityName {
368 name: String,
369 #[source]
370 source: EntityNameError,
371 },
372
373 #[error("invalid index name for '{index}': {source}")]
374 InvalidIndexName {
375 index: IndexModel,
376 #[source]
377 source: IndexNameError,
378 },
379
380 #[error("unknown field '{field}'")]
381 UnknownField { field: String },
382
383 #[error("unsupported field type for '{field}'")]
384 UnsupportedFieldType { field: String },
385
386 #[error("duplicate field '{field}'")]
387 DuplicateField { field: String },
388
389 #[error("primary key '{field}' not present in entity fields")]
390 InvalidPrimaryKey { field: String },
391
392 #[error("primary key '{field}' has an unsupported type")]
393 InvalidPrimaryKeyType { field: String },
394
395 #[error("index '{index}' references unknown field '{field}'")]
396 IndexFieldUnknown { index: IndexModel, field: String },
397
398 #[error("index '{index}' references unsupported field '{field}'")]
399 IndexFieldUnsupported { index: IndexModel, field: String },
400
401 #[error("index '{index}' repeats field '{field}'")]
402 IndexFieldDuplicate { index: IndexModel, field: String },
403
404 #[error("duplicate index name '{name}'")]
405 DuplicateIndexName { name: String },
406
407 #[error("operator {op} is not valid for field '{field}'")]
408 InvalidOperator { field: String, op: String },
409
410 #[error("coercion {coercion:?} is not valid for field '{field}'")]
411 InvalidCoercion { field: String, coercion: CoercionId },
412
413 #[error("invalid literal for field '{field}': {message}")]
414 InvalidLiteral { field: String, message: String },
415}
416
417pub fn validate(schema: &SchemaInfo, predicate: &Predicate) -> Result<(), ValidateError> {
418 match predicate {
419 Predicate::True | Predicate::False => Ok(()),
420 Predicate::And(children) | Predicate::Or(children) => {
421 for child in children {
422 validate(schema, child)?;
423 }
424 Ok(())
425 }
426 Predicate::Not(inner) => validate(schema, inner),
427 Predicate::Compare(cmp) => validate_compare(schema, cmp),
428 Predicate::IsNull { field } | Predicate::IsMissing { field } => {
429 ensure_field_exists(schema, field).map(|_| ())
431 }
432 Predicate::IsEmpty { field } => {
433 let field_type = ensure_field(schema, field)?;
434 if field_type.is_text() || field_type.is_collection() {
435 Ok(())
436 } else {
437 Err(invalid_operator(field, "is_empty"))
438 }
439 }
440 Predicate::IsNotEmpty { field } => {
441 let field_type = ensure_field(schema, field)?;
442 if field_type.is_text() || field_type.is_collection() {
443 Ok(())
444 } else {
445 Err(invalid_operator(field, "is_not_empty"))
446 }
447 }
448 Predicate::MapContainsKey {
449 field,
450 key,
451 coercion,
452 } => validate_map_key(schema, field, key, coercion),
453 Predicate::MapContainsValue {
454 field,
455 value,
456 coercion,
457 } => validate_map_value(schema, field, value, coercion),
458 Predicate::MapContainsEntry {
459 field,
460 key,
461 value,
462 coercion,
463 } => validate_map_entry(schema, field, key, value, coercion),
464 Predicate::TextContains { field, value } => {
465 validate_text_contains(schema, field, value, "text_contains")
466 }
467 Predicate::TextContainsCi { field, value } => {
468 validate_text_contains(schema, field, value, "text_contains_ci")
469 }
470 }
471}
472
473pub fn validate_model(model: &EntityModel, predicate: &Predicate) -> Result<(), ValidateError> {
474 let schema = SchemaInfo::from_entity_model(model)?;
475 validate(&schema, predicate)
476}
477
478fn validate_compare(schema: &SchemaInfo, cmp: &ComparePredicate) -> Result<(), ValidateError> {
479 let field_type = ensure_field(schema, &cmp.field)?;
480
481 match cmp.op {
482 CompareOp::Eq | CompareOp::Ne => {
483 validate_eq_ne(&cmp.field, field_type, &cmp.value, &cmp.coercion)
484 }
485 CompareOp::Lt | CompareOp::Lte | CompareOp::Gt | CompareOp::Gte => {
486 validate_ordering(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
487 }
488 CompareOp::In | CompareOp::NotIn => {
489 validate_in(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
490 }
491 CompareOp::Contains => validate_contains(&cmp.field, field_type, &cmp.value, &cmp.coercion),
492 CompareOp::StartsWith | CompareOp::EndsWith => {
493 validate_text_compare(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
494 }
495 }
496}
497
498fn validate_eq_ne(
499 field: &str,
500 field_type: &FieldType,
501 value: &Value,
502 coercion: &CoercionSpec,
503) -> Result<(), ValidateError> {
504 if field_type.is_list_like() {
505 ensure_list_literal(field, value, field_type)?;
506 } else if field_type.is_map() {
507 ensure_map_literal(field, value, field_type)?;
508 } else {
509 ensure_scalar_literal(field, value)?;
510 }
511
512 ensure_coercion(field, field_type, value, coercion)
513}
514
515fn validate_ordering(
516 field: &str,
517 field_type: &FieldType,
518 value: &Value,
519 coercion: &CoercionSpec,
520 op: CompareOp,
521) -> Result<(), ValidateError> {
522 if matches!(coercion.id, CoercionId::CollectionElement) {
523 return Err(ValidateError::InvalidCoercion {
524 field: field.to_string(),
525 coercion: coercion.id,
526 });
527 }
528
529 if !field_type.is_orderable() {
530 return Err(invalid_operator(field, format!("{op:?}")));
531 }
532
533 ensure_scalar_literal(field, value)?;
534
535 ensure_coercion(field, field_type, value, coercion)
536}
537
538fn validate_in(
540 field: &str,
541 field_type: &FieldType,
542 value: &Value,
543 coercion: &CoercionSpec,
544 op: CompareOp,
545) -> Result<(), ValidateError> {
546 if field_type.is_collection() {
547 return Err(invalid_operator(field, format!("{op:?}")));
548 }
549
550 let Value::List(items) = value else {
551 return Err(invalid_literal(field, "expected list literal"));
552 };
553
554 for item in items {
555 ensure_coercion(field, field_type, item, coercion)?;
556 }
557
558 Ok(())
559}
560
561fn validate_contains(
563 field: &str,
564 field_type: &FieldType,
565 value: &Value,
566 coercion: &CoercionSpec,
567) -> Result<(), ValidateError> {
568 if field_type.is_text() {
569 return Err(invalid_operator(
571 field,
572 format!("{:?}", CompareOp::Contains),
573 ));
574 }
575
576 let element_type = match field_type {
577 FieldType::List(inner) | FieldType::Set(inner) => inner.as_ref(),
578 _ => {
579 return Err(invalid_operator(
580 field,
581 format!("{:?}", CompareOp::Contains),
582 ));
583 }
584 };
585
586 if matches!(coercion.id, CoercionId::TextCasefold) {
587 return Err(ValidateError::InvalidCoercion {
589 field: field.to_string(),
590 coercion: coercion.id,
591 });
592 }
593
594 ensure_coercion(field, element_type, value, coercion)
595}
596
597fn validate_text_compare(
599 field: &str,
600 field_type: &FieldType,
601 value: &Value,
602 coercion: &CoercionSpec,
603 op: CompareOp,
604) -> Result<(), ValidateError> {
605 if !field_type.is_text() {
606 return Err(invalid_operator(field, format!("{op:?}")));
607 }
608
609 ensure_text_literal(field, value)?;
610
611 ensure_coercion(field, field_type, value, coercion)
612}
613
614fn ensure_map_types<'a>(
616 schema: &'a SchemaInfo,
617 field: &str,
618 op: &str,
619) -> Result<(&'a FieldType, &'a FieldType), ValidateError> {
620 let field_type = ensure_field(schema, field)?;
621 field_type
622 .map_types()
623 .ok_or_else(|| invalid_operator(field, op))
624}
625
626fn validate_map_key(
627 schema: &SchemaInfo,
628 field: &str,
629 key: &Value,
630 coercion: &CoercionSpec,
631) -> Result<(), ValidateError> {
632 ensure_no_text_casefold(field, coercion)?;
633
634 let (key_type, _) = ensure_map_types(schema, field, "map_contains_key")?;
635
636 ensure_coercion(field, key_type, key, coercion)
637}
638
639fn validate_map_value(
640 schema: &SchemaInfo,
641 field: &str,
642 value: &Value,
643 coercion: &CoercionSpec,
644) -> Result<(), ValidateError> {
645 ensure_no_text_casefold(field, coercion)?;
646
647 let (_, value_type) = ensure_map_types(schema, field, "map_contains_value")?;
648
649 ensure_coercion(field, value_type, value, coercion)
650}
651
652fn validate_map_entry(
653 schema: &SchemaInfo,
654 field: &str,
655 key: &Value,
656 value: &Value,
657 coercion: &CoercionSpec,
658) -> Result<(), ValidateError> {
659 ensure_no_text_casefold(field, coercion)?;
660
661 let (key_type, value_type) = ensure_map_types(schema, field, "map_contains_entry")?;
662
663 ensure_coercion(field, key_type, key, coercion)?;
664 ensure_coercion(field, value_type, value, coercion)?;
665
666 Ok(())
667}
668
669fn validate_text_contains(
671 schema: &SchemaInfo,
672 field: &str,
673 value: &Value,
674 op: &str,
675) -> Result<(), ValidateError> {
676 let field_type = ensure_field(schema, field)?;
677 if !field_type.is_text() {
678 return Err(invalid_operator(field, op));
679 }
680
681 ensure_text_literal(field, value)?;
682
683 Ok(())
684}
685
686fn ensure_field<'a>(schema: &'a SchemaInfo, field: &str) -> Result<&'a FieldType, ValidateError> {
687 let field_type = schema
688 .field(field)
689 .ok_or_else(|| ValidateError::UnknownField {
690 field: field.to_string(),
691 })?;
692
693 if matches!(field_type, FieldType::Unsupported) {
694 return Err(ValidateError::UnsupportedFieldType {
695 field: field.to_string(),
696 });
697 }
698
699 Ok(field_type)
700}
701
702fn ensure_field_exists<'a>(
703 schema: &'a SchemaInfo,
704 field: &str,
705) -> Result<&'a FieldType, ValidateError> {
706 schema
707 .field(field)
708 .ok_or_else(|| ValidateError::UnknownField {
709 field: field.to_string(),
710 })
711}
712
713fn invalid_operator(field: &str, op: impl fmt::Display) -> ValidateError {
714 ValidateError::InvalidOperator {
715 field: field.to_string(),
716 op: op.to_string(),
717 }
718}
719
720fn invalid_literal(field: &str, msg: &str) -> ValidateError {
721 ValidateError::InvalidLiteral {
722 field: field.to_string(),
723 message: msg.to_string(),
724 }
725}
726
727fn ensure_no_text_casefold(field: &str, coercion: &CoercionSpec) -> Result<(), ValidateError> {
729 if matches!(coercion.id, CoercionId::TextCasefold) {
730 return Err(ValidateError::InvalidCoercion {
731 field: field.to_string(),
732 coercion: coercion.id,
733 });
734 }
735
736 Ok(())
737}
738
739fn ensure_text_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
741 if !matches!(value, Value::Text(_)) {
742 return Err(invalid_literal(field, "expected text literal"));
743 }
744
745 Ok(())
746}
747
748fn ensure_scalar_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
750 if matches!(value, Value::List(_)) {
751 return Err(invalid_literal(field, "expected scalar literal"));
752 }
753
754 Ok(())
755}
756
757fn ensure_coercion(
758 field: &str,
759 field_type: &FieldType,
760 literal: &Value,
761 coercion: &CoercionSpec,
762) -> Result<(), ValidateError> {
763 if matches!(coercion.id, CoercionId::TextCasefold) && !field_type.is_text() {
764 return Err(ValidateError::InvalidCoercion {
766 field: field.to_string(),
767 coercion: coercion.id,
768 });
769 }
770
771 if matches!(coercion.id, CoercionId::NumericWiden)
776 && (!field_type.supports_numeric_coercion() || !literal.supports_numeric_coercion())
777 {
778 return Err(ValidateError::InvalidCoercion {
779 field: field.to_string(),
780 coercion: coercion.id,
781 });
782 }
783
784 if !matches!(coercion.id, CoercionId::NumericWiden) {
785 let left_family =
786 field_type
787 .coercion_family()
788 .ok_or_else(|| ValidateError::UnsupportedFieldType {
789 field: field.to_string(),
790 })?;
791 let right_family = literal.coercion_family();
792
793 if !supports_coercion(left_family, right_family, coercion.id) {
794 return Err(ValidateError::InvalidCoercion {
795 field: field.to_string(),
796 coercion: coercion.id,
797 });
798 }
799 }
800
801 if matches!(
802 coercion.id,
803 CoercionId::Strict | CoercionId::CollectionElement
804 ) && !literal_matches_type(literal, field_type)
805 {
806 return Err(invalid_literal(
807 field,
808 "literal type does not match field type",
809 ));
810 }
811
812 Ok(())
813}
814
815fn ensure_list_literal(
816 field: &str,
817 literal: &Value,
818 field_type: &FieldType,
819) -> Result<(), ValidateError> {
820 if !literal_matches_type(literal, field_type) {
821 return Err(invalid_literal(
822 field,
823 "list literal does not match field element type",
824 ));
825 }
826
827 Ok(())
828}
829
830fn ensure_map_literal(
831 field: &str,
832 literal: &Value,
833 field_type: &FieldType,
834) -> Result<(), ValidateError> {
835 if !literal_matches_type(literal, field_type) {
836 return Err(invalid_literal(
837 field,
838 "map literal does not match field key/value types",
839 ));
840 }
841
842 Ok(())
843}
844
845pub(crate) fn literal_matches_type(literal: &Value, field_type: &FieldType) -> bool {
846 match field_type {
847 FieldType::Scalar(inner) => inner.matches_value(literal),
848 FieldType::List(element) | FieldType::Set(element) => match literal {
849 Value::List(items) => items.iter().all(|item| literal_matches_type(item, element)),
850 _ => false,
851 },
852 FieldType::Map { key, value } => match literal {
853 Value::List(entries) => entries.iter().all(|entry| match entry {
854 Value::List(pair) if pair.len() == 2 => {
855 literal_matches_type(&pair[0], key) && literal_matches_type(&pair[1], value)
856 }
857 _ => false,
858 }),
859 _ => false,
860 },
861 FieldType::Unsupported => {
862 false
864 }
865 }
866}
867
868fn field_type_from_model_kind(kind: &EntityFieldKind) -> FieldType {
869 match kind {
870 EntityFieldKind::Account => FieldType::Scalar(ScalarType::Account),
871 EntityFieldKind::Blob => FieldType::Scalar(ScalarType::Blob),
872 EntityFieldKind::Bool => FieldType::Scalar(ScalarType::Bool),
873 EntityFieldKind::Date => FieldType::Scalar(ScalarType::Date),
874 EntityFieldKind::Decimal => FieldType::Scalar(ScalarType::Decimal),
875 EntityFieldKind::Duration => FieldType::Scalar(ScalarType::Duration),
876 EntityFieldKind::Enum => FieldType::Scalar(ScalarType::Enum),
877 EntityFieldKind::E8s => FieldType::Scalar(ScalarType::E8s),
878 EntityFieldKind::E18s => FieldType::Scalar(ScalarType::E18s),
879 EntityFieldKind::Float32 => FieldType::Scalar(ScalarType::Float32),
880 EntityFieldKind::Float64 => FieldType::Scalar(ScalarType::Float64),
881 EntityFieldKind::Int => FieldType::Scalar(ScalarType::Int),
882 EntityFieldKind::Int128 => FieldType::Scalar(ScalarType::Int128),
883 EntityFieldKind::IntBig => FieldType::Scalar(ScalarType::IntBig),
884 EntityFieldKind::Principal => FieldType::Scalar(ScalarType::Principal),
885 EntityFieldKind::Subaccount => FieldType::Scalar(ScalarType::Subaccount),
886 EntityFieldKind::Text => FieldType::Scalar(ScalarType::Text),
887 EntityFieldKind::Timestamp => FieldType::Scalar(ScalarType::Timestamp),
888 EntityFieldKind::Uint => FieldType::Scalar(ScalarType::Uint),
889 EntityFieldKind::Uint128 => FieldType::Scalar(ScalarType::Uint128),
890 EntityFieldKind::UintBig => FieldType::Scalar(ScalarType::UintBig),
891 EntityFieldKind::Ulid => FieldType::Scalar(ScalarType::Ulid),
892 EntityFieldKind::Unit => FieldType::Scalar(ScalarType::Unit),
893 EntityFieldKind::Ref { key_kind, .. } => field_type_from_model_kind(key_kind),
894 EntityFieldKind::List(inner) => {
895 FieldType::List(Box::new(field_type_from_model_kind(inner)))
896 }
897 EntityFieldKind::Set(inner) => FieldType::Set(Box::new(field_type_from_model_kind(inner))),
898 EntityFieldKind::Map { key, value } => FieldType::Map {
899 key: Box::new(field_type_from_model_kind(key)),
900 value: Box::new(field_type_from_model_kind(value)),
901 },
902 EntityFieldKind::Unsupported => FieldType::Unsupported,
903 }
904}
905
906impl fmt::Display for FieldType {
907 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
908 match self {
909 Self::Scalar(inner) => write!(f, "{inner:?}"),
910 Self::List(inner) => write!(f, "List<{inner}>"),
911 Self::Set(inner) => write!(f, "Set<{inner}>"),
912 Self::Map { key, value } => write!(f, "Map<{key}, {value}>"),
913 Self::Unsupported => write!(f, "Unsupported"),
914 }
915 }
916}
917
918#[cfg(test)]
923mod tests {
924 use super::{FieldType, ScalarType, ValidateError, ensure_coercion, validate_model};
926 use crate::{
927 db::query::{
928 FieldRef,
929 predicate::{CoercionId, CoercionSpec, Predicate},
930 },
931 model::field::{EntityFieldKind, EntityFieldModel},
932 test_fixtures::InvalidEntityModelBuilder,
933 traits::EntitySchema,
934 types::{
935 Account, Date, Decimal, Duration, E8s, E18s, Float32, Float64, Int, Int128, Nat,
936 Nat128, Principal, Subaccount, Timestamp, Ulid,
937 },
938 value::{CoercionFamily, Value, ValueEnum},
939 };
940 use std::collections::BTreeSet;
941
942 fn registry_scalars() -> Vec<ScalarType> {
944 macro_rules! collect_scalars {
945 ( @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
946 vec![ $( ScalarType::$scalar ),* ]
947 };
948 ( @args $($ignore:tt)*; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
949 vec![ $( ScalarType::$scalar ),* ]
950 };
951 }
952
953 let scalars = scalar_registry!(collect_scalars);
954
955 scalars
956 }
957
958 const SCALAR_TYPE_VARIANT_COUNT: usize = 23;
960
961 fn scalar_index(scalar: ScalarType) -> usize {
963 match scalar {
964 ScalarType::Account => 0,
965 ScalarType::Blob => 1,
966 ScalarType::Bool => 2,
967 ScalarType::Date => 3,
968 ScalarType::Decimal => 4,
969 ScalarType::Duration => 5,
970 ScalarType::Enum => 6,
971 ScalarType::E8s => 7,
972 ScalarType::E18s => 8,
973 ScalarType::Float32 => 9,
974 ScalarType::Float64 => 10,
975 ScalarType::Int => 11,
976 ScalarType::Int128 => 12,
977 ScalarType::IntBig => 13,
978 ScalarType::Principal => 14,
979 ScalarType::Subaccount => 15,
980 ScalarType::Text => 16,
981 ScalarType::Timestamp => 17,
982 ScalarType::Uint => 18,
983 ScalarType::Uint128 => 19,
984 ScalarType::UintBig => 20,
985 ScalarType::Ulid => 21,
986 ScalarType::Unit => 22,
987 }
988 }
989
990 fn scalar_from_index(index: usize) -> Option<ScalarType> {
992 let scalar = match index {
993 0 => ScalarType::Account,
994 1 => ScalarType::Blob,
995 2 => ScalarType::Bool,
996 3 => ScalarType::Date,
997 4 => ScalarType::Decimal,
998 5 => ScalarType::Duration,
999 6 => ScalarType::Enum,
1000 7 => ScalarType::E8s,
1001 8 => ScalarType::E18s,
1002 9 => ScalarType::Float32,
1003 10 => ScalarType::Float64,
1004 11 => ScalarType::Int,
1005 12 => ScalarType::Int128,
1006 13 => ScalarType::IntBig,
1007 14 => ScalarType::Principal,
1008 15 => ScalarType::Subaccount,
1009 16 => ScalarType::Text,
1010 17 => ScalarType::Timestamp,
1011 18 => ScalarType::Uint,
1012 19 => ScalarType::Uint128,
1013 20 => ScalarType::UintBig,
1014 21 => ScalarType::Ulid,
1015 22 => ScalarType::Unit,
1016 _ => return None,
1017 };
1018
1019 Some(scalar)
1020 }
1021
1022 fn sample_value_for_scalar(scalar: ScalarType) -> Value {
1024 match scalar {
1025 ScalarType::Account => Value::Account(Account::dummy(1)),
1026 ScalarType::Blob => Value::Blob(vec![0u8, 1u8]),
1027 ScalarType::Bool => Value::Bool(true),
1028 ScalarType::Date => Value::Date(Date::EPOCH),
1029 ScalarType::Decimal => Value::Decimal(Decimal::ZERO),
1030 ScalarType::Duration => Value::Duration(Duration::ZERO),
1031 ScalarType::Enum => Value::Enum(ValueEnum::loose("example")),
1032 ScalarType::E8s => Value::E8s(E8s::from_atomic(0)),
1033 ScalarType::E18s => Value::E18s(E18s::from_atomic(0)),
1034 ScalarType::Float32 => {
1035 Value::Float32(Float32::try_new(0.0).expect("Float32 sample should be finite"))
1036 }
1037 ScalarType::Float64 => {
1038 Value::Float64(Float64::try_new(0.0).expect("Float64 sample should be finite"))
1039 }
1040 ScalarType::Int => Value::Int(0),
1041 ScalarType::Int128 => Value::Int128(Int128::from(0i128)),
1042 ScalarType::IntBig => Value::IntBig(Int::from(0i32)),
1043 ScalarType::Principal => Value::Principal(Principal::anonymous()),
1044 ScalarType::Subaccount => Value::Subaccount(Subaccount::dummy(2)),
1045 ScalarType::Text => Value::Text("text".to_string()),
1046 ScalarType::Timestamp => Value::Timestamp(Timestamp::EPOCH),
1047 ScalarType::Uint => Value::Uint(0),
1048 ScalarType::Uint128 => Value::Uint128(Nat128::from(0u128)),
1049 ScalarType::UintBig => Value::UintBig(Nat::from(0u64)),
1050 ScalarType::Ulid => Value::Ulid(Ulid::nil()),
1051 ScalarType::Unit => Value::Unit,
1052 }
1053 }
1054
1055 fn field(name: &'static str, kind: EntityFieldKind) -> EntityFieldModel {
1056 EntityFieldModel { name, kind }
1057 }
1058
1059 crate::test_entity_schema! {
1060 ScalarPredicateEntity,
1061 id = Ulid,
1062 path = "predicate_validate::ScalarEntity",
1063 entity_name = "ScalarEntity",
1064 primary_key = "id",
1065 pk_index = 0,
1066 fields = [
1067 ("id", EntityFieldKind::Ulid),
1068 ("email", EntityFieldKind::Text),
1069 ("age", EntityFieldKind::Uint),
1070 ("created_at", EntityFieldKind::Timestamp),
1071 ("active", EntityFieldKind::Bool),
1072 ],
1073 indexes = [],
1074 }
1075
1076 crate::test_entity_schema! {
1077 CollectionPredicateEntity,
1078 id = Ulid,
1079 path = "predicate_validate::CollectionEntity",
1080 entity_name = "CollectionEntity",
1081 primary_key = "id",
1082 pk_index = 0,
1083 fields = [
1084 ("id", EntityFieldKind::Ulid),
1085 ("tags", EntityFieldKind::List(&EntityFieldKind::Text)),
1086 ("principals", EntityFieldKind::Set(&EntityFieldKind::Principal)),
1087 (
1088 "attributes",
1089 EntityFieldKind::Map {
1090 key: &EntityFieldKind::Text,
1091 value: &EntityFieldKind::Uint,
1092 }
1093 ),
1094 ],
1095 indexes = [],
1096 }
1097
1098 crate::test_entity_schema! {
1099 NumericCoercionPredicateEntity,
1100 id = Ulid,
1101 path = "predicate_validate::NumericCoercionEntity",
1102 entity_name = "NumericCoercionEntity",
1103 primary_key = "id",
1104 pk_index = 0,
1105 fields = [
1106 ("id", EntityFieldKind::Ulid),
1107 ("date", EntityFieldKind::Date),
1108 ("int_big", EntityFieldKind::IntBig),
1109 ("uint_big", EntityFieldKind::UintBig),
1110 ("int_small", EntityFieldKind::Int),
1111 ("uint_small", EntityFieldKind::Uint),
1112 ("decimal", EntityFieldKind::Decimal),
1113 ("e8s", EntityFieldKind::E8s),
1114 ],
1115 indexes = [],
1116 }
1117
1118 #[test]
1119 fn validate_model_accepts_scalars_and_coercions() {
1120 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1121
1122 let predicate = Predicate::And(vec![
1123 FieldRef::new("id").eq(Ulid::nil()),
1124 FieldRef::new("email").text_eq_ci("User@example.com"),
1125 FieldRef::new("age").lt(30u32),
1126 ]);
1127
1128 assert!(validate_model(model, &predicate).is_ok());
1129 }
1130
1131 #[test]
1132 fn validate_model_accepts_collections_and_map_contains() {
1133 let model = <CollectionPredicateEntity as EntitySchema>::MODEL;
1134
1135 let predicate = Predicate::And(vec![
1136 FieldRef::new("tags").is_empty(),
1137 FieldRef::new("principals").is_not_empty(),
1138 FieldRef::new("attributes").map_contains_entry("k", 1u64, CoercionId::Strict),
1139 ]);
1140
1141 assert!(validate_model(model, &predicate).is_ok());
1142
1143 let bad =
1144 FieldRef::new("attributes").map_contains_entry("k", 1u64, CoercionId::TextCasefold);
1145
1146 assert!(matches!(
1147 validate_model(model, &bad),
1148 Err(ValidateError::InvalidCoercion { .. })
1149 ));
1150 }
1151
1152 #[test]
1153 fn validate_model_rejects_unsupported_fields() {
1154 let model = InvalidEntityModelBuilder::from_fields(
1155 vec![
1156 field("id", EntityFieldKind::Ulid),
1157 field("broken", EntityFieldKind::Unsupported),
1158 ],
1159 0,
1160 );
1161
1162 let predicate = FieldRef::new("broken").eq(1u64);
1163
1164 assert!(matches!(
1165 validate_model(&model, &predicate),
1166 Err(ValidateError::UnsupportedFieldType { field }) if field == "broken"
1167 ));
1168 }
1169
1170 #[test]
1171 fn validate_model_accepts_text_contains() {
1172 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1173
1174 let predicate = FieldRef::new("email").text_contains("example");
1175 assert!(validate_model(model, &predicate).is_ok());
1176
1177 let predicate = FieldRef::new("email").text_contains_ci("EXAMPLE");
1178 assert!(validate_model(model, &predicate).is_ok());
1179 }
1180
1181 #[test]
1182 fn validate_model_rejects_text_contains_on_non_text() {
1183 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1184
1185 let predicate = FieldRef::new("age").text_contains("1");
1186 assert!(matches!(
1187 validate_model(model, &predicate),
1188 Err(ValidateError::InvalidOperator { field, op })
1189 if field == "age" && op == "text_contains"
1190 ));
1191 }
1192
1193 #[test]
1194 fn validate_model_rejects_numeric_widen_for_registry_exclusions() {
1195 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1196
1197 let date_pred = FieldRef::new("date").lt(1i64);
1198 assert!(matches!(
1199 validate_model(model, &date_pred),
1200 Err(ValidateError::InvalidCoercion { field, coercion })
1201 if field == "date" && coercion == CoercionId::NumericWiden
1202 ));
1203
1204 let int_big_pred = FieldRef::new("int_big").lt(Int::from(1i32));
1205 assert!(matches!(
1206 validate_model(model, &int_big_pred),
1207 Err(ValidateError::InvalidCoercion { field, coercion })
1208 if field == "int_big" && coercion == CoercionId::NumericWiden
1209 ));
1210
1211 let uint_big_pred = FieldRef::new("uint_big").lt(Nat::from(1u64));
1212 assert!(matches!(
1213 validate_model(model, &uint_big_pred),
1214 Err(ValidateError::InvalidCoercion { field, coercion })
1215 if field == "uint_big" && coercion == CoercionId::NumericWiden
1216 ));
1217 }
1218
1219 #[test]
1220 fn validate_model_accepts_numeric_widen_for_registry_allowed_scalars() {
1221 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1222 let predicate = Predicate::And(vec![
1223 FieldRef::new("int_small").lt(9u64),
1224 FieldRef::new("uint_small").lt(9i64),
1225 FieldRef::new("decimal").lt(9u64),
1226 FieldRef::new("e8s").lt(9u64),
1227 ]);
1228
1229 assert!(validate_model(model, &predicate).is_ok());
1230 }
1231
1232 #[test]
1233 fn numeric_widen_authority_tracks_registry_flags() {
1234 for scalar in registry_scalars() {
1235 let field_type = FieldType::Scalar(scalar.clone());
1236 let literal = sample_value_for_scalar(scalar.clone());
1237 let expected = scalar.supports_numeric_coercion();
1238 let actual = ensure_coercion(
1239 "value",
1240 &field_type,
1241 &literal,
1242 &CoercionSpec::new(CoercionId::NumericWiden),
1243 )
1244 .is_ok();
1245
1246 assert_eq!(
1247 actual, expected,
1248 "numeric widen drift for scalar {scalar:?}: expected {expected}, got {actual}"
1249 );
1250 }
1251 }
1252
1253 #[test]
1254 fn numeric_widen_is_not_inferred_from_coercion_family() {
1255 let mut numeric_family_with_no_numeric_widen = 0usize;
1256
1257 for scalar in registry_scalars() {
1258 if scalar.coercion_family() != CoercionFamily::Numeric {
1259 continue;
1260 }
1261
1262 let field_type = FieldType::Scalar(scalar.clone());
1263 let literal = sample_value_for_scalar(scalar.clone());
1264 let numeric_widen_allowed = ensure_coercion(
1265 "value",
1266 &field_type,
1267 &literal,
1268 &CoercionSpec::new(CoercionId::NumericWiden),
1269 )
1270 .is_ok();
1271
1272 assert_eq!(
1273 numeric_widen_allowed,
1274 scalar.supports_numeric_coercion(),
1275 "numeric family must not imply numeric widen for scalar {scalar:?}"
1276 );
1277
1278 if !scalar.supports_numeric_coercion() {
1279 numeric_family_with_no_numeric_widen =
1280 numeric_family_with_no_numeric_widen.saturating_add(1);
1281 }
1282 }
1283
1284 assert!(
1285 numeric_family_with_no_numeric_widen > 0,
1286 "expected at least one numeric-family scalar without numeric widen support"
1287 );
1288 }
1289
1290 #[test]
1291 fn scalar_registry_covers_all_variants_exactly_once() {
1292 let scalars = registry_scalars();
1293 let mut names = BTreeSet::new();
1294 let mut seen = [false; SCALAR_TYPE_VARIANT_COUNT];
1295
1296 for scalar in scalars {
1297 let index = scalar_index(scalar.clone());
1298 assert!(!seen[index], "duplicate scalar entry: {scalar:?}");
1299 seen[index] = true;
1300
1301 let name = format!("{scalar:?}");
1302 assert!(names.insert(name.clone()), "duplicate scalar entry: {name}");
1303 }
1304
1305 let mut missing = Vec::new();
1306 for (index, was_seen) in seen.iter().enumerate() {
1307 if !*was_seen {
1308 let scalar = scalar_from_index(index).expect("index is in range");
1309 missing.push(format!("{scalar:?}"));
1310 }
1311 }
1312
1313 assert!(missing.is_empty(), "missing scalar entries: {missing:?}");
1314 assert_eq!(names.len(), SCALAR_TYPE_VARIANT_COUNT);
1315 }
1316
1317 #[test]
1318 fn scalar_keyability_matches_value_storage_key() {
1319 for scalar in registry_scalars() {
1320 let value = sample_value_for_scalar(scalar.clone());
1321 let scalar_keyable = scalar.is_keyable();
1322 let value_keyable = value.as_storage_key().is_some();
1323
1324 assert_eq!(
1325 value_keyable, scalar_keyable,
1326 "Value::as_storage_key drift for scalar {scalar:?}"
1327 );
1328 }
1329 }
1330}