use std::sync::atomic::AtomicU64;
use crate::errors::{OptionReplyError, ProtocolError};
use crate::flags::{CommandFlags, ServerFeatures};
pub trait NbdDriver {
fn get_name(&self) -> String;
fn get_features(&self) -> ServerFeatures;
fn get_read_only(&self) -> impl Future<Output = Result<bool, OptionReplyError>> + Send;
fn get_block_size(
&self,
) -> impl Future<Output = Result<(u32, u32, u32), OptionReplyError>> + Send;
fn get_canonical_name(&self) -> impl Future<Output = Result<String, OptionReplyError>> + Send;
fn get_description(&self) -> impl Future<Output = Result<String, OptionReplyError>> + Send;
fn get_device_size(&self) -> &AtomicU64;
fn read(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<Vec<u8>, ProtocolError>> + Send;
fn write(
&self,
flags: CommandFlags,
offset: u64,
data: Vec<u8>,
) -> impl Future<Output = Result<(), ProtocolError>> + Send;
fn disconnect(
&self,
flags: CommandFlags,
) -> impl Future<Output = Result<(), ProtocolError>> + Send;
fn flush(
&self,
_flags: CommandFlags,
) -> impl Future<Output = Result<(), ProtocolError>> + Send {
async move { Err(ProtocolError::CommandNotSupported) }
}
fn trim(
&self,
_flags: CommandFlags,
_offset: u64,
_length: u32,
) -> impl Future<Output = Result<(), ProtocolError>> + Send {
async move { Err(ProtocolError::CommandNotSupported) }
}
fn write_zeroes(
&self,
_flags: CommandFlags,
_offset: u64,
_length: u32,
) -> impl Future<Output = Result<(), ProtocolError>> + Send {
async move { Err(ProtocolError::CommandNotSupported) }
}
fn resize(
&self,
flags: CommandFlags,
size: u64,
) -> impl Future<Output = Result<(), ProtocolError>> + Send {
async move { Err(ProtocolError::CommandNotSupported) }
}
fn cache(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<(), ProtocolError>> + Send {
async move { Err(ProtocolError::CommandNotSupported) }
}
fn block_status(
&self,
flags: CommandFlags,
offset: u64,
length: u32,
) -> impl Future<Output = Result<(), ProtocolError>> + Send {
async move { Err(ProtocolError::CommandNotSupported) }
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::NbdDriver;
use crate::errors::{OptionReplyError, ProtocolError};
use crate::flags::{CommandFlags, ServerFeatures};
use tokio;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::RwLock;
#[derive(Debug)]
pub(crate) struct MemoryDriver {
size: AtomicU64,
data: RwLock<Vec<u8>>,
read_only: bool,
name: String,
}
impl Default for MemoryDriver {
fn default() -> Self {
MemoryDriver {
size: AtomicU64::new(1024), data: RwLock::new(vec![0; 1024]),
read_only: false,
name: "".to_string(),
}
}
}
impl NbdDriver for MemoryDriver {
fn get_features(&self) -> ServerFeatures {
ServerFeatures::SEND_FUA | ServerFeatures::SEND_RESIZE
}
fn get_name(&self) -> String {
self.name.clone()
}
async fn get_read_only(&self) -> Result<bool, OptionReplyError> {
Ok(self.read_only)
}
async fn get_block_size(&self) -> Result<(u32, u32, u32), OptionReplyError> {
Err(OptionReplyError::Unsupported)
}
async fn get_canonical_name(&self) -> Result<String, OptionReplyError> {
Err(OptionReplyError::Unsupported)
}
async fn get_description(&self) -> Result<String, OptionReplyError> {
Err(OptionReplyError::Unsupported)
}
fn get_device_size(&self) -> &AtomicU64 {
&self.size
}
async fn read(
&self,
_flags: CommandFlags,
offset: u64,
length: u32,
) -> Result<Vec<u8>, ProtocolError> {
if offset + length as u64 >= self.size.load(Ordering::Acquire) {
return Err(ProtocolError::InvalidArgument);
}
let start = offset as usize;
let end = start + length as usize;
let data = self.data.read().await;
Ok(data[start..end].to_vec())
}
async fn write(
&self,
_flags: CommandFlags,
offset: u64,
data: Vec<u8>,
) -> Result<(), ProtocolError> {
if offset + data.len() as u64 >= self.size.load(Ordering::Acquire) {
return Err(ProtocolError::InvalidArgument);
}
let mut memory = self.data.write().await;
let start = offset as usize;
let end = start + data.len();
memory[start..end].copy_from_slice(&data);
Ok(())
}
async fn disconnect(&self, _flags: CommandFlags) -> Result<(), ProtocolError> {
Ok(())
}
async fn resize(&self, _flags: CommandFlags, size: u64) -> Result<(), ProtocolError> {
let current_size = self.size.load(Ordering::Acquire);
match size.cmp(¤t_size) {
std::cmp::Ordering::Equal => {
Ok(())
}
std::cmp::Ordering::Less => {
Err(ProtocolError::CommandNotSupported)
}
std::cmp::Ordering::Greater => {
let mut memory = self.data.write().await;
memory.resize(size as usize, 0);
self.size.store(size, Ordering::Release);
Ok(())
}
}
}
}
#[tokio::test]
async fn driver_memory_read() {
let driver = MemoryDriver::default();
let result = driver.read(CommandFlags::empty(), 0, 512).await;
assert_eq!(result, Ok(vec![0; 512]));
}
#[tokio::test]
async fn driver_memory_write_read() {
let driver = MemoryDriver::default();
let write_data = vec![1; 512];
let write_result = driver
.write(CommandFlags::empty(), 0, write_data.clone())
.await;
assert!(write_result.is_ok());
let read_result = driver.read(CommandFlags::empty(), 0, 512).await;
assert_eq!(read_result, Ok(write_data));
}
#[tokio::test]
async fn driver_memory_resize() {
let driver = MemoryDriver::default();
let resize_result = driver.resize(CommandFlags::empty(), 2048).await;
assert!(resize_result.is_ok());
let new_size = driver.get_device_size().load(Ordering::Acquire);
assert_eq!(new_size, 2048);
let write_data = vec![42; 100];
let write_result = driver
.write(CommandFlags::empty(), 1500, write_data.clone())
.await;
assert!(write_result.is_ok());
let read_result = driver.read(CommandFlags::empty(), 1500, 100).await;
assert_eq!(read_result, Ok(write_data));
let shrink_result = driver.resize(CommandFlags::empty(), 1024).await;
assert_eq!(shrink_result, Err(ProtocolError::CommandNotSupported));
}
#[tokio::test]
async fn driver_memory_read_out_of_bounds() {
let driver = MemoryDriver::default();
let result = driver.read(CommandFlags::empty(), 512, 600).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
let result = driver.read(CommandFlags::empty(), 2000, 10).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
}
#[tokio::test]
async fn driver_memory_write_out_of_bounds() {
let driver = MemoryDriver::default();
let write_data = vec![1; 600];
let result = driver.write(CommandFlags::empty(), 512, write_data).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
let write_data = vec![1; 10];
let result = driver.write(CommandFlags::empty(), 2000, write_data).await;
assert_eq!(result, Err(ProtocolError::InvalidArgument));
}
#[tokio::test]
async fn driver_memory_partial_write_read() {
let driver = MemoryDriver::default();
let write_data = vec![1, 2, 3, 4, 5];
let offset = 100;
let data_len = write_data.len();
let result = driver
.write(CommandFlags::empty(), offset, write_data)
.await;
assert!(result.is_ok());
let read_result = driver
.read(CommandFlags::empty(), offset, data_len as u32)
.await;
assert_eq!(read_result, Ok(vec![1, 2, 3, 4, 5]));
let before_result = driver.read(CommandFlags::empty(), offset - 10, 10).await;
assert_eq!(before_result, Ok(vec![0; 10]));
let after_result = driver
.read(CommandFlags::empty(), offset + data_len as u64, 10)
.await;
assert_eq!(after_result, Ok(vec![0; 10]));
}
#[tokio::test]
async fn driver_memory_zero_length_operations() {
let driver = MemoryDriver::default();
let read_result = driver.read(CommandFlags::empty(), 100, 0).await;
assert_eq!(read_result, Ok(vec![]));
let write_result = driver.write(CommandFlags::empty(), 100, vec![]).await;
assert!(write_result.is_ok());
}
#[tokio::test]
async fn driver_memory_list_devices_and_info() {
let driver = MemoryDriver::default();
let size = driver.get_device_size().load(Ordering::Acquire);
assert_eq!(size, 1024);
let read_only = driver.get_read_only().await.unwrap();
assert_eq!(read_only, false);
}
#[tokio::test]
async fn driver_memory_get_name() {
let mut driver = MemoryDriver::default();
driver.name = "test-device".to_string();
let name = driver.get_name();
assert_eq!(name, "test-device");
driver.name = "another-device".to_string();
let name = driver.get_name();
assert_eq!(name, "another-device");
}
#[tokio::test]
async fn driver_memory_command_flags() {
let driver = MemoryDriver::default();
let features = driver.get_features();
assert!(features.contains(ServerFeatures::SEND_FUA));
let fua_flag = CommandFlags::FUA; let write_data = vec![42; 128];
let data_len = write_data.len();
let write_result = driver.write(fua_flag, 200, write_data).await;
assert!(write_result.is_ok());
let read_result = driver
.read(CommandFlags::empty(), 200, data_len as u32)
.await;
assert_eq!(read_result, Ok(vec![42; data_len]));
}
#[tokio::test]
async fn driver_memory_unsupported_operations() {
let driver = MemoryDriver::default();
assert_eq!(
driver.flush(CommandFlags::empty()).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.trim(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.write_zeroes(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.cache(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.block_status(CommandFlags::empty(), 0, 100).await,
Err(ProtocolError::CommandNotSupported)
);
assert_eq!(
driver.get_block_size().await,
Err(OptionReplyError::Unsupported)
);
assert_eq!(
driver.get_canonical_name().await,
Err(OptionReplyError::Unsupported)
);
assert_eq!(
driver.get_description().await,
Err(OptionReplyError::Unsupported)
);
}
#[tokio::test]
async fn driver_memory_disconnect() {
let driver = MemoryDriver::default();
let result = driver.disconnect(CommandFlags::empty()).await;
assert_eq!(result, Ok(()));
}
#[tokio::test]
async fn driver_memory_error_handling_and_recovery() {
let driver = MemoryDriver::default();
let trim_result = driver.trim(CommandFlags::empty(), 0, 100).await;
assert_eq!(trim_result, Err(ProtocolError::CommandNotSupported));
let write_data = vec![123; 50];
let write_result = driver
.write(CommandFlags::empty(), 200, write_data.clone())
.await;
assert!(write_result.is_ok());
let out_of_bounds = driver.read(CommandFlags::empty(), 2000, 10).await;
assert_eq!(out_of_bounds, Err(ProtocolError::InvalidArgument));
let read_result = driver.read(CommandFlags::empty(), 200, 50).await;
assert_eq!(read_result, Ok(vec![123; 50]));
let data1 = vec![1, 2, 3, 4, 5];
let data2 = vec![6, 7, 8, 9, 10];
let write1_result = driver
.write(CommandFlags::empty(), 100, data1.clone())
.await;
assert!(write1_result.is_ok());
let write2_result = driver.write(CommandFlags::empty(), 5000, vec![0; 10]).await;
assert_eq!(write2_result, Err(ProtocolError::InvalidArgument));
let write3_result = driver
.write(CommandFlags::empty(), 300, data2.clone())
.await;
assert!(write3_result.is_ok());
assert_eq!(driver.read(CommandFlags::empty(), 100, 5).await, Ok(data1));
assert_eq!(driver.read(CommandFlags::empty(), 300, 5).await, Ok(data2));
}
}