nfs 0.1.0

A userspace NFSv3 and NFSv4 client library.
Documentation
use std::time::Duration;

use ::tokio::io::{AsyncReadExt, AsyncWriteExt};
use ::tokio::net::{TcpStream, ToSocketAddrs};

use crate::error::{Error, Result};
use crate::rpc::{
    Auth, DEFAULT_MAX_RECORD_SIZE, FRAGMENT_LEN_MASK, LAST_FRAGMENT, decode_reply, default_stamp,
    encode_call,
};
use crate::xdr::Encode;

#[derive(Debug)]
pub(crate) struct RpcClient {
    stream: TcpStream,
    xid: u32,
    auth: Auth,
    max_record_size: usize,
    timeout: Option<Duration>,
}

impl RpcClient {
    pub(crate) async fn connect_with_timeout<A: ToSocketAddrs>(
        addr: A,
        auth: Auth,
        timeout: Option<Duration>,
    ) -> Result<Self> {
        let stream = connect_tcp_stream(addr, timeout).await?;
        stream.set_nodelay(true)?;
        Ok(Self {
            stream,
            xid: default_stamp(),
            auth,
            max_record_size: DEFAULT_MAX_RECORD_SIZE,
            timeout,
        })
    }

    pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
        self.timeout = timeout;
    }

    pub(crate) fn set_max_record_size(&mut self, max_record_size: usize) {
        self.max_record_size = max_record_size;
    }

    pub(crate) async fn call<T: Encode + ?Sized>(
        &mut self,
        program: u32,
        version: u32,
        procedure: u32,
        args: &T,
    ) -> Result<Vec<u8>> {
        let xid = self.next_xid();
        let request = encode_call(xid, program, version, procedure, &self.auth, args)?;
        self.write_record(&request).await?;
        let reply = self.read_record().await?;
        decode_reply(xid, &reply)
    }

    fn next_xid(&mut self) -> u32 {
        self.xid = self.xid.wrapping_add(1);
        if self.xid == 0 {
            self.xid = 1;
        }
        self.xid
    }

    async fn write_record(&mut self, payload: &[u8]) -> Result<()> {
        if payload.len() > FRAGMENT_LEN_MASK as usize {
            return Err(Error::RpcRecordTooLarge {
                len: payload.len(),
                max: FRAGMENT_LEN_MASK as usize,
            });
        }

        let len = u32::try_from(payload.len()).map_err(|_| Error::RpcRecordTooLarge {
            len: payload.len(),
            max: FRAGMENT_LEN_MASK as usize,
        })?;
        let header = LAST_FRAGMENT | len;
        write_all(&mut self.stream, &header.to_be_bytes(), self.timeout).await?;
        write_all(&mut self.stream, payload, self.timeout).await?;
        flush(&mut self.stream, self.timeout).await
    }

    async fn read_record(&mut self) -> Result<Vec<u8>> {
        let mut record = Vec::new();
        loop {
            let mut header_bytes = [0; 4];
            read_exact(&mut self.stream, &mut header_bytes, self.timeout).await?;
            let header = u32::from_be_bytes(header_bytes);
            let is_last = (header & LAST_FRAGMENT) != 0;
            let fragment_len = (header & FRAGMENT_LEN_MASK) as usize;

            if record.len().saturating_add(fragment_len) > self.max_record_size {
                return Err(Error::RpcRecordTooLarge {
                    len: record.len().saturating_add(fragment_len),
                    max: self.max_record_size,
                });
            }

            let start = record.len();
            record.resize(start + fragment_len, 0);
            read_exact(&mut self.stream, &mut record[start..], self.timeout).await?;

            if is_last {
                return Ok(record);
            }
        }
    }
}

async fn connect_tcp_stream<A: ToSocketAddrs>(
    addr: A,
    timeout: Option<Duration>,
) -> Result<TcpStream> {
    if let Some(timeout) = timeout {
        ::tokio::time::timeout(timeout, TcpStream::connect(addr))
            .await
            .map_err(|_| timeout_error())?
            .map_err(Error::from)
    } else {
        TcpStream::connect(addr).await.map_err(Error::from)
    }
}

async fn write_all(stream: &mut TcpStream, buf: &[u8], timeout: Option<Duration>) -> Result<()> {
    if let Some(timeout) = timeout {
        ::tokio::time::timeout(timeout, stream.write_all(buf))
            .await
            .map_err(|_| timeout_error())??;
    } else {
        stream.write_all(buf).await?;
    }
    Ok(())
}

async fn flush(stream: &mut TcpStream, timeout: Option<Duration>) -> Result<()> {
    if let Some(timeout) = timeout {
        ::tokio::time::timeout(timeout, stream.flush())
            .await
            .map_err(|_| timeout_error())??;
    } else {
        stream.flush().await?;
    }
    Ok(())
}

async fn read_exact(
    stream: &mut TcpStream,
    buf: &mut [u8],
    timeout: Option<Duration>,
) -> Result<()> {
    if let Some(timeout) = timeout {
        ::tokio::time::timeout(timeout, stream.read_exact(buf))
            .await
            .map_err(|_| timeout_error())??;
    } else {
        stream.read_exact(buf).await?;
    }
    Ok(())
}

fn timeout_error() -> Error {
    Error::Io(std::io::Error::new(
        std::io::ErrorKind::TimedOut,
        "NFS async operation timed out",
    ))
}