zenoh_sync/cache.rs
1//
2// Copyright (c) 2025 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12// ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14
15use std::sync::{
16 atomic::{AtomicBool, Ordering},
17 Arc,
18};
19
20use arc_swap::{ArcSwap, Guard};
21
22pub struct CacheValue<T: Sized> {
23 version: usize,
24 value: T,
25}
26
27impl<T> CacheValue<T> {
28 pub fn get_ref(&self) -> &T {
29 &self.value
30 }
31}
32
33/// This is a lock-free concurrent cache.
34/// It stores only the most up-to-date value.
35pub struct Cache<T> {
36 value: ArcSwap<CacheValue<T>>,
37 is_updating: AtomicBool,
38}
39
40pub type CacheValueType<T> = Guard<Arc<CacheValue<T>>>;
41
42impl<T> Cache<T> {
43 pub fn new(value: T, version: usize) -> Self {
44 Cache {
45 value: ArcSwap::new(CacheValue::<T> { version, value }.into()),
46 is_updating: AtomicBool::new(false),
47 }
48 }
49
50 fn finish_update(&self) {
51 self.is_updating.store(false, Ordering::SeqCst);
52 }
53
54 /// Tries to retrieve value for the specified version.
55 /// Returns a result either containing a cached value, or an f (which is guaranteed to be not invoked by function call in this case).
56 /// If requested version corresponds to the value currently stored in cache - the value is returned.
57 /// If requested version is older None will be returned.
58 /// If requested version is newer, the new value will be computed and stored by calling f, and then returned,
59 /// unless the value is being currently updated - in this case None will be returned.
60 /// If None is returned it is guaranteed that f was not called.
61 pub fn value(
62 &self,
63 version: usize,
64 f: impl FnOnce() -> T,
65 ) -> Result<CacheValueType<T>, impl FnOnce() -> T> {
66 let v = self.value.load();
67 match v.version.cmp(&version) {
68 std::cmp::Ordering::Equal => Ok(v),
69 std::cmp::Ordering::Greater => Err(f), //requesting too old version
70 std::cmp::Ordering::Less => {
71 // try to update
72 drop(v);
73 match self.is_updating.compare_exchange(
74 false,
75 true,
76 Ordering::SeqCst,
77 Ordering::SeqCst,
78 ) {
79 Ok(_) => {
80 let v = self.value.load();
81 match v.version.cmp(&version) {
82 std::cmp::Ordering::Equal => {
83 // already updated by someone else to the version we need
84 self.finish_update();
85 Ok(v)
86 }
87 std::cmp::Ordering::Greater => {
88 // already updated by someone else beyond the version we need
89 self.finish_update();
90 Err(f)
91 }
92 std::cmp::Ordering::Less => {
93 drop(v);
94 self.value.store(
95 CacheValue {
96 value: f(),
97 version,
98 }
99 .into(),
100 );
101 let v = self.value.load(); // is_updating set to true guarantees that nobody else will modify the value.
102 self.finish_update();
103 Ok(v)
104 }
105 }
106 }
107 Err(_) => Err(f),
108 }
109 }
110 }
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use std::{sync::Arc, time::Duration};
117
118 use super::Cache;
119
120 #[test]
121 fn test_cache() {
122 let cache = Cache::<String>::new("0".to_string(), 0);
123
124 assert_eq!(
125 cache
126 .value(0, || { "1".to_string() })
127 .as_ref()
128 .map(|v| v.get_ref().as_str())
129 .unwrap_or(""),
130 "0"
131 );
132 assert_eq!(
133 cache
134 .value(1, || { "1".to_string() })
135 .as_ref()
136 .map(|v| v.get_ref().as_str())
137 .unwrap_or(""),
138 "1"
139 );
140 assert!(cache.value(0, || { "2".to_string() }).is_err());
141
142 // try to get-update value from another thread
143 let cache = Arc::new(cache);
144 let cache2 = cache.clone();
145 std::thread::spawn(move || {
146 let res = cache2.value(2, || {
147 std::thread::sleep(Duration::from_secs(5));
148 "2".to_string()
149 });
150 assert_eq!(
151 res.as_ref().map(|v| v.get_ref().as_str()).unwrap_or(""),
152 "2"
153 );
154 });
155 std::thread::sleep(Duration::from_secs(1));
156 while cache.value(2, || "".to_string()).is_err() {
157 std::thread::sleep(Duration::from_secs(1));
158 }
159 assert_eq!(
160 cache
161 .value(2, || { "".to_string() })
162 .as_ref()
163 .map(|v| v.get_ref().as_str())
164 .unwrap_or(""),
165 "2"
166 );
167 }
168}