ration/
array.rs

1//! Shared Memory Array
2
3
4
5use std::{path::Path, sync::atomic::{AtomicIsize, AtomicU8, Ordering}};
6
7use crate::{Error, Result};
8
9
10
11/// A shared array that can store `capacity` elements of type `T`.
12///
13/// # Example
14/// *In your "parent" process:*
15/// ```no_run
16/// use ration::Array;
17///
18/// let mut array: Array<char> = Array::alloc("/tmp/MY_ARRAY", 6).unwrap();
19/// array.push_many("ration".chars());
20/// ```
21/// *In your "child" process:*
22/// ```no_run
23/// use ration::Array;
24///
25/// let mut array: Array<char> = Array::open("/tmp/MY_ARRAY").unwrap();
26///
27/// let mut s = String::new();
28/// while let Some(c) = array.pop() {
29///     s.push(c);
30/// }
31/// println!("MY_ARRAY: {}", s); // "ration"
32/// ```
33// TODO: Some sort of mutable access check.
34pub struct Array<T: Sized> {
35    shm: shared_memory::Shmem,
36
37    empty_flag: *mut AtomicU8,
38    base: *mut Option<T>,
39    capacity: isize,
40    first: isize,
41    last: *mut AtomicIsize,
42    len: *mut AtomicIsize,
43}
44
45impl<T: Sized> Array<T> {
46    /// Allocate an array to shared memory identified by the given path, with the given capacity.
47    pub fn alloc(path: impl AsRef<Path>, capacity: usize) -> Result<Self> {
48        let block_size
49            = (std::mem::size_of::<Option<T>>() * capacity) // elements
50            + std::mem::size_of::<AtomicU8>()               // empty_flag
51            + (std::mem::size_of::<AtomicIsize>() * 2);     // last & len
52
53        let shm = match shared_memory::ShmemConf::new().flink(&path).size(block_size).create() {
54            Ok(shmem) => shmem,
55            Err(shared_memory::ShmemError::LinkExists) => {
56                return Err(Error::BlockAlreadyAllocated);
57            }
58            Err(e) => { return Err(Error::Shm(e)); }
59        };
60
61        unsafe {
62            let empty_flag = shm.as_ptr() as *mut AtomicU8;
63            let len = empty_flag.offset(1) as *mut AtomicIsize;
64            let first = 1;
65            let last = len.offset(1);
66            let base = len.offset(2) as *mut Option<T>;
67            let capacity = capacity as isize;
68
69            (&*len).store(0, Ordering::SeqCst);
70            (&*last).store(first, Ordering::SeqCst);
71            for i in 0..capacity {
72                base.offset(i).write(None);
73            }
74
75            Ok(Self {
76                shm,
77                empty_flag,
78                base,
79                capacity,
80                first,
81                last,
82                len,
83            })
84        }
85    }
86
87    /// Open an array in shared memory identified by the given path.
88    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
89        let shm = shared_memory::ShmemConf::new()
90            .flink(path)
91            .open()
92            .map_err(|e| Error::Shm(e))?;
93
94        let metadata_size
95            = std::mem::size_of::<AtomicU8>()               // empty_flag
96            + (std::mem::size_of::<AtomicIsize>() * 2);     // last & len
97
98        let array_size = shm.len() - metadata_size;
99        let slot_size = std::mem::size_of::<Option<T>>();
100        let capacity = array_size / slot_size;
101
102        unsafe {
103            let empty_flag = shm.as_ptr() as *mut AtomicU8;
104            let len = empty_flag.offset(1) as *mut AtomicIsize;
105            let first = 1;
106            let last = len.offset(1);
107            let base = len.offset(2) as *mut Option<T>;
108            let capacity = capacity as isize;
109
110            Ok(Self {
111                shm,
112                empty_flag,
113                base,
114                capacity,
115                first,
116                last,
117                len,
118            })
119        }
120    }
121
122    /// Returns `true` if the array contains no elements.
123    pub fn is_empty(&self) -> bool {
124        unsafe { &*self.empty_flag }.load(Ordering::Relaxed) == 0
125    }
126
127    /// Returns the number of array slots that are empty.
128    pub fn slots_remaining(&self) -> usize {
129        (self.capacity - unsafe { &*self.len }.load(Ordering::SeqCst)).unsigned_abs()
130    }
131
132    /// Push an element to the back of the array.
133    pub fn push(&mut self, element: T) -> bool {
134        // Ensure the internal ring buffer isn't full.
135        let count = unsafe { &*self.len }.fetch_add(1, Ordering::SeqCst);
136        if count >= self.capacity {
137            // The buffer is full; give up.
138            unsafe { &*self.len }.fetch_sub(1, Ordering::SeqCst);
139            return false;
140        }
141
142        self.push_unchecked(element);
143
144        // Signal.
145        unsafe { &mut *self.empty_flag }.store(1, Ordering::Relaxed);
146
147        true
148    }
149
150    /// Push an iterator of elements to the back of the array.
151    pub fn push_many(&mut self, elements: impl IntoIterator<Item = T>) {
152        let slots_remaining = self.slots_remaining();
153        for element in elements.into_iter().take(slots_remaining) {
154            let _ = unsafe { &*self.len }.fetch_add(1, Ordering::SeqCst);
155            self.push_unchecked(element);
156        }
157
158        // Signal.
159        unsafe { &mut *self.empty_flag }.store(1, Ordering::Relaxed);
160    }
161
162    /// Push an element to the back of the array without checking for overflows, raising the empty
163    /// flag, or checking access.
164    pub fn push_unchecked(&mut self, element: T) {
165        // Get the next available index, wrapping if need be.
166        let index = unsafe { &*self.last }.fetch_add(1, Ordering::SeqCst) % self.capacity;
167        if index == 0 {
168            // Just mod on overflow; the buffer is circular.
169            unsafe { &*self.last }.fetch_sub(self.capacity, Ordering::SeqCst);
170        }
171
172        // Write the element into the shared memory.
173        unsafe {
174            self.base.offset(index).write(Some(element));
175        }
176    }
177
178    /// Push an iterator of elements to the back of the array without checking for overflows,
179    /// raising the empty flag, or checking access.
180    pub fn push_many_unchecked(&mut self, elements: impl Iterator<Item = T>) {
181        for elem in elements {
182            self.push_unchecked(elem)
183        }
184    }
185
186    /// Pop an element from the front of the array.
187    pub fn pop(&mut self) -> Option<T> {
188        if self.is_empty() {
189            return None;
190        }
191
192        let result = unsafe { &mut *self.base.offset(self.first) }.take();
193        if !result.is_none() {
194            self.first = (self.first + 1) % self.capacity;
195            unsafe { &*self.len }.fetch_sub(1, Ordering::SeqCst);
196        } else {
197            // Signal.
198            unsafe { &mut *self.empty_flag }.store(0, Ordering::Relaxed);
199        }
200
201        result
202    }
203
204    /// Pop an element from the front of the array without checking for overflows, raising the
205    /// empty flag, or checking access.
206    pub fn pop_unchecked(&mut self) -> Option<T> {
207        let result = unsafe { &mut *self.base.offset(self.first) }.take();
208        if !result.is_none() {
209            self.first = (self.first + 1) % self.capacity;
210            unsafe { &*self.len }.fetch_sub(1, Ordering::SeqCst);
211        }
212
213        result
214    }
215}
216
217impl<T> Array<T> {
218    /// Returns `true` if the underlying shared memory mapping is owned by this array instance.
219    pub fn is_owner(&self) -> bool {
220        self.shm.is_owner()
221    }
222}
223
224impl<T: std::fmt::Debug> std::fmt::Debug for Array<T> {
225    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226        f.debug_struct("Array")
227            .field("capacity", &self.capacity)
228            .field("len", &self.len)
229            // .field("elements", self.iter().collect())
230            .finish_non_exhaustive()
231    }
232}
233
234
235
236/// # Warning
237///
238/// This can be wildly unsafe if the array is being mutated while you are iterating over its
239/// elements. **Use at your own risk.**
240pub struct ArrayIter<'a, T> {
241    array: &'a Array<T>,
242    index: isize,
243    capacity: isize,
244    // FIXME: There is probably a better way to stop the iteration than counting up to length.
245    count: isize,
246    len: isize,
247}
248
249impl<'a, T> Iterator for ArrayIter<'a, T> {
250    type Item = &'a T;
251
252    fn next(&mut self) -> Option<Self::Item> {
253        let elem = unsafe { &*self.array.base.offset(self.index) }.as_ref();
254        if elem.is_some() {
255            self.index = (self.index + 1) % self.capacity;
256            self.count += 1;
257        }
258        if self.count > self.len {
259            // NOTE: This is a ring buffer, so the iterator will continue indefinitely if the
260            //       array is full without this check.
261            None
262        } else {
263            elem
264        }
265    }
266}
267
268/// # Warning
269///
270/// This can be wildly unsafe if the array is being mutated while you are iterating over its
271/// elements. **Use at your own risk.**
272pub struct ArrayIterMut<'a, T> {
273    array: &'a mut Array<T>,
274    index: isize,
275    capacity: isize,
276    // FIXME: There is probably a better way to stop the iteration than counting up to length.
277    count: isize,
278    len: isize,
279}
280
281impl<'a, T> Iterator for ArrayIterMut<'a, T> {
282    type Item = &'a mut T;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        let elem = unsafe { &mut *self.array.base.offset(self.index) }.as_mut();
286        if elem.is_some() {
287            self.index = (self.index + 1) % self.capacity;
288            self.count += 1;
289        }
290        if self.count > self.len {
291            // NOTE: This is a ring buffer, so the iterator will continue indefinitely if the
292            //       array is full without this check.
293            None
294        } else {
295            elem
296        }
297    }
298}
299
300// Iteration methods.
301impl<T: Sized> Array<T> {
302    /// Iterate over this array's elements.
303    ///
304    /// # Warning
305    ///
306    /// This can be wildly unsafe if the array is being mutated while you are iterating over its
307    /// elements. **Use at your own risk.**
308    pub fn iter(&self) -> ArrayIter<'_, T> {
309        ArrayIter {
310            array: self,
311            index: self.first,
312            capacity: self.capacity,
313            count: 0,
314            len: unsafe { &*self.len }.load(Ordering::Relaxed),
315        }
316    }
317
318    /// Mutably iterate over this array's elements.
319    ///
320    /// # Warning
321    ///
322    /// This can be wildly unsafe if the array is being mutated while you are iterating over its
323    /// elements. **Use at your own risk.**
324    pub fn iter_mut(&mut self) -> ArrayIterMut<'_, T> {
325        let index = self.first.clone();
326        let capacity = self.capacity.clone();
327        let len = unsafe { &*self.len }.load(Ordering::Relaxed).clone();
328
329        ArrayIterMut {
330            array: self,
331            index,
332            capacity,
333            count: 0,
334            len,
335        }
336    }
337}
338
339
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn array_test_1() {
347        let mut array_1: Array<char> = Array::alloc("/tmp/TEST_ARRAY_1", 16).unwrap();
348        assert!(array_1.is_owner());
349        assert!(array_1.is_empty());
350
351        let s = "Something...";
352
353        array_1.push_many(s.chars());
354
355        assert!(!array_1.is_empty());
356        assert_eq!(array_1.slots_remaining(), 4);
357
358        {
359            let mut ref_array_1: Array<char> = Array::open("/tmp/TEST_ARRAY_1").unwrap();
360            assert!(!ref_array_1.is_owner());
361            assert!(!ref_array_1.is_empty());
362            assert_eq!(array_1.capacity, ref_array_1.capacity);
363
364            let mut ref_s = String::new();
365            while let Some(c) = ref_array_1.pop() {
366                ref_s.push(c);
367            }
368
369            assert_eq!(ref_array_1.slots_remaining(), 16);
370            assert_eq!(s.to_string(), ref_s);
371        }
372
373        assert!(array_1.is_empty());
374    }
375
376    #[test]
377    fn array_push_overflow() {
378        let mut array: Array<u8> = Array::alloc("/tmp/TEST_ARRAY_OVERFLOW", 8).unwrap();
379
380        let mut stopped_at = 0;
381        for i in 0..16 {
382            if !array.push(i) {
383                stopped_at = i;
384                break;
385            }
386        }
387
388        assert_eq!(stopped_at, 8);
389        assert_eq!(array.slots_remaining(), 0);
390    }
391
392    #[test]
393    fn array_slots_update_correctly() {
394        let mut array: Array<u8> = Array::alloc("/tmp/TEST_ARRAY_SLOTSUPDATE", 8).unwrap();
395
396        for i in 0..9 {
397            if !array.push(i) {
398                assert_eq!(array.slots_remaining(), 0);
399
400                for j in 0..4_u8 {
401                    let Some(last_i) = array.pop() else {
402                        panic!("array should have filled slots")
403                    };
404                    assert_eq!(last_i, j);
405                    assert_eq!(array.slots_remaining(), (j + 1) as usize);
406                }
407                for k in (0..4_u8).rev() {
408                    assert!(array.push(k));
409                    assert_eq!(array.slots_remaining(), k as usize);
410                }
411            }
412        }
413    }
414
415    #[test]
416    fn array_traverse_full() {
417        let mut array = Array::alloc("/tmp/TEST_ARRAY_ITERFULL", 16).unwrap();
418        // 16th item is 'j'.
419        array.push_many("This is a test just to see if the array iterates correctly.".chars());
420
421        let iter = array.iter();
422
423        assert_eq!(iter.len, 16);
424
425        let mut s = String::new();
426        for ch in iter {
427            s.push(*ch);
428        }
429
430        assert_eq!(&s, "This is a test j");
431    }
432
433    #[test]
434    fn array_traverse_partial() {
435        let mut array = Array::alloc("/tmp/TEST_ARRAY_ITERPARTIAL", 16).unwrap();
436        array.push_many("Testing...".chars());
437
438        let iter = array.iter();
439
440        assert_eq!(iter.len, 10);
441        assert_eq!(array.slots_remaining(), 6);
442
443        let mut s = String::new();
444        for ch in iter {
445            s.push(*ch);
446        }
447
448        assert_eq!(&s, "Testing...");
449    }
450
451    #[test]
452    fn array_traverse_mutable() {
453        let mut array = Array::alloc("/tmp/TEST_ARRAY_ITERMUT", 16).unwrap();
454        array.push_many([1, 2, 3, 4, 5].into_iter());
455
456        for n in array.iter_mut() {
457            *n += 1;
458        }
459
460        let result = array.iter().copied().collect::<Vec<_>>();
461        assert_eq!(result, vec![2, 3, 4, 5, 6]);
462    }
463}