ockam_core/access_control/
cache.rs

1use crate::compat::sync::Mutex;
2use crate::{
3    Address, IncomingAccessControl, LocalInfo, OutgoingAccessControl, RelayMessage, Route,
4};
5use alloc::vec::Vec;
6use async_trait::async_trait;
7use core::fmt::Debug;
8use std::time::Instant;
9
10const CACHE_MAX_SIZE: usize = 10;
11const CACHE_DURATION_SECS: u64 = 1;
12
13#[derive(Debug)]
14struct CacheEntry {
15    source: Address,
16    destination: Address,
17    onward_route: Route,
18    return_route: Route,
19    local_info: Vec<LocalInfo>,
20    timestamp: Instant,
21}
22
23impl CacheEntry {
24    fn from(relay_message: &RelayMessage) -> Self {
25        Self {
26            source: relay_message.source().clone(),
27            destination: relay_message.destination().clone(),
28            onward_route: relay_message.onward_route().clone(),
29            return_route: relay_message.return_route().clone(),
30            local_info: relay_message.local_message().local_info().to_vec(),
31            timestamp: Instant::now(),
32        }
33    }
34
35    /// Returns true if the cache entry is expired.
36    fn is_expired(&self) -> bool {
37        self.timestamp.elapsed().as_secs() >= CACHE_DURATION_SECS
38    }
39
40    /// Returns true if the relay message matches the cache entry.
41    /// Everything except the payload is compared.
42    fn matches(&self, relay_message: &RelayMessage) -> bool {
43        self.source == *relay_message.source()
44            && self.destination == *relay_message.destination()
45            && self.onward_route == *relay_message.onward_route()
46            && self.return_route == *relay_message.return_route()
47            && self.local_info == relay_message.local_message().local_info()
48    }
49}
50
51#[derive(Debug)]
52struct Cache {
53    cache: Mutex<Vec<CacheEntry>>,
54}
55
56impl Cache {
57    pub fn new() -> Self {
58        Self {
59            cache: Mutex::new(Vec::new()),
60        }
61    }
62
63    /// Returns true if the relay message is in the cache and not expired.
64    pub fn exist_in_cache(&self, relay_message: &RelayMessage) -> bool {
65        let mut cache_guard = self.cache.lock().unwrap();
66        cache_guard
67            .iter()
68            .position(|entry| entry.matches(relay_message))
69            .map(|position| {
70                if cache_guard[position].is_expired() {
71                    cache_guard.remove(position);
72                    false
73                } else {
74                    true
75                }
76            })
77            .unwrap_or(false)
78    }
79
80    /// Adds the relay message to the cache.
81    pub fn add_authorized(&self, relay_message: &RelayMessage) {
82        let mut cache_guard = self.cache.lock().unwrap();
83        let position = cache_guard
84            .iter()
85            .position(|entry| entry.matches(relay_message));
86        if let Some(position) = position {
87            cache_guard.remove(position);
88        }
89        cache_guard.push(CacheEntry::from(relay_message));
90        if cache_guard.len() > CACHE_MAX_SIZE {
91            cache_guard.remove(0);
92        }
93    }
94}
95
96/// A wrapper for an incoming access control that caches successful authorizations.
97/// The message is considered the same if everything except the payload is the same.
98/// Keeps a cache of the last [`CACHE_MAX_SIZE`] authorized messages with validity of
99/// [`CACHE_DURATION_SECS`] seconds.
100#[derive(Debug)]
101pub struct CachedIncomingAccessControl {
102    cache: Cache,
103    access_control: Box<dyn IncomingAccessControl>,
104}
105
106impl CachedIncomingAccessControl {
107    /// Wraps an incoming access control with a cache.
108    pub fn new(access_control: Box<dyn IncomingAccessControl>) -> Self {
109        Self {
110            cache: Cache::new(),
111            access_control,
112        }
113    }
114}
115
116#[async_trait]
117impl IncomingAccessControl for CachedIncomingAccessControl {
118    async fn is_authorized(&self, relay_msg: &RelayMessage) -> crate::Result<bool> {
119        if self.cache.exist_in_cache(relay_msg) {
120            return crate::allow();
121        }
122        let is_authorized = self.access_control.is_authorized(relay_msg).await?;
123        if is_authorized {
124            self.cache.add_authorized(relay_msg);
125            crate::allow()
126        } else {
127            crate::deny()
128        }
129    }
130}
131
132/// A wrapper for an outgoing access control that caches successful authorizations.
133/// The message is considered the same if everything except the payload is the same.
134/// Keeps a cache of the last [`CACHE_MAX_SIZE`] authorized messages with validity of
135/// [`CACHE_DURATION_SECS`] seconds.
136#[derive(Debug)]
137pub struct CachedOutgoingAccessControl {
138    cache: Cache,
139    access_control: Box<dyn OutgoingAccessControl>,
140}
141
142impl CachedOutgoingAccessControl {
143    /// Wraps an outgoing access control with a cache.
144    pub fn new(access_control: Box<dyn OutgoingAccessControl>) -> Self {
145        Self {
146            cache: Cache::new(),
147            access_control,
148        }
149    }
150}
151
152#[async_trait]
153impl OutgoingAccessControl for CachedOutgoingAccessControl {
154    async fn is_authorized(&self, relay_msg: &RelayMessage) -> crate::Result<bool> {
155        if self.cache.exist_in_cache(relay_msg) {
156            return crate::allow();
157        }
158        let is_authorized = self.access_control.is_authorized(relay_msg).await?;
159        if is_authorized {
160            self.cache.add_authorized(relay_msg);
161            crate::allow()
162        } else {
163            crate::deny()
164        }
165    }
166}
167
168#[cfg(test)]
169#[allow(missing_docs)]
170pub mod test {
171    use crate::access_control::cache::{CacheEntry, CACHE_DURATION_SECS};
172    use crate::{
173        route, Address, IncomingAccessControl, LocalInfo, OutgoingAccessControl, RelayMessage,
174    };
175    use async_trait::async_trait;
176    use std::sync::atomic::{AtomicBool, Ordering};
177    use std::sync::Arc;
178    use std::time::Duration;
179    use std::time::Instant;
180    use tokio::time::sleep;
181
182    #[derive(Debug)]
183    struct DebugAccessControl {
184        authorized: Arc<AtomicBool>,
185    }
186
187    #[async_trait]
188    impl IncomingAccessControl for DebugAccessControl {
189        async fn is_authorized(&self, _relay_msg: &RelayMessage) -> crate::Result<bool> {
190            Ok(self.authorized.load(Ordering::Relaxed))
191        }
192    }
193
194    #[async_trait]
195    impl OutgoingAccessControl for DebugAccessControl {
196        async fn is_authorized(&self, _relay_msg: &RelayMessage) -> crate::Result<bool> {
197            Ok(self.authorized.load(Ordering::Relaxed))
198        }
199    }
200    fn relay_message() -> RelayMessage {
201        RelayMessage::new(
202            Address::random_local(),
203            Address::random_local(),
204            crate::LocalMessage::new()
205                .with_onward_route(route!["onward"])
206                .with_return_route(route!["return"])
207                .with_local_info(vec![LocalInfo::new("type".into(), vec![1, 2, 3])]),
208        )
209    }
210
211    // deduplicated test for incoming and outgoing access control
212    macro_rules! access_policy_test {
213        ($struct_name:tt) => {
214            let authorized = Arc::new(AtomicBool::new(false));
215            let access_control = DebugAccessControl {
216                authorized: authorized.clone(),
217            };
218
219            let access_control = crate::$struct_name::new(Box::new(access_control));
220            let relay_msg = relay_message();
221
222            // negative result is not cached
223            assert!(!access_control.is_authorized(&relay_msg).await.unwrap());
224            authorized.store(true, Ordering::Relaxed);
225            assert!(access_control.is_authorized(&relay_msg).await.unwrap());
226
227            // positive result is cached
228            authorized.store(false, Ordering::Relaxed);
229            assert!(access_control.is_authorized(&relay_msg).await.unwrap());
230
231            // but it expires
232            sleep(Duration::from_millis(CACHE_DURATION_SECS * 1000 + 100)).await;
233            assert!(!access_control.is_authorized(&relay_msg).await.unwrap());
234
235            // positive result is cached again until the cache is full
236            authorized.store(true, Ordering::Relaxed);
237            assert!(access_control.is_authorized(&relay_msg).await.unwrap());
238            for _ in 0..crate::access_control::cache::CACHE_MAX_SIZE {
239                let different_relay_msg = relay_message();
240                assert!(access_control
241                    .is_authorized(&different_relay_msg)
242                    .await
243                    .unwrap());
244            }
245            // the relay message is no longer cached
246            authorized.store(false, Ordering::Relaxed);
247            assert!(!access_control.is_authorized(&relay_msg).await.unwrap());
248        };
249    }
250
251    #[tokio::test]
252    pub async fn incoming_access_control() {
253        access_policy_test!(CachedIncomingAccessControl);
254    }
255
256    #[tokio::test]
257    pub async fn outgoing_access_control() {
258        access_policy_test!(CachedOutgoingAccessControl);
259    }
260
261    #[test]
262    pub fn cache_entry_matches() {
263        let relay_msg = relay_message();
264
265        // self matches
266        let entry = CacheEntry::from(&relay_msg);
267        assert!(entry.matches(&relay_msg));
268
269        // payload is ignored
270        let cloned = RelayMessage::new(
271            relay_msg.source().clone(),
272            relay_msg.destination().clone(),
273            relay_msg.local_message().clone().with_payload(vec![1]),
274        );
275        assert!(entry.matches(&cloned));
276
277        // we check that if any field is different, the entry does not match
278
279        // source
280        let cloned = RelayMessage::new(
281            Address::random_local(),
282            relay_msg.destination().clone(),
283            relay_msg.local_message().clone(),
284        );
285        assert!(!entry.matches(&cloned));
286
287        // destination
288        let cloned = RelayMessage::new(
289            relay_msg.source().clone(),
290            Address::random_local(),
291            relay_msg.local_message().clone(),
292        );
293        assert!(!entry.matches(&cloned));
294
295        // onward route
296        let cloned = RelayMessage::new(
297            relay_msg.source().clone(),
298            relay_msg.destination().clone(),
299            relay_msg
300                .local_message()
301                .clone()
302                .with_onward_route(route!["different"]),
303        );
304        assert!(!entry.matches(&cloned));
305
306        // return route
307        let cloned = RelayMessage::new(
308            relay_msg.source().clone(),
309            relay_msg.destination().clone(),
310            relay_msg
311                .local_message()
312                .clone()
313                .with_return_route(route!["different"]),
314        );
315        assert!(!entry.matches(&cloned));
316
317        // local info
318        let cloned = RelayMessage::new(
319            relay_msg.source().clone(),
320            relay_msg.destination().clone(),
321            relay_msg
322                .local_message()
323                .clone()
324                .with_local_info(vec![LocalInfo::new("type".into(), vec![1, 2, 3, 4])]),
325        );
326        assert!(!entry.matches(&cloned));
327    }
328
329    #[test]
330    pub fn cache_entry_is_expired() {
331        let entry = CacheEntry {
332            source: Address::random_local(),
333            destination: Address::random_local(),
334            onward_route: route!["onward"],
335            return_route: route!["return"],
336            local_info: vec![],
337            timestamp: Instant::now(),
338        };
339
340        // not expired
341        assert!(!entry.is_expired());
342
343        // expired
344        let entry = CacheEntry {
345            timestamp: Instant::now() - Duration::from_secs(CACHE_DURATION_SECS),
346            ..entry
347        };
348        assert!(entry.is_expired());
349    }
350}