autoincrement/
atomic.rs

1use std::sync::atomic::{AtomicU16, AtomicU32, AtomicU64, AtomicU8, AtomicUsize};
2
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5
6/// Thread-safe container for keeping autoincrement counter
7#[derive(Debug, Clone)]
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9pub struct AsyncIncrement<T: AsyncIncremental>(T::Atomic);
10
11impl<T: AsyncIncremental> AsyncIncrement<T> {
12    #[allow(clippy::should_implement_trait)]
13    pub fn pull(&self) -> T {
14        AsyncIncremental::get_next(&self.0)
15    }
16
17    pub fn init_with(initial_value: T) -> Self {
18        Self(AsyncIncremental::into_atomic(initial_value))
19    }
20
21    pub fn current(&self) -> T {
22        AsyncIncremental::from_inner(Atomic::current(&self.0))
23    }
24}
25
26/// Trait for implementing over thread-safe incremental types
27pub trait AsyncIncremental: Sized {
28    type Atomic: Atomic;
29
30    fn initial() -> Self;
31
32    fn get_next(atomic: &Self::Atomic) -> Self;
33
34    fn into_atomic(value: Self) -> Self::Atomic;
35
36    fn from_inner(inner: <Self::Atomic as Atomic>::Inner) -> Self;
37
38    fn init() -> AsyncIncrement<Self> {
39        Self::init_with(Self::initial())
40    }
41
42    fn init_with(value: Self) -> AsyncIncrement<Self> {
43        AsyncIncrement(Self::into_atomic(value))
44    }
45
46    fn init_from(self) -> AsyncIncrement<Self> {
47        Self::init_with(self)
48    }
49}
50
51/// Only for type-safe purposes. You don't need to use this trait.
52pub trait Atomic: Send + Sync + std::fmt::Debug {
53    type Inner: Copy;
54
55    fn new(initial_value: Self::Inner) -> Self;
56
57    fn next(&self, step: Self::Inner) -> Self::Inner;
58
59    fn current(&self) -> Self::Inner;
60}
61
62macro_rules! impl_atomic {
63    ($basic:ty => $atomic:ty) => (
64        impl Atomic for $atomic {
65            type Inner = $basic;
66
67            fn new(initial_value: Self::Inner) -> Self {
68                Self::new(initial_value)
69            }
70
71            fn next(&self, step: Self::Inner) -> Self::Inner {
72                self.fetch_add(step, std::sync::atomic::Ordering::SeqCst)
73            }
74
75            fn current(&self) -> Self::Inner {
76                self.load(std::sync::atomic::Ordering::SeqCst)
77            }
78        }
79    );
80
81    ($basic:ty => $atomic:ty, $($basics:ty => $atomics:ty),+) => (
82        impl_atomic!($basic => $atomic);
83        impl_atomic!($($basics => $atomics),+);
84    )
85}
86
87impl_atomic!(
88    u8 => AtomicU8,
89    u16 => AtomicU16,
90    u32 => AtomicU32,
91    u64 => AtomicU64,
92    usize => AtomicUsize
93);
94
95#[cfg(test)]
96mod tests {
97    use crate as autoincrement;
98
99    #[cfg(feature = "derive")]
100    use autoincrement::AsyncIncremental;
101    #[cfg(not(feature = "derive"))]
102    use autoincrement_derive::AsyncIncremental;
103
104    #[test]
105    #[cfg(feature = "async")]
106    fn test_async_u8() {
107        #[derive(AsyncIncremental, Debug, PartialEq, Eq)]
108        struct MyID(u8);
109
110        let counter = MyID::init();
111
112        assert_eq!(counter.current(), MyID(1));
113        assert_eq!(counter.pull(), MyID(1));
114        assert_eq!(counter.current(), MyID(2));
115        assert_eq!(counter.pull(), MyID(2));
116        assert_eq!(counter.current(), MyID(3));
117        assert_eq!(counter.pull(), MyID(3));
118    }
119
120    #[test]
121    #[cfg(feature = "async")]
122    fn test_async_u16() {
123        #[derive(AsyncIncremental, Debug, PartialEq, Eq)]
124        struct MyID(u16);
125
126        let counter = MyID::init();
127
128        assert_eq!(counter.current(), MyID(1));
129        assert_eq!(counter.pull(), MyID(1));
130        assert_eq!(counter.current(), MyID(2));
131        assert_eq!(counter.pull(), MyID(2));
132        assert_eq!(counter.current(), MyID(3));
133        assert_eq!(counter.pull(), MyID(3));
134    }
135
136    #[test]
137    #[cfg(feature = "async")]
138    fn test_async_u32() {
139        #[derive(AsyncIncremental, Debug, PartialEq, Eq)]
140        struct MyID(u32);
141
142        let counter = MyID::init();
143
144        assert_eq!(counter.current(), MyID(1));
145        assert_eq!(counter.pull(), MyID(1));
146        assert_eq!(counter.current(), MyID(2));
147        assert_eq!(counter.pull(), MyID(2));
148        assert_eq!(counter.current(), MyID(3));
149        assert_eq!(counter.pull(), MyID(3));
150    }
151
152    #[test]
153    #[cfg(feature = "async")]
154    fn test_async_u64() {
155        #[derive(AsyncIncremental, Debug, PartialEq, Eq)]
156        struct MyID(u64);
157
158        let counter = MyID::init();
159
160        assert_eq!(counter.current(), MyID(1));
161        assert_eq!(counter.pull(), MyID(1));
162        assert_eq!(counter.current(), MyID(2));
163        assert_eq!(counter.pull(), MyID(2));
164        assert_eq!(counter.current(), MyID(3));
165        assert_eq!(counter.pull(), MyID(3));
166    }
167
168    #[test]
169    #[cfg(feature = "async")]
170    fn test_async_usize() {
171        #[derive(AsyncIncremental, Debug, PartialEq, Eq)]
172        struct MyID(usize);
173
174        let counter = MyID::init();
175
176        assert_eq!(counter.current(), MyID(1));
177        assert_eq!(counter.pull(), MyID(1));
178        assert_eq!(counter.current(), MyID(2));
179        assert_eq!(counter.pull(), MyID(2));
180        assert_eq!(counter.current(), MyID(3));
181        assert_eq!(counter.pull(), MyID(3));
182    }
183}