sharify/
lib.rs

1#![warn(clippy::all)]
2
3//! This crate allows backing types with shared memory to send them cheaply
4//! between processes. Here's an example of doing so with a slice:
5//!
6//! ```
7//! use sharify::SharedMut;
8//! use std::{iter, sync::mpsc::channel, thread};
9//!
10//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
11//! // Create a slice backed by shared memory.
12//! let mut shared_slice: SharedMut<[u64]> = SharedMut::new(&(0, 1_000_000))?;
13//! // Write some data to it.
14//! for (src, dst) in
15//!     iter::successors(Some(0), |&p| Some(p + 1))
16//!     .zip(shared_slice.as_view_mut().iter_mut())
17//! {
18//!     *dst = src;
19//! }
20//! // The shared slice can be sent between processes cheaply without copying the
21//! // data. What is shown here for threads works equally well for processes,
22//! // e.g. using the ipc_channel crate.
23//! let (tx, rx) = channel::<SharedMut<[u64]>>();
24//! let handle = thread::spawn(move || {
25//!     let shared_slice = rx.recv().unwrap();
26//!     // Get a view into the shared memory
27//!     let view: &[u64] = shared_slice.as_view();
28//!     assert_eq!(view.len(), 1_000_000);
29//!     assert!(iter::successors(Some(0), |&p| Some(p + 1))
30//!         .zip(view.iter())
31//!         .all(|(a, &b)| a == b));
32//! });
33//! tx.send(shared_slice)?;
34//! handle.join().unwrap();
35//! # Ok(())
36//! # }
37//! ```
38//!
39//! The [`Shared`] and [`SharedMut`] structs wrap types to be backed by shared
40//! memory. They handle cheap serialization / deserialization by only
41//! serializing the metadata required to recreate the struct on the
42//! deserialization side. As a result, [`Shared`] and [`SharedMut`] can be used
43//! with inter-process channels (e.g. the [ipc-channel](https://crates.io/crates/ipc-channel) crate) the
44//! same way that the wrapped types are used with Rust's
45//! [builtin](std::sync::mpsc::channel) or [crossbeam](https://crates.io/crates/crossbeam-channel) inter-thread channels without copying the
46//! underlying data.
47//!
48//! Memory is managed through reference counts in the underlying shared memory.
49//! The wrappers behave as follows:
50//!
51//! <table>
52//! <tr>
53//! <th>
54//! </th>
55//! <th>
56//! Mutability
57//! </th>
58//! <th>
59//!
60//! Trait bounds on `T`
61//!
62//! </th>
63//! <th>
64//! Ownership
65//! </th>
66//! <th>
67//! Shared memory freed when...
68//! </th>
69//! </tr>
70//!
71//!
72//! <tr>
73//! <td>
74//!
75//! [`Shared<T>`]
76//!
77//! </td>
78//! <td>
79//! Immutable
80//! </td>
81//! <td>
82//!
83//! [`ShmemBacked`] + [`ShmemView`]
84//!
85//! </td>
86//! <td>
87//!
88//! Multiple ownership tracked with refcount, implements [`Clone`] to create
89//! another instance backed by the same shared memory.
90//!
91//! </td>
92//! <td>
93//!
94//! ...an instance with exclusive ownership of the shared memory drops **and**
95//! the [*serialization count*](#-safety-and-the-serialization-count) is 0.
96//!
97//! </td>
98//! </tr>
99//!
100//!
101//! <tr>
102//! <td>
103//!
104//! [`SharedMut<T>`]
105//!
106//! </td>
107//! <td>
108//! Mutable
109//! </td>
110//! <td>
111//!
112//! [`ShmemBacked`] + [`ShmemView`] + [`ShmemViewMut`]
113//!
114//! </td>
115//! <td>
116//!
117//! Exclusive ownership, but implements [`TryInto<Shared>`].
118//!
119//! </td>
120//! <td>
121//!
122//! ...an instance drops without serialization.
123//!
124//! </td>
125//! </tr>
126//!
127//!
128//! </table>
129//!
130//!
131//! # ⚠️ Safety and the serialization count
132//!
133//! When serializing a [`Shared`]/[`SharedMut`] to send it between processes
134//! the underlying shared memory must not be freed. However, calling
135//! [`Shared::into_serialized`]/[`SharedMut::into_serialized`] consumes
136//! the `Self` instance acting as the memory RAII guard. As
137//! a result, not deserializing at the other end can **leak the shared memory**.
138//! While this is not inherently [`unsafe`](https://doc.rust-lang.org/stable/nomicon/what-unsafe-does.html),
139//! it must be kept in mind when serializing.
140//!
141//! [`Shared`]s keep count of how many instances accessing the same shared
142//! memory have been serialized without a matching deserialization. Only when
143//! this *serialization count* is 0, i.e. there are
144//! no 'dangling' serializations, will the shared memory be freed when an
145//! instance with exclusive ownership drops. This is necessary so that the
146//! shared memory persists when a [`Shared`] with exclusive access is
147//! serialized/deserialized while sent between processes.
148//!
149//! The downside is that this approach only allows usage patterns where each
150//! serialization is paired with exactly one deserialization. If multiple
151//! receivers deserialize a [`Shared`] from a single serialization and drop, the
152//! shared memory may be freed before other receivers attempt to deserialize a
153//! different serialization. See `tests/serialization_count.rs` for an example
154//! of this situation. The opposite scenario is also bad - serializing the same
155//! [`Shared`] instance multiple times through [`serde::Serialize::serialize`]
156//! without matching deserializations will likely leak memory.
157//!
158//!
159//! # [`SharedMut`] and [`serde::Serialize`]
160//!
161//! A [`SharedMut`] represents unique ownership of the underlying shared memory.
162//! Because each serialization expects a matching deserialization, serializing
163//! should consume `Self` so that only one memory access exists in either
164//! instance or serialized form. [`serde::Serialize::serialize`], however, takes
165//! a `&Self` argument, which leaves the [`SharedMut`] intact. As a workaround
166//! to provide integration with [`serde`], calls to
167//! [`serde::Serialize::serialize`] invalidate `Self` through interior
168//! mutability. Any future use of `Self` produces a panic. This enforces
169//! the intended usage of dropping a [`SharedMut`] immediately after the
170//! serialization call.
171//!
172//!
173//! # Backing custom types with shared memory
174//!
175//! To be wrappable in [`Shared`], a type must implement
176//! the [`ShmemBacked`] and [`ShmemView`] traits. [`SharedMut`]s
177//! have an additional [`ShmemViewMut`] trait bound. See the example below for
178//! how to back a custom type with shared memory.
179//!
180//! ```
181//! use sharify::{Shared, ShmemBacked, ShmemView};
182//! use std::{sync::mpsc::channel, thread};
183//!
184//! // Holds a stack of images in contiguous memory.
185//! struct ImageStack {
186//!     data: Vec<u8>,
187//!     shape: [u16; 2],
188//! }
189//!
190//! // To back `ImageStack` with shared memory, it needs to implement `ShmemBacked`.
191//! unsafe impl ShmemBacked for ImageStack {
192//!     // Constructor arguments, (shape, n_images, init value).
193//!     type NewArg = ([u16; 2], usize, u8);
194//!     // Information required to create a view of an `ImageStack` from raw memory.
195//!     type MetaData = [u16; 2];
196//!
197//!     fn required_memory_arg((shape, n_images, _init): &Self::NewArg) -> usize {
198//!         shape.iter().product::<u16>() as usize * n_images
199//!     }
200//!     fn required_memory_src(src: &Self) -> usize {
201//!         src.data.len()
202//!     }
203//!     fn new(data: &mut [u8], (shape, _n_images, init): &Self::NewArg) -> Self::MetaData {
204//!         data.fill(*init);
205//!         *shape
206//!     }
207//!     fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
208//!         data.copy_from_slice(&src.data);
209//!         src.shape
210//!     }
211//! }
212//!
213//! // Create a referential struct as a view into the memory.
214//! struct ImageStackView<'a> {
215//!     data: &'a [u8],
216//!     shape: [u16; 2],
217//! }
218//!
219//! // The view must implement `ShmemView`.
220//! impl<'a> ShmemView<'a> for ImageStack {
221//!     type View = ImageStackView<'a>;
222//!     fn view(data: &'a [u8], shape: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
223//!         ImageStackView {
224//!             data,
225//!             shape: *shape,
226//!         }
227//!     }
228//! }
229//!
230//!
231//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
232//! // Existing stack with its data in a `Vec`.
233//! let stack = ImageStack {
234//!     data: vec![0; 640 * 640 * 100],
235//!     shape: [640, 640],
236//! };
237//! // Copy the stack into shared memory.
238//! let shared_stack: Shared<ImageStack> = Shared::new_from_inner(&stack)?;
239//! // The `data` field is now backed by shared memory so the stack can be sent
240//! // between processes cheaply. What is shown here for threads works equally
241//! // well for processes, e.g. using the ipc_channel crate.
242//! let (tx, rx) = channel::<Shared<ImageStack>>();
243//! let handle = thread::spawn(move || {
244//!     let shared_stack = rx.recv().unwrap();
245//!     // Get a view into the shared memory.
246//!     let view: ImageStackView = shared_stack.as_view();
247//!     assert!(view.data.iter().all(|&x| x == 0));
248//!     assert_eq!(view.shape, [640, 640]);
249//! });
250//! tx.send(shared_stack)?;
251//! handle.join().unwrap();
252//! # Ok(())
253//! # }
254//! ```
255//!
256//! # [`ndarray`] integration
257//! By default the `shared_ndarray` feature is enabled, which implements [`ShmemBacked`]
258//! for [`ndarray::Array`] and is useful for cheaply sending large arrays between
259//! processes.
260//!
261
262use raw_sync::locks::{LockInit, Mutex};
263use serde::{de, de::DeserializeOwned, ser, Deserialize, Deserializer, Serialize, Serializer};
264use std::cell::UnsafeCell;
265use std::convert::{From, Into, TryFrom, TryInto};
266use std::ops::{Deref, DerefMut};
267
268#[allow(dead_code)]
269mod shared_memory;
270use shared_memory::{Shmem, ShmemConf, ShmemError};
271
272#[allow(dead_code)]
273mod memory;
274use memory::{is_aligned, ALIGNMENT};
275
276#[derive(Debug)]
277pub enum Error {
278    UnalignedMemory,
279    Mutex(String),
280    Shmem(ShmemError),
281    Serialization(String),
282    Deserialization(String),
283    InvalidSharedMut,
284}
285
286impl std::fmt::Display for Error {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        match self {
289            Self::UnalignedMemory => write!(f, "Encountered unaligned memory."),
290            Self::Shmem(e) => e.fmt(f),
291            Self::Mutex(s) | Self::Serialization(s) | Self::Deserialization(s) => {
292                write!(f, "{}", s)
293            }
294            Self::InvalidSharedMut => write!(f, "Trying to use a `SharedMut` previously invalidated with a call to `serde::Serialize::serialize`.")
295        }
296    }
297}
298
299impl std::error::Error for Error {}
300
301/// Implemented for types which can be wrapped in a [`Shared`] or [`SharedMut`]
302/// for cheap sharing across processes.
303///
304/// This trait is closely
305/// connected to the [`ShmemView`] and [`ShmemViewMut`] traits. See the
306/// crate-level documentation for how to implement this trait on custom types.
307///
308/// # Safety
309///
310/// This trait is unsafe because care must be taken to return the correct memory
311/// size and metadata. The [`ShmemView`]/[`ShmemViewMut`] traits rely on valid
312/// metadata to create view types from raw memory, usually using `unsafe`
313/// operations. Incorrect metadata can produce invalid view types.
314//
315// IDEA: Generic associated types should allow constraining lifetime parameters
316// on generic return types, obviating the need for separate
317// [`ShmemView`]/[`ShmemViewMut`] traits. https://github.com/rust-lang/rfcs/blob/master/text/1598-generic_associated_types.md
318pub unsafe trait ShmemBacked {
319    /// The type of user-defined arguments passed to
320    /// [`Self::required_memory_arg`] and [`Self::new`]. The
321    /// [`new`](SharedMut::new) function on the [`Shared`]/[`SharedMut`]
322    /// wrappers accepts an argument of this type to construct a shared
323    /// memory-backed version of `Self`.
324    type NewArg: ?Sized;
325    /// The information required to construct
326    /// [`ShmemView::View`]/[`ShmemViewMut::View`] types from raw memory.
327    type MetaData: Serialize + DeserializeOwned + Clone;
328
329    /// Returns the memory size in bytes required by a shared memory-backed
330    /// version of `Self` created with the user-defined constructor argument
331    /// `arg`.
332    fn required_memory_arg(arg: &Self::NewArg) -> usize;
333    /// Returns the required shared memory size in bytes
334    /// to hold a copy of `src`.
335    fn required_memory_src(src: &Self) -> usize;
336    /// Fills the shared memory at `data` with bytes according to the
337    /// user-defined constructor argument `arg`  and returns the metadata
338    /// required to create a [`ShmemView::View`]/[`ShmemViewMut::View`] from
339    /// `data`.
340    fn new(data: &mut [u8], arg: &Self::NewArg) -> Self::MetaData;
341    /// Fills the shared memory at `data` with a copy of `src` and returns the
342    /// metadata required to create a [`ShmemView::View`]/[`ShmemViewMut::View`]
343    /// from `data`.
344    fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData;
345}
346
347/// An immutable view into shared memory.
348pub trait ShmemView<'a>: ShmemBacked {
349    type View;
350    /// Creates a [`Self::View`] into the shared memory at `data` based on the information
351    /// in `metadata`.
352    fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View;
353}
354
355/// An mutable view into shared memory.
356pub trait ShmemViewMut<'a>: ShmemBacked {
357    type View;
358    /// Creates a [`Self::View`] into the shared memory at `data` based on the information
359    /// in `metadata`.
360    fn view_mut(
361        data: &'a mut [u8],
362        metadata: &'a mut <Self as ShmemBacked>::MetaData,
363    ) -> Self::View;
364}
365
366unsafe impl ShmemBacked for str {
367    type NewArg = str;
368    // str size in bytes, only used for asserts
369    type MetaData = usize;
370
371    fn required_memory_arg(src: &Self::NewArg) -> usize {
372        src.len()
373    }
374
375    fn required_memory_src(src: &Self) -> usize {
376        src.len()
377    }
378
379    /// Creates a string slice at `data` with the contents of `src`.
380    fn new(data: &mut [u8], src: &Self::NewArg) -> Self::MetaData {
381        assert_eq!(data.len(), Self::required_memory_arg(src));
382        data.copy_from_slice(src.as_bytes());
383        data.len()
384    }
385
386    /// Creates a string slice at `data` with the contents of `src`.
387    fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
388        assert_eq!(data.len(), Self::required_memory_src(src));
389        data.copy_from_slice(src.as_bytes());
390        data.len()
391    }
392}
393
394impl<'a> ShmemView<'a> for str {
395    type View = &'a str;
396
397    fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
398        assert_eq!(data.len(), *metadata);
399        unsafe { std::str::from_utf8_unchecked(data) }
400    }
401}
402
403impl<'a> ShmemViewMut<'a> for str {
404    type View = &'a mut str;
405
406    fn view_mut(
407        data: &'a mut [u8],
408        metadata: &'a mut <Self as ShmemBacked>::MetaData,
409    ) -> Self::View {
410        assert_eq!(data.len(), *metadata);
411        unsafe { std::str::from_utf8_unchecked_mut(data) }
412    }
413}
414
415unsafe impl<T> ShmemBacked for [T]
416where
417    T: Copy,
418{
419    type NewArg = (T, usize);
420    // slice size in elements of T
421    type MetaData = usize;
422
423    fn required_memory_arg((_, len): &Self::NewArg) -> usize {
424        *len * std::mem::size_of::<T>()
425    }
426
427    fn required_memory_src(src: &Self) -> usize {
428        src.len() * std::mem::size_of::<T>()
429    }
430
431    /// Creates a slice at `data` of length `len` filled with `init`.
432    fn new(data: &mut [u8], &(init, len): &Self::NewArg) -> Self::MetaData {
433        assert_eq!(data.len(), Self::required_memory_arg(&(init, len)));
434        let data_typed =
435            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, len) };
436        for elem in data_typed.iter_mut() {
437            *elem = init;
438        }
439        len
440    }
441
442    /// Creates a slice at `data` with the contents of `src`.
443    fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
444        assert_eq!(data.len(), Self::required_memory_src(src));
445        let data_typed =
446            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, src.len()) };
447        data_typed.copy_from_slice(src);
448        data_typed.len()
449    }
450}
451
452impl<'a, T> ShmemView<'a> for [T]
453where
454    T: Copy + 'a,
455{
456    type View = &'a [T];
457
458    fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
459        unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, *metadata) }
460    }
461}
462
463impl<'a, T> ShmemViewMut<'a> for [T]
464where
465    T: Copy + 'a,
466{
467    type View = &'a mut [T];
468
469    fn view_mut(
470        data: &'a mut [u8],
471        metadata: &'a mut <Self as ShmemBacked>::MetaData,
472    ) -> Self::View {
473        unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, *metadata) }
474    }
475}
476
477use num_traits::{ops::checked::CheckedAdd, sign::Unsigned, NumOps};
478
479trait AccessCounter: Copy + PartialOrd + Eq + NumOps + Unsigned + CheckedAdd {}
480
481impl<T: Copy + PartialOrd + Eq + NumOps + Unsigned + CheckedAdd> AccessCounter for T {}
482
483struct ShmemBase<A: AccessCounter, P: DropBehaviour, const N: usize> {
484    tag: P,
485    shmem: Shmem,
486    access_counter_type: std::marker::PhantomData<A>,
487    // In bytes
488    counter_offset: usize,
489    // In bytes
490    data_offset: usize,
491    // In bytes
492    data_size: usize,
493    free_shmem: bool,
494}
495
496impl<A: AccessCounter, P: DropBehaviour, const N: usize> ShmemBase<A, P, N> {
497    /// Creates uninitialized shared memory to hold data of size `data_size`.
498    /// The region at the start of the memory segment contains a mutex and N
499    /// access counters of type `A` protected by the mutex. Use
500    /// [`mutex_cond_write`] to read / change the access counter. The behaviour
501    /// on `Drop` depends on the `DropBehaviour` impl on the `tag` field.
502    fn new(access_counter: &[A; N], data_size: usize, tag: P) -> Result<Self, Error> {
503        // `size_of` seems to check for `*mut u8` (or `usize`??) aligment of the
504        // internal lock on *nix, see https://github.com/elast0ny/raw_sync-rs/blob/506c542600de9d79439ab4be5297e760bd428a2c/src/locks/unix.rs
505        // Since below we're using the aligment checks for arrow which require at least
506        // 8-byte alignment, we should be ok passing `None` here, bypassing the
507        // check. This assumes a pointer size of no more than 8 bytes (i.e. no
508        // larger than 64-bit systems).
509        let lock_size = Mutex::size_of(None);
510        // Ensure that the start of the memory region for user data follows arrow
511        // alignment.
512        let reserved_size =
513            ((lock_size + std::mem::size_of::<A>() * N) / ALIGNMENT + 1) * ALIGNMENT;
514        // Create shared memory.
515        let shmem = ShmemConf::new()
516            .size(reserved_size + data_size)
517            .create()
518            .map_err(Error::Shmem)?;
519        // Create a mutex.
520        let (mutex, _) = unsafe {
521            Mutex::new(shmem.as_ptr(), shmem.as_ptr().add(lock_size))
522                .map_err(|e| Error::Mutex(format!("{}", e)))?
523        };
524        // Initialize the sentinel value.
525        {
526            let lock = mutex.lock().map_err(|e| Error::Mutex(format!("{}", e)))?;
527            unsafe {
528                let counter_ptr = std::slice::from_raw_parts_mut(*lock as *mut A, N);
529                counter_ptr.copy_from_slice(access_counter);
530            }
531        }
532        // Prevents the mutex destructor from running, so we can use the same mutex for
533        // serialization/deserialization later
534        std::mem::forget(mutex);
535        // TODO: Should we guarantee this?
536        let aligned = unsafe { is_aligned(shmem.as_ptr().add(reserved_size), ALIGNMENT) };
537        if !aligned {
538            Err(Error::UnalignedMemory)
539        } else {
540            Ok(Self {
541                tag,
542                shmem,
543                access_counter_type: std::marker::PhantomData,
544                counter_offset: lock_size,
545                data_offset: reserved_size,
546                data_size,
547                free_shmem: true,
548            })
549        }
550    }
551
552    /// Uses the mutex at `shmem.as_ptr()` to write the value returned by
553    /// `write` to the sentinel location. If `write` returns `None`, nothing
554    /// is written. Returns the previously held sentinel value.
555    ///
556    /// This function allows interior mutability through a non-mutable
557    /// reference. This is necessary to allow implementing
558    /// [`serde::Serialize`] for [`Shared`]/[`SharedMut`].
559    /// ([`serde::Serialize::serialize`] takes a non-mutable reference.)
560    fn mutex_write(&self, write: fn(&[A; N]) -> Option<[A; N]>) -> Result<[A; N], Error> {
561        let (mutex, _) = unsafe {
562            let counter_ptr = self.shmem.as_ptr().add(self.counter_offset);
563            Mutex::from_existing(self.shmem.as_ptr(), counter_ptr)
564                .map_err(|e| Error::Mutex(format!("{}", e)))?
565        };
566        let counter_values = {
567            let lock = mutex.lock().map_err(|e| Error::Mutex(format!("{}", e)))?;
568            let counter_ptr = unsafe { std::slice::from_raw_parts_mut(*lock as *mut A, N) };
569            let mut old = [A::zero(); N];
570            old.copy_from_slice(counter_ptr);
571            if let Some(new) = write(&old) {
572                counter_ptr.copy_from_slice(&new);
573            }
574            old
575        };
576        // Prevents the mutex destructor from running, so we can use the same mutex for
577        // serialization/deserialization later
578        std::mem::forget(mutex);
579        Ok(counter_values)
580    }
581
582    /// If `free` is `true`, `ShmemBase` frees the shared memory & mutex on
583    /// drop depending on the `DropBehaviour` impl on the `tag` field.
584    fn free_on_drop(&mut self, free: bool) {
585        self.free_shmem = free;
586    }
587
588    /// Returns a pointer to the data segment of the underlying shared memory.
589    fn data_ptr(&self) -> *const u8 {
590        unsafe { self.shmem.as_ptr().add(self.data_offset) }
591    }
592
593    /// Returns a mutable pointer to the data segment of the underlying shared
594    /// memory.
595    fn data_ptr_mut(&mut self) -> *mut u8 {
596        unsafe { self.shmem.as_ptr().add(self.data_offset) }
597    }
598
599    fn to_wire_format<T>(&self, metadata: T) -> WireFormat<T> {
600        WireFormat {
601            tag: self.tag.clone().into(),
602            os_id: String::from(self.shmem.get_os_id()),
603            mem_size: self.shmem.len(),
604            counter_offset: self.counter_offset,
605            data_offset: self.data_offset,
606            data_size: self.data_size,
607            meta: metadata,
608        }
609    }
610
611    /// Serializes a `ShmemBase` into the wire format **without** freeing the
612    /// shared memory & mutex.
613    fn into_wire_format<T>(mut self, metadata: T) -> WireFormat<T> {
614        // Prevents the shared memory & mutex from being freed on drop
615        self.free_shmem = false;
616        WireFormat {
617            tag: self.tag.clone().into(),
618            os_id: String::from(self.shmem.get_os_id()),
619            mem_size: self.shmem.len(),
620            counter_offset: self.counter_offset,
621            data_offset: self.data_offset,
622            data_size: self.data_size,
623            meta: metadata,
624        }
625    }
626
627    /// Creates a `ShmemBase` which **does not** free the shared memory & mutex
628    /// on drop.
629    fn from_wire_format<T>(wire_format: WireFormat<T>) -> Result<(Self, T), Error> {
630        let WireFormat {
631            tag,
632            os_id,
633            mem_size,
634            counter_offset,
635            data_offset,
636            data_size,
637            meta,
638        } = wire_format;
639        let shmem = ShmemConf::new()
640            .os_id(os_id)
641            .size(mem_size)
642            .open()
643            .map_err(Error::Shmem)?;
644        Ok((
645            Self {
646                tag: tag.try_into()?,
647                shmem,
648                access_counter_type: std::marker::PhantomData,
649                counter_offset,
650                data_offset,
651                data_size,
652                free_shmem: false,
653            },
654            meta,
655        ))
656    }
657}
658
659impl<A: AccessCounter, P: DropBehaviour, const N: usize> Drop for ShmemBase<A, P, N> {
660    fn drop(&mut self) {
661        P::called_on_drop(self);
662    }
663}
664
665trait DropBehaviour: Clone + TryFrom<Tag, Error = Error> + Into<Tag> {
666    fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
667        base: &mut ShmemBase<A, P, N>,
668    );
669}
670
671impl<A: AccessCounter> Clone for ShmemBase<A, SharedTag, 2> {
672    fn clone(&self) -> Self {
673        let shmem = ShmemConf::new()
674            .os_id(self.shmem.get_os_id())
675            .size(self.shmem.len())
676            .open()
677            .unwrap();
678        let new = Self {
679            tag: SharedTag(),
680            shmem,
681            access_counter_type: std::marker::PhantomData,
682            counter_offset: self.counter_offset,
683            data_offset: self.data_offset,
684            data_size: self.data_size,
685            free_shmem: true,
686        };
687        // Increment the access counter, keep the serialization counter
688        let write: fn(&[A; 2]) -> Option<[A; 2]> = |old| {
689            let mut new = [A::zero(); 2];
690            if let Some(new_acc_count) = old[0].checked_add(&A::one()) {
691                new[0] = new_acc_count;
692                new[1] = old[1];
693                Some(new)
694            } else {
695                panic!("Can't have more than A::MAX `Shared`s with simultaneous access.")
696            }
697        };
698        new.mutex_write(write).unwrap();
699        new
700    }
701}
702
703impl<A: AccessCounter> From<ShmemBase<A, SharedMutTag, 2>> for ShmemBase<A, SharedTag, 2> {
704    fn from(mut shared_mut: ShmemBase<A, SharedMutTag, 2>) -> Self {
705        // Don't free the shared memory on drop.
706        shared_mut.free_on_drop(false);
707        // Need to create new `Shmem` because `shared_mut` implements `Drop`, so we
708        // can't move out of it. Option<shmem>?
709        let shmem = ShmemConf::new()
710            .os_id(shared_mut.shmem.get_os_id())
711            .size(shared_mut.shmem.len())
712            .open()
713            .unwrap();
714        let new = Self {
715            tag: SharedTag(),
716            shmem,
717            access_counter_type: std::marker::PhantomData,
718            counter_offset: shared_mut.counter_offset,
719            data_offset: shared_mut.data_offset,
720            data_size: shared_mut.data_size,
721            free_shmem: true,
722        };
723        // Set the access counter to 1 and the serialization counter to 0
724        new.mutex_write(|_| Some([A::one(), A::zero()])).unwrap();
725        new
726    }
727}
728
729#[derive(Serialize, Deserialize, Clone)]
730struct SharedTag();
731
732impl TryFrom<Tag> for SharedTag {
733    type Error = Error;
734    fn try_from(value: Tag) -> Result<Self, Self::Error> {
735        match value {
736            Tag::Shared(tag) => Ok(tag),
737            Tag::SharedMut(_) => Err(Error::Deserialization(String::from(
738                "Can't deserialize a `Shared` from a `SharedMut` serialization.",
739            ))),
740        }
741    }
742}
743
744impl DropBehaviour for SharedTag {
745    fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
746        base: &mut ShmemBase<A, P, N>,
747    ) {
748        // `free_shmem` is **only** true when `base` drops outside of a serialization
749        // function.
750        if base.free_shmem {
751            // Decrement the access counter, keep the serialization counter
752            let write: fn(&[A; N]) -> Option<[A; N]> = |old| {
753                let mut new = [A::zero(); N];
754                new[0] = old[0] - A::one();
755                new[1] = old[1];
756                Some(new)
757            };
758            let counter_value = base.mutex_write(write).unwrap();
759            // Only free the shared memory if the access counter was 1 and the serialization
760            // counter was 0 (exclusive access).
761            if (counter_value[0] == A::one()) && (counter_value[1] == A::zero()) {
762                // Drops the mutex in shared memory, cleaning up any additional (kernel?)
763                // memory. This seems to be unnessecary on *nix, see
764                // https://stackoverflow.com/questions/39822987/pthread-mutexattr-process-shared-memory-leak
765                unsafe {
766                    let counter_ptr = base.shmem.as_ptr().add(base.counter_offset);
767                    Mutex::from_existing(base.shmem.as_ptr(), counter_ptr).unwrap();
768                }
769                base.shmem.set_owner(true);
770            } else {
771                base.shmem.set_owner(false);
772            }
773        } else {
774            // Prevents 'shmem' from freeing the shared memory when dropped.
775            // Windows behavior should be fixed with https://github.com/elast0ny/shared_memory-rs/pull/59
776            base.shmem.set_owner(false);
777        }
778    }
779}
780
781/// Wrapper type for immutable access to shared memory from multiple processes.
782pub struct Shared<T>
783where
784    T: ShmemBacked + ?Sized,
785{
786    metadata: <T as ShmemBacked>::MetaData,
787    shmem: ShmemBase<u64, SharedTag, 2>,
788}
789
790impl<T> Shared<T>
791where
792    T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
793{
794    pub fn new(arg: &<T as ShmemBacked>::NewArg) -> Result<Self, Error> {
795        let size = T::required_memory_arg(&arg);
796        // The strategy is to use a pair of u64s as counters: the first one keeps track
797        // of the number of `Shared`s accessing the shared memory (access
798        // counter), the second one tracks the number of serialized instances
799        // (serialization counter).
800        //
801        // The counters live at the start of the shared memory region. The access
802        // counter is incremented when calling `new`, `deserialize` or `clone`
803        // and decremented when dropping or calling
804        // `into_serialized`. The serialization
805        // counter is incremented when calling
806        // `into_serialized`/`serde::Serialize::serialize`, and decremented when
807        // calling `deserialize`.
808        //
809        // Shared memory is only freed if the access count is 1 and the serialization
810        // count is 0 (exclusive access) on drop and we're not serializing.
811        let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedTag())?;
812        let metadata = unsafe {
813            let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
814            T::new(data, arg)
815        };
816        Ok(Shared { metadata, shmem })
817    }
818
819    pub fn new_from_inner(arg: &T) -> Result<Self, Error> {
820        let size = T::required_memory_src(&arg);
821        let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedTag())?;
822        let inner = unsafe {
823            let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
824            T::new_from_src(data, arg)
825        };
826        Ok(Shared {
827            metadata: inner,
828            shmem,
829        })
830    }
831
832    #[allow(clippy::clippy::needless_lifetimes)]
833    pub fn as_view<'a>(&'a self) -> <T as ShmemView<'a>>::View {
834        let data =
835            unsafe { std::slice::from_raw_parts(self.shmem.data_ptr(), self.shmem.data_size) };
836        T::view(data, &self.metadata)
837    }
838
839    #[cfg(test)]
840    pub fn counts(&self) -> Result<[u64; 2], Error> {
841        self.shmem.mutex_write(|_| None)
842    }
843
844    pub fn into_serialized<S: Serializer>(self, serializer: S) -> Result<S::Ok, S::Error> {
845        // Decrements the access counter and increments the serialized counter (see
846        // [`Self::new`]).
847        let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
848            let mut new = [0_u64, 0];
849            if let Some(new_ser_count) = old[1].checked_add(1) {
850                new[0] = old[0] - 1;
851                new[1] = new_ser_count;
852                Some(new)
853            } else {
854                panic!("Can't have more than A::MAX serialized `Shared`s.")
855            }
856        };
857        self.shmem.mutex_write(write).map_err(ser::Error::custom)?;
858        let wire_format = self.shmem.into_wire_format(self.metadata);
859        wire_format.serialize(serializer)
860    }
861}
862
863impl<T> Serialize for Shared<T>
864where
865    T: ShmemBacked + ?Sized,
866{
867    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
868    where
869        S: Serializer,
870    {
871        // Increments the serialized counter (see [`Self::new`]).
872        let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
873            let mut new = [0_u64, 0];
874            if let Some(new_ser_count) = old[1].checked_add(1) {
875                new[0] = old[0];
876                new[1] = new_ser_count;
877                Some(new)
878            } else {
879                panic!("Can't have more than A::MAX serialized `Shared`s.")
880            }
881        };
882        self.shmem.mutex_write(write).map_err(ser::Error::custom)?;
883        let wire_format = self.shmem.to_wire_format(&self.metadata);
884        wire_format.serialize(serializer)
885    }
886}
887
888impl<'de, T> Deserialize<'de> for Shared<T>
889where
890    T: ShmemBacked + ?Sized,
891{
892    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
893    where
894        D: Deserializer<'de>,
895    {
896        let wire_format = WireFormat::deserialize(deserializer)?;
897        // The created `shmem` does not free the shared memory, see
898        // (`SharedBase::from_wire_format`).
899        let (mut shmem, metadata) = ShmemBase::<u64, SharedTag, 2>::from_wire_format(wire_format)
900            .map_err(de::Error::custom)?;
901        // Increments the access counter and decrements the serialized counter (see
902        // [`Self::new`]).
903        let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
904            let mut new = [0_u64, 0];
905            if let Some(new_acc_count) = old[0].checked_add(1) {
906                new[0] = new_acc_count;
907                new[1] = old[1].saturating_sub(1);
908                Some(new)
909            } else {
910                None
911            }
912        };
913        if shmem.mutex_write(write).map_err(de::Error::custom)?[0] < u64::MAX {
914            // `shmem` does not free the shared memory in case of an error.
915            // In the `Ok` case it should do so.
916            shmem.free_on_drop(true);
917            Ok(Shared { metadata, shmem })
918        } else {
919            Err(de::Error::custom(
920                "Can't have more than u64::MAX `Shared`s with simultaneous access.",
921            ))
922        }
923    }
924}
925
926unsafe impl<T> Send for Shared<T>
927where
928    T: ShmemBacked + ?Sized,
929    T::MetaData: Send,
930{
931}
932// TODO: Is the mutex around the counters enough to make a `Shared` `Sync`?
933// unsafe impl<T> Sync for Shared<T> where T: ShmemBacked + ?Sized, T::MetaData:
934// Sync {}
935
936impl<'a, T> From<&'a T> for Shared<T>
937where
938    T: ShmemBacked + for<'b> ShmemView<'b> + ?Sized,
939{
940    fn from(src: &'a T) -> Self {
941        Self::new_from_inner(src).unwrap()
942    }
943}
944
945impl<T> Clone for Shared<T>
946where
947    T: ShmemBacked + ?Sized,
948{
949    /// Creates another `Shared` instance backed by the same shared memory. No
950    /// data is copied.
951    fn clone(&self) -> Self {
952        let Shared {
953            shmem, metadata, ..
954        } = self;
955        let shmem = shmem.clone();
956        Self {
957            metadata: metadata.clone(),
958            shmem,
959        }
960    }
961}
962
963impl<T> TryFrom<SharedMut<T>> for Shared<T>
964where
965    T: ShmemBacked + ?Sized,
966{
967    type Error = Error;
968
969    /// This method fails if the `SharedMut` has previously been
970    /// invalidated with a call to `serde::Serialize::serialize`.
971    fn try_from(shared_mut: SharedMut<T>) -> Result<Self, Self::Error> {
972        let SharedMut { shmem, metadata } = shared_mut;
973        if let Some(shmem) = shmem.into_inner() {
974            let shmem: ShmemBase<_, SharedTag, 2> = shmem.into();
975            Ok(Self { metadata, shmem })
976        } else {
977            Err(Error::InvalidSharedMut)
978        }
979    }
980}
981
982#[derive(Serialize, Deserialize, Clone)]
983struct SharedMutTag();
984
985impl TryFrom<Tag> for SharedMutTag {
986    type Error = Error;
987    fn try_from(value: Tag) -> Result<Self, Self::Error> {
988        match value {
989            Tag::SharedMut(tag) => Ok(tag),
990            Tag::Shared(_) => Err(Error::Deserialization(String::from(
991                "Can't deserialize a `SharedMut` from a `Shared` serialization.",
992            ))),
993        }
994    }
995}
996
997impl DropBehaviour for SharedMutTag {
998    fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
999        base: &mut ShmemBase<A, P, N>,
1000    ) {
1001        if base.free_shmem {
1002            // Drops the mutex in shared memory, cleaning up any additional (kernel?)
1003            // memory. This seems to be unnessecary on *nix, see
1004            // https://stackoverflow.com/questions/39822987/pthread-mutexattr-process-shared-memory-leak
1005            unsafe {
1006                let counter_ptr = base.shmem.as_ptr().add(base.counter_offset);
1007                Mutex::from_existing(base.shmem.as_ptr(), counter_ptr).unwrap();
1008            }
1009            base.shmem.set_owner(true);
1010        } else {
1011            // Prevents 'shmem' from freeing the shared memory when dropped.
1012            // Windows behavior should be fixed with https://github.com/elast0ny/shared_memory-rs/pull/59
1013            base.shmem.set_owner(false);
1014        }
1015    }
1016}
1017
1018///  Safe mutable access to shared memory from multiple processes through unique ownership.
1019pub struct SharedMut<T>
1020where
1021    T: ShmemBacked + ?Sized,
1022{
1023    metadata: <T as ShmemBacked>::MetaData,
1024    shmem: UnsafeCell<Option<ShmemBase<u64, SharedMutTag, 2>>>,
1025}
1026
1027impl<T> SharedMut<T>
1028where
1029    T: ShmemBacked + for<'a> ShmemView<'a> + for<'a> ShmemViewMut<'a> + ?Sized,
1030{
1031    pub fn new(arg: &<T as ShmemBacked>::NewArg) -> Result<Self, Error> {
1032        let size = T::required_memory_arg(&arg);
1033        // The strategy is to reserve a u64 at the start of the memory segment which is
1034        // set to 1 when a `SharedMut` instance exists in any process and set to
1035        // 0 otherwise. This lets us throw an error if attempting to deserialize
1036        // a `SharedMut` backed by shared memory which is already accessed from
1037        // somewhere else.
1038        //
1039        // The `access_counter` array has two elements to make the memory behind
1040        // `SharedMut` compatible with `Shared` for easy conversion with
1041        // `From<SharedMut> for Shared`, the second element is not used.
1042        let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedMutTag())?;
1043        let metadata = unsafe {
1044            let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
1045            T::new(data, arg)
1046        };
1047        Ok(SharedMut {
1048            metadata,
1049            shmem: UnsafeCell::new(Some(shmem)),
1050        })
1051    }
1052
1053    pub fn new_from_inner(arg: &T) -> Result<Self, Error> {
1054        let size = T::required_memory_src(&arg);
1055        let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedMutTag())?;
1056        let metadata = unsafe {
1057            let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
1058            T::new_from_src(data, arg)
1059        };
1060        Ok(SharedMut {
1061            metadata,
1062            shmem: UnsafeCell::new(Some(shmem)),
1063        })
1064    }
1065
1066    #[allow(clippy::clippy::needless_lifetimes)]
1067    pub fn as_view_mut<'a>(&'a mut self) -> <T as ShmemViewMut<'a>>::View {
1068        let shmem =
1069            self.shmem.get_mut().as_mut().expect(
1070                "`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
1071            );
1072        let data = unsafe { std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size) };
1073        T::view_mut(data, &mut self.metadata)
1074    }
1075}
1076
1077impl<T> SharedMut<T>
1078where
1079    T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1080{
1081    #[allow(clippy::clippy::needless_lifetimes)]
1082    pub fn as_view<'a>(&'a self) -> <T as ShmemView<'a>>::View {
1083        let data = unsafe {
1084            let shmem: &mut _ = &mut *self.shmem.get();
1085            let shmem = shmem.as_mut().expect(
1086                "`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
1087            );
1088            std::slice::from_raw_parts(shmem.data_ptr(), shmem.data_size)
1089        };
1090        T::view(data, &self.metadata)
1091    }
1092
1093    /// Allows directly modifying the metadata used to create the view types
1094    /// from raw memory. See the [`ShmemBacked`]/[`ShmemView`]/[`ShmemViewMut`]
1095    /// traits for details.
1096    ///
1097    /// # Safety
1098    /// [`ShmemView::view`]/[`ShmemViewMut::view_mut`] use metadata to create
1099    /// views from raw memory, usually through the use of `unsafe`
1100    /// operations. Directly changing the metadata is almost always a bad
1101    /// idea unless the changes come from a view into the same memory.
1102    pub unsafe fn metadata_mut(&mut self) -> &mut <T as ShmemBacked>::MetaData {
1103        &mut self.metadata
1104    }
1105
1106    #[cfg(test)]
1107    pub fn counts(&mut self) -> Result<[u64; 1], Error> {
1108        let shmem =
1109            self.shmem.get_mut().as_mut().expect(
1110                "`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
1111            );
1112        let mut access_count = [0];
1113        let counts = shmem.mutex_write(|_| None)?;
1114        access_count[0] = counts[0];
1115        Ok(access_count)
1116    }
1117
1118    pub fn into_serialized<S: Serializer>(self, serializer: S) -> Result<S::Ok, S::Error> {
1119        let shmem = self
1120            .shmem
1121            .into_inner()
1122            .expect("`SharedMut` must not be used after a call to `serde::Serialize::serialize`.");
1123        // Sets the access counter (see [`Self::new`]) to 0, indicating we're done
1124        // accessing the memory segment.
1125        let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
1126            debug_assert_eq!(old[0], 1);
1127            Some([0_u64, 0])
1128        };
1129        shmem.mutex_write(write).map_err(ser::Error::custom)?;
1130        let wire_format = shmem.into_wire_format(self.metadata);
1131        wire_format.serialize(serializer)
1132    }
1133}
1134
1135impl<T> Serialize for SharedMut<T>
1136where
1137    T: ShmemBacked + ?Sized,
1138{
1139    /// # Safety
1140    /// This function can only be called **once**. It makes use of interior
1141    /// mutability to invalidate `Self`, any future use leads to a panic.
1142    /// This is a workaround to provide an implementation of
1143    /// [`serde::Serialize`] even though `serialize` takes an immutable
1144    /// reference to `Self`.
1145    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1146    where
1147        S: Serializer,
1148    {
1149        let shmem: &mut _ = unsafe { &mut *self.shmem.get() };
1150        let shmem = shmem
1151            .take()
1152            .expect("`SharedMut` must not be used after a call to `serde::Serialize::serialize`.");
1153        // Sets the access counter (see [`Self::new`]) to 0, indicating we're done
1154        // accessing the memory segment.
1155        let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
1156            debug_assert_eq!(old[0], 1);
1157            Some([0_u64, 0])
1158        };
1159        shmem.mutex_write(write).map_err(ser::Error::custom)?;
1160        let wire_format = shmem.into_wire_format(self.metadata.clone());
1161        wire_format.serialize(serializer)
1162    }
1163}
1164
1165impl<'de, T> Deserialize<'de> for SharedMut<T>
1166where
1167    T: ShmemBacked + ?Sized,
1168{
1169    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1170    where
1171        D: Deserializer<'de>,
1172    {
1173        let wire_format = WireFormat::deserialize(deserializer)?;
1174        // The created `shmem` does not free the shared memory, see
1175        // (`SharedBase::from_wire_format`).
1176        let (mut shmem, metadata) =
1177            ShmemBase::<u64, SharedMutTag, 2>::from_wire_format(wire_format)
1178                .map_err(de::Error::custom)?;
1179        // If the access counter is 0 (no other `SharedMut`s accessing this data exist),
1180        // set the value to 1 and create a `Self` instance. If not throw an error.
1181        let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
1182            debug_assert!(old[0] <= 1);
1183            if old[0] == 0 {
1184                Some([1, 0])
1185            } else {
1186                None
1187            }
1188        };
1189        if shmem.mutex_write(write).map_err(de::Error::custom)?[0] == 0 {
1190            // `shmem` does not free the shared memory in case of an error.
1191            // In the `Ok` case it should do so.
1192            shmem.free_on_drop(true);
1193            Ok(SharedMut {
1194                metadata,
1195                shmem: UnsafeCell::new(Some(shmem)),
1196            })
1197        } else {
1198            Err(de::Error::custom("A shared memory region can only be accessed by one `SharedMut` instance at any time. Note that the existing instance may live in a different process."))
1199        }
1200    }
1201}
1202
1203unsafe impl<T> Send for SharedMut<T>
1204where
1205    T: ShmemBacked + ?Sized,
1206    T::MetaData: Send,
1207{
1208}
1209// A `SharedMut` is **not** `Sync` because the `serde::Serialize` impl relies on
1210// interior mutability through an immutable reference with an `UnsafeCell`.
1211
1212impl<'a, T> From<&'a T> for SharedMut<T>
1213where
1214    T: ShmemBacked + for<'b> ShmemView<'b> + for<'b> ShmemViewMut<'b> + ?Sized,
1215{
1216    fn from(src: &'a T) -> Self {
1217        Self::new_from_inner(src).unwrap()
1218    }
1219}
1220
1221#[derive(Serialize, Deserialize)]
1222struct WireFormat<T> {
1223    tag: Tag,
1224    os_id: String,
1225    mem_size: usize,
1226    counter_offset: usize,
1227    data_offset: usize,
1228    data_size: usize,
1229    meta: T,
1230}
1231
1232#[derive(Serialize, Deserialize)]
1233enum Tag {
1234    Shared(SharedTag),
1235    SharedMut(SharedMutTag),
1236}
1237
1238impl From<SharedTag> for Tag {
1239    fn from(shared: SharedTag) -> Self {
1240        Tag::Shared(shared)
1241    }
1242}
1243
1244impl From<SharedMutTag> for Tag {
1245    fn from(shared: SharedMutTag) -> Self {
1246        Tag::SharedMut(shared)
1247    }
1248}
1249
1250/// A `str` in shared memory.
1251pub type SharedStr = Shared<str>;
1252
1253impl Deref for SharedStr {
1254    type Target = str;
1255
1256    fn deref(&self) -> &Self::Target {
1257        self.as_view()
1258    }
1259}
1260
1261/// A mutable `str` in shared memory.
1262pub type SharedStrMut = SharedMut<str>;
1263
1264impl Deref for SharedStrMut {
1265    type Target = str;
1266
1267    fn deref(&self) -> &Self::Target {
1268        self.as_view()
1269    }
1270}
1271
1272impl DerefMut for SharedStrMut {
1273    fn deref_mut(&mut self) -> &mut Self::Target {
1274        self.as_view_mut()
1275    }
1276}
1277
1278/// A `slice` in shared memory.
1279pub type SharedSlice<T> = Shared<[T]>;
1280
1281impl<T: Copy + 'static> Deref for SharedSlice<T> {
1282    type Target = [T];
1283
1284    fn deref(&self) -> &Self::Target {
1285        self.as_view()
1286    }
1287}
1288
1289/// A mutable `slice` in shared memory.
1290pub type SharedSliceMut<T> = SharedMut<[T]>;
1291
1292impl<T: Copy + 'static> Deref for SharedSliceMut<T> {
1293    type Target = [T];
1294
1295    fn deref(&self) -> &Self::Target {
1296        self.as_view()
1297    }
1298}
1299
1300impl<T: Copy + 'static> DerefMut for SharedSliceMut<T> {
1301    fn deref_mut(&mut self) -> &mut Self::Target {
1302        self.as_view_mut()
1303    }
1304}
1305
1306#[cfg(feature = "shared_ndarray")]
1307pub use sharify_ndarray::{SharedArray, SharedArrayMut};
1308
1309#[cfg(feature = "shared_ndarray")]
1310pub mod sharify_ndarray {
1311    use super::*;
1312    use ndarray::{Array, ArrayView, ArrayViewMut, Dimension};
1313
1314    /// An immutable [`ndarray::Array`] whose data lives in shared memory.
1315    pub type SharedArray<T, D> = Shared<Array<T, D>>;
1316    /// A mutable [`ndarray::Array`] whose data lives in shared memory.
1317    pub type SharedArrayMut<T, D> = SharedMut<Array<T, D>>;
1318
1319    unsafe impl<'a, T, D> ShmemBacked for Array<T, D>
1320    where
1321        T: Copy,
1322        D: Dimension + Serialize + DeserializeOwned,
1323    {
1324        type NewArg = (T, D);
1325        type MetaData = (Vec<usize>, Vec<isize>);
1326
1327        fn required_memory_arg((_, dim): &Self::NewArg) -> usize {
1328            dim.size() * std::mem::size_of::<T>()
1329        }
1330
1331        fn required_memory_src(src: &Self) -> usize {
1332            src.len() * std::mem::size_of::<T>()
1333        }
1334
1335        fn new(data: &mut [u8], (init, dim): &Self::NewArg) -> Self::MetaData {
1336            let data =
1337                unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, dim.size()) };
1338            for element in data.iter_mut() {
1339                *element = *init;
1340            }
1341            let view = ArrayView::from_shape(dim.clone(), data).unwrap();
1342            let shape = Vec::from(view.shape());
1343            let strides = Vec::from(view.strides());
1344            (shape, strides)
1345        }
1346
1347        fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
1348            let data =
1349                unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, src.len()) };
1350            let mut view = ArrayViewMut::from_shape(src.raw_dim(), data).unwrap();
1351            for (src, dst) in src.iter().zip(view.iter_mut()) {
1352                *dst = *src;
1353            }
1354            let shape = Vec::from(view.shape());
1355            let strides = Vec::from(view.strides());
1356            (shape, strides)
1357        }
1358    }
1359
1360    impl<'a, T, D> ShmemView<'a> for Array<T, D>
1361    where
1362        T: Copy + Default + 'a,
1363        D: Dimension + Serialize + DeserializeOwned,
1364    {
1365        type View = ArrayView<'a, T, D>;
1366
1367        fn view(
1368            data: &'a [u8],
1369            (shape, strides): &'a <Self as ShmemBacked>::MetaData,
1370        ) -> Self::View {
1371            use ndarray::ShapeBuilder;
1372            debug_assert!(shape.iter().product::<usize>() <= data.len());
1373            let data = unsafe {
1374                std::slice::from_raw_parts(data.as_ptr() as *const T, shape.iter().product())
1375            };
1376            let mut shape_dim = D::zeros(shape.len());
1377            for (src, dst) in shape.iter().zip(shape_dim.as_array_view_mut().iter_mut()) {
1378                *dst = *src;
1379            }
1380            let mut strides_dim = D::zeros(strides.len());
1381            for (src, dst) in strides
1382                .iter()
1383                .zip(strides_dim.as_array_view_mut().iter_mut())
1384            {
1385                // Conversion for negative strides, see https://github.com/rust-ndarray/ndarray/pull/948/files
1386                // and https://github.com/rust-ndarray/ndarray/pull/948
1387                *dst = *src as usize;
1388            }
1389            ArrayView::from_shape(shape_dim.strides(strides_dim), data).unwrap()
1390        }
1391    }
1392
1393    impl<'a, T, D> ShmemViewMut<'a> for Array<T, D>
1394    where
1395        T: Copy + Default + 'a,
1396        D: Dimension + Serialize + DeserializeOwned,
1397    {
1398        type View = ArrayViewMut<'a, T, D>;
1399
1400        fn view_mut(
1401            data: &'a mut [u8],
1402            (shape, strides): &'a mut <Self as ShmemBacked>::MetaData,
1403        ) -> Self::View {
1404            use ndarray::ShapeBuilder;
1405            debug_assert!(shape.iter().product::<usize>() <= data.len());
1406            let data = unsafe {
1407                std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, shape.iter().product())
1408            };
1409            let mut shape_dim = D::zeros(shape.len());
1410            for (src, dst) in shape.iter().zip(shape_dim.as_array_view_mut().iter_mut()) {
1411                *dst = *src;
1412            }
1413            let mut strides_dim = D::zeros(strides.len());
1414            for (src, dst) in strides
1415                .iter()
1416                .zip(strides_dim.as_array_view_mut().iter_mut())
1417            {
1418                // Conversion for negative strides, see https://github.com/rust-ndarray/ndarray/pull/948/files
1419                // and https://github.com/rust-ndarray/ndarray/pull/948
1420                *dst = *src as usize;
1421            }
1422            ArrayViewMut::from_shape(shape_dim.strides(strides_dim), data).unwrap()
1423        }
1424    }
1425}
1426
1427#[cfg(test)]
1428mod tests {
1429    use super::*;
1430    use bincode::{self, de, options};
1431    use rand::prelude::*;
1432    use std::thread;
1433
1434    fn serialize_shared<T>(shared: Shared<T>) -> Vec<u8>
1435    where
1436        T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1437    {
1438        let mut bytes = Vec::new();
1439        let mut serializer = bincode::Serializer::new(&mut bytes, options());
1440        shared.into_serialized(&mut serializer).unwrap();
1441        bytes
1442    }
1443
1444    fn serialize_shared_mut<T>(shared: SharedMut<T>) -> Vec<u8>
1445    where
1446        T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1447    {
1448        let mut bytes = Vec::new();
1449        let mut serializer = bincode::Serializer::new(&mut bytes, options());
1450        shared.into_serialized(&mut serializer).unwrap();
1451        bytes
1452    }
1453
1454    fn deserialize<S: DeserializeOwned>(bytes: &[u8]) -> Result<S, String> {
1455        let mut deserializer = de::Deserializer::from_slice(bytes, options());
1456        S::deserialize(&mut deserializer).map_err(|e| format!("{}", e))
1457    }
1458
1459    fn serialization_roundtrip_shared<T>(shared: Shared<T>) -> Shared<T>
1460    where
1461        T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1462    {
1463        let mut bytes = Vec::new();
1464        let mut serializer = bincode::Serializer::new(&mut bytes, options());
1465        shared.into_serialized(&mut serializer).unwrap();
1466        let mut deserializer = de::Deserializer::from_slice(bytes.as_slice(), options());
1467        Shared::<T>::deserialize(&mut deserializer).unwrap()
1468    }
1469
1470    fn serialization_roundtrip_shared_mut<T>(shared: SharedMut<T>) -> SharedMut<T>
1471    where
1472        T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1473    {
1474        let mut bytes = Vec::new();
1475        let mut serializer = bincode::Serializer::new(&mut bytes, options());
1476        shared.into_serialized(&mut serializer).unwrap();
1477        let mut deserializer = de::Deserializer::from_slice(bytes.as_slice(), options());
1478        SharedMut::<T>::deserialize(&mut deserializer).unwrap()
1479    }
1480
1481    fn slice_check_src_shared<T>(slice: &[T])
1482    where
1483        T: Copy + PartialEq + std::fmt::Debug + ?Sized + 'static,
1484    {
1485        let shared = Shared::<[T]>::from(slice);
1486        let roundtrip = serialization_roundtrip_shared(shared);
1487        assert_eq!(roundtrip.as_view(), slice);
1488    }
1489
1490    fn slice_check_src_shared_mut<T>(slice: &[T])
1491    where
1492        T: Copy + PartialEq + std::fmt::Debug + ?Sized + 'static,
1493    {
1494        let shared = SharedMut::<[T]>::from(slice);
1495        let roundtrip = serialization_roundtrip_shared_mut(shared);
1496        assert_eq!(roundtrip.as_view(), slice);
1497    }
1498
1499    #[test]
1500    fn shared_str() {
1501        let s = "sharify_test";
1502        let shared: Shared<str> = Shared::new(s).unwrap();
1503        let roundtrip = serialization_roundtrip_shared(shared);
1504        assert_eq!(roundtrip.as_view(), s);
1505    }
1506
1507    #[test]
1508    fn shared_mut_str() {
1509        let s = "sharify_test";
1510        let shared: SharedMut<str> = SharedMut::new(s).unwrap();
1511        let roundtrip = serialization_roundtrip_shared_mut(shared);
1512        assert_eq!(roundtrip.as_view(), s);
1513    }
1514
1515    enum Slice {
1516        Usize(&'static [usize]),
1517        U8(&'static [u8]),
1518        U64(&'static [u16]),
1519        I16(&'static [i16]),
1520        F64(&'static [f64]),
1521    }
1522
1523    impl Slice {
1524        fn create_slices() -> Vec<Slice> {
1525            vec![
1526                Slice::Usize(&[1, 2, 3, 4, 5]),
1527                Slice::U8(&[1, 2, 3, 4, 5]),
1528                Slice::U64(&[1, 2, 3, 4, 5]),
1529                Slice::I16(&[1, 2, 3, 4, 5]),
1530                Slice::F64(&[1.0, 2.0, 3.0, 4.0, 5.0]),
1531            ]
1532        }
1533    }
1534
1535    #[test]
1536    fn shared_slice() {
1537        let slice: &[usize] = &[0_usize, 0, 0, 0, 0];
1538        let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
1539        let roundtrip = serialization_roundtrip_shared(shared);
1540        assert_eq!(roundtrip.as_view(), slice);
1541    }
1542
1543    #[test]
1544    fn shared_mut_slice() {
1545        let slice: &mut [usize] = &mut [1_usize, 2, 3, 4, 5];
1546        let mut shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1547        shared.deref_mut().copy_from_slice(slice);
1548        let roundtrip = serialization_roundtrip_shared_mut(shared);
1549        assert_eq!(roundtrip.as_view(), slice);
1550    }
1551
1552    #[test]
1553    fn shared_slice_from_src() {
1554        let slices = Slice::create_slices();
1555        for s in slices {
1556            match s {
1557                Slice::Usize(s) => slice_check_src_shared(s),
1558                Slice::U8(s) => slice_check_src_shared(s),
1559                Slice::U64(s) => slice_check_src_shared(s),
1560                Slice::I16(s) => slice_check_src_shared(s),
1561                Slice::F64(s) => slice_check_src_shared(s),
1562            }
1563        }
1564    }
1565
1566    #[test]
1567    fn shared_mut_slice_from_src() {
1568        let slices = Slice::create_slices();
1569        for s in slices {
1570            match s {
1571                Slice::Usize(s) => slice_check_src_shared_mut(s),
1572                Slice::U8(s) => slice_check_src_shared_mut(s),
1573                Slice::U64(s) => slice_check_src_shared_mut(s),
1574                Slice::I16(s) => slice_check_src_shared_mut(s),
1575                Slice::F64(s) => slice_check_src_shared_mut(s),
1576            }
1577        }
1578    }
1579
1580    #[test]
1581    fn shared_memory() {
1582        let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
1583        assert_eq!(shared.counts().unwrap(), [1, 0]);
1584        let bytes = serialize_shared(shared);
1585        let deser: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1586        assert_eq!(deser.counts().unwrap(), [1, 0]);
1587        // Should be able to create more than one `Shared`
1588        let mut instances = Vec::new();
1589        for i in 1..=10 {
1590            let inst: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1591            assert_eq!(deser.counts().unwrap(), [1 + i, 0]);
1592            instances.push(inst);
1593        }
1594        // Data should only be freed when all `Shared`s have been dropped.
1595        assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
1596        std::mem::drop(instances);
1597        assert_eq!(deser.counts().unwrap(), [1, 0]);
1598        assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
1599        std::mem::drop(deser);
1600        assert!(deserialize::<Shared::<[usize]>>(bytes.as_slice()).is_err());
1601    }
1602
1603    #[test]
1604    fn shared_mut_memory() {
1605        let mut shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1606        assert_eq!(shared.counts().unwrap(), [1]);
1607        let bytes = serialize_shared_mut(shared);
1608        let mut deser: SharedMut<[usize]> = deserialize(bytes.as_slice()).unwrap();
1609        assert_eq!(deser.counts().unwrap(), [1]);
1610        // Should not be able to create more than one `SharedMut`
1611        assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
1612        // ...but a deserialization error must leave the data intact.
1613        assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
1614        assert_eq!(deser.counts().unwrap(), [1]);
1615        // Data should only be freed on drop.
1616        std::mem::drop(deser);
1617        assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
1618    }
1619
1620    #[test]
1621    fn cross_serialization_from_shared() {
1622        let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
1623        assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1624        let bytes = serialize_shared(shared);
1625        // Serializing into a `SharedMut` is not allowed
1626        assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
1627        // ...but should be able to deserialize back into a `Shared`
1628        let shared: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1629        assert_eq!(shared.counts().unwrap(), [1, 0]);
1630        assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1631    }
1632
1633    #[test]
1634    fn cross_serialization_from_shared_mut() {
1635        let shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1636        assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1637        let bytes = serialize_shared_mut(shared);
1638        // Serializing into a `Shared` is not allowed
1639        assert!(deserialize::<Shared::<[usize]>>(bytes.as_slice()).is_err());
1640        // ...but should be able to deserialize back into a `SharedMut`
1641        let mut shared: SharedMut<[usize]> = deserialize(bytes.as_slice()).unwrap();
1642        assert_eq!(shared.counts().unwrap(), [1]);
1643        assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1644    }
1645
1646    #[test]
1647    fn shared_mut_into_shared() {
1648        let mut shared_mut: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1649        shared_mut.deref_mut().copy_from_slice(&[1, 2, 3, 4, 5]);
1650        let shared: Shared<[usize]> = shared_mut.try_into().unwrap();
1651        assert_eq!(shared.counts().unwrap(), [1, 0]);
1652        assert_eq!(&[1, 2, 3, 4, 5], shared.deref());
1653    }
1654
1655    #[test]
1656    fn shared_clone() {
1657        let shared = Shared::from(&[1_usize, 2, 3, 4, 5] as &[_]);
1658        let mut container = Vec::new();
1659        for i in 0..100 {
1660            let bytes = serialize_shared(shared.clone());
1661            container.push((bytes, shared.clone()));
1662            assert_eq!(shared.counts().unwrap(), [2 + i, 1 + i]);
1663        }
1664        assert_eq!(shared.counts().unwrap(), [101, 100]);
1665        for (i, (bytes, cl)) in container.into_iter().enumerate() {
1666            let mut _deser: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1667            assert_eq!(&[1_usize, 2, 3, 4, 5], cl.deref());
1668            assert_eq!(shared.counts().unwrap(), [102 - i as u64, 99 - i as u64]);
1669        }
1670        assert_eq!(shared.counts().unwrap(), [1, 0]);
1671        assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1672    }
1673
1674    #[test]
1675    fn races() {
1676        let shared = Shared::from(&[1_usize, 2, 3, 4, 5] as &[_]);
1677        let mut handles = Vec::new();
1678        for _ in 0..50 {
1679            let (send, recv) = std::sync::mpsc::sync_channel(0);
1680            let bytes_send = serialize_shared(shared.clone());
1681            let handle = thread::spawn(move || {
1682                let mut rng = rand::thread_rng();
1683                recv.recv().unwrap();
1684                let mut shared: Shared<[usize]> = deserialize(bytes_send.as_slice()).unwrap();
1685                assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1686                for _ in 0..1000 {
1687                    thread::sleep(std::time::Duration::from_millis(rng.gen_range(0..=5)));
1688                    let tmp = serialize_shared(shared);
1689                    thread::sleep(std::time::Duration::from_millis(rng.gen_range(0..=5)));
1690                    shared = deserialize(tmp.as_slice()).unwrap();
1691                    assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1692                }
1693            });
1694            handles.push((handle, send));
1695        }
1696        thread::sleep(std::time::Duration::from_millis(100));
1697        for (_, send) in handles.iter() {
1698            send.send(()).unwrap();
1699        }
1700        for (handle, _) in handles {
1701            handle.join().unwrap();
1702        }
1703        assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1704        assert_eq!(shared.counts().unwrap(), [1, 0]);
1705    }
1706
1707    #[cfg(feature = "shared_ndarray")]
1708    mod ndarray_tests {
1709        use super::*;
1710        use ndarray::{Array, Axis, IxDyn};
1711
1712        #[test]
1713        fn shared_ndarray() {
1714            let shared: SharedArray<u64, IxDyn> = Shared::new(&(0, IxDyn(&[3, 2]))).unwrap();
1715            let shared = serialization_roundtrip_shared(shared);
1716            assert_eq!(&[0; 6], shared.as_view().as_slice().unwrap());
1717        }
1718
1719        #[test]
1720        fn shared_mut_ndarray() {
1721            let mut shared: SharedArrayMut<f64, IxDyn> =
1722                SharedMut::new(&(0.0, IxDyn(&[3, 2]))).unwrap();
1723            let slice: &[f64] = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1724            for (&x, element) in slice.iter().zip(shared.as_view_mut().iter_mut()) {
1725                *element = x;
1726            }
1727            let roundtrip = serialization_roundtrip_shared_mut(shared);
1728            assert_eq!(slice, roundtrip.as_view().as_slice().unwrap());
1729        }
1730
1731        #[test]
1732        fn shared_mut_array_layout() {
1733            let mut array: SharedArrayMut<f64, ndarray::IxDyn> =
1734                SharedArrayMut::new(&(0.0, ndarray::IxDyn(&[100, 200, 300]))).unwrap();
1735            assert_eq!(array.as_view().strides(), &[200 * 300, 300, 1]);
1736            {
1737                let mut view = array.as_view_mut();
1738                assert!(view.is_standard_layout());
1739                view.swap_axes(0, 1);
1740                assert_eq!(view.strides(), &[300, 200 * 300, 1]);
1741                unsafe {
1742                    *array.metadata_mut() = (Vec::from(view.shape()), Vec::from(view.strides()));
1743                }
1744            }
1745            assert_eq!(array.as_view().shape(), &[200, 100, 300]);
1746            assert_eq!(array.as_view().strides(), &[300, 200 * 300, 1]);
1747            assert!(!array.as_view().is_standard_layout());
1748            let bytes = serialize_shared_mut(array);
1749            let deser: SharedArrayMut<f64, ndarray::IxDyn> = deserialize(bytes.as_slice()).unwrap();
1750            assert_eq!(deser.as_view().shape(), &[200, 100, 300]);
1751            assert_eq!(deser.as_view().strides(), &[300, 200 * 300, 1]);
1752            assert!(!deser.as_view().is_standard_layout());
1753        }
1754
1755        #[test]
1756        fn shared_ndarray_from_src() {
1757            let mut src = Array::from_elem((100, 200, 300), 0_u64);
1758            src.invert_axis(Axis(1));
1759            let src_shape = Vec::from(src.shape());
1760            let array = Shared::new_from_inner(&src).unwrap();
1761            assert_eq!(src_shape, array.as_view().shape());
1762            assert!(src
1763                .iter()
1764                .zip(array.as_view().iter())
1765                .all(|(src, dst)| src == dst))
1766        }
1767    }
1768}