use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use std::{io, net::IpAddr};
use anyhow;
use async_trait::async_trait;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tracing::{debug, error, info, warn};
use crate::protocol::nfs::portmap::PortmapTable;
use crate::protocol::{rpc, xdr};
use crate::vfs::NFSFileSystem;
pub struct NFSTcpListener<T: NFSFileSystem + Send + Sync + 'static> {
listener: TcpListener,
port: u16,
arcfs: Arc<T>,
mount_signal: Option<mpsc::Sender<bool>>,
export_name: Arc<String>,
transaction_tracker: Arc<rpc::TransactionTracker>,
portmap_table: Arc<RwLock<PortmapTable>>,
require_privileged_source_port: bool,
}
pub fn generate_host_ip(hostnum: u16) -> String {
format!("127.88.{}.{}", ((hostnum >> 8) & 0xFF) as u8, (hostnum & 0xFF) as u8)
}
async fn process_socket(
mut socket: tokio::net::TcpStream,
context: rpc::Context,
) -> Result<(), anyhow::Error> {
let (mut message_handler, mut socksend, mut msgrecvchan) =
rpc::SocketMessageHandler::new(&context);
let _ = socket.set_nodelay(true);
tokio::spawn(async move {
loop {
if let Err(e) = message_handler.read().await {
debug!("Message loop broken due to {:?}", e);
break;
}
}
});
loop {
tokio::select! {
_ = socket.readable() => {
let mut buf = [0; 128_000];
match socket.try_read(&mut buf) {
Ok(0) => {
return Ok(());
}
Ok(n) => {
let _ = socksend.write_all(&buf[..n]).await;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
}
Err(e) => {
debug!("Message handling closed : {:?}", e);
return Err(e.into());
}
}
},
reply = msgrecvchan.recv() => {
match reply {
Some(Err(e)) => {
debug!("Message handling closed : {:?}", e);
return Err(e);
}
Some(Ok(msg)) => {
if let Err(e) = rpc::write_fragment(&mut socket, &msg).await {
error!("Write error {:?}", e);
}
}
None => {
return Err(anyhow::anyhow!("Unexpected socket context termination"));
}
}
}
}
}
}
#[async_trait]
pub trait NFSTcp: Send + Sync {
fn get_listen_port(&self) -> u16;
fn get_listen_ip(&self) -> IpAddr;
fn set_mount_listener(&mut self, signal: mpsc::Sender<bool>);
async fn handle_forever(&self) -> io::Result<()>;
}
impl<T: NFSFileSystem + Send + Sync + 'static> NFSTcpListener<T> {
pub async fn bind(ipstr: &str, fs: T) -> io::Result<NFSTcpListener<T>> {
let arcfs: Arc<T> = Arc::new(fs);
if let Some(port_str) = ipstr.strip_prefix("auto:") {
let port = port_str.parse::<u16>().map_err(|_| {
io::Error::new(io::ErrorKind::AddrNotAvailable, "Port not in range 0..=65535")
})?;
const NUM_TRIES: u16 = 32;
for try_ip in 1..=NUM_TRIES {
let ip: IpAddr = generate_host_ip(try_ip).parse().map_err(|_| {
io::Error::new(io::ErrorKind::AddrNotAvailable, "Invalid auto IP address")
})?;
let addr = SocketAddr::new(ip, port);
let result = NFSTcpListener::bind_internal(addr, arcfs.clone()).await;
if result.is_ok() {
return result;
}
}
return Err(io::Error::other("Can't bind automatically"));
}
let addr = ipstr.parse::<SocketAddr>().map_err(|_| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
"Address must be in IP:PORT or [IPV6]:PORT form",
)
})?;
NFSTcpListener::bind_internal(addr, arcfs).await
}
async fn bind_internal(addr: SocketAddr, arcfs: Arc<T>) -> io::Result<NFSTcpListener<T>> {
let listener = TcpListener::bind(addr).await?;
let local_addr = listener.local_addr()?;
info!("Listening on {:?}", local_addr);
let port = local_addr.port();
Ok(NFSTcpListener {
listener,
port,
arcfs,
mount_signal: None,
export_name: Arc::from("/".to_string()),
transaction_tracker: Arc::new(rpc::TransactionTracker::new(Duration::from_secs(60))),
portmap_table: Arc::from(RwLock::from(PortmapTable::default())),
require_privileged_source_port: false,
})
}
pub fn with_export_name<S: AsRef<str>>(&mut self, export_name: S) {
self.export_name = Arc::new(format!(
"/{}",
export_name.as_ref().trim_end_matches('/').trim_start_matches('/')
));
}
pub fn require_privileged_source_port(&mut self, require: bool) {
self.require_privileged_source_port = require;
}
}
#[async_trait]
impl<T: NFSFileSystem + Send + Sync + 'static> NFSTcp for NFSTcpListener<T> {
fn get_listen_port(&self) -> u16 {
let addr = self.listener.local_addr().unwrap();
addr.port()
}
fn get_listen_ip(&self) -> IpAddr {
let addr = self.listener.local_addr().unwrap();
addr.ip()
}
fn set_mount_listener(&mut self, signal: mpsc::Sender<bool>) {
self.mount_signal = Some(signal);
}
async fn handle_forever(&self) -> io::Result<()> {
loop {
let (socket, peer_addr) = self.listener.accept().await?;
if self.require_privileged_source_port && peer_addr.port() >= 1024 {
warn!(
"Rejecting connection from {}: source port {} is not privileged",
peer_addr.ip(),
peer_addr.port()
);
continue;
}
let context = rpc::Context {
local_port: self.port,
client_addr: peer_addr.to_string(),
auth: xdr::rpc::auth_unix::default(),
vfs: self.arcfs.clone(),
mount_signal: self.mount_signal.clone(),
export_name: self.export_name.clone(),
transaction_tracker: self.transaction_tracker.clone(),
portmap_table: self.portmap_table.clone(),
};
info!("Accepting connection from {}", context.client_addr);
debug!("Accepting socket {:?} {:?}", socket, context);
tokio::spawn(async move {
let _ = process_socket(socket, context).await;
});
}
}
}