#![cfg(feature = "proc_macro")]
use std::sync::atomic::{AtomicUsize, Ordering};
static CLONES: AtomicUsize = AtomicUsize::new(0);
#[derive(PartialEq, Debug)]
struct Counted(u32);
impl Clone for Counted {
fn clone(&self) -> Self {
CLONES.fetch_add(1, Ordering::SeqCst);
Counted(self.0)
}
}
fn counted_to_string(v: &Counted) -> String {
v.0.to_string()
}
fn counted_from_str(s: &str) -> Counted {
Counted(s.parse().expect("parse Counted"))
}
mod stores {
use super::{Counted, counted_from_str, counted_to_string};
use cached::{ConcurrentCacheBase, ConcurrentCached, SerializeCached};
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Mutex;
pub struct SerStore {
map: Mutex<HashMap<u32, String>>,
}
impl SerStore {
pub fn new() -> Self {
SerStore {
map: Mutex::new(HashMap::new()),
}
}
}
impl ConcurrentCacheBase for SerStore {
type Error = Infallible;
}
impl ConcurrentCached<u32, Counted> for SerStore {
fn cache_get(&self, k: &u32) -> Result<Option<Counted>, Infallible> {
let map = self.map.lock().unwrap();
Ok(map.get(k).map(|s| counted_from_str(s)))
}
fn cache_set(&self, k: u32, v: Counted) -> Result<Option<Counted>, Infallible> {
let s = counted_to_string(&v);
let mut map = self.map.lock().unwrap();
Ok(map.insert(k, s).map(|s| counted_from_str(&s)))
}
fn cache_remove(&self, k: &u32) -> Result<Option<Counted>, Infallible> {
let mut map = self.map.lock().unwrap();
Ok(map.remove(k).map(|s| counted_from_str(&s)))
}
fn cache_remove_entry(&self, k: &u32) -> Result<Option<(u32, Counted)>, Infallible> {
let mut map = self.map.lock().unwrap();
Ok(map.remove_entry(k).map(|(k, s)| (k, counted_from_str(&s))))
}
fn cache_clear(&self) -> Result<(), Infallible> {
self.map.lock().unwrap().clear();
Ok(())
}
fn cache_reset(&self) -> Result<(), Infallible> {
self.cache_clear()
}
}
impl SerializeCached<u32, Counted> for SerStore {
fn cache_set_ref(&self, k: &u32, v: &Counted) -> Result<Option<Counted>, Infallible> {
let s = counted_to_string(v);
let mut map = self.map.lock().unwrap();
Ok(map.insert(*k, s).map(|s| counted_from_str(&s)))
}
}
pub struct OwnedStore {
map: Mutex<HashMap<u32, Counted>>,
}
impl OwnedStore {
pub fn new() -> Self {
OwnedStore {
map: Mutex::new(HashMap::new()),
}
}
}
impl ConcurrentCacheBase for OwnedStore {
type Error = Infallible;
}
impl ConcurrentCached<u32, Counted> for OwnedStore {
fn cache_get(&self, k: &u32) -> Result<Option<Counted>, Infallible> {
let map = self.map.lock().unwrap();
Ok(map.get(k).cloned())
}
fn cache_set(&self, k: u32, v: Counted) -> Result<Option<Counted>, Infallible> {
let mut map = self.map.lock().unwrap();
Ok(map.insert(k, v))
}
fn cache_remove(&self, k: &u32) -> Result<Option<Counted>, Infallible> {
let mut map = self.map.lock().unwrap();
Ok(map.remove(k))
}
fn cache_remove_entry(&self, k: &u32) -> Result<Option<(u32, Counted)>, Infallible> {
let mut map = self.map.lock().unwrap();
Ok(map.remove_entry(k))
}
fn cache_clear(&self) -> Result<(), Infallible> {
self.map.lock().unwrap().clear();
Ok(())
}
fn cache_reset(&self) -> Result<(), Infallible> {
self.cache_clear()
}
}
}
mod fns {
use super::Counted;
use super::stores::{OwnedStore, SerStore};
use cached::macros::concurrent_cached;
#[concurrent_cached(
ty = "SerStore",
create = "{ SerStore::new() }",
map_error = "|e| e",
key = "u32",
convert = "{ n }"
)]
pub fn via_serialize(n: u32) -> Result<Counted, std::convert::Infallible> {
Ok(Counted(n))
}
#[concurrent_cached(
ty = "OwnedStore",
create = "{ OwnedStore::new() }",
map_error = "|e| e",
key = "u32",
convert = "{ n }"
)]
pub fn via_owned(n: u32) -> Result<Counted, std::convert::Infallible> {
Ok(Counted(n))
}
}
mod clone_count_tests {
use super::*;
use cached::ConcurrentCached;
use fns::{VIA_OWNED, VIA_SERIALIZE, via_owned, via_serialize};
use std::sync::Mutex;
static CLONE_TEST_LOCK: Mutex<()> = Mutex::new(());
#[test]
fn serialize_store_skips_value_clone() {
let _guard = CLONE_TEST_LOCK.lock().unwrap();
CLONES.store(0, Ordering::SeqCst);
VIA_SERIALIZE.cache_clear().unwrap();
let v = via_serialize(7).unwrap();
assert_eq!(v, Counted(7));
assert_eq!(
CLONES.load(Ordering::SeqCst),
0,
"SerializeCached path must not clone the value at the set site"
);
CLONES.store(0, Ordering::SeqCst);
let hit = via_serialize(7).unwrap();
assert_eq!(hit, Counted(7));
assert_eq!(
CLONES.load(Ordering::SeqCst),
0,
"Cache hit on SerializeCached path must not clone"
);
}
#[test]
fn owned_store_clones_once() {
let _guard = CLONE_TEST_LOCK.lock().unwrap();
CLONES.store(0, Ordering::SeqCst);
VIA_OWNED.cache_clear().unwrap();
let v = via_owned(7).unwrap();
assert_eq!(v, Counted(7));
assert_eq!(
CLONES.load(Ordering::SeqCst),
1,
"Owned fallback path must clone the value exactly once at the set site"
);
CLONES.store(0, Ordering::SeqCst);
let hit = via_owned(7).unwrap();
assert_eq!(hit, Counted(7), "Cache hit must return the correct value");
}
}
#[cfg(feature = "async")]
mod async_serialize_store {
use cached::{ConcurrentCacheBase, ConcurrentCachedAsync, SerializeCachedAsync};
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
static ASYNC_CLONES: AtomicUsize = AtomicUsize::new(0);
#[derive(PartialEq, Debug)]
pub struct AsyncVal(u32);
impl Clone for AsyncVal {
fn clone(&self) -> Self {
ASYNC_CLONES.fetch_add(1, Ordering::SeqCst);
AsyncVal(self.0)
}
}
fn async_val_to_string(v: &AsyncVal) -> String {
v.0.to_string()
}
fn async_val_from_str(s: &str) -> AsyncVal {
AsyncVal(s.parse().expect("parse AsyncVal"))
}
pub struct AsyncSerStore {
map: Mutex<HashMap<u32, String>>,
}
impl AsyncSerStore {
pub fn new() -> Self {
AsyncSerStore {
map: Mutex::new(HashMap::new()),
}
}
}
impl ConcurrentCacheBase for AsyncSerStore {
type Error = Infallible;
}
impl ConcurrentCachedAsync<u32, AsyncVal> for AsyncSerStore {
fn async_cache_get(
&self,
k: &u32,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let result = {
let map = self.map.lock().unwrap();
map.get(k).map(|s| async_val_from_str(s))
};
async move { Ok(result) }
}
fn async_cache_set(
&self,
k: u32,
v: AsyncVal,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let s = async_val_to_string(&v);
let prev = {
let mut map = self.map.lock().unwrap();
map.insert(k, s)
};
let prev_val = prev.map(|s| async_val_from_str(&s));
async move { Ok(prev_val) }
}
fn async_cache_remove(
&self,
k: &u32,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let result = {
let mut map = self.map.lock().unwrap();
map.remove(k).map(|s| async_val_from_str(&s))
};
async move { Ok(result) }
}
fn async_cache_remove_entry(
&self,
k: &u32,
) -> impl std::future::Future<Output = Result<Option<(u32, AsyncVal)>, Infallible>> + Send
{
let result = {
let mut map = self.map.lock().unwrap();
map.remove_entry(k)
.map(|(k, s)| (k, async_val_from_str(&s)))
};
async move { Ok(result) }
}
fn async_cache_clear(
&self,
) -> impl std::future::Future<Output = Result<(), Infallible>> + Send
where
Self: Sync,
{
self.map.lock().unwrap().clear();
async move { Ok(()) }
}
fn async_cache_reset(
&self,
) -> impl std::future::Future<Output = Result<(), Infallible>> + Send
where
Self: Sync,
{
self.map.lock().unwrap().clear();
async move { Ok(()) }
}
}
impl SerializeCachedAsync<u32, AsyncVal> for AsyncSerStore {
fn async_cache_set_ref(
&self,
k: &u32,
v: &AsyncVal,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let s = async_val_to_string(v);
let k = *k;
let prev = {
let mut map = self.map.lock().unwrap();
map.insert(k, s)
};
let prev_val = prev.map(|s| async_val_from_str(&s));
async move { Ok(prev_val) }
}
}
pub struct AsyncOwnedStore {
map: Mutex<HashMap<u32, AsyncVal>>,
}
impl AsyncOwnedStore {
pub fn new() -> Self {
AsyncOwnedStore {
map: Mutex::new(HashMap::new()),
}
}
}
impl ConcurrentCacheBase for AsyncOwnedStore {
type Error = Infallible;
}
impl ConcurrentCachedAsync<u32, AsyncVal> for AsyncOwnedStore {
fn async_cache_get(
&self,
k: &u32,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let result = {
let map = self.map.lock().unwrap();
map.get(k).cloned()
};
async move { Ok(result) }
}
fn async_cache_set(
&self,
k: u32,
v: AsyncVal,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let prev = {
let mut map = self.map.lock().unwrap();
map.insert(k, v)
};
async move { Ok(prev) }
}
fn async_cache_remove(
&self,
k: &u32,
) -> impl std::future::Future<Output = Result<Option<AsyncVal>, Infallible>> + Send
{
let result = {
let mut map = self.map.lock().unwrap();
map.remove(k)
};
async move { Ok(result) }
}
fn async_cache_remove_entry(
&self,
k: &u32,
) -> impl std::future::Future<Output = Result<Option<(u32, AsyncVal)>, Infallible>> + Send
{
let result = {
let mut map = self.map.lock().unwrap();
map.remove_entry(k)
};
async move { Ok(result) }
}
fn async_cache_clear(
&self,
) -> impl std::future::Future<Output = Result<(), Infallible>> + Send
where
Self: Sync,
{
self.map.lock().unwrap().clear();
async move { Ok(()) }
}
fn async_cache_reset(
&self,
) -> impl std::future::Future<Output = Result<(), Infallible>> + Send
where
Self: Sync,
{
self.map.lock().unwrap().clear();
async move { Ok(()) }
}
}
use cached::macros::concurrent_cached;
#[concurrent_cached(
ty = "AsyncSerStore",
create = "{ AsyncSerStore::new() }",
map_error = "|e| e",
key = "u32",
convert = "{ n }"
)]
async fn async_via_serialize(n: u32) -> Result<AsyncVal, Infallible> {
Ok(AsyncVal(n))
}
#[concurrent_cached(
ty = "AsyncOwnedStore",
create = "{ AsyncOwnedStore::new() }",
map_error = "|e| e",
key = "u32",
convert = "{ n }"
)]
async fn async_via_owned(n: u32) -> Result<AsyncVal, Infallible> {
Ok(AsyncVal(n))
}
#[tokio::test]
#[serial_test::serial(async_clones)]
async fn async_serialize_store_skips_value_clone() {
ASYNC_CLONES.store(0, Ordering::SeqCst);
if let Some(store) = ASYNC_VIA_SERIALIZE.get() {
store.async_cache_clear().await.unwrap();
}
let first = async_via_serialize(42).await.unwrap();
assert_eq!(first, AsyncVal(42));
assert_eq!(
ASYNC_CLONES.load(Ordering::SeqCst),
0,
"SerializeCachedAsync path must not clone the value at the set site"
);
ASYNC_CLONES.store(0, Ordering::SeqCst);
let second = async_via_serialize(42).await.unwrap();
assert_eq!(second, AsyncVal(42));
assert_eq!(
ASYNC_CLONES.load(Ordering::SeqCst),
0,
"Cache hit on SerializeCachedAsync path must not clone"
);
ASYNC_CLONES.store(0, Ordering::SeqCst);
let other = async_via_serialize(99).await.unwrap();
assert_eq!(other, AsyncVal(99));
assert_eq!(ASYNC_CLONES.load(Ordering::SeqCst), 0);
}
#[tokio::test]
#[serial_test::serial(async_clones)]
async fn async_owned_store_clones_once() {
ASYNC_CLONES.store(0, Ordering::SeqCst);
if let Some(store) = ASYNC_VIA_OWNED.get() {
store.async_cache_clear().await.unwrap();
}
let v = async_via_owned(7).await.unwrap();
assert_eq!(v, AsyncVal(7));
assert_eq!(
ASYNC_CLONES.load(Ordering::SeqCst),
1,
"Owned async fallback path must clone the value exactly once at the set site"
);
ASYNC_CLONES.store(0, Ordering::SeqCst);
let hit = async_via_owned(7).await.unwrap();
assert_eq!(hit, AsyncVal(7), "Cache hit must return the correct value");
}
}