1use crate::{
2 db::query::predicate::{
3 ast::{CompareOp, ComparePredicate, Predicate},
4 coercion::{CoercionId, CoercionSpec},
5 },
6 value::{Value, ValueEnum},
7};
8
9#[must_use]
28pub fn normalize(predicate: &Predicate) -> Predicate {
29 match predicate {
30 Predicate::True => Predicate::True,
31 Predicate::False => Predicate::False,
32
33 Predicate::And(children) => normalize_and(children),
34 Predicate::Or(children) => normalize_or(children),
35 Predicate::Not(inner) => normalize_not(inner),
36
37 Predicate::Compare(cmp) => Predicate::Compare(normalize_compare(cmp)),
38
39 Predicate::IsNull { field } => Predicate::IsNull {
40 field: field.clone(),
41 },
42 Predicate::IsMissing { field } => Predicate::IsMissing {
43 field: field.clone(),
44 },
45 Predicate::IsEmpty { field } => Predicate::IsEmpty {
46 field: field.clone(),
47 },
48 Predicate::IsNotEmpty { field } => Predicate::IsNotEmpty {
49 field: field.clone(),
50 },
51
52 Predicate::MapContainsKey {
53 field,
54 key,
55 coercion,
56 } => Predicate::MapContainsKey {
57 field: field.clone(),
58 key: key.clone(),
59 coercion: coercion.clone(),
60 },
61
62 Predicate::MapContainsValue {
63 field,
64 value,
65 coercion,
66 } => Predicate::MapContainsValue {
67 field: field.clone(),
68 value: value.clone(),
69 coercion: coercion.clone(),
70 },
71
72 Predicate::MapContainsEntry {
73 field,
74 key,
75 value,
76 coercion,
77 } => Predicate::MapContainsEntry {
78 field: field.clone(),
79 key: key.clone(),
80 value: value.clone(),
81 coercion: coercion.clone(),
82 },
83 Predicate::TextContains { field, value } => Predicate::TextContains {
84 field: field.clone(),
85 value: value.clone(),
86 },
87 Predicate::TextContainsCi { field, value } => Predicate::TextContainsCi {
88 field: field.clone(),
89 value: value.clone(),
90 },
91 }
92}
93
94fn normalize_compare(cmp: &ComparePredicate) -> ComparePredicate {
101 ComparePredicate {
102 field: cmp.field.clone(),
103 op: cmp.op,
104 value: cmp.value.clone(),
105 coercion: cmp.coercion.clone(),
106 }
107}
108
109fn normalize_not(inner: &Predicate) -> Predicate {
116 let normalized = normalize(inner);
117
118 if let Predicate::Not(double) = normalized {
119 return normalize(&double);
120 }
121
122 Predicate::Not(Box::new(normalized))
123}
124
125fn normalize_and(children: &[Predicate]) -> Predicate {
137 let mut out = Vec::new();
138
139 for child in children {
140 let normalized = normalize(child);
141
142 match normalized {
143 Predicate::True => {}
144 Predicate::False => return Predicate::False,
145 Predicate::And(grandchildren) => out.extend(grandchildren),
146 other => out.push(other),
147 }
148 }
149
150 if out.is_empty() {
151 return Predicate::True;
152 }
153
154 out.sort_by_cached_key(sort_key);
155 Predicate::And(out)
156}
157
158fn normalize_or(children: &[Predicate]) -> Predicate {
170 let mut out = Vec::new();
171
172 for child in children {
173 let normalized = normalize(child);
174
175 match normalized {
176 Predicate::False => {}
177 Predicate::True => return Predicate::True,
178 Predicate::Or(grandchildren) => out.extend(grandchildren),
179 other => out.push(other),
180 }
181 }
182
183 if out.is_empty() {
184 return Predicate::False;
185 }
186
187 out.sort_by_cached_key(sort_key);
188 Predicate::Or(out)
189}
190
191fn sort_key(predicate: &Predicate) -> Vec<u8> {
201 let mut out = Vec::new();
202 encode_predicate_key(&mut out, predicate);
203 out
204}
205
206const PRED_TRUE: u8 = 0x00;
207const PRED_FALSE: u8 = 0x01;
208const PRED_AND: u8 = 0x02;
209const PRED_OR: u8 = 0x03;
210const PRED_NOT: u8 = 0x04;
211const PRED_COMPARE: u8 = 0x05;
212const PRED_IS_NULL: u8 = 0x06;
213const PRED_IS_MISSING: u8 = 0x07;
214const PRED_IS_EMPTY: u8 = 0x08;
215const PRED_IS_NOT_EMPTY: u8 = 0x09;
216const PRED_MAP_CONTAINS_KEY: u8 = 0x0A;
217const PRED_MAP_CONTAINS_VALUE: u8 = 0x0B;
218const PRED_MAP_CONTAINS_ENTRY: u8 = 0x0C;
219const PRED_TEXT_CONTAINS: u8 = 0x0D;
220const PRED_TEXT_CONTAINS_CI: u8 = 0x0E;
221
222fn encode_predicate_key(out: &mut Vec<u8>, predicate: &Predicate) {
224 match predicate {
225 Predicate::True => out.push(PRED_TRUE),
226 Predicate::False => out.push(PRED_FALSE),
227 Predicate::And(children) => {
228 out.push(PRED_AND);
229 push_len(out, children.len());
230 for child in children {
231 push_predicate(out, child);
232 }
233 }
234 Predicate::Or(children) => {
235 out.push(PRED_OR);
236 push_len(out, children.len());
237 for child in children {
238 push_predicate(out, child);
239 }
240 }
241 Predicate::Not(inner) => {
242 out.push(PRED_NOT);
243 push_predicate(out, inner);
244 }
245 Predicate::Compare(cmp) => {
246 out.push(PRED_COMPARE);
247 push_str(out, &cmp.field);
248 out.push(compare_op_tag(cmp.op));
249 push_value(out, &cmp.value);
250 push_coercion(out, &cmp.coercion);
251 }
252 Predicate::IsNull { field } => {
253 out.push(PRED_IS_NULL);
254 push_str(out, field);
255 }
256 Predicate::IsMissing { field } => {
257 out.push(PRED_IS_MISSING);
258 push_str(out, field);
259 }
260 Predicate::IsEmpty { field } => {
261 out.push(PRED_IS_EMPTY);
262 push_str(out, field);
263 }
264 Predicate::IsNotEmpty { field } => {
265 out.push(PRED_IS_NOT_EMPTY);
266 push_str(out, field);
267 }
268 Predicate::MapContainsKey {
269 field,
270 key,
271 coercion,
272 } => {
273 out.push(PRED_MAP_CONTAINS_KEY);
274 push_str(out, field);
275 push_value(out, key);
276 push_coercion(out, coercion);
277 }
278 Predicate::MapContainsValue {
279 field,
280 value,
281 coercion,
282 } => {
283 out.push(PRED_MAP_CONTAINS_VALUE);
284 push_str(out, field);
285 push_value(out, value);
286 push_coercion(out, coercion);
287 }
288 Predicate::MapContainsEntry {
289 field,
290 key,
291 value,
292 coercion,
293 } => {
294 out.push(PRED_MAP_CONTAINS_ENTRY);
295 push_str(out, field);
296 push_value(out, key);
297 push_value(out, value);
298 push_coercion(out, coercion);
299 }
300 Predicate::TextContains { field, value } => {
301 out.push(PRED_TEXT_CONTAINS);
302 push_str(out, field);
303 push_value(out, value);
304 }
305 Predicate::TextContainsCi { field, value } => {
306 out.push(PRED_TEXT_CONTAINS_CI);
307 push_str(out, field);
308 push_value(out, value);
309 }
310 }
311}
312
313const VALUE_ACCOUNT: u8 = 1;
314const VALUE_BLOB: u8 = 2;
315const VALUE_BOOL: u8 = 3;
316const VALUE_DATE: u8 = 4;
317const VALUE_DECIMAL: u8 = 5;
318const VALUE_DURATION: u8 = 6;
319const VALUE_ENUM: u8 = 7;
320const VALUE_E8S: u8 = 8;
321const VALUE_E18S: u8 = 9;
322const VALUE_FLOAT32: u8 = 10;
323const VALUE_FLOAT64: u8 = 11;
324const VALUE_INT: u8 = 12;
325const VALUE_INT128: u8 = 13;
326const VALUE_INT_BIG: u8 = 14;
327const VALUE_LIST: u8 = 15;
328const VALUE_NONE: u8 = 16;
329const VALUE_PRINCIPAL: u8 = 17;
330const VALUE_SUBACCOUNT: u8 = 18;
331const VALUE_TEXT: u8 = 19;
332const VALUE_TIMESTAMP: u8 = 20;
333const VALUE_UINT: u8 = 21;
334const VALUE_UINT128: u8 = 22;
335const VALUE_UINT_BIG: u8 = 23;
336const VALUE_ULID: u8 = 24;
337const VALUE_UNIT: u8 = 25;
338const VALUE_UNSUPPORTED: u8 = 26;
339
340#[expect(clippy::too_many_lines)]
341fn encode_value_key(out: &mut Vec<u8>, value: &Value) {
342 match value {
343 Value::Account(v) => {
344 out.push(VALUE_ACCOUNT);
345 push_bytes(out, v.owner.as_slice());
346 match v.subaccount {
347 Some(sub) => {
348 out.push(1);
349 push_bytes(out, &sub.to_bytes());
350 }
351 None => out.push(0),
352 }
353 }
354 Value::Blob(v) => {
355 out.push(VALUE_BLOB);
356 push_bytes(out, v);
357 }
358 Value::Bool(v) => {
359 out.push(VALUE_BOOL);
360 out.push(u8::from(*v));
361 }
362 Value::Date(v) => {
363 out.push(VALUE_DATE);
364 out.extend_from_slice(&v.get().to_be_bytes());
365 }
366 Value::Decimal(v) => {
367 out.push(VALUE_DECIMAL);
368 out.push(u8::from(v.is_sign_negative()));
369 out.extend_from_slice(&v.scale().to_be_bytes());
370 out.extend_from_slice(&v.mantissa().to_be_bytes());
371 }
372 Value::Duration(v) => {
373 out.push(VALUE_DURATION);
374 out.extend_from_slice(&v.get().to_be_bytes());
375 }
376 Value::Enum(v) => {
377 out.push(VALUE_ENUM);
378 push_enum(out, v);
379 }
380 Value::E8s(v) => {
381 out.push(VALUE_E8S);
382 out.extend_from_slice(&v.get().to_be_bytes());
383 }
384 Value::E18s(v) => {
385 out.push(VALUE_E18S);
386 out.extend_from_slice(&v.get().to_be_bytes());
387 }
388 Value::Float32(v) => {
389 out.push(VALUE_FLOAT32);
390 out.extend_from_slice(&v.to_be_bytes());
391 }
392 Value::Float64(v) => {
393 out.push(VALUE_FLOAT64);
394 out.extend_from_slice(&v.to_be_bytes());
395 }
396 Value::Int(v) => {
397 out.push(VALUE_INT);
398 out.extend_from_slice(&v.to_be_bytes());
399 }
400 Value::Int128(v) => {
401 out.push(VALUE_INT128);
402 out.extend_from_slice(&v.get().to_be_bytes());
403 }
404 Value::IntBig(v) => {
405 out.push(VALUE_INT_BIG);
406 push_bytes(out, &v.to_leb128());
407 }
408 Value::List(items) => {
409 out.push(VALUE_LIST);
410 push_len(out, items.len());
411 for item in items {
412 push_value(out, item);
413 }
414 }
415 Value::None => out.push(VALUE_NONE),
416 Value::Principal(v) => {
417 out.push(VALUE_PRINCIPAL);
418 push_bytes(out, v.as_slice());
419 }
420 Value::Subaccount(v) => {
421 out.push(VALUE_SUBACCOUNT);
422 push_bytes(out, &v.to_bytes());
423 }
424 Value::Text(v) => {
425 out.push(VALUE_TEXT);
426 push_str(out, v);
427 }
428 Value::Timestamp(v) => {
429 out.push(VALUE_TIMESTAMP);
430 out.extend_from_slice(&v.get().to_be_bytes());
431 }
432 Value::Uint(v) => {
433 out.push(VALUE_UINT);
434 out.extend_from_slice(&v.to_be_bytes());
435 }
436 Value::Uint128(v) => {
437 out.push(VALUE_UINT128);
438 out.extend_from_slice(&v.get().to_be_bytes());
439 }
440 Value::UintBig(v) => {
441 out.push(VALUE_UINT_BIG);
442 push_bytes(out, &v.to_leb128());
443 }
444 Value::Ulid(v) => {
445 out.push(VALUE_ULID);
446 out.extend_from_slice(&v.to_bytes());
447 }
448 Value::Unit => out.push(VALUE_UNIT),
449 Value::Unsupported => out.push(VALUE_UNSUPPORTED),
450 }
451}
452
453fn push_predicate(out: &mut Vec<u8>, predicate: &Predicate) {
454 let mut buf = Vec::new();
455 encode_predicate_key(&mut buf, predicate);
456 push_bytes(out, &buf);
457}
458
459fn push_value(out: &mut Vec<u8>, value: &Value) {
460 let mut buf = Vec::new();
461 encode_value_key(&mut buf, value);
462 push_bytes(out, &buf);
463}
464
465fn push_enum(out: &mut Vec<u8>, value: &ValueEnum) {
466 match &value.path {
467 Some(path) => {
468 out.push(1);
469 push_str(out, path);
470 }
471 None => out.push(0),
472 }
473 push_str(out, &value.variant);
474 match &value.payload {
475 Some(payload) => {
476 out.push(1);
477 push_value(out, payload);
478 }
479 None => out.push(0),
480 }
481}
482
483fn push_coercion(out: &mut Vec<u8>, spec: &CoercionSpec) {
484 out.push(coercion_id_tag(spec.id));
485 push_len(out, spec.params.len());
486 for (key, value) in &spec.params {
487 push_str(out, key);
488 push_str(out, value);
489 }
490}
491
492const fn compare_op_tag(op: CompareOp) -> u8 {
493 match op {
494 CompareOp::Eq => 0,
495 CompareOp::Ne => 1,
496 CompareOp::Lt => 2,
497 CompareOp::Lte => 3,
498 CompareOp::Gt => 4,
499 CompareOp::Gte => 5,
500 CompareOp::In => 6,
501 CompareOp::NotIn => 7,
502 CompareOp::AnyIn => 8,
503 CompareOp::AllIn => 9,
504 CompareOp::Contains => 10,
505 CompareOp::StartsWith => 11,
506 CompareOp::EndsWith => 12,
507 }
508}
509
510const fn coercion_id_tag(id: CoercionId) -> u8 {
511 match id {
512 CoercionId::Strict => 0,
513 CoercionId::NumericWiden => 1,
514 CoercionId::IdentifierText => 2,
515 CoercionId::TextCasefold => 3,
516 CoercionId::CollectionElement => 4,
517 }
518}
519
520fn push_len(out: &mut Vec<u8>, len: usize) {
521 let len = u64::try_from(len).unwrap_or(u64::MAX);
522 out.extend_from_slice(&len.to_be_bytes());
523}
524
525fn push_bytes(out: &mut Vec<u8>, bytes: &[u8]) {
526 push_len(out, bytes.len());
527 out.extend_from_slice(bytes);
528}
529
530fn push_str(out: &mut Vec<u8>, s: &str) {
531 push_bytes(out, s.as_bytes());
532}
533
534#[cfg(test)]
539mod tests {
540 use super::*;
541
542 #[test]
543 fn sort_key_distinguishes_list_text_with_delimiters() {
544 let left = Predicate::Compare(ComparePredicate {
545 field: "field".to_string(),
546 op: CompareOp::Eq,
547 value: Value::List(vec![Value::Text("a,b".to_string())]),
548 coercion: CoercionSpec::default(),
549 });
550 let right = Predicate::Compare(ComparePredicate {
551 field: "field".to_string(),
552 op: CompareOp::Eq,
553 value: Value::List(vec![
554 Value::Text("a".to_string()),
555 Value::Text("b".to_string()),
556 ]),
557 coercion: CoercionSpec::default(),
558 });
559
560 assert_ne!(sort_key(&left), sort_key(&right));
561 }
562}