1use std::fmt::Debug;
4
5use nfs3_types::rpc::{
6 RPC_VERSION_2, accept_stat_data, call_body, fragment_header, msg_body, opaque_auth, reply_body,
7 rpc_msg,
8};
9use nfs3_types::xdr_codec::{Pack, Unpack};
10
11use crate::error::{Error, RpcError};
12use crate::io::{AsyncRead, AsyncWrite};
13
14pub struct RpcClient<IO> {
16 io: IO,
17 xid: u32,
18 credential: opaque_auth<'static>,
19 verifier: opaque_auth<'static>,
20}
21
22impl<IO> Debug for RpcClient<IO> {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
24 f.debug_struct("RpcClient").finish()
25 }
26}
27
28impl<IO> RpcClient<IO>
29where
30 IO: AsyncRead + AsyncWrite + Send,
31{
32 pub fn new(io: IO) -> Self {
34 Self::new_with_auth(io, opaque_auth::default(), opaque_auth::default())
35 }
36
37 pub fn new_with_auth(
39 io: IO,
40 credential: opaque_auth<'static>,
41 verifier: opaque_auth<'static>,
42 ) -> Self {
43 Self {
44 io,
45 xid: rand::random(),
46 credential,
47 verifier,
48 }
49 }
50
51 #[allow(clippy::similar_names)] pub async fn call<C, R>(
57 &mut self,
58 prog: u32,
59 vers: u32,
60 proc: u32,
61 args: &C,
62 ) -> Result<R, Error>
63 where
64 R: Unpack,
65 C: Pack + Send + Sync,
66 {
67 let call = call_body {
68 rpcvers: RPC_VERSION_2,
69 prog,
70 vers,
71 proc,
72 cred: self.credential.borrow(),
73 verf: self.verifier.borrow(),
74 };
75 let msg = rpc_msg {
76 xid: self.xid,
77 body: msg_body::CALL(call),
78 };
79 self.xid = self.xid.wrapping_add(1);
80
81 Self::send_call(&mut self.io, &msg, args).await?;
82 Self::recv_reply::<R>(&mut self.io, msg.xid).await
83 }
84
85 async fn send_call<T>(io: &mut IO, msg: &rpc_msg<'_, '_>, args: &T) -> Result<(), Error>
86 where
87 T: Pack + Send + Sync,
88 {
89 let total_len = msg.packed_size() + args.packed_size();
90 if total_len % 4 != 0 {
91 return Err(RpcError::WrongLength.into());
92 }
93
94 let fragment_header = nfs3_types::rpc::fragment_header::new(
95 u32::try_from(total_len).expect("message is too large"),
96 true,
97 );
98 let mut buf = Vec::with_capacity(total_len + 4);
99 fragment_header.pack(&mut buf)?;
100 msg.pack(&mut buf)?;
101 args.pack(&mut buf)?;
102 if buf.len() - 4 != total_len {
103 return Err(RpcError::WrongLength.into());
104 }
105 io.async_write_all(&buf).await?;
106 Ok(())
107 }
108
109 async fn recv_reply<T>(io: &mut IO, xid: u32) -> Result<T, Error>
110 where
111 T: Unpack,
112 {
113 let mut buf = [0u8; 4];
114 io.async_read_exact(&mut buf).await?;
115 let fragment_header: fragment_header = buf.into();
116 assert!(
117 fragment_header.eof(),
118 "Fragment header does not have EOF flag"
119 );
120
121 let total_len = fragment_header.fragment_length();
122 let mut buf = vec![0u8; total_len as usize];
123 io.async_read_exact(&mut buf).await?;
124
125 let mut cursor = std::io::Cursor::new(buf);
126 let (resp_msg, _) = rpc_msg::unpack(&mut cursor)?;
127
128 if resp_msg.xid != xid {
129 return Err(RpcError::UnexpectedXid.into());
130 }
131
132 let reply = match resp_msg.body {
133 msg_body::REPLY(reply_body::MSG_ACCEPTED(reply)) => reply,
134 msg_body::REPLY(reply_body::MSG_DENIED(r)) => return Err(r.into()),
135 msg_body::CALL(_) => return Err(RpcError::UnexpectedCall.into()),
136 };
137
138 if !matches!(reply.reply_data, accept_stat_data::SUCCESS) {
139 return Err(crate::error::RpcError::try_from(reply.reply_data)
140 .expect("accept_stat_data::SUCCESS is not a valid error")
141 .into());
142 }
143
144 let (final_value, _) = T::unpack(&mut cursor)?;
145 if cursor.position() != u64::from(total_len) {
146 let pos = cursor.position();
147 return Err(RpcError::NotFullyParsed {
148 buf: cursor.into_inner(),
149 pos,
150 }
151 .into());
152 }
153 Ok(final_value)
154 }
155}