use bitflags::bitflags;
use libc::IPPROTO_SCTP;
use socket2::{Domain, SockAddr, Socket, Type};
use std::convert::From;
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use std::net::{Shutdown, SocketAddr};
use std::os::unix::io::AsRawFd;
use std::time::Duration;
use tokio::io::ReadBuf;
use crate::sys::common::{sctp_assoc_t, SctpSndRcvInfo};
use crate::{SctpListener, SctpStream};
use crate::sys::linux::*;
#[derive(Debug)]
pub struct SctpSocket {
pub inner: Socket,
}
impl SctpSocket {
pub fn new(domain: Domain) -> io::Result<Self> {
let s = Socket::new(domain, Type::STREAM, Some(libc::IPPROTO_SCTP.into()))
.and_then(|s| SctpSocket::enable_notifications(&s).map(|_| s))
.map(|socket| SctpSocket { inner: socket })?;
Ok(s)
}
fn enable_notifications(socket: &Socket) -> io::Result<()> {
let sub: EventSubscribe = EventSubscribe {
data_io_event: 1,
..Default::default()
};
unsafe {
match libc::setsockopt(
socket.as_raw_fd(),
IPPROTO_SCTP,
SCTP_EVENTS,
&sub as *const EventSubscribe as *const libc::c_void,
std::mem::size_of::<EventSubscribe>() as u32,
) {
r if r >= 0 => Ok(()),
_ => Err(io::Error::last_os_error()),
}
}
}
pub async fn connect(self, addr: SocketAddr) -> io::Result<SctpStream> {
SctpStream::connect_from(self, addr).await
}
pub async fn connectx(self, addrs: &[SocketAddr]) -> io::Result<SctpStream> {
SctpStream::connectx_from(self, addrs).await
}
pub fn connect_sys(&self, addr: SocketAddr) -> io::Result<()> {
self.inner.connect(&SockAddr::from(addr))
}
pub fn connectx_sys(&self, addrs: &mut [libc::sockaddr]) -> io::Result<()> {
unsafe {
match sctp_connectx(
self.as_raw_fd(),
addrs.as_mut_ptr(),
addrs.len().try_into().unwrap(),
std::ptr::null_mut(),
) {
-1 => Err(io::Error::last_os_error()),
_ => Ok(()),
}
}
}
pub async fn bind(self, addr: SocketAddr) -> io::Result<SctpListener> {
SctpListener::bind_from(self, addr)
}
pub async fn bindx(self, addrs: &[SocketAddr]) -> io::Result<SctpListener> {
SctpListener::bindx_from(self, addrs)
}
pub fn bind_sys(&self, addr: SocketAddr) -> io::Result<()> {
self.inner.bind(&socket2::SockAddr::from(addr))
}
pub fn bindx_sys(&self, addrs: &mut [libc::sockaddr]) -> io::Result<()> {
unsafe {
match sctp_bindx(
self.as_raw_fd(),
addrs.as_mut_ptr(),
addrs.len().try_into().unwrap(),
SCTP_BINDX_ADD_ADDR,
) {
-1 => Err(io::Error::last_os_error()),
_ => Ok(()),
}
}
}
pub fn listen(&self, backlog: i32) -> io::Result<()> {
self.inner.listen(backlog)
}
pub fn accept(&self) -> io::Result<(SctpSocket, SocketAddr)> {
let (socket, addr) = self.inner.accept()?;
let addr = addr
.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Address is not valid"))?;
let stream = SctpSocket { inner: socket };
Ok((stream, addr))
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.shutdown(how)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr().and_then(|addr| {
addr.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Address is not valid"))
})
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr().and_then(|addr| {
addr.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "Address is not valid"))
})
}
pub fn sendmsg(
&self,
msg: &[u8],
to: Option<SocketAddr>,
opts: &SendOptions,
) -> io::Result<usize> {
let len = msg.len() as libc::size_t;
let to = to.map(socket2::SockAddr::from);
unsafe {
let res = sctp_sendmsg(
self.as_raw_fd(),
msg.as_ptr() as *const libc::c_void,
len,
to.as_ref()
.map(|t| t.as_ptr() as *mut _)
.unwrap_or(std::ptr::null_mut()),
to.map(|t| t.len()).unwrap_or(0),
opts.ppid.into(),
opts.flags.into(),
opts.stream,
opts.ttl.into(),
0,
);
match res {
res if res > 0 => {
debug_assert!(res as usize == msg.len());
Ok(res as usize)
}
_ => Err(io::Error::last_os_error()),
}
}
}
pub fn recvmsg(&self, buf: &mut ReadBuf<'_>) -> io::Result<(usize, RecvInfo, RecvFlags)> {
let mut flags: libc::c_int = 0;
let mut info = SctpSndRcvInfo::default();
unsafe {
let unfilled: &mut [MaybeUninit<u8>] = buf.unfilled_mut();
let len = unfilled.len() as libc::size_t;
let res = sctp_recvmsg(
self.as_raw_fd(),
unfilled.as_mut_ptr() as *mut libc::c_void,
len,
std::ptr::null_mut(),
std::ptr::null_mut(),
(&mut info) as *mut SctpSndRcvInfo,
&mut flags,
);
match res {
n if n >= 0 => {
let n = n as usize;
buf.assume_init(n);
buf.advance(n);
let f = RecvFlags::from_bits_unchecked(flags);
Ok((n, info.into(), f))
}
_ => Err(io::Error::last_os_error()),
}
}
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.take_error()
}
pub fn peek(&self, buf: &mut [MaybeUninit<u8>]) -> io::Result<usize> {
self.inner.peek(buf)
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
unsafe {
setsockopt(
self.as_raw_fd(),
IPPROTO_SCTP,
SCTP_NODELAY,
nodelay as libc::c_int,
)
}
}
pub fn nodelay(&self) -> io::Result<bool> {
unsafe {
getsockopt::<libc::c_int>(self.as_raw_fd(), IPPROTO_SCTP, SCTP_NODELAY).map(|r| r != 0)
}
}
pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner.set_nonblocking(nonblocking)
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.inner.ttl()
}
pub fn set_read_timeout(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner.set_read_timeout(duration)
}
pub fn read_timeout(&self) -> io::Result<Option<Duration>> {
self.inner.read_timeout()
}
pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner.set_linger(duration)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.linger()
}
pub fn set_sctp_initmsg(&self, init_msg: &InitMsg) -> io::Result<()> {
unsafe {
match libc::setsockopt(
self.as_raw_fd(),
IPPROTO_SCTP,
SCTP_INITMSG,
init_msg as *const InitMsg as *const libc::c_void,
std::mem::size_of::<InitMsg>() as u32,
) {
r if r >= 0 => Ok(()),
_ => Err(io::Error::last_os_error()),
}
}
}
pub fn status(&self) -> io::Result<Status> {
unsafe { getsockopt::<Status>(self.as_raw_fd(), IPPROTO_SCTP, SCTP_STATUS) }
}
pub fn reuseaddr(&self) -> io::Result<bool> {
self.inner.reuse_address()
}
pub fn set_reuseaddr(&self, reuse: bool) -> io::Result<()> {
self.inner.set_reuse_address(reuse)
}
}
impl Read for SctpSocket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf)
}
}
impl Read for &SctpSocket {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
(&self.inner).read(buf)
}
}
impl Write for SctpSocket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl Write for &SctpSocket {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
(&self.inner).write(buf)
}
fn flush(&mut self) -> io::Result<()> {
(&self.inner).flush()
}
}
impl AsRawFd for SctpSocket {
fn as_raw_fd(&self) -> std::os::unix::prelude::RawFd {
self.inner.as_raw_fd()
}
}
pub(crate) unsafe fn setsockopt<T>(
fd: libc::c_int,
opt: libc::c_int,
val: libc::c_int,
payload: T,
) -> io::Result<()> {
let payload = &payload as *const T as *const libc::c_void;
match libc::setsockopt(
fd,
opt,
val,
payload,
std::mem::size_of::<T>() as libc::socklen_t,
) {
-1 => Err(io::Error::last_os_error()),
_ => Ok(()),
}
}
pub(crate) unsafe fn getsockopt<T>(
fd: libc::c_int,
opt: libc::c_int,
val: libc::c_int,
) -> io::Result<T> {
let mut payload: MaybeUninit<T> = MaybeUninit::uninit();
let mut len = std::mem::size_of::<T>() as libc::socklen_t;
match libc::getsockopt(fd, opt, val, payload.as_mut_ptr().cast(), &mut len) {
-1 => Err(io::Error::last_os_error()),
_ => Ok(payload.assume_init()),
}
}
#[derive(Debug, Copy, Clone, Default)]
#[repr(C)]
pub struct InitMsg {
pub num_ostreams: u16,
pub max_instreams: u16,
pub max_attempts: u16,
pub max_init_timeout: u16,
}
#[derive(Clone, Debug)]
#[repr(C)]
pub struct PeerAddrInfo {
pub assoc_id: sctp_assoc_t,
pub address: libc::sockaddr_storage,
pub state: i32,
pub cwnd: u32,
pub srtt: u32,
pub rto: u32,
pub mtu: u32,
}
#[derive(Debug)]
pub struct RecvInfo {
pub stream: u16,
pub ssn: u16,
pub ppid: u32,
pub tsn: u32,
pub cummulative_tsn: u32,
pub flags: u16,
}
impl From<SctpSndRcvInfo> for RecvInfo {
fn from(v: SctpSndRcvInfo) -> Self {
RecvInfo {
stream: v.stream,
ssn: v.ssn,
ppid: v.ppid,
tsn: v.tsn,
cummulative_tsn: v.cumtsn,
flags: v.flags,
}
}
}
#[derive(Debug, Default)]
pub struct SendOptions {
pub ppid: u32,
pub flags: u32,
pub stream: u16,
pub ttl: u32,
}
bitflags! {
pub struct RecvFlags: i32 {
const EOR = libc::MSG_EOR;
const NOTIFICATION = 0x8000;
}
}
#[derive(Debug, Copy, Clone, Default)]
#[repr(C)]
struct EventSubscribe {
data_io_event: u8,
association_event: u8,
address_event: u8,
send_failure_event: u8,
peer_error_event: u8,
shutdown_event: u8,
partial_delivery_event: u8,
adaptation_layer_event: u8,
authentication_event: u8,
sender_dry_event: u8,
}
#[derive(Clone, Debug)]
#[repr(C)]
pub struct Status {
pub assoc_id: sctp_assoc_t,
pub state: i32,
pub rwnd: u32,
pub unackdata: u16,
pub penddata: u16,
pub instrms: u16,
pub outstrms: u16,
pub fragmentation_point: u32,
pub primary: PeerAddrInfo,
}