1use crate::db::predicate::{ComparePredicate, Predicate};
8
9#[must_use]
17pub fn normalize_identifier_to_scope(identifier: String, entity_scope: &[String]) -> String {
18 let Some((qualifier, leaf)) = split_qualified_identifier(identifier.as_str()) else {
19 return identifier;
20 };
21 if !entity_scope
22 .iter()
23 .any(|candidate| identifiers_tail_match(candidate.as_str(), qualifier))
24 {
25 return identifier;
26 }
27
28 leaf.to_string()
29}
30
31#[must_use]
33pub fn split_qualified_identifier(identifier: &str) -> Option<(&str, &str)> {
34 let (qualifier, leaf) = identifier.rsplit_once('.')?;
35 if qualifier.is_empty() || leaf.is_empty() {
36 return None;
37 }
38
39 Some((qualifier, leaf))
40}
41
42#[must_use]
44pub fn identifier_last_segment(identifier: &str) -> Option<&str> {
45 identifier.rsplit('.').next()
46}
47
48#[must_use]
50pub fn identifiers_tail_match(left: &str, right: &str) -> bool {
51 if left.eq_ignore_ascii_case(right) {
52 return true;
53 }
54
55 let left_last = identifier_last_segment(left);
56 let right_last = identifier_last_segment(right);
57 match (left_last, right_last) {
58 (Some(l), Some(r)) => l.eq_ignore_ascii_case(r),
59 _ => false,
60 }
61}
62
63pub(crate) fn rewrite_field_identifiers<F>(predicate: Predicate, map_field: F) -> Predicate
70where
71 F: FnMut(String) -> String,
72{
73 let mut map_field = map_field;
74
75 rewrite_field_identifiers_inner(predicate, &mut map_field)
76}
77
78fn rewrite_field_identifiers_inner<F>(predicate: Predicate, map_field: &mut F) -> Predicate
80where
81 F: FnMut(String) -> String,
82{
83 match predicate {
84 Predicate::True => Predicate::True,
85 Predicate::False => Predicate::False,
86 Predicate::And(children) => Predicate::And(
87 children
88 .into_iter()
89 .map(|child| rewrite_field_identifiers_inner(child, map_field))
90 .collect(),
91 ),
92 Predicate::Or(children) => Predicate::Or(
93 children
94 .into_iter()
95 .map(|child| rewrite_field_identifiers_inner(child, map_field))
96 .collect(),
97 ),
98 Predicate::Not(inner) => {
99 Predicate::Not(Box::new(rewrite_field_identifiers_inner(*inner, map_field)))
100 }
101 Predicate::Compare(compare) => {
102 Predicate::Compare(rewrite_compare_field(compare, map_field))
103 }
104 Predicate::IsNull { field } => Predicate::IsNull {
105 field: map_field(field),
106 },
107 Predicate::IsNotNull { field } => Predicate::IsNotNull {
108 field: map_field(field),
109 },
110 Predicate::IsMissing { field } => Predicate::IsMissing {
111 field: map_field(field),
112 },
113 Predicate::IsEmpty { field } => Predicate::IsEmpty {
114 field: map_field(field),
115 },
116 Predicate::IsNotEmpty { field } => Predicate::IsNotEmpty {
117 field: map_field(field),
118 },
119 Predicate::TextContains { field, value } => Predicate::TextContains {
120 field: map_field(field),
121 value,
122 },
123 Predicate::TextContainsCi { field, value } => Predicate::TextContainsCi {
124 field: map_field(field),
125 value,
126 },
127 }
128}
129
130fn rewrite_compare_field<F>(compare: ComparePredicate, map_field: &mut F) -> ComparePredicate
132where
133 F: FnMut(String) -> String,
134{
135 ComparePredicate {
136 field: map_field(compare.field),
137 op: compare.op,
138 value: compare.value,
139 coercion: compare.coercion,
140 }
141}
142
143#[cfg(test)]
148mod tests {
149 use crate::{
150 db::{
151 predicate::{CoercionId, CompareOp, ComparePredicate, Predicate},
152 sql::identifier::{identifiers_tail_match, normalize_identifier_to_scope},
153 },
154 value::Value,
155 };
156
157 #[test]
158 fn identifiers_tail_match_accepts_schema_qualified_forms() {
159 assert!(identifiers_tail_match("public.FixtureUser", "FixtureUser"));
160 assert!(identifiers_tail_match("fixtureorder", "FixtureOrder"));
161 assert!(!identifiers_tail_match("FixtureUser", "FixtureOrder"));
162 }
163
164 #[test]
165 fn normalize_identifier_to_scope_strips_matching_qualifier() {
166 let scope = vec!["public.FixtureUser".to_string(), "FixtureUser".to_string()];
167 assert_eq!(
168 normalize_identifier_to_scope("FixtureUser.email".to_string(), scope.as_slice()),
169 "email".to_string()
170 );
171 assert_eq!(
172 normalize_identifier_to_scope("public.FixtureUser.email".to_string(), scope.as_slice()),
173 "email".to_string()
174 );
175 }
176
177 #[test]
178 fn normalize_identifier_to_scope_preserves_non_matching_qualifier() {
179 let scope = vec!["FixtureUser".to_string()];
180 assert_eq!(
181 normalize_identifier_to_scope("FixtureOrder.email".to_string(), scope.as_slice()),
182 "FixtureOrder.email".to_string()
183 );
184 }
185
186 #[test]
187 fn rewrite_field_identifiers_updates_nested_predicate_fields() {
188 let predicate = Predicate::And(vec![
189 Predicate::Compare(ComparePredicate::eq(
190 "users.age".to_string(),
191 Value::Int(21),
192 )),
193 Predicate::Or(vec![
194 Predicate::IsNull {
195 field: "users.deleted_at".to_string(),
196 },
197 Predicate::Not(Box::new(Predicate::TextContainsCi {
198 field: "users.email".to_string(),
199 value: Value::Text("EXAMPLE".to_string()),
200 })),
201 ]),
202 ]);
203
204 let rewritten = super::rewrite_field_identifiers(predicate, strip_users_prefix);
205
206 let expected = Predicate::And(vec![
207 Predicate::Compare(ComparePredicate::eq("age".to_string(), Value::Int(21))),
208 Predicate::Or(vec![
209 Predicate::IsNull {
210 field: "deleted_at".to_string(),
211 },
212 Predicate::Not(Box::new(Predicate::TextContainsCi {
213 field: "email".to_string(),
214 value: Value::Text("EXAMPLE".to_string()),
215 })),
216 ]),
217 ]);
218
219 assert_eq!(rewritten, expected);
220 }
221
222 #[test]
223 fn rewrite_field_identifiers_preserves_compare_semantics() {
224 let predicate = Predicate::Compare(ComparePredicate::with_coercion(
225 "users.email",
226 CompareOp::StartsWith,
227 Value::Text("Ada".to_string()),
228 CoercionId::TextCasefold,
229 ));
230
231 let rewritten = super::rewrite_field_identifiers(predicate, strip_users_prefix);
232 let Predicate::Compare(compare) = rewritten else {
233 panic!("rewritten predicate should remain compare");
234 };
235
236 assert_eq!(compare.field, "email".to_string());
237 assert_eq!(compare.op, CompareOp::StartsWith);
238 assert_eq!(compare.value, Value::Text("Ada".to_string()));
239 assert_eq!(compare.coercion.id, CoercionId::TextCasefold);
240 }
241
242 fn strip_users_prefix(identifier: String) -> String {
243 if let Some(field) = identifier.strip_prefix("users.") {
244 return field.to_string();
245 }
246
247 identifier
248 }
249}