#![deny(missing_docs)]
#[cfg(not(target_os = "linux"))]
compile_error!("Netlink only works on Linux");
use std::os::unix::io::AsRawFd;
use anyhow::{anyhow, Context};
use neli::{
consts::genl::NlAttrType,
consts::nl::{NlmF, NlmFFlags},
consts::socket::NlFamily,
err::SerError,
genl::{Genlmsghdr, Nlattr},
nl::{NlPayload, Nlmsghdr},
socket::NlSocketHandle,
types::{Buffer, GenlBuffer},
Size, ToBytes,
};
use neli_proc_macros::neli_enum;
#[neli_enum(serialized_type = "u8")]
enum NbdCmd {
Unspec = 0,
Connect = 1,
Disconnect = 2,
Reconfigure = 3,
LinkDead = 4,
Status = 5,
}
impl neli::consts::genl::Cmd for NbdCmd {}
#[neli_enum(serialized_type = "u16")]
enum NbdAttr {
Unspec = 0,
Index = 1,
SizeBytes = 2,
BlockSizeBytes = 3,
Timeout = 4,
ServerFlags = 5,
ClientFlags = 6,
Sockets = 7,
DeadConnTimeout = 8,
DeviceList = 9,
}
impl NlAttrType for NbdAttr {}
#[neli_enum(serialized_type = "u16")]
enum NbdSockItem {
Unspec = 0,
Item = 1,
}
impl NlAttrType for NbdSockItem {}
#[neli_enum(serialized_type = "u16")]
enum NbdSock {
Unspec = 0,
Fd = 1,
}
impl NlAttrType for NbdSock {}
const HAS_FLAGS: u64 = 1 << 0;
const READ_ONLY: u64 = 1 << 1;
const CAN_MULTI_CONN: u64 = 1 << 8;
const NBD_CFLAG_DISCONNECT_ON_CLOSE: u64 = 1 << 1;
pub struct NBD {
nl: NlSocketHandle,
nbd_family: u16,
}
impl NBD {
pub fn new() -> anyhow::Result<Self> {
let mut nl = NlSocketHandle::new(NlFamily::Generic)?;
let nbd_family = nl
.resolve_genl_family("nbd")
.context("Could not resolve the NBD generic netlink family")?;
Ok(Self { nl, nbd_family })
}
}
pub struct NBDConnect {
size_bytes: u64,
block_size_bytes: u64,
server_flags: u64,
client_flags: u64,
}
impl NBDConnect {
pub fn new() -> Self {
Self {
size_bytes: 0,
block_size_bytes: 4096,
server_flags: HAS_FLAGS,
client_flags: 0,
}
}
pub fn size_bytes(&mut self, bytes: u64) -> &mut Self {
self.size_bytes = bytes;
self
}
pub fn block_size(&mut self, bytes: u64) -> &mut Self {
self.block_size_bytes = bytes;
self
}
pub fn read_only(&mut self, read_only: bool) -> &mut Self {
if read_only {
self.server_flags |= READ_ONLY;
} else {
self.server_flags &= !READ_ONLY;
}
self
}
pub fn can_multi_conn(&mut self, can_multi_conn: bool) -> &mut Self {
if can_multi_conn {
self.server_flags |= CAN_MULTI_CONN;
} else {
self.server_flags &= !CAN_MULTI_CONN;
}
self
}
pub fn disconnect_on_close(&mut self, disconnect_on_close: bool) -> &mut Self {
if disconnect_on_close {
self.client_flags |= NBD_CFLAG_DISCONNECT_ON_CLOSE;
} else {
self.client_flags &= !NBD_CFLAG_DISCONNECT_ON_CLOSE;
}
self
}
pub fn connect<'a>(
&self,
nbd: &mut NBD,
sockets: impl IntoIterator<Item = &'a (impl AsRawFd + 'a)>,
) -> anyhow::Result<u32> {
fn attr<T: NlAttrType>(
t: T,
p: impl Size + ToBytes,
) -> Result<Nlattr<T, Buffer>, SerError> {
Nlattr::new(false, false, t, p)
}
let mut sockets_attr = Nlattr::new(true, false, NbdAttr::Sockets, Buffer::new())?;
for socket in sockets {
sockets_attr.add_nested_attribute(&Nlattr::new(
true,
false,
NbdSockItem::Item,
attr(NbdSock::Fd, socket.as_raw_fd())?,
)?)?;
}
let mut attrs = GenlBuffer::new();
attrs.push(attr(NbdAttr::SizeBytes, self.size_bytes)?);
attrs.push(attr(NbdAttr::BlockSizeBytes, self.block_size_bytes)?);
attrs.push(attr(NbdAttr::ServerFlags, self.server_flags)?);
attrs.push(attr(NbdAttr::ClientFlags, self.client_flags)?);
attrs.push(sockets_attr);
let genl_header = Genlmsghdr::new(NbdCmd::Connect, 1, attrs);
let nl_header = Nlmsghdr::new(
None,
nbd.nbd_family,
NlmFFlags::new(&[NlmF::Request]),
None,
None,
NlPayload::Payload(genl_header),
);
nbd.nl.send(nl_header)?;
let response: Nlmsghdr<u16, Genlmsghdr<NbdCmd, NbdAttr>> = nbd
.nl
.recv()?
.ok_or_else(|| anyhow!("Error connecting NBD device: No response received"))?;
let handle = response.get_payload()?.get_attr_handle();
let index = handle.get_attr_payload_as::<u32>(NbdAttr::Index)?;
Ok(index)
}
}