ptr_array/
lib.rs

1//! An stack array store &T or &mut T as ([TypeId], *const/mut ())
2//! Can be used to pass reference to function call when do not have full control of it's argument.
3//!
4//! # Example:
5//! ```rust
6//! # use ptr_array::PointerArray;
7//! // a determined function that you can't modify for arbitrary argument.
8//! fn bar(_arg1: (), mut arg2: PointerArray<'_>) {
9//!     // remove &mut Vec<String> from array.
10//!     let arg3 = arg2.remove_mut::<Vec<String>>().unwrap();
11//!     assert_eq!("arg3", arg3.pop().unwrap());
12//! }
13//!
14//! # fn foo() {
15//! let ptr_array = PointerArray::new();
16//!
17//! let mut arg3 = vec![String::from("arg3")];
18//!
19//! // insert &mut Vec<String> to array. on Ok(_) the array itself would be returned with
20//! // reference successfully inserted.
21//! let ptr_array = ptr_array.insert_mut(&mut arg3).unwrap();
22//!
23//! // pass the array as argument to bar.
24//! bar((), ptr_array);
25//! # }
26//! ```
27
28#![no_std]
29
30use core::{
31    any::TypeId,
32    fmt,
33    hint::unreachable_unchecked,
34    marker::PhantomData,
35    mem::{self, MaybeUninit},
36};
37
38pub struct PointerArray<'a, const N: usize = 4> {
39    queue: [MaybeUninit<(TypeId, RefVariant)>; N],
40    head: usize,
41    _marker: PhantomData<&'a ()>,
42}
43
44impl Default for PointerArray<'_> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl<'a, const N: usize> PointerArray<'a, N> {
51    /// Construct an empty PointerArray.
52    ///
53    /// Array have a length of 4usize by default.
54    pub const fn new() -> Self {
55        PointerArray {
56            // SAFETY:
57            // [MaybeUninit<T>; N] is safe to assume from uninit.
58            queue: unsafe { MaybeUninit::uninit().assume_init() },
59            head: 0,
60            _marker: PhantomData,
61        }
62    }
63}
64
65/// Error type when stack array is full.
66pub struct Full<'a, const N: usize>(PointerArray<'a, N>);
67
68impl<'a, const N: usize> Full<'a, N> {
69    /// Retrieve array from Error.
70    pub fn into_inner(self) -> PointerArray<'a, N> {
71        self.0
72    }
73}
74
75impl<const N: usize> fmt::Debug for Full<'_, N> {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        write!(f, "Full(..)")
78    }
79}
80
81impl<const N: usize> fmt::Display for Full<'_, N> {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        write!(f, "PointerArray is full")
84    }
85}
86
87impl<'a, const N: usize> PointerArray<'a, N> {
88    /// Insert a &mut T to array. This method would consume Self.
89    ///
90    /// If the array have a &mut T already it would replace it and the old one is dropped.
91    /// *. only the &mut T would be dropped and destructor of T does not run.
92    ///
93    /// On Ok(Self) the array with reference successfully inserted would returned.
94    /// (Self's lifetime is shrink to the same of input &mut T).
95    ///
96    /// On Err([Full]) the array without modification would be returned. See [Full::into_inner] for
97    /// detail.
98    pub fn insert_mut<'b, T: 'static>(
99        mut self,
100        value: &'b mut T,
101    ) -> Result<PointerArray<'b, N>, Full<'a, N>>
102    where
103        'a: 'b,
104    {
105        let (id, opt) = self.try_find_one::<T>();
106
107        match opt {
108            Some(ptr) => {
109                let _ = mem::replace(ptr, value.into());
110            }
111            None => {
112                if self.head == N {
113                    return Err(Full(self));
114                }
115                // SAFETY:
116                // head just checked.
117                unsafe { self.write(id, value.into()) }
118            }
119        }
120
121        Ok(self)
122    }
123
124    /// The &T version of [Self::insert_mut]. See it for more detail.
125    pub fn insert_ref<'b, T: 'static>(
126        mut self,
127        value: &'b T,
128    ) -> Result<PointerArray<'b, N>, Full<'a, N>>
129    where
130        'a: 'b,
131    {
132        let (id, opt) = self.try_find_one::<T>();
133
134        match opt {
135            Some(ptr) => {
136                let _ = mem::replace(ptr, value.into());
137            }
138            None => {
139                if self.head == N {
140                    return Err(Full(self));
141                }
142                // SAFETY:
143                // head just checked.
144                unsafe { self.write(id, value.into()) }
145            }
146        }
147
148        Ok(self)
149    }
150
151    /// &T version of [Self::remove_mut]. See it for detail.
152    pub fn remove_ref<T: 'static>(&mut self) -> Option<&'a T> {
153        self.remove::<T, _, _>(|r| match r {
154            RefVariant::Ref(_) => match mem::replace(r, RefVariant::None) {
155                RefVariant::Ref(t) => Some(unsafe { &*(t as *const T) }),
156                // SAFETY:
157                // The branch was just checked before.
158                _ => unsafe { unreachable_unchecked() },
159            },
160            _ => None,
161        })
162    }
163
164    /// Remove &mut T from array. Return None when there is no according type in array.
165    pub fn remove_mut<T: 'static>(&mut self) -> Option<&'a mut T> {
166        self.remove::<T, _, _>(|r| match r {
167            RefVariant::Mut(_) => match mem::replace(r, RefVariant::None) {
168                RefVariant::Mut(t) => Some(unsafe { &mut *(t as *mut T) }),
169                // SAFETY:
170                // The branch was just checked before.
171                _ => unsafe { unreachable_unchecked() },
172            },
173            _ => None,
174        })
175    }
176
177    /// Get &T without removing it from the array.
178    ///
179    /// &T can be get from either insert with [Self::insert_mut] or [Self::insert_ref].
180    pub fn get<T: 'static>(&self) -> Option<&'a T> {
181        let id = TypeId::of::<T>();
182        self.try_find_init(&id).and_then(|r| match r {
183            RefVariant::Ref(t) => Some(unsafe { &*(*t as *const T) }),
184            RefVariant::Mut(t) => Some(unsafe { &*(*t as *mut T as *const T) }),
185            RefVariant::None => None,
186        })
187    }
188
189    /// Get &mut T without removing it from the array.
190    ///
191    /// &mut T can be get from insert with [Self::insert_mut].
192    pub fn get_mut<T: 'static>(&mut self) -> Option<&'a mut T> {
193        let id = TypeId::of::<T>();
194        self.try_find_init_mut(&id).and_then(|r| match r {
195            RefVariant::Mut(t) => Some(unsafe { &mut *(*t as *mut T) }),
196            _ => None,
197        })
198    }
199
200    fn try_find_one<T: 'static>(&mut self) -> (TypeId, Option<&mut RefVariant>) {
201        let id = TypeId::of::<T>();
202        let opt = self.try_find_init_mut(&id);
203        (id, opt)
204    }
205
206    fn remove<T, F, R>(&mut self, func: F) -> Option<R>
207    where
208        T: 'static,
209        R: 'a,
210        F: for<'r> Fn(&'r mut RefVariant) -> Option<R>,
211    {
212        let id = TypeId::of::<T>();
213        self.try_find_init_mut(&id).and_then(func)
214    }
215
216    fn try_find_init(&self, id: &TypeId) -> Option<&RefVariant> {
217        self.queue.iter().take(self.head).find_map(|v| {
218            // SAFETY:
219            // head tracks the items that are initialized and it's safe to assume.
220            let (i, opt) = unsafe { v.assume_init_ref() };
221            (i == id).then(|| opt)
222        })
223    }
224
225    fn try_find_init_mut(&mut self, id: &TypeId) -> Option<&mut RefVariant> {
226        self.queue.iter_mut().take(self.head).find_map(|v| {
227            // SAFETY:
228            // head tracks the items that are initialized and it's safe to assume.
229            let (i, opt) = unsafe { v.assume_init_mut() };
230            (i == id).then(|| opt)
231        })
232    }
233
234    // SAFETY:
235    // Caller must make sure Array is not full and head is not out of bound.
236    unsafe fn write(&mut self, id: TypeId, value: RefVariant) {
237        self.queue.get_unchecked_mut(self.head).write((id, value));
238        self.head += 1;
239    }
240}
241
242enum RefVariant {
243    None,
244    Mut(*mut ()),
245    Ref(*const ()),
246}
247
248impl<T> From<&T> for RefVariant {
249    fn from(t: &T) -> Self {
250        Self::Ref(t as *const T as *const ())
251    }
252}
253
254impl<T> From<&mut T> for RefVariant {
255    fn from(t: &mut T) -> Self {
256        Self::Mut(t as *mut T as *mut ())
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use super::*;
263
264    extern crate alloc;
265
266    use alloc::{boxed::Box, string::String, vec, vec::Vec};
267
268    #[test]
269    fn test() {
270        let map = PointerArray::default();
271
272        let mut s = String::from("hello,string!");
273
274        let s2 = String::from("hello,box!").into_boxed_str();
275
276        let map = map.insert_mut(&mut s).unwrap();
277
278        let map = map.insert_ref(&s2).unwrap();
279
280        fn scope(mut map: PointerArray<'_>) {
281            let s2 = map.remove_ref::<Box<str>>().unwrap();
282
283            let mut v = vec![String::from("hello,string!")];
284
285            let s = map.remove_mut::<String>().unwrap();
286
287            let mut map = map.insert_mut(&mut v).unwrap();
288
289            assert_eq!(s, "hello,string!");
290
291            let v = map.remove_mut::<Vec<String>>().unwrap();
292
293            assert_eq!(s, &v.pop().unwrap());
294
295            assert_eq!(&**s2, "hello,box!");
296        }
297
298        assert_eq!(&*s2, "hello,box!");
299
300        scope(map);
301
302        assert_eq!(s, "hello,string!");
303    }
304
305    #[test]
306    fn out_of_bound() {
307        let map = PointerArray::<1>::new();
308
309        let mut s = String::from("hello,string!");
310
311        let b = String::from("hello,box!").into_boxed_str();
312
313        let map = map.insert_mut(&mut s).unwrap();
314
315        assert!(map.insert_ref(&b).is_err());
316    }
317
318    #[test]
319    fn error_retake() {
320        let map = PointerArray::<1>::new();
321
322        let mut s = String::from("hello,string!");
323        let b = String::from("hello,box!").into_boxed_str();
324
325        let mut map = map
326            .insert_mut(&mut s)
327            .unwrap()
328            .insert_ref(&b)
329            .err()
330            .unwrap()
331            .into_inner();
332
333        assert_eq!(map.remove_mut::<String>().unwrap(), "hello,string!");
334    }
335
336    #[test]
337    fn ref_variant() {
338        let map = PointerArray::<2>::new();
339
340        let mut s = String::from("hello,string!");
341        let b = String::from("hello,box!").into_boxed_str();
342
343        let mut map = map.insert_mut(&mut s).unwrap().insert_ref(&b).unwrap();
344
345        assert!(map.remove_ref::<String>().is_none());
346        assert!(map.remove_mut::<Box<str>>().is_none());
347        assert!(map.remove_mut::<String>().is_some());
348        assert!(map.remove_ref::<Box<str>>().is_some());
349    }
350
351    #[test]
352    fn get() {
353        let map = PointerArray::<2>::new();
354
355        let mut s = String::from("hello,string!");
356        let b = String::from("hello,box!").into_boxed_str();
357
358        let mut map = map.insert_mut(&mut s).unwrap().insert_ref(&b).unwrap();
359
360        assert!(map.get::<String>().is_some());
361        assert!(map.get_mut::<String>().is_some());
362        assert!(map.get::<Box<str>>().is_some());
363        assert!(map.get_mut::<Box<str>>().is_none());
364    }
365}