1use super::{
2 ast::{CompareOp, ComparePredicate, Predicate, UnsupportedQueryFeature},
3 coercion::{CoercionId, CoercionSpec, supports_coercion},
4};
5use crate::{
6 db::identity::{EntityName, EntityNameError, IndexName, IndexNameError},
7 model::{entity::EntityModel, field::EntityFieldKind, index::IndexModel},
8 traits::FieldValueKind,
9 value::{CoercionFamily, CoercionFamilyExt, Value},
10};
11use std::{
12 collections::{BTreeMap, BTreeSet},
13 fmt,
14};
15
16#[derive(Clone, Debug, Eq, PartialEq)]
28pub(crate) enum ScalarType {
29 Account,
30 Blob,
31 Bool,
32 Date,
33 Decimal,
34 Duration,
35 Enum,
36 E8s,
37 E18s,
38 Float32,
39 Float64,
40 Int,
41 Int128,
42 IntBig,
43 Principal,
44 Subaccount,
45 Text,
46 Timestamp,
47 Uint,
48 Uint128,
49 UintBig,
50 Ulid,
51 Unit,
52}
53
54macro_rules! scalar_coercion_family_from_registry {
56 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
57 match $self {
58 $( ScalarType::$scalar => $coercion_family, )*
59 }
60 };
61}
62
63macro_rules! scalar_matches_value_from_registry {
64 ( @args $self:expr, $value:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
65 matches!(
66 ($self, $value),
67 $( (ScalarType::$scalar, $value_pat) )|*
68 )
69 };
70}
71
72macro_rules! scalar_supports_numeric_coercion_from_registry {
73 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
74 match $self {
75 $( ScalarType::$scalar => $supports_numeric_coercion, )*
76 }
77 };
78}
79
80macro_rules! scalar_is_keyable_from_registry {
81 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
82 match $self {
83 $( ScalarType::$scalar => $is_keyable, )*
84 }
85 };
86}
87
88macro_rules! scalar_supports_ordering_from_registry {
89 ( @args $self:expr; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
90 match $self {
91 $( ScalarType::$scalar => $supports_ordering, )*
92 }
93 };
94}
95
96impl ScalarType {
97 #[must_use]
98 pub const fn coercion_family(&self) -> CoercionFamily {
99 scalar_registry!(scalar_coercion_family_from_registry, self)
100 }
101
102 #[must_use]
103 pub const fn is_orderable(&self) -> bool {
104 self.supports_ordering()
107 }
108
109 #[must_use]
110 pub const fn matches_value(&self, value: &Value) -> bool {
111 scalar_registry!(scalar_matches_value_from_registry, self, value)
112 }
113
114 #[must_use]
115 pub const fn supports_numeric_coercion(&self) -> bool {
116 scalar_registry!(scalar_supports_numeric_coercion_from_registry, self)
117 }
118
119 #[must_use]
120 pub const fn is_keyable(&self) -> bool {
121 scalar_registry!(scalar_is_keyable_from_registry, self)
122 }
123
124 #[must_use]
125 pub const fn supports_ordering(&self) -> bool {
126 scalar_registry!(scalar_supports_ordering_from_registry, self)
127 }
128}
129
130#[derive(Clone, Debug, Eq, PartialEq)]
141pub(crate) enum FieldType {
142 Scalar(ScalarType),
143 List(Box<Self>),
144 Set(Box<Self>),
145 Map { key: Box<Self>, value: Box<Self> },
146 Structured { queryable: bool },
147}
148
149impl FieldType {
150 #[must_use]
151 pub const fn value_kind(&self) -> FieldValueKind {
152 match self {
153 Self::Scalar(_) => FieldValueKind::Atomic,
154 Self::List(_) | Self::Set(_) => FieldValueKind::Structured { queryable: true },
155 Self::Map { .. } => FieldValueKind::Structured { queryable: false },
156 Self::Structured { queryable } => FieldValueKind::Structured {
157 queryable: *queryable,
158 },
159 }
160 }
161
162 #[must_use]
163 pub const fn coercion_family(&self) -> Option<CoercionFamily> {
164 match self {
165 Self::Scalar(inner) => Some(inner.coercion_family()),
166 Self::List(_) | Self::Set(_) | Self::Map { .. } => Some(CoercionFamily::Collection),
167 Self::Structured { .. } => None,
168 }
169 }
170
171 #[must_use]
172 pub const fn is_text(&self) -> bool {
173 matches!(self, Self::Scalar(ScalarType::Text))
174 }
175
176 #[must_use]
177 pub const fn is_collection(&self) -> bool {
178 matches!(self, Self::List(_) | Self::Set(_) | Self::Map { .. })
179 }
180
181 #[must_use]
182 pub const fn is_list_like(&self) -> bool {
183 matches!(self, Self::List(_) | Self::Set(_))
184 }
185
186 #[must_use]
187 pub const fn is_orderable(&self) -> bool {
188 match self {
189 Self::Scalar(inner) => inner.is_orderable(),
190 _ => false,
191 }
192 }
193
194 #[must_use]
195 pub const fn is_keyable(&self) -> bool {
196 match self {
197 Self::Scalar(inner) => inner.is_keyable(),
198 _ => false,
199 }
200 }
201
202 #[must_use]
203 pub const fn supports_numeric_coercion(&self) -> bool {
204 match self {
205 Self::Scalar(inner) => inner.supports_numeric_coercion(),
206 _ => false,
207 }
208 }
209}
210
211fn validate_index_fields(
212 fields: &BTreeMap<String, FieldType>,
213 indexes: &[&IndexModel],
214) -> Result<(), ValidateError> {
215 let mut seen_names = BTreeSet::new();
216 for index in indexes {
217 if seen_names.contains(index.name) {
218 return Err(ValidateError::DuplicateIndexName {
219 name: index.name.to_string(),
220 });
221 }
222 seen_names.insert(index.name);
223
224 let mut seen = BTreeSet::new();
225 for field in index.fields {
226 if !fields.contains_key(*field) {
227 return Err(ValidateError::IndexFieldUnknown {
228 index: **index,
229 field: (*field).to_string(),
230 });
231 }
232 if seen.contains(*field) {
233 return Err(ValidateError::IndexFieldDuplicate {
234 index: **index,
235 field: (*field).to_string(),
236 });
237 }
238 seen.insert(*field);
239
240 let field_type = fields
241 .get(*field)
242 .expect("index field existence checked above");
243 if matches!(field_type, FieldType::Map { .. }) {
246 return Err(ValidateError::IndexFieldMapNotQueryable {
247 index: **index,
248 field: (*field).to_string(),
249 });
250 }
251 if !field_type.value_kind().is_queryable() {
252 return Err(ValidateError::IndexFieldNotQueryable {
253 index: **index,
254 field: (*field).to_string(),
255 });
256 }
257 }
258 }
259
260 Ok(())
261}
262
263#[derive(Clone, Debug)]
271pub struct SchemaInfo {
272 fields: BTreeMap<String, FieldType>,
273}
274
275impl SchemaInfo {
276 #[must_use]
277 pub(crate) fn field(&self, name: &str) -> Option<&FieldType> {
278 self.fields.get(name)
279 }
280
281 pub fn from_entity_model(model: &EntityModel) -> Result<Self, ValidateError> {
282 let entity_name = EntityName::try_from_str(model.entity_name).map_err(|err| {
284 ValidateError::InvalidEntityName {
285 name: model.entity_name.to_string(),
286 source: err,
287 }
288 })?;
289
290 if !model
291 .fields
292 .iter()
293 .any(|field| std::ptr::eq(field, model.primary_key))
294 {
295 return Err(ValidateError::InvalidPrimaryKey {
296 field: model.primary_key.name.to_string(),
297 });
298 }
299
300 let mut fields = BTreeMap::new();
301 for field in model.fields {
302 if fields.contains_key(field.name) {
303 return Err(ValidateError::DuplicateField {
304 field: field.name.to_string(),
305 });
306 }
307 let ty = field_type_from_model_kind(&field.kind);
308 fields.insert(field.name.to_string(), ty);
309 }
310
311 let pk_field_type = fields
312 .get(model.primary_key.name)
313 .expect("primary key verified above");
314 if !pk_field_type.is_keyable() {
315 return Err(ValidateError::InvalidPrimaryKeyType {
316 field: model.primary_key.name.to_string(),
317 });
318 }
319
320 validate_index_fields(&fields, model.indexes)?;
321 for index in model.indexes {
322 IndexName::try_from_parts(&entity_name, index.fields).map_err(|err| {
323 ValidateError::InvalidIndexName {
324 index: **index,
325 source: err,
326 }
327 })?;
328 }
329
330 Ok(Self { fields })
331 }
332}
333
334#[derive(Debug, thiserror::Error)]
336pub enum ValidateError {
337 #[error("invalid entity name '{name}': {source}")]
338 InvalidEntityName {
339 name: String,
340 #[source]
341 source: EntityNameError,
342 },
343
344 #[error("invalid index name for '{index}': {source}")]
345 InvalidIndexName {
346 index: IndexModel,
347 #[source]
348 source: IndexNameError,
349 },
350
351 #[error("unknown field '{field}'")]
352 UnknownField { field: String },
353
354 #[error("field '{field}' is not queryable")]
355 NonQueryableFieldType { field: String },
356
357 #[error("duplicate field '{field}'")]
358 DuplicateField { field: String },
359
360 #[error("{0}")]
361 UnsupportedQueryFeature(#[from] UnsupportedQueryFeature),
362
363 #[error("primary key '{field}' not present in entity fields")]
364 InvalidPrimaryKey { field: String },
365
366 #[error("primary key '{field}' has a non-keyable type")]
367 InvalidPrimaryKeyType { field: String },
368
369 #[error("index '{index}' references unknown field '{field}'")]
370 IndexFieldUnknown { index: IndexModel, field: String },
371
372 #[error("index '{index}' references non-queryable field '{field}'")]
373 IndexFieldNotQueryable { index: IndexModel, field: String },
374
375 #[error(
376 "index '{index}' references map field '{field}'; map fields are not queryable in icydb 0.7"
377 )]
378 IndexFieldMapNotQueryable { index: IndexModel, field: String },
379
380 #[error("index '{index}' repeats field '{field}'")]
381 IndexFieldDuplicate { index: IndexModel, field: String },
382
383 #[error("duplicate index name '{name}'")]
384 DuplicateIndexName { name: String },
385
386 #[error("operator {op} is not valid for field '{field}'")]
387 InvalidOperator { field: String, op: String },
388
389 #[error("coercion {coercion:?} is not valid for field '{field}'")]
390 InvalidCoercion { field: String, coercion: CoercionId },
391
392 #[error("invalid literal for field '{field}': {message}")]
393 InvalidLiteral { field: String, message: String },
394}
395
396pub fn reject_unsupported_query_features(
398 predicate: &Predicate,
399) -> Result<(), UnsupportedQueryFeature> {
400 match predicate {
401 Predicate::True
402 | Predicate::False
403 | Predicate::Compare(_)
404 | Predicate::IsNull { .. }
405 | Predicate::IsMissing { .. }
406 | Predicate::IsEmpty { .. }
407 | Predicate::IsNotEmpty { .. }
408 | Predicate::TextContains { .. }
409 | Predicate::TextContainsCi { .. } => Ok(()),
410 Predicate::And(children) | Predicate::Or(children) => {
411 for child in children {
412 reject_unsupported_query_features(child)?;
413 }
414
415 Ok(())
416 }
417 Predicate::Not(inner) => reject_unsupported_query_features(inner),
418 }
419}
420
421pub fn validate(schema: &SchemaInfo, predicate: &Predicate) -> Result<(), ValidateError> {
422 reject_unsupported_query_features(predicate)?;
423
424 match predicate {
425 Predicate::True | Predicate::False => Ok(()),
426 Predicate::And(children) | Predicate::Or(children) => {
427 for child in children {
428 validate(schema, child)?;
429 }
430 Ok(())
431 }
432 Predicate::Not(inner) => validate(schema, inner),
433 Predicate::Compare(cmp) => validate_compare(schema, cmp),
434 Predicate::IsNull { field } | Predicate::IsMissing { field } => {
435 let _field_type = ensure_field(schema, field)?;
436 Ok(())
437 }
438 Predicate::IsEmpty { field } => {
439 let field_type = ensure_field(schema, field)?;
440 if field_type.is_text() || field_type.is_collection() {
441 Ok(())
442 } else {
443 Err(invalid_operator(field, "is_empty"))
444 }
445 }
446 Predicate::IsNotEmpty { field } => {
447 let field_type = ensure_field(schema, field)?;
448 if field_type.is_text() || field_type.is_collection() {
449 Ok(())
450 } else {
451 Err(invalid_operator(field, "is_not_empty"))
452 }
453 }
454 Predicate::TextContains { field, value } => {
455 validate_text_contains(schema, field, value, "text_contains")
456 }
457 Predicate::TextContainsCi { field, value } => {
458 validate_text_contains(schema, field, value, "text_contains_ci")
459 }
460 }
461}
462
463pub fn validate_model(model: &EntityModel, predicate: &Predicate) -> Result<(), ValidateError> {
464 let schema = SchemaInfo::from_entity_model(model)?;
465 validate(&schema, predicate)
466}
467
468fn validate_compare(schema: &SchemaInfo, cmp: &ComparePredicate) -> Result<(), ValidateError> {
469 let field_type = ensure_field(schema, &cmp.field)?;
470
471 match cmp.op {
472 CompareOp::Eq | CompareOp::Ne => {
473 validate_eq_ne(&cmp.field, field_type, &cmp.value, &cmp.coercion)
474 }
475 CompareOp::Lt | CompareOp::Lte | CompareOp::Gt | CompareOp::Gte => {
476 validate_ordering(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
477 }
478 CompareOp::In | CompareOp::NotIn => {
479 validate_in(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
480 }
481 CompareOp::Contains => validate_contains(&cmp.field, field_type, &cmp.value, &cmp.coercion),
482 CompareOp::StartsWith | CompareOp::EndsWith => {
483 validate_text_compare(&cmp.field, field_type, &cmp.value, &cmp.coercion, cmp.op)
484 }
485 }
486}
487
488fn validate_eq_ne(
489 field: &str,
490 field_type: &FieldType,
491 value: &Value,
492 coercion: &CoercionSpec,
493) -> Result<(), ValidateError> {
494 if field_type.is_list_like() {
495 ensure_list_literal(field, value, field_type)?;
496 } else {
497 ensure_scalar_literal(field, value)?;
498 }
499
500 ensure_coercion(field, field_type, value, coercion)
501}
502
503fn validate_ordering(
504 field: &str,
505 field_type: &FieldType,
506 value: &Value,
507 coercion: &CoercionSpec,
508 op: CompareOp,
509) -> Result<(), ValidateError> {
510 if matches!(coercion.id, CoercionId::CollectionElement) {
511 return Err(ValidateError::InvalidCoercion {
512 field: field.to_string(),
513 coercion: coercion.id,
514 });
515 }
516
517 if !field_type.is_orderable() {
518 return Err(invalid_operator(field, format!("{op:?}")));
519 }
520
521 ensure_scalar_literal(field, value)?;
522
523 ensure_coercion(field, field_type, value, coercion)
524}
525
526fn validate_in(
528 field: &str,
529 field_type: &FieldType,
530 value: &Value,
531 coercion: &CoercionSpec,
532 op: CompareOp,
533) -> Result<(), ValidateError> {
534 if field_type.is_collection() {
535 return Err(invalid_operator(field, format!("{op:?}")));
536 }
537
538 let Value::List(items) = value else {
539 return Err(invalid_literal(field, "expected list literal"));
540 };
541
542 for item in items {
543 ensure_coercion(field, field_type, item, coercion)?;
544 }
545
546 Ok(())
547}
548
549fn validate_contains(
551 field: &str,
552 field_type: &FieldType,
553 value: &Value,
554 coercion: &CoercionSpec,
555) -> Result<(), ValidateError> {
556 if field_type.is_text() {
557 return Err(invalid_operator(
559 field,
560 format!("{:?}", CompareOp::Contains),
561 ));
562 }
563
564 let element_type = match field_type {
565 FieldType::List(inner) | FieldType::Set(inner) => inner.as_ref(),
566 _ => {
567 return Err(invalid_operator(
568 field,
569 format!("{:?}", CompareOp::Contains),
570 ));
571 }
572 };
573
574 if matches!(coercion.id, CoercionId::TextCasefold) {
575 return Err(ValidateError::InvalidCoercion {
577 field: field.to_string(),
578 coercion: coercion.id,
579 });
580 }
581
582 ensure_coercion(field, element_type, value, coercion)
583}
584
585fn validate_text_compare(
587 field: &str,
588 field_type: &FieldType,
589 value: &Value,
590 coercion: &CoercionSpec,
591 op: CompareOp,
592) -> Result<(), ValidateError> {
593 if !field_type.is_text() {
594 return Err(invalid_operator(field, format!("{op:?}")));
595 }
596
597 ensure_text_literal(field, value)?;
598
599 ensure_coercion(field, field_type, value, coercion)
600}
601
602fn validate_text_contains(
604 schema: &SchemaInfo,
605 field: &str,
606 value: &Value,
607 op: &str,
608) -> Result<(), ValidateError> {
609 let field_type = ensure_field(schema, field)?;
610 if !field_type.is_text() {
611 return Err(invalid_operator(field, op));
612 }
613
614 ensure_text_literal(field, value)?;
615
616 Ok(())
617}
618
619fn ensure_field<'a>(schema: &'a SchemaInfo, field: &str) -> Result<&'a FieldType, ValidateError> {
620 let field_type = schema
621 .field(field)
622 .ok_or_else(|| ValidateError::UnknownField {
623 field: field.to_string(),
624 })?;
625
626 if matches!(field_type, FieldType::Map { .. }) {
627 return Err(UnsupportedQueryFeature::MapPredicate {
628 field: field.to_string(),
629 }
630 .into());
631 }
632
633 if !field_type.value_kind().is_queryable() {
634 return Err(ValidateError::NonQueryableFieldType {
635 field: field.to_string(),
636 });
637 }
638
639 Ok(field_type)
640}
641
642fn invalid_operator(field: &str, op: impl fmt::Display) -> ValidateError {
643 ValidateError::InvalidOperator {
644 field: field.to_string(),
645 op: op.to_string(),
646 }
647}
648
649fn invalid_literal(field: &str, msg: &str) -> ValidateError {
650 ValidateError::InvalidLiteral {
651 field: field.to_string(),
652 message: msg.to_string(),
653 }
654}
655
656fn ensure_text_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
658 if !matches!(value, Value::Text(_)) {
659 return Err(invalid_literal(field, "expected text literal"));
660 }
661
662 Ok(())
663}
664
665fn ensure_scalar_literal(field: &str, value: &Value) -> Result<(), ValidateError> {
667 if matches!(value, Value::List(_)) {
668 return Err(invalid_literal(field, "expected scalar literal"));
669 }
670
671 Ok(())
672}
673
674fn ensure_coercion(
675 field: &str,
676 field_type: &FieldType,
677 literal: &Value,
678 coercion: &CoercionSpec,
679) -> Result<(), ValidateError> {
680 if matches!(coercion.id, CoercionId::TextCasefold) && !field_type.is_text() {
681 return Err(ValidateError::InvalidCoercion {
683 field: field.to_string(),
684 coercion: coercion.id,
685 });
686 }
687
688 if matches!(coercion.id, CoercionId::NumericWiden)
693 && (!field_type.supports_numeric_coercion() || !literal.supports_numeric_coercion())
694 {
695 return Err(ValidateError::InvalidCoercion {
696 field: field.to_string(),
697 coercion: coercion.id,
698 });
699 }
700
701 if !matches!(coercion.id, CoercionId::NumericWiden) {
702 let left_family =
703 field_type
704 .coercion_family()
705 .ok_or_else(|| ValidateError::NonQueryableFieldType {
706 field: field.to_string(),
707 })?;
708 let right_family = literal.coercion_family();
709
710 if !supports_coercion(left_family, right_family, coercion.id) {
711 return Err(ValidateError::InvalidCoercion {
712 field: field.to_string(),
713 coercion: coercion.id,
714 });
715 }
716 }
717
718 if matches!(
719 coercion.id,
720 CoercionId::Strict | CoercionId::CollectionElement
721 ) && !literal_matches_type(literal, field_type)
722 {
723 return Err(invalid_literal(
724 field,
725 "literal type does not match field type",
726 ));
727 }
728
729 Ok(())
730}
731
732fn ensure_list_literal(
733 field: &str,
734 literal: &Value,
735 field_type: &FieldType,
736) -> Result<(), ValidateError> {
737 if !literal_matches_type(literal, field_type) {
738 return Err(invalid_literal(
739 field,
740 "list literal does not match field element type",
741 ));
742 }
743
744 Ok(())
745}
746
747pub(crate) fn literal_matches_type(literal: &Value, field_type: &FieldType) -> bool {
748 match field_type {
749 FieldType::Scalar(inner) => inner.matches_value(literal),
750 FieldType::List(element) | FieldType::Set(element) => match literal {
751 Value::List(items) => items.iter().all(|item| literal_matches_type(item, element)),
752 _ => false,
753 },
754 FieldType::Map { key, value } => match literal {
755 Value::Map(entries) => {
756 if Value::validate_map_entries(entries.as_slice()).is_err() {
757 return false;
758 }
759
760 entries.iter().all(|(entry_key, entry_value)| {
761 literal_matches_type(entry_key, key) && literal_matches_type(entry_value, value)
762 })
763 }
764 _ => false,
765 },
766 FieldType::Structured { .. } => {
767 false
769 }
770 }
771}
772
773fn field_type_from_model_kind(kind: &EntityFieldKind) -> FieldType {
774 match kind {
775 EntityFieldKind::Account => FieldType::Scalar(ScalarType::Account),
776 EntityFieldKind::Blob => FieldType::Scalar(ScalarType::Blob),
777 EntityFieldKind::Bool => FieldType::Scalar(ScalarType::Bool),
778 EntityFieldKind::Date => FieldType::Scalar(ScalarType::Date),
779 EntityFieldKind::Decimal => FieldType::Scalar(ScalarType::Decimal),
780 EntityFieldKind::Duration => FieldType::Scalar(ScalarType::Duration),
781 EntityFieldKind::Enum => FieldType::Scalar(ScalarType::Enum),
782 EntityFieldKind::E8s => FieldType::Scalar(ScalarType::E8s),
783 EntityFieldKind::E18s => FieldType::Scalar(ScalarType::E18s),
784 EntityFieldKind::Float32 => FieldType::Scalar(ScalarType::Float32),
785 EntityFieldKind::Float64 => FieldType::Scalar(ScalarType::Float64),
786 EntityFieldKind::Int => FieldType::Scalar(ScalarType::Int),
787 EntityFieldKind::Int128 => FieldType::Scalar(ScalarType::Int128),
788 EntityFieldKind::IntBig => FieldType::Scalar(ScalarType::IntBig),
789 EntityFieldKind::Principal => FieldType::Scalar(ScalarType::Principal),
790 EntityFieldKind::Subaccount => FieldType::Scalar(ScalarType::Subaccount),
791 EntityFieldKind::Text => FieldType::Scalar(ScalarType::Text),
792 EntityFieldKind::Timestamp => FieldType::Scalar(ScalarType::Timestamp),
793 EntityFieldKind::Uint => FieldType::Scalar(ScalarType::Uint),
794 EntityFieldKind::Uint128 => FieldType::Scalar(ScalarType::Uint128),
795 EntityFieldKind::UintBig => FieldType::Scalar(ScalarType::UintBig),
796 EntityFieldKind::Ulid => FieldType::Scalar(ScalarType::Ulid),
797 EntityFieldKind::Unit => FieldType::Scalar(ScalarType::Unit),
798 EntityFieldKind::Relation { key_kind, .. } => field_type_from_model_kind(key_kind),
799 EntityFieldKind::List(inner) => {
800 FieldType::List(Box::new(field_type_from_model_kind(inner)))
801 }
802 EntityFieldKind::Set(inner) => FieldType::Set(Box::new(field_type_from_model_kind(inner))),
803 EntityFieldKind::Map { key, value } => FieldType::Map {
804 key: Box::new(field_type_from_model_kind(key)),
805 value: Box::new(field_type_from_model_kind(value)),
806 },
807 EntityFieldKind::Structured { queryable } => FieldType::Structured {
808 queryable: *queryable,
809 },
810 }
811}
812
813impl fmt::Display for FieldType {
814 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
815 match self {
816 Self::Scalar(inner) => write!(f, "{inner:?}"),
817 Self::List(inner) => write!(f, "List<{inner}>"),
818 Self::Set(inner) => write!(f, "Set<{inner}>"),
819 Self::Map { key, value } => write!(f, "Map<{key}, {value}>"),
820 Self::Structured { queryable } => {
821 write!(f, "Structured<queryable={queryable}>")
822 }
823 }
824 }
825}
826
827#[cfg(test)]
832mod tests {
833 use super::{FieldType, ScalarType, ValidateError, ensure_coercion, validate_model};
835 use crate::{
836 db::query::{
837 FieldRef,
838 predicate::{CoercionId, CoercionSpec, CompareOp, ComparePredicate, Predicate},
839 },
840 model::field::{EntityFieldKind, EntityFieldModel},
841 test_fixtures::InvalidEntityModelBuilder,
842 traits::{EntitySchema, FieldValue},
843 types::{
844 Account, Date, Decimal, Duration, E8s, E18s, Float32, Float64, Int, Int128, Nat,
845 Nat128, Principal, Subaccount, Timestamp, Ulid,
846 },
847 value::{CoercionFamily, Value, ValueEnum},
848 };
849 use std::collections::BTreeSet;
850
851 fn registry_scalars() -> Vec<ScalarType> {
853 macro_rules! collect_scalars {
854 ( @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
855 vec![ $( ScalarType::$scalar ),* ]
856 };
857 ( @args $($ignore:tt)*; @entries $( ($scalar:ident, $coercion_family:expr, $value_pat:pat, is_numeric_value = $is_numeric:expr, supports_numeric_coercion = $supports_numeric_coercion:expr, supports_arithmetic = $supports_arithmetic:expr, supports_equality = $supports_equality:expr, supports_ordering = $supports_ordering:expr, is_keyable = $is_keyable:expr, is_storage_key_encodable = $is_storage_key_encodable:expr) ),* $(,)? ) => {
858 vec![ $( ScalarType::$scalar ),* ]
859 };
860 }
861
862 let scalars = scalar_registry!(collect_scalars);
863
864 scalars
865 }
866
867 const SCALAR_TYPE_VARIANT_COUNT: usize = 23;
869
870 fn scalar_index(scalar: ScalarType) -> usize {
872 match scalar {
873 ScalarType::Account => 0,
874 ScalarType::Blob => 1,
875 ScalarType::Bool => 2,
876 ScalarType::Date => 3,
877 ScalarType::Decimal => 4,
878 ScalarType::Duration => 5,
879 ScalarType::Enum => 6,
880 ScalarType::E8s => 7,
881 ScalarType::E18s => 8,
882 ScalarType::Float32 => 9,
883 ScalarType::Float64 => 10,
884 ScalarType::Int => 11,
885 ScalarType::Int128 => 12,
886 ScalarType::IntBig => 13,
887 ScalarType::Principal => 14,
888 ScalarType::Subaccount => 15,
889 ScalarType::Text => 16,
890 ScalarType::Timestamp => 17,
891 ScalarType::Uint => 18,
892 ScalarType::Uint128 => 19,
893 ScalarType::UintBig => 20,
894 ScalarType::Ulid => 21,
895 ScalarType::Unit => 22,
896 }
897 }
898
899 fn scalar_from_index(index: usize) -> Option<ScalarType> {
901 let scalar = match index {
902 0 => ScalarType::Account,
903 1 => ScalarType::Blob,
904 2 => ScalarType::Bool,
905 3 => ScalarType::Date,
906 4 => ScalarType::Decimal,
907 5 => ScalarType::Duration,
908 6 => ScalarType::Enum,
909 7 => ScalarType::E8s,
910 8 => ScalarType::E18s,
911 9 => ScalarType::Float32,
912 10 => ScalarType::Float64,
913 11 => ScalarType::Int,
914 12 => ScalarType::Int128,
915 13 => ScalarType::IntBig,
916 14 => ScalarType::Principal,
917 15 => ScalarType::Subaccount,
918 16 => ScalarType::Text,
919 17 => ScalarType::Timestamp,
920 18 => ScalarType::Uint,
921 19 => ScalarType::Uint128,
922 20 => ScalarType::UintBig,
923 21 => ScalarType::Ulid,
924 22 => ScalarType::Unit,
925 _ => return None,
926 };
927
928 Some(scalar)
929 }
930
931 fn sample_value_for_scalar(scalar: ScalarType) -> Value {
933 match scalar {
934 ScalarType::Account => Value::Account(Account::dummy(1)),
935 ScalarType::Blob => Value::Blob(vec![0u8, 1u8]),
936 ScalarType::Bool => Value::Bool(true),
937 ScalarType::Date => Value::Date(Date::EPOCH),
938 ScalarType::Decimal => Value::Decimal(Decimal::ZERO),
939 ScalarType::Duration => Value::Duration(Duration::ZERO),
940 ScalarType::Enum => Value::Enum(ValueEnum::loose("example")),
941 ScalarType::E8s => Value::E8s(E8s::from_atomic(0)),
942 ScalarType::E18s => Value::E18s(E18s::from_atomic(0)),
943 ScalarType::Float32 => {
944 Value::Float32(Float32::try_new(0.0).expect("Float32 sample should be finite"))
945 }
946 ScalarType::Float64 => {
947 Value::Float64(Float64::try_new(0.0).expect("Float64 sample should be finite"))
948 }
949 ScalarType::Int => Value::Int(0),
950 ScalarType::Int128 => Value::Int128(Int128::from(0i128)),
951 ScalarType::IntBig => Value::IntBig(Int::from(0i32)),
952 ScalarType::Principal => Value::Principal(Principal::anonymous()),
953 ScalarType::Subaccount => Value::Subaccount(Subaccount::dummy(2)),
954 ScalarType::Text => Value::Text("text".to_string()),
955 ScalarType::Timestamp => Value::Timestamp(Timestamp::EPOCH),
956 ScalarType::Uint => Value::Uint(0),
957 ScalarType::Uint128 => Value::Uint128(Nat128::from(0u128)),
958 ScalarType::UintBig => Value::UintBig(Nat::from(0u64)),
959 ScalarType::Ulid => Value::Ulid(Ulid::nil()),
960 ScalarType::Unit => Value::Unit,
961 }
962 }
963
964 fn field(name: &'static str, kind: EntityFieldKind) -> EntityFieldModel {
965 EntityFieldModel { name, kind }
966 }
967
968 crate::test_entity_schema! {
969 ScalarPredicateEntity,
970 id = Ulid,
971 path = "predicate_validate::ScalarEntity",
972 entity_name = "ScalarEntity",
973 primary_key = "id",
974 pk_index = 0,
975 fields = [
976 ("id", EntityFieldKind::Ulid),
977 ("email", EntityFieldKind::Text),
978 ("age", EntityFieldKind::Uint),
979 ("created_at", EntityFieldKind::Timestamp),
980 ("active", EntityFieldKind::Bool),
981 ],
982 indexes = [],
983 }
984
985 crate::test_entity_schema! {
986 CollectionPredicateEntity,
987 id = Ulid,
988 path = "predicate_validate::CollectionEntity",
989 entity_name = "CollectionEntity",
990 primary_key = "id",
991 pk_index = 0,
992 fields = [
993 ("id", EntityFieldKind::Ulid),
994 ("tags", EntityFieldKind::List(&EntityFieldKind::Text)),
995 ("principals", EntityFieldKind::Set(&EntityFieldKind::Principal)),
996 (
997 "attributes",
998 EntityFieldKind::Map {
999 key: &EntityFieldKind::Text,
1000 value: &EntityFieldKind::Uint,
1001 }
1002 ),
1003 ],
1004 indexes = [],
1005 }
1006
1007 crate::test_entity_schema! {
1008 NumericCoercionPredicateEntity,
1009 id = Ulid,
1010 path = "predicate_validate::NumericCoercionEntity",
1011 entity_name = "NumericCoercionEntity",
1012 primary_key = "id",
1013 pk_index = 0,
1014 fields = [
1015 ("id", EntityFieldKind::Ulid),
1016 ("date", EntityFieldKind::Date),
1017 ("int_big", EntityFieldKind::IntBig),
1018 ("uint_big", EntityFieldKind::UintBig),
1019 ("int_small", EntityFieldKind::Int),
1020 ("uint_small", EntityFieldKind::Uint),
1021 ("decimal", EntityFieldKind::Decimal),
1022 ("e8s", EntityFieldKind::E8s),
1023 ],
1024 indexes = [],
1025 }
1026
1027 #[test]
1028 fn validate_model_accepts_scalars_and_coercions() {
1029 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1030
1031 let predicate = Predicate::And(vec![
1032 FieldRef::new("id").eq(Ulid::nil()),
1033 FieldRef::new("email").text_eq_ci("User@example.com"),
1034 FieldRef::new("age").lt(30u32),
1035 ]);
1036
1037 assert!(validate_model(model, &predicate).is_ok());
1038 }
1039
1040 #[test]
1041 fn validate_model_accepts_deterministic_set_predicates() {
1042 let model = <CollectionPredicateEntity as EntitySchema>::MODEL;
1043
1044 let predicate = Predicate::Compare(ComparePredicate::with_coercion(
1045 "principals",
1046 CompareOp::Contains,
1047 Principal::anonymous().to_value(),
1048 CoercionId::Strict,
1049 ));
1050
1051 assert!(validate_model(model, &predicate).is_ok());
1052 }
1053
1054 #[test]
1055 fn validate_model_rejects_non_queryable_fields() {
1056 let model = InvalidEntityModelBuilder::from_fields(
1057 vec![
1058 field("id", EntityFieldKind::Ulid),
1059 field("broken", EntityFieldKind::Structured { queryable: false }),
1060 ],
1061 0,
1062 );
1063
1064 let predicate = FieldRef::new("broken").eq(1u64);
1065
1066 assert!(matches!(
1067 validate_model(&model, &predicate),
1068 Err(ValidateError::NonQueryableFieldType { field }) if field == "broken"
1069 ));
1070 }
1071
1072 #[test]
1073 fn validate_model_accepts_text_contains() {
1074 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1075
1076 let predicate = FieldRef::new("email").text_contains("example");
1077 assert!(validate_model(model, &predicate).is_ok());
1078
1079 let predicate = FieldRef::new("email").text_contains_ci("EXAMPLE");
1080 assert!(validate_model(model, &predicate).is_ok());
1081 }
1082
1083 #[test]
1084 fn validate_model_rejects_text_contains_on_non_text() {
1085 let model = <ScalarPredicateEntity as EntitySchema>::MODEL;
1086
1087 let predicate = FieldRef::new("age").text_contains("1");
1088 assert!(matches!(
1089 validate_model(model, &predicate),
1090 Err(ValidateError::InvalidOperator { field, op })
1091 if field == "age" && op == "text_contains"
1092 ));
1093 }
1094
1095 #[test]
1096 fn validate_model_rejects_numeric_widen_for_registry_exclusions() {
1097 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1098
1099 let date_pred = FieldRef::new("date").lt(1i64);
1100 assert!(matches!(
1101 validate_model(model, &date_pred),
1102 Err(ValidateError::InvalidCoercion { field, coercion })
1103 if field == "date" && coercion == CoercionId::NumericWiden
1104 ));
1105
1106 let int_big_pred = FieldRef::new("int_big").lt(Int::from(1i32));
1107 assert!(matches!(
1108 validate_model(model, &int_big_pred),
1109 Err(ValidateError::InvalidCoercion { field, coercion })
1110 if field == "int_big" && coercion == CoercionId::NumericWiden
1111 ));
1112
1113 let uint_big_pred = FieldRef::new("uint_big").lt(Nat::from(1u64));
1114 assert!(matches!(
1115 validate_model(model, &uint_big_pred),
1116 Err(ValidateError::InvalidCoercion { field, coercion })
1117 if field == "uint_big" && coercion == CoercionId::NumericWiden
1118 ));
1119 }
1120
1121 #[test]
1122 fn validate_model_accepts_numeric_widen_for_registry_allowed_scalars() {
1123 let model = <NumericCoercionPredicateEntity as EntitySchema>::MODEL;
1124 let predicate = Predicate::And(vec![
1125 FieldRef::new("int_small").lt(9u64),
1126 FieldRef::new("uint_small").lt(9i64),
1127 FieldRef::new("decimal").lt(9u64),
1128 FieldRef::new("e8s").lt(9u64),
1129 ]);
1130
1131 assert!(validate_model(model, &predicate).is_ok());
1132 }
1133
1134 #[test]
1135 fn numeric_widen_authority_tracks_registry_flags() {
1136 for scalar in registry_scalars() {
1137 let field_type = FieldType::Scalar(scalar.clone());
1138 let literal = sample_value_for_scalar(scalar.clone());
1139 let expected = scalar.supports_numeric_coercion();
1140 let actual = ensure_coercion(
1141 "value",
1142 &field_type,
1143 &literal,
1144 &CoercionSpec::new(CoercionId::NumericWiden),
1145 )
1146 .is_ok();
1147
1148 assert_eq!(
1149 actual, expected,
1150 "numeric widen drift for scalar {scalar:?}: expected {expected}, got {actual}"
1151 );
1152 }
1153 }
1154
1155 #[test]
1156 fn numeric_widen_is_not_inferred_from_coercion_family() {
1157 let mut numeric_family_with_no_numeric_widen = 0usize;
1158
1159 for scalar in registry_scalars() {
1160 if scalar.coercion_family() != CoercionFamily::Numeric {
1161 continue;
1162 }
1163
1164 let field_type = FieldType::Scalar(scalar.clone());
1165 let literal = sample_value_for_scalar(scalar.clone());
1166 let numeric_widen_allowed = ensure_coercion(
1167 "value",
1168 &field_type,
1169 &literal,
1170 &CoercionSpec::new(CoercionId::NumericWiden),
1171 )
1172 .is_ok();
1173
1174 assert_eq!(
1175 numeric_widen_allowed,
1176 scalar.supports_numeric_coercion(),
1177 "numeric family must not imply numeric widen for scalar {scalar:?}"
1178 );
1179
1180 if !scalar.supports_numeric_coercion() {
1181 numeric_family_with_no_numeric_widen =
1182 numeric_family_with_no_numeric_widen.saturating_add(1);
1183 }
1184 }
1185
1186 assert!(
1187 numeric_family_with_no_numeric_widen > 0,
1188 "expected at least one numeric-family scalar without numeric widen support"
1189 );
1190 }
1191
1192 #[test]
1193 fn scalar_registry_covers_all_variants_exactly_once() {
1194 let scalars = registry_scalars();
1195 let mut names = BTreeSet::new();
1196 let mut seen = [false; SCALAR_TYPE_VARIANT_COUNT];
1197
1198 for scalar in scalars {
1199 let index = scalar_index(scalar.clone());
1200 assert!(!seen[index], "duplicate scalar entry: {scalar:?}");
1201 seen[index] = true;
1202
1203 let name = format!("{scalar:?}");
1204 assert!(names.insert(name.clone()), "duplicate scalar entry: {name}");
1205 }
1206
1207 let mut missing = Vec::new();
1208 for (index, was_seen) in seen.iter().enumerate() {
1209 if !*was_seen {
1210 let scalar = scalar_from_index(index).expect("index is in range");
1211 missing.push(format!("{scalar:?}"));
1212 }
1213 }
1214
1215 assert!(missing.is_empty(), "missing scalar entries: {missing:?}");
1216 assert_eq!(names.len(), SCALAR_TYPE_VARIANT_COUNT);
1217 }
1218
1219 #[test]
1220 fn scalar_keyability_matches_value_storage_key() {
1221 for scalar in registry_scalars() {
1222 let value = sample_value_for_scalar(scalar.clone());
1223 let scalar_keyable = scalar.is_keyable();
1224 let value_keyable = value.as_storage_key().is_some();
1225
1226 assert_eq!(
1227 value_keyable, scalar_keyable,
1228 "Value::as_storage_key drift for scalar {scalar:?}"
1229 );
1230 }
1231 }
1232}