use std::sync::Arc;
use std::{io, vec};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::TcpStream;
use crate::command_request::CommandRequest;
use crate::errors::{OptionReplyError, ProtocolError};
use crate::flags::{CommandFlags, HandshakeFlags, ServerFeatures, TransmissionFlags};
use crate::io::command_reply::SimpleReplyRaw;
use crate::io::command_request::CommandRequestRaw;
use crate::io::option_reply::OptionReplyRaw;
use crate::io::option_request::OptionRequestRaw;
use crate::magic::{NBD_IHAVEOPT, NBD_MAGIC};
use crate::option_reply::{InfoPayload, OptionReply};
use crate::option_request::OptionRequest;
use std::future::Future;
pub trait NBDDriver {
fn get_features(&self) -> ServerFeatures;
fn list_devices(&self) -> impl Future<Output = Result<Vec<String>, OptionReplyError>>;
fn get_read_only(
&self,
device_name: &str,
) -> impl Future<Output = Result<bool, OptionReplyError>>;
fn get_block_size(
&self,
device_name: &str,
) -> impl Future<Output = Result<(u32, u32, u32), OptionReplyError>>;
fn get_canonical_name(
&self,
device_name: &str,
) -> impl Future<Output = Result<String, OptionReplyError>>;
fn get_description(
&self,
device_name: &str,
) -> impl Future<Output = Result<String, OptionReplyError>>;
fn get_device_size(
&self,
device_name: &str,
) -> impl Future<Output = Result<u64, OptionReplyError>>;
fn read(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<Vec<u8>, ProtocolError>>;
fn write(
&self,
flags: CommandFlags,
offset: u64,
data: Vec<u8>,
) -> impl Future<Output = Result<(), ProtocolError>>;
fn flush(&self, flags: CommandFlags) -> impl Future<Output = Result<(), ProtocolError>>;
fn trim(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<(), ProtocolError>>;
fn write_zeroes(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<(), ProtocolError>>;
fn disconnect(&self, flags: CommandFlags) -> impl Future<Output = Result<(), ProtocolError>>;
fn resize(
&self,
flags: CommandFlags,
size: u64,
) -> impl Future<Output = Result<(), ProtocolError>>;
fn cache(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<(), ProtocolError>>;
fn block_status(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<(), ProtocolError>>;
}
#[derive(Debug)]
struct SelectedDevice {
read_only: bool,
}
impl SelectedDevice {
fn is_command_permitted(&self, command: &CommandRequest) -> bool {
return !(self.read_only
&& matches!(
command,
CommandRequest::Write(_, _)
| CommandRequest::Flush
| CommandRequest::Trim(_, _)
| CommandRequest::WriteZeroes(_, _)
| CommandRequest::Resize(_)
));
}
}
enum OptionReplyFinalize {
Abort,
Continue,
End(SelectedDevice),
}
pub struct NBDServer<T: NBDDriver> {
driver: Arc<T>,
}
impl<T: NBDDriver> NBDServer<T> {
pub fn new(driver: Arc<T>) -> Self {
NBDServer { driver }
}
pub async fn start(&self, stream: TcpStream) -> std::io::Result<()> {
let (reader, writer) = stream.into_split();
let mut reader = BufReader::new(reader);
let mut writer = BufWriter::new(writer);
dbg!("Starting handshake");
self.handle_handshake(&mut reader, &mut writer).await?;
dbg!("Starting options negotiation");
let selected_device = self.handle_options(&mut reader, &mut writer).await?;
dbg!("Starting command handling");
self.handle_commands(&selected_device, &mut reader, &mut writer)
.await?;
Ok(())
}
async fn handle_handshake<R, W>(&self, reader: &mut R, writer: &mut W) -> std::io::Result<()>
where
R: AsyncReadExt + Unpin,
W: AsyncWrite + Unpin,
{
writer.write_all(&NBD_MAGIC.to_be_bytes()).await?;
writer.write_all(&NBD_IHAVEOPT.to_be_bytes()).await?;
writer
.write_all(&HandshakeFlags::default().bits().to_be_bytes())
.await?;
writer.flush().await?;
let client_flags = reader.read_u32().await?;
let Ok(client_flags) = u16::try_from(client_flags) else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid client flags",
));
};
let client_flags = HandshakeFlags::from_bits(client_flags)
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Invalid client flags"))?;
if !client_flags.contains(HandshakeFlags::FIXED_NEWSTYLE) {
dbg!("Client did not send FIXED_NEWSTYLE flag, which is required");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client did not send FIXED_NEWSTYLE flag, which is required",
));
}
if !client_flags.contains(HandshakeFlags::NO_ZEROES) {
dbg!("Client did not send NO_ZEROES flag, which is required");
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Client did not send NO_ZEROES flag, which is required",
));
}
Ok(())
}
async fn handle_option_request(
&self,
request: &OptionRequest,
) -> Result<(Vec<OptionReply>, OptionReplyFinalize), OptionReplyError> {
let mut responses: Vec<OptionReply> = Vec::new();
match request {
OptionRequest::Abort => {
responses.push(OptionReply::Ack);
return Ok((responses, OptionReplyFinalize::Abort));
}
OptionRequest::List => {
for device in self.driver.list_devices().await? {
responses.push(OptionReply::Server(device));
}
}
OptionRequest::StartTLS => unimplemented!(),
OptionRequest::Info(name, _info_requests) | OptionRequest::Go(name, _info_requests) => {
let read_only = self.driver.get_read_only(name).await?;
let size = self.driver.get_device_size(name).await?;
let mut flags: TransmissionFlags = self.driver.get_features().into();
if read_only {
flags.insert(TransmissionFlags::READ_ONLY);
}
responses.push(OptionReply::Info(InfoPayload::Export(size, flags)));
responses.push(OptionReply::Info(InfoPayload::Name(name.clone())));
responses.push(OptionReply::Info(InfoPayload::Description(
self.driver.get_description(name).await?,
)));
let (min, optimal, max) = self.driver.get_block_size(name).await?;
responses.push(OptionReply::Info(InfoPayload::BlockSize(min, optimal, max)));
responses.push(OptionReply::Ack);
if matches!(request, OptionRequest::Go(..)) {
return Ok((
responses,
OptionReplyFinalize::End(SelectedDevice { read_only }),
));
}
}
OptionRequest::StructuredReply => unimplemented!(),
OptionRequest::ListMetaContext => unimplemented!(),
OptionRequest::SetMetaContext(_) => unimplemented!(),
OptionRequest::ExtendedHeaders(_) => unimplemented!(),
OptionRequest::ExportName(name) => {
return Ok((
vec![],
OptionReplyFinalize::End(SelectedDevice {
read_only: self.driver.get_read_only(name).await?,
}),
));
}
}
Ok((responses, OptionReplyFinalize::Continue))
}
async fn write_option_reply_error<W>(
&self,
writer: &mut W,
option: u32,
error: OptionReplyError,
) -> std::io::Result<()>
where
W: AsyncWrite + Unpin,
{
let reply = OptionReplyRaw::new(option, error.into(), error.to_string().into_bytes());
reply.write(writer).await?;
writer.flush().await?;
if error == OptionReplyError::Shutdown {
Err(io::Error::new(
io::ErrorKind::Other,
"Server is shutting down",
))
} else {
Ok(())
}
}
async fn handle_options<R, W>(
&self,
reader: &mut R,
writer: &mut W,
) -> std::io::Result<SelectedDevice>
where
R: AsyncReadExt + Unpin,
W: AsyncWrite + Unpin,
{
loop {
let request_raw = OptionRequestRaw::read(reader).await?;
dbg!("Received option request, raw");
let request = match OptionRequest::try_from(&request_raw) {
Err(e) => {
self.write_option_reply_error(writer, request_raw.option, e)
.await?;
continue;
}
Ok(req) => req,
};
dbg!("Parsed option request: {:?}", &request);
match self.handle_option_request(&request).await {
Err(e) => {
self.write_option_reply_error(writer, request_raw.option, e)
.await?;
continue;
}
Ok((responses, finalize)) => {
for response in responses {
dbg!("Writing option reply: {:?}", &response);
let response_raw = OptionReplyRaw::new(
request_raw.option,
response.get_reply_type().into(),
response.get_data(),
);
response_raw.write(writer).await?;
}
writer.flush().await?;
match finalize {
OptionReplyFinalize::Abort => {
dbg!("Aborting option negotiation");
return Err(io::Error::new(
io::ErrorKind::Other,
"Abort request received",
));
}
OptionReplyFinalize::Continue => {
dbg!("Continuing option negotiation");
}
OptionReplyFinalize::End(selected_device) => {
dbg!(
"Ending option negotiation with selected device: {:?}",
&selected_device
);
return Ok(selected_device);
}
}
}
};
}
}
async fn handle_commands<R, W>(
&self,
selected_device: &SelectedDevice,
reader: &mut R,
writer: &mut W,
) -> io::Result<()>
where
R: AsyncReadExt + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
loop {
let command_raw = CommandRequestRaw::read(reader).await?;
let cookie = command_raw.cookie;
let Ok(flags) = CommandFlags::try_from(command_raw.flags) else {
let reply =
SimpleReplyRaw::new(ProtocolError::InvalidArgument.into(), cookie, vec![]);
reply.write(writer).await?;
writer.flush().await?;
continue;
};
let command = match CommandRequest::try_from(&command_raw) {
Ok(op) => op,
Err(e) => {
let reply = SimpleReplyRaw::new(e.into(), cookie, vec![]);
reply.write(writer).await?;
writer.flush().await?;
continue;
}
};
let command_not_permitted = !selected_device.is_command_permitted(&command);
let result = if command_not_permitted {
Err(ProtocolError::CommandNotPermitted)
} else {
match command {
CommandRequest::Disconnect => {
self.driver
.disconnect(flags)
.await
.expect("Failed to disconnect");
return Ok(());
}
CommandRequest::Read(offset, length) => {
self.driver.read(flags, offset, length).await
}
CommandRequest::Write(offset, data) => {
self.driver.write(flags, offset, data).await.map(|_| vec![])
}
CommandRequest::Flush => self.driver.flush(flags).await.map(|_| vec![]),
CommandRequest::Trim(offset, length) => self
.driver
.trim(flags, offset, length)
.await
.map(|_| vec![]),
CommandRequest::WriteZeroes(offset, length) => self
.driver
.write_zeroes(flags, offset, length)
.await
.map(|_| vec![]),
CommandRequest::Resize(size) => {
self.driver.resize(flags, size).await.map(|_| vec![])
}
CommandRequest::Cache(offset, length) => self
.driver
.cache(flags, offset, length)
.await
.map(|_| vec![]),
CommandRequest::BlockStatus(..) => self
.driver
.block_status(flags, command_raw.offset, command_raw.length)
.await
.map(|_| vec![]),
}
};
let (reply, abort) = match result {
Err(e) => {
dbg!("Error processing command: {:?}", &e);
(
SimpleReplyRaw::new(
ProtocolError::ServerShuttingDown.into(),
cookie,
vec![],
),
e == ProtocolError::ServerShuttingDown,
)
}
Ok(data) => (SimpleReplyRaw::new(0, cookie, data), false),
};
reply.write(writer).await?;
writer.flush().await?;
if abort {
return Err(io::Error::new(
io::ErrorKind::Other,
"Unrecoverable error in NBD driver",
));
}
}
}
}
#[cfg(test)]
mod tests {
use super::NBDDriver;
use crate::errors::{OptionReplyError, ProtocolError};
use crate::flags::{CommandFlags, ServerFeatures};
use tokio;
use std::sync::RwLock;
struct MemoryDriver {
data: RwLock<Vec<u8>>,
}
impl NBDDriver for MemoryDriver {
fn get_features(&self) -> ServerFeatures {
ServerFeatures::SEND_FLUSH | ServerFeatures::SEND_FUA
}
async fn list_devices(&self) -> Result<Vec<String>, OptionReplyError> {
Ok(vec!["memory".to_string()])
}
async fn get_read_only(&self, device_name: &str) -> Result<bool, OptionReplyError> {
if device_name == "memory" {
Ok(false) } else {
Err(OptionReplyError::Unknown)
}
}
async fn read(
&self,
_flags: CommandFlags,
offset: u64,
length: u32,
) -> Result<Vec<u8>, ProtocolError> {
let data = self.data.read().unwrap();
let start = offset as usize;
let end = start + length as usize;
if end > data.len() {
return Err(ProtocolError::InvalidArgument);
}
Ok(data[start..end].to_vec())
}
async fn cache(
&self,
_flags: CommandFlags,
_offset: u64,
_length: u32,
) -> Result<(), ProtocolError> {
Err(ProtocolError::CommandNotSupported)
}
async fn get_block_size(
&self,
device_name: &str,
) -> Result<(u32, u32, u32), OptionReplyError> {
Err(OptionReplyError::Unsupported)
}
async fn get_canonical_name(&self, device_name: &str) -> Result<String, OptionReplyError> {
Err(OptionReplyError::Unsupported)
}
async fn get_description(&self, device_name: &str) -> Result<String, OptionReplyError> {
Err(OptionReplyError::Unsupported)
}
async fn get_device_size(&self, device_name: &str) -> Result<u64, OptionReplyError> {
Ok(self.data.read().unwrap().len() as u64)
}
async fn write(
&self,
flags: CommandFlags,
offset: u64,
data: Vec<u8>,
) -> Result<(), ProtocolError> {
let mut memory = self.data.write().unwrap();
let start = offset as usize;
let end = start + data.len();
if end > memory.len() {
return Err(ProtocolError::InvalidArgument);
}
memory[start..end].copy_from_slice(&data);
Ok(())
}
async fn flush(&self, flags: CommandFlags) -> Result<(), ProtocolError> {
Err(ProtocolError::CommandNotSupported)
}
async fn trim(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> Result<(), ProtocolError> {
Err(ProtocolError::CommandNotSupported)
}
async fn write_zeroes(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> Result<(), ProtocolError> {
Err(ProtocolError::CommandNotSupported)
}
async fn disconnect(&self, flags: CommandFlags) -> Result<(), ProtocolError> {
Ok(())
}
async fn resize(&self, flags: CommandFlags, size: u64) -> Result<(), ProtocolError> {
Err(ProtocolError::CommandNotSupported)
}
async fn block_status(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> Result<(), ProtocolError> {
Err(ProtocolError::CommandNotSupported)
}
}
#[tokio::test]
async fn driver_memory_read() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let result = driver.read(CommandFlags::empty(), 0, 512).await;
assert_eq!(result, Ok(vec![0; 512]));
}
#[tokio::test]
async fn driver_memory_write_read() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let write_data = vec![1; 512];
let write_result = driver
.write(CommandFlags::empty(), 0, write_data.clone())
.await;
assert!(write_result.is_ok());
let read_result = driver.read(CommandFlags::empty(), 0, 512).await;
assert_eq!(read_result, Ok(write_data));
}
#[tokio::test]
async fn driver_memory_resize() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let resize_result = driver.resize(CommandFlags::empty(), 2048).await;
assert_eq!(resize_result, Err(ProtocolError::CommandNotSupported));
}
#[tokio::test]
async fn driver_memory_read_out_of_bounds() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let result = driver.read(CommandFlags::empty(), 512, 600).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
let result = driver.read(CommandFlags::empty(), 2000, 10).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
}
#[tokio::test]
async fn driver_memory_write_out_of_bounds() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let write_data = vec![1; 600];
let result = driver.write(CommandFlags::empty(), 512, write_data).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
let write_data = vec![1; 10];
let result = driver.write(CommandFlags::empty(), 2000, write_data).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
}
#[tokio::test]
async fn driver_memory_partial_write_read() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let write_data = vec![1, 2, 3, 4, 5];
let offset = 100;
let data_len = write_data.len();
let result = driver
.write(CommandFlags::empty(), offset, write_data)
.await;
assert!(result.is_ok());
let read_result = driver
.read(CommandFlags::empty(), offset, data_len as u32)
.await;
assert_eq!(read_result, Ok(vec![1, 2, 3, 4, 5]));
let before_result = driver.read(CommandFlags::empty(), offset - 10, 10).await;
assert_eq!(before_result, Ok(vec![0; 10]));
let after_result = driver
.read(CommandFlags::empty(), offset + data_len as u64, 10)
.await;
assert_eq!(after_result, Ok(vec![0; 10]));
}
#[tokio::test]
async fn driver_memory_zero_length_operations() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let read_result = driver.read(CommandFlags::empty(), 100, 0).await;
assert_eq!(read_result, Ok(vec![]));
let write_result = driver.write(CommandFlags::empty(), 100, vec![]).await;
assert!(write_result.is_ok());
}
#[tokio::test]
async fn driver_memory_list_devices_and_info() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let devices = driver.list_devices().await.unwrap();
assert_eq!(devices, vec!["memory".to_string()]);
let size = driver.get_device_size("memory").await.unwrap();
assert_eq!(size, 1024);
let read_only = driver.get_read_only("memory").await.unwrap();
assert_eq!(read_only, false);
let read_only_result = driver.get_read_only("nonexistent").await;
assert!(matches!(read_only_result, Err(OptionReplyError::Unknown)));
}
#[tokio::test]
async fn driver_memory_command_flags() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let features = driver.get_features();
assert!(features.contains(ServerFeatures::SEND_FLUSH));
assert!(features.contains(ServerFeatures::SEND_FUA));
let fua_flag = CommandFlags::FUA; let write_data = vec![42; 128];
let data_len = write_data.len();
let write_result = driver.write(fua_flag, 200, write_data).await;
assert!(write_result.is_ok());
let read_result = driver
.read(CommandFlags::empty(), 200, data_len as u32)
.await;
assert_eq!(read_result, Ok(vec![42; data_len]));
}
#[tokio::test]
async fn driver_memory_unsupported_operations() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
assert_eq!(
driver.flush(CommandFlags::empty()).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.trim(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.write_zeroes(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.cache(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.block_status(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.get_block_size("memory").await,
Err(OptionReplyError::Unsupported)
);
assert_eq!(
driver.get_canonical_name("memory").await,
Err(OptionReplyError::Unsupported)
);
assert_eq!(
driver.get_description("memory").await,
Err(OptionReplyError::Unsupported)
);
}
#[tokio::test]
async fn driver_memory_disconnect() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let result = driver.disconnect(CommandFlags::empty()).await;
assert_eq!(result, Ok(()));
}
#[tokio::test]
async fn driver_memory_error_handling_and_recovery() {
let driver = MemoryDriver {
data: RwLock::new(vec![0; 1024]), };
let trim_result = driver.trim(CommandFlags::empty(), 0, 100).await;
assert_eq!(trim_result, Err(ProtocolError::CommandNotSupported));
let write_data = vec![123; 50];
let write_result = driver
.write(CommandFlags::empty(), 200, write_data.clone())
.await;
assert!(write_result.is_ok());
let out_of_bounds = driver.read(CommandFlags::empty(), 2000, 10).await;
assert_eq!(out_of_bounds, Err(ProtocolError::InvalidArgument));
let read_result = driver.read(CommandFlags::empty(), 200, 50).await;
assert_eq!(read_result, Ok(vec![123; 50]));
let data1 = vec![1, 2, 3, 4, 5];
let data2 = vec![6, 7, 8, 9, 10];
let write1_result = driver
.write(CommandFlags::empty(), 100, data1.clone())
.await;
assert!(write1_result.is_ok());
let write2_result = driver.write(CommandFlags::empty(), 5000, vec![0; 10]).await;
assert_eq!(write2_result, Err(ProtocolError::InvalidArgument));
let write3_result = driver
.write(CommandFlags::empty(), 300, data2.clone())
.await;
assert!(write3_result.is_ok());
assert_eq!(driver.read(CommandFlags::empty(), 100, 5).await, Ok(data1));
assert_eq!(driver.read(CommandFlags::empty(), 300, 5).await, Ok(data2));
}
}