use std::ops::RangeBounds;
use crate::config::Config;
use crate::error::Result;
use crate::key::Key;
use crate::map::{Guard, LearnedMap, MapRef};
#[derive(Debug)]
pub struct LearnedSet<K: Key> {
inner: LearnedMap<K, ()>,
}
pub struct SetRef<'a, K: Key> {
inner: MapRef<'a, K, ()>,
}
impl<K: Key> SetRef<'_, K> {
pub fn insert(&self, key: K) -> bool {
self.inner.insert(key, ())
}
pub fn remove(&self, key: &K) -> bool {
self.inner.remove(key)
}
pub fn contains(&self, key: &K) -> bool {
self.inner.contains_key(key)
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn range<R: RangeBounds<K>>(&self, range: R) -> impl Iterator<Item = &K> {
self.inner.range(range).map(|(k, ())| k)
}
pub fn first(&self) -> Option<&K> {
self.inner.first_key_value().map(|(k, ())| k)
}
pub fn last(&self) -> Option<&K> {
self.inner.last_key_value().map(|(k, ())| k)
}
}
impl<K: Key> LearnedSet<K> {
pub fn new() -> Self {
Self {
inner: LearnedMap::new(),
}
}
pub fn with_config(config: Config) -> Self {
Self {
inner: LearnedMap::with_config(config),
}
}
pub fn bulk_load(keys: &[K]) -> Result<Self> {
let pairs: Vec<(K, ())> = keys.iter().map(|k| (k.clone(), ())).collect();
Ok(Self {
inner: LearnedMap::bulk_load_dedup(&pairs)?,
})
}
pub fn guard(&self) -> Guard {
self.inner.guard()
}
pub fn pin(&self) -> SetRef<'_, K> {
SetRef {
inner: self.inner.pin(),
}
}
pub fn insert(&self, key: K, guard: &Guard) -> bool {
self.inner.insert(key, (), guard)
}
pub fn remove(&self, key: &K, guard: &Guard) -> bool {
self.inner.remove(key, guard)
}
pub fn contains(&self, key: &K, guard: &Guard) -> bool {
self.inner.contains_key(key, guard)
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn range<'g, R: RangeBounds<K>>(
&self,
range: R,
guard: &'g Guard,
) -> impl Iterator<Item = &'g K> {
self.inner.range(range, guard).map(|(k, ())| k)
}
pub fn first<'g>(&self, guard: &'g Guard) -> Option<&'g K> {
self.inner.first_key_value(guard).map(|(k, ())| k)
}
pub fn last<'g>(&self, guard: &'g Guard) -> Option<&'g K> {
self.inner.last_key_value(guard).map(|(k, ())| k)
}
}
#[cfg(feature = "serde")]
impl<K> serde::Serialize for LearnedSet<K>
where
K: Key + serde::Serialize,
{
fn serialize<S: serde::Serializer>(
&self,
serializer: S,
) -> std::result::Result<S::Ok, S::Error> {
use serde::ser::SerializeSeq;
let guard = self.guard();
let len = self.len();
let mut seq = serializer.serialize_seq(Some(len))?;
for (k, ()) in self.inner.iter(&guard) {
seq.serialize_element(k)?;
}
seq.end()
}
}
#[cfg(feature = "serde")]
impl<'de, K> serde::Deserialize<'de> for LearnedSet<K>
where
K: Key + serde::Deserialize<'de>,
{
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<Self, D::Error> {
let keys: Vec<K> = Vec::deserialize(deserializer)?;
if keys.is_empty() {
return Ok(Self::new());
}
Self::bulk_load(&keys).map_err(serde::de::Error::custom)
}
}
impl<K: Key> Default for LearnedSet<K> {
fn default() -> Self {
Self::new()
}
}
impl<K: Key> FromIterator<K> for LearnedSet<K> {
fn from_iter<I: IntoIterator<Item = K>>(iter: I) -> Self {
let set = Self::new();
let guard = set.guard();
for k in iter {
set.insert(k, &guard);
}
set
}
}
impl<K: Key> Extend<K> for LearnedSet<K> {
fn extend<I: IntoIterator<Item = K>>(&mut self, iter: I) {
let guard = self.guard();
for k in iter {
self.insert(k, &guard);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_set_ops() {
let set = LearnedSet::new();
let g = set.guard();
assert!(set.insert(1u64, &g));
assert!(set.insert(2, &g));
assert!(!set.insert(1, &g)); assert_eq!(set.len(), 2);
assert!(set.contains(&1, &g));
assert!(set.remove(&1, &g));
assert!(!set.contains(&1, &g));
assert_eq!(set.len(), 1);
}
#[test]
fn from_iterator() {
let set: LearnedSet<u64> = vec![3, 1, 2].into_iter().collect();
let g = set.guard();
assert_eq!(set.len(), 3);
assert!(set.contains(&1, &g));
assert!(set.contains(&2, &g));
assert!(set.contains(&3, &g));
}
#[test]
fn bulk_load_set() {
let keys: Vec<u64> = (0..100).collect();
let set = LearnedSet::bulk_load(&keys).unwrap();
let g = set.guard();
assert_eq!(set.len(), 100);
for k in &keys {
assert!(set.contains(k, &g));
}
}
#[test]
fn bulk_load_deduplicates() {
let keys: Vec<u64> = vec![1, 1, 2, 3, 3, 3, 4, 5];
let set = LearnedSet::bulk_load(&keys).unwrap();
let g = set.guard();
assert_eq!(set.len(), 5);
for k in 1..=5u64 {
assert!(set.contains(&k, &g), "key {k} missing after dedup");
}
}
#[test]
fn set_ref_convenience() {
let set = LearnedSet::new();
let s = set.pin();
assert!(s.insert(10u64));
assert!(s.insert(20));
assert!(!s.insert(10));
assert_eq!(s.len(), 2);
assert!(s.contains(&10));
assert!(s.remove(&10));
assert!(!s.contains(&10));
assert_eq!(s.len(), 1);
}
}