use crate::socket::{RecvFlags, RecvInfo, SctpSocket};
use crate::{split_owned, wrap_io_error, OwnedReadHalf, OwnedWriteHalf, SendOptions, Status};
use bytes::BufMut;
use socket2::Domain;
use std::io;
use std::net::{Shutdown, SocketAddr};
use std::os::unix::io::{AsRawFd, RawFd};
use tokio::io::unix::AsyncFd;
use tokio::io::ReadBuf;
#[derive(Debug)]
pub struct SctpStream {
pub(crate) inner: AsyncFd<SctpSocket>,
}
impl SctpStream {
pub fn new(socket: SctpSocket) -> io::Result<Self> {
Ok(SctpStream {
inner: AsyncFd::new(socket)?,
})
}
pub async fn connect(addr: SocketAddr) -> io::Result<Self> {
let socket = SctpSocket::new(Domain::for_address(addr))
.map_err(|e| wrap_io_error("Failed to create socket", e))?;
socket
.set_nonblocking(true)
.map_err(|e| wrap_io_error("Failed to set socket non-blocking", e))?;
SctpStream::connect_inner(socket, |s| s.connect_sys(addr)).await
}
pub async fn connectx(addrs: &[SocketAddr]) -> io::Result<Self> {
let Some(addr) = addrs.first() else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "Addresses empty"))
};
let socket = SctpSocket::new(Domain::for_address(*addr))
.map_err(|e| wrap_io_error("Failed to create socket", e))?;
socket
.set_nonblocking(true)
.map_err(|e| wrap_io_error("Failed to set socket non-blocking", e))?;
let mut addrs = addrs
.iter()
.map(|a| unsafe { *socket2::SockAddr::from(*a).as_ptr() })
.collect::<Vec<_>>();
SctpStream::connect_inner(socket, |s| s.connectx_sys(&mut addrs)).await
}
pub async fn connect_from(socket: SctpSocket, addr: SocketAddr) -> io::Result<Self> {
socket
.set_nonblocking(true)
.map_err(|e| wrap_io_error("Failed to set socket non-blocking", e))?;
SctpStream::connect_inner(socket, |s| s.connect_sys(addr)).await
}
pub async fn connectx_from(socket: SctpSocket, addrs: &[SocketAddr]) -> io::Result<Self> {
socket
.set_nonblocking(true)
.map_err(|e| wrap_io_error("Failed to set socket non-blocking", e))?;
let mut addrs = addrs
.iter()
.map(|a| unsafe { *socket2::SockAddr::from(*a).as_ptr() })
.collect::<Vec<_>>();
SctpStream::connect_inner(socket, |s| s.connectx_sys(&mut addrs)).await
}
async fn connect_inner<F>(socket: SctpSocket, mut sys_call: F) -> io::Result<Self>
where
F: FnMut(&SctpSocket) -> io::Result<()>,
{
match sys_call(&socket) {
Err(err) if err.kind() == io::ErrorKind::WouldBlock => {}
Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(err) => return Err(wrap_io_error("Failed to connect", err)),
Ok(_) => {}
};
let stream =
SctpStream::new(socket).map_err(|e| wrap_io_error("Failed to create SctpStream", e))?;
let _ = stream
.inner
.writable()
.await
.map_err(|e| wrap_io_error("Failed to wait until writable", e))?;
if let Some(e) = stream.inner.get_ref().take_error()? {
return Err(wrap_io_error("Connect error", e));
}
Ok(stream)
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.get_ref().shutdown(how)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().peer_addr()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.get_ref().set_nodelay(nodelay)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.get_ref().nodelay()
}
pub fn status(&self) -> io::Result<Status> {
self.inner.get_ref().status()
}
pub async fn sendmsg<'a>(
&'a self,
buf: &'a [u8],
to: Option<SocketAddr>,
opts: &SendOptions,
) -> io::Result<usize> {
loop {
let mut guard = self.inner.writable().await?;
match guard.try_io(|inner| inner.get_ref().sendmsg(buf, to, opts)) {
Ok(res) => return res,
Err(_would_block) => {}
};
}
}
pub async fn recvmsg<'a>(
&'a self,
buf: &'a mut ReadBuf<'a>,
) -> io::Result<(usize, RecvInfo, RecvFlags)> {
loop {
let mut guard = self.inner.readable().await?;
match guard.try_io(|inner| inner.get_ref().recvmsg(buf)) {
Ok(Ok(res)) => return Ok(res),
Ok(err) => return err,
Err(_would_block) => {}
};
}
}
pub async fn recvmsg_buf<'a, B: BufMut>(
&'a self,
buf: &'a mut B,
) -> io::Result<(usize, RecvInfo, RecvFlags)> {
let dst = unsafe { buf.chunk_mut().as_uninit_slice_mut() };
let mut dst = ReadBuf::uninit(dst);
match self.recvmsg(&mut dst).await {
Ok(res) => {
unsafe { buf.advance_mut(res.0) };
Ok(res)
}
Err(err) => Err(err),
}
}
pub async fn recvmsg_eor_buf<'a, B: BufMut>(
&'a self,
buf: &'a mut B,
) -> io::Result<(usize, RecvInfo, RecvFlags)> {
let mut total = 0;
loop {
let (n, rcvinfo, flags) = self.recvmsg_buf(buf).await?;
if n == 0 {
return Ok((n, rcvinfo, flags));
}
total += n;
if flags.contains(RecvFlags::EOR) {
return Ok((total, rcvinfo, flags));
}
}
}
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
split_owned(self)
}
}
impl AsRawFd for SctpStream {
#[inline]
fn as_raw_fd(&self) -> RawFd {
self.inner.get_ref().as_raw_fd()
}
}