icydb_core/db/query/predicate/
coercion.rs1use crate::value::{TextMode, Value, ValueFamily};
2use std::{cmp::Ordering, collections::BTreeMap, mem::discriminant};
3
4#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
28pub enum CoercionId {
29 Strict,
30 NumericWiden,
31 TextCasefold,
32 CollectionElement,
33}
34
35#[derive(Clone, Debug, Eq, PartialEq)]
46pub struct CoercionSpec {
47 pub id: CoercionId,
48 pub params: BTreeMap<String, String>,
49}
50
51impl CoercionSpec {
52 #[must_use]
53 pub const fn new(id: CoercionId) -> Self {
54 Self {
55 id,
56 params: BTreeMap::new(),
57 }
58 }
59}
60
61impl Default for CoercionSpec {
62 fn default() -> Self {
63 Self::new(CoercionId::Strict)
64 }
65}
66
67#[derive(Clone, Copy, Debug, Eq, PartialEq)]
72pub enum CoercionFamily {
73 Any,
74 Family(ValueFamily),
75}
76
77#[derive(Clone, Copy, Debug, Eq, PartialEq)]
88pub struct CoercionRule {
89 pub left: CoercionFamily,
90 pub right: CoercionFamily,
91 pub id: CoercionId,
92}
93
94pub const COERCION_TABLE: &[CoercionRule] = &[
95 CoercionRule {
96 left: CoercionFamily::Any,
97 right: CoercionFamily::Any,
98 id: CoercionId::Strict,
99 },
100 CoercionRule {
101 left: CoercionFamily::Family(ValueFamily::Numeric),
102 right: CoercionFamily::Family(ValueFamily::Numeric),
103 id: CoercionId::NumericWiden,
104 },
105 CoercionRule {
106 left: CoercionFamily::Family(ValueFamily::Textual),
107 right: CoercionFamily::Family(ValueFamily::Textual),
108 id: CoercionId::TextCasefold,
109 },
110 CoercionRule {
111 left: CoercionFamily::Any,
112 right: CoercionFamily::Any,
113 id: CoercionId::CollectionElement,
114 },
115];
116
117#[must_use]
118pub fn supports_coercion(left: ValueFamily, right: ValueFamily, id: CoercionId) -> bool {
119 COERCION_TABLE.iter().any(|rule| {
120 rule.id == id && family_matches(rule.left, left) && family_matches(rule.right, right)
121 })
122}
123
124fn family_matches(rule: CoercionFamily, value: ValueFamily) -> bool {
125 match rule {
126 CoercionFamily::Any => true,
127 CoercionFamily::Family(expected) => expected == value,
128 }
129}
130
131#[derive(Clone, Copy, Debug, Eq, PartialEq)]
136pub enum TextOp {
137 Eq,
138 Contains,
139 StartsWith,
140 EndsWith,
141}
142
143#[must_use]
148pub fn compare_eq(left: &Value, right: &Value, coercion: &CoercionSpec) -> Option<bool> {
149 match coercion.id {
150 CoercionId::Strict | CoercionId::CollectionElement => {
151 same_variant(left, right).then_some(left == right)
152 }
153 CoercionId::NumericWiden => left.cmp_numeric(right).map(|ord| ord == Ordering::Equal),
154 CoercionId::TextCasefold => compare_casefold(left, right),
155 }
156}
157
158#[must_use]
163pub fn compare_order(left: &Value, right: &Value, coercion: &CoercionSpec) -> Option<Ordering> {
164 match coercion.id {
165 CoercionId::Strict | CoercionId::CollectionElement => {
166 if !same_variant(left, right) {
167 return None;
168 }
169 strict_ordering(left, right)
170 }
171 CoercionId::NumericWiden => left.cmp_numeric(right),
172 CoercionId::TextCasefold => {
173 let left = casefold_value(left)?;
174 let right = casefold_value(right)?;
175 Some(left.cmp(&right))
176 }
177 }
178}
179
180#[must_use]
185pub fn compare_text(
186 left: &Value,
187 right: &Value,
188 coercion: &CoercionSpec,
189 op: TextOp,
190) -> Option<bool> {
191 if !matches!(left, Value::Text(_)) || !matches!(right, Value::Text(_)) {
192 return None;
194 }
195
196 let mode = match coercion.id {
197 CoercionId::Strict => TextMode::Cs,
198 CoercionId::TextCasefold => TextMode::Ci,
199 _ => return None,
200 };
201
202 match op {
203 TextOp::Eq => left.text_eq(right, mode),
204 TextOp::Contains => left.text_contains(right, mode),
205 TextOp::StartsWith => left.text_starts_with(right, mode),
206 TextOp::EndsWith => left.text_ends_with(right, mode),
207 }
208}
209
210fn same_variant(left: &Value, right: &Value) -> bool {
211 discriminant(left) == discriminant(right)
212}
213
214fn strict_ordering(left: &Value, right: &Value) -> Option<Ordering> {
219 match (left, right) {
220 (Value::Account(a), Value::Account(b)) => Some(a.cmp(b)),
221 (Value::Bool(a), Value::Bool(b)) => a.partial_cmp(b),
222 (Value::Date(a), Value::Date(b)) => a.partial_cmp(b),
223 (Value::Decimal(a), Value::Decimal(b)) => a.partial_cmp(b),
224 (Value::Duration(a), Value::Duration(b)) => a.partial_cmp(b),
225 (Value::E8s(a), Value::E8s(b)) => a.partial_cmp(b),
226 (Value::E18s(a), Value::E18s(b)) => a.partial_cmp(b),
227 (Value::Enum(a), Value::Enum(b)) => a.partial_cmp(b),
228 (Value::Float32(a), Value::Float32(b)) => a.partial_cmp(b),
229 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
230 (Value::Int(a), Value::Int(b)) => a.partial_cmp(b),
231 (Value::Int128(a), Value::Int128(b)) => a.partial_cmp(b),
232 (Value::IntBig(a), Value::IntBig(b)) => a.partial_cmp(b),
233 (Value::Principal(a), Value::Principal(b)) => a.partial_cmp(b),
234 (Value::Subaccount(a), Value::Subaccount(b)) => a.partial_cmp(b),
235 (Value::Text(a), Value::Text(b)) => a.partial_cmp(b),
236 (Value::Timestamp(a), Value::Timestamp(b)) => a.partial_cmp(b),
237 (Value::Uint(a), Value::Uint(b)) => a.partial_cmp(b),
238 (Value::Uint128(a), Value::Uint128(b)) => a.partial_cmp(b),
239 (Value::UintBig(a), Value::UintBig(b)) => a.partial_cmp(b),
240 (Value::Ulid(a), Value::Ulid(b)) => a.partial_cmp(b),
241 (Value::Unit, Value::Unit) => Some(Ordering::Equal),
242 _ => None,
243 }
244}
245
246fn compare_casefold(left: &Value, right: &Value) -> Option<bool> {
247 let left = casefold_value(left)?;
248 let right = casefold_value(right)?;
249 Some(left == right)
250}
251
252fn casefold_value(value: &Value) -> Option<String> {
255 match value {
256 Value::Text(text) => Some(casefold(text)),
257 _ => None,
259 }
260}
261
262fn casefold(input: &str) -> String {
263 if input.is_ascii() {
264 return input.to_ascii_lowercase();
265 }
266
267 input.to_lowercase()
269}