1use std::collections::BTreeMap;
12
13use osproxy_core::FieldName;
14use serde_json::value::RawValue;
15use serde_json::{Map, Value};
16
17use crate::error::RewriteError;
18
19pub fn wrap_query(body: &[u8], filter: &[(FieldName, Value)]) -> Result<Vec<u8>, RewriteError> {
57 let mut top = parse_top(body)?;
63
64 if !filter.is_empty() {
69 reject_unfilterable(&top)?;
70 }
71
72 let client_query = top.remove("query");
75 let query = build_filtered_query(client_query.as_deref(), filter)?;
76 top.insert("query".to_owned(), query);
77
78 serde_json::to_vec(&top).map_err(|_| RewriteError::InvalidJson)
81}
82
83fn build_filtered_query(
86 client_query: Option<&RawValue>,
87 filter: &[(FieldName, Value)],
88) -> Result<Box<RawValue>, RewriteError> {
89 let mut q = Vec::with_capacity(64 + client_query.map_or(0, |q| q.get().len()));
90 q.extend_from_slice(br#"{"bool":{"must":"#);
91 match client_query {
92 Some(raw) => {
94 q.push(b'[');
95 q.extend_from_slice(raw.get().as_bytes());
96 q.push(b']');
97 }
98 None => q.extend_from_slice(b"[]"),
99 }
100 q.extend_from_slice(br#","filter":["#);
101 for (i, (name, value)) in filter.iter().enumerate() {
102 if i > 0 {
103 q.push(b',');
104 }
105 q.extend_from_slice(br#"{"term":"#);
106 let mut term = Map::with_capacity(1);
109 term.insert(name.as_str().to_owned(), value.clone());
110 serde_json::to_writer(&mut q, &term).map_err(|_| RewriteError::InvalidJson)?;
111 q.push(b'}');
112 }
113 q.extend_from_slice(b"]}}");
114
115 let s = String::from_utf8(q).map_err(|_| RewriteError::InvalidJson)?;
116 RawValue::from_string(s).map_err(|_| RewriteError::InvalidJson)
117}
118
119fn parse_top(body: &[u8]) -> Result<BTreeMap<String, Box<RawValue>>, RewriteError> {
124 if body.iter().all(u8::is_ascii_whitespace) {
125 return Ok(BTreeMap::new());
126 }
127 match serde_json::from_slice::<BTreeMap<String, Box<RawValue>>>(body) {
128 Ok(map) => Ok(map),
129 Err(_) => match serde_json::from_slice::<&RawValue>(body) {
133 Ok(_) => Err(RewriteError::NotAnObject),
134 Err(_) => Err(RewriteError::InvalidJson),
135 },
136 }
137}
138
139fn reject_unfilterable(top: &BTreeMap<String, Box<RawValue>>) -> Result<(), RewriteError> {
145 if top.contains_key("suggest") {
146 return Err(RewriteError::Unfilterable {
147 construct: "suggest",
148 });
149 }
150 for key in ["aggs", "aggregations"] {
151 if let Some(raw) = top.get(key) {
152 let aggs: Value = serde_json::from_slice(raw.get().as_bytes())
153 .map_err(|_| RewriteError::InvalidJson)?;
154 if contains_global_agg(&aggs) {
155 return Err(RewriteError::Unfilterable {
156 construct: "global aggregation",
157 });
158 }
159 }
160 }
161 Ok(())
162}
163
164fn contains_global_agg(aggs: &Value) -> bool {
169 let Some(obj) = aggs.as_object() else {
170 return false;
171 };
172 obj.values().any(|agg| {
173 agg.as_object().is_some_and(|agg| {
174 agg.contains_key("global")
175 || ["aggs", "aggregations"]
176 .iter()
177 .filter_map(|k| agg.get(*k))
178 .any(contains_global_agg)
179 })
180 })
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 fn filter() -> Vec<(FieldName, Value)> {
188 vec![(FieldName::from("_tenant"), Value::from("acme"))]
189 }
190
191 #[test]
192 fn client_query_is_nested_under_must_with_filter_sibling() {
193 let wrapped = wrap_query(br#"{"query":{"match":{"msg":"hi"}}}"#, &filter()).unwrap();
194 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
195 assert_eq!(doc["query"]["bool"]["must"][0]["match"]["msg"], "hi");
196 assert_eq!(doc["query"]["bool"]["filter"][0]["term"]["_tenant"], "acme");
197 }
198
199 #[test]
200 fn absent_query_becomes_filtered_match_all() {
201 let wrapped = wrap_query(br#"{"size":5}"#, &filter()).unwrap();
202 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
203 assert_eq!(doc["query"]["bool"]["must"].as_array().unwrap().len(), 0);
205 assert_eq!(doc["query"]["bool"]["filter"][0]["term"]["_tenant"], "acme");
206 assert_eq!(doc["size"], 5);
208 }
209
210 #[test]
211 fn empty_body_is_a_filtered_match_all() {
212 let wrapped = wrap_query(b"", &filter()).unwrap();
213 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
214 assert_eq!(doc["query"]["bool"]["filter"][0]["term"]["_tenant"], "acme");
215 }
216
217 #[test]
218 fn multiple_filter_terms_are_all_applied() {
219 let wrapped = wrap_query(
220 b"{}",
221 &[
222 (FieldName::from("_tenant"), Value::from("acme")),
223 (FieldName::from("_region"), Value::from("eu")),
224 ],
225 )
226 .unwrap();
227 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
228 let terms = doc["query"]["bool"]["filter"].as_array().unwrap();
229 assert_eq!(terms.len(), 2);
230 }
231
232 #[test]
233 fn a_nested_query_key_is_not_confused_with_the_top_level_one() {
234 let wrapped = wrap_query(
237 br#"{"query":{"match":{"msg":"hi"}},"aggs":{"q":{"terms":{"field":"query"}}}}"#,
238 &filter(),
239 )
240 .unwrap();
241 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
242 assert_eq!(doc["query"]["bool"]["must"][0]["match"]["msg"], "hi");
243 assert_eq!(doc["aggs"]["q"]["terms"]["field"], "query");
245 }
246
247 #[test]
248 fn complex_sibling_subtrees_survive_verbatim() {
249 let body = br#"{"size":5,"sort":[{"ts":"desc"},"_score"],"_source":["a","b"],"query":{"term":{"k":"v"}}}"#;
250 let wrapped = wrap_query(body, &filter()).unwrap();
251 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
252 assert_eq!(doc["size"], 5);
253 assert_eq!(doc["sort"][0]["ts"], "desc");
254 assert_eq!(doc["sort"][1], "_score");
255 assert_eq!(doc["_source"][1], "b");
256 assert_eq!(doc["query"]["bool"]["must"][0]["term"]["k"], "v");
257 assert_eq!(doc["query"]["bool"]["filter"][0]["term"]["_tenant"], "acme");
258 }
259
260 #[test]
261 fn escaped_and_unicode_content_in_the_client_query_is_preserved() {
262 let body = "{\"query\":{\"match\":{\"msg\":\"a\\\"b\\\\c\\té \u{4e2d}\"}}}";
264 let wrapped = wrap_query(body.as_bytes(), &filter()).unwrap();
265 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
266 assert_eq!(
267 doc["query"]["bool"]["must"][0]["match"]["msg"],
268 "a\"b\\c\t\u{e9} \u{4e2d}"
269 );
270 }
271
272 #[test]
273 fn a_non_string_filter_value_is_embedded_correctly() {
274 let wrapped = wrap_query(
275 br#"{"query":{"match_all":{}}}"#,
276 &[
277 (FieldName::from("_active"), Value::from(true)),
278 (FieldName::from("_shard"), Value::from(7)),
279 ],
280 )
281 .unwrap();
282 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
283 assert_eq!(doc["query"]["bool"]["filter"][0]["term"]["_active"], true);
284 assert_eq!(doc["query"]["bool"]["filter"][1]["term"]["_shard"], 7);
285 }
286
287 #[test]
288 fn a_global_aggregation_is_rejected_under_a_partition_filter() {
289 let body = br#"{"size":0,"aggs":{"outer":{"terms":{"field":"k"},"aggs":{"leak":{"global":{},"aggs":{"hits":{"top_hits":{"size":50}}}}}}}}"#;
292 assert_eq!(
293 wrap_query(body, &filter()).unwrap_err(),
294 RewriteError::Unfilterable {
295 construct: "global aggregation"
296 }
297 );
298 let body = br#"{"aggregations":{"g":{"global":{}}}}"#;
300 assert!(matches!(
301 wrap_query(body, &filter()).unwrap_err(),
302 RewriteError::Unfilterable { .. }
303 ));
304 }
305
306 #[test]
307 fn a_suggest_block_is_rejected_under_a_partition_filter() {
308 let body = br#"{"suggest":{"s":{"text":"x","term":{"field":"msg"}}}}"#;
309 assert_eq!(
310 wrap_query(body, &filter()).unwrap_err(),
311 RewriteError::Unfilterable {
312 construct: "suggest"
313 }
314 );
315 }
316
317 #[test]
318 fn ordinary_query_scoped_aggregations_are_allowed() {
319 let body = br#"{"aggs":{"by_k":{"terms":{"field":"k"}}}}"#;
321 let wrapped = wrap_query(body, &filter()).unwrap();
322 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
323 assert_eq!(doc["aggs"]["by_k"]["terms"]["field"], "k");
324 assert_eq!(doc["query"]["bool"]["filter"][0]["term"]["_tenant"], "acme");
325 }
326
327 #[test]
328 fn unfilterable_constructs_are_allowed_without_a_partition_filter() {
329 let body = br#"{"aggs":{"g":{"global":{}}},"suggest":{"s":{"text":"x"}}}"#;
332 let wrapped = wrap_query(body, &[]).unwrap();
333 let doc: Value = serde_json::from_slice(&wrapped).unwrap();
334 assert_eq!(doc["aggs"]["g"]["global"], serde_json::json!({}));
335 }
336
337 #[test]
338 fn non_object_body_is_rejected() {
339 assert_eq!(
340 wrap_query(b"[1,2,3]", &filter()).unwrap_err(),
341 RewriteError::NotAnObject
342 );
343 assert_eq!(
344 wrap_query(b"not json", &filter()).unwrap_err(),
345 RewriteError::InvalidJson
346 );
347 }
348}