1#![deny(missing_docs)]
20
21#[cfg(not(target_os = "linux"))]
22compile_error!("Netlink only works on Linux");
23
24use std::os::unix::io::AsRawFd;
25
26use anyhow::{anyhow, Context};
27use neli::{
28 consts::genl::NlAttrType,
29 consts::nl::{NlmF, NlmFFlags},
30 consts::socket::NlFamily,
31 err::SerError,
32 genl::{Genlmsghdr, Nlattr},
33 nl::{NlPayload, Nlmsghdr},
34 socket::NlSocketHandle,
35 types::{Buffer, GenlBuffer},
36 Size, ToBytes,
37};
38use neli_proc_macros::neli_enum;
39
40#[neli_enum(serialized_type = "u8")]
41enum NbdCmd {
42 Unspec = 0,
43 Connect = 1,
44 Disconnect = 2,
45 Reconfigure = 3,
46 LinkDead = 4,
47 Status = 5,
48}
49impl neli::consts::genl::Cmd for NbdCmd {}
50
51#[neli_enum(serialized_type = "u16")]
52enum NbdAttr {
53 Unspec = 0,
54 Index = 1,
55 SizeBytes = 2,
56 BlockSizeBytes = 3,
57 Timeout = 4,
58 ServerFlags = 5,
59 ClientFlags = 6,
60 Sockets = 7,
61 DeadConnTimeout = 8,
62 DeviceList = 9,
63}
64impl NlAttrType for NbdAttr {}
65
66#[neli_enum(serialized_type = "u16")]
67enum NbdSockItem {
68 Unspec = 0,
69 Item = 1,
70}
71impl NlAttrType for NbdSockItem {}
72
73#[neli_enum(serialized_type = "u16")]
74enum NbdSock {
75 Unspec = 0,
76 Fd = 1,
77}
78impl NlAttrType for NbdSock {}
79
80const HAS_FLAGS: u64 = 1 << 0;
81const READ_ONLY: u64 = 1 << 1;
82const CAN_MULTI_CONN: u64 = 1 << 8;
83
84const NBD_CFLAG_DISCONNECT_ON_CLOSE: u64 = 1 << 1;
85
86pub struct NBD {
88 nl: NlSocketHandle,
89 nbd_family: u16,
90}
91
92impl NBD {
93 pub fn new() -> anyhow::Result<Self> {
99 let mut nl = NlSocketHandle::new(NlFamily::Generic)?;
100 let nbd_family = nl
101 .resolve_genl_family("nbd")
102 .context("Could not resolve the NBD generic netlink family")?;
103 Ok(Self { nl, nbd_family })
104 }
105}
106
107pub struct NBDConnect {
109 size_bytes: u64,
110 block_size_bytes: u64,
111 server_flags: u64,
112 client_flags: u64,
113}
114
115impl NBDConnect {
116 pub fn new() -> Self {
118 Self {
119 size_bytes: 0,
120 block_size_bytes: 4096,
121 server_flags: HAS_FLAGS,
122 client_flags: 0,
123 }
124 }
125
126 pub fn size_bytes(&mut self, bytes: u64) -> &mut Self {
128 self.size_bytes = bytes;
129 self
130 }
131
132 pub fn block_size(&mut self, bytes: u64) -> &mut Self {
134 self.block_size_bytes = bytes;
135 self
136 }
137
138 pub fn read_only(&mut self, read_only: bool) -> &mut Self {
140 if read_only {
141 self.server_flags |= READ_ONLY;
142 } else {
143 self.server_flags &= !READ_ONLY;
144 }
145 self
146 }
147
148 pub fn can_multi_conn(&mut self, can_multi_conn: bool) -> &mut Self {
150 if can_multi_conn {
151 self.server_flags |= CAN_MULTI_CONN;
152 } else {
153 self.server_flags &= !CAN_MULTI_CONN;
154 }
155 self
156 }
157
158 pub fn disconnect_on_close(&mut self, disconnect_on_close: bool) -> &mut Self {
160 if disconnect_on_close {
161 self.client_flags |= NBD_CFLAG_DISCONNECT_ON_CLOSE;
162 } else {
163 self.client_flags &= !NBD_CFLAG_DISCONNECT_ON_CLOSE;
164 }
165 self
166 }
167
168 pub fn connect<'a>(
172 &self,
173 nbd: &mut NBD,
174 sockets: impl IntoIterator<Item = &'a (impl AsRawFd + 'a)>,
175 ) -> anyhow::Result<u32> {
176 fn attr<T: NlAttrType>(
177 t: T,
178 p: impl Size + ToBytes,
179 ) -> Result<Nlattr<T, Buffer>, SerError> {
180 Nlattr::new(false, false, t, p)
181 }
182 let mut sockets_attr = Nlattr::new(true, false, NbdAttr::Sockets, Buffer::new())?;
183 for socket in sockets {
184 sockets_attr.add_nested_attribute(&Nlattr::new(
185 true,
186 false,
187 NbdSockItem::Item,
188 attr(NbdSock::Fd, socket.as_raw_fd())?,
189 )?)?;
190 }
191 let mut attrs = GenlBuffer::new();
192 attrs.push(attr(NbdAttr::SizeBytes, self.size_bytes)?);
193 attrs.push(attr(NbdAttr::BlockSizeBytes, self.block_size_bytes)?);
194 attrs.push(attr(NbdAttr::ServerFlags, self.server_flags)?);
195 attrs.push(attr(NbdAttr::ClientFlags, self.client_flags)?);
196 attrs.push(sockets_attr);
197
198 let genl_header = Genlmsghdr::new(NbdCmd::Connect, 1, attrs);
199 let nl_header = Nlmsghdr::new(
200 None,
201 nbd.nbd_family,
202 NlmFFlags::new(&[NlmF::Request]),
203 None,
204 None,
205 NlPayload::Payload(genl_header),
206 );
207 nbd.nl.send(nl_header)?;
208 let response: Nlmsghdr<u16, Genlmsghdr<NbdCmd, NbdAttr>> = nbd
209 .nl
210 .recv()?
211 .ok_or_else(|| anyhow!("Error connecting NBD device: No response received"))?;
212 let handle = response.get_payload()?.get_attr_handle();
213 let index = handle.get_attr_payload_as::<u32>(NbdAttr::Index)?;
214 Ok(index)
215 }
216}