use crate::error::SerDeError;
use crate::platform::{self, OsIpcChannel, OsIpcReceiver, OsIpcReceiverSet, OsIpcSender};
use crate::platform::{
OsIpcOneShotServer, OsIpcSelectionResult, OsIpcSharedMemory, OsOpaqueIpcChannel,
OsTrySelectError,
};
use crate::{IpcError, TryRecvError, TrySelectError};
use serde_core::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
use std::cell::RefCell;
use std::cmp::min;
use std::fmt::{self, Debug, Formatter};
use std::io;
use std::marker::PhantomData;
use std::ops::Deref;
use std::time::Duration;
thread_local! {
static OS_IPC_CHANNELS_FOR_DESERIALIZATION: RefCell<Vec<OsOpaqueIpcChannel>> =
const { RefCell::new(Vec::new()) };
static OS_IPC_SHARED_MEMORY_REGIONS_FOR_DESERIALIZATION:
RefCell<Vec<Option<OsIpcSharedMemory>>> = const { RefCell::new(Vec::new()) };
static OS_IPC_CHANNELS_FOR_SERIALIZATION: RefCell<Vec<OsIpcChannel>> = const { RefCell::new(Vec::new()) };
static OS_IPC_SHARED_MEMORY_REGIONS_FOR_SERIALIZATION: RefCell<Vec<OsIpcSharedMemory>> =
const { RefCell::new(Vec::new()) }
}
pub fn channel<T>() -> Result<(IpcSender<T>, IpcReceiver<T>), io::Error>
where
T: for<'de> Deserialize<'de> + Serialize,
{
let (os_sender, os_receiver) = platform::channel()?;
let ipc_receiver = IpcReceiver {
os_receiver,
phantom: PhantomData,
};
let ipc_sender = IpcSender {
os_sender,
phantom: PhantomData,
};
Ok((ipc_sender, ipc_receiver))
}
pub fn bytes_channel() -> Result<(IpcBytesSender, IpcBytesReceiver), io::Error> {
let (os_sender, os_receiver) = platform::channel()?;
let ipc_bytes_receiver = IpcBytesReceiver { os_receiver };
let ipc_bytes_sender = IpcBytesSender { os_sender };
Ok((ipc_bytes_sender, ipc_bytes_receiver))
}
#[derive(Debug)]
pub struct IpcReceiver<T> {
os_receiver: OsIpcReceiver,
phantom: PhantomData<T>,
}
impl<T> IpcReceiver<T>
where
T: for<'de> Deserialize<'de> + Serialize,
{
pub fn recv(&self) -> Result<T, IpcError> {
self.os_receiver
.recv()?
.to()
.map_err(IpcError::SerializationError)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.os_receiver
.try_recv()?
.to()
.map_err(IpcError::SerializationError)
.map_err(TryRecvError::IpcError)
}
pub fn try_recv_timeout(&self, duration: Duration) -> Result<T, TryRecvError> {
self.os_receiver
.try_recv_timeout(duration)?
.to()
.map_err(IpcError::SerializationError)
.map_err(TryRecvError::IpcError)
}
pub fn to_opaque(self) -> OpaqueIpcReceiver {
OpaqueIpcReceiver {
os_receiver: self.os_receiver,
}
}
}
impl<'de, T> Deserialize<'de> for IpcReceiver<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let os_receiver = deserialize_os_ipc_receiver(deserializer)?;
Ok(IpcReceiver {
os_receiver,
phantom: PhantomData,
})
}
}
impl<T> Serialize for IpcReceiver<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_os_ipc_receiver(&self.os_receiver, serializer)
}
}
#[derive(Debug)]
pub struct IpcSender<T> {
os_sender: OsIpcSender,
phantom: PhantomData<T>,
}
impl<T> Clone for IpcSender<T>
where
T: Serialize,
{
fn clone(&self) -> IpcSender<T> {
IpcSender {
os_sender: self.os_sender.clone(),
phantom: PhantomData,
}
}
}
impl<T> IpcSender<T>
where
T: Serialize,
{
pub fn connect(name: String) -> Result<IpcSender<T>, io::Error> {
Ok(IpcSender {
os_sender: OsIpcSender::connect(name)?,
phantom: PhantomData,
})
}
pub fn send(&self, data: T) -> Result<(), IpcError> {
OS_IPC_CHANNELS_FOR_SERIALIZATION.with(|os_ipc_channels_for_serialization| {
OS_IPC_SHARED_MEMORY_REGIONS_FOR_SERIALIZATION.with(
|os_ipc_shared_memory_regions_for_serialization| {
let bytes = postcard::to_stdvec(&data).map_err(SerDeError)?;
let os_ipc_channels = os_ipc_channels_for_serialization.take();
let os_ipc_shared_memory_regions =
os_ipc_shared_memory_regions_for_serialization.take();
Ok(self.os_sender.send(
&bytes[..],
os_ipc_channels,
os_ipc_shared_memory_regions,
)?)
},
)
})
}
pub fn to_opaque(self) -> OpaqueIpcSender {
OpaqueIpcSender {
os_sender: self.os_sender,
}
}
}
impl<'de, T> Deserialize<'de> for IpcSender<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let os_sender = deserialize_os_ipc_sender(deserializer)?;
Ok(IpcSender {
os_sender,
phantom: PhantomData,
})
}
}
impl<T> Serialize for IpcSender<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_os_ipc_sender(&self.os_sender, serializer)
}
}
pub struct IpcReceiverSet {
os_receiver_set: OsIpcReceiverSet,
}
impl IpcReceiverSet {
pub fn new() -> Result<IpcReceiverSet, io::Error> {
Ok(IpcReceiverSet {
os_receiver_set: OsIpcReceiverSet::new()?,
})
}
pub fn add<T>(&mut self, receiver: IpcReceiver<T>) -> Result<u64, io::Error>
where
T: for<'de> Deserialize<'de> + Serialize,
{
Ok(self.os_receiver_set.add(receiver.os_receiver)?)
}
pub fn add_opaque(&mut self, receiver: OpaqueIpcReceiver) -> Result<u64, io::Error> {
Ok(self.os_receiver_set.add(receiver.os_receiver)?)
}
pub fn select(&mut self) -> Result<Vec<IpcSelectionResult>, io::Error> {
let results = self.os_receiver_set.select()?;
Ok(results
.into_iter()
.map(|result| match result {
OsIpcSelectionResult::DataReceived(os_receiver_id, ipc_message) => {
IpcSelectionResult::MessageReceived(os_receiver_id, ipc_message)
},
OsIpcSelectionResult::ChannelClosed(os_receiver_id) => {
IpcSelectionResult::ChannelClosed(os_receiver_id)
},
})
.collect())
}
pub fn try_select(&mut self) -> Result<Vec<IpcSelectionResult>, TrySelectError> {
let results: Vec<OsIpcSelectionResult> =
self.os_receiver_set.try_select().map_err(|e| match e {
OsTrySelectError::IoError(e) => TrySelectError::IoError(e.into()),
OsTrySelectError::Empty => TrySelectError::Empty,
})?;
let results = results
.into_iter()
.map(|result| match result {
OsIpcSelectionResult::DataReceived(os_receiver_id, ipc_message) => {
IpcSelectionResult::MessageReceived(os_receiver_id, ipc_message)
},
OsIpcSelectionResult::ChannelClosed(os_receiver_id) => {
IpcSelectionResult::ChannelClosed(os_receiver_id)
},
})
.collect::<Vec<IpcSelectionResult>>();
Ok(results)
}
pub fn try_select_timeout(
&mut self,
duration: Duration,
) -> Result<Vec<IpcSelectionResult>, TrySelectError> {
let results = self
.os_receiver_set
.try_select_timeout(duration)
.map_err(|e| match e {
OsTrySelectError::IoError(e) => TrySelectError::IoError(e.into()),
OsTrySelectError::Empty => TrySelectError::Empty,
})?;
let results = results
.into_iter()
.map(|result| match result {
OsIpcSelectionResult::DataReceived(os_receiver_id, ipc_message) => {
IpcSelectionResult::MessageReceived(os_receiver_id, ipc_message)
},
OsIpcSelectionResult::ChannelClosed(os_receiver_id) => {
IpcSelectionResult::ChannelClosed(os_receiver_id)
},
})
.collect::<Vec<IpcSelectionResult>>();
if results.is_empty() {
Err(TrySelectError::Empty)
} else {
Ok(results)
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct IpcSharedMemory {
os_shared_memory: Option<OsIpcSharedMemory>,
}
impl Deref for IpcSharedMemory {
type Target = [u8];
#[inline]
fn deref(&self) -> &[u8] {
if let Some(os_shared_memory) = &self.os_shared_memory {
os_shared_memory
} else {
&[]
}
}
}
impl IpcSharedMemory {
#[inline]
pub unsafe fn deref_mut(&mut self) -> &mut [u8] {
if let Some(os_shared_memory) = &mut self.os_shared_memory {
os_shared_memory.deref_mut()
} else {
&mut []
}
}
}
impl<'de> Deserialize<'de> for IpcSharedMemory {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let index: usize = Deserialize::deserialize(deserializer)?;
if index == usize::MAX {
return Ok(IpcSharedMemory::empty());
}
let os_shared_memory = OS_IPC_SHARED_MEMORY_REGIONS_FOR_DESERIALIZATION.with(
|os_ipc_shared_memory_regions_for_deserialization| {
let mut regions = os_ipc_shared_memory_regions_for_deserialization.borrow_mut();
let Some(region) = regions.get_mut(index) else {
return Err(format!("Cannot consume shared memory region at index {index}, there are only {} regions available", regions.len()));
};
region.take().ok_or_else(|| format!("Shared memory region {index} has already been consumed"))
},
).map_err(D::Error::custom)?;
Ok(IpcSharedMemory {
os_shared_memory: Some(os_shared_memory),
})
}
}
impl Serialize for IpcSharedMemory {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if let Some(os_shared_memory) = &self.os_shared_memory {
let index = OS_IPC_SHARED_MEMORY_REGIONS_FOR_SERIALIZATION.with(
|os_ipc_shared_memory_regions_for_serialization| {
let mut os_ipc_shared_memory_regions_for_serialization =
os_ipc_shared_memory_regions_for_serialization.borrow_mut();
let index = os_ipc_shared_memory_regions_for_serialization.len();
os_ipc_shared_memory_regions_for_serialization.push(os_shared_memory.clone());
index
},
);
debug_assert!(index < usize::MAX);
index
} else {
usize::MAX
}
.serialize(serializer)
}
}
impl IpcSharedMemory {
const fn empty() -> Self {
Self {
os_shared_memory: None,
}
}
pub fn from_bytes(bytes: &[u8]) -> IpcSharedMemory {
if bytes.is_empty() {
IpcSharedMemory::empty()
} else {
IpcSharedMemory {
os_shared_memory: Some(OsIpcSharedMemory::from_bytes(bytes)),
}
}
}
pub fn from_byte(byte: u8, length: usize) -> IpcSharedMemory {
if length == 0 {
IpcSharedMemory::empty()
} else {
IpcSharedMemory {
os_shared_memory: Some(OsIpcSharedMemory::from_byte(byte, length)),
}
}
}
pub fn take(mut self) -> Option<Vec<u8>> {
if let Some(os_shared_memory) = self.os_shared_memory.take() {
os_shared_memory.take()
} else {
Some(vec![])
}
}
}
#[derive(Debug)]
pub enum IpcSelectionResult {
MessageReceived(u64, IpcMessage),
ChannelClosed(u64),
}
impl IpcSelectionResult {
pub fn unwrap(self) -> (u64, IpcMessage) {
match self {
IpcSelectionResult::MessageReceived(id, message) => (id, message),
IpcSelectionResult::ChannelClosed(id) => {
panic!("IpcSelectionResult::unwrap(): channel {id} closed")
},
}
}
}
pub struct IpcMessage {
pub(crate) data: Vec<u8>,
pub(crate) os_ipc_channels: Vec<OsOpaqueIpcChannel>,
pub(crate) os_ipc_shared_memory_regions: Vec<OsIpcSharedMemory>,
}
impl IpcMessage {
pub fn from_data(data: Vec<u8>) -> Self {
Self {
data,
os_ipc_channels: vec![],
os_ipc_shared_memory_regions: vec![],
}
}
}
impl Debug for IpcMessage {
fn fmt(&self, formatter: &mut Formatter) -> Result<(), fmt::Error> {
match String::from_utf8(self.data.clone()) {
Ok(string) => string.chars().take(256).collect::<String>().fmt(formatter),
Err(..) => self.data[0..min(self.data.len(), 256)].fmt(formatter),
}
}
}
impl IpcMessage {
pub(crate) fn new(
data: Vec<u8>,
os_ipc_channels: Vec<OsOpaqueIpcChannel>,
os_ipc_shared_memory_regions: Vec<OsIpcSharedMemory>,
) -> IpcMessage {
IpcMessage {
data,
os_ipc_channels,
os_ipc_shared_memory_regions,
}
}
pub fn to<T>(self) -> Result<T, SerDeError>
where
T: for<'de> Deserialize<'de> + Serialize,
{
OS_IPC_CHANNELS_FOR_DESERIALIZATION.with(|os_ipc_channels_for_deserialization| {
OS_IPC_SHARED_MEMORY_REGIONS_FOR_DESERIALIZATION.with(
|os_ipc_shared_memory_regions_for_deserialization| {
*os_ipc_channels_for_deserialization.borrow_mut() = self.os_ipc_channels;
*os_ipc_shared_memory_regions_for_deserialization.borrow_mut() = self
.os_ipc_shared_memory_regions
.into_iter()
.map(Some)
.collect();
let result = postcard::from_bytes(&self.data).map_err(|e| e.into());
let _ = os_ipc_shared_memory_regions_for_deserialization.take();
let _ = os_ipc_channels_for_deserialization.take();
result
},
)
})
}
}
#[derive(Clone, Debug)]
pub struct OpaqueIpcSender {
os_sender: OsIpcSender,
}
impl OpaqueIpcSender {
pub fn to<'de, T>(self) -> IpcSender<T>
where
T: Deserialize<'de> + Serialize,
{
IpcSender {
os_sender: self.os_sender,
phantom: PhantomData,
}
}
}
impl<'de> Deserialize<'de> for OpaqueIpcSender {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let os_sender = deserialize_os_ipc_sender(deserializer)?;
Ok(OpaqueIpcSender { os_sender })
}
}
impl Serialize for OpaqueIpcSender {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_os_ipc_sender(&self.os_sender, serializer)
}
}
#[derive(Debug)]
pub struct OpaqueIpcReceiver {
os_receiver: OsIpcReceiver,
}
impl OpaqueIpcReceiver {
pub fn to<'de, T>(self) -> IpcReceiver<T>
where
T: Deserialize<'de> + Serialize,
{
IpcReceiver {
os_receiver: self.os_receiver,
phantom: PhantomData,
}
}
}
impl<'de> Deserialize<'de> for OpaqueIpcReceiver {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let os_receiver = deserialize_os_ipc_receiver(deserializer)?;
Ok(OpaqueIpcReceiver { os_receiver })
}
}
impl Serialize for OpaqueIpcReceiver {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_os_ipc_receiver(&self.os_receiver, serializer)
}
}
pub struct IpcOneShotServer<T> {
os_server: OsIpcOneShotServer,
phantom: PhantomData<T>,
}
impl<T> IpcOneShotServer<T>
where
T: for<'de> Deserialize<'de> + Serialize,
{
pub fn new() -> Result<(IpcOneShotServer<T>, String), io::Error> {
let (os_server, name) = OsIpcOneShotServer::new()?;
Ok((
IpcOneShotServer {
os_server,
phantom: PhantomData,
},
name,
))
}
pub fn accept(self) -> Result<(IpcReceiver<T>, T), IpcError> {
let (os_receiver, ipc_message) = self.os_server.accept()?;
Ok((
IpcReceiver {
os_receiver,
phantom: PhantomData,
},
ipc_message.to()?,
))
}
}
#[derive(Debug)]
pub struct IpcBytesReceiver {
os_receiver: OsIpcReceiver,
}
impl IpcBytesReceiver {
#[inline]
pub fn recv(&self) -> Result<Vec<u8>, IpcError> {
match self.os_receiver.recv() {
Ok(ipc_message) => Ok(ipc_message.data),
Err(err) => Err(err.into()),
}
}
pub fn try_recv(&self) -> Result<Vec<u8>, TryRecvError> {
match self.os_receiver.try_recv() {
Ok(ipc_message) => Ok(ipc_message.data),
Err(err) => Err(err.into()),
}
}
}
impl<'de> Deserialize<'de> for IpcBytesReceiver {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let os_receiver = deserialize_os_ipc_receiver(deserializer)?;
Ok(IpcBytesReceiver { os_receiver })
}
}
impl Serialize for IpcBytesReceiver {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_os_ipc_receiver(&self.os_receiver, serializer)
}
}
#[derive(Debug)]
pub struct IpcBytesSender {
os_sender: OsIpcSender,
}
impl Clone for IpcBytesSender {
fn clone(&self) -> IpcBytesSender {
IpcBytesSender {
os_sender: self.os_sender.clone(),
}
}
}
impl<'de> Deserialize<'de> for IpcBytesSender {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let os_sender = deserialize_os_ipc_sender(deserializer)?;
Ok(IpcBytesSender { os_sender })
}
}
impl Serialize for IpcBytesSender {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serialize_os_ipc_sender(&self.os_sender, serializer)
}
}
impl IpcBytesSender {
#[inline]
pub fn send(&self, data: &[u8]) -> Result<(), io::Error> {
self.os_sender
.send(data, vec![], vec![])
.map_err(io::Error::from)
}
}
fn serialize_os_ipc_sender<S>(os_ipc_sender: &OsIpcSender, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let index = OS_IPC_CHANNELS_FOR_SERIALIZATION.with(|os_ipc_channels_for_serialization| {
let mut os_ipc_channels_for_serialization = os_ipc_channels_for_serialization.borrow_mut();
let index = os_ipc_channels_for_serialization.len();
os_ipc_channels_for_serialization.push(OsIpcChannel::Sender(os_ipc_sender.clone()));
index
});
index.serialize(serializer)
}
fn deserialize_os_ipc_sender<'de, D>(deserializer: D) -> Result<OsIpcSender, D::Error>
where
D: Deserializer<'de>,
{
let index: usize = Deserialize::deserialize(deserializer)?;
OS_IPC_CHANNELS_FOR_DESERIALIZATION.with(|os_ipc_channels_for_deserialization| {
Ok(os_ipc_channels_for_deserialization.borrow_mut()[index].to_sender())
})
}
fn serialize_os_ipc_receiver<S>(
os_receiver: &OsIpcReceiver,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let index = OS_IPC_CHANNELS_FOR_SERIALIZATION.with(|os_ipc_channels_for_serialization| {
let mut os_ipc_channels_for_serialization = os_ipc_channels_for_serialization.borrow_mut();
let index = os_ipc_channels_for_serialization.len();
os_ipc_channels_for_serialization.push(OsIpcChannel::Receiver(os_receiver.consume()));
index
});
index.serialize(serializer)
}
fn deserialize_os_ipc_receiver<'de, D>(deserializer: D) -> Result<OsIpcReceiver, D::Error>
where
D: Deserializer<'de>,
{
let index: usize = Deserialize::deserialize(deserializer)?;
OS_IPC_CHANNELS_FOR_DESERIALIZATION.with(|os_ipc_channels_for_deserialization| {
Ok(os_ipc_channels_for_deserialization.borrow_mut()[index].to_receiver())
})
}