use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
use roaring::RoaringBitmap;
use selene_core::{DbString, DurationOrderKey, Value, duration_order_key};
use smallvec::SmallVec;
use crate::typed_index::{NotNanError, NotNanF32, NotNanF64, TypedIndexKind, TypedIndexValueError};
pub type CompositeKey = SmallVec<[CompositeKeyComponent; 4]>;
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum CompositeKeyComponent {
Bool(bool),
I64(i64),
U64(u64),
I128(i128),
U128(u128),
Decimal(rust_decimal::Decimal),
F32(NotNanF32),
F64(NotNanF64),
String(DbString),
Date(jiff::civil::Date),
LocalDateTime(jiff::civil::DateTime),
ZonedDateTime(jiff::Zoned),
LocalTime(jiff::civil::Time),
ZonedTime(jiff::Zoned),
Duration(DurationOrderKey),
Uuid(uuid::Uuid),
}
impl Ord for CompositeKeyComponent {
fn cmp(&self, rhs: &Self) -> std::cmp::Ordering {
use CompositeKeyComponent as K;
match (self, rhs) {
(K::Bool(lhs), K::Bool(rhs)) => lhs.cmp(rhs),
(K::I64(lhs), K::I64(rhs)) => lhs.cmp(rhs),
(K::U64(lhs), K::U64(rhs)) => lhs.cmp(rhs),
(K::I128(lhs), K::I128(rhs)) => lhs.cmp(rhs),
(K::U128(lhs), K::U128(rhs)) => lhs.cmp(rhs),
(K::Decimal(lhs), K::Decimal(rhs)) => lhs.cmp(rhs),
(K::F32(lhs), K::F32(rhs)) => lhs.cmp(rhs),
(K::F64(lhs), K::F64(rhs)) => lhs.cmp(rhs),
(K::String(lhs), K::String(rhs)) => lhs.cmp(rhs),
(K::Date(lhs), K::Date(rhs)) => lhs.cmp(rhs),
(K::LocalDateTime(lhs), K::LocalDateTime(rhs)) => lhs.cmp(rhs),
(K::ZonedDateTime(lhs), K::ZonedDateTime(rhs)) => lhs.cmp(rhs),
(K::LocalTime(lhs), K::LocalTime(rhs)) => lhs.cmp(rhs),
(K::ZonedTime(lhs), K::ZonedTime(rhs)) => lhs.cmp(rhs),
(K::Duration(lhs), K::Duration(rhs)) => lhs.cmp(rhs),
(K::Uuid(lhs), K::Uuid(rhs)) => lhs.cmp(rhs),
_ => self.discriminant().cmp(&rhs.discriminant()),
}
}
}
impl PartialOrd for CompositeKeyComponent {
fn partial_cmp(&self, rhs: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(rhs))
}
}
impl Hash for CompositeKeyComponent {
fn hash<H: Hasher>(&self, state: &mut H) {
self.discriminant().hash(state);
match self {
Self::Bool(value) => value.hash(state),
Self::I64(value) => value.hash(state),
Self::U64(value) => value.hash(state),
Self::I128(value) => value.hash(state),
Self::U128(value) => value.hash(state),
Self::Decimal(value) => value.hash(state),
Self::F32(value) => value.hash(state),
Self::F64(value) => value.hash(state),
Self::String(value) => value.hash(state),
Self::Date(value) => value.hash(state),
Self::LocalDateTime(value) => value.hash(state),
Self::ZonedDateTime(value) => value.hash(state),
Self::LocalTime(value) => value.hash(state),
Self::ZonedTime(value) => value.hash(state),
Self::Duration(value) => value.hash(state),
Self::Uuid(value) => value.hash(state),
}
}
}
impl CompositeKeyComponent {
const fn discriminant(&self) -> u8 {
match self {
Self::Bool(_) => 0,
Self::I64(_) => 1,
Self::U64(_) => 2,
Self::I128(_) => 3,
Self::U128(_) => 4,
Self::Decimal(_) => 5,
Self::F32(_) => 6,
Self::F64(_) => 7,
Self::String(_) => 8,
Self::Date(_) => 9,
Self::LocalDateTime(_) => 10,
Self::ZonedDateTime(_) => 11,
Self::LocalTime(_) => 12,
Self::ZonedTime(_) => 13,
Self::Duration(_) => 14,
Self::Uuid(_) => 15,
}
}
}
#[derive(Clone, Debug)]
pub struct CompositeTypedIndex {
kinds: SmallVec<[TypedIndexKind; 4]>,
entries: BTreeMap<CompositeKey, RoaringBitmap>,
}
impl CompositeTypedIndex {
#[must_use]
pub fn new(kinds: SmallVec<[TypedIndexKind; 4]>) -> Self {
Self {
kinds,
entries: BTreeMap::new(),
}
}
#[must_use]
pub fn kinds(&self) -> &[TypedIndexKind] {
&self.kinds
}
#[must_use]
pub fn cardinality(&self) -> u64 {
self.entries.values().map(RoaringBitmap::len).sum()
}
#[must_use]
pub fn distinct_keys(&self) -> u64 {
self.entries.len() as u64
}
pub fn entries(&self) -> impl Iterator<Item = (&CompositeKey, &RoaringBitmap)> {
self.entries.iter()
}
#[must_use]
pub(crate) fn buckets_eq(&self, reference: &Self) -> bool {
self.kinds == reference.kinds && self.entries == reference.entries
}
#[must_use]
pub(crate) fn has_empty_bucket(&self) -> bool {
self.entries.values().any(RoaringBitmap::is_empty)
}
pub fn insert(&mut self, values: &[&Value], row: u32) -> Result<(), CompositeIndexValueError> {
let key = self.key_from_values(values)?;
self.entries.entry(key).or_default().insert(row);
Ok(())
}
pub fn remove(&mut self, values: &[&Value], row: u32) -> Result<(), CompositeIndexValueError> {
let key = self.key_from_values(values)?;
if let Some(bitmap) = self.entries.get_mut(&key) {
bitmap.remove(row);
if bitmap.is_empty() {
self.entries.remove(&key);
}
}
Ok(())
}
#[must_use]
pub fn lookup_key(&self, key: &CompositeKey) -> Option<&RoaringBitmap> {
self.entries.get(key)
}
pub fn key_from_values(
&self,
values: &[&Value],
) -> Result<CompositeKey, CompositeIndexValueError> {
composite_key_from_values(&self.kinds, values)
}
pub fn values_share_key(&self, lhs: &[&Value], rhs: &[&Value]) -> bool {
match (self.key_from_values(lhs), self.key_from_values(rhs)) {
(Ok(lhs_key), Ok(rhs_key)) => lhs_key == rhs_key,
(Err(_), Err(_)) => true,
_ => false,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum CompositeIndexValueError {
ArityMismatch {
expected: usize,
observed: usize,
},
Component {
index: usize,
expected_kind: TypedIndexKind,
observed: &'static str,
},
}
pub(crate) fn composite_key_from_values(
kinds: &[TypedIndexKind],
values: &[&Value],
) -> Result<CompositeKey, CompositeIndexValueError> {
if kinds.len() != values.len() {
return Err(CompositeIndexValueError::ArityMismatch {
expected: kinds.len(),
observed: values.len(),
});
}
kinds
.iter()
.zip(values)
.enumerate()
.map(|(index, (kind, value))| {
component_from_value(*kind, value).map_err(|source| {
CompositeIndexValueError::Component {
index,
expected_kind: source.expected_kind(),
observed: source.observed(),
}
})
})
.collect()
}
fn component_from_value(
kind: TypedIndexKind,
value: &Value,
) -> Result<CompositeKeyComponent, TypedIndexValueError> {
match (kind, value) {
(TypedIndexKind::Bool, Value::Bool(value)) => Ok(CompositeKeyComponent::Bool(*value)),
(TypedIndexKind::I64, Value::Int(value)) => Ok(CompositeKeyComponent::I64(*value)),
(TypedIndexKind::U64, Value::Uint(value)) => Ok(CompositeKeyComponent::U64(*value)),
(TypedIndexKind::I128, Value::Int128(value)) => Ok(CompositeKeyComponent::I128(*value)),
(TypedIndexKind::U128, Value::Uint128(value)) => Ok(CompositeKeyComponent::U128(*value)),
(TypedIndexKind::Decimal, Value::Decimal(value)) => {
Ok(CompositeKeyComponent::Decimal(*value))
}
(TypedIndexKind::F32, Value::Float32(value)) => NotNanF32::new(*value)
.map(CompositeKeyComponent::F32)
.map_err(|NotNanError| TypedIndexValueError::NaN {
expected_kind: TypedIndexKind::F32,
}),
(TypedIndexKind::F64, Value::Float(value)) => NotNanF64::new(*value)
.map(CompositeKeyComponent::F64)
.map_err(|NotNanError| TypedIndexValueError::NaN {
expected_kind: TypedIndexKind::F64,
}),
(TypedIndexKind::String, Value::String(value)) => {
Ok(CompositeKeyComponent::String(value.clone()))
}
(TypedIndexKind::Date, Value::Date(value)) => Ok(CompositeKeyComponent::Date(*value)),
(TypedIndexKind::LocalDateTime, Value::LocalDateTime(value)) => {
Ok(CompositeKeyComponent::LocalDateTime(*value))
}
(TypedIndexKind::ZonedDateTime, Value::ZonedDateTime(value)) => {
Ok(CompositeKeyComponent::ZonedDateTime((**value).clone()))
}
(TypedIndexKind::LocalTime, Value::LocalTime(value)) => {
Ok(CompositeKeyComponent::LocalTime(*value))
}
(TypedIndexKind::ZonedTime, Value::ZonedTime(value)) => {
Ok(CompositeKeyComponent::ZonedTime((**value).clone()))
}
(TypedIndexKind::Duration, Value::Duration(value)) => {
Ok(CompositeKeyComponent::Duration(duration_order_key(value)))
}
(TypedIndexKind::Uuid, Value::Uuid(value)) => Ok(CompositeKeyComponent::Uuid(*value)),
(expected_kind, value) => Err(TypedIndexValueError::KindMismatch {
expected_kind,
observed: crate::typed_index::observed_value_kind(value),
}),
}
}
#[cfg(test)]
mod tests {
use selene_core::db_string;
use smallvec::smallvec;
use super::*;
fn decimal(value: &str) -> rust_decimal::Decimal {
value.parse().expect("test decimal parses")
}
#[test]
fn component_from_value_string_kind() {
let probe = "component_admit.string.unique-1";
let value = Value::String(db_string(probe).unwrap());
let component =
component_from_value(TypedIndexKind::String, &value).expect("string component coerces");
let CompositeKeyComponent::String(db_string) = component else {
panic!("expected String component, got {component:?}");
};
assert_eq!(db_string.as_str(), probe);
}
#[test]
fn component_from_value_bool_kind() {
let value = Value::Bool(true);
let component =
component_from_value(TypedIndexKind::Bool, &value).expect("bool component coerces");
assert_eq!(component, CompositeKeyComponent::Bool(true));
}
#[test]
fn component_from_value_u64_kind() {
let value = Value::Uint(42);
let component =
component_from_value(TypedIndexKind::U64, &value).expect("u64 component coerces");
assert_eq!(component, CompositeKeyComponent::U64(42));
}
#[test]
fn component_from_value_exact_numeric_kinds() {
let signed = Value::Int128(i128::MIN + 42);
let unsigned = Value::Uint128(u128::MAX - 42);
let amount = Value::Decimal(decimal("42.25"));
assert_eq!(
component_from_value(TypedIndexKind::I128, &signed).expect("i128 component coerces"),
CompositeKeyComponent::I128(i128::MIN + 42)
);
assert_eq!(
component_from_value(TypedIndexKind::U128, &unsigned).expect("u128 component coerces"),
CompositeKeyComponent::U128(u128::MAX - 42)
);
assert_eq!(
component_from_value(TypedIndexKind::Decimal, &amount)
.expect("decimal component coerces"),
CompositeKeyComponent::Decimal(decimal("42.25"))
);
}
#[test]
fn component_from_value_float32_kind() {
let value = Value::Float32(1.25_f32);
let component =
component_from_value(TypedIndexKind::F32, &value).expect("f32 component coerces");
assert_eq!(
component,
CompositeKeyComponent::F32(NotNanF32::new(1.25_f32).unwrap())
);
}
#[test]
fn component_from_value_duration_kind() {
let value = Value::Duration(Box::new("PT1H2S".parse().unwrap()));
let component = component_from_value(TypedIndexKind::Duration, &value)
.expect("duration component coerces");
assert_eq!(
component,
CompositeKeyComponent::Duration(selene_core::duration_order_key(match &value {
Value::Duration(value) => value,
_ => unreachable!("test value is duration"),
}))
);
}
#[test]
fn composite_key_rejects_when_later_component_kind_mismatches() {
let kinds: SmallVec<[TypedIndexKind; 4]> =
smallvec![TypedIndexKind::String, TypedIndexKind::I64];
let location = Value::String(db_string("composite_admit.left_to_right.loc").unwrap());
let bad = Value::String(db_string("composite_admit.left_to_right.bad").unwrap());
let refs: Vec<&Value> = vec![&location, &bad];
let err = composite_key_from_values(&kinds, &refs)
.expect_err("tuple kind mismatch on later component rejects whole tuple");
assert!(matches!(
err,
CompositeIndexValueError::Component {
index: 1,
expected_kind: TypedIndexKind::I64,
observed: "String",
}
));
}
#[test]
fn composite_key_from_values_admits_string_component() {
let kinds: SmallVec<[TypedIndexKind; 4]> =
smallvec![TypedIndexKind::I64, TypedIndexKind::String];
let ts = Value::Int(7);
let location = Value::String(db_string("composite_admit.string.unique-1").unwrap());
let refs: Vec<&Value> = vec![&ts, &location];
let key = composite_key_from_values(&kinds, &refs).expect("string component coerces");
assert_eq!(key.len(), 2);
}
#[test]
fn values_share_key_matches_equal_string_components() {
let index =
CompositeTypedIndex::new(smallvec![TypedIndexKind::I64, TypedIndexKind::String]);
let ts_lhs = Value::Int(1);
let ts_rhs = Value::Int(1);
let loc_lhs =
Value::String(db_string("values_share_key.composite.string.unique-1").unwrap());
let loc_rhs =
Value::String(db_string("values_share_key.composite.string.unique-1").unwrap());
let lhs: Vec<&Value> = vec![&ts_lhs, &loc_lhs];
let rhs: Vec<&Value> = vec![&ts_rhs, &loc_rhs];
assert!(index.values_share_key(&lhs, &rhs));
}
#[test]
fn distinct_keys_counts_composite_buckets_not_rows() {
let mut index =
CompositeTypedIndex::new(smallvec![TypedIndexKind::I64, TypedIndexKind::String]);
assert_eq!(index.distinct_keys(), 0, "empty index");
let k1 = db_string("k1").unwrap();
let v_k1 = Value::String(k1);
let one = Value::Int(1);
let two = Value::Int(2);
index.insert(&[&one, &v_k1], 0).unwrap();
index.insert(&[&one, &v_k1], 1).unwrap();
index.insert(&[&two, &v_k1], 2).unwrap();
assert_eq!(index.cardinality(), 3);
assert_eq!(index.distinct_keys(), 2);
index.remove(&[&one, &v_k1], 0).unwrap();
assert_eq!(index.cardinality(), 2);
assert_eq!(index.distinct_keys(), 2);
index.remove(&[&one, &v_k1], 1).unwrap();
assert_eq!(index.cardinality(), 1);
assert_eq!(index.distinct_keys(), 1);
}
#[test]
fn values_share_key_returns_false_for_distinct_strings() {
let index =
CompositeTypedIndex::new(smallvec![TypedIndexKind::I64, TypedIndexKind::String]);
let ts_lhs = Value::Int(1);
let ts_rhs = Value::Int(1);
let loc_lhs = Value::String(db_string("values_share_key.composite.lhs-unique").unwrap());
let loc_rhs = Value::String(db_string("values_share_key.composite.rhs-unique").unwrap());
let lhs: Vec<&Value> = vec![&ts_lhs, &loc_lhs];
let rhs: Vec<&Value> = vec![&ts_rhs, &loc_rhs];
assert!(!index.values_share_key(&lhs, &rhs));
}
}