use super::imports::*;
use crate::local_socket::{self as sync, ToLocalSocketName};
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[derive(Debug)]
pub struct LocalSocketListener {
inner: Arc<sync::LocalSocketListener>,
}
impl LocalSocketListener {
pub async fn bind<'a>(name: impl ToLocalSocketName<'_> + Send + 'static) -> io::Result<Self> {
Ok(Self {
inner: Arc::new(unblock(move || sync::LocalSocketListener::bind(name)).await?),
})
}
pub async fn accept(&self) -> io::Result<LocalSocketStream> {
let s = self.inner.clone();
Ok(LocalSocketStream {
inner: Unblock::new(unblock(move || s.accept()).await?),
})
}
pub fn incoming(&self) -> Incoming {
Incoming {
inner: Unblock::new(SyncArcIncoming {
inner: Arc::clone(&self.inner),
}),
}
}
}
#[derive(Debug)]
pub struct Incoming {
inner: Unblock<SyncArcIncoming>,
}
#[cfg(feature = "nonblocking")]
impl Stream for Incoming {
type Item = Result<LocalSocketStream, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let poll = <Unblock<_> as Stream>::poll_next(Pin::new(&mut self.inner), ctx);
match poll {
Poll::Ready(val) => {
let val = val.map(|val| match val {
Ok(inner) => Ok(LocalSocketStream {
inner: Unblock::new(inner),
}),
Err(error) => Err(error),
});
Poll::Ready(val)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(feature = "nonblocking")]
impl FusedStream for Incoming {
fn is_terminated(&self) -> bool {
false
}
}
#[derive(Debug)]
struct SyncArcIncoming {
inner: Arc<sync::LocalSocketListener>,
}
impl Iterator for SyncArcIncoming {
type Item = Result<sync::LocalSocketStream, io::Error>;
fn next(&mut self) -> Option<Self::Item> {
Some(self.inner.accept())
}
}
#[derive(Debug)]
pub struct LocalSocketStream {
inner: Unblock<sync::LocalSocketStream>,
}
impl LocalSocketStream {
pub async fn connect<'a>(
name: impl ToLocalSocketName<'a> + Send + 'static,
) -> io::Result<Self> {
Ok(Self {
inner: Unblock::new(unblock(move || sync::LocalSocketStream::connect(name)).await?),
})
}
}
#[cfg(feature = "nonblocking")]
impl AsyncRead for LocalSocketStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf)
}
}
#[cfg(feature = "nonblocking")]
impl AsyncWrite for LocalSocketStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_close(Pin::new(&mut self.inner), cx)
}
}