onc_rpc/
call_body.rs

1use std::{
2    convert::TryFrom,
3    io::{Cursor, Write},
4};
5
6use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
7
8use crate::{auth::AuthFlavor, Error};
9
10const RPC_VERSION: u32 = 2;
11
12/// A request invoking an RPC.
13///
14/// This structure is the Rust equivalent of the `call_body` structure defined
15/// in the [RFC](https://tools.ietf.org/html/rfc5531#section-9). The `rpcvers`
16/// field (representing the RPC protocol version) is hard coded to `2`.
17#[derive(Debug, PartialEq)]
18pub struct CallBody<T, P>
19where
20    T: AsRef<[u8]>,
21{
22    program: u32,
23    program_version: u32,
24    procedure: u32,
25
26    auth_credentials: AuthFlavor<T>,
27    auth_verifier: AuthFlavor<T>,
28
29    payload: P,
30}
31
32impl<'a> CallBody<&'a [u8], &'a [u8]> {
33    /// Constructs a new `CallBody` by parsing the wire format read from `r`.
34    ///
35    /// `from_cursor` advances the position of `r` to the end of the `CallBody`
36    /// structure.
37    pub(crate) fn from_cursor(r: &mut Cursor<&'a [u8]>) -> Result<Self, Error> {
38        // Read the RPC version and stop if it is not 2.
39        let rpc_version = r.read_u32::<BigEndian>()?;
40        if rpc_version != RPC_VERSION {
41            return Err(Error::InvalidRpcVersion(rpc_version));
42        }
43
44        let program = r.read_u32::<BigEndian>()?;
45        let program_version = r.read_u32::<BigEndian>()?;
46        let procedure = r.read_u32::<BigEndian>()?;
47        let auth_credentials = AuthFlavor::from_cursor(r)?;
48        let auth_verifier = AuthFlavor::from_cursor(r)?;
49
50        // NOTE: this payload does not use an Opaque as it is not defined as an
51        // opaque byte array (that necessitates padding) in the spec.
52
53        let data = *r.get_ref();
54        let start = r.position() as usize;
55        if start > data.len() {
56            return Err(Error::IncompleteHeader);
57        }
58
59        let payload = &data[start..];
60
61        Ok(CallBody {
62            program,
63            program_version,
64            procedure,
65            auth_credentials,
66            auth_verifier,
67            payload,
68        })
69    }
70}
71
72impl<T, P> CallBody<T, P>
73where
74    T: AsRef<[u8]>,
75    P: AsRef<[u8]>,
76{
77    /// Construct a new RPC invocation request.
78    pub fn new(
79        program: u32,
80        program_version: u32,
81        procedure: u32,
82        auth_credentials: AuthFlavor<T>,
83        auth_verifier: AuthFlavor<T>,
84        payload: P,
85    ) -> Self {
86        Self {
87            program,
88            program_version,
89            procedure,
90            auth_credentials,
91            auth_verifier,
92            payload,
93        }
94    }
95
96    /// Serialises this `CallBody` into `buf`, advancing the cursor position by
97    /// [`CallBody::serialised_len()`] bytes.
98    pub fn serialise_into<W: Write>(&self, mut buf: W) -> Result<(), std::io::Error> {
99        buf.write_u32::<BigEndian>(RPC_VERSION)?;
100        buf.write_u32::<BigEndian>(self.program)?;
101        buf.write_u32::<BigEndian>(self.program_version)?;
102        buf.write_u32::<BigEndian>(self.procedure)?;
103
104        self.auth_credentials.serialise_into(&mut buf)?;
105        self.auth_verifier.serialise_into(&mut buf)?;
106
107        buf.write_all(self.payload.as_ref())
108    }
109
110    /// Returns the on-wire length of this call body once serialised.
111    pub fn serialised_len(&self) -> u32 {
112        let mut l = std::mem::size_of::<u32>() * 4;
113
114        l += self.auth_credentials.serialised_len() as usize;
115        l += self.auth_verifier.serialised_len() as usize;
116        l += self.payload.as_ref().len();
117
118        l as u32
119    }
120
121    /// Returns the RPC version of this request.
122    ///
123    /// This crate supports the ONC RPC version 2 only.
124    pub fn rpc_version(&self) -> u32 {
125        2
126    }
127
128    /// Returns the program identifier in this request.
129    pub fn program(&self) -> u32 {
130        self.program
131    }
132
133    /// The version of the program to be invoked.
134    pub fn program_version(&self) -> u32 {
135        self.program_version
136    }
137
138    /// The program procedure number identifying the RPC to invoke.
139    pub fn procedure(&self) -> u32 {
140        self.procedure
141    }
142
143    /// The credentials to use for authenticating the request.
144    pub fn auth_credentials(&self) -> &AuthFlavor<T> {
145        &self.auth_credentials
146    }
147
148    /// The verifier that should be used to validate the authentication
149    /// credentials.
150    ///
151    /// The RFC says the following about the verifier:
152    /// ```text
153    /// The purpose of the authentication verifier is to validate the
154    /// authentication credential.  Note that these two items are
155    /// historically separate, but are always used together as one logical
156    /// entity.
157    /// ```
158    pub fn auth_verifier(&self) -> &AuthFlavor<T> {
159        &self.auth_verifier
160    }
161
162    /// Returns a reference to the opaque message payload bytes.
163    pub fn payload(&self) -> &P {
164        &self.payload
165    }
166}
167
168impl<'a> TryFrom<&'a [u8]> for CallBody<&'a [u8], &'a [u8]> {
169    type Error = Error;
170
171    fn try_from(v: &'a [u8]) -> Result<Self, Self::Error> {
172        let mut c = Cursor::new(v);
173        CallBody::from_cursor(&mut c)
174    }
175}
176
177#[cfg(feature = "bytes")]
178impl TryFrom<crate::Bytes> for CallBody<crate::Bytes, crate::Bytes> {
179    type Error = Error;
180
181    fn try_from(mut v: crate::Bytes) -> Result<Self, Self::Error> {
182        use crate::{bytes_ext::BytesReaderExt, Buf};
183
184        let rpc_version = v.try_u32()?;
185        if rpc_version != RPC_VERSION {
186            return Err(Error::InvalidRpcVersion(rpc_version));
187        }
188
189        let program = v.try_u32()?;
190        let program_version = v.try_u32()?;
191        let procedure = v.try_u32()?;
192
193        // Deserialise the auth flavor using a copy of v, and then advance the
194        // pointer in v.
195        let auth_credentials = AuthFlavor::try_from(v.clone())?;
196        v.advance(auth_credentials.serialised_len() as usize);
197
198        let auth_verifier = AuthFlavor::try_from(v.clone())?;
199        v.advance(auth_verifier.serialised_len() as usize);
200
201        Ok(Self {
202            program,
203            program_version,
204            procedure,
205            auth_credentials,
206            auth_verifier,
207            payload: v,
208        })
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    // A compile-time test that ensures a payload can differ in type from the
217    // auth buffer.
218    #[test]
219    fn test_differing_payload_type() {
220        let binding = vec![42];
221        let auth = AuthFlavor::AuthNone(Some(binding.as_slice()));
222        let payload = [42, 42, 42, 42];
223
224        let _call: CallBody<&[u8], &[u8; 4]> =
225            CallBody::new(100000, 42, 13, auth.clone(), auth, &payload);
226    }
227}