use anyhow::Result;
use hexz_core::{Archive, ArchiveStream};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
const NBD_MAGIC: u64 = 0x4e42_444d_4147_4943;
const NBD_OPT_MAGIC: u64 = 0x4948_4156_454F_5054;
const NBD_REP_MAGIC: u64 = 0x0003_e889_0455_65a9;
const NBD_FLAG_FIXED_NEWSTYLE: u16 = 1 << 0;
const NBD_FLAG_NO_ZEROES: u16 = 1 << 1;
const NBD_FLAG_HAS_FLAGS: u16 = 1 << 0;
const NBD_FLAG_READ_ONLY: u16 = 1 << 1;
const NBD_OPT_EXPORT_NAME: u32 = 1;
const NBD_OPT_ABORT: u32 = 2;
const NBD_OPT_INFO: u32 = 6;
const NBD_OPT_GO: u32 = 7;
const NBD_REP_ACK: u32 = 1;
const NBD_REP_INFO: u32 = 3;
const NBD_INFO_EXPORT: u16 = 0;
const NBD_CMD_READ: u16 = 0;
const NBD_CMD_WRITE: u16 = 1;
const NBD_CMD_DISC: u16 = 2;
const NBD_CMD_FLUSH: u16 = 3;
const NBD_CMD_TRIM: u16 = 4;
const NBD_REQUEST_MAGIC: u32 = 0x2560_9513;
const NBD_REPLY_MAGIC: u32 = 0x6744_6698;
const NBD_MAX_BUFFER_SIZE: u32 = 32 * 1024 * 1024;
pub async fn handle_client(mut socket: TcpStream, snap: Arc<Archive>) -> Result<()> {
socket.write_u64(NBD_MAGIC).await?;
socket.write_u64(NBD_OPT_MAGIC).await?;
socket
.write_u16(NBD_FLAG_FIXED_NEWSTYLE | NBD_FLAG_NO_ZEROES)
.await?;
let client_flags = socket.read_u32().await?;
let client_supports_no_zeroes = (client_flags & (NBD_FLAG_NO_ZEROES as u32)) != 0;
loop {
let magic = socket.read_u64().await?;
if magic != NBD_OPT_MAGIC {
anyhow::bail!("Invalid option magic");
}
let opt_id = socket.read_u32().await?;
let opt_len = socket.read_u32().await?;
if opt_len > NBD_MAX_BUFFER_SIZE {
anyhow::bail!("NBD option data too large: {opt_len} bytes");
}
let mut opt_data = vec![0u8; opt_len as usize];
_ = socket.read_exact(&mut opt_data).await?;
match opt_id {
NBD_OPT_ABORT => return Ok(()),
NBD_OPT_EXPORT_NAME => {
let size = snap.size(ArchiveStream::Main);
let export_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_READ_ONLY;
socket.write_u64(size).await?;
socket.write_u16(export_flags).await?;
if !client_supports_no_zeroes {
socket.write_all(&[0u8; 124]).await?;
}
break;
}
NBD_OPT_INFO | NBD_OPT_GO => {
let size = snap.size(ArchiveStream::Main);
let export_flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_READ_ONLY;
socket.write_u64(NBD_REP_MAGIC).await?;
socket.write_u32(opt_id).await?;
socket.write_u32(NBD_REP_INFO).await?;
socket.write_u32(12).await?; socket.write_u16(NBD_INFO_EXPORT).await?;
socket.write_u64(size).await?;
socket.write_u16(export_flags).await?;
socket.write_u64(NBD_REP_MAGIC).await?;
socket.write_u32(opt_id).await?;
socket.write_u32(NBD_REP_ACK).await?;
socket.write_u32(0).await?;
if opt_id == NBD_OPT_GO {
break;
}
}
_ => {
socket.write_u64(NBD_REP_MAGIC).await?;
socket.write_u32(opt_id).await?;
socket.write_u32(0x8000_0001).await?;
socket.write_u32(0).await?;
}
}
}
loop {
let magic = socket.read_u32().await?;
if magic != NBD_REQUEST_MAGIC {
anyhow::bail!("Invalid request magic: {magic:x}");
}
let _flags = socket.read_u16().await?;
let type_ = socket.read_u16().await?;
let handle = socket.read_u64().await?;
let offset = socket.read_u64().await?;
let length = socket.read_u32().await?;
if length > NBD_MAX_BUFFER_SIZE {
anyhow::bail!("NBD request length too large: {length} bytes");
}
match type_ {
NBD_CMD_READ => {
let mut error = 0u32;
let data = match snap.read_at(ArchiveStream::Main, offset, length as usize) {
Ok(d) => d,
Err(e) => {
tracing::error!("Read error: {}", e);
error = 5; Vec::new()
}
};
socket.write_u32(NBD_REPLY_MAGIC).await?;
socket.write_u32(error).await?;
socket.write_u64(handle).await?;
if error == 0 {
socket.write_all(&data).await?;
if data.len() < length as usize {
let padding = vec![0u8; length as usize - data.len()];
socket.write_all(&padding).await?;
}
} else {
let padding = vec![0u8; length as usize];
socket.write_all(&padding).await?;
}
}
NBD_CMD_DISC => {
return Ok(());
}
NBD_CMD_WRITE | NBD_CMD_TRIM | NBD_CMD_FLUSH => {
if type_ == NBD_CMD_WRITE {
let mut buf = vec![0u8; length as usize];
_ = socket.read_exact(&mut buf).await?;
}
let error = 1u32;
socket.write_u32(NBD_REPLY_MAGIC).await?;
socket.write_u32(error).await?;
socket.write_u64(handle).await?;
}
_ => {
let error = 22u32;
socket.write_u32(NBD_REPLY_MAGIC).await?;
socket.write_u32(error).await?;
socket.write_u64(handle).await?;
}
}
}
}