use crate::error::{Error, Result};
use std::collections::HashMap;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex};
struct Node<K, V> {
key: K,
value: V,
prev: Option<NonNull<Node<K, V>>>,
next: Option<NonNull<Node<K, V>>>,
}
impl<K, V> Node<K, V> {
fn new(key: K, value: V) -> Self {
Self {
key,
value,
prev: None,
next: None,
}
}
}
pub struct LruCache<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
inner: Arc<Mutex<LruCacheInner<K, V>>>,
}
struct LruCacheInner<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
capacity: usize,
map: HashMap<K, NonNull<Node<K, V>>>,
head: Option<NonNull<Node<K, V>>>,
tail: Option<NonNull<Node<K, V>>>,
len: usize,
}
impl<K, V> LruCache<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
pub fn new(capacity: usize) -> Self {
assert!(capacity > 0, "Capacity must be greater than 0");
Self {
inner: Arc::new(Mutex::new(LruCacheInner {
capacity,
map: HashMap::new(),
head: None,
tail: None,
len: 0,
})),
}
}
pub fn get(&self, key: &K) -> Result<Option<V>> {
let mut inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
if let Some(&node_ptr) = inner.map.get(key) {
unsafe {
let node_ref = node_ptr.as_ref();
let value = node_ref.value.clone();
inner.move_to_front(node_ptr);
Ok(Some(value))
}
} else {
Ok(None)
}
}
pub fn put(&self, key: K, value: V) -> Result<()> {
let mut inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
if let Some(&existing_node) = inner.map.get(&key) {
unsafe {
let mut existing_node_mut = existing_node;
let existing_ref = existing_node_mut.as_mut();
existing_ref.value = value;
inner.move_to_front(existing_node_mut);
}
} else {
let new_node = Box::new(Node::new(key.clone(), value));
let new_node_ptr = NonNull::from(Box::leak(new_node));
inner.map.insert(key, new_node_ptr);
unsafe {
inner.add_to_front(new_node_ptr);
}
inner.len += 1;
if inner.len > inner.capacity {
unsafe {
inner.remove_tail();
}
}
}
Ok(())
}
pub fn remove(&self, key: &K) -> Result<Option<V>> {
let mut inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
if let Some(node_ptr) = inner.map.remove(key) {
unsafe {
let value = node_ptr.as_ref().value.clone();
inner.remove_node(node_ptr);
inner.len -= 1;
let _ = Box::from_raw(node_ptr.as_ptr());
Ok(Some(value))
}
} else {
Ok(None)
}
}
pub fn contains_key(&self, key: &K) -> Result<bool> {
let inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
Ok(inner.map.contains_key(key))
}
pub fn len(&self) -> Result<usize> {
let inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
Ok(inner.len)
}
pub fn is_empty(&self) -> Result<bool> {
Ok(self.len()? == 0)
}
pub fn capacity(&self) -> Result<usize> {
let inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
Ok(inner.capacity)
}
pub fn clear(&self) -> Result<()> {
let mut inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
unsafe {
let mut current = inner.head;
while let Some(node_ptr) = current {
let node_ref = node_ptr.as_ref();
current = node_ref.next;
let _ = Box::from_raw(node_ptr.as_ptr());
}
}
inner.map.clear();
inner.head = None;
inner.tail = None;
inner.len = 0;
Ok(())
}
pub fn keys(&self) -> Result<Vec<K>> {
let inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
let mut keys = Vec::new();
let mut current = inner.head;
unsafe {
while let Some(node_ptr) = current {
let node_ref = node_ptr.as_ref();
keys.push(node_ref.key.clone());
current = node_ref.next;
}
}
Ok(keys)
}
pub fn peek_lru(&self) -> Result<Option<(K, V)>> {
let inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
if let Some(tail_ptr) = inner.tail {
unsafe {
let tail_ref = tail_ptr.as_ref();
Ok(Some((tail_ref.key.clone(), tail_ref.value.clone())))
}
} else {
Ok(None)
}
}
pub fn peek_mru(&self) -> Result<Option<(K, V)>> {
let inner = self.inner.lock()
.map_err(|_| Error::concurrency("Failed to acquire lock".to_string()))?;
if let Some(head_ptr) = inner.head {
unsafe {
let head_ref = head_ptr.as_ref();
Ok(Some((head_ref.key.clone(), head_ref.value.clone())))
}
} else {
Ok(None)
}
}
pub fn get_or_insert<F>(&self, key: K, compute_fn: F) -> Result<V>
where
F: FnOnce() -> V,
{
if let Some(value) = self.get(&key)? {
return Ok(value);
}
let value = compute_fn();
self.put(key, value.clone())?;
Ok(value)
}
}
impl<K, V> LruCacheInner<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
unsafe fn move_to_front(&mut self, node_ptr: NonNull<Node<K, V>>) {
unsafe {
self.remove_node(node_ptr);
self.add_to_front(node_ptr);
}
}
unsafe fn add_to_front(&mut self, mut node_ptr: NonNull<Node<K, V>>) {
let node_ref = unsafe { node_ptr.as_mut() };
node_ref.prev = None;
node_ref.next = self.head;
if let Some(mut old_head) = self.head {
unsafe { old_head.as_mut() }.prev = Some(node_ptr);
} else {
self.tail = Some(node_ptr);
}
self.head = Some(node_ptr);
}
unsafe fn remove_node(&mut self, node_ptr: NonNull<Node<K, V>>) {
let node_ref = unsafe { node_ptr.as_ref() };
if let Some(mut prev) = node_ref.prev {
unsafe { prev.as_mut() }.next = node_ref.next;
} else {
self.head = node_ref.next;
}
if let Some(mut next) = node_ref.next {
unsafe { next.as_mut() }.prev = node_ref.prev;
} else {
self.tail = node_ref.prev;
}
}
unsafe fn remove_tail(&mut self) {
if let Some(tail_ptr) = self.tail {
let tail_ref = unsafe { tail_ptr.as_ref() };
let key = tail_ref.key.clone();
self.map.remove(&key);
unsafe { self.remove_node(tail_ptr) };
self.len -= 1;
let _ = unsafe { Box::from_raw(tail_ptr.as_ptr()) };
}
}
}
impl<K, V> Clone for LruCache<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<K, V> Drop for LruCacheInner<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
fn drop(&mut self) {
unsafe {
let mut current = self.head;
while let Some(node_ptr) = current {
let node_ref = node_ptr.as_ref();
current = node_ref.next;
let _ = Box::from_raw(node_ptr.as_ptr());
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_operations() {
let cache: LruCache<String, String> = LruCache::new(2);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
assert_eq!(cache.get(&"key1".to_string()).unwrap(), Some("value1".to_string()));
assert!(cache.contains_key(&"key1".to_string()).unwrap());
assert!(!cache.contains_key(&"nonexistent".to_string()).unwrap());
assert_eq!(cache.len().unwrap(), 1);
assert!(!cache.is_empty().unwrap());
assert_eq!(cache.capacity().unwrap(), 2);
}
#[test]
fn test_lru_eviction() {
let cache: LruCache<String, String> = LruCache::new(2);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
cache.put("key3".to_string(), "value3".to_string()).unwrap();
assert_eq!(cache.len().unwrap(), 2);
assert_eq!(cache.get(&"key1".to_string()).unwrap(), None); assert_eq!(cache.get(&"key2".to_string()).unwrap(), Some("value2".to_string()));
assert_eq!(cache.get(&"key3".to_string()).unwrap(), Some("value3".to_string()));
}
#[test]
fn test_lru_order() {
let cache: LruCache<String, String> = LruCache::new(3);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
cache.put("key3".to_string(), "value3".to_string()).unwrap();
cache.get(&"key1".to_string()).unwrap();
cache.put("key4".to_string(), "value4".to_string()).unwrap();
assert_eq!(cache.get(&"key2".to_string()).unwrap(), None); assert_eq!(cache.get(&"key1".to_string()).unwrap(), Some("value1".to_string())); assert_eq!(cache.get(&"key3".to_string()).unwrap(), Some("value3".to_string()));
assert_eq!(cache.get(&"key4".to_string()).unwrap(), Some("value4".to_string()));
}
#[test]
fn test_update_existing_key() {
let cache: LruCache<String, String> = LruCache::new(2);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
cache.put("key1".to_string(), "updated_value1".to_string()).unwrap();
assert_eq!(cache.len().unwrap(), 2);
assert_eq!(cache.get(&"key1".to_string()).unwrap(), Some("updated_value1".to_string()));
assert_eq!(cache.get(&"key2".to_string()).unwrap(), Some("value2".to_string()));
}
#[test]
fn test_remove() {
let cache: LruCache<String, String> = LruCache::new(3);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
let removed = cache.remove(&"key1".to_string()).unwrap();
assert_eq!(removed, Some("value1".to_string()));
assert_eq!(cache.len().unwrap(), 1);
assert_eq!(cache.get(&"key1".to_string()).unwrap(), None);
assert_eq!(cache.get(&"key2".to_string()).unwrap(), Some("value2".to_string()));
let removed = cache.remove(&"nonexistent".to_string()).unwrap();
assert_eq!(removed, None);
}
#[test]
fn test_clear() {
let cache: LruCache<String, String> = LruCache::new(3);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
cache.put("key3".to_string(), "value3".to_string()).unwrap();
cache.clear().unwrap();
assert!(cache.is_empty().unwrap());
assert_eq!(cache.len().unwrap(), 0);
assert_eq!(cache.get(&"key1".to_string()).unwrap(), None);
}
#[test]
fn test_keys() {
let cache: LruCache<String, String> = LruCache::new(3);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
cache.put("key3".to_string(), "value3".to_string()).unwrap();
let keys = cache.keys().unwrap();
assert_eq!(keys, vec!["key3".to_string(), "key2".to_string(), "key1".to_string()]);
cache.get(&"key1".to_string()).unwrap();
let keys = cache.keys().unwrap();
assert_eq!(keys, vec!["key1".to_string(), "key3".to_string(), "key2".to_string()]);
}
#[test]
fn test_peek_lru_mru() {
let cache: LruCache<String, String> = LruCache::new(3);
cache.put("key1".to_string(), "value1".to_string()).unwrap();
cache.put("key2".to_string(), "value2".to_string()).unwrap();
cache.put("key3".to_string(), "value3".to_string()).unwrap();
let (lru_key, lru_value) = cache.peek_lru().unwrap().unwrap();
assert_eq!(lru_key, "key1".to_string());
assert_eq!(lru_value, "value1".to_string());
let (mru_key, mru_value) = cache.peek_mru().unwrap().unwrap();
assert_eq!(mru_key, "key3".to_string());
assert_eq!(mru_value, "value3".to_string());
cache.get(&"key1".to_string()).unwrap();
let (lru_key, _) = cache.peek_lru().unwrap().unwrap();
assert_eq!(lru_key, "key2".to_string());
let (mru_key, _) = cache.peek_mru().unwrap().unwrap();
assert_eq!(mru_key, "key1".to_string());
}
#[test]
fn test_get_or_insert() {
let cache: LruCache<String, String> = LruCache::new(2);
let value = cache.get_or_insert("key".to_string(), || "computed".to_string()).unwrap();
assert_eq!(value, "computed".to_string());
let cached = cache.get_or_insert("key".to_string(), || "new_computed".to_string()).unwrap();
assert_eq!(cached, "computed".to_string());
assert_eq!(cache.len().unwrap(), 1);
}
#[test]
fn test_clone() {
let cache1 = LruCache::new(2);
cache1.put("key".to_string(), "value".to_string()).unwrap();
let cache2 = cache1.clone();
assert_eq!(cache2.get(&"key".to_string()).unwrap(), Some("value".to_string()));
cache2.put("key2".to_string(), "value2".to_string()).unwrap();
assert_eq!(cache1.get(&"key2".to_string()).unwrap(), Some("value2".to_string()));
}
#[test]
fn test_empty_cache() {
let cache: LruCache<String, String> = LruCache::new(1);
assert!(cache.is_empty().unwrap());
assert_eq!(cache.peek_lru().unwrap(), None);
assert_eq!(cache.peek_mru().unwrap(), None);
assert_eq!(cache.keys().unwrap(), Vec::<String>::new());
}
#[test]
#[should_panic(expected = "Capacity must be greater than 0")]
fn test_zero_capacity() {
LruCache::<i32, i32>::new(0);
}
}