1use std::cmp::Ordering;
19
20use roaring::RoaringBitmap;
21use serde::{Deserialize, Serialize};
22use serde_json::Value;
23
24use crate::descriptor::{FieldType, FilterableField};
25use crate::error::{CoreError, Result};
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum SecValue {
30 Keyword(String),
32 Numeric(f64),
34}
35
36#[derive(Debug, Clone, PartialEq)]
39pub enum SecPredicate {
40 Eq {
42 field: String,
44 value: SecValue,
46 },
47 In {
49 field: String,
51 values: Vec<SecValue>,
53 },
54 Range {
56 field: String,
58 lo: Option<SecValue>,
60 hi: Option<SecValue>,
62 lo_inclusive: bool,
64 hi_inclusive: bool,
66 },
67}
68
69impl SecPredicate {
70 #[must_use]
72 pub fn field(&self) -> &str {
73 match self {
74 SecPredicate::Eq { field, .. }
75 | SecPredicate::In { field, .. }
76 | SecPredicate::Range { field, .. } => field,
77 }
78 }
79}
80
81fn encode_f64(x: f64) -> [u8; 8] {
85 let bits = x.to_bits();
86 let ordered = if bits >> 63 == 0 {
87 bits | (1 << 63)
88 } else {
89 !bits
90 };
91 ordered.to_be_bytes()
92}
93
94fn field_value<'a>(payload: &'a Value, path: &str) -> Option<&'a Value> {
96 let mut current = payload;
97 for part in path.split('.') {
98 current = current.get(part)?;
99 }
100 Some(current)
101}
102
103fn encode_field_value(field_type: FieldType, value: &Value) -> Option<Vec<u8>> {
107 match field_type {
108 FieldType::Keyword => match value {
109 Value::String(s) => Some(s.as_bytes().to_vec()),
110 _ => None,
111 },
112 FieldType::Numeric => match value {
113 Value::Number(n) => {
114 let x = n.as_f64()?;
115 (!x.is_nan()).then(|| encode_f64(x).to_vec())
116 }
117 _ => None,
118 },
119 }
120}
121
122fn encode_sec_value(field_type: FieldType, value: &SecValue) -> Option<Vec<u8>> {
125 match (field_type, value) {
126 (FieldType::Keyword, SecValue::Keyword(s)) => Some(s.as_bytes().to_vec()),
127 (FieldType::Numeric, SecValue::Numeric(x)) => {
128 (!x.is_nan()).then(|| encode_f64(*x).to_vec())
129 }
130 _ => None,
131 }
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
137struct FieldIndex {
138 path: String,
139 field_type: FieldType,
140 keys: Vec<Vec<u8>>,
142 bitmaps: Vec<Vec<u8>>,
143}
144
145impl FieldIndex {
146 fn bitmap_at(&self, i: usize) -> Result<RoaringBitmap> {
147 let bytes = self.bitmaps.get(i).ok_or_else(|| {
148 CoreError::MalformedPage("secondary-index bitmap out of range".into())
149 })?;
150 Ok(RoaringBitmap::deserialize_from(&bytes[..])?)
151 }
152
153 fn equals(&self, value: &SecValue) -> Result<RoaringBitmap> {
154 let Some(key) = encode_sec_value(self.field_type, value) else {
155 return Ok(RoaringBitmap::new());
156 };
157 match self.keys.binary_search(&key) {
158 Ok(i) => self.bitmap_at(i),
159 Err(_) => Ok(RoaringBitmap::new()),
160 }
161 }
162
163 fn range(
164 &self,
165 lo: Option<&SecValue>,
166 hi: Option<&SecValue>,
167 lo_inclusive: bool,
168 hi_inclusive: bool,
169 ) -> Result<RoaringBitmap> {
170 let lo_key = match lo {
173 Some(v) => match encode_sec_value(self.field_type, v) {
174 Some(k) => Some(k),
175 None => return Ok(RoaringBitmap::new()),
176 },
177 None => None,
178 };
179 let hi_key = match hi {
180 Some(v) => match encode_sec_value(self.field_type, v) {
181 Some(k) => Some(k),
182 None => return Ok(RoaringBitmap::new()),
183 },
184 None => None,
185 };
186 let mut out = RoaringBitmap::new();
187 for (i, key) in self.keys.iter().enumerate() {
188 if let Some(l) = &lo_key {
189 let c = key.as_slice().cmp(l.as_slice());
190 if c == Ordering::Less || (c == Ordering::Equal && !lo_inclusive) {
191 continue;
192 }
193 }
194 if let Some(h) = &hi_key {
195 let c = key.as_slice().cmp(h.as_slice());
196 if c == Ordering::Greater || (c == Ordering::Equal && !hi_inclusive) {
197 continue;
198 }
199 }
200 out |= self.bitmap_at(i)?;
201 }
202 Ok(out)
203 }
204}
205
206#[derive(Debug, Clone, Default, Serialize, Deserialize)]
208pub(crate) struct SecIndex {
209 fields: Vec<FieldIndex>,
210}
211
212impl SecIndex {
213 pub(crate) fn build(filterable: &[FilterableField], payloads: &[&[u8]]) -> Result<Self> {
216 let mut maps: Vec<std::collections::BTreeMap<Vec<u8>, RoaringBitmap>> =
218 vec![std::collections::BTreeMap::new(); filterable.len()];
219 for (row, payload) in payloads.iter().enumerate() {
220 let Ok(value) = serde_json::from_slice::<Value>(payload) else {
221 continue; };
223 for (i, field) in filterable.iter().enumerate() {
224 if let Some(fv) = field_value(&value, &field.path)
225 && let Some(key) = encode_field_value(field.field_type, fv)
226 {
227 maps[i].entry(key).or_default().insert(row as u32);
228 }
229 }
230 }
231 let mut fields = Vec::with_capacity(filterable.len());
232 for (field, map) in filterable.iter().zip(maps) {
233 let mut keys = Vec::with_capacity(map.len());
234 let mut bitmaps = Vec::with_capacity(map.len());
235 for (key, bitmap) in map {
236 let mut buf = Vec::with_capacity(bitmap.serialized_size());
237 bitmap.serialize_into(&mut buf)?;
238 keys.push(key);
239 bitmaps.push(buf);
240 }
241 fields.push(FieldIndex {
242 path: field.path.clone(),
243 field_type: field.field_type,
244 keys,
245 bitmaps,
246 });
247 }
248 Ok(Self { fields })
249 }
250
251 pub(crate) fn encode(&self) -> Result<Vec<u8>> {
253 Ok(postcard::to_allocvec(self)?)
254 }
255
256 pub(crate) fn decode(bytes: &[u8]) -> Result<Self> {
258 Ok(postcard::from_bytes(bytes)?)
259 }
260
261 pub(crate) fn query(&self, predicate: &SecPredicate) -> Result<Option<RoaringBitmap>> {
264 let Some(field) = self.fields.iter().find(|f| f.path == predicate.field()) else {
265 return Ok(None);
266 };
267 let bitmap = match predicate {
268 SecPredicate::Eq { value, .. } => field.equals(value)?,
269 SecPredicate::In { values, .. } => {
270 let mut out = RoaringBitmap::new();
271 for value in values {
272 out |= field.equals(value)?;
273 }
274 out
275 }
276 SecPredicate::Range {
277 lo,
278 hi,
279 lo_inclusive,
280 hi_inclusive,
281 ..
282 } => field.range(lo.as_ref(), hi.as_ref(), *lo_inclusive, *hi_inclusive)?,
283 };
284 Ok(Some(bitmap))
285 }
286}
287
288pub(crate) fn payload_matches(
292 predicate: &SecPredicate,
293 field_type: FieldType,
294 payload: &[u8],
295) -> bool {
296 let Ok(value) = serde_json::from_slice::<Value>(payload) else {
297 return false;
298 };
299 let Some(fv) = field_value(&value, predicate.field()) else {
300 return false;
301 };
302 let Some(key) = encode_field_value(field_type, fv) else {
303 return false;
304 };
305 match predicate {
306 SecPredicate::Eq { value, .. } => {
307 encode_sec_value(field_type, value).is_some_and(|k| k == key)
308 }
309 SecPredicate::In { values, .. } => values
310 .iter()
311 .any(|v| encode_sec_value(field_type, v).is_some_and(|k| k == key)),
312 SecPredicate::Range {
313 lo,
314 hi,
315 lo_inclusive,
316 hi_inclusive,
317 ..
318 } => {
319 let lo_ok = match lo {
320 Some(v) => encode_sec_value(field_type, v).is_some_and(|l| {
321 let c = key.as_slice().cmp(l.as_slice());
322 c == Ordering::Greater || (c == Ordering::Equal && *lo_inclusive)
323 }),
324 None => true,
325 };
326 let hi_ok = match hi {
327 Some(v) => encode_sec_value(field_type, v).is_some_and(|h| {
328 let c = key.as_slice().cmp(h.as_slice());
329 c == Ordering::Less || (c == Ordering::Equal && *hi_inclusive)
330 }),
331 None => true,
332 };
333 lo_ok && hi_ok
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use serde_json::json;
342
343 fn fields() -> Vec<FilterableField> {
344 vec![
345 FilterableField::keyword("city"),
346 FilterableField::numeric("age"),
347 ]
348 }
349
350 fn payloads() -> Vec<Vec<u8>> {
351 vec![
352 json!({"city": "paris", "age": 30}).to_string().into_bytes(),
353 json!({"city": "lyon", "age": 25}).to_string().into_bytes(),
354 json!({"city": "paris", "age": 40}).to_string().into_bytes(),
355 json!({"city": "paris"}).to_string().into_bytes(), ]
357 }
358
359 fn built() -> SecIndex {
360 let p = payloads();
361 let refs: Vec<&[u8]> = p.iter().map(Vec::as_slice).collect();
362 let idx = SecIndex::build(&fields(), &refs).unwrap();
363 SecIndex::decode(&idx.encode().unwrap()).unwrap()
365 }
366
367 fn rows(bm: Option<RoaringBitmap>) -> Vec<u32> {
368 bm.unwrap().iter().collect()
369 }
370
371 #[test]
372 fn equality_on_keyword_and_numeric() {
373 let idx = built();
374 assert_eq!(
375 rows(
376 idx.query(&SecPredicate::Eq {
377 field: "city".into(),
378 value: SecValue::Keyword("paris".into()),
379 })
380 .unwrap()
381 ),
382 vec![0, 2, 3]
383 );
384 assert_eq!(
385 rows(
386 idx.query(&SecPredicate::Eq {
387 field: "age".into(),
388 value: SecValue::Numeric(25.0),
389 })
390 .unwrap()
391 ),
392 vec![1]
393 );
394 }
395
396 #[test]
397 fn numeric_range_is_order_preserving() {
398 let idx = built();
399 assert_eq!(
401 rows(
402 idx.query(&SecPredicate::Range {
403 field: "age".into(),
404 lo: Some(SecValue::Numeric(25.0)),
405 hi: Some(SecValue::Numeric(40.0)),
406 lo_inclusive: true,
407 hi_inclusive: false,
408 })
409 .unwrap()
410 ),
411 vec![0, 1]
412 );
413 }
414
415 #[test]
416 fn in_unions_values_and_unknown_field_is_none() {
417 let idx = built();
418 assert_eq!(
419 rows(
420 idx.query(&SecPredicate::In {
421 field: "city".into(),
422 values: vec![
423 SecValue::Keyword("lyon".into()),
424 SecValue::Keyword("paris".into())
425 ],
426 })
427 .unwrap()
428 ),
429 vec![0, 1, 2, 3]
430 );
431 assert!(
433 idx.query(&SecPredicate::Eq {
434 field: "country".into(),
435 value: SecValue::Keyword("fr".into()),
436 })
437 .unwrap()
438 .is_none()
439 );
440 }
441
442 #[test]
443 fn negative_numbers_order_correctly() {
444 let p: Vec<Vec<u8>> = [-5.0, 0.0, -100.0, 7.0]
445 .iter()
446 .map(|x| json!({ "t": x }).to_string().into_bytes())
447 .collect();
448 let refs: Vec<&[u8]> = p.iter().map(Vec::as_slice).collect();
449 let idx = SecIndex::build(&[FilterableField::numeric("t")], &refs).unwrap();
450 assert_eq!(
452 rows(
453 idx.query(&SecPredicate::Range {
454 field: "t".into(),
455 lo: None,
456 hi: Some(SecValue::Numeric(0.0)),
457 lo_inclusive: true,
458 hi_inclusive: false,
459 })
460 .unwrap()
461 ),
462 vec![0, 2]
463 );
464 }
465
466 #[test]
467 fn payload_matches_agrees_with_the_index() {
468 let pay = json!({"city": "paris", "age": 30}).to_string().into_bytes();
469 assert!(payload_matches(
470 &SecPredicate::Eq {
471 field: "city".into(),
472 value: SecValue::Keyword("paris".into())
473 },
474 FieldType::Keyword,
475 &pay
476 ));
477 assert!(payload_matches(
478 &SecPredicate::Range {
479 field: "age".into(),
480 lo: Some(SecValue::Numeric(18.0)),
481 hi: None,
482 lo_inclusive: true,
483 hi_inclusive: true,
484 },
485 FieldType::Numeric,
486 &pay
487 ));
488 assert!(!payload_matches(
489 &SecPredicate::Eq {
490 field: "city".into(),
491 value: SecValue::Keyword("lyon".into())
492 },
493 FieldType::Keyword,
494 &pay
495 ));
496 }
497}