#![warn(missing_docs)]
use core::mem;
use std::fmt;
use std::io::Error;
use std::os::unix::io::RawFd;
use std::time::Duration;
use mctp::{
Eid, MsgIC, MsgType, Result, Tag, TagValue, MCTP_ADDR_ANY, MCTP_TAG_OWNER,
};
const AF_MCTP: libc::sa_family_t = 45;
#[repr(C)]
#[allow(non_camel_case_types)]
struct sockaddr_mctp {
smctp_family: libc::sa_family_t,
__smctp_pad0: u16,
smctp_network: u32,
smctp_addr: u8,
smctp_type: u8,
smctp_tag: u8,
__smctp_pad1: u8,
}
pub const MCTP_NET_ANY: u32 = 0x00;
pub struct MctpSockAddr(sockaddr_mctp);
impl MctpSockAddr {
pub fn new(eid: u8, net: u32, typ: u8, tag: u8) -> Self {
MctpSockAddr(sockaddr_mctp {
smctp_family: AF_MCTP,
__smctp_pad0: 0,
smctp_network: net,
smctp_addr: eid,
smctp_type: typ,
smctp_tag: tag,
__smctp_pad1: 0,
})
}
fn zero() -> Self {
Self::new(0, MCTP_NET_ANY, 0, 0)
}
fn as_raw(&self) -> (*const libc::sockaddr, libc::socklen_t) {
(
&self.0 as *const sockaddr_mctp as *const libc::sockaddr,
mem::size_of::<sockaddr_mctp>() as libc::socklen_t,
)
}
fn as_raw_mut(&mut self) -> (*mut libc::sockaddr, libc::socklen_t) {
(
&mut self.0 as *mut sockaddr_mctp as *mut libc::sockaddr,
mem::size_of::<sockaddr_mctp>() as libc::socklen_t,
)
}
}
impl fmt::Debug for MctpSockAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"McptSockAddr(family={}, net={}, addr={}, type={}, tag={})",
self.0.smctp_family,
self.0.smctp_network,
self.0.smctp_addr,
self.0.smctp_type,
self.0.smctp_tag
)
}
}
fn tag_from_smctp(to: u8) -> Tag {
let t = TagValue(to & !MCTP_TAG_OWNER);
if to & MCTP_TAG_OWNER == 0 {
Tag::Unowned(t)
} else {
Tag::Owned(t)
}
}
fn tag_to_smctp(tag: &Tag) -> u8 {
let to_bit = if tag.is_owner() { MCTP_TAG_OWNER } else { 0 };
tag.tag().0 | to_bit
}
fn last_os_error() -> mctp::Error {
mctp::Error::Io(Error::last_os_error())
}
pub struct MctpSocket(RawFd);
impl Drop for MctpSocket {
fn drop(&mut self) {
unsafe { libc::close(self.0) };
}
}
impl MctpSocket {
pub fn new() -> Result<Self> {
let rc = unsafe {
libc::socket(
AF_MCTP.into(),
libc::SOCK_DGRAM | libc::SOCK_CLOEXEC,
0,
)
};
if rc < 0 {
return Err(last_os_error());
}
Ok(MctpSocket(rc))
}
pub fn recvfrom(&self, buf: &mut [u8]) -> Result<(usize, MctpSockAddr)> {
let mut addr = MctpSockAddr::zero();
let (addr_ptr, mut addr_len) = addr.as_raw_mut();
let buf_ptr = buf.as_mut_ptr() as *mut libc::c_void;
let buf_len = buf.len() as libc::size_t;
let rc = unsafe {
libc::recvfrom(self.0, buf_ptr, buf_len, 0, addr_ptr, &mut addr_len)
};
if rc < 0 {
Err(last_os_error())
} else {
Ok((rc as usize, addr))
}
}
pub fn sendto(&self, buf: &[u8], addr: &MctpSockAddr) -> Result<usize> {
let (addr_ptr, addr_len) = addr.as_raw();
let buf_ptr = buf.as_ptr() as *const libc::c_void;
let buf_len = buf.len() as libc::size_t;
let rc = unsafe {
libc::sendto(self.0, buf_ptr, buf_len, 0, addr_ptr, addr_len)
};
if rc < 0 {
Err(last_os_error())
} else {
Ok(rc as usize)
}
}
pub fn bind(&self, addr: &MctpSockAddr) -> Result<()> {
let (addr_ptr, addr_len) = addr.as_raw();
let rc = unsafe { libc::bind(self.0, addr_ptr, addr_len) };
if rc < 0 {
Err(last_os_error())
} else {
Ok(())
}
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
#![allow(deprecated)]
let dur = dur.unwrap_or(Duration::ZERO);
let tv = libc::timeval {
tv_sec: dur.as_secs() as libc::time_t,
tv_usec: dur.subsec_micros() as libc::suseconds_t,
};
let rc = unsafe {
libc::setsockopt(
self.0,
libc::SOL_SOCKET,
libc::SO_RCVTIMEO,
(&tv as *const libc::timeval) as *const libc::c_void,
std::mem::size_of::<libc::timeval>() as libc::socklen_t,
)
};
if rc < 0 {
Err(last_os_error())
} else {
Ok(())
}
}
pub fn read_timeout(&self) -> Result<Option<Duration>> {
#![allow(deprecated)]
let mut tv = std::mem::MaybeUninit::<libc::timeval>::uninit();
let mut tv_len =
std::mem::size_of::<libc::timeval>() as libc::socklen_t;
let rc = unsafe {
libc::getsockopt(
self.0,
libc::SOL_SOCKET,
libc::SO_RCVTIMEO,
tv.as_mut_ptr() as *mut libc::c_void,
&mut tv_len as *mut libc::socklen_t,
)
};
if rc < 0 {
Err(last_os_error())
} else {
let tv = unsafe { tv.assume_init() };
if tv.tv_sec < 0 || tv.tv_usec < 0 {
return Err(mctp::Error::Other);
}
if tv.tv_sec == 0 && tv.tv_usec == 0 {
Ok(None)
} else {
Ok(Some(
Duration::from_secs(tv.tv_sec as u64)
+ Duration::from_micros(tv.tv_usec as u64),
))
}
}
}
}
impl std::os::fd::AsRawFd for MctpSocket {
fn as_raw_fd(&self) -> RawFd {
self.0
}
}
pub struct MctpLinuxReq {
eid: Eid,
net: u32,
sock: MctpSocket,
sent: bool,
}
impl MctpLinuxReq {
pub fn new(eid: Eid, net: Option<u32>) -> Result<Self> {
let net = net.unwrap_or(MCTP_NET_ANY);
Ok(Self {
eid,
net,
sock: MctpSocket::new()?,
sent: false,
})
}
pub fn as_socket(&mut self) -> &mut MctpSocket {
&mut self.sock
}
pub fn net(&self) -> Option<u32> {
if self.net == MCTP_NET_ANY {
None
} else {
Some(self.net)
}
}
}
impl mctp::ReqChannel for MctpLinuxReq {
fn send_vectored(
&mut self,
typ: MsgType,
ic: MsgIC,
bufs: &[&[u8]],
) -> Result<()> {
let typ_ic = mctp::encode_type_ic(typ, ic);
let addr = MctpSockAddr::new(
self.eid.0,
self.net,
typ_ic,
mctp::MCTP_TAG_OWNER,
);
let concat = bufs
.iter()
.flat_map(|b| b.iter().cloned())
.collect::<Vec<u8>>();
self.sock.sendto(&concat, &addr)?;
self.sent = true;
Ok(())
}
fn recv<'f>(
&mut self,
buf: &'f mut [u8],
) -> Result<(MsgType, MsgIC, &'f mut [u8])> {
if !self.sent {
return Err(mctp::Error::BadArgument);
}
let (sz, addr) = self.sock.recvfrom(buf)?;
let src = Eid(addr.0.smctp_addr);
let (typ, ic) = mctp::decode_type_ic(addr.0.smctp_type);
if src != self.eid {
return Err(mctp::Error::Other);
}
Ok((typ, ic, &mut buf[..sz]))
}
fn remote_eid(&self) -> Eid {
self.eid
}
}
pub struct MctpLinuxListener {
sock: MctpSocket,
net: u32,
typ: MsgType,
}
impl MctpLinuxListener {
pub fn new(typ: MsgType, net: Option<u32>) -> Result<Self> {
let sock = MctpSocket::new()?;
let net = net.unwrap_or(MCTP_NET_ANY);
let addr = MctpSockAddr::new(
MCTP_ADDR_ANY.0,
net,
typ.0,
mctp::MCTP_TAG_OWNER,
);
sock.bind(&addr)?;
Ok(Self { sock, net, typ })
}
pub fn as_socket(&mut self) -> &mut MctpSocket {
&mut self.sock
}
pub fn net(&self) -> Option<u32> {
if self.net == MCTP_NET_ANY {
None
} else {
Some(self.net)
}
}
}
impl mctp::Listener for MctpLinuxListener {
type RespChannel<'a> = MctpLinuxResp<'a>;
fn recv<'f>(
&mut self,
buf: &'f mut [u8],
) -> Result<(MsgType, MsgIC, &'f mut [u8], MctpLinuxResp<'_>)> {
let (sz, addr) = self.sock.recvfrom(buf)?;
let src = Eid(addr.0.smctp_addr);
let (typ, ic) = mctp::decode_type_ic(addr.0.smctp_type);
let tag = tag_from_smctp(addr.0.smctp_tag);
if let Tag::Unowned(_) = tag {
return Err(mctp::Error::InternalError);
}
if typ != self.typ {
return Err(mctp::Error::InternalError);
}
let ep = MctpLinuxResp {
eid: src,
tv: tag.tag(),
listener: self,
typ,
};
Ok((typ, ic, &mut buf[..sz], ep))
}
}
pub struct MctpLinuxResp<'a> {
eid: Eid,
tv: TagValue,
listener: &'a MctpLinuxListener,
typ: MsgType,
}
impl mctp::RespChannel for MctpLinuxResp<'_> {
type ReqChannel = MctpLinuxReq;
fn send_vectored(&mut self, ic: MsgIC, bufs: &[&[u8]]) -> Result<()> {
let typ_ic = mctp::encode_type_ic(self.typ, ic);
let tag = tag_to_smctp(&Tag::Unowned(self.tv));
let addr =
MctpSockAddr::new(self.eid.0, self.listener.net, typ_ic, tag);
let concat = bufs
.iter()
.flat_map(|b| b.iter().cloned())
.collect::<Vec<u8>>();
self.listener.sock.sendto(&concat, &addr)?;
Ok(())
}
fn remote_eid(&self) -> Eid {
self.eid
}
fn req_channel(&self) -> Result<Self::ReqChannel> {
MctpLinuxReq::new(self.eid, Some(self.listener.net))
}
}
#[derive(Debug)]
pub struct MctpAddr {
eid: Eid,
net: Option<u32>,
}
impl std::str::FromStr for MctpAddr {
type Err = String;
fn from_str(s: &str) -> std::result::Result<MctpAddr, String> {
let mut parts = s.split(',');
let p1 = parts.next();
let p2 = parts.next();
let (net_str, eid_str) = match (p1, p2) {
(Some(n), Some(e)) => (Some(n), e),
(Some(e), None) => (None, e),
_ => return Err("invalid MCTP address format".to_string()),
};
const HEX_PREFIX: &str = "0x";
const HEX_PREFIX_LEN: usize = HEX_PREFIX.len();
let eid = if eid_str.to_ascii_lowercase().starts_with(HEX_PREFIX) {
u8::from_str_radix(&eid_str[HEX_PREFIX_LEN..], 16)
} else {
eid_str.parse()
}
.map_err(|e| e.to_string())?;
let eid = Eid(eid);
let net: Option<u32> = match net_str {
Some(n) => Some(
n.parse()
.map_err(|e: std::num::ParseIntError| e.to_string())?,
),
None => None,
};
Ok(MctpAddr { net, eid })
}
}
impl MctpAddr {
pub fn eid(&self) -> Eid {
self.eid
}
pub fn net(&self) -> u32 {
self.net.unwrap_or(MCTP_NET_ANY)
}
pub fn create_endpoint(&self) -> Result<MctpLinuxReq> {
MctpLinuxReq::new(self.eid, self.net)
}
pub fn create_listener(&self, typ: MsgType) -> Result<MctpLinuxListener> {
MctpLinuxListener::new(typ, self.net)
}
}