use std::cell::UnsafeCell;
use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::mem;
use std::ptr::NonNull;
#[doc(hidden)]
pub(crate) mod private {
use super::*;
#[allow(dead_code)]
pub(crate) trait Cache<K, V>: Send + Sync
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
fn get(&self, key: &K) -> Option<V>;
fn put(&self, key: K, value: V) -> Option<V>;
fn remove(&self, key: &K) -> Option<V>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn clear(&self);
}
}
use private::Cache;
struct Node<K, V> {
key: K,
value: V,
prev: *mut Node<K, V>,
next: *mut Node<K, V>,
}
impl<K, V> Node<K, V> {
fn new(key: K, value: V) -> Self {
Self {
key,
value,
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
}
}
}
struct DoublyLinkedList<K, V> {
head: *mut Node<K, V>,
tail: *mut Node<K, V>,
len: usize,
}
impl<K, V> DoublyLinkedList<K, V> {
fn new() -> Self {
Self {
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
len: 0,
}
}
fn push_front(&mut self, node: *mut Node<K, V>) {
unsafe {
(*node).prev = std::ptr::null_mut();
(*node).next = self.head;
if !self.head.is_null() {
(*self.head).prev = node;
} else {
self.tail = node;
}
self.head = node;
self.len += 1;
}
}
fn remove(&mut self, node: *mut Node<K, V>) {
unsafe {
let prev = (*node).prev;
let next = (*node).next;
if !prev.is_null() {
(*prev).next = next;
} else {
self.head = next;
}
if !next.is_null() {
(*next).prev = prev;
} else {
self.tail = prev;
}
self.len -= 1;
}
}
fn pop_back(&mut self) -> Option<*mut Node<K, V>> {
if self.tail.is_null() {
return None;
}
unsafe {
let old_tail = self.tail;
let prev = (*old_tail).prev;
if !prev.is_null() {
(*prev).next = std::ptr::null_mut();
self.tail = prev;
} else {
self.head = std::ptr::null_mut();
self.tail = std::ptr::null_mut();
}
self.len -= 1;
Some(old_tail)
}
}
fn len(&self) -> usize {
self.len
}
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.len == 0
}
fn reinsert_front(&mut self, node: *mut Node<K, V>) {
self.remove(node);
self.push_front(node);
}
}
pub struct BasicLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
cap: usize,
list: UnsafeCell<DoublyLinkedList<K, V>>,
map: UnsafeCell<HashMap<K, NonNull<Node<K, V>>>>,
}
unsafe impl<K, V> Send for BasicLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
}
unsafe impl<K, V> Sync for BasicLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
}
impl<K, V> BasicLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "Capacity must be positive");
Self {
cap: capacity,
list: UnsafeCell::new(DoublyLinkedList::new()),
map: UnsafeCell::new(HashMap::with_capacity(capacity)),
}
}
pub fn capacity(&self) -> usize {
self.cap
}
pub fn get(&self, key: &K) -> Option<V> {
unsafe {
let map = &mut *self.map.get();
if let Some(entry) = map.get(key) {
let node_ptr = entry.as_ptr();
(*self.list.get()).reinsert_front(node_ptr);
Some((*node_ptr).value.clone())
} else {
None
}
}
}
pub fn put(&self, key: K, value: V) -> Option<V> {
unsafe {
let map = &mut *self.map.get();
if let Some(entry) = map.get(&key) {
let node_ptr = entry.as_ptr();
let old_value = mem::replace(&mut (*node_ptr).value, value);
(*self.list.get()).reinsert_front(node_ptr);
Some(old_value)
} else {
let new_node = Box::new(Node::new(key.clone(), value));
let node_ptr = Box::into_raw(new_node);
let list = &mut *self.list.get();
while list.len() >= self.cap {
if let Some(last_node) = list.pop_back() {
map.remove(&(*last_node).key);
drop(Box::from_raw(last_node));
}
}
list.push_front(node_ptr);
map.insert(key, NonNull::new_unchecked(node_ptr));
None
}
}
}
pub fn remove(&self, key: &K) -> Option<V> {
unsafe {
let map = &mut *self.map.get();
if let Some(entry) = map.remove(key) {
let node_ptr = entry.as_ptr();
let list = &mut *self.list.get();
list.remove(node_ptr);
let value = (*node_ptr).value.clone();
drop(Box::from_raw(node_ptr));
Some(value)
} else {
None
}
}
}
pub fn len(&self) -> usize {
unsafe { (*self.list.get()).len }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
unsafe {
let map = &mut *self.map.get();
let list = &mut *self.list.get();
let mut current = list.head;
while !current.is_null() {
let next = (*current).next;
drop(Box::from_raw(current));
current = next;
}
list.head = std::ptr::null_mut();
list.tail = std::ptr::null_mut();
list.len = 0;
map.clear();
}
}
}
impl<K, V> Cache<K, V> for BasicLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
fn get(&self, key: &K) -> Option<V> {
self.get(key)
}
fn put(&self, key: K, value: V) -> Option<V> {
self.put(key, value)
}
fn remove(&self, key: &K) -> Option<V> {
self.remove(key)
}
fn len(&self) -> usize {
self.len()
}
fn is_empty(&self) -> bool {
self.is_empty()
}
fn clear(&self) {
self.clear()
}
}
impl<K, V> Drop for BasicLruCache<K, V>
where
K: Clone + Debug + Hash + Eq + Send + Sync + 'static,
V: Clone + Debug + Send + Sync + 'static,
{
fn drop(&mut self) {
unsafe {
let list = &mut *self.list.get();
let mut current = list.head;
while !current.is_null() {
let next = (*current).next;
drop(Box::from_raw(current));
current = next;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let cache = BasicLruCache::new(2);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
assert_eq!(cache.put("key1".to_string(), "one".to_string()), None);
assert_eq!(cache.put("key2".to_string(), "two".to_string()), None);
assert!(!cache.is_empty());
assert_eq!(cache.len(), 2);
assert_eq!(cache.get(&"key1".to_string()), Some("one".to_string()));
assert_eq!(cache.get(&"key2".to_string()), Some("two".to_string()));
cache.put("key3".to_string(), "three".to_string());
assert!(cache.len() <= cache.capacity());
assert_eq!(cache.get(&"key1".to_string()), None);
assert_eq!(cache.get(&"key2".to_string()), Some("two".to_string()));
assert_eq!(cache.get(&"key3".to_string()), Some("three".to_string()));
}
#[test]
fn test_update_existing() {
let cache = BasicLruCache::new(2);
cache.put("key1".to_string(), "one".to_string());
assert_eq!(
cache.put("key1".to_string(), "new_one".to_string()),
Some("one".to_string())
);
assert_eq!(cache.get(&"key1".to_string()), Some("new_one".to_string()));
}
#[test]
fn test_clear() {
let cache = BasicLruCache::new(2);
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
cache.put("key1".to_string(), "one".to_string());
cache.put("key2".to_string(), "two".to_string());
assert_eq!(cache.len(), 2);
assert!(!cache.is_empty());
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.is_empty());
assert_eq!(cache.get(&"key1".to_string()), None);
assert_eq!(cache.get(&"key2".to_string()), None);
}
}