use std::{borrow::Borrow, collections::HashMap, hash::Hash, hash::Hasher};
#[derive(Debug, Clone, Copy)]
struct KeyRef<K> {
k: *const K,
}
impl<K: Hash> Hash for KeyRef<K> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
unsafe {
let key = &*self.k;
key.hash(state)
}
}
}
impl<K: PartialEq> PartialEq for KeyRef<K> {
fn eq(&self, other: &Self) -> bool {
unsafe {
let key1 = &*self.k;
let key2 = &*other.k;
key1.eq(key2)
}
}
}
impl<K: Eq> Eq for KeyRef<K> {}
#[repr(transparent)]
struct KeyWrapper<K: ?Sized>(K);
impl<K: ?Sized> KeyWrapper<K> {
fn from_ref(key: &K) -> &Self {
unsafe { &*(key as *const K as *const KeyWrapper<K>) }
}
}
impl<K: ?Sized + Hash> Hash for KeyWrapper<K> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.hash(state)
}
}
impl<K: ?Sized + PartialEq> PartialEq for KeyWrapper<K> {
#![allow(unknown_lints)]
#[allow(clippy::unconditional_recursion)]
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl<K: ?Sized + Eq> Eq for KeyWrapper<K> {}
impl<K, Q> Borrow<KeyWrapper<Q>> for KeyRef<K>
where
K: Borrow<Q>,
Q: ?Sized,
{
fn borrow(&self) -> &KeyWrapper<Q> {
unsafe {
let key = &*self.k;
KeyWrapper::from_ref(key.borrow())
}
}
}
struct Entry<K, V> {
key: K,
value: V,
prev: Option<*mut Entry<K, V>>,
next: Option<*mut Entry<K, V>>,
}
impl<K, V> Entry<K, V> {
fn new(key: K, value: V) -> Self {
Self {
key,
value,
prev: None,
next: None,
}
}
}
pub struct LruCache<K, V> {
capacity: usize,
map: HashMap<KeyRef<K>, Box<Entry<K, V>>>,
head: Option<*mut Entry<K, V>>,
tail: Option<*mut Entry<K, V>>,
_marker: std::marker::PhantomData<K>,
}
impl<K: Hash + Eq, V> LruCache<K, V> {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
map: HashMap::new(),
head: None,
tail: None,
_marker: std::marker::PhantomData,
}
}
pub fn get<'a, Q>(&'a mut self, k: &Q) -> Option<&'a V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
if let Some(entry) = self.map.get_mut(KeyWrapper::from_ref(k)) {
let entry_ptr: *mut Entry<K, V> = &mut **entry;
self.detach(entry_ptr);
self.attach(entry_ptr);
Some(&unsafe { &*entry_ptr }.value)
} else {
None
}
}
pub fn put(&mut self, key: K, mut value: V) -> Option<V> {
if let Some(existing_entry) = self.map.get_mut(KeyWrapper::from_ref(&key)) {
let entry = existing_entry.as_mut();
std::mem::swap(&mut entry.value, &mut value);
let entry_ptr: *mut Entry<K, V> = entry;
self.detach(entry_ptr);
self.attach(entry_ptr);
return Some(value);
}
let mut evicted_value = None;
if self.map.len() >= self.capacity {
if let Some(tail) = self.tail {
self.detach(tail);
unsafe {
if let Some(entry) = self.map.remove(KeyWrapper::from_ref(&(*tail).key)) {
evicted_value = Some(entry.value);
}
}
}
}
let new_entry = Box::new(Entry::new(key, value));
let key_ptr: *const K = &new_entry.key;
let entry_ptr = Box::into_raw(new_entry);
unsafe {
self.attach(entry_ptr);
self.map
.insert(KeyRef { k: key_ptr }, Box::from_raw(entry_ptr));
}
evicted_value
}
fn detach(&mut self, entry: *mut Entry<K, V>) {
unsafe {
let prev = (*entry).prev;
let next = (*entry).next;
match prev {
Some(prev) => (*prev).next = next,
None => self.head = next,
}
match next {
Some(next) => (*next).prev = prev,
None => self.tail = prev,
}
(*entry).prev = None;
(*entry).next = None;
}
}
fn attach(&mut self, entry: *mut Entry<K, V>) {
match self.head {
Some(head) => {
unsafe {
(*entry).next = Some(head);
(*head).prev = Some(entry);
}
self.head = Some(entry);
}
None => {
self.head = Some(entry);
self.tail = Some(entry);
}
}
}
pub fn contains<Q>(&self, k: &Q) -> bool
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.map.contains_key(KeyWrapper::from_ref(k))
}
pub fn peek<'a, Q>(&'a self, k: &Q) -> Option<&'a V>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.map
.get(KeyWrapper::from_ref(k))
.map(|entry| &entry.value)
}
pub fn pop_lru(&mut self) -> Option<(K, V)> {
if self.is_empty() {
return None;
}
let tail = self.tail?;
self.detach(tail);
unsafe {
self.map
.remove(KeyWrapper::from_ref(&(*tail).key))
.map(|entry| (entry.key, entry.value))
}
}
pub fn pop<Q>(&mut self, k: &Q) -> Option<(K, V)>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
match self.map.remove(KeyWrapper::from_ref(k)) {
None => None,
Some(entry) => {
let entry_ptr = Box::into_raw(entry);
self.detach(entry_ptr);
unsafe {
let entry = Box::from_raw(entry_ptr);
Some((entry.key, entry.value))
}
}
}
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
}
unsafe impl<K: Send, V: Send> Send for LruCache<K, V> {}
unsafe impl<K: Sync, V: Sync> Sync for LruCache<K, V> {}
impl<K, V> Drop for LruCache<K, V> {
fn drop(&mut self) {
self.map.clear();
self.head = None;
self.tail = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let test_cases = vec![
(5, 5),
(1, 1),
(0, 0),
(usize::MAX, usize::MAX),
];
for (capacity, expected_capacity) in test_cases {
let cache: LruCache<String, i32> = LruCache::new(capacity);
assert!(cache.is_empty());
assert_eq!(cache.capacity, expected_capacity);
}
}
#[test]
fn test_get() {
let mut cache: LruCache<String, i32> = LruCache::new(3);
let test_cases = vec![
("key1", 1, None),
("key2", 2, None),
("key3", 3, None),
("key2", 22, Some(2)),
("key4", 4, Some(1)),
];
for (key, value, expected_result) in test_cases {
let result = cache.put(key.to_string(), value);
assert_eq!(result, expected_result);
}
assert_eq!(cache.get(&"key1".to_string()), None);
assert_eq!(cache.get(&"key2".to_string()).copied(), Some(22));
assert_eq!(cache.get(&"key3".to_string()).copied(), Some(3));
assert_eq!(cache.get(&"key4".to_string()).copied(), Some(4));
}
#[test]
fn test_get_after_evction() {
let mut cache = LruCache::new(3);
assert_eq!(cache.get(&"nonexistent".to_string()), None);
for (key, value) in [("key1", 1), ("key2", 2), ("key3", 3)] {
cache.put(key.to_string(), value);
}
let test_cases = vec![
("key1", Some(1)),
("nonexistent", None),
("key1", Some(1)),
("key3", Some(3)),
];
for (key, expected_value) in test_cases {
assert_eq!(cache.get(&key.to_string()).copied(), expected_value);
}
cache.put("key4".to_string(), 4);
assert_eq!(cache.get(&"key1".to_string()).copied(), Some(1));
assert_eq!(cache.get(&"key2".to_string()), None);
assert_eq!(cache.get(&"key3".to_string()).copied(), Some(3));
assert_eq!(cache.get(&"key4".to_string()).copied(), Some(4));
}
#[test]
fn test_put() {
let mut cache = LruCache::new(3);
let test_cases = vec![
("key1", 1, None),
("key2", 2, None),
("key3", 3, None),
("key4", 4, Some(1)),
("key5", 5, Some(2)),
("key4", 44, Some(4)),
];
for (key, value, expected_result) in test_cases {
let result = cache.put(key.to_string(), value);
assert_eq!(result, expected_result);
}
assert_eq!(cache.get(&"key1".to_string()), None);
assert_eq!(cache.get(&"key2".to_string()), None);
assert_eq!(cache.get(&"key3".to_string()).copied(), Some(3));
assert_eq!(cache.get(&"key4".to_string()).copied(), Some(44));
assert_eq!(cache.get(&"key5".to_string()).copied(), Some(5));
}
#[test]
fn test_peek() {
let mut cache: LruCache<String, i32> = LruCache::new(3);
assert_eq!(cache.peek(&"nonexistent".to_string()), None);
for (key, value) in [("key1", 1), ("key2", 2), ("key3", 3)] {
cache.put(key.to_string(), value);
}
let test_cases = vec![
("nonexistent", None),
("key1", Some(1)),
("key2", Some(2)),
("key3", Some(3)),
];
for (key, expected_value) in test_cases {
assert_eq!(cache.peek(&key.to_string()).copied(), expected_value);
}
cache.put("key4".to_string(), 4);
assert_eq!(cache.peek(&"key1".to_string()), None);
assert_eq!(cache.peek(&"key2".to_string()).copied(), Some(2));
assert_eq!(cache.peek(&"key3".to_string()).copied(), Some(3));
assert_eq!(cache.peek(&"key4".to_string()).copied(), Some(4));
}
#[test]
fn test_contains() {
let mut cache: LruCache<String, i32> = LruCache::new(3);
assert!(!cache.contains(&"nonexistent".to_string()));
for (key, value) in [("key1", 1), ("key2", 2), ("key3", 3)] {
cache.put(key.to_string(), value);
}
let test_cases = vec![
("nonexistent", false),
("key1", true),
("key2", true),
("key3", true),
];
for (key, expected_result) in test_cases {
assert_eq!(cache.contains(&key.to_string()), expected_result);
}
cache.put("key4".to_string(), 4);
assert!(!cache.contains(&"key1".to_string()));
assert!(cache.contains(&"key2".to_string()));
assert!(cache.contains(&"key3".to_string()));
assert!(cache.contains(&"key4".to_string()));
}
#[test]
fn test_pop_lru() {
let mut cache: LruCache<String, i32> = LruCache::new(3);
assert_eq!(cache.pop_lru(), None);
for (key, value) in [("key1", 1), ("key2", 2), ("key3", 3)] {
cache.put(key.to_string(), value);
}
assert_eq!(cache.pop_lru(), Some(("key1".to_string(), 1)));
assert_eq!(cache.pop_lru(), Some(("key2".to_string(), 2)));
assert_eq!(cache.pop_lru(), Some(("key3".to_string(), 3)));
assert_eq!(cache.pop_lru(), None);
assert!(cache.is_empty());
}
#[test]
fn test_pop() {
let mut cache: LruCache<String, i32> = LruCache::new(3);
let test_cases = vec![
("key1".to_string(), Some(("key1".to_string(), 1))),
("key2".to_string(), Some(("key2".to_string(), 2))),
("key3".to_string(), Some(("key3".to_string(), 3))),
("key1".to_string(), None),
("key2".to_string(), None),
("key3".to_string(), None),
];
cache.put("key1".to_string(), 1);
cache.put("key2".to_string(), 2);
cache.put("key3".to_string(), 3);
for (key, expected) in test_cases {
assert_eq!(cache.pop(&key), expected);
}
assert!(cache.is_empty());
}
}