zenoh_sync/
cache.rs

1//
2// Copyright (c) 2025 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14
15use std::sync::{
16    atomic::{AtomicBool, Ordering},
17    Arc,
18};
19
20use arc_swap::{ArcSwap, Guard};
21
22pub struct CacheValue<T: Sized> {
23    version: usize,
24    value: T,
25}
26
27impl<T> CacheValue<T> {
28    pub fn get_ref(&self) -> &T {
29        &self.value
30    }
31}
32
33/// This is a lock-free concurrent cache.
34/// It stores only the most up-to-date value.
35pub struct Cache<T> {
36    value: ArcSwap<CacheValue<T>>,
37    is_updating: AtomicBool,
38}
39
40pub type CacheValueType<T> = Guard<Arc<CacheValue<T>>>;
41
42impl<T> Cache<T> {
43    pub fn new(value: T, version: usize) -> Self {
44        Cache {
45            value: ArcSwap::new(CacheValue::<T> { version, value }.into()),
46            is_updating: AtomicBool::new(false),
47        }
48    }
49
50    fn finish_update(&self) {
51        self.is_updating.store(false, Ordering::SeqCst);
52    }
53
54    /// Tries to retrieve value for the specified version.
55    /// Returns a result either containing a cached value, or an f (which is guaranteed to be not invoked by function call in this case).
56    /// If requested version corresponds to the value currently stored in cache - the value is returned.
57    /// If requested version is older None will be returned.
58    /// If requested version is newer, the new value will be computed and stored by calling f, and then returned,
59    /// unless the value is being currently updated - in this case None will be returned.
60    /// If None is returned it is guaranteed that f was not called.
61    pub fn value(
62        &self,
63        version: usize,
64        f: impl FnOnce() -> T,
65    ) -> Result<CacheValueType<T>, impl FnOnce() -> T> {
66        let v = self.value.load();
67        match v.version.cmp(&version) {
68            std::cmp::Ordering::Equal => Ok(v),
69            std::cmp::Ordering::Greater => Err(f), //requesting too old version
70            std::cmp::Ordering::Less => {
71                // try to update
72                drop(v);
73                match self.is_updating.compare_exchange(
74                    false,
75                    true,
76                    Ordering::SeqCst,
77                    Ordering::SeqCst,
78                ) {
79                    Ok(_) => {
80                        let v = self.value.load();
81                        match v.version.cmp(&version) {
82                            std::cmp::Ordering::Equal => {
83                                // already updated by someone else to the version we need
84                                self.finish_update();
85                                Ok(v)
86                            }
87                            std::cmp::Ordering::Greater => {
88                                // already updated by someone else beyond the version we need
89                                self.finish_update();
90                                Err(f)
91                            }
92                            std::cmp::Ordering::Less => {
93                                drop(v);
94                                self.value.store(
95                                    CacheValue {
96                                        value: f(),
97                                        version,
98                                    }
99                                    .into(),
100                                );
101                                let v = self.value.load(); // is_updating set to true guarantees that nobody else will modify the value.
102                                self.finish_update();
103                                Ok(v)
104                            }
105                        }
106                    }
107                    Err(_) => Err(f),
108                }
109            }
110        }
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use std::{sync::Arc, time::Duration};
117
118    use super::Cache;
119
120    #[test]
121    fn test_cache() {
122        let cache = Cache::<String>::new("0".to_string(), 0);
123
124        assert_eq!(
125            cache
126                .value(0, || { "1".to_string() })
127                .as_ref()
128                .map(|v| v.get_ref().as_str())
129                .unwrap_or(""),
130            "0"
131        );
132        assert_eq!(
133            cache
134                .value(1, || { "1".to_string() })
135                .as_ref()
136                .map(|v| v.get_ref().as_str())
137                .unwrap_or(""),
138            "1"
139        );
140        assert!(cache.value(0, || { "2".to_string() }).is_err());
141
142        // try to get-update value from another thread
143        let cache = Arc::new(cache);
144        let cache2 = cache.clone();
145        std::thread::spawn(move || {
146            let res = cache2.value(2, || {
147                std::thread::sleep(Duration::from_secs(5));
148                "2".to_string()
149            });
150            assert_eq!(
151                res.as_ref().map(|v| v.get_ref().as_str()).unwrap_or(""),
152                "2"
153            );
154        });
155        std::thread::sleep(Duration::from_secs(1));
156        while cache.value(2, || "".to_string()).is_err() {
157            std::thread::sleep(Duration::from_secs(1));
158        }
159        assert_eq!(
160            cache
161                .value(2, || { "".to_string() })
162                .as_ref()
163                .map(|v| v.get_ref().as_str())
164                .unwrap_or(""),
165            "2"
166        );
167    }
168}