use crate::eval::value::Value;
use super::comparator::Comparator;
use super::Container;
use std::sync::{Arc, RwLock};
use std::cmp::Ordering;
#[derive(Clone, Copy, Debug, PartialEq)]
enum Color {
Red,
Black,
}
#[derive(Clone, Debug)]
struct Node {
value: Value,
color: Color,
left: Option<Arc<Node>>,
right: Option<Arc<Node>>,
size: usize, }
impl Node {
fn new_red(value: Value) -> Self {
Self {
value,
color: Color::Red,
left: None,
right: None,
size: 1,
}
}
fn new_black(value: Value) -> Self {
Self {
value,
color: Color::Black,
left: None,
right: None,
size: 1,
}
}
fn new_with_children(
value: Value,
color: Color,
left: Option<Arc<Node>>,
right: Option<Arc<Node>>,
) -> Self {
let size = 1
+ left.as_ref().map(|n| n.size).unwrap_or(0)
+ right.as_ref().map(|n| n.size).unwrap_or(0);
Self {
value,
color,
left,
right,
size,
}
}
fn update_size(&mut self) {
self.size = 1
+ self.left.as_ref().map(|n| n.size).unwrap_or(0)
+ self.right.as_ref().map(|n| n.size).unwrap_or(0);
}
fn is_red(&self) -> bool {
self.color == Color::Red
}
fn is_black(&self) -> bool {
self.color == Color::Black
}
fn make_red(mut self) -> Self {
self.color = Color::Red;
self
}
fn make_black(mut self) -> Self {
self.color = Color::Black;
self
}
}
fn is_red(node: &Option<Arc<Node>>) -> bool {
node.as_ref().map(|n| n.is_red()).unwrap_or(false)
}
fn size(node: &Option<Arc<Node>>) -> usize {
node.as_ref().map(|n| n.size).unwrap_or(0)
}
#[derive(Clone, Debug)]
struct RedBlackTree {
root: Option<Arc<Node>>,
comparator: Comparator,
}
impl RedBlackTree {
fn new(comparator: Comparator) -> Self {
Self {
root: None,
comparator,
}
}
fn size(&self) -> usize {
size(&self.root)
}
fn is_empty(&self) -> bool {
self.root.is_none()
}
fn contains(&self, value: &Value) -> bool {
self.search_node(&self.root, value).is_some()
}
fn search_node(&self, node: &Option<Arc<Node>>, value: &Value) -> Option<Arc<Node>> {
match node {
None => None,
Some(n) => {
match self.comparator.compare(value, &n.value) {
Ordering::Equal => Some(n.clone()),
Ordering::Less => self.search_node(&n.left, value),
Ordering::Greater => self.search_node(&n.right, value),
}
}
}
}
fn insert(&mut self, value: Value) -> bool {
let root = self.root.take();
let (new_root, inserted) = self.insert_recursive(root, value);
self.root = new_root.map(|n| Arc::new(n.make_black()));
inserted
}
fn insert_recursive(&self, node: Option<Arc<Node>>, value: Value) -> (Option<Node>, bool) {
match node {
None => (Some(Node::new_red(value)), true),
Some(n) => {
match self.comparator.compare(&value, &n.value) {
Ordering::Equal => (Some((*n).clone()), false), Ordering::Less => {
let (new_left, inserted) = self.insert_recursive(n.left.clone(), value);
let mut new_node = Node::new_with_children(
n.value.clone(),
n.color,
new_left.map(Arc::new),
n.right.clone(),
);
new_node = self.fix_up(new_node);
(Some(new_node), inserted)
}
Ordering::Greater => {
let (new_right, inserted) = self.insert_recursive(n.right.clone(), value);
let mut new_node = Node::new_with_children(
n.value.clone(),
n.color,
n.left.clone(),
new_right.map(Arc::new),
);
new_node = self.fix_up(new_node);
(Some(new_node), inserted)
}
}
}
}
}
fn remove(&mut self, value: &Value) -> bool {
if !self.contains(value) {
return false;
}
if let Some(ref root) = self.root {
if !is_red(&root.left) && !is_red(&root.right) {
let cloned_root = (**root).clone();
self.root = Some(Arc::new(cloned_root.make_red()));
}
}
let root = self.root.take();
self.root = self.delete_recursive(root, value);
if let Some(root) = self.root.take() {
self.root = Some(Arc::new((*root).clone().make_black()));
}
true
}
fn delete_recursive(&self, node: Option<Arc<Node>>, value: &Value) -> Option<Arc<Node>> {
match node {
None => None,
Some(n) => {
match self.comparator.compare(value, &n.value) {
Ordering::Less => {
let mut new_node = (*n).clone();
if !is_red(&n.left) && n.left.as_ref().map(|l| !is_red(&l.left)).unwrap_or(true) {
new_node = self.move_red_left(new_node);
}
new_node.left = self.delete_recursive(new_node.left, value);
new_node.update_size();
Some(Arc::new(self.fix_up(new_node)))
}
_ => {
if is_red(&n.left) {
let mut new_node = self.rotate_right((*n).clone());
new_node.left = self.delete_recursive(new_node.left, value);
new_node.update_size();
Some(Arc::new(self.fix_up(new_node)))
} else {
if self.comparator.compare(value, &n.value) == Ordering::Equal && n.right.is_none() {
return None;
}
let mut new_node = (*n).clone();
if !is_red(&n.right) && n.right.as_ref().map(|r| !is_red(&r.left)).unwrap_or(true) {
new_node = self.move_red_right(new_node);
}
if self.comparator.compare(value, &new_node.value) == Ordering::Equal {
if let Some(min_val) = Self::find_min(&new_node.right) {
new_node.value = min_val;
new_node.right = self.delete_min(new_node.right);
} else {
return None;
}
} else {
new_node.right = self.delete_recursive(new_node.right, value);
}
new_node.update_size();
Some(Arc::new(self.fix_up(new_node)))
}
}
}
}
}
}
fn find_min(node: &Option<Arc<Node>>) -> Option<Value> {
match node {
None => None,
Some(n) => {
if n.left.is_none() {
Some(n.value.clone())
} else {
Self::find_min(&n.left)
}
}
}
}
fn delete_min(&self, node: Option<Arc<Node>>) -> Option<Arc<Node>> {
match node {
None => None,
Some(n) => {
if n.left.is_none() {
return n.right.clone();
}
let mut new_node = (*n).clone();
if !is_red(&n.left) && n.left.as_ref().map(|l| !is_red(&l.left)).unwrap_or(true) {
new_node = self.move_red_left(new_node);
}
new_node.left = self.delete_min(new_node.left);
new_node.update_size();
Some(Arc::new(self.fix_up(new_node)))
}
}
}
fn rotate_left(&self, mut node: Node) -> Node {
if let Some(right) = node.right.take() {
let new_right = right.left.clone();
let mut new_root = (*right).clone();
new_root.left = Some(Arc::new(Node::new_with_children(
node.value,
Color::Red,
node.left,
new_right,
)));
new_root.color = node.color;
new_root.update_size();
new_root
} else {
node
}
}
fn rotate_right(&self, mut node: Node) -> Node {
if let Some(left) = node.left.take() {
let new_left = left.right.clone();
let mut new_root = (*left).clone();
new_root.right = Some(Arc::new(Node::new_with_children(
node.value,
Color::Red,
new_left,
node.right,
)));
new_root.color = node.color;
new_root.update_size();
new_root
} else {
node
}
}
fn flip_colors(&self, mut node: Node) -> Node {
node.color = match node.color {
Color::Red => Color::Black,
Color::Black => Color::Red,
};
if let Some(left) = &node.left {
let cloned_left = (**left).clone();
node.left = Some(Arc::new(cloned_left.make_red()));
}
if let Some(right) = &node.right {
let cloned_right = (**right).clone();
node.right = Some(Arc::new(cloned_right.make_red()));
}
node
}
fn fix_up(&self, mut node: Node) -> Node {
if is_red(&node.right) && !is_red(&node.left) {
node = self.rotate_left(node);
}
if is_red(&node.left) && node.left.as_ref().map(|l| is_red(&l.left)).unwrap_or(false) {
node = self.rotate_right(node);
}
if is_red(&node.left) && is_red(&node.right) {
node = self.flip_colors(node);
}
node.update_size();
node
}
fn move_red_left(&self, mut node: Node) -> Node {
node = self.flip_colors(node);
if node.right.as_ref().map(|r| is_red(&r.left)).unwrap_or(false) {
if let Some(right) = node.right.take() {
node.right = Some(Arc::new(self.rotate_right((*right).clone())));
}
node = self.rotate_left(node);
node = self.flip_colors(node);
}
node
}
fn move_red_right(&self, mut node: Node) -> Node {
node = self.flip_colors(node);
if node.left.as_ref().map(|l| is_red(&l.left)).unwrap_or(false) {
node = self.rotate_right(node);
node = self.flip_colors(node);
}
node
}
fn to_vec(&self) -> Vec<Value> {
let mut result = Vec::new();
Self::inorder_traversal(&self.root, &mut result);
result
}
fn inorder_traversal(node: &Option<Arc<Node>>, result: &mut Vec<Value>) {
if let Some(n) = node {
Self::inorder_traversal(&n.left, result);
result.push(n.value.clone());
Self::inorder_traversal(&n.right, result);
}
}
fn min(&self) -> Option<Value> {
Self::find_min(&self.root)
}
fn max(&self) -> Option<Value> {
Self::find_max(&self.root)
}
fn find_max(node: &Option<Arc<Node>>) -> Option<Value> {
match node {
None => None,
Some(n) => {
if n.right.is_none() {
Some(n.value.clone())
} else {
Self::find_max(&n.right)
}
}
}
}
fn range(&self, low: &Value, high: &Value) -> Vec<Value> {
let mut result = Vec::new();
self.range_search(&self.root, low, high, &mut result);
result
}
fn range_search(&self, node: &Option<Arc<Node>>, low: &Value, high: &Value, result: &mut Vec<Value>) {
if let Some(n) = node {
let cmp_low = self.comparator.compare(&n.value, low);
let cmp_high = self.comparator.compare(&n.value, high);
if cmp_low != Ordering::Less {
self.range_search(&n.left, low, high, result);
}
if cmp_low != Ordering::Less && cmp_high != Ordering::Greater {
result.push(n.value.clone());
}
if cmp_high != Ordering::Greater {
self.range_search(&n.right, low, high, result);
}
}
}
}
#[derive(Clone, Debug)]
pub struct OrderedSet {
tree: RedBlackTree,
name: Option<String>,
}
impl OrderedSet {
pub fn new() -> Self {
Self {
tree: RedBlackTree::new(Comparator::with_default()),
name: None,
}
}
pub fn with_comparator(comparator: Comparator) -> Self {
Self {
tree: RedBlackTree::new(comparator),
name: None,
}
}
pub fn with_name(name: impl Into<String>) -> Self {
let mut set = Self::new();
set.name = Some(name.into());
set
}
pub fn from_vec(values: Vec<Value>) -> Self {
let mut set = Self::new();
for value in values {
set.insert(value);
}
set
}
pub fn insert(&mut self, value: Value) -> bool {
self.tree.insert(value)
}
pub fn remove(&mut self, value: &Value) -> bool {
self.tree.remove(value)
}
pub fn contains(&self, value: &Value) -> bool {
self.tree.contains(value)
}
pub fn min(&self) -> Option<Value> {
self.tree.min()
}
pub fn max(&self) -> Option<Value> {
self.tree.max()
}
pub fn range(&self, low: &Value, high: &Value) -> Vec<Value> {
self.tree.range(low, high)
}
pub fn to_vec(&self) -> Vec<Value> {
self.tree.to_vec()
}
pub fn iter(&self) -> impl Iterator<Item = Value> + '_ {
self.to_vec().into_iter()
}
pub fn union(&self, other: &Self) -> Self {
let mut result = self.clone();
for value in other.iter() {
result.insert(value);
}
result
}
pub fn intersection(&self, other: &Self) -> Self {
let mut result = Self::with_comparator(self.tree.comparator.clone());
for value in self.iter() {
if other.contains(&value) {
result.insert(value);
}
}
result
}
pub fn difference(&self, other: &Self) -> Self {
let mut result = Self::with_comparator(self.tree.comparator.clone());
for value in self.iter() {
if !other.contains(&value) {
result.insert(value);
}
}
result
}
pub fn symmetric_difference(&self, other: &Self) -> Self {
let mut result = Self::with_comparator(self.tree.comparator.clone());
for value in self.iter() {
if !other.contains(&value) {
result.insert(value);
}
}
for value in other.iter() {
if !self.contains(&value) {
result.insert(value);
}
}
result
}
pub fn is_subset(&self, other: &Self) -> bool {
self.iter().all(|value| other.contains(&value))
}
pub fn is_superset(&self, other: &Self) -> bool {
other.is_subset(self)
}
pub fn is_disjoint(&self, other: &Self) -> bool {
self.iter().all(|value| !other.contains(&value))
}
pub fn filter<F>(&self, mut predicate: F) -> Self
where
F: FnMut(&Value) -> bool,
{
let mut result = Self::with_comparator(self.tree.comparator.clone());
for value in self.iter() {
if predicate(&value) {
result.insert(value);
}
}
result
}
pub fn map<F>(&self, mut f: F) -> Self
where
F: FnMut(&Value) -> Value,
{
let mut result = Self::with_comparator(self.tree.comparator.clone());
for value in self.iter() {
result.insert(f(&value));
}
result
}
pub fn fold<F, Acc>(&self, init: Acc, mut f: F) -> Acc
where
F: FnMut(Acc, &Value) -> Acc,
{
self.iter().fold(init, |acc, value| f(acc, &value))
}
pub fn partition<F>(&self, mut predicate: F) -> (Self, Self)
where
F: FnMut(&Value) -> bool,
{
let mut true_set = Self::with_comparator(self.tree.comparator.clone());
let mut false_set = Self::with_comparator(self.tree.comparator.clone());
for value in self.iter() {
if predicate(&value) {
true_set.insert(value);
} else {
false_set.insert(value);
}
}
(true_set, false_set)
}
}
impl Container for OrderedSet {
fn len(&self) -> usize {
self.tree.size()
}
fn clear(&mut self) {
self.tree.root = None;
}
}
impl Default for OrderedSet {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct ThreadSafeOrderedSet {
inner: Arc<RwLock<OrderedSet>>,
}
impl ThreadSafeOrderedSet {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(OrderedSet::new())),
}
}
pub fn with_comparator(comparator: Comparator) -> Self {
Self {
inner: Arc::new(RwLock::new(OrderedSet::with_comparator(comparator))),
}
}
pub fn insert(&self, value: Value) -> bool {
self.inner.write().unwrap().insert(value)
}
pub fn remove(&self, value: &Value) -> bool {
self.inner.write().unwrap().remove(value)
}
pub fn contains(&self, value: &Value) -> bool {
self.inner.read().unwrap().contains(value)
}
pub fn min(&self) -> Option<Value> {
self.inner.read().unwrap().min()
}
pub fn max(&self) -> Option<Value> {
self.inner.read().unwrap().max()
}
pub fn range(&self, low: &Value, high: &Value) -> Vec<Value> {
self.inner.read().unwrap().range(low, high)
}
pub fn len(&self) -> usize {
self.inner.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.inner.read().unwrap().is_empty()
}
pub fn clear(&self) {
self.inner.write().unwrap().clear();
}
pub fn to_vec(&self) -> Vec<Value> {
self.inner.read().unwrap().to_vec()
}
pub fn union(&self, other: &Self) -> Self {
let self_set = self.inner.read().unwrap();
let other_set = other.inner.read().unwrap();
Self {
inner: Arc::new(RwLock::new(self_set.union(&other_set))),
}
}
pub fn intersection(&self, other: &Self) -> Self {
let self_set = self.inner.read().unwrap();
let other_set = other.inner.read().unwrap();
Self {
inner: Arc::new(RwLock::new(self_set.intersection(&other_set))),
}
}
pub fn difference(&self, other: &Self) -> Self {
let self_set = self.inner.read().unwrap();
let other_set = other.inner.read().unwrap();
Self {
inner: Arc::new(RwLock::new(self_set.difference(&other_set))),
}
}
pub fn is_subset(&self, other: &Self) -> bool {
let self_set = self.inner.read().unwrap();
let other_set = other.inner.read().unwrap();
self_set.is_subset(&other_set)
}
pub fn with_read<F, R>(&self, f: F) -> R
where
F: FnOnce(&OrderedSet) -> R,
{
f(&self.inner.read().unwrap())
}
pub fn with_write<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut OrderedSet) -> R,
{
f(&mut self.inner.write().unwrap())
}
}
impl Default for ThreadSafeOrderedSet {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let mut set = OrderedSet::new();
assert!(set.is_empty());
assert_eq!(set.len(), 0);
assert!(set.insert(Value::number(3.0)));
assert!(set.insert(Value::number(1.0)));
assert!(set.insert(Value::number(2.0)));
assert!(!set.insert(Value::number(2.0)));
assert_eq!(set.len(), 3);
assert!(set.contains(&Value::number(2.0)));
assert!(!set.contains(&Value::number(4.0)));
let values = set.to_vec();
assert_eq!(values, vec![
Value::number(1.0),
Value::number(2.0),
Value::number(3.0),
]);
assert!(set.remove(&Value::number(2.0)));
assert!(!set.remove(&Value::number(4.0))); assert_eq!(set.len(), 2);
assert!(!set.contains(&Value::number(2.0)));
}
#[test]
fn test_min_max() {
let mut set = OrderedSet::new();
assert_eq!(set.min(), None);
assert_eq!(set.max(), None);
set.insert(Value::number(5.0));
set.insert(Value::number(1.0));
set.insert(Value::number(9.0));
assert_eq!(set.min(), Some(Value::number(1.0)));
assert_eq!(set.max(), Some(Value::number(9.0)));
}
#[test]
fn test_range() {
let mut set = OrderedSet::new();
for i in 1..=10 {
set.insert(Value::number(i as f64));
}
let range = set.range(&Value::number(3.0), &Value::number(7.0));
assert_eq!(range, vec![
Value::number(3.0),
Value::number(4.0),
Value::number(5.0),
Value::number(6.0),
Value::number(7.0),
]);
}
#[test]
fn test_set_operations() {
let mut set1 = OrderedSet::new();
let mut set2 = OrderedSet::new();
for i in 1..=5 {
set1.insert(Value::number(i as f64));
}
for i in 3..=7 {
set2.insert(Value::number(i as f64));
}
let union = set1.union(&set2);
assert_eq!(union.len(), 7);
assert!(union.contains(&Value::number(1.0)));
assert!(union.contains(&Value::number(7.0)));
let intersection = set1.intersection(&set2);
assert_eq!(intersection.len(), 3);
assert!(intersection.contains(&Value::number(3.0)));
assert!(intersection.contains(&Value::number(4.0)));
assert!(intersection.contains(&Value::number(5.0)));
let difference = set1.difference(&set2);
assert_eq!(difference.len(), 2);
assert!(difference.contains(&Value::number(1.0)));
assert!(difference.contains(&Value::number(2.0)));
let sym_diff = set1.symmetric_difference(&set2);
assert_eq!(sym_diff.len(), 4);
assert!(sym_diff.contains(&Value::number(1.0)));
assert!(sym_diff.contains(&Value::number(2.0)));
assert!(sym_diff.contains(&Value::number(6.0)));
assert!(sym_diff.contains(&Value::number(7.0)));
}
#[test]
fn test_subset_operations() {
let mut set1 = OrderedSet::new();
let mut set2 = OrderedSet::new();
let mut set3 = OrderedSet::new();
for i in 1..=3 {
set1.insert(Value::number(i as f64));
}
for i in 1..=5 {
set2.insert(Value::number(i as f64));
}
for i in 6..=8 {
set3.insert(Value::number(i as f64));
}
assert!(set1.is_subset(&set2));
assert!(set2.is_superset(&set1));
assert!(!set1.is_subset(&set3));
assert!(set1.is_disjoint(&set3));
assert!(!set1.is_disjoint(&set2));
}
#[test]
fn test_functional_operations() {
let mut set = OrderedSet::new();
for i in 1..=5 {
set.insert(Value::number(i as f64));
}
let evens = set.filter(|v| {
if let Some(n) = v.as_number() {
n as i64 % 2 == 0
} else {
false
}
});
assert_eq!(evens.len(), 2);
assert!(evens.contains(&Value::number(2.0)));
assert!(evens.contains(&Value::number(4.0)));
let doubled = set.map(|v| {
if let Some(n) = v.as_number() {
Value::number(n * 2.0)
} else {
v.clone()
}
});
assert_eq!(doubled.len(), 5);
assert!(doubled.contains(&Value::number(2.0)));
assert!(doubled.contains(&Value::number(10.0)));
let sum = set.fold(0.0, |acc, v| {
acc + v.as_number().unwrap_or(0.0)
});
assert_eq!(sum, 15.0);
let (odds, evens) = set.partition(|v| {
if let Some(n) = v.as_number() {
n as i64 % 2 == 1
} else {
false
}
});
assert_eq!(odds.len(), 3);
assert_eq!(evens.len(), 2);
}
#[test]
fn test_large_set() {
let mut set = OrderedSet::new();
for i in 0..1000 {
assert!(set.insert(Value::number((i * 17) as f64 % 1000.0)));
}
assert_eq!(set.len(), 1000);
let values = set.to_vec();
for i in 1..values.len() {
let prev = values[i - 1].as_number().unwrap();
let curr = values[i].as_number().unwrap();
assert!(prev < curr);
}
for i in 0..500 {
set.remove(&Value::number(i as f64));
}
assert_eq!(set.len(), 500);
}
#[test]
fn test_thread_safe_ordered_set() {
let set = ThreadSafeOrderedSet::new();
assert!(set.insert(Value::number(1.0)));
assert!(set.insert(Value::number(3.0)));
assert!(set.insert(Value::number(2.0)));
assert_eq!(set.len(), 3);
assert!(set.contains(&Value::number(2.0)));
let values = set.to_vec();
assert_eq!(values, vec![
Value::number(1.0),
Value::number(2.0),
Value::number(3.0),
]);
assert!(set.remove(&Value::number(2.0)));
assert_eq!(set.len(), 2);
}
#[test]
fn test_custom_comparator() {
let comparator = Comparator::new(
"reverse-numeric",
|a, b| {
match (a.as_number(), b.as_number()) {
(Some(n1), Some(n2)) => n2.partial_cmp(&n1).unwrap_or(std::cmp::Ordering::Equal),
_ => std::cmp::Ordering::Equal,
}
},
|v| v.as_number().map(|n| n.to_bits()).unwrap_or(0),
);
let mut set = OrderedSet::with_comparator(comparator);
set.insert(Value::number(1.0));
set.insert(Value::number(3.0));
set.insert(Value::number(2.0));
let values = set.to_vec();
assert_eq!(values, vec![
Value::number(3.0),
Value::number(2.0),
Value::number(1.0),
]);
}
}