nfs3_client/
rpc.rs

1//! RPC client implementation
2
3use std::fmt::Debug;
4
5use nfs3_types::rpc::{
6    RPC_VERSION_2, accept_stat_data, call_body, fragment_header, msg_body, opaque_auth, reply_body,
7    rpc_msg,
8};
9use nfs3_types::xdr_codec::{Pack, Unpack};
10
11use crate::error::{Error, RpcError};
12use crate::io::{AsyncRead, AsyncWrite};
13
14/// RPC client
15pub struct RpcClient<IO> {
16    io: IO,
17    xid: u32,
18    credential: opaque_auth<'static>,
19    verifier: opaque_auth<'static>,
20}
21
22impl<IO> Debug for RpcClient<IO> {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
24        f.debug_struct("RpcClient").finish()
25    }
26}
27
28impl<IO> RpcClient<IO>
29where
30    IO: AsyncRead + AsyncWrite + Send,
31{
32    /// Create a new RPC client. XID is initialized to a random value.
33    pub fn new(io: IO) -> Self {
34        Self::new_with_auth(io, opaque_auth::default(), opaque_auth::default())
35    }
36
37    /// Create a new RPC client with custom credential and verifier.
38    pub fn new_with_auth(
39        io: IO,
40        credential: opaque_auth<'static>,
41        verifier: opaque_auth<'static>,
42    ) -> Self {
43        Self {
44            io,
45            xid: rand::random(),
46            credential,
47            verifier,
48        }
49    }
50
51    /// Call an RPC procedure
52    ///
53    /// This method uses `Pack` trait to serialize the arguments and `Unpack` trait to deserialize
54    /// the reply.
55    #[allow(clippy::similar_names)] // prog and proc are part of call_body struct
56    pub async fn call<C, R>(
57        &mut self,
58        prog: u32,
59        vers: u32,
60        proc: u32,
61        args: &C,
62    ) -> Result<R, Error>
63    where
64        R: Unpack,
65        C: Pack + Send + Sync,
66    {
67        let call = call_body {
68            rpcvers: RPC_VERSION_2,
69            prog,
70            vers,
71            proc,
72            cred: self.credential.borrow(),
73            verf: self.verifier.borrow(),
74        };
75        let msg = rpc_msg {
76            xid: self.xid,
77            body: msg_body::CALL(call),
78        };
79        self.xid = self.xid.wrapping_add(1);
80
81        Self::send_call(&mut self.io, &msg, args).await?;
82        Self::recv_reply::<R>(&mut self.io, msg.xid).await
83    }
84
85    async fn send_call<T>(io: &mut IO, msg: &rpc_msg<'_, '_>, args: &T) -> Result<(), Error>
86    where
87        T: Pack + Send + Sync,
88    {
89        let total_len = msg.packed_size() + args.packed_size();
90        if total_len % 4 != 0 {
91            return Err(RpcError::WrongLength.into());
92        }
93
94        let fragment_header = nfs3_types::rpc::fragment_header::new(
95            u32::try_from(total_len).expect("message is too large"),
96            true,
97        );
98        let mut buf = Vec::with_capacity(total_len + 4);
99        fragment_header.pack(&mut buf)?;
100        msg.pack(&mut buf)?;
101        args.pack(&mut buf)?;
102        if buf.len() - 4 != total_len {
103            return Err(RpcError::WrongLength.into());
104        }
105        io.async_write_all(&buf).await?;
106        Ok(())
107    }
108
109    async fn recv_reply<T>(io: &mut IO, xid: u32) -> Result<T, Error>
110    where
111        T: Unpack,
112    {
113        let mut buf = [0u8; 4];
114        io.async_read_exact(&mut buf).await?;
115        let fragment_header: fragment_header = buf.into();
116        assert!(
117            fragment_header.eof(),
118            "Fragment header does not have EOF flag"
119        );
120
121        let total_len = fragment_header.fragment_length();
122        let mut buf = vec![0u8; total_len as usize];
123        io.async_read_exact(&mut buf).await?;
124
125        let mut cursor = std::io::Cursor::new(buf);
126        let (resp_msg, _) = rpc_msg::unpack(&mut cursor)?;
127
128        if resp_msg.xid != xid {
129            return Err(RpcError::UnexpectedXid.into());
130        }
131
132        let reply = match resp_msg.body {
133            msg_body::REPLY(reply_body::MSG_ACCEPTED(reply)) => reply,
134            msg_body::REPLY(reply_body::MSG_DENIED(r)) => return Err(r.into()),
135            msg_body::CALL(_) => return Err(RpcError::UnexpectedCall.into()),
136        };
137
138        if !matches!(reply.reply_data, accept_stat_data::SUCCESS) {
139            return Err(crate::error::RpcError::try_from(reply.reply_data)
140                .expect("accept_stat_data::SUCCESS is not a valid error")
141                .into());
142        }
143
144        let (final_value, _) = T::unpack(&mut cursor)?;
145        if cursor.position() != u64::from(total_len) {
146            let pos = cursor.position();
147            return Err(RpcError::NotFullyParsed {
148                buf: cursor.into_inner(),
149                pos,
150            }
151            .into());
152        }
153        Ok(final_value)
154    }
155}