1#![warn(clippy::all)]
2
3use raw_sync::locks::{LockInit, Mutex};
263use serde::{de, de::DeserializeOwned, ser, Deserialize, Deserializer, Serialize, Serializer};
264use std::cell::UnsafeCell;
265use std::convert::{From, Into, TryFrom, TryInto};
266use std::ops::{Deref, DerefMut};
267
268#[allow(dead_code)]
269mod shared_memory;
270use shared_memory::{Shmem, ShmemConf, ShmemError};
271
272#[allow(dead_code)]
273mod memory;
274use memory::{is_aligned, ALIGNMENT};
275
276#[derive(Debug)]
277pub enum Error {
278 UnalignedMemory,
279 Mutex(String),
280 Shmem(ShmemError),
281 Serialization(String),
282 Deserialization(String),
283 InvalidSharedMut,
284}
285
286impl std::fmt::Display for Error {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match self {
289 Self::UnalignedMemory => write!(f, "Encountered unaligned memory."),
290 Self::Shmem(e) => e.fmt(f),
291 Self::Mutex(s) | Self::Serialization(s) | Self::Deserialization(s) => {
292 write!(f, "{}", s)
293 }
294 Self::InvalidSharedMut => write!(f, "Trying to use a `SharedMut` previously invalidated with a call to `serde::Serialize::serialize`.")
295 }
296 }
297}
298
299impl std::error::Error for Error {}
300
301pub unsafe trait ShmemBacked {
319 type NewArg: ?Sized;
325 type MetaData: Serialize + DeserializeOwned + Clone;
328
329 fn required_memory_arg(arg: &Self::NewArg) -> usize;
333 fn required_memory_src(src: &Self) -> usize;
336 fn new(data: &mut [u8], arg: &Self::NewArg) -> Self::MetaData;
341 fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData;
345}
346
347pub trait ShmemView<'a>: ShmemBacked {
349 type View;
350 fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View;
353}
354
355pub trait ShmemViewMut<'a>: ShmemBacked {
357 type View;
358 fn view_mut(
361 data: &'a mut [u8],
362 metadata: &'a mut <Self as ShmemBacked>::MetaData,
363 ) -> Self::View;
364}
365
366unsafe impl ShmemBacked for str {
367 type NewArg = str;
368 type MetaData = usize;
370
371 fn required_memory_arg(src: &Self::NewArg) -> usize {
372 src.len()
373 }
374
375 fn required_memory_src(src: &Self) -> usize {
376 src.len()
377 }
378
379 fn new(data: &mut [u8], src: &Self::NewArg) -> Self::MetaData {
381 assert_eq!(data.len(), Self::required_memory_arg(src));
382 data.copy_from_slice(src.as_bytes());
383 data.len()
384 }
385
386 fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
388 assert_eq!(data.len(), Self::required_memory_src(src));
389 data.copy_from_slice(src.as_bytes());
390 data.len()
391 }
392}
393
394impl<'a> ShmemView<'a> for str {
395 type View = &'a str;
396
397 fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
398 assert_eq!(data.len(), *metadata);
399 unsafe { std::str::from_utf8_unchecked(data) }
400 }
401}
402
403impl<'a> ShmemViewMut<'a> for str {
404 type View = &'a mut str;
405
406 fn view_mut(
407 data: &'a mut [u8],
408 metadata: &'a mut <Self as ShmemBacked>::MetaData,
409 ) -> Self::View {
410 assert_eq!(data.len(), *metadata);
411 unsafe { std::str::from_utf8_unchecked_mut(data) }
412 }
413}
414
415unsafe impl<T> ShmemBacked for [T]
416where
417 T: Copy,
418{
419 type NewArg = (T, usize);
420 type MetaData = usize;
422
423 fn required_memory_arg((_, len): &Self::NewArg) -> usize {
424 *len * std::mem::size_of::<T>()
425 }
426
427 fn required_memory_src(src: &Self) -> usize {
428 src.len() * std::mem::size_of::<T>()
429 }
430
431 fn new(data: &mut [u8], &(init, len): &Self::NewArg) -> Self::MetaData {
433 assert_eq!(data.len(), Self::required_memory_arg(&(init, len)));
434 let data_typed =
435 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, len) };
436 for elem in data_typed.iter_mut() {
437 *elem = init;
438 }
439 len
440 }
441
442 fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
444 assert_eq!(data.len(), Self::required_memory_src(src));
445 let data_typed =
446 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, src.len()) };
447 data_typed.copy_from_slice(src);
448 data_typed.len()
449 }
450}
451
452impl<'a, T> ShmemView<'a> for [T]
453where
454 T: Copy + 'a,
455{
456 type View = &'a [T];
457
458 fn view(data: &'a [u8], metadata: &'a <Self as ShmemBacked>::MetaData) -> Self::View {
459 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, *metadata) }
460 }
461}
462
463impl<'a, T> ShmemViewMut<'a> for [T]
464where
465 T: Copy + 'a,
466{
467 type View = &'a mut [T];
468
469 fn view_mut(
470 data: &'a mut [u8],
471 metadata: &'a mut <Self as ShmemBacked>::MetaData,
472 ) -> Self::View {
473 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, *metadata) }
474 }
475}
476
477use num_traits::{ops::checked::CheckedAdd, sign::Unsigned, NumOps};
478
479trait AccessCounter: Copy + PartialOrd + Eq + NumOps + Unsigned + CheckedAdd {}
480
481impl<T: Copy + PartialOrd + Eq + NumOps + Unsigned + CheckedAdd> AccessCounter for T {}
482
483struct ShmemBase<A: AccessCounter, P: DropBehaviour, const N: usize> {
484 tag: P,
485 shmem: Shmem,
486 access_counter_type: std::marker::PhantomData<A>,
487 counter_offset: usize,
489 data_offset: usize,
491 data_size: usize,
493 free_shmem: bool,
494}
495
496impl<A: AccessCounter, P: DropBehaviour, const N: usize> ShmemBase<A, P, N> {
497 fn new(access_counter: &[A; N], data_size: usize, tag: P) -> Result<Self, Error> {
503 let lock_size = Mutex::size_of(None);
510 let reserved_size =
513 ((lock_size + std::mem::size_of::<A>() * N) / ALIGNMENT + 1) * ALIGNMENT;
514 let shmem = ShmemConf::new()
516 .size(reserved_size + data_size)
517 .create()
518 .map_err(Error::Shmem)?;
519 let (mutex, _) = unsafe {
521 Mutex::new(shmem.as_ptr(), shmem.as_ptr().add(lock_size))
522 .map_err(|e| Error::Mutex(format!("{}", e)))?
523 };
524 {
526 let lock = mutex.lock().map_err(|e| Error::Mutex(format!("{}", e)))?;
527 unsafe {
528 let counter_ptr = std::slice::from_raw_parts_mut(*lock as *mut A, N);
529 counter_ptr.copy_from_slice(access_counter);
530 }
531 }
532 std::mem::forget(mutex);
535 let aligned = unsafe { is_aligned(shmem.as_ptr().add(reserved_size), ALIGNMENT) };
537 if !aligned {
538 Err(Error::UnalignedMemory)
539 } else {
540 Ok(Self {
541 tag,
542 shmem,
543 access_counter_type: std::marker::PhantomData,
544 counter_offset: lock_size,
545 data_offset: reserved_size,
546 data_size,
547 free_shmem: true,
548 })
549 }
550 }
551
552 fn mutex_write(&self, write: fn(&[A; N]) -> Option<[A; N]>) -> Result<[A; N], Error> {
561 let (mutex, _) = unsafe {
562 let counter_ptr = self.shmem.as_ptr().add(self.counter_offset);
563 Mutex::from_existing(self.shmem.as_ptr(), counter_ptr)
564 .map_err(|e| Error::Mutex(format!("{}", e)))?
565 };
566 let counter_values = {
567 let lock = mutex.lock().map_err(|e| Error::Mutex(format!("{}", e)))?;
568 let counter_ptr = unsafe { std::slice::from_raw_parts_mut(*lock as *mut A, N) };
569 let mut old = [A::zero(); N];
570 old.copy_from_slice(counter_ptr);
571 if let Some(new) = write(&old) {
572 counter_ptr.copy_from_slice(&new);
573 }
574 old
575 };
576 std::mem::forget(mutex);
579 Ok(counter_values)
580 }
581
582 fn free_on_drop(&mut self, free: bool) {
585 self.free_shmem = free;
586 }
587
588 fn data_ptr(&self) -> *const u8 {
590 unsafe { self.shmem.as_ptr().add(self.data_offset) }
591 }
592
593 fn data_ptr_mut(&mut self) -> *mut u8 {
596 unsafe { self.shmem.as_ptr().add(self.data_offset) }
597 }
598
599 fn to_wire_format<T>(&self, metadata: T) -> WireFormat<T> {
600 WireFormat {
601 tag: self.tag.clone().into(),
602 os_id: String::from(self.shmem.get_os_id()),
603 mem_size: self.shmem.len(),
604 counter_offset: self.counter_offset,
605 data_offset: self.data_offset,
606 data_size: self.data_size,
607 meta: metadata,
608 }
609 }
610
611 fn into_wire_format<T>(mut self, metadata: T) -> WireFormat<T> {
614 self.free_shmem = false;
616 WireFormat {
617 tag: self.tag.clone().into(),
618 os_id: String::from(self.shmem.get_os_id()),
619 mem_size: self.shmem.len(),
620 counter_offset: self.counter_offset,
621 data_offset: self.data_offset,
622 data_size: self.data_size,
623 meta: metadata,
624 }
625 }
626
627 fn from_wire_format<T>(wire_format: WireFormat<T>) -> Result<(Self, T), Error> {
630 let WireFormat {
631 tag,
632 os_id,
633 mem_size,
634 counter_offset,
635 data_offset,
636 data_size,
637 meta,
638 } = wire_format;
639 let shmem = ShmemConf::new()
640 .os_id(os_id)
641 .size(mem_size)
642 .open()
643 .map_err(Error::Shmem)?;
644 Ok((
645 Self {
646 tag: tag.try_into()?,
647 shmem,
648 access_counter_type: std::marker::PhantomData,
649 counter_offset,
650 data_offset,
651 data_size,
652 free_shmem: false,
653 },
654 meta,
655 ))
656 }
657}
658
659impl<A: AccessCounter, P: DropBehaviour, const N: usize> Drop for ShmemBase<A, P, N> {
660 fn drop(&mut self) {
661 P::called_on_drop(self);
662 }
663}
664
665trait DropBehaviour: Clone + TryFrom<Tag, Error = Error> + Into<Tag> {
666 fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
667 base: &mut ShmemBase<A, P, N>,
668 );
669}
670
671impl<A: AccessCounter> Clone for ShmemBase<A, SharedTag, 2> {
672 fn clone(&self) -> Self {
673 let shmem = ShmemConf::new()
674 .os_id(self.shmem.get_os_id())
675 .size(self.shmem.len())
676 .open()
677 .unwrap();
678 let new = Self {
679 tag: SharedTag(),
680 shmem,
681 access_counter_type: std::marker::PhantomData,
682 counter_offset: self.counter_offset,
683 data_offset: self.data_offset,
684 data_size: self.data_size,
685 free_shmem: true,
686 };
687 let write: fn(&[A; 2]) -> Option<[A; 2]> = |old| {
689 let mut new = [A::zero(); 2];
690 if let Some(new_acc_count) = old[0].checked_add(&A::one()) {
691 new[0] = new_acc_count;
692 new[1] = old[1];
693 Some(new)
694 } else {
695 panic!("Can't have more than A::MAX `Shared`s with simultaneous access.")
696 }
697 };
698 new.mutex_write(write).unwrap();
699 new
700 }
701}
702
703impl<A: AccessCounter> From<ShmemBase<A, SharedMutTag, 2>> for ShmemBase<A, SharedTag, 2> {
704 fn from(mut shared_mut: ShmemBase<A, SharedMutTag, 2>) -> Self {
705 shared_mut.free_on_drop(false);
707 let shmem = ShmemConf::new()
710 .os_id(shared_mut.shmem.get_os_id())
711 .size(shared_mut.shmem.len())
712 .open()
713 .unwrap();
714 let new = Self {
715 tag: SharedTag(),
716 shmem,
717 access_counter_type: std::marker::PhantomData,
718 counter_offset: shared_mut.counter_offset,
719 data_offset: shared_mut.data_offset,
720 data_size: shared_mut.data_size,
721 free_shmem: true,
722 };
723 new.mutex_write(|_| Some([A::one(), A::zero()])).unwrap();
725 new
726 }
727}
728
729#[derive(Serialize, Deserialize, Clone)]
730struct SharedTag();
731
732impl TryFrom<Tag> for SharedTag {
733 type Error = Error;
734 fn try_from(value: Tag) -> Result<Self, Self::Error> {
735 match value {
736 Tag::Shared(tag) => Ok(tag),
737 Tag::SharedMut(_) => Err(Error::Deserialization(String::from(
738 "Can't deserialize a `Shared` from a `SharedMut` serialization.",
739 ))),
740 }
741 }
742}
743
744impl DropBehaviour for SharedTag {
745 fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
746 base: &mut ShmemBase<A, P, N>,
747 ) {
748 if base.free_shmem {
751 let write: fn(&[A; N]) -> Option<[A; N]> = |old| {
753 let mut new = [A::zero(); N];
754 new[0] = old[0] - A::one();
755 new[1] = old[1];
756 Some(new)
757 };
758 let counter_value = base.mutex_write(write).unwrap();
759 if (counter_value[0] == A::one()) && (counter_value[1] == A::zero()) {
762 unsafe {
766 let counter_ptr = base.shmem.as_ptr().add(base.counter_offset);
767 Mutex::from_existing(base.shmem.as_ptr(), counter_ptr).unwrap();
768 }
769 base.shmem.set_owner(true);
770 } else {
771 base.shmem.set_owner(false);
772 }
773 } else {
774 base.shmem.set_owner(false);
777 }
778 }
779}
780
781pub struct Shared<T>
783where
784 T: ShmemBacked + ?Sized,
785{
786 metadata: <T as ShmemBacked>::MetaData,
787 shmem: ShmemBase<u64, SharedTag, 2>,
788}
789
790impl<T> Shared<T>
791where
792 T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
793{
794 pub fn new(arg: &<T as ShmemBacked>::NewArg) -> Result<Self, Error> {
795 let size = T::required_memory_arg(&arg);
796 let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedTag())?;
812 let metadata = unsafe {
813 let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
814 T::new(data, arg)
815 };
816 Ok(Shared { metadata, shmem })
817 }
818
819 pub fn new_from_inner(arg: &T) -> Result<Self, Error> {
820 let size = T::required_memory_src(&arg);
821 let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedTag())?;
822 let inner = unsafe {
823 let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
824 T::new_from_src(data, arg)
825 };
826 Ok(Shared {
827 metadata: inner,
828 shmem,
829 })
830 }
831
832 #[allow(clippy::clippy::needless_lifetimes)]
833 pub fn as_view<'a>(&'a self) -> <T as ShmemView<'a>>::View {
834 let data =
835 unsafe { std::slice::from_raw_parts(self.shmem.data_ptr(), self.shmem.data_size) };
836 T::view(data, &self.metadata)
837 }
838
839 #[cfg(test)]
840 pub fn counts(&self) -> Result<[u64; 2], Error> {
841 self.shmem.mutex_write(|_| None)
842 }
843
844 pub fn into_serialized<S: Serializer>(self, serializer: S) -> Result<S::Ok, S::Error> {
845 let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
848 let mut new = [0_u64, 0];
849 if let Some(new_ser_count) = old[1].checked_add(1) {
850 new[0] = old[0] - 1;
851 new[1] = new_ser_count;
852 Some(new)
853 } else {
854 panic!("Can't have more than A::MAX serialized `Shared`s.")
855 }
856 };
857 self.shmem.mutex_write(write).map_err(ser::Error::custom)?;
858 let wire_format = self.shmem.into_wire_format(self.metadata);
859 wire_format.serialize(serializer)
860 }
861}
862
863impl<T> Serialize for Shared<T>
864where
865 T: ShmemBacked + ?Sized,
866{
867 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
868 where
869 S: Serializer,
870 {
871 let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
873 let mut new = [0_u64, 0];
874 if let Some(new_ser_count) = old[1].checked_add(1) {
875 new[0] = old[0];
876 new[1] = new_ser_count;
877 Some(new)
878 } else {
879 panic!("Can't have more than A::MAX serialized `Shared`s.")
880 }
881 };
882 self.shmem.mutex_write(write).map_err(ser::Error::custom)?;
883 let wire_format = self.shmem.to_wire_format(&self.metadata);
884 wire_format.serialize(serializer)
885 }
886}
887
888impl<'de, T> Deserialize<'de> for Shared<T>
889where
890 T: ShmemBacked + ?Sized,
891{
892 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
893 where
894 D: Deserializer<'de>,
895 {
896 let wire_format = WireFormat::deserialize(deserializer)?;
897 let (mut shmem, metadata) = ShmemBase::<u64, SharedTag, 2>::from_wire_format(wire_format)
900 .map_err(de::Error::custom)?;
901 let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
904 let mut new = [0_u64, 0];
905 if let Some(new_acc_count) = old[0].checked_add(1) {
906 new[0] = new_acc_count;
907 new[1] = old[1].saturating_sub(1);
908 Some(new)
909 } else {
910 None
911 }
912 };
913 if shmem.mutex_write(write).map_err(de::Error::custom)?[0] < u64::MAX {
914 shmem.free_on_drop(true);
917 Ok(Shared { metadata, shmem })
918 } else {
919 Err(de::Error::custom(
920 "Can't have more than u64::MAX `Shared`s with simultaneous access.",
921 ))
922 }
923 }
924}
925
926unsafe impl<T> Send for Shared<T>
927where
928 T: ShmemBacked + ?Sized,
929 T::MetaData: Send,
930{
931}
932impl<'a, T> From<&'a T> for Shared<T>
937where
938 T: ShmemBacked + for<'b> ShmemView<'b> + ?Sized,
939{
940 fn from(src: &'a T) -> Self {
941 Self::new_from_inner(src).unwrap()
942 }
943}
944
945impl<T> Clone for Shared<T>
946where
947 T: ShmemBacked + ?Sized,
948{
949 fn clone(&self) -> Self {
952 let Shared {
953 shmem, metadata, ..
954 } = self;
955 let shmem = shmem.clone();
956 Self {
957 metadata: metadata.clone(),
958 shmem,
959 }
960 }
961}
962
963impl<T> TryFrom<SharedMut<T>> for Shared<T>
964where
965 T: ShmemBacked + ?Sized,
966{
967 type Error = Error;
968
969 fn try_from(shared_mut: SharedMut<T>) -> Result<Self, Self::Error> {
972 let SharedMut { shmem, metadata } = shared_mut;
973 if let Some(shmem) = shmem.into_inner() {
974 let shmem: ShmemBase<_, SharedTag, 2> = shmem.into();
975 Ok(Self { metadata, shmem })
976 } else {
977 Err(Error::InvalidSharedMut)
978 }
979 }
980}
981
982#[derive(Serialize, Deserialize, Clone)]
983struct SharedMutTag();
984
985impl TryFrom<Tag> for SharedMutTag {
986 type Error = Error;
987 fn try_from(value: Tag) -> Result<Self, Self::Error> {
988 match value {
989 Tag::SharedMut(tag) => Ok(tag),
990 Tag::Shared(_) => Err(Error::Deserialization(String::from(
991 "Can't deserialize a `SharedMut` from a `Shared` serialization.",
992 ))),
993 }
994 }
995}
996
997impl DropBehaviour for SharedMutTag {
998 fn called_on_drop<A: AccessCounter, P: DropBehaviour, const N: usize>(
999 base: &mut ShmemBase<A, P, N>,
1000 ) {
1001 if base.free_shmem {
1002 unsafe {
1006 let counter_ptr = base.shmem.as_ptr().add(base.counter_offset);
1007 Mutex::from_existing(base.shmem.as_ptr(), counter_ptr).unwrap();
1008 }
1009 base.shmem.set_owner(true);
1010 } else {
1011 base.shmem.set_owner(false);
1014 }
1015 }
1016}
1017
1018pub struct SharedMut<T>
1020where
1021 T: ShmemBacked + ?Sized,
1022{
1023 metadata: <T as ShmemBacked>::MetaData,
1024 shmem: UnsafeCell<Option<ShmemBase<u64, SharedMutTag, 2>>>,
1025}
1026
1027impl<T> SharedMut<T>
1028where
1029 T: ShmemBacked + for<'a> ShmemView<'a> + for<'a> ShmemViewMut<'a> + ?Sized,
1030{
1031 pub fn new(arg: &<T as ShmemBacked>::NewArg) -> Result<Self, Error> {
1032 let size = T::required_memory_arg(&arg);
1033 let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedMutTag())?;
1043 let metadata = unsafe {
1044 let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
1045 T::new(data, arg)
1046 };
1047 Ok(SharedMut {
1048 metadata,
1049 shmem: UnsafeCell::new(Some(shmem)),
1050 })
1051 }
1052
1053 pub fn new_from_inner(arg: &T) -> Result<Self, Error> {
1054 let size = T::required_memory_src(&arg);
1055 let mut shmem = ShmemBase::new(&[1_u64, 0], size, SharedMutTag())?;
1056 let metadata = unsafe {
1057 let data = std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size);
1058 T::new_from_src(data, arg)
1059 };
1060 Ok(SharedMut {
1061 metadata,
1062 shmem: UnsafeCell::new(Some(shmem)),
1063 })
1064 }
1065
1066 #[allow(clippy::clippy::needless_lifetimes)]
1067 pub fn as_view_mut<'a>(&'a mut self) -> <T as ShmemViewMut<'a>>::View {
1068 let shmem =
1069 self.shmem.get_mut().as_mut().expect(
1070 "`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
1071 );
1072 let data = unsafe { std::slice::from_raw_parts_mut(shmem.data_ptr_mut(), shmem.data_size) };
1073 T::view_mut(data, &mut self.metadata)
1074 }
1075}
1076
1077impl<T> SharedMut<T>
1078where
1079 T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1080{
1081 #[allow(clippy::clippy::needless_lifetimes)]
1082 pub fn as_view<'a>(&'a self) -> <T as ShmemView<'a>>::View {
1083 let data = unsafe {
1084 let shmem: &mut _ = &mut *self.shmem.get();
1085 let shmem = shmem.as_mut().expect(
1086 "`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
1087 );
1088 std::slice::from_raw_parts(shmem.data_ptr(), shmem.data_size)
1089 };
1090 T::view(data, &self.metadata)
1091 }
1092
1093 pub unsafe fn metadata_mut(&mut self) -> &mut <T as ShmemBacked>::MetaData {
1103 &mut self.metadata
1104 }
1105
1106 #[cfg(test)]
1107 pub fn counts(&mut self) -> Result<[u64; 1], Error> {
1108 let shmem =
1109 self.shmem.get_mut().as_mut().expect(
1110 "`SharedMut` must not be used after a call to `serde::Serialize::serialize`.",
1111 );
1112 let mut access_count = [0];
1113 let counts = shmem.mutex_write(|_| None)?;
1114 access_count[0] = counts[0];
1115 Ok(access_count)
1116 }
1117
1118 pub fn into_serialized<S: Serializer>(self, serializer: S) -> Result<S::Ok, S::Error> {
1119 let shmem = self
1120 .shmem
1121 .into_inner()
1122 .expect("`SharedMut` must not be used after a call to `serde::Serialize::serialize`.");
1123 let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
1126 debug_assert_eq!(old[0], 1);
1127 Some([0_u64, 0])
1128 };
1129 shmem.mutex_write(write).map_err(ser::Error::custom)?;
1130 let wire_format = shmem.into_wire_format(self.metadata);
1131 wire_format.serialize(serializer)
1132 }
1133}
1134
1135impl<T> Serialize for SharedMut<T>
1136where
1137 T: ShmemBacked + ?Sized,
1138{
1139 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
1146 where
1147 S: Serializer,
1148 {
1149 let shmem: &mut _ = unsafe { &mut *self.shmem.get() };
1150 let shmem = shmem
1151 .take()
1152 .expect("`SharedMut` must not be used after a call to `serde::Serialize::serialize`.");
1153 let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
1156 debug_assert_eq!(old[0], 1);
1157 Some([0_u64, 0])
1158 };
1159 shmem.mutex_write(write).map_err(ser::Error::custom)?;
1160 let wire_format = shmem.into_wire_format(self.metadata.clone());
1161 wire_format.serialize(serializer)
1162 }
1163}
1164
1165impl<'de, T> Deserialize<'de> for SharedMut<T>
1166where
1167 T: ShmemBacked + ?Sized,
1168{
1169 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1170 where
1171 D: Deserializer<'de>,
1172 {
1173 let wire_format = WireFormat::deserialize(deserializer)?;
1174 let (mut shmem, metadata) =
1177 ShmemBase::<u64, SharedMutTag, 2>::from_wire_format(wire_format)
1178 .map_err(de::Error::custom)?;
1179 let write: fn(&[u64; 2]) -> Option<[u64; 2]> = |old| {
1182 debug_assert!(old[0] <= 1);
1183 if old[0] == 0 {
1184 Some([1, 0])
1185 } else {
1186 None
1187 }
1188 };
1189 if shmem.mutex_write(write).map_err(de::Error::custom)?[0] == 0 {
1190 shmem.free_on_drop(true);
1193 Ok(SharedMut {
1194 metadata,
1195 shmem: UnsafeCell::new(Some(shmem)),
1196 })
1197 } else {
1198 Err(de::Error::custom("A shared memory region can only be accessed by one `SharedMut` instance at any time. Note that the existing instance may live in a different process."))
1199 }
1200 }
1201}
1202
1203unsafe impl<T> Send for SharedMut<T>
1204where
1205 T: ShmemBacked + ?Sized,
1206 T::MetaData: Send,
1207{
1208}
1209impl<'a, T> From<&'a T> for SharedMut<T>
1213where
1214 T: ShmemBacked + for<'b> ShmemView<'b> + for<'b> ShmemViewMut<'b> + ?Sized,
1215{
1216 fn from(src: &'a T) -> Self {
1217 Self::new_from_inner(src).unwrap()
1218 }
1219}
1220
1221#[derive(Serialize, Deserialize)]
1222struct WireFormat<T> {
1223 tag: Tag,
1224 os_id: String,
1225 mem_size: usize,
1226 counter_offset: usize,
1227 data_offset: usize,
1228 data_size: usize,
1229 meta: T,
1230}
1231
1232#[derive(Serialize, Deserialize)]
1233enum Tag {
1234 Shared(SharedTag),
1235 SharedMut(SharedMutTag),
1236}
1237
1238impl From<SharedTag> for Tag {
1239 fn from(shared: SharedTag) -> Self {
1240 Tag::Shared(shared)
1241 }
1242}
1243
1244impl From<SharedMutTag> for Tag {
1245 fn from(shared: SharedMutTag) -> Self {
1246 Tag::SharedMut(shared)
1247 }
1248}
1249
1250pub type SharedStr = Shared<str>;
1252
1253impl Deref for SharedStr {
1254 type Target = str;
1255
1256 fn deref(&self) -> &Self::Target {
1257 self.as_view()
1258 }
1259}
1260
1261pub type SharedStrMut = SharedMut<str>;
1263
1264impl Deref for SharedStrMut {
1265 type Target = str;
1266
1267 fn deref(&self) -> &Self::Target {
1268 self.as_view()
1269 }
1270}
1271
1272impl DerefMut for SharedStrMut {
1273 fn deref_mut(&mut self) -> &mut Self::Target {
1274 self.as_view_mut()
1275 }
1276}
1277
1278pub type SharedSlice<T> = Shared<[T]>;
1280
1281impl<T: Copy + 'static> Deref for SharedSlice<T> {
1282 type Target = [T];
1283
1284 fn deref(&self) -> &Self::Target {
1285 self.as_view()
1286 }
1287}
1288
1289pub type SharedSliceMut<T> = SharedMut<[T]>;
1291
1292impl<T: Copy + 'static> Deref for SharedSliceMut<T> {
1293 type Target = [T];
1294
1295 fn deref(&self) -> &Self::Target {
1296 self.as_view()
1297 }
1298}
1299
1300impl<T: Copy + 'static> DerefMut for SharedSliceMut<T> {
1301 fn deref_mut(&mut self) -> &mut Self::Target {
1302 self.as_view_mut()
1303 }
1304}
1305
1306#[cfg(feature = "shared_ndarray")]
1307pub use sharify_ndarray::{SharedArray, SharedArrayMut};
1308
1309#[cfg(feature = "shared_ndarray")]
1310pub mod sharify_ndarray {
1311 use super::*;
1312 use ndarray::{Array, ArrayView, ArrayViewMut, Dimension};
1313
1314 pub type SharedArray<T, D> = Shared<Array<T, D>>;
1316 pub type SharedArrayMut<T, D> = SharedMut<Array<T, D>>;
1318
1319 unsafe impl<'a, T, D> ShmemBacked for Array<T, D>
1320 where
1321 T: Copy,
1322 D: Dimension + Serialize + DeserializeOwned,
1323 {
1324 type NewArg = (T, D);
1325 type MetaData = (Vec<usize>, Vec<isize>);
1326
1327 fn required_memory_arg((_, dim): &Self::NewArg) -> usize {
1328 dim.size() * std::mem::size_of::<T>()
1329 }
1330
1331 fn required_memory_src(src: &Self) -> usize {
1332 src.len() * std::mem::size_of::<T>()
1333 }
1334
1335 fn new(data: &mut [u8], (init, dim): &Self::NewArg) -> Self::MetaData {
1336 let data =
1337 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, dim.size()) };
1338 for element in data.iter_mut() {
1339 *element = *init;
1340 }
1341 let view = ArrayView::from_shape(dim.clone(), data).unwrap();
1342 let shape = Vec::from(view.shape());
1343 let strides = Vec::from(view.strides());
1344 (shape, strides)
1345 }
1346
1347 fn new_from_src(data: &mut [u8], src: &Self) -> Self::MetaData {
1348 let data =
1349 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, src.len()) };
1350 let mut view = ArrayViewMut::from_shape(src.raw_dim(), data).unwrap();
1351 for (src, dst) in src.iter().zip(view.iter_mut()) {
1352 *dst = *src;
1353 }
1354 let shape = Vec::from(view.shape());
1355 let strides = Vec::from(view.strides());
1356 (shape, strides)
1357 }
1358 }
1359
1360 impl<'a, T, D> ShmemView<'a> for Array<T, D>
1361 where
1362 T: Copy + Default + 'a,
1363 D: Dimension + Serialize + DeserializeOwned,
1364 {
1365 type View = ArrayView<'a, T, D>;
1366
1367 fn view(
1368 data: &'a [u8],
1369 (shape, strides): &'a <Self as ShmemBacked>::MetaData,
1370 ) -> Self::View {
1371 use ndarray::ShapeBuilder;
1372 debug_assert!(shape.iter().product::<usize>() <= data.len());
1373 let data = unsafe {
1374 std::slice::from_raw_parts(data.as_ptr() as *const T, shape.iter().product())
1375 };
1376 let mut shape_dim = D::zeros(shape.len());
1377 for (src, dst) in shape.iter().zip(shape_dim.as_array_view_mut().iter_mut()) {
1378 *dst = *src;
1379 }
1380 let mut strides_dim = D::zeros(strides.len());
1381 for (src, dst) in strides
1382 .iter()
1383 .zip(strides_dim.as_array_view_mut().iter_mut())
1384 {
1385 *dst = *src as usize;
1388 }
1389 ArrayView::from_shape(shape_dim.strides(strides_dim), data).unwrap()
1390 }
1391 }
1392
1393 impl<'a, T, D> ShmemViewMut<'a> for Array<T, D>
1394 where
1395 T: Copy + Default + 'a,
1396 D: Dimension + Serialize + DeserializeOwned,
1397 {
1398 type View = ArrayViewMut<'a, T, D>;
1399
1400 fn view_mut(
1401 data: &'a mut [u8],
1402 (shape, strides): &'a mut <Self as ShmemBacked>::MetaData,
1403 ) -> Self::View {
1404 use ndarray::ShapeBuilder;
1405 debug_assert!(shape.iter().product::<usize>() <= data.len());
1406 let data = unsafe {
1407 std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut T, shape.iter().product())
1408 };
1409 let mut shape_dim = D::zeros(shape.len());
1410 for (src, dst) in shape.iter().zip(shape_dim.as_array_view_mut().iter_mut()) {
1411 *dst = *src;
1412 }
1413 let mut strides_dim = D::zeros(strides.len());
1414 for (src, dst) in strides
1415 .iter()
1416 .zip(strides_dim.as_array_view_mut().iter_mut())
1417 {
1418 *dst = *src as usize;
1421 }
1422 ArrayViewMut::from_shape(shape_dim.strides(strides_dim), data).unwrap()
1423 }
1424 }
1425}
1426
1427#[cfg(test)]
1428mod tests {
1429 use super::*;
1430 use bincode::{self, de, options};
1431 use rand::prelude::*;
1432 use std::thread;
1433
1434 fn serialize_shared<T>(shared: Shared<T>) -> Vec<u8>
1435 where
1436 T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1437 {
1438 let mut bytes = Vec::new();
1439 let mut serializer = bincode::Serializer::new(&mut bytes, options());
1440 shared.into_serialized(&mut serializer).unwrap();
1441 bytes
1442 }
1443
1444 fn serialize_shared_mut<T>(shared: SharedMut<T>) -> Vec<u8>
1445 where
1446 T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1447 {
1448 let mut bytes = Vec::new();
1449 let mut serializer = bincode::Serializer::new(&mut bytes, options());
1450 shared.into_serialized(&mut serializer).unwrap();
1451 bytes
1452 }
1453
1454 fn deserialize<S: DeserializeOwned>(bytes: &[u8]) -> Result<S, String> {
1455 let mut deserializer = de::Deserializer::from_slice(bytes, options());
1456 S::deserialize(&mut deserializer).map_err(|e| format!("{}", e))
1457 }
1458
1459 fn serialization_roundtrip_shared<T>(shared: Shared<T>) -> Shared<T>
1460 where
1461 T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1462 {
1463 let mut bytes = Vec::new();
1464 let mut serializer = bincode::Serializer::new(&mut bytes, options());
1465 shared.into_serialized(&mut serializer).unwrap();
1466 let mut deserializer = de::Deserializer::from_slice(bytes.as_slice(), options());
1467 Shared::<T>::deserialize(&mut deserializer).unwrap()
1468 }
1469
1470 fn serialization_roundtrip_shared_mut<T>(shared: SharedMut<T>) -> SharedMut<T>
1471 where
1472 T: ShmemBacked + for<'a> ShmemView<'a> + ?Sized,
1473 {
1474 let mut bytes = Vec::new();
1475 let mut serializer = bincode::Serializer::new(&mut bytes, options());
1476 shared.into_serialized(&mut serializer).unwrap();
1477 let mut deserializer = de::Deserializer::from_slice(bytes.as_slice(), options());
1478 SharedMut::<T>::deserialize(&mut deserializer).unwrap()
1479 }
1480
1481 fn slice_check_src_shared<T>(slice: &[T])
1482 where
1483 T: Copy + PartialEq + std::fmt::Debug + ?Sized + 'static,
1484 {
1485 let shared = Shared::<[T]>::from(slice);
1486 let roundtrip = serialization_roundtrip_shared(shared);
1487 assert_eq!(roundtrip.as_view(), slice);
1488 }
1489
1490 fn slice_check_src_shared_mut<T>(slice: &[T])
1491 where
1492 T: Copy + PartialEq + std::fmt::Debug + ?Sized + 'static,
1493 {
1494 let shared = SharedMut::<[T]>::from(slice);
1495 let roundtrip = serialization_roundtrip_shared_mut(shared);
1496 assert_eq!(roundtrip.as_view(), slice);
1497 }
1498
1499 #[test]
1500 fn shared_str() {
1501 let s = "sharify_test";
1502 let shared: Shared<str> = Shared::new(s).unwrap();
1503 let roundtrip = serialization_roundtrip_shared(shared);
1504 assert_eq!(roundtrip.as_view(), s);
1505 }
1506
1507 #[test]
1508 fn shared_mut_str() {
1509 let s = "sharify_test";
1510 let shared: SharedMut<str> = SharedMut::new(s).unwrap();
1511 let roundtrip = serialization_roundtrip_shared_mut(shared);
1512 assert_eq!(roundtrip.as_view(), s);
1513 }
1514
1515 enum Slice {
1516 Usize(&'static [usize]),
1517 U8(&'static [u8]),
1518 U64(&'static [u16]),
1519 I16(&'static [i16]),
1520 F64(&'static [f64]),
1521 }
1522
1523 impl Slice {
1524 fn create_slices() -> Vec<Slice> {
1525 vec![
1526 Slice::Usize(&[1, 2, 3, 4, 5]),
1527 Slice::U8(&[1, 2, 3, 4, 5]),
1528 Slice::U64(&[1, 2, 3, 4, 5]),
1529 Slice::I16(&[1, 2, 3, 4, 5]),
1530 Slice::F64(&[1.0, 2.0, 3.0, 4.0, 5.0]),
1531 ]
1532 }
1533 }
1534
1535 #[test]
1536 fn shared_slice() {
1537 let slice: &[usize] = &[0_usize, 0, 0, 0, 0];
1538 let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
1539 let roundtrip = serialization_roundtrip_shared(shared);
1540 assert_eq!(roundtrip.as_view(), slice);
1541 }
1542
1543 #[test]
1544 fn shared_mut_slice() {
1545 let slice: &mut [usize] = &mut [1_usize, 2, 3, 4, 5];
1546 let mut shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1547 shared.deref_mut().copy_from_slice(slice);
1548 let roundtrip = serialization_roundtrip_shared_mut(shared);
1549 assert_eq!(roundtrip.as_view(), slice);
1550 }
1551
1552 #[test]
1553 fn shared_slice_from_src() {
1554 let slices = Slice::create_slices();
1555 for s in slices {
1556 match s {
1557 Slice::Usize(s) => slice_check_src_shared(s),
1558 Slice::U8(s) => slice_check_src_shared(s),
1559 Slice::U64(s) => slice_check_src_shared(s),
1560 Slice::I16(s) => slice_check_src_shared(s),
1561 Slice::F64(s) => slice_check_src_shared(s),
1562 }
1563 }
1564 }
1565
1566 #[test]
1567 fn shared_mut_slice_from_src() {
1568 let slices = Slice::create_slices();
1569 for s in slices {
1570 match s {
1571 Slice::Usize(s) => slice_check_src_shared_mut(s),
1572 Slice::U8(s) => slice_check_src_shared_mut(s),
1573 Slice::U64(s) => slice_check_src_shared_mut(s),
1574 Slice::I16(s) => slice_check_src_shared_mut(s),
1575 Slice::F64(s) => slice_check_src_shared_mut(s),
1576 }
1577 }
1578 }
1579
1580 #[test]
1581 fn shared_memory() {
1582 let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
1583 assert_eq!(shared.counts().unwrap(), [1, 0]);
1584 let bytes = serialize_shared(shared);
1585 let deser: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1586 assert_eq!(deser.counts().unwrap(), [1, 0]);
1587 let mut instances = Vec::new();
1589 for i in 1..=10 {
1590 let inst: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1591 assert_eq!(deser.counts().unwrap(), [1 + i, 0]);
1592 instances.push(inst);
1593 }
1594 assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
1596 std::mem::drop(instances);
1597 assert_eq!(deser.counts().unwrap(), [1, 0]);
1598 assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
1599 std::mem::drop(deser);
1600 assert!(deserialize::<Shared::<[usize]>>(bytes.as_slice()).is_err());
1601 }
1602
1603 #[test]
1604 fn shared_mut_memory() {
1605 let mut shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1606 assert_eq!(shared.counts().unwrap(), [1]);
1607 let bytes = serialize_shared_mut(shared);
1608 let mut deser: SharedMut<[usize]> = deserialize(bytes.as_slice()).unwrap();
1609 assert_eq!(deser.counts().unwrap(), [1]);
1610 assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
1612 assert_eq!(&[0_usize, 0, 0, 0, 0], deser.deref());
1614 assert_eq!(deser.counts().unwrap(), [1]);
1615 std::mem::drop(deser);
1617 assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
1618 }
1619
1620 #[test]
1621 fn cross_serialization_from_shared() {
1622 let shared: Shared<[usize]> = Shared::new(&(0_usize, 5)).unwrap();
1623 assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1624 let bytes = serialize_shared(shared);
1625 assert!(deserialize::<SharedMut::<[usize]>>(bytes.as_slice()).is_err());
1627 let shared: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1629 assert_eq!(shared.counts().unwrap(), [1, 0]);
1630 assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1631 }
1632
1633 #[test]
1634 fn cross_serialization_from_shared_mut() {
1635 let shared: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1636 assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1637 let bytes = serialize_shared_mut(shared);
1638 assert!(deserialize::<Shared::<[usize]>>(bytes.as_slice()).is_err());
1640 let mut shared: SharedMut<[usize]> = deserialize(bytes.as_slice()).unwrap();
1642 assert_eq!(shared.counts().unwrap(), [1]);
1643 assert_eq!(&[0_usize, 0, 0, 0, 0], shared.deref());
1644 }
1645
1646 #[test]
1647 fn shared_mut_into_shared() {
1648 let mut shared_mut: SharedMut<[usize]> = SharedMut::new(&(0_usize, 5)).unwrap();
1649 shared_mut.deref_mut().copy_from_slice(&[1, 2, 3, 4, 5]);
1650 let shared: Shared<[usize]> = shared_mut.try_into().unwrap();
1651 assert_eq!(shared.counts().unwrap(), [1, 0]);
1652 assert_eq!(&[1, 2, 3, 4, 5], shared.deref());
1653 }
1654
1655 #[test]
1656 fn shared_clone() {
1657 let shared = Shared::from(&[1_usize, 2, 3, 4, 5] as &[_]);
1658 let mut container = Vec::new();
1659 for i in 0..100 {
1660 let bytes = serialize_shared(shared.clone());
1661 container.push((bytes, shared.clone()));
1662 assert_eq!(shared.counts().unwrap(), [2 + i, 1 + i]);
1663 }
1664 assert_eq!(shared.counts().unwrap(), [101, 100]);
1665 for (i, (bytes, cl)) in container.into_iter().enumerate() {
1666 let mut _deser: Shared<[usize]> = deserialize(bytes.as_slice()).unwrap();
1667 assert_eq!(&[1_usize, 2, 3, 4, 5], cl.deref());
1668 assert_eq!(shared.counts().unwrap(), [102 - i as u64, 99 - i as u64]);
1669 }
1670 assert_eq!(shared.counts().unwrap(), [1, 0]);
1671 assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1672 }
1673
1674 #[test]
1675 fn races() {
1676 let shared = Shared::from(&[1_usize, 2, 3, 4, 5] as &[_]);
1677 let mut handles = Vec::new();
1678 for _ in 0..50 {
1679 let (send, recv) = std::sync::mpsc::sync_channel(0);
1680 let bytes_send = serialize_shared(shared.clone());
1681 let handle = thread::spawn(move || {
1682 let mut rng = rand::thread_rng();
1683 recv.recv().unwrap();
1684 let mut shared: Shared<[usize]> = deserialize(bytes_send.as_slice()).unwrap();
1685 assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1686 for _ in 0..1000 {
1687 thread::sleep(std::time::Duration::from_millis(rng.gen_range(0..=5)));
1688 let tmp = serialize_shared(shared);
1689 thread::sleep(std::time::Duration::from_millis(rng.gen_range(0..=5)));
1690 shared = deserialize(tmp.as_slice()).unwrap();
1691 assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1692 }
1693 });
1694 handles.push((handle, send));
1695 }
1696 thread::sleep(std::time::Duration::from_millis(100));
1697 for (_, send) in handles.iter() {
1698 send.send(()).unwrap();
1699 }
1700 for (handle, _) in handles {
1701 handle.join().unwrap();
1702 }
1703 assert_eq!(&[1_usize, 2, 3, 4, 5], shared.deref());
1704 assert_eq!(shared.counts().unwrap(), [1, 0]);
1705 }
1706
1707 #[cfg(feature = "shared_ndarray")]
1708 mod ndarray_tests {
1709 use super::*;
1710 use ndarray::{Array, Axis, IxDyn};
1711
1712 #[test]
1713 fn shared_ndarray() {
1714 let shared: SharedArray<u64, IxDyn> = Shared::new(&(0, IxDyn(&[3, 2]))).unwrap();
1715 let shared = serialization_roundtrip_shared(shared);
1716 assert_eq!(&[0; 6], shared.as_view().as_slice().unwrap());
1717 }
1718
1719 #[test]
1720 fn shared_mut_ndarray() {
1721 let mut shared: SharedArrayMut<f64, IxDyn> =
1722 SharedMut::new(&(0.0, IxDyn(&[3, 2]))).unwrap();
1723 let slice: &[f64] = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1724 for (&x, element) in slice.iter().zip(shared.as_view_mut().iter_mut()) {
1725 *element = x;
1726 }
1727 let roundtrip = serialization_roundtrip_shared_mut(shared);
1728 assert_eq!(slice, roundtrip.as_view().as_slice().unwrap());
1729 }
1730
1731 #[test]
1732 fn shared_mut_array_layout() {
1733 let mut array: SharedArrayMut<f64, ndarray::IxDyn> =
1734 SharedArrayMut::new(&(0.0, ndarray::IxDyn(&[100, 200, 300]))).unwrap();
1735 assert_eq!(array.as_view().strides(), &[200 * 300, 300, 1]);
1736 {
1737 let mut view = array.as_view_mut();
1738 assert!(view.is_standard_layout());
1739 view.swap_axes(0, 1);
1740 assert_eq!(view.strides(), &[300, 200 * 300, 1]);
1741 unsafe {
1742 *array.metadata_mut() = (Vec::from(view.shape()), Vec::from(view.strides()));
1743 }
1744 }
1745 assert_eq!(array.as_view().shape(), &[200, 100, 300]);
1746 assert_eq!(array.as_view().strides(), &[300, 200 * 300, 1]);
1747 assert!(!array.as_view().is_standard_layout());
1748 let bytes = serialize_shared_mut(array);
1749 let deser: SharedArrayMut<f64, ndarray::IxDyn> = deserialize(bytes.as_slice()).unwrap();
1750 assert_eq!(deser.as_view().shape(), &[200, 100, 300]);
1751 assert_eq!(deser.as_view().strides(), &[300, 200 * 300, 1]);
1752 assert!(!deser.as_view().is_standard_layout());
1753 }
1754
1755 #[test]
1756 fn shared_ndarray_from_src() {
1757 let mut src = Array::from_elem((100, 200, 300), 0_u64);
1758 src.invert_axis(Axis(1));
1759 let src_shape = Vec::from(src.shape());
1760 let array = Shared::new_from_inner(&src).unwrap();
1761 assert_eq!(src_shape, array.as_view().shape());
1762 assert!(src
1763 .iter()
1764 .zip(array.as_view().iter())
1765 .all(|(src, dst)| src == dst))
1766 }
1767 }
1768}