use bon::Builder;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use std::{io, vec};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream};
use tokio::signal;
use tokio::sync::broadcast;
use tokio::time::sleep;
use tracing::{debug, error, info, instrument, trace, warn};
use crate::command_request::CommandRequest;
use crate::device::NbdDriver;
use crate::errors::{OptionReplyError, ProtocolError};
use crate::flags::{CommandFlags, HandshakeFlags, TransmissionFlags};
use crate::io::command_reply::SimpleReplyRaw;
use crate::io::command_request::CommandRequestRaw;
use crate::io::option_reply::OptionReplyRaw;
use crate::io::option_request::OptionRequestRaw;
use crate::magic::{NBD_IHAVEOPT, NBD_MAGIC};
use crate::option_reply::{InfoPayload, OptionReply};
use crate::option_request::OptionRequest;
struct SelectedDevice<'a, T>
where
T: NbdDriver + 'a,
{
device: &'a T,
read_only: bool,
size: u64,
name: String,
}
enum OptionReplyFinalize<'a, T>
where
T: NbdDriver + 'a,
{
Abort,
Continue,
End(SelectedDevice<'a, T>),
}
#[derive(Builder)]
pub struct NbdServerBuilder<'a, T>
where
T: NbdDriver + 'static,
{
#[builder(with = |devices: Vec<T>| Arc::new(devices))]
devices: Arc<Vec<T>>,
host: &'a str,
port: Option<u16>,
shutdown_timeout: Option<u64>,
}
impl<'a, T> NbdServerBuilder<'a, T>
where
T: NbdDriver + Send + Sync + 'static,
{
#[instrument(name = "nbd_server_listen", skip(self))]
pub async fn listen(&self) -> std::io::Result<()> {
let port = self.port.unwrap_or(10809);
let shutdown_timeout = self.shutdown_timeout.unwrap_or(60); let listener = TcpListener::bind(format!("{}:{}", self.host, port)).await?;
let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
let active_connections = Arc::new(AtomicUsize::new(0));
info!("NBD server starting on {}:{}", self.host, port);
let connections_counter = Arc::clone(&active_connections);
let devices = self.devices.clone();
let handle = tokio::spawn({
let mut rx = shutdown_rx.resubscribe();
async move {
loop {
tokio::select! {
accept_result = listener.accept() => {
match accept_result {
Ok((stream, addr)) => {
info!("NBD client connected from {}", addr);
active_connections.fetch_add(1, Ordering::SeqCst);
let devices = Arc::clone(&devices);
let connection_shutdown_rx = shutdown_rx.resubscribe();
let connection_counter = Arc::clone(&active_connections);
let server = NbdServer {
devices,
shutdown_rx: connection_shutdown_rx,
};
tokio::spawn(async move {
if let Err(e) = server.start(stream).await {
error!("Error in NBD server session: {:?}", e);
}
connection_counter.fetch_sub(1, Ordering::SeqCst);
});
}
Err(e) => {
error!("Failed to accept connection: {}", e);
if e.kind() == io::ErrorKind::ConnectionAborted {
continue;
} else {
return;
}
}
}
}
_ = rx.recv() => {
info!("Received shutdown signal, stopping accept loop");
break;
}
}
}
}
});
let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to create SIGTERM signal handler");
let mut sighup = signal::unix::signal(signal::unix::SignalKind::hangup())
.expect("Failed to create SIGHUP signal handler");
tokio::select! {
result = handle => {
warn!("Accept loop stopped unexpectedly");
result?;
}
_ = tokio::signal::ctrl_c() => {
info!("Received Ctrl+C signal, initiating graceful shutdown");
let _ = shutdown_tx.send(());
}
_ = sigterm.recv() => {
info!("Received SIGTERM signal, initiating graceful shutdown");
let _ = shutdown_tx.send(());
}
_ = sighup.recv() => {
info!("Received SIGHUP signal, initiating graceful shutdown");
let _ = shutdown_tx.send(());
}
};
info!("Starting graceful shutdown");
let shutdown_deadline = tokio::time::Instant::now() + Duration::from_secs(shutdown_timeout);
loop {
let remaining = connections_counter.load(Ordering::SeqCst);
if remaining == 0 {
info!("All connections closed, shutdown complete");
break;
}
if tokio::time::Instant::now() >= shutdown_deadline {
warn!(
"Shutdown timeout reached with {} connections still active",
remaining
);
break;
}
debug!("Waiting for {} active connections to close...", remaining);
sleep(Duration::from_secs(1)).await;
}
info!("Server shutdown complete");
Ok(())
}
}
pub struct NbdServer<T>
where
T: NbdDriver,
{
devices: Arc<Vec<T>>,
shutdown_rx: broadcast::Receiver<()>,
}
impl<T> NbdServer<T>
where
T: NbdDriver + Send + Sync + 'static,
{
async fn list_devices(&self) -> Result<Vec<String>, OptionReplyError> {
if self.devices.is_empty() {
return Err(OptionReplyError::UnknownExport);
}
let mut device_names: Vec<String> = Vec::with_capacity(self.devices.len());
for device in &*self.devices {
device_names.push(device.get_name());
}
Ok(device_names)
}
fn get_device(&self, device_name: &str) -> Option<&T> {
if device_name.is_empty() {
return self.devices.first();
}
self.devices.iter().find(|d| d.get_name() == device_name)
}
#[instrument(name = "nbd_server_session", skip(self, stream))]
pub async fn start(&self, stream: TcpStream) -> std::io::Result<()> {
if self.devices.is_empty() {
return Err(io::Error::new(
io::ErrorKind::Other,
"No devices available for NBD server",
));
}
let (reader, writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let mut writer = BufWriter::new(writer);
debug!("Starting handshake");
self.handle_handshake(&mut reader, &mut writer).await?;
debug!("Starting options negotiation");
let selected_device = self.handle_options(&mut reader, &mut writer).await?;
debug!("Starting command handling");
self.handle_commands(
&selected_device.device,
&mut reader,
&mut writer,
selected_device.read_only,
selected_device.size,
)
.await?;
Ok(())
}
#[instrument(name = "nbd_handshake", skip(self, reader, writer))]
async fn handle_handshake<R, W>(&self, reader: &mut R, writer: &mut W) -> std::io::Result<()>
where
R: AsyncReadExt + Unpin,
W: AsyncWrite + Unpin,
{
writer.write_all(&NBD_MAGIC.to_be_bytes()).await?;
writer.write_all(&NBD_IHAVEOPT.to_be_bytes()).await?;
writer
.write_all(&HandshakeFlags::default().bits().to_be_bytes())
.await?;
writer.flush().await?;
let client_flags = reader.read_u32().await?;
let Ok(client_flags) = u16::try_from(client_flags) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid client flags",
));
};
let client_flags = HandshakeFlags::from_bits(client_flags)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid client flags"))?;
if !client_flags.contains(HandshakeFlags::FIXED_NEWSTYLE) {
error!("Client did not send FIXED_NEWSTYLE flag, which is required");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client did not send FIXED_NEWSTYLE flag, which is required",
));
}
if !client_flags.contains(HandshakeFlags::NO_ZEROES) {
error!("Client did not send NO_ZEROES flag, which is required");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client did not send NO_ZEROES flag, which is required",
));
}
Ok(())
}
#[instrument(name = "nbd_process_option", skip(self), fields(option_type = ?request))]
async fn handle_option_request(
&self,
request: &OptionRequest,
) -> Result<(Vec<OptionReply>, OptionReplyFinalize<T>), OptionReplyError> {
let mut responses: Vec<OptionReply> = Vec::new();
match request {
OptionRequest::Abort => {
responses.push(OptionReply::Ack);
return Ok((responses, OptionReplyFinalize::Abort));
}
OptionRequest::List => {
for device in self.list_devices().await? {
responses.push(OptionReply::Server(device));
}
}
OptionRequest::StartTLS => unimplemented!(),
OptionRequest::Info(name, _info_requests) | OptionRequest::Go(name, _info_requests) => {
let Some(device) = self.get_device(name) else {
return Err(OptionReplyError::UnknownExport);
};
let mut flags: TransmissionFlags = device.get_features().into();
let read_only = device.get_read_only().await?;
let size = device.get_device_size().load(Ordering::Acquire);
if read_only {
flags.insert(TransmissionFlags::READ_ONLY);
}
responses.push(OptionReply::Info(InfoPayload::Export(size, flags)));
responses.push(OptionReply::Info(InfoPayload::Name(name.clone())));
responses.push(OptionReply::Info(InfoPayload::Description(
device.get_description().await?,
)));
let (min, optimal, max) = device.get_block_size().await?;
responses.push(OptionReply::Info(InfoPayload::BlockSize(min, optimal, max)));
responses.push(OptionReply::Ack);
if matches!(request, OptionRequest::Go(..)) {
return Ok((
responses,
OptionReplyFinalize::End(SelectedDevice {
device: &device,
read_only,
size,
name: name.clone(),
}),
));
}
}
OptionRequest::StructuredReply => unimplemented!(),
OptionRequest::ListMetaContext => unimplemented!(),
OptionRequest::SetMetaContext(_) => unimplemented!(),
OptionRequest::ExtendedHeaders(_) => unimplemented!(),
OptionRequest::ExportName(name) => {
let Some(device) = self.get_device(name) else {
return Err(OptionReplyError::UnknownExport);
};
return Ok((
vec![],
OptionReplyFinalize::End(SelectedDevice {
device: &device,
read_only: device.get_read_only().await?,
size: device.get_device_size().load(Ordering::Acquire),
name: name.clone(),
}),
));
}
}
Ok((responses, OptionReplyFinalize::Continue))
}
async fn write_option_reply_error<W>(
&self,
writer: &mut W,
option: u32,
error: OptionReplyError,
) -> std::io::Result<()>
where
W: AsyncWrite + Unpin,
{
let reply = OptionReplyRaw::new(option, error.into(), error.to_string().into_bytes());
reply.write(writer).await?;
writer.flush().await?;
if error == OptionReplyError::Shutdown {
Err(io::Error::new(
io::ErrorKind::Other,
"Server is shutting down",
))
} else {
Ok(())
}
}
#[instrument(name = "nbd_options_negotiation", skip(self, reader, writer))]
async fn handle_options<R, W>(
&self,
reader: &mut R,
writer: &mut W,
) -> std::io::Result<SelectedDevice<T>>
where
R: AsyncReadExt + Unpin,
W: AsyncWrite + Unpin,
{
let mut shutdown_rx = self.shutdown_rx.resubscribe();
loop {
let request_raw = tokio::select! {
request_result = OptionRequestRaw::read(reader) => {
match request_result {
Ok(req) => req,
Err(e) => return Err(e),
}
},
_ = shutdown_rx.recv() => {
info!("Server is shutting down, aborting option negotiation");
return Err(io::Error::new(
io::ErrorKind::Other,
"Server is shutting down during option negotiation"
));
}
};
trace!("Received option request, raw");
let request = match OptionRequest::try_from(&request_raw) {
Err(e) => {
self.write_option_reply_error(writer, request_raw.option, e)
.await?;
continue;
}
Ok(req) => req,
};
debug!("Parsed option request: {:?}", &request);
match self.handle_option_request(&request).await {
Err(e) => {
self.write_option_reply_error(writer, request_raw.option, e)
.await?;
continue;
}
Ok((responses, finalize)) => {
for response in responses {
trace!("Writing option reply: {:?}", &response);
let response_raw = OptionReplyRaw::new(
request_raw.option,
response.get_reply_type().into(),
response.get_data(),
);
response_raw.write(writer).await?;
}
writer.flush().await?;
match finalize {
OptionReplyFinalize::Abort => {
info!("Aborting option negotiation");
return Err(io::Error::new(
io::ErrorKind::Other,
"Abort request received",
));
}
OptionReplyFinalize::Continue => {
debug!("Continuing option negotiation");
}
OptionReplyFinalize::End(selected_device) => {
info!(
"Ending option negotiation with selected device: {}",
&selected_device.name
);
return Ok(selected_device);
}
}
}
};
}
}
fn bounds_check(&self, command: &CommandRequest, device_size: u64) -> bool {
let command_end = match command {
CommandRequest::Read(offset, length)
| CommandRequest::Trim(offset, length)
| CommandRequest::WriteZeroes(offset, length)
| CommandRequest::Cache(offset, length)
| CommandRequest::BlockStatus(offset, length) => offset + *length as u64,
CommandRequest::Write(offset, data) => offset + data.len() as u64,
_ => return true,
};
command_end <= device_size
}
fn parse_command(
&self,
command_raw: &CommandRequestRaw,
read_only: bool,
device_size: u64,
) -> Result<(CommandFlags, CommandRequest), ProtocolError> {
let flags = CommandFlags::try_from(command_raw.flags)
.map_err(|_| ProtocolError::InvalidArgument)?;
let command =
CommandRequest::try_from(command_raw).map_err(|_| ProtocolError::InvalidArgument)?;
if read_only && command.is_write_command() {
Err(ProtocolError::CommandNotPermitted)
} else if !self.bounds_check(&command, device_size) {
Err(ProtocolError::ValueTooLarge)
} else {
Ok((flags, command))
}
}
fn check_noop(&self, command: &CommandRequest) -> bool {
match command {
CommandRequest::Read(_, 0)
| CommandRequest::Cache(_, 0)
| CommandRequest::Trim(_, 0)
| CommandRequest::WriteZeroes(_, 0) => true,
CommandRequest::Write(_, data) if data.is_empty() => true,
_ => false,
}
}
#[instrument(name = "nbd_command_handling", skip(self, device, reader, writer), fields(device_name = %device.get_name()))]
async fn handle_commands<R, W>(
&self,
device: &T,
reader: &mut R,
writer: &mut W,
read_only: bool,
device_size: u64,
) -> io::Result<()>
where
R: AsyncReadExt + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut shutdown_rx = self.shutdown_rx.resubscribe();
loop {
let command_raw = tokio::select! {
cmd_result = CommandRequestRaw::read(reader) => {
match cmd_result {
Ok(cmd) => cmd,
Err(e) => return Err(e),
}
},
_ = shutdown_rx.recv() => {
info!("Server is shutting down, aborting command processing");
device
.disconnect(CommandFlags::empty())
.await
.expect("Failed to disconnect");
return Err(io::Error::new(
io::ErrorKind::Other,
"Server is shutting down during command processing"
));
}
};
let cookie = command_raw.cookie;
let (flags, command) = match self.parse_command(&command_raw, read_only, device_size) {
Ok((flags, command)) => (flags, command),
Err(e) => {
let reply = SimpleReplyRaw::new(e.into(), cookie, vec![]);
reply.write(writer).await?;
writer.flush().await?;
continue;
}
};
if self.check_noop(&command) {
let reply = SimpleReplyRaw::new(0, cookie, vec![]);
reply.write(writer).await?;
writer.flush().await?;
continue;
}
let result = match command {
CommandRequest::Disconnect => {
device
.disconnect(flags)
.await
.expect("Failed to disconnect");
return Ok(());
}
CommandRequest::Read(offset, length) => device.read(flags, offset, length).await,
CommandRequest::Write(offset, data) => {
device.write(flags, offset, data).await.map(|_| vec![])
}
CommandRequest::Flush => device.flush(flags).await.map(|_| vec![]),
CommandRequest::Trim(offset, length) => {
device.trim(flags, offset, length).await.map(|_| vec![])
}
CommandRequest::WriteZeroes(offset, length) => device
.write_zeroes(flags, offset, length)
.await
.map(|_| vec![]),
CommandRequest::Resize(size) => device.resize(flags, size).await.map(|_| vec![]),
CommandRequest::Cache(offset, length) => {
device.cache(flags, offset, length).await.map(|_| vec![])
}
CommandRequest::BlockStatus(_, _) => {
device
.block_status(flags, command_raw.offset, command_raw.length)
.await
}
.map(|_| vec![]),
};
let (reply, abort) = match result {
Err(e) => {
error!("Error processing command: {:?}", &e);
(
SimpleReplyRaw::new(
ProtocolError::ServerShuttingDown.into(),
cookie,
vec![],
),
e == ProtocolError::ServerShuttingDown,
)
}
Ok(data) => (SimpleReplyRaw::new(0, cookie, data), false),
};
reply.write(writer).await?;
writer.flush().await?;
if abort {
return Err(io::Error::new(
io::ErrorKind::Other,
"Unrecoverable error in NBD driver",
));
}
}
}
}