maelstrom_base/
ring_buffer.rs

1//! This implements a simple ring-buffer backed by a vector that has serde support
2
3use serde::{
4    de::{Deserializer, SeqAccess, Visitor},
5    ser::SerializeSeq as _,
6    Deserialize, Serialize, Serializer,
7};
8use std::{
9    fmt::{self, Debug, Formatter},
10    iter::FusedIterator,
11    marker::PhantomData,
12    mem::{self, MaybeUninit},
13};
14
15pub struct RingBuffer<T, const N: usize> {
16    buf: [MaybeUninit<T>; N],
17    length: usize,
18    cursor: usize,
19}
20
21impl<T, const N: usize> Default for RingBuffer<T, N> {
22    fn default() -> Self {
23        assert!(N > 0, "capacity must not be zero");
24        Self {
25            buf: [const { MaybeUninit::uninit() }; N],
26            length: 0,
27            cursor: 0,
28        }
29    }
30}
31
32impl<T: Clone, const N: usize> Clone for RingBuffer<T, N> {
33    fn clone(&self) -> Self {
34        Self::from_iter(self.iter().cloned())
35    }
36}
37
38impl<T, const N: usize> Drop for RingBuffer<T, N> {
39    fn drop(&mut self) {
40        for i in self.cursor..self.length {
41            unsafe { self.buf[i].assume_init_drop() };
42        }
43        for i in 0..self.cursor {
44            unsafe { self.buf[i].assume_init_drop() };
45        }
46    }
47}
48
49impl<T: PartialEq, const N: usize> PartialEq for RingBuffer<T, N> {
50    fn eq(&self, other: &Self) -> bool {
51        self.iter().eq(other.iter())
52    }
53}
54
55impl<T: Eq, const N: usize> Eq for RingBuffer<T, N> {}
56
57impl<T: Debug, const N: usize> Debug for RingBuffer<T, N> {
58    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
59        f.debug_list().entries(self.iter()).finish()
60    }
61}
62
63impl<T, const N: usize> FromIterator<T> for RingBuffer<T, N> {
64    fn from_iter<IterT: IntoIterator<Item = T>>(iter: IterT) -> Self {
65        let mut result = Self::default();
66        for item in iter.into_iter() {
67            result.push(item)
68        }
69        result
70    }
71}
72
73impl<T, const N: usize> RingBuffer<T, N> {
74    pub fn push(&mut self, element: T) {
75        if self.length < N {
76            self.buf[self.length].write(element);
77            self.length += 1;
78        } else {
79            let old = unsafe { self.buf[self.cursor].assume_init_mut() };
80            *old = element;
81            self.cursor = (self.cursor + 1) % N;
82        }
83    }
84
85    pub fn len(&self) -> usize {
86        self.length
87    }
88
89    pub fn is_empty(&self) -> bool {
90        self.length == 0
91    }
92
93    pub fn iter(&self) -> Iter<'_, T, N> {
94        Iter { buf: self, idx: 0 }
95    }
96
97    pub fn get(&self, i: usize) -> Option<&T> {
98        if i >= self.length {
99            None
100        } else {
101            let i = if self.cursor < N - i {
102                self.cursor + i
103            } else {
104                self.cursor - (N - i)
105            };
106            Some(unsafe { self.buf[i].assume_init_ref() })
107        }
108    }
109}
110
111#[derive(Clone)]
112pub struct Iter<'a, T, const N: usize> {
113    buf: &'a RingBuffer<T, N>,
114    idx: usize,
115}
116
117impl<'a, T, const N: usize> Iterator for Iter<'a, T, N> {
118    type Item = &'a T;
119
120    fn next(&mut self) -> Option<Self::Item> {
121        let result = self.buf.get(self.idx);
122        if result.is_some() {
123            self.idx += 1;
124        }
125        result
126    }
127
128    fn size_hint(&self) -> (usize, Option<usize>) {
129        let size = self.len();
130        (size, Some(size))
131    }
132}
133
134impl<T, const N: usize> ExactSizeIterator for Iter<'_, T, N> {
135    fn len(&self) -> usize {
136        self.buf.length - self.idx
137    }
138}
139
140impl<T, const N: usize> FusedIterator for Iter<'_, T, N> {}
141
142impl<'a, T, const N: usize> IntoIterator for &'a RingBuffer<T, N> {
143    type Item = &'a T;
144    type IntoIter = Iter<'a, T, N>;
145    fn into_iter(self) -> Self::IntoIter {
146        self.iter()
147    }
148}
149
150impl<T, const N: usize> IntoIterator for RingBuffer<T, N> {
151    type Item = T;
152    type IntoIter = IntoIter<T, N>;
153    fn into_iter(mut self) -> Self::IntoIter {
154        let buf = mem::replace(&mut self.buf, [const { MaybeUninit::uninit() }; N]);
155        let length = mem::take(&mut self.length);
156        let cursor = mem::take(&mut self.cursor);
157        if length < N {
158            IntoIter {
159                buf,
160                first_beg: 0,
161                first_end: length,
162                rest_beg: 0,
163                rest_end: 0,
164            }
165        } else {
166            IntoIter {
167                buf,
168                first_beg: cursor,
169                first_end: N,
170                rest_beg: 0,
171                rest_end: cursor,
172            }
173        }
174    }
175}
176
177pub struct IntoIter<T, const N: usize> {
178    buf: [MaybeUninit<T>; N],
179    first_beg: usize,
180    first_end: usize,
181    rest_beg: usize,
182    rest_end: usize,
183}
184
185impl<T, const N: usize> Drop for IntoIter<T, N> {
186    fn drop(&mut self) {
187        for i in self.first_beg..self.first_end {
188            unsafe { self.buf[i].assume_init_drop() };
189        }
190        for i in self.rest_beg..self.rest_end {
191            unsafe { self.buf[i].assume_init_drop() };
192        }
193    }
194}
195
196impl<T, const N: usize> Iterator for IntoIter<T, N> {
197    type Item = T;
198
199    fn next(&mut self) -> Option<Self::Item> {
200        if self.first_beg < self.first_end {
201            let loc = mem::replace(&mut self.buf[self.first_beg], MaybeUninit::uninit());
202            self.first_beg += 1;
203            Some(unsafe { loc.assume_init() })
204        } else if self.rest_beg < self.rest_end {
205            let loc = mem::replace(&mut self.buf[self.rest_beg], MaybeUninit::uninit());
206            self.rest_beg += 1;
207            Some(unsafe { loc.assume_init() })
208        } else {
209            None
210        }
211    }
212
213    fn size_hint(&self) -> (usize, Option<usize>) {
214        let size = self.len();
215        (size, Some(size))
216    }
217}
218
219impl<T, const N: usize> ExactSizeIterator for IntoIter<T, N> {
220    fn len(&self) -> usize {
221        (self.first_end - self.first_beg) + (self.rest_end - self.rest_beg)
222    }
223}
224
225impl<T, const N: usize> FusedIterator for IntoIter<T, N> {}
226
227impl<T, const N: usize> Serialize for RingBuffer<T, N>
228where
229    T: Serialize,
230{
231    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
232    where
233        S: Serializer,
234    {
235        let mut seq = serializer.serialize_seq(Some(self.len()))?;
236        for element in self {
237            seq.serialize_element(element)?;
238        }
239        seq.end()
240    }
241}
242
243struct DeserializeVisitor<T, const N: usize>(PhantomData<fn() -> RingBuffer<T, N>>);
244
245impl<T, const N: usize> DeserializeVisitor<T, N> {
246    fn new() -> Self {
247        Self(PhantomData)
248    }
249}
250
251impl<'de, T: Deserialize<'de>, const N: usize> Visitor<'de> for DeserializeVisitor<T, N> {
252    type Value = RingBuffer<T, N>;
253
254    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
255        formatter.write_str("a RingBuffer sequence")
256    }
257
258    fn visit_seq<A: SeqAccess<'de>>(self, mut access: A) -> Result<Self::Value, A::Error> {
259        let mut result = RingBuffer::default();
260        while let Some(elem) = access.next_element()? {
261            result.push(elem);
262        }
263        Ok(result)
264    }
265}
266
267// This is the trait that informs Serde how to deserialize MyMap.
268impl<'de, T, const N: usize> Deserialize<'de> for RingBuffer<T, N>
269where
270    T: Deserialize<'de>,
271{
272    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
273    where
274        D: Deserializer<'de>,
275    {
276        deserializer.deserialize_seq(DeserializeVisitor::new())
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use rstest::rstest;
284    use serde_test::{assert_tokens, Token};
285    use std::{cell::RefCell, mem, rc::Rc};
286
287    #[track_caller]
288    fn insert_test<const N: usize>() {
289        let mut r = RingBuffer::<usize, N>::default();
290
291        for i in 0..N {
292            assert_eq!(r.len(), i);
293            r.push(i);
294        }
295        assert_eq!(r.len(), N);
296        assert_eq!(Vec::from_iter(r.iter().copied()), Vec::from_iter(0..N),);
297
298        for i in N..=(N * 2) {
299            r.push(i);
300            assert_eq!(r.len(), N);
301            assert_eq!(
302                Vec::from_iter(r.iter().copied()),
303                Vec::from_iter(i - N + 1..=i),
304            );
305        }
306    }
307
308    macro_rules! insert_test {
309        ($name:ident, $n:expr) => {
310            #[test]
311            fn $name() {
312                insert_test::<$n>();
313            }
314        };
315    }
316
317    insert_test!(insert_test_1, 1);
318    insert_test!(insert_test_2, 2);
319    insert_test!(insert_test_3, 3);
320    insert_test!(insert_test_10, 10);
321    insert_test!(insert_test_100, 100);
322
323    #[test]
324    fn equal_with_different_cursor() {
325        let mut r1 = RingBuffer::<usize, 3>::default();
326        r1.push(1);
327        r1.push(2);
328        r1.push(3);
329
330        let mut r2 = RingBuffer::<usize, 3>::default();
331        r2.push(1);
332        r2.push(1);
333        r2.push(2);
334        r2.push(3);
335
336        assert_eq!(r1, r2);
337    }
338
339    #[test]
340    fn not_equal_with_different_elements() {
341        let mut r1 = RingBuffer::<usize, 4>::default();
342        r1.push(1);
343
344        let mut r2 = RingBuffer::<usize, 4>::default();
345        r2.push(2);
346
347        assert_ne!(r1, r2);
348    }
349
350    #[test]
351    fn equal_with_same_elements() {
352        let mut r1 = RingBuffer::<usize, 3>::default();
353        r1.push(1);
354
355        let mut r2 = RingBuffer::<usize, 3>::default();
356        r2.push(1);
357
358        assert_eq!(r1, r2);
359    }
360
361    #[test]
362    fn debug_fmt() {
363        assert_eq!(
364            format!("{:?}", RingBuffer::<_, 2>::from_iter::<[usize; 0]>([])),
365            "[]",
366        );
367        assert_eq!(format!("{:?}", RingBuffer::<_, 2>::from_iter([1])), "[1]");
368        assert_eq!(
369            format!("{:?}", RingBuffer::<_, 2>::from_iter([1, 2])),
370            "[1, 2]",
371        );
372        assert_eq!(
373            format!("{:?}", RingBuffer::<_, 2>::from_iter([1, 2, 3])),
374            "[2, 3]",
375        );
376    }
377
378    #[test]
379    fn serialize_deserialize_empty() {
380        let r = RingBuffer::<i32, 5>::from_iter([]);
381        assert_tokens(&r, &[Token::Seq { len: Some(0) }, Token::SeqEnd])
382    }
383
384    #[test]
385    fn serialize_deserialize_half_empty() {
386        let r = RingBuffer::<i32, 5>::from_iter(1..=3);
387        assert_tokens(
388            &r,
389            &[
390                Token::Seq { len: Some(3) },
391                Token::I32(1),
392                Token::I32(2),
393                Token::I32(3),
394                Token::SeqEnd,
395            ],
396        )
397    }
398
399    #[test]
400    fn serialize_deserialize_full() {
401        let r = RingBuffer::<i32, 5>::from_iter(1..=5);
402        assert_tokens(
403            &r,
404            &[
405                Token::Seq { len: Some(5) },
406                Token::I32(1),
407                Token::I32(2),
408                Token::I32(3),
409                Token::I32(4),
410                Token::I32(5),
411                Token::SeqEnd,
412            ],
413        )
414    }
415
416    #[test]
417    fn from_iterator() {
418        let mut r = RingBuffer::<_, 3>::default();
419        assert_eq!(r, RingBuffer::<_, 3>::from_iter([]));
420        r.push(1);
421        assert_eq!(r, RingBuffer::<_, 3>::from_iter([1]));
422        r.push(2);
423        assert_eq!(r, RingBuffer::<_, 3>::from_iter([1, 2]));
424        r.push(3);
425        assert_eq!(r, RingBuffer::<_, 3>::from_iter([1, 2, 3]));
426        r.push(4);
427        assert_eq!(r, RingBuffer::<_, 3>::from_iter([2, 3, 4]));
428    }
429
430    struct TestStruct {
431        value: i32,
432        drop_log: Rc<RefCell<Vec<i32>>>,
433    }
434
435    impl Drop for TestStruct {
436        fn drop(&mut self) {
437            self.drop_log.borrow_mut().push(self.value);
438        }
439    }
440
441    #[rstest]
442    fn drop(#[values(0, 1, 2, 3, 4)] count: i32) {
443        let log = Rc::new(RefCell::new(Vec::<_>::new()));
444        let mut r = RingBuffer::<_, 2>::default();
445        for i in 1..=count {
446            r.push(TestStruct {
447                value: i,
448                drop_log: log.clone(),
449            });
450        }
451        log.borrow_mut().clear();
452        mem::drop(r);
453
454        let expected = match count {
455            0 => vec![],
456            1 => vec![1],
457            n => vec![n - 1, n],
458        };
459        assert_eq!(*log.borrow(), expected);
460    }
461
462    #[rstest]
463    fn iter(#[values(0, 1, 2, 3, 4, 5)] count: usize) {
464        let r = RingBuffer::<_, 2>::from_iter(0..count);
465        let mut i = r.iter();
466        if count > 1 {
467            assert_eq!(i.len(), 2);
468            assert_eq!(i.size_hint(), (2, Some(2)));
469            assert_eq!(i.next(), Some(count - 2).as_ref());
470        }
471        if count > 0 {
472            assert_eq!(i.len(), 1);
473            assert_eq!(i.size_hint(), (1, Some(1)));
474            assert_eq!(i.next(), Some(count - 1).as_ref());
475        }
476        for _ in 0..1_000_000 {
477            assert_eq!(i.len(), 0);
478            assert_eq!(i.size_hint(), (0, Some(0)));
479            assert_eq!(i.next(), None);
480        }
481    }
482
483    #[rstest]
484    fn into_iter(#[values(0, 1, 2, 3, 4, 5)] count: usize) {
485        let r = RingBuffer::<_, 2>::from_iter(0..count);
486        let expected = match count {
487            0 => vec![],
488            1 => vec![0],
489            n => vec![n - 2, n - 1],
490        };
491        assert_eq!(r.into_iter().collect::<Vec<_>>(), expected);
492    }
493
494    #[rstest]
495    #[case(0, 0)]
496    #[case(1, 0)]
497    #[case(1, 1)]
498    #[case(2, 0)]
499    #[case(2, 1)]
500    #[case(2, 2)]
501    #[case(3, 0)]
502    #[case(3, 1)]
503    #[case(3, 2)]
504    #[case(4, 0)]
505    #[case(4, 1)]
506    #[case(4, 2)]
507    fn into_iter_drop(#[case] to_insert: i32, #[case] to_consume: i32) {
508        let log = Rc::new(RefCell::new(Vec::<_>::new()));
509        let mut r = RingBuffer::<_, 2>::default();
510        for i in 0..to_insert {
511            r.push(TestStruct {
512                value: i,
513                drop_log: log.clone(),
514            });
515        }
516        let mut iter = r.into_iter();
517        for _ in 0..to_consume {
518            iter.next().unwrap();
519        }
520        log.borrow_mut().clear();
521        mem::drop(iter);
522
523        let expected = match to_insert.min(2) - to_consume {
524            0 => vec![],
525            1 => vec![to_insert - 1],
526            _ => vec![to_insert - 2, to_insert - 1],
527        };
528        assert_eq!(*log.borrow(), expected);
529    }
530}