ironrdp_core/
encode.rs

1#[cfg(feature = "alloc")]
2use alloc::string::String;
3#[cfg(feature = "alloc")]
4use alloc::{vec, vec::Vec};
5use core::fmt;
6
7#[cfg(feature = "alloc")]
8use crate::WriteBuf;
9use crate::{
10    InvalidFieldErr, NotEnoughBytesErr, OtherErr, UnexpectedMessageTypeErr, UnsupportedValueErr, UnsupportedVersionErr,
11    WriteCursor,
12};
13
14/// A result type for encoding operations, which can either succeed with a value of type `T`
15/// or fail with an [`EncodeError`].
16pub type EncodeResult<T> = Result<T, EncodeError>;
17
18/// An error type specifically for encoding operations, wrapping an [`EncodeErrorKind`].
19pub type EncodeError = ironrdp_error::Error<EncodeErrorKind>;
20
21/// Represents the different kinds of errors that can occur during encoding operations.
22#[non_exhaustive]
23#[derive(Clone, Debug)]
24pub enum EncodeErrorKind {
25    /// Indicates that there were not enough bytes to complete the encoding operation.
26    NotEnoughBytes {
27        /// The number of bytes actually received.
28        received: usize,
29        /// The number of bytes expected or required.
30        expected: usize,
31    },
32    /// Indicates that a field in the data being encoded is invalid.
33    InvalidField {
34        /// The name of the invalid field.
35        field: &'static str,
36        /// The reason why the field is considered invalid.
37        reason: &'static str,
38    },
39    /// Indicates that an unexpected message type was encountered during encoding.
40    UnexpectedMessageType {
41        /// The unexpected message type that was received.
42        got: u8,
43    },
44    /// Indicates that an unsupported version was encountered during encoding.
45    UnsupportedVersion {
46        /// The unsupported version that was received.
47        got: u8,
48    },
49    /// Indicates that an unsupported value was encountered during encoding.
50    #[cfg(feature = "alloc")]
51    UnsupportedValue {
52        /// The name of the field or parameter with the unsupported value.
53        name: &'static str,
54        /// The unsupported value that was received.
55        value: String,
56    },
57    /// Indicates that an unsupported value was encountered during encoding (no-alloc version).
58    #[cfg(not(feature = "alloc"))]
59    UnsupportedValue {
60        /// The name of the field or parameter with the unsupported value.
61        name: &'static str,
62    },
63    /// Represents any other error that doesn't fit into the above categories.
64    Other {
65        /// A description of the error.
66        description: &'static str,
67    },
68}
69
70#[cfg(feature = "std")]
71impl std::error::Error for EncodeErrorKind {}
72
73impl fmt::Display for EncodeErrorKind {
74    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75        match self {
76            Self::NotEnoughBytes { received, expected } => write!(
77                f,
78                "not enough bytes provided to decode: received {received} bytes, expected {expected} bytes"
79            ),
80            Self::InvalidField { field, reason } => {
81                write!(f, "invalid `{field}`: {reason}")
82            }
83            Self::UnexpectedMessageType { got } => {
84                write!(f, "invalid message type ({got})")
85            }
86            Self::UnsupportedVersion { got } => {
87                write!(f, "unsupported version ({got})")
88            }
89            #[cfg(feature = "alloc")]
90            Self::UnsupportedValue { name, value } => {
91                write!(f, "unsupported {name} ({value})")
92            }
93            #[cfg(not(feature = "alloc"))]
94            Self::UnsupportedValue { name } => {
95                write!(f, "unsupported {name}")
96            }
97            Self::Other { description } => {
98                write!(f, "other ({description})")
99            }
100        }
101    }
102}
103
104impl NotEnoughBytesErr for EncodeError {
105    fn not_enough_bytes(context: &'static str, received: usize, expected: usize) -> Self {
106        Self::new(context, EncodeErrorKind::NotEnoughBytes { received, expected })
107    }
108}
109
110impl InvalidFieldErr for EncodeError {
111    fn invalid_field(context: &'static str, field: &'static str, reason: &'static str) -> Self {
112        Self::new(context, EncodeErrorKind::InvalidField { field, reason })
113    }
114}
115
116impl UnexpectedMessageTypeErr for EncodeError {
117    fn unexpected_message_type(context: &'static str, got: u8) -> Self {
118        Self::new(context, EncodeErrorKind::UnexpectedMessageType { got })
119    }
120}
121
122impl UnsupportedVersionErr for EncodeError {
123    fn unsupported_version(context: &'static str, got: u8) -> Self {
124        Self::new(context, EncodeErrorKind::UnsupportedVersion { got })
125    }
126}
127
128impl UnsupportedValueErr for EncodeError {
129    #[cfg(feature = "alloc")]
130    fn unsupported_value(context: &'static str, name: &'static str, value: String) -> Self {
131        Self::new(context, EncodeErrorKind::UnsupportedValue { name, value })
132    }
133    #[cfg(not(feature = "alloc"))]
134    fn unsupported_value(context: &'static str, name: &'static str) -> Self {
135        Self::new(context, EncodeErrorKind::UnsupportedValue { name })
136    }
137}
138
139impl OtherErr for EncodeError {
140    fn other(context: &'static str, description: &'static str) -> Self {
141        Self::new(context, EncodeErrorKind::Other { description })
142    }
143}
144
145/// PDU that can be encoded into its binary form.
146///
147/// The resulting binary payload is a fully encoded PDU that may be sent to the peer.
148///
149/// This trait is object-safe and may be used in a dynamic context.
150pub trait Encode {
151    /// Encodes this PDU in-place using the provided `WriteCursor`.
152    fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()>;
153
154    /// Returns the associated PDU name associated.
155    fn name(&self) -> &'static str;
156
157    /// Computes the size in bytes for this PDU.
158    fn size(&self) -> usize;
159}
160
161crate::assert_obj_safe!(Encode);
162
163/// Encodes the given PDU in-place into the provided buffer and returns the number of bytes written.
164pub fn encode<T>(pdu: &T, dst: &mut [u8]) -> EncodeResult<usize>
165where
166    T: Encode + ?Sized,
167{
168    let mut cursor = WriteCursor::new(dst);
169    encode_cursor(pdu, &mut cursor)?;
170    Ok(cursor.pos())
171}
172
173/// Encodes the given PDU in-place using the provided `WriteCursor`.
174pub fn encode_cursor<T>(pdu: &T, dst: &mut WriteCursor<'_>) -> EncodeResult<()>
175where
176    T: Encode + ?Sized,
177{
178    pdu.encode(dst)
179}
180
181/// Same as `encode` but resizes the buffer when it is too small to fit the PDU.
182#[cfg(feature = "alloc")]
183pub fn encode_buf<T>(pdu: &T, buf: &mut WriteBuf) -> EncodeResult<usize>
184where
185    T: Encode + ?Sized,
186{
187    let pdu_size = pdu.size();
188    let dst = buf.unfilled_to(pdu_size);
189    let written = encode(pdu, dst)?;
190    debug_assert_eq!(written, pdu_size);
191    buf.advance(written);
192    Ok(written)
193}
194
195/// Same as `encode` but allocates and returns a new buffer each time.
196///
197/// This is a convenience function, but it’s not very resource efficient.
198#[cfg(any(feature = "alloc", test))]
199pub fn encode_vec<T>(pdu: &T) -> EncodeResult<Vec<u8>>
200where
201    T: Encode + ?Sized,
202{
203    let pdu_size = pdu.size();
204    let mut buf = vec![0; pdu_size];
205    let written = encode(pdu, buf.as_mut_slice())?;
206    debug_assert_eq!(written, pdu_size);
207    Ok(buf)
208}
209
210/// Gets the name of this PDU.
211pub fn name<T: Encode>(pdu: &T) -> &'static str {
212    pdu.name()
213}
214
215/// Computes the size in bytes for this PDU.
216pub fn size<T: Encode>(pdu: &T) -> usize {
217    pdu.size()
218}
219
220#[cfg(feature = "alloc")]
221mod legacy {
222    use super::{Encode, EncodeResult};
223    use crate::WriteCursor;
224
225    impl Encode for alloc::vec::Vec<u8> {
226        fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
227            ensure_size!(in: dst, size: self.len());
228
229            dst.write_slice(self);
230            Ok(())
231        }
232
233        /// Returns the associated PDU name associated.
234        fn name(&self) -> &'static str {
235            "legacy-pdu-encode"
236        }
237
238        /// Computes the size in bytes for this PDU.
239        fn size(&self) -> usize {
240            self.len()
241        }
242    }
243}