ockam_core/access_control/
cache.rs1use crate::compat::sync::Mutex;
2use crate::{
3 Address, IncomingAccessControl, LocalInfo, OutgoingAccessControl, RelayMessage, Route,
4};
5use alloc::vec::Vec;
6use async_trait::async_trait;
7use core::fmt::Debug;
8use std::time::Instant;
9
10const CACHE_MAX_SIZE: usize = 10;
11const CACHE_DURATION_SECS: u64 = 1;
12
13#[derive(Debug)]
14struct CacheEntry {
15 source: Address,
16 destination: Address,
17 onward_route: Route,
18 return_route: Route,
19 local_info: Vec<LocalInfo>,
20 timestamp: Instant,
21}
22
23impl CacheEntry {
24 fn from(relay_message: &RelayMessage) -> Self {
25 Self {
26 source: relay_message.source().clone(),
27 destination: relay_message.destination().clone(),
28 onward_route: relay_message.onward_route().clone(),
29 return_route: relay_message.return_route().clone(),
30 local_info: relay_message.local_message().local_info().to_vec(),
31 timestamp: Instant::now(),
32 }
33 }
34
35 fn is_expired(&self) -> bool {
37 self.timestamp.elapsed().as_secs() >= CACHE_DURATION_SECS
38 }
39
40 fn matches(&self, relay_message: &RelayMessage) -> bool {
43 self.source == *relay_message.source()
44 && self.destination == *relay_message.destination()
45 && self.onward_route == *relay_message.onward_route()
46 && self.return_route == *relay_message.return_route()
47 && self.local_info == relay_message.local_message().local_info()
48 }
49}
50
51#[derive(Debug)]
52struct Cache {
53 cache: Mutex<Vec<CacheEntry>>,
54}
55
56impl Cache {
57 pub fn new() -> Self {
58 Self {
59 cache: Mutex::new(Vec::new()),
60 }
61 }
62
63 pub fn exist_in_cache(&self, relay_message: &RelayMessage) -> bool {
65 let mut cache_guard = self.cache.lock().unwrap();
66 cache_guard
67 .iter()
68 .position(|entry| entry.matches(relay_message))
69 .map(|position| {
70 if cache_guard[position].is_expired() {
71 cache_guard.remove(position);
72 false
73 } else {
74 true
75 }
76 })
77 .unwrap_or(false)
78 }
79
80 pub fn add_authorized(&self, relay_message: &RelayMessage) {
82 let mut cache_guard = self.cache.lock().unwrap();
83 let position = cache_guard
84 .iter()
85 .position(|entry| entry.matches(relay_message));
86 if let Some(position) = position {
87 cache_guard.remove(position);
88 }
89 cache_guard.push(CacheEntry::from(relay_message));
90 if cache_guard.len() > CACHE_MAX_SIZE {
91 cache_guard.remove(0);
92 }
93 }
94}
95
96#[derive(Debug)]
101pub struct CachedIncomingAccessControl {
102 cache: Cache,
103 access_control: Box<dyn IncomingAccessControl>,
104}
105
106impl CachedIncomingAccessControl {
107 pub fn new(access_control: Box<dyn IncomingAccessControl>) -> Self {
109 Self {
110 cache: Cache::new(),
111 access_control,
112 }
113 }
114}
115
116#[async_trait]
117impl IncomingAccessControl for CachedIncomingAccessControl {
118 async fn is_authorized(&self, relay_msg: &RelayMessage) -> crate::Result<bool> {
119 if self.cache.exist_in_cache(relay_msg) {
120 return crate::allow();
121 }
122 let is_authorized = self.access_control.is_authorized(relay_msg).await?;
123 if is_authorized {
124 self.cache.add_authorized(relay_msg);
125 crate::allow()
126 } else {
127 crate::deny()
128 }
129 }
130}
131
132#[derive(Debug)]
137pub struct CachedOutgoingAccessControl {
138 cache: Cache,
139 access_control: Box<dyn OutgoingAccessControl>,
140}
141
142impl CachedOutgoingAccessControl {
143 pub fn new(access_control: Box<dyn OutgoingAccessControl>) -> Self {
145 Self {
146 cache: Cache::new(),
147 access_control,
148 }
149 }
150}
151
152#[async_trait]
153impl OutgoingAccessControl for CachedOutgoingAccessControl {
154 async fn is_authorized(&self, relay_msg: &RelayMessage) -> crate::Result<bool> {
155 if self.cache.exist_in_cache(relay_msg) {
156 return crate::allow();
157 }
158 let is_authorized = self.access_control.is_authorized(relay_msg).await?;
159 if is_authorized {
160 self.cache.add_authorized(relay_msg);
161 crate::allow()
162 } else {
163 crate::deny()
164 }
165 }
166}
167
168#[cfg(test)]
169#[allow(missing_docs)]
170pub mod test {
171 use crate::access_control::cache::{CacheEntry, CACHE_DURATION_SECS};
172 use crate::{
173 route, Address, IncomingAccessControl, LocalInfo, OutgoingAccessControl, RelayMessage,
174 };
175 use async_trait::async_trait;
176 use std::sync::atomic::{AtomicBool, Ordering};
177 use std::sync::Arc;
178 use std::time::Duration;
179 use std::time::Instant;
180 use tokio::time::sleep;
181
182 #[derive(Debug)]
183 struct DebugAccessControl {
184 authorized: Arc<AtomicBool>,
185 }
186
187 #[async_trait]
188 impl IncomingAccessControl for DebugAccessControl {
189 async fn is_authorized(&self, _relay_msg: &RelayMessage) -> crate::Result<bool> {
190 Ok(self.authorized.load(Ordering::Relaxed))
191 }
192 }
193
194 #[async_trait]
195 impl OutgoingAccessControl for DebugAccessControl {
196 async fn is_authorized(&self, _relay_msg: &RelayMessage) -> crate::Result<bool> {
197 Ok(self.authorized.load(Ordering::Relaxed))
198 }
199 }
200 fn relay_message() -> RelayMessage {
201 RelayMessage::new(
202 Address::random_local(),
203 Address::random_local(),
204 crate::LocalMessage::new()
205 .with_onward_route(route!["onward"])
206 .with_return_route(route!["return"])
207 .with_local_info(vec![LocalInfo::new("type".into(), vec![1, 2, 3])]),
208 )
209 }
210
211 macro_rules! access_policy_test {
213 ($struct_name:tt) => {
214 let authorized = Arc::new(AtomicBool::new(false));
215 let access_control = DebugAccessControl {
216 authorized: authorized.clone(),
217 };
218
219 let access_control = crate::$struct_name::new(Box::new(access_control));
220 let relay_msg = relay_message();
221
222 assert!(!access_control.is_authorized(&relay_msg).await.unwrap());
224 authorized.store(true, Ordering::Relaxed);
225 assert!(access_control.is_authorized(&relay_msg).await.unwrap());
226
227 authorized.store(false, Ordering::Relaxed);
229 assert!(access_control.is_authorized(&relay_msg).await.unwrap());
230
231 sleep(Duration::from_millis(CACHE_DURATION_SECS * 1000 + 100)).await;
233 assert!(!access_control.is_authorized(&relay_msg).await.unwrap());
234
235 authorized.store(true, Ordering::Relaxed);
237 assert!(access_control.is_authorized(&relay_msg).await.unwrap());
238 for _ in 0..crate::access_control::cache::CACHE_MAX_SIZE {
239 let different_relay_msg = relay_message();
240 assert!(access_control
241 .is_authorized(&different_relay_msg)
242 .await
243 .unwrap());
244 }
245 authorized.store(false, Ordering::Relaxed);
247 assert!(!access_control.is_authorized(&relay_msg).await.unwrap());
248 };
249 }
250
251 #[tokio::test]
252 pub async fn incoming_access_control() {
253 access_policy_test!(CachedIncomingAccessControl);
254 }
255
256 #[tokio::test]
257 pub async fn outgoing_access_control() {
258 access_policy_test!(CachedOutgoingAccessControl);
259 }
260
261 #[test]
262 pub fn cache_entry_matches() {
263 let relay_msg = relay_message();
264
265 let entry = CacheEntry::from(&relay_msg);
267 assert!(entry.matches(&relay_msg));
268
269 let cloned = RelayMessage::new(
271 relay_msg.source().clone(),
272 relay_msg.destination().clone(),
273 relay_msg.local_message().clone().with_payload(vec![1]),
274 );
275 assert!(entry.matches(&cloned));
276
277 let cloned = RelayMessage::new(
281 Address::random_local(),
282 relay_msg.destination().clone(),
283 relay_msg.local_message().clone(),
284 );
285 assert!(!entry.matches(&cloned));
286
287 let cloned = RelayMessage::new(
289 relay_msg.source().clone(),
290 Address::random_local(),
291 relay_msg.local_message().clone(),
292 );
293 assert!(!entry.matches(&cloned));
294
295 let cloned = RelayMessage::new(
297 relay_msg.source().clone(),
298 relay_msg.destination().clone(),
299 relay_msg
300 .local_message()
301 .clone()
302 .with_onward_route(route!["different"]),
303 );
304 assert!(!entry.matches(&cloned));
305
306 let cloned = RelayMessage::new(
308 relay_msg.source().clone(),
309 relay_msg.destination().clone(),
310 relay_msg
311 .local_message()
312 .clone()
313 .with_return_route(route!["different"]),
314 );
315 assert!(!entry.matches(&cloned));
316
317 let cloned = RelayMessage::new(
319 relay_msg.source().clone(),
320 relay_msg.destination().clone(),
321 relay_msg
322 .local_message()
323 .clone()
324 .with_local_info(vec![LocalInfo::new("type".into(), vec![1, 2, 3, 4])]),
325 );
326 assert!(!entry.matches(&cloned));
327 }
328
329 #[test]
330 pub fn cache_entry_is_expired() {
331 let entry = CacheEntry {
332 source: Address::random_local(),
333 destination: Address::random_local(),
334 onward_route: route!["onward"],
335 return_route: route!["return"],
336 local_info: vec![],
337 timestamp: Instant::now(),
338 };
339
340 assert!(!entry.is_expired());
342
343 let entry = CacheEntry {
345 timestamp: Instant::now() - Duration::from_secs(CACHE_DURATION_SECS),
346 ..entry
347 };
348 assert!(entry.is_expired());
349 }
350}