1use crate::db::predicate::{ComparePredicate, Predicate};
8
9#[must_use]
17pub fn normalize_identifier_to_scope(identifier: String, entity_scope: &[String]) -> String {
18 let Some((qualifier, leaf)) = split_qualified_identifier(identifier.as_str()) else {
19 return identifier;
20 };
21 if !entity_scope
22 .iter()
23 .any(|candidate| identifiers_tail_match(candidate.as_str(), qualifier))
24 {
25 return identifier;
26 }
27
28 leaf.to_string()
29}
30
31#[must_use]
33pub fn split_qualified_identifier(identifier: &str) -> Option<(&str, &str)> {
34 let (qualifier, leaf) = identifier.rsplit_once('.')?;
35 if qualifier.is_empty() || leaf.is_empty() {
36 return None;
37 }
38
39 Some((qualifier, leaf))
40}
41
42#[must_use]
44pub fn identifier_last_segment(identifier: &str) -> Option<&str> {
45 identifier.rsplit('.').next()
46}
47
48#[must_use]
50pub fn identifiers_tail_match(left: &str, right: &str) -> bool {
51 if left.eq_ignore_ascii_case(right) {
52 return true;
53 }
54
55 let left_last = identifier_last_segment(left);
56 let right_last = identifier_last_segment(right);
57 match (left_last, right_last) {
58 (Some(l), Some(r)) => l.eq_ignore_ascii_case(r),
59 _ => false,
60 }
61}
62
63pub(crate) fn rewrite_field_identifiers<F>(predicate: Predicate, map_field: F) -> Predicate
70where
71 F: FnMut(String) -> String,
72{
73 let mut map_field = map_field;
74
75 rewrite_field_identifiers_inner(predicate, &mut map_field)
76}
77
78fn rewrite_field_identifiers_inner<F>(predicate: Predicate, map_field: &mut F) -> Predicate
80where
81 F: FnMut(String) -> String,
82{
83 match predicate {
84 Predicate::True => Predicate::True,
85 Predicate::False => Predicate::False,
86 Predicate::And(children) => Predicate::And(
87 children
88 .into_iter()
89 .map(|child| rewrite_field_identifiers_inner(child, map_field))
90 .collect(),
91 ),
92 Predicate::Or(children) => Predicate::Or(
93 children
94 .into_iter()
95 .map(|child| rewrite_field_identifiers_inner(child, map_field))
96 .collect(),
97 ),
98 Predicate::Not(inner) => {
99 Predicate::Not(Box::new(rewrite_field_identifiers_inner(*inner, map_field)))
100 }
101 Predicate::Compare(compare) => {
102 Predicate::Compare(rewrite_compare_field(compare, map_field))
103 }
104 Predicate::CompareFields(compare) => {
105 Predicate::CompareFields(rewrite_compare_fields(compare, map_field))
106 }
107 Predicate::IsNull { field } => Predicate::IsNull {
108 field: map_field(field),
109 },
110 Predicate::IsNotNull { field } => Predicate::IsNotNull {
111 field: map_field(field),
112 },
113 Predicate::IsMissing { field } => Predicate::IsMissing {
114 field: map_field(field),
115 },
116 Predicate::IsEmpty { field } => Predicate::IsEmpty {
117 field: map_field(field),
118 },
119 Predicate::IsNotEmpty { field } => Predicate::IsNotEmpty {
120 field: map_field(field),
121 },
122 Predicate::TextContains { field, value } => Predicate::TextContains {
123 field: map_field(field),
124 value,
125 },
126 Predicate::TextContainsCi { field, value } => Predicate::TextContainsCi {
127 field: map_field(field),
128 value,
129 },
130 }
131}
132
133fn rewrite_compare_field<F>(compare: ComparePredicate, map_field: &mut F) -> ComparePredicate
135where
136 F: FnMut(String) -> String,
137{
138 ComparePredicate {
139 field: map_field(compare.field),
140 op: compare.op,
141 value: compare.value,
142 coercion: compare.coercion,
143 }
144}
145
146fn rewrite_compare_fields<F>(
147 compare: crate::db::predicate::CompareFieldsPredicate,
148 map_field: &mut F,
149) -> crate::db::predicate::CompareFieldsPredicate
150where
151 F: FnMut(String) -> String,
152{
153 crate::db::predicate::CompareFieldsPredicate::with_coercion(
154 map_field(compare.left_field().to_string()),
155 compare.op(),
156 map_field(compare.right_field().to_string()),
157 compare.coercion().id,
158 )
159}
160
161#[cfg(test)]
166mod tests {
167 use crate::{
168 db::{
169 predicate::{CoercionId, CompareOp, ComparePredicate, Predicate},
170 sql::identifier::{identifiers_tail_match, normalize_identifier_to_scope},
171 },
172 value::Value,
173 };
174
175 #[test]
176 fn identifiers_tail_match_accepts_schema_qualified_forms() {
177 assert!(identifiers_tail_match("public.FixtureUser", "FixtureUser"));
178 assert!(identifiers_tail_match("fixtureorder", "FixtureOrder"));
179 assert!(!identifiers_tail_match("FixtureUser", "FixtureOrder"));
180 }
181
182 #[test]
183 fn normalize_identifier_to_scope_strips_matching_qualifier() {
184 let scope = vec!["public.FixtureUser".to_string(), "FixtureUser".to_string()];
185 assert_eq!(
186 normalize_identifier_to_scope("FixtureUser.email".to_string(), scope.as_slice()),
187 "email".to_string()
188 );
189 assert_eq!(
190 normalize_identifier_to_scope("public.FixtureUser.email".to_string(), scope.as_slice()),
191 "email".to_string()
192 );
193 }
194
195 #[test]
196 fn normalize_identifier_to_scope_preserves_non_matching_qualifier() {
197 let scope = vec!["FixtureUser".to_string()];
198 assert_eq!(
199 normalize_identifier_to_scope("FixtureOrder.email".to_string(), scope.as_slice()),
200 "FixtureOrder.email".to_string()
201 );
202 }
203
204 #[test]
205 fn rewrite_field_identifiers_updates_nested_predicate_fields() {
206 let predicate = Predicate::And(vec![
207 Predicate::Compare(ComparePredicate::eq(
208 "users.age".to_string(),
209 Value::Int(21),
210 )),
211 Predicate::Or(vec![
212 Predicate::IsNull {
213 field: "users.deleted_at".to_string(),
214 },
215 Predicate::Not(Box::new(Predicate::TextContainsCi {
216 field: "users.email".to_string(),
217 value: Value::Text("EXAMPLE".to_string()),
218 })),
219 ]),
220 ]);
221
222 let rewritten = super::rewrite_field_identifiers(predicate, strip_users_prefix);
223
224 let expected = Predicate::And(vec![
225 Predicate::Compare(ComparePredicate::eq("age".to_string(), Value::Int(21))),
226 Predicate::Or(vec![
227 Predicate::IsNull {
228 field: "deleted_at".to_string(),
229 },
230 Predicate::Not(Box::new(Predicate::TextContainsCi {
231 field: "email".to_string(),
232 value: Value::Text("EXAMPLE".to_string()),
233 })),
234 ]),
235 ]);
236
237 assert_eq!(rewritten, expected);
238 }
239
240 #[test]
241 fn rewrite_field_identifiers_preserves_compare_semantics() {
242 let predicate = Predicate::Compare(ComparePredicate::with_coercion(
243 "users.email",
244 CompareOp::StartsWith,
245 Value::Text("Ada".to_string()),
246 CoercionId::TextCasefold,
247 ));
248
249 let rewritten = super::rewrite_field_identifiers(predicate, strip_users_prefix);
250 let Predicate::Compare(compare) = rewritten else {
251 panic!("rewritten predicate should remain compare");
252 };
253
254 assert_eq!(compare.field, "email".to_string());
255 assert_eq!(compare.op, CompareOp::StartsWith);
256 assert_eq!(compare.value, Value::Text("Ada".to_string()));
257 assert_eq!(compare.coercion.id, CoercionId::TextCasefold);
258 }
259
260 fn strip_users_prefix(identifier: String) -> String {
261 if let Some(field) = identifier.strip_prefix("users.") {
262 return field.to_string();
263 }
264
265 identifier
266 }
267}