use dashmap::DashMap;
use parking_lot::RwLock as PLRwLock;
use rand::Rng;
use std::collections::HashMap;
use std::collections::hash_map::RandomState;
use std::fmt::Debug;
use std::hash::{BuildHasher, Hash, Hasher};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct ConcurrentHashMap<K, V, S = RandomState>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
segments: Vec<Arc<Segment<K, V, S>>>,
hasher_builder: S,
concurrency_level: usize,
}
#[derive(Debug)]
struct Segment<K, V, S>
where
K: Hash + Eq + Send + Sync + 'static,
V: Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + 'static,
{
data: DashMap<K, V, S>,
stats: PLRwLock<SegmentStats>,
}
#[derive(Debug, Default, Clone)]
struct SegmentStats {
operations: u64,
contention_count: u64,
last_resize: Option<Instant>,
avg_operation_time: Duration,
}
impl<K, V, S> ConcurrentHashMap<K, V, S>
where
K: Hash + Eq + Clone + Send + Sync + 'static,
V: Clone + Send + Sync + 'static,
S: BuildHasher + Clone + Send + Sync + Default + 'static,
{
pub fn new() -> Self {
Self::with_capacity_and_hasher(16, S::default())
}
pub fn with_capacity_and_hasher(capacity: usize, hasher_builder: S) -> Self {
let concurrency_level = num_cpus::get().max(16);
let segment_capacity = (capacity + concurrency_level - 1) / concurrency_level;
let segments = (0..concurrency_level)
.map(|_| {
Arc::new(Segment {
data: DashMap::with_capacity_and_hasher(
segment_capacity,
hasher_builder.clone(),
),
stats: PLRwLock::new(SegmentStats::default()),
})
})
.collect();
Self {
segments,
hasher_builder,
concurrency_level,
}
}
pub fn insert(&self, key: K, value: V) -> Option<V> {
let start = Instant::now();
let segment_index = self.get_segment_index(&key);
let segment = &self.segments[segment_index];
let result = segment.data.insert(key, value);
let mut stats = segment.stats.write();
stats.operations += 1;
stats.avg_operation_time = (stats.avg_operation_time + start.elapsed()) / 2;
result
}
pub fn get(&self, key: &K) -> Option<V> {
let start = Instant::now();
let segment_index = self.get_segment_index(key);
let segment = &self.segments[segment_index];
let result = segment.data.get(key).map(|v| v.clone());
let mut stats = segment.stats.write();
stats.operations += 1;
stats.avg_operation_time = (stats.avg_operation_time + start.elapsed()) / 2;
result
}
pub fn remove(&self, key: &K) -> Option<V> {
let start = Instant::now();
let segment_index = self.get_segment_index(key);
let segment = &self.segments[segment_index];
let result = segment.data.remove(key).map(|(_, v)| v);
let mut stats = segment.stats.write();
stats.operations += 1;
stats.avg_operation_time = (stats.avg_operation_time + start.elapsed()) / 2;
result
}
pub fn get_stats(&self) -> HashMap<usize, SegmentStats> {
self.segments
.iter()
.enumerate()
.map(|(i, segment)| (i, segment.stats.read().clone()))
.collect::<HashMap<_, _>>()
}
fn get_segment_index(&self, key: &K) -> usize {
let mut hasher = self.hasher_builder.build_hasher();
key.hash(&mut hasher);
(hasher.finish() as usize) % self.concurrency_level
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_basic_operations() {
let map = ConcurrentHashMap::<i32, String>::new();
assert_eq!(map.insert(1, "one".to_string()), None);
assert_eq!(map.insert(2, "two".to_string()), None);
assert_eq!(map.get(&1), Some("one".to_string()));
assert_eq!(map.get(&2), Some("two".to_string()));
assert_eq!(map.get(&3), None);
assert_eq!(map.insert(1, "ONE".to_string()), Some("one".to_string()));
assert_eq!(map.get(&1), Some("ONE".to_string()));
assert_eq!(map.remove(&1), Some("ONE".to_string()));
assert_eq!(map.get(&1), None);
}
#[test]
fn test_concurrent_operations() {
let map = Arc::new(ConcurrentHashMap::<i32, i32>::new());
let threads = 8;
let operations = 5000;
let mut handles = vec![];
for t in 0..threads {
let map = Arc::clone(&map);
let handle = thread::spawn(move || {
for i in 0..operations {
let key = i * threads + t;
map.insert(key as i32, key as i32);
}
});
handles.push(handle);
}
for _ in 0..threads {
let map = Arc::clone(&map);
let handle = thread::spawn(move || {
let mut rng = rand::thread_rng();
for _ in 0..operations {
let key = rng.gen_range(0..(operations * threads)) as i32;
let _ = map.get(&key);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
for t in 0..threads {
for i in 0..operations {
let key = i * threads + t;
assert_eq!(map.get(&(key as i32)), Some(key as i32));
}
}
}
#[test]
fn test_concurrent_operations_with_metrics() {
let map = Arc::new(ConcurrentHashMap::<i32, i32>::new());
let threads = 8;
let operations = 5000;
let start = Instant::now();
let write_stats = Arc::new(PLRwLock::new(Vec::new()));
let read_stats = Arc::new(PLRwLock::new(Vec::new()));
let mut handles = vec![];
for t in 0..threads {
let map = Arc::clone(&map);
let write_stats = Arc::clone(&write_stats);
let handle = thread::spawn(move || {
let thread_start = Instant::now();
for i in 0..operations {
let key = i * threads + t;
map.insert(key as i32, key as i32);
}
write_stats.write().push(thread_start.elapsed());
});
handles.push(handle);
}
for _ in 0..threads {
let map = Arc::clone(&map);
let read_stats = Arc::clone(&read_stats);
let handle = thread::spawn(move || {
let thread_start = Instant::now();
let mut rng = rand::thread_rng();
for _ in 0..operations {
let key = rng.gen_range(0..(operations * threads)) as i32;
let _ = map.get(&key);
}
read_stats.write().push(thread_start.elapsed());
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let total_time = start.elapsed();
println!("\nPerformance Metrics:");
println!("Total time: {:?}", total_time);
println!(
"Operations per second: {:.2}",
(threads * operations * 2) as f64 / total_time.as_secs_f64()
);
println!("\nWrite Thread Statistics:");
let write_times = write_stats.read();
println!(
"Average write thread time: {:?}",
write_times.iter().sum::<Duration>() / write_times.len() as u32
);
println!("\nRead Thread Statistics:");
let read_times = read_stats.read();
println!(
"Average read thread time: {:?}",
read_times.iter().sum::<Duration>() / read_times.len() as u32
);
let mut errors = 0;
for t in 0..threads {
for i in 0..operations {
let key = i * threads + t;
if map.get(&(key as i32)) != Some(key as i32) {
errors += 1;
}
}
}
assert_eq!(errors, 0, "发现 {} 个数据不一致", errors);
}
}