use std::mem::{self, MaybeUninit};
use std::fmt;
use std::slice;
use std::marker::PhantomData;
use std::sync::mpsc::{RecvError, SendError};
use std::os::unix::io::{RawFd, AsRawFd, IntoRawFd, FromRawFd};
pub struct Sender<T> {
fd: RawFd,
variance: PhantomData<fn(T)>,
not_send: PhantomData<*const ()>,
}
pub struct Receiver<T>{
fd: RawFd,
variance: PhantomData<fn() -> T>,
not_send: PhantomData<*const ()>,
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
nix::unistd::close(self.fd).unwrap();
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
nix::unistd::close(self.fd).unwrap();
}
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Send for Receiver<T> {}
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let flags = nix::fcntl::OFlag::from_bits(libc::O_CLOEXEC).unwrap();
let fd = nix::unistd::pipe2(flags).unwrap();
(
Sender::new(fd.1),
Receiver::new(fd.0),
)
}
impl<T> Sender<T> {
fn new(fd: RawFd) -> Self {
Sender { fd, variance: PhantomData, not_send: PhantomData }
}
pub fn send(&mut self, t: T) -> Result<(), SendError<T>> {
let mut s: &[u8] = &[0];
if mem::size_of::<T>() > 0 {
s = unsafe {
slice::from_raw_parts(&t as *const T as *const u8, mem::size_of::<T>())
};
}
let mut n = 0;
while n < s.len() {
match nix::unistd::write(self.fd, &s[n..]) {
Ok(count) => n += count,
Err(nix::Error::Sys(nix::errno::Errno::EPIPE)) => return Err(SendError(t)),
e => { e.unwrap(); }
}
}
mem::forget(t);
Ok(())
}
}
impl<T> Receiver<T> {
fn new(fd: RawFd) -> Self {
Receiver { fd, variance: PhantomData, not_send: PhantomData }
}
pub fn recv(&mut self) -> Result<T, RecvError> {
unsafe {
let mut t = MaybeUninit::<T>::uninit();
let mut s: &mut [u8] = &mut [0];
if mem::size_of::<T>() > 0 {
s = slice::from_raw_parts_mut(t.as_mut_ptr() as *mut u8, mem::size_of::<T>())
}
let mut n = 0;
while n < s.len() {
match nix::unistd::read(self.fd, &mut s[n..]) {
Ok(0) => return Err(RecvError),
Ok(count) => n += count,
e => { e.unwrap(); }
}
}
Ok(t.assume_init())
}
}
pub fn iter(&mut self) -> Iter<T> {
self.into_iter()
}
}
impl<T> AsRawFd for Sender<T> {
fn as_raw_fd(&self) -> RawFd { self.fd }
}
impl<T> AsRawFd for Receiver<T> {
fn as_raw_fd(&self) -> RawFd { self.fd }
}
impl<T> IntoRawFd for Sender<T> {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl<T> IntoRawFd for Receiver<T> {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl<T> FromRawFd for Sender<T> {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Sender::new(fd)
}
}
impl<T> FromRawFd for Receiver<T> {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Receiver::new(fd)
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Sender")
.field("fd", &self.fd)
.finish()
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Receiver")
.field("fd", &self.fd)
.finish()
}
}
pub struct IntoIter<T>(Receiver<T>);
impl<T> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.0.recv().ok()
}
}
impl<T> IntoIterator for Receiver<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> IntoIter<T> {
IntoIter(self)
}
}
pub struct Iter<'a, T: 'a>(&'a mut Receiver<T>);
impl<'a, T: 'a> Iterator for Iter<'a, T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.0.recv().ok()
}
}
impl<'a, T: 'a> IntoIterator for &'a mut Receiver<T> {
type Item = T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Iter<'a, T> {
Iter(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc::RecvError;
#[test]
fn no_leak() {
use std::sync::{Arc, Mutex};
use std::thread;
struct T(Arc<Mutex<i32>>);
impl Drop for T {
fn drop(&mut self) {
*self.0.lock().unwrap() += 1;
}
}
let cnt = Arc::new(Mutex::new(0));
let t = T(cnt.clone());
let (mut tx, mut rx) = channel();
assert_eq!(*cnt.lock().unwrap(), 0);
tx.send(t).unwrap();
assert_eq!(*cnt.lock().unwrap(), 0);
thread::spawn(move || rx.recv().unwrap()).join().unwrap();
assert_eq!(*cnt.lock().unwrap(), 1);
}
#[test]
fn no_drop_on_recv_err() {
#[derive(Debug)]
struct T<'a>(&'a mut i32);
impl Drop for T<'_> {
fn drop(&mut self) {
*self.0 += 1;
}
}
let mut cnt = 0;
{
let t = T(&mut cnt);
let (mut tx, mut rx) = channel();
tx.send(t).unwrap();
rx.recv().unwrap();
drop(tx);
rx.recv().unwrap_err();
}
assert_eq!(cnt, 1);
}
#[test]
fn zero_sized_type() {
let (mut tx, mut rx) = channel();
tx.send(()).unwrap();
assert_eq!(rx.recv().unwrap(), ());
}
#[test]
fn zero_sized_type_drop() {
let (tx, mut rx) = channel::<()>();
drop(tx);
assert_eq!(rx.recv(), Err(RecvError));
}
#[test]
fn debug_print() {
use std::os::unix::io::AsRawFd;
let (tx, _) = channel::<i32>();
let s1 = format!("Sender {{ fd: {:?} }}", tx.as_raw_fd());
let s2 = format!("{:?}", tx);
assert_eq!(s1, s2);
}
#[test]
fn large_data() {
struct Large([usize; 4096]);
impl Large {
fn new() -> Large {
let mut res = [0; 4096];
for i in 0..(res.len()) {
res[i] = i * i;
}
Large(res)
}
}
unsafe impl Send for Large {};
let (mut tx, mut rx) = channel();
tx.send(Large::new()).unwrap();
let res = rx.recv().unwrap();
let expected = Large::new();
for i in 0..(res.0.len()) {
assert_eq!(res.0[i], expected.0[i]);
}
}
#[test]
fn no_send_no_threading() {
use std::rc::Rc;
let rc = Rc::new(1024);
let (mut tx, mut rx) = channel();
tx.send(rc).unwrap();
let res = rx.recv().unwrap();
assert_eq!(*res, 1024);
}
#[test]
fn raw_fd() {
use std::os::unix::io::{AsRawFd, IntoRawFd, FromRawFd};
let (mut tx, rx) = channel();
let fd = rx.into_raw_fd();
let mut rx = unsafe { Receiver::<i32>::from_raw_fd(fd) };
assert_eq!(rx.as_raw_fd(), fd);
tx.send(42).unwrap();
assert_eq!(rx.recv().unwrap(), 42);
}
}