use std::sync::Arc;
use std::{io, vec};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpListener, TcpStream};
use crate::command_request::CommandRequest;
use crate::device::NbdDriver;
use crate::errors::{OptionReplyError, ProtocolError};
use crate::flags::{CommandFlags, HandshakeFlags, TransmissionFlags};
use crate::io::command_reply::SimpleReplyRaw;
use crate::io::command_request::CommandRequestRaw;
use crate::io::option_reply::OptionReplyRaw;
use crate::io::option_request::OptionRequestRaw;
use crate::magic::{NBD_IHAVEOPT, NBD_MAGIC};
use crate::option_reply::{InfoPayload, OptionReply};
use crate::option_request::OptionRequest;
struct SelectedDevice<'a, T>
where
T: NbdDriver + 'a,
{
device: &'a T,
read_only: bool,
size: u64,
name: String,
}
enum OptionReplyFinalize<'a, T>
where
T: NbdDriver + 'a,
{
Abort,
Continue,
End(SelectedDevice<'a, T>),
}
pub struct NbdServer<T>
where
T: NbdDriver,
{
devices: Arc<Vec<T>>,
}
impl<T> NbdServer<T>
where
T: NbdDriver + Send + Sync + 'static,
{
pub fn new(devices: Arc<Vec<T>>) -> Self {
Self { devices }
}
pub async fn listen(devices: Vec<T>, host: &str, port: Option<u16>) -> std::io::Result<()> {
let port = port.unwrap_or(10809);
let devices = Arc::new(devices);
let listener = TcpListener::bind(format!("{host}:{port}")).await?;
loop {
let (stream, addr) = listener.accept().await?;
println!("NBD client connected from {}", addr);
let devices = Arc::clone(&devices);
tokio::spawn(async move {
let server = NbdServer::new(devices);
if let Err(e) = server.start(stream).await {
println!("Error starting NBD server: {:?}", e);
return;
}
});
}
}
async fn list_devices(&self) -> Result<Vec<String>, OptionReplyError> {
if self.devices.is_empty() {
return Err(OptionReplyError::UnknownExport);
}
let mut device_names: Vec<String> = Vec::with_capacity(self.devices.len());
for device in &*self.devices {
device_names.push(device.get_name());
}
Ok(device_names)
}
fn get_device(&self, device_name: &str) -> Option<&T> {
if device_name.is_empty() {
return self.devices.first();
}
self.devices.iter().find(|d| d.get_name() == device_name)
}
pub async fn start(&self, stream: TcpStream) -> std::io::Result<()> {
if self.devices.is_empty() {
return Err(io::Error::new(
io::ErrorKind::Other,
"No devices available for NBD server",
));
}
let (reader, writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let mut writer = BufWriter::new(writer);
dbg!("Starting handshake");
self.handle_handshake(&mut reader, &mut writer).await?;
dbg!("Starting options negotiation");
let selected_device = self.handle_options(&mut reader, &mut writer).await?;
dbg!("Starting command handling");
self.handle_commands(
&selected_device.device,
&mut reader,
&mut writer,
selected_device.read_only,
selected_device.size,
)
.await?;
Ok(())
}
async fn handle_handshake<R, W>(&self, reader: &mut R, writer: &mut W) -> std::io::Result<()>
where
R: AsyncReadExt + Unpin,
W: AsyncWrite + Unpin,
{
writer.write_all(&NBD_MAGIC.to_be_bytes()).await?;
writer.write_all(&NBD_IHAVEOPT.to_be_bytes()).await?;
writer
.write_all(&HandshakeFlags::default().bits().to_be_bytes())
.await?;
writer.flush().await?;
let client_flags = reader.read_u32().await?;
let Ok(client_flags) = u16::try_from(client_flags) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid client flags",
));
};
let client_flags = HandshakeFlags::from_bits(client_flags)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid client flags"))?;
if !client_flags.contains(HandshakeFlags::FIXED_NEWSTYLE) {
dbg!("Client did not send FIXED_NEWSTYLE flag, which is required");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client did not send FIXED_NEWSTYLE flag, which is required",
));
}
if !client_flags.contains(HandshakeFlags::NO_ZEROES) {
dbg!("Client did not send NO_ZEROES flag, which is required");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client did not send NO_ZEROES flag, which is required",
));
}
Ok(())
}
async fn handle_option_request(
&self,
request: &OptionRequest,
) -> Result<(Vec<OptionReply>, OptionReplyFinalize<T>), OptionReplyError> {
let mut responses: Vec<OptionReply> = Vec::new();
match request {
OptionRequest::Abort => {
responses.push(OptionReply::Ack);
return Ok((responses, OptionReplyFinalize::Abort));
}
OptionRequest::List => {
for device in self.list_devices().await? {
responses.push(OptionReply::Server(device));
}
}
OptionRequest::StartTLS => unimplemented!(),
OptionRequest::Info(name, _info_requests) | OptionRequest::Go(name, _info_requests) => {
let Some(device) = self.get_device(name) else {
return Err(OptionReplyError::UnknownExport);
};
let mut flags: TransmissionFlags = device.get_features().into();
let read_only = device.get_read_only().await?;
let size = device.get_device_size().await?;
if read_only {
flags.insert(TransmissionFlags::READ_ONLY);
}
responses.push(OptionReply::Info(InfoPayload::Export(size, flags)));
responses.push(OptionReply::Info(InfoPayload::Name(name.clone())));
responses.push(OptionReply::Info(InfoPayload::Description(
device.get_description().await?,
)));
let (min, optimal, max) = device.get_block_size().await?;
responses.push(OptionReply::Info(InfoPayload::BlockSize(min, optimal, max)));
responses.push(OptionReply::Ack);
if matches!(request, OptionRequest::Go(..)) {
return Ok((
responses,
OptionReplyFinalize::End(SelectedDevice {
device: &device,
read_only,
size,
name: name.clone(),
}),
));
}
}
OptionRequest::StructuredReply => unimplemented!(),
OptionRequest::ListMetaContext => unimplemented!(),
OptionRequest::SetMetaContext(_) => unimplemented!(),
OptionRequest::ExtendedHeaders(_) => unimplemented!(),
OptionRequest::ExportName(name) => {
let Some(device) = self.get_device(name) else {
return Err(OptionReplyError::UnknownExport);
};
return Ok((
vec![],
OptionReplyFinalize::End(SelectedDevice {
device: &device,
read_only: device.get_read_only().await?,
size: device.get_device_size().await?,
name: name.clone(),
}),
));
}
}
Ok((responses, OptionReplyFinalize::Continue))
}
async fn write_option_reply_error<W>(
&self,
writer: &mut W,
option: u32,
error: OptionReplyError,
) -> std::io::Result<()>
where
W: AsyncWrite + Unpin,
{
let reply = OptionReplyRaw::new(option, error.into(), error.to_string().into_bytes());
reply.write(writer).await?;
writer.flush().await?;
if error == OptionReplyError::Shutdown {
Err(io::Error::new(
io::ErrorKind::Other,
"Server is shutting down",
))
} else {
Ok(())
}
}
async fn handle_options<R, W>(
&self,
reader: &mut R,
writer: &mut W,
) -> std::io::Result<SelectedDevice<T>>
where
R: AsyncReadExt + Unpin,
W: AsyncWrite + Unpin,
{
loop {
let request_raw = OptionRequestRaw::read(reader).await?;
dbg!("Received option request, raw");
let request = match OptionRequest::try_from(&request_raw) {
Err(e) => {
self.write_option_reply_error(writer, request_raw.option, e)
.await?;
continue;
}
Ok(req) => req,
};
dbg!("Parsed option request: {:?}", &request);
match self.handle_option_request(&request).await {
Err(e) => {
self.write_option_reply_error(writer, request_raw.option, e)
.await?;
continue;
}
Ok((responses, finalize)) => {
for response in responses {
dbg!("Writing option reply: {:?}", &response);
let response_raw = OptionReplyRaw::new(
request_raw.option,
response.get_reply_type().into(),
response.get_data(),
);
response_raw.write(writer).await?;
}
writer.flush().await?;
match finalize {
OptionReplyFinalize::Abort => {
dbg!("Aborting option negotiation");
return Err(io::Error::new(
io::ErrorKind::Other,
"Abort request received",
));
}
OptionReplyFinalize::Continue => {
dbg!("Continuing option negotiation");
}
OptionReplyFinalize::End(selected_device) => {
dbg!(
"Ending option negotiation with selected device: {:?}",
&selected_device.name
);
return Ok(selected_device);
}
}
}
};
}
}
fn check_command(
&self,
command: &CommandRequest,
read_only: bool,
device_size: u64,
) -> Result<(), ProtocolError> {
if read_only && command.is_write_command() {
return Err(ProtocolError::CommandNotPermitted);
}
let command_end = match command {
CommandRequest::Read(offset, length)
| CommandRequest::Trim(offset, length)
| CommandRequest::WriteZeroes(offset, length)
| CommandRequest::Cache(offset, length)
| CommandRequest::BlockStatus(offset, length) => offset + *length as u64,
CommandRequest::Write(offset, data) => offset + data.len() as u64,
_ => return Ok(()),
};
if command_end > device_size {
return Err(ProtocolError::ValueTooLarge);
}
Ok(())
}
fn bounds_check(&self, command: &CommandRequest, device_size: u64) -> bool {
let command_end = match command {
CommandRequest::Read(offset, length)
| CommandRequest::Trim(offset, length)
| CommandRequest::WriteZeroes(offset, length)
| CommandRequest::Cache(offset, length)
| CommandRequest::BlockStatus(offset, length) => offset + *length as u64,
CommandRequest::Write(offset, data) => offset + data.len() as u64,
_ => return true,
};
command_end <= device_size
}
fn parse_command(
&self,
command_raw: &CommandRequestRaw,
read_only: bool,
device_size: u64,
) -> Result<(CommandFlags, CommandRequest), ProtocolError> {
let flags = CommandFlags::try_from(command_raw.flags)
.map_err(|_| ProtocolError::InvalidArgument)?;
let command =
CommandRequest::try_from(command_raw).map_err(|_| ProtocolError::InvalidArgument)?;
if read_only && command.is_write_command() {
Err(ProtocolError::CommandNotPermitted)
} else if !self.bounds_check(&command, device_size) {
Err(ProtocolError::ValueTooLarge)
} else {
Ok((flags, command))
}
}
async fn handle_commands<R, W>(
&self,
device: &T,
reader: &mut R,
writer: &mut W,
read_only: bool,
device_size: u64,
) -> io::Result<()>
where
R: AsyncReadExt + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
loop {
let command_raw = CommandRequestRaw::read(reader).await?;
let cookie = command_raw.cookie;
let (flags, command) = match self.parse_command(&command_raw, read_only, device_size) {
Ok((flags, command)) => (flags, command),
Err(e) => {
let reply = SimpleReplyRaw::new(e.into(), cookie, vec![]);
reply.write(writer).await?;
writer.flush().await?;
continue;
}
};
let result = match command {
CommandRequest::Disconnect => {
device
.disconnect(flags)
.await
.expect("Failed to disconnect");
return Ok(());
}
CommandRequest::Read(offset, length) => device.read(flags, offset, length).await,
CommandRequest::Write(offset, data) => {
device.write(flags, offset, data).await.map(|_| vec![])
}
CommandRequest::Flush => device.flush(flags).await.map(|_| vec![]),
CommandRequest::Trim(offset, length) => {
device.trim(flags, offset, length).await.map(|_| vec![])
}
CommandRequest::WriteZeroes(offset, length) => device
.write_zeroes(flags, offset, length)
.await
.map(|_| vec![]),
CommandRequest::Resize(size) => device.resize(flags, size).await.map(|_| vec![]),
CommandRequest::Cache(offset, length) => {
device.cache(flags, offset, length).await.map(|_| vec![])
}
CommandRequest::BlockStatus(_, _) => {
device
.block_status(flags, command_raw.offset, command_raw.length)
.await
}
.map(|_| vec![]),
};
let (reply, abort) = match result {
Err(e) => {
dbg!("Error processing command: {:?}", &e);
(
SimpleReplyRaw::new(
ProtocolError::ServerShuttingDown.into(),
cookie,
vec![],
),
e == ProtocolError::ServerShuttingDown,
)
}
Ok(data) => (SimpleReplyRaw::new(0, cookie, data), false),
};
reply.write(writer).await?;
writer.flush().await?;
if abort {
return Err(io::Error::new(
io::ErrorKind::Other,
"Unrecoverable error in NBD driver",
));
}
}
}
}