use crate::future::poll_fn;
use crate::net::udp::UdpSocket;
use std::error::Error;
use std::fmt;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
#[derive(Debug)]
pub struct SendHalf(Arc<UdpSocket>);
#[derive(Debug)]
pub struct RecvHalf(Arc<UdpSocket>);
pub(crate) fn split(socket: UdpSocket) -> (RecvHalf, SendHalf) {
let shared = Arc::new(socket);
let send = shared.clone();
let recv = shared;
(RecvHalf(recv), SendHalf(send))
}
#[derive(Debug)]
pub struct ReuniteError(pub SendHalf, pub RecvHalf);
impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl Error for ReuniteError {}
fn reunite(s: SendHalf, r: RecvHalf) -> Result<UdpSocket, ReuniteError> {
if Arc::ptr_eq(&s.0, &r.0) {
drop(r);
Ok(Arc::try_unwrap(s.0).expect("udp: try_unwrap failed in reunite"))
} else {
Err(ReuniteError(s, r))
}
}
impl RecvHalf {
pub fn reunite(self, other: SendHalf) -> Result<UdpSocket, ReuniteError> {
reunite(other, self)
}
pub async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
poll_fn(|cx| self.0.poll_recv_from(cx, buf)).await
}
pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| self.0.poll_recv(cx, buf)).await
}
}
impl SendHalf {
pub fn reunite(self, other: RecvHalf) -> Result<UdpSocket, ReuniteError> {
reunite(self, other)
}
pub async fn send_to(&mut self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
poll_fn(|cx| self.0.poll_send_to(cx, buf, target)).await
}
pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
poll_fn(|cx| self.0.poll_send(cx, buf)).await
}
}
impl AsRef<UdpSocket> for SendHalf {
fn as_ref(&self) -> &UdpSocket {
&self.0
}
}
impl AsRef<UdpSocket> for RecvHalf {
fn as_ref(&self) -> &UdpSocket {
&self.0
}
}