use std::io::Cursor;
use std::io::{Read, Write};
use std::sync::Arc;
use anyhow::anyhow;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::DuplexStream;
use tokio::sync::mpsc;
use tracing::{debug, error, trace, warn};
use crate::protocol::rpc::command_queue::{CommandQueue, CommandResult, ResponseBuffer};
use crate::protocol::xdr::{self, deserialize, mount, nfs3, portmap, Serialize};
use crate::protocol::{nfs, rpc};
const NFS_ACL_PROGRAM: u32 = 100227;
const NFS_ID_MAP_PROGRAM: u32 = 100270;
const NFS_LOCALIO_PROGRAM: u32 = 400122;
const NFS_METADATA_PROGRAM: u32 = 200024;
const DEFAULT_RESPONSE_BUFFER_CAPACITY: usize = 8192;
enum RpcOutcome {
Send { xid: u32, record_response: bool },
Drop,
}
async fn handle_rpc(
input: &mut impl Read,
output: &mut impl Write,
mut context: rpc::Context,
) -> Result<RpcOutcome, anyhow::Error> {
let recv = deserialize::<xdr::rpc::rpc_msg>(input)?;
let xid = recv.xid;
if let xdr::rpc::rpc_body::CALL(call) = recv.body {
if let xdr::rpc::auth_flavor::AUTH_UNIX = call.cred.flavor {
context.auth = deserialize(&mut Cursor::new(&call.cred.body))?;
}
let status = context.transaction_tracker.check(xid, &context.client_addr);
match status {
rpc::TransactionStatus::Completed(response) => {
output.write_all(&response)?;
return Ok(RpcOutcome::Send { xid, record_response: false });
}
rpc::TransactionStatus::InProgress => {
debug!(
"Retransmission in progress, xid: {}, client_addr: {}, call: {:?}",
xid, context.client_addr, call
);
return Ok(RpcOutcome::Drop);
}
rpc::TransactionStatus::New => {}
}
if call.rpcvers != 2 {
warn!("Invalid RPC version {} != 2", call.rpcvers);
xdr::rpc::rpc_vers_mismatch(xid).serialize(output)?;
return Ok(RpcOutcome::Send { xid, record_response: true });
}
let res = {
match call.prog {
nfs3::PROGRAM => match call.vers {
nfs3::VERSION => nfs::v3::handle_nfs(xid, call, input, output, &context).await,
_ => {
warn!(
"Unsupported NFS program version {} (supported {})",
call.vers,
nfs3::VERSION
);
xdr::rpc::prog_mismatch_reply_message(xid, nfs3::VERSION)
.serialize(output)?;
Ok(())
}
},
portmap::PROGRAM => {
nfs::portmap::handle_portmap(xid, &call, input, output, &mut context)
}
mount::PROGRAM => {
nfs::mount::handle_mount(xid, call, input, output, &context).await
}
NFS_ACL_PROGRAM | NFS_ID_MAP_PROGRAM | NFS_METADATA_PROGRAM => {
trace!("ignoring NFS_ACL packet");
xdr::rpc::prog_unavail_reply_message(xid).serialize(output)?;
Ok(())
}
NFS_LOCALIO_PROGRAM => {
trace!("Ignoring NFS_LOCALIO packet");
xdr::rpc::prog_unavail_reply_message(xid).serialize(output)?;
Ok(())
}
unknown_number => {
warn!("Unknown RPC Program number {} != {}", unknown_number, nfs3::PROGRAM);
xdr::rpc::prog_unavail_reply_message(xid).serialize(output)?;
Ok(())
}
}
};
match res {
Ok(()) => Ok(RpcOutcome::Send { xid, record_response: true }),
Err(e) => {
context.transaction_tracker.clear(xid, &context.client_addr);
Err(e)
}
}
} else {
error!("Unexpectedly received a Reply instead of a Call");
Err(anyhow!("Bad RPC Call format"))
}
}
async fn read_fragment(
socket: &mut DuplexStream,
append_to: &mut Vec<u8>,
) -> Result<bool, anyhow::Error> {
let mut header_buf = [0_u8; 4];
socket.read_exact(&mut header_buf).await?;
let fragment_header = u32::from_be_bytes(header_buf);
let is_last = (fragment_header & (1 << 31)) > 0;
let length = (fragment_header & ((1 << 31) - 1)) as usize;
trace!("Reading fragment length:{}, last:{}", length, is_last);
if append_to.len().saturating_add(length) > rpc::MAX_RPC_RECORD_LENGTH {
return Err(anyhow!(
"RPC record length {} exceeds max {}",
length,
rpc::MAX_RPC_RECORD_LENGTH
));
}
let start_offset = append_to.len();
append_to.resize(append_to.len() + length, 0);
socket.read_exact(&mut append_to[start_offset..]).await?;
trace!("Finishing Reading fragment length:{}, last:{}", length, is_last);
Ok(is_last)
}
pub async fn write_fragment(
socket: &mut tokio::net::TcpStream,
buf: &[u8],
) -> Result<(), anyhow::Error> {
const MAX_FRAGMENT_SIZE: usize = (1 << 31) - 1;
let mut offset = 0;
while offset < buf.len() {
let remaining = buf.len() - offset;
let fragment_size = std::cmp::min(remaining, MAX_FRAGMENT_SIZE);
let is_last = offset + fragment_size >= buf.len();
let fragment_header =
if is_last { fragment_size as u32 + (1 << 31) } else { fragment_size as u32 };
let header_buf = u32::to_be_bytes(fragment_header);
socket.write_all(&header_buf).await?;
trace!("Writing fragment length:{}, last:{}", fragment_size, is_last);
socket.write_all(&buf[offset..offset + fragment_size]).await?;
offset += fragment_size;
}
Ok(())
}
pub type SocketMessageType = Result<Vec<u8>, anyhow::Error>;
#[derive(Debug)]
pub struct SocketMessageHandler {
cur_fragment: Vec<u8>,
socket_receive_channel: DuplexStream,
context: rpc::Context,
command_queue: CommandQueue,
}
impl SocketMessageHandler {
pub fn new(
context: &rpc::Context,
) -> (Self, DuplexStream, mpsc::UnboundedReceiver<SocketMessageType>) {
let (socksend, sockrecv) = tokio::io::duplex(256_000);
let (msgsend, msgrecv) = mpsc::unbounded_channel();
let (result_sender, mut result_receiver) = mpsc::unbounded_channel::<CommandResult>();
let command_queue =
CommandQueue::new(process_rpc_command, result_sender, DEFAULT_RESPONSE_BUFFER_CAPACITY);
tokio::spawn(async move {
while let Some(result) = result_receiver.recv().await {
match result {
Ok(Some(response_buffer)) if response_buffer.has_content() => {
let _ = msgsend.send(Ok(response_buffer.into_inner()));
}
Ok(None) => {
}
Ok(Some(_)) => {
}
Err(e) => {
error!("RPC error: {:?}", e);
let _ = msgsend.send(Err(e));
}
}
}
debug!("Command result handler finished");
});
(
Self {
cur_fragment: Vec::new(),
socket_receive_channel: sockrecv,
context: context.clone(),
command_queue,
},
socksend,
msgrecv,
)
}
pub async fn read(&mut self) -> Result<(), anyhow::Error> {
let is_last =
read_fragment(&mut self.socket_receive_channel, &mut self.cur_fragment).await?;
if is_last {
let fragment_data = std::mem::take(&mut self.cur_fragment);
let context = self.context.clone();
if let Err(e) = self.command_queue.submit_command(fragment_data, context) {
error!("Failed to submit command to queue: {:?}", e);
return Err(anyhow::anyhow!("Command queue error: {}", e));
}
}
Ok(())
}
}
pub fn process_rpc_command<'a>(
data: &[u8],
output: &'a mut ResponseBuffer,
context: rpc::Context,
) -> futures::future::BoxFuture<'a, anyhow::Result<bool>> {
let data_clone = data.to_vec();
Box::pin(async move {
let mut input_cursor = Cursor::new(data_clone);
let output_buffer = output.get_mut_buffer();
let mut output_cursor = Cursor::new(&mut *output_buffer);
let tracker = context.transaction_tracker.clone();
let client_addr = context.client_addr.clone();
let result = handle_rpc(&mut input_cursor, &mut output_cursor, context).await?;
match result {
RpcOutcome::Send { xid, record_response } => {
if record_response && !output_buffer.is_empty() {
tracker.record_response(xid, &client_addr, Arc::new(output_buffer.clone()));
}
Ok(true)
}
RpcOutcome::Drop => Ok(false),
}
})
}