use anyhow::anyhow;
use std::io::Cursor;
use std::io::{Read, Write};
use tracing::{error, trace, warn};
use crate::context::RPCContext;
use crate::rpc::*;
use crate::xdr::*;
use crate::mount;
use crate::mount_handlers;
use crate::nfs;
use crate::nfs_handlers;
use crate::portmap;
use crate::portmap_handlers;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::DuplexStream;
use tokio::sync::mpsc;
const NFS_ACL_PROGRAM: u32 = 100227;
const NFS_ID_MAP_PROGRAM: u32 = 100270;
const NFS_METADATA_PROGRAM: u32 = 200024;
async fn handle_rpc(
input: &mut impl Read,
output: &mut impl Write,
mut context: RPCContext,
) -> Result<(), anyhow::Error> {
let mut recv = rpc_msg::default();
recv.deserialize(input)?;
let xid = recv.xid;
if let rpc_body::CALL(call) = recv.body {
if let auth_flavor::AUTH_UNIX = call.cred.flavor {
let mut auth = auth_unix::default();
auth.deserialize(&mut Cursor::new(&call.cred.body))?;
context.auth = auth;
}
if call.rpcvers != 2 {
warn!("Invalid RPC version {} != 2", call.rpcvers);
rpc_vers_mismatch(xid).serialize(output)?;
return Ok(());
}
if call.prog == nfs::PROGRAM {
nfs_handlers::handle_nfs(xid, call, input, output, &context).await
} else if call.prog == portmap::PROGRAM {
portmap_handlers::handle_portmap(xid, call, input, output, &context)
} else if call.prog == mount::PROGRAM {
mount_handlers::handle_mount(xid, call, input, output, &context).await
} else if call.prog == NFS_ACL_PROGRAM
|| call.prog == NFS_ID_MAP_PROGRAM
|| call.prog == NFS_METADATA_PROGRAM
{
trace!("ignoring NFS_ACL packet");
prog_unavail_reply_message(xid).serialize(output)?;
Ok(())
} else {
warn!(
"Unknown RPC Program number {} != {}",
call.prog,
nfs::PROGRAM
);
prog_unavail_reply_message(xid).serialize(output)?;
Ok(())
}
} else {
error!("Unexpectedly received a Reply instead of a Call");
Err(anyhow!("Bad RPC Call format"))
}
}
async fn read_fragment(
socket: &mut DuplexStream,
append_to: &mut Vec<u8>,
) -> Result<bool, anyhow::Error> {
let mut header_buf = [0_u8; 4];
socket.read_exact(&mut header_buf).await?;
let fragment_header = u32::from_be_bytes(header_buf);
let is_last = (fragment_header & (1 << 31)) > 0;
let length = (fragment_header & ((1 << 31) - 1)) as usize;
trace!("Reading fragment length:{}, last:{}", length, is_last);
let start_offset = append_to.len();
append_to.resize(append_to.len() + length, 0);
socket.read_exact(&mut append_to[start_offset..]).await?;
trace!(
"Finishing Reading fragment length:{}, last:{}",
length,
is_last
);
Ok(is_last)
}
pub async fn write_fragment(
socket: &mut tokio::net::TcpStream,
buf: &Vec<u8>,
) -> Result<(), anyhow::Error> {
assert!(buf.len() < (1 << 31));
let fragment_header = buf.len() as u32 + (1 << 31);
let header_buf = u32::to_be_bytes(fragment_header);
socket.write_all(&header_buf).await?;
trace!("Writing fragment length:{}", buf.len());
socket.write_all(buf).await?;
Ok(())
}
pub type SocketMessageType = Result<Vec<u8>, anyhow::Error>;
#[derive(Debug)]
pub struct SocketMessageHandler {
cur_fragment: Vec<u8>,
socket_receive_channel: DuplexStream,
reply_send_channel: mpsc::UnboundedSender<SocketMessageType>,
context: RPCContext,
}
impl SocketMessageHandler {
pub fn new(
context: &RPCContext,
) -> (
Self,
DuplexStream,
mpsc::UnboundedReceiver<SocketMessageType>,
) {
let (socksend, sockrecv) = tokio::io::duplex(256000);
let (msgsend, msgrecv) = mpsc::unbounded_channel();
(
Self {
cur_fragment: Vec::new(),
socket_receive_channel: sockrecv,
reply_send_channel: msgsend,
context: context.clone(),
},
socksend,
msgrecv,
)
}
pub async fn read(&mut self) -> Result<(), anyhow::Error> {
let is_last =
read_fragment(&mut self.socket_receive_channel, &mut self.cur_fragment).await?;
if is_last {
let fragment = std::mem::take(&mut self.cur_fragment);
let context = self.context.clone();
let send = self.reply_send_channel.clone();
tokio::spawn(async move {
let mut write_buf: Vec<u8> = Vec::new();
let mut write_cursor = Cursor::new(&mut write_buf);
let maybe_reply =
handle_rpc(&mut Cursor::new(fragment), &mut write_cursor, context).await;
match maybe_reply {
Err(e) => {
error!("RPC Error: {:?}", e);
let _ = send.send(Err(e));
}
Ok(_) => {
let _ = std::io::Write::flush(&mut write_cursor);
let _ = send.send(Ok(write_buf));
}
}
});
}
Ok(())
}
}