use branches::mark_unlikely;
use crate::turso_assert;
use std::collections::BTreeSet;
#[derive(Debug)]
pub enum RowSetMode {
Test {
set: BTreeSet<i64>,
batch_number: i32,
},
Smallest {
sorted_vec: Vec<i64>,
},
Unset,
}
#[derive(Debug)]
pub struct RowSet {
fresh: Vec<i64>,
mode: RowSetMode,
}
impl Default for RowSet {
fn default() -> Self {
Self::new()
}
}
impl RowSet {
pub fn new() -> Self {
Self {
fresh: Vec::new(),
mode: RowSetMode::Unset,
}
}
pub fn insert(&mut self, rowid: i64) {
turso_assert!(
!matches!(self.mode, RowSetMode::Smallest { .. }),
"cannot insert after smallest() has been used"
);
self.fresh.push(rowid);
}
pub fn test(&mut self, rowid: i64, batch: i32) -> bool {
turso_assert!(
!matches!(self.mode, RowSetMode::Smallest { .. }),
"cannot call test() after smallest() has started"
);
if matches!(self.mode, RowSetMode::Unset) {
self.mode = RowSetMode::Test {
set: BTreeSet::new(),
batch_number: 0,
};
}
let RowSetMode::Test { set, batch_number } = &mut self.mode else {
mark_unlikely();
unreachable!()
};
if batch != *batch_number {
for v in self.fresh.drain(..) {
set.insert(v);
}
*batch_number = batch;
}
set.contains(&rowid)
}
pub fn smallest(&mut self) -> Option<i64> {
turso_assert!(
!matches!(self.mode, RowSetMode::Test { .. }),
"cannot call smallest() after test() has been used"
);
if matches!(self.mode, RowSetMode::Unset) {
let mut v = Vec::with_capacity(self.fresh.len());
v.append(&mut self.fresh);
v.sort_unstable();
v.dedup();
v.reverse();
self.mode = RowSetMode::Smallest { sorted_vec: v };
}
let RowSetMode::Smallest { sorted_vec } = &mut self.mode else {
mark_unlikely();
unreachable!()
};
sorted_vec.pop()
}
pub fn is_empty(&self) -> bool {
if !self.fresh.is_empty() {
return false;
}
match &self.mode {
RowSetMode::Test { set, .. } => set.is_empty(),
RowSetMode::Smallest { sorted_vec, .. } => sorted_vec.is_empty(),
RowSetMode::Unset => true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand_chacha::{
rand_core::{RngCore, SeedableRng},
ChaCha8Rng,
};
fn get_seed() -> u64 {
std::env::var("SEED").map_or(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis(),
|v| {
v.parse()
.expect("Failed to parse SEED environment variable as u64")
},
) as u64
}
#[test]
fn test_empty_rowset() {
let rowset = RowSet::new();
assert!(rowset.is_empty());
}
#[test]
fn test_insert_and_test() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.insert(20);
rowset.insert(30);
assert!(!rowset.test(10, 0));
assert!(!rowset.test(20, 0));
assert!(!rowset.test(30, 0));
assert!(rowset.test(10, 1));
assert!(rowset.test(20, 1));
assert!(rowset.test(30, 1));
assert!(!rowset.test(40, 1));
}
#[test]
fn test_batch_consolidation() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.insert(20);
assert!(!rowset.test(10, 0));
rowset.insert(30);
rowset.insert(40);
assert!(rowset.test(10, 1));
assert!(rowset.test(20, 1));
assert!(rowset.test(30, 1));
assert!(rowset.test(40, 1));
assert!(!rowset.test(50, 1));
rowset.insert(50);
assert!(rowset.test(10, 1));
assert!(!rowset.test(50, 1));
assert!(rowset.test(50, 2));
}
#[test]
fn test_smallest_extraction() {
let mut rowset = RowSet::new();
rowset.insert(30);
rowset.insert(10);
rowset.insert(50);
rowset.insert(20);
rowset.insert(40);
assert_eq!(rowset.smallest(), Some(10));
assert_eq!(rowset.smallest(), Some(20));
assert_eq!(rowset.smallest(), Some(30));
assert_eq!(rowset.smallest(), Some(40));
assert_eq!(rowset.smallest(), Some(50));
assert_eq!(rowset.smallest(), None);
assert!(rowset.is_empty());
}
#[test]
fn test_smallest_with_duplicates() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.insert(20);
rowset.insert(10);
rowset.insert(30);
rowset.insert(20);
assert_eq!(rowset.smallest(), Some(10));
assert_eq!(rowset.smallest(), Some(20));
assert_eq!(rowset.smallest(), Some(30));
assert_eq!(rowset.smallest(), None);
}
#[test]
fn test_insert_after_smallest_panics() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.smallest();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rowset.insert(20);
}));
assert!(result.is_err());
}
#[test]
fn test_test_after_smallest_panics() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.smallest();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rowset.test(10, 1);
}));
assert!(result.is_err());
}
#[test]
fn test_smallest_after_test_panics() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.test(10, 1);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rowset.smallest();
}));
assert!(result.is_err());
}
#[test]
fn test_batch_zero_allows_smallest() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.insert(20);
rowset.insert(30);
rowset.insert(5);
rowset.insert(15);
assert_eq!(rowset.smallest(), Some(5));
assert_eq!(rowset.smallest(), Some(10));
assert_eq!(rowset.smallest(), Some(15));
assert_eq!(rowset.smallest(), Some(20));
assert_eq!(rowset.smallest(), Some(30));
assert_eq!(rowset.smallest(), None);
}
#[test]
fn test_empty_smallest() {
let mut rowset = RowSet::new();
assert_eq!(rowset.smallest(), None);
assert!(rowset.is_empty());
}
#[test]
fn test_batch_zero_semantics() {
let mut rowset = RowSet::new();
rowset.insert(10);
rowset.insert(20);
assert!(!rowset.test(10, 0));
assert!(!rowset.test(20, 0));
assert!(rowset.test(10, 1));
assert!(rowset.test(20, 1));
}
#[test]
fn test_batch_final_semantics() {
let mut rowset = RowSet::new();
rowset.insert(10);
assert!(rowset.test(10, 1));
rowset.insert(20);
assert!(rowset.test(10, -1));
assert!(rowset.test(20, -1));
assert!(!rowset.test(30, -1));
assert!(!rowset.test(30, -1));
}
#[test]
fn test_negative_values() {
let mut rowset = RowSet::new();
rowset.insert(-10);
rowset.insert(-5);
rowset.insert(0);
rowset.insert(5);
rowset.insert(10);
assert!(rowset.test(-10, 1));
assert!(rowset.test(-5, 1));
assert!(rowset.test(0, 1));
assert!(rowset.test(5, 1));
assert!(rowset.test(10, 1));
assert!(rowset.test(-10, 2));
assert!(rowset.test(-5, 2));
assert!(rowset.test(0, 2));
assert!(rowset.test(5, 2));
assert!(rowset.test(10, 2));
}
#[test]
fn test_large_values() {
let mut rowset = RowSet::new();
let large1 = i64::MAX;
let large2 = i64::MAX - 1;
let large3 = i64::MIN;
let large4 = i64::MIN + 1;
rowset.insert(large1);
rowset.insert(large2);
rowset.insert(large3);
rowset.insert(large4);
assert!(rowset.test(large1, 1));
assert!(rowset.test(large2, 1));
assert!(rowset.test(large3, 1));
assert!(rowset.test(large4, 1));
assert!(rowset.test(large1, 2));
assert!(rowset.test(large2, 2));
assert!(rowset.test(large3, 2));
assert!(rowset.test(large4, 2));
}
#[test]
fn fuzz_basic_operations() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let attempts = 10;
for _ in 0..attempts {
let mut rowset = RowSet::new();
let mut inserted = std::collections::BTreeSet::new();
let num_inserts = 100 + (rng.next_u64() % 900) as usize;
for _ in 0..num_inserts {
let value = rng.next_u64() as i64;
rowset.insert(value);
inserted.insert(value);
}
let mut extracted = Vec::new();
while let Some(value) = rowset.smallest() {
extracted.push(value);
}
assert_eq!(extracted.len(), inserted.len());
let mut sorted_inserted: Vec<i64> = inserted.iter().copied().collect();
sorted_inserted.sort_unstable();
assert_eq!(extracted, sorted_inserted);
}
}
#[test]
fn fuzz_batch_operations() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let attempts = 10;
for _ in 0..attempts {
let mut rowset = RowSet::new();
let mut batches: Vec<(i32, Vec<i64>)> = Vec::new();
let num_batches = 5 + (rng.next_u64() % 10) as usize;
for batch_idx in 0..num_batches {
let batch = if batch_idx == 0 {
0
} else if batch_idx == num_batches - 1 {
-1
} else {
batch_idx as i32
};
let mut batch_values = Vec::new();
let num_values = 10 + (rng.next_u64() % 90) as usize;
for _ in 0..num_values {
let value = rng.next_u64() as i64;
rowset.insert(value);
batch_values.push(value);
}
batches.push((batch, batch_values));
}
for (batch, values) in &batches {
for &value in values {
if *batch == 0 {
assert!(!rowset.test(value, *batch));
} else {
assert!(
rowset.test(value, *batch),
"Value {value} should be found in batch {batch}",
);
}
}
}
}
}
#[test]
fn fuzz_mixed_operations() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let attempts = 3;
for _ in 0..attempts {
let mut rowset = RowSet::new();
let mut all_values = std::collections::BTreeSet::new();
let mut next_batch = 1;
let num_ops = 20 + (rng.next_u64() % 30) as usize;
for _ in 0..num_ops {
let op = rng.next_u64() % 2;
match op {
0 => {
let value = rng.next_u64() as i64;
rowset.insert(value);
all_values.insert(value);
}
_ => {
if !all_values.is_empty() {
let values_vec: Vec<i64> = all_values.iter().copied().collect();
let idx = (rng.next_u64() % values_vec.len() as u64) as usize;
let value = values_vec[idx];
let found = rowset.test(value, next_batch);
assert!(found, "Value {value} should be found in batch {next_batch}",);
next_batch += 1;
}
}
}
}
if !all_values.is_empty() {
let final_batch = next_batch;
for &value in &all_values {
assert!(
rowset.test(value, final_batch),
"Value {value} should be found in batch {final_batch}",
);
}
}
}
}
#[test]
fn fuzz_long() {
let seed = get_seed();
let mut rng = ChaCha8Rng::seed_from_u64(seed);
println!("Fuzz seed: {seed}");
let attempts = 2;
for attempt in 0..attempts {
let mut rowset = RowSet::new();
let mut reference = std::collections::BTreeSet::new();
let mut batches: Vec<(i32, Vec<i64>)> = Vec::new();
let num_batches = 10 + (rng.next_u64() % 40) as usize;
let total_inserts = 1000 + (rng.next_u64() % 9000) as usize;
let inserts_per_batch = (total_inserts / num_batches).max(1);
for batch_idx in 0..num_batches {
let batch = if batch_idx == 0 {
0
} else if batch_idx == num_batches - 1 {
-1
} else {
batch_idx as i32
};
let mut batch_values = Vec::new();
let already_inserted = batches.iter().map(|(_, v)| v.len()).sum::<usize>();
let batch_inserts = if batch_idx == num_batches - 1 {
total_inserts.saturating_sub(already_inserted)
} else {
let remaining = total_inserts.saturating_sub(already_inserted);
let max_for_this_batch = remaining.min(inserts_per_batch * 2);
inserts_per_batch
+ (rng.next_u64()
% (max_for_this_batch.saturating_sub(inserts_per_batch) + 1) as u64)
as usize
};
for _ in 0..batch_inserts {
let value = rng.next_u64() as i64;
rowset.insert(value);
reference.insert(value);
batch_values.push(value);
}
if batch > 0 {
let test_count = (batch_values.len() / 10).max(1);
for _ in 0..test_count {
let idx = (rng.next_u64() % batch_values.len() as u64) as usize;
let value = batch_values[idx];
let found = rowset.test(value, batch);
assert!(
found,
"Attempt {attempt}, batch {batch}, value {value} should be found",
);
}
}
batches.push((batch, batch_values));
}
if !reference.is_empty() && !batches.is_empty() {
let last_batch = batches.last().unwrap().0;
let final_batch = if last_batch == -1 { -1 } else { last_batch + 1 };
for &value in &reference {
assert!(
rowset.test(value, final_batch),
"Attempt {attempt}, value {value} should be found in batch {final_batch}",
);
}
}
}
}
}