use std::io::{Read, Write};
use std::sync::Arc;
use tracing::{debug, warn};
pub trait NbdBackend: Send + Sync {
fn size(&self) -> u64;
fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize>;
fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize>;
fn flush(&self) -> std::io::Result<()>;
}
impl NbdBackend for crate::backend::FlatFileBackend {
fn size(&self) -> u64 { self.size() }
fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> { self.read(offset, buf) }
fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize> { self.write(offset, buf) }
fn flush(&self) -> std::io::Result<()> { self.flush() }
}
impl NbdBackend for crate::cas::CasBackend {
fn size(&self) -> u64 { self.size() }
fn read(&self, offset: u64, buf: &mut [u8]) -> std::io::Result<usize> { self.read(offset, buf) }
fn write(&self, offset: u64, buf: &[u8]) -> std::io::Result<usize> { self.write(offset, buf) }
fn flush(&self) -> std::io::Result<()> { self.flush() }
}
const NBDMAGIC: u64 = 0x4e42444d41474943;
const IHAVEOPT: u64 = 0x49484156454F5054;
const REPLY_MAGIC: u64 = 0x3e889045565a9;
const NBD_FLAG_FIXED_NEWSTYLE: u16 = 1 << 0;
const NBD_FLAG_NO_ZEROES: u16 = 1 << 1;
const NBD_FLAG_C_NO_ZEROES: u32 = 1 << 1;
const NBD_FLAG_HAS_FLAGS: u16 = 1 << 0;
const NBD_FLAG_SEND_FLUSH: u16 = 1 << 2;
const NBD_OPT_EXPORT_NAME: u32 = 1;
const NBD_OPT_ABORT: u32 = 2;
const NBD_OPT_INFO: u32 = 6;
const NBD_OPT_GO: u32 = 7;
const NBD_REP_ACK: u32 = 1;
const NBD_REP_INFO: u32 = 3;
const NBD_REP_ERR_UNSUP: u32 = (1 << 31) | 1;
const NBD_INFO_EXPORT: u16 = 0;
const NBD_CMD_READ: u16 = 0;
const NBD_CMD_WRITE: u16 = 1;
const NBD_CMD_DISC: u16 = 2;
const NBD_CMD_FLUSH: u16 = 3;
const NBD_SIMPLE_REPLY_MAGIC: u32 = 0x67446698;
const NBD_OK: u32 = 0;
const NBD_EIO: u32 = 5;
const NBD_EINVAL: u32 = 22;
pub fn handle_client(
mut stream: std::os::unix::net::UnixStream,
backend: Arc<dyn NbdBackend>,
) -> anyhow::Result<()> {
handshake(&mut stream, backend.as_ref())?;
transmission(&mut stream, backend.as_ref())?;
Ok(())
}
fn handshake(
stream: &mut std::os::unix::net::UnixStream,
backend: &dyn NbdBackend,
) -> anyhow::Result<()> {
stream.write_all(&NBDMAGIC.to_be_bytes())?;
stream.write_all(&IHAVEOPT.to_be_bytes())?;
let server_flags = NBD_FLAG_FIXED_NEWSTYLE | NBD_FLAG_NO_ZEROES;
stream.write_all(&server_flags.to_be_bytes())?;
stream.flush()?;
let mut buf = [0u8; 4];
stream.read_exact(&mut buf)?;
let client_flags = u32::from_be_bytes(buf);
let no_zeroes = (client_flags & NBD_FLAG_C_NO_ZEROES) != 0;
loop {
let mut opt_header = [0u8; 16];
stream.read_exact(&mut opt_header)?;
let magic = u64::from_be_bytes(opt_header[0..8].try_into().unwrap());
if magic != IHAVEOPT {
anyhow::bail!("bad option magic: {:#x}", magic);
}
let option = u32::from_be_bytes(opt_header[8..12].try_into().unwrap());
let data_len = u32::from_be_bytes(opt_header[12..16].try_into().unwrap());
let mut opt_data = vec![0u8; data_len as usize];
if data_len > 0 {
stream.read_exact(&mut opt_data)?;
}
match option {
NBD_OPT_EXPORT_NAME => {
let trans_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_FLUSH;
stream.write_all(&backend.size().to_be_bytes())?;
stream.write_all(&trans_flags.to_be_bytes())?;
if !no_zeroes {
stream.write_all(&[0u8; 124])?;
}
stream.flush()?;
debug!("NBD handshake complete (EXPORT_NAME), size={}", backend.size());
return Ok(());
}
NBD_OPT_INFO | NBD_OPT_GO => {
send_export_info(stream, option, backend)?;
if option == NBD_OPT_GO {
debug!("NBD handshake complete (GO), size={}", backend.size());
return Ok(());
}
debug!("NBD INFO reply sent, size={}", backend.size());
}
NBD_OPT_ABORT => {
send_option_reply(stream, option, NBD_REP_ACK, &[])?;
stream.flush()?;
anyhow::bail!("client aborted");
}
_ => {
debug!("unsupported NBD option: {}", option);
send_option_reply(stream, option, NBD_REP_ERR_UNSUP, &[])?;
stream.flush()?;
}
}
}
}
fn send_export_info(
stream: &mut std::os::unix::net::UnixStream,
option: u32,
backend: &dyn NbdBackend,
) -> std::io::Result<()> {
let trans_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_FLUSH;
let mut info = [0u8; 12];
info[0..2].copy_from_slice(&NBD_INFO_EXPORT.to_be_bytes());
info[2..10].copy_from_slice(&backend.size().to_be_bytes());
info[10..12].copy_from_slice(&trans_flags.to_be_bytes());
send_option_reply(stream, option, NBD_REP_INFO, &info)?;
send_option_reply(stream, option, NBD_REP_ACK, &[])?;
stream.flush()
}
fn send_option_reply(
stream: &mut std::os::unix::net::UnixStream,
option: u32,
reply_type: u32,
data: &[u8],
) -> std::io::Result<()> {
stream.write_all(&REPLY_MAGIC.to_be_bytes())?;
stream.write_all(&option.to_be_bytes())?;
stream.write_all(&reply_type.to_be_bytes())?;
stream.write_all(&(data.len() as u32).to_be_bytes())?;
if !data.is_empty() {
stream.write_all(data)?;
}
Ok(())
}
fn transmission(
stream: &mut std::os::unix::net::UnixStream,
backend: &dyn NbdBackend,
) -> anyhow::Result<()> {
let mut req_header = [0u8; 28];
loop {
if let Err(e) = stream.read_exact(&mut req_header) {
debug!("NBD session ended while reading command: {}", e);
return Ok(());
}
let magic = u32::from_be_bytes(req_header[0..4].try_into().unwrap());
if magic != 0x25609513 {
anyhow::bail!("bad request magic: {:#x}", magic);
}
let cmd_type = u16::from_be_bytes(req_header[6..8].try_into().unwrap());
let handle = &req_header[8..16];
let offset = u64::from_be_bytes(req_header[16..24].try_into().unwrap());
let length = u32::from_be_bytes(req_header[24..28].try_into().unwrap());
match cmd_type {
NBD_CMD_READ => {
let mut buf = vec![0u8; length as usize];
let error = match backend.read(offset, &mut buf) {
Ok(n) => {
if n < length as usize {
buf[n..].fill(0);
}
NBD_OK
}
Err(e) => {
warn!("NBD read error at offset {}: {}", offset, e);
NBD_EIO
}
};
send_reply(stream, error, handle, if error == NBD_OK { Some(&buf) } else { None })?;
}
NBD_CMD_WRITE => {
let mut data = vec![0u8; length as usize];
stream.read_exact(&mut data)?;
let error = match backend.write(offset, &data) {
Ok(_) => NBD_OK,
Err(e) => {
warn!("NBD write error at offset {}: {}", offset, e);
NBD_EIO
}
};
send_reply(stream, error, handle, None)?;
}
NBD_CMD_FLUSH => {
let error = match backend.flush() {
Ok(()) => NBD_OK,
Err(e) => {
warn!("NBD flush error: {}", e);
NBD_EIO
}
};
send_reply(stream, error, handle, None)?;
}
NBD_CMD_DISC => {
debug!("NBD client sent disconnect");
return Ok(());
}
_ => {
warn!("unsupported NBD command: {}", cmd_type);
send_reply(stream, NBD_EINVAL, handle, None)?;
}
}
}
}
fn send_reply(
stream: &mut std::os::unix::net::UnixStream,
error: u32,
handle: &[u8],
data: Option<&[u8]>,
) -> std::io::Result<()> {
stream.write_all(&NBD_SIMPLE_REPLY_MAGIC.to_be_bytes())?;
stream.write_all(&error.to_be_bytes())?;
stream.write_all(handle)?;
if let Some(data) = data {
stream.write_all(data)?;
}
stream.flush()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use crate::FlatFileBackend;
fn create_test_backend() -> (tempfile::NamedTempFile, Arc<FlatFileBackend>) {
let mut tmp = tempfile::NamedTempFile::new().unwrap();
let data = vec![0xABu8; 1024 * 1024]; tmp.write_all(&data).unwrap();
tmp.flush().unwrap();
let backend = Arc::new(FlatFileBackend::open(tmp.path().to_str().unwrap()).unwrap());
(tmp, backend)
}
#[test]
fn test_backend_read_write() {
let (_tmp, backend) = create_test_backend();
let mut buf = [0u8; 4];
backend.read(0, &mut buf).unwrap();
assert_eq!(buf, [0xAB; 4]);
backend.write(0, &[1, 2, 3, 4]).unwrap();
backend.read(0, &mut buf).unwrap();
assert_eq!(buf, [1, 2, 3, 4]);
let mut buf2 = [0u8; 4];
backend.read(4, &mut buf2).unwrap();
assert_eq!(buf2, [0xAB; 4]);
}
}