#[cfg(feature = "alloc")]
use alloc::string::String;
#[cfg(feature = "alloc")]
use alloc::{vec, vec::Vec};
use core::fmt;
#[cfg(feature = "alloc")]
use crate::WriteBuf;
use crate::{
InvalidFieldErr, NotEnoughBytesErr, OtherErr, UnexpectedMessageTypeErr, UnsupportedValueErr, UnsupportedVersionErr,
WriteCursor,
};
pub type EncodeResult<T> = Result<T, EncodeError>;
pub type EncodeError = ironrdp_error::Error<EncodeErrorKind>;
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum EncodeErrorKind {
NotEnoughBytes {
received: usize,
expected: usize,
},
InvalidField {
field: &'static str,
reason: &'static str,
},
UnexpectedMessageType {
got: u8,
},
UnsupportedVersion {
got: u8,
},
#[cfg(feature = "alloc")]
UnsupportedValue {
name: &'static str,
value: String,
},
#[cfg(not(feature = "alloc"))]
UnsupportedValue {
name: &'static str,
},
Other {
description: &'static str,
},
}
#[cfg(feature = "std")]
impl std::error::Error for EncodeErrorKind {}
impl fmt::Display for EncodeErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotEnoughBytes { received, expected } => write!(
f,
"not enough bytes provided to decode: received {received} bytes, expected {expected} bytes"
),
Self::InvalidField { field, reason } => {
write!(f, "invalid `{field}`: {reason}")
}
Self::UnexpectedMessageType { got } => {
write!(f, "invalid message type ({got})")
}
Self::UnsupportedVersion { got } => {
write!(f, "unsupported version ({got})")
}
#[cfg(feature = "alloc")]
Self::UnsupportedValue { name, value } => {
write!(f, "unsupported {name} ({value})")
}
#[cfg(not(feature = "alloc"))]
Self::UnsupportedValue { name } => {
write!(f, "unsupported {name}")
}
Self::Other { description } => {
write!(f, "other ({description})")
}
}
}
}
impl NotEnoughBytesErr for EncodeError {
fn not_enough_bytes(context: &'static str, received: usize, expected: usize) -> Self {
Self::new(context, EncodeErrorKind::NotEnoughBytes { received, expected })
}
}
impl InvalidFieldErr for EncodeError {
fn invalid_field(context: &'static str, field: &'static str, reason: &'static str) -> Self {
Self::new(context, EncodeErrorKind::InvalidField { field, reason })
}
}
impl UnexpectedMessageTypeErr for EncodeError {
fn unexpected_message_type(context: &'static str, got: u8) -> Self {
Self::new(context, EncodeErrorKind::UnexpectedMessageType { got })
}
}
impl UnsupportedVersionErr for EncodeError {
fn unsupported_version(context: &'static str, got: u8) -> Self {
Self::new(context, EncodeErrorKind::UnsupportedVersion { got })
}
}
impl UnsupportedValueErr for EncodeError {
#[cfg(feature = "alloc")]
fn unsupported_value(context: &'static str, name: &'static str, value: String) -> Self {
Self::new(context, EncodeErrorKind::UnsupportedValue { name, value })
}
#[cfg(not(feature = "alloc"))]
fn unsupported_value(context: &'static str, name: &'static str) -> Self {
Self::new(context, EncodeErrorKind::UnsupportedValue { name })
}
}
impl OtherErr for EncodeError {
fn other(context: &'static str, description: &'static str) -> Self {
Self::new(context, EncodeErrorKind::Other { description })
}
}
pub trait Encode {
fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()>;
fn name(&self) -> &'static str;
fn size(&self) -> usize;
}
crate::assert_obj_safe!(Encode);
pub fn encode<T>(pdu: &T, dst: &mut [u8]) -> EncodeResult<usize>
where
T: Encode + ?Sized,
{
let mut cursor = WriteCursor::new(dst);
encode_cursor(pdu, &mut cursor)?;
Ok(cursor.pos())
}
pub fn encode_cursor<T>(pdu: &T, dst: &mut WriteCursor<'_>) -> EncodeResult<()>
where
T: Encode + ?Sized,
{
pdu.encode(dst)
}
#[cfg(feature = "alloc")]
pub fn encode_buf<T>(pdu: &T, buf: &mut WriteBuf) -> EncodeResult<usize>
where
T: Encode + ?Sized,
{
let pdu_size = pdu.size();
let dst = buf.unfilled_to(pdu_size);
let written = encode(pdu, dst)?;
debug_assert_eq!(written, pdu_size);
buf.advance(written);
Ok(written)
}
#[cfg(any(feature = "alloc", test))]
pub fn encode_vec<T>(pdu: &T) -> EncodeResult<Vec<u8>>
where
T: Encode + ?Sized,
{
let pdu_size = pdu.size();
let mut buf = vec![0; pdu_size];
let written = encode(pdu, buf.as_mut_slice())?;
debug_assert_eq!(written, pdu_size);
Ok(buf)
}
pub fn name<T: Encode>(pdu: &T) -> &'static str {
pdu.name()
}
pub fn size<T: Encode>(pdu: &T) -> usize {
pdu.size()
}
#[cfg(feature = "alloc")]
mod legacy {
use super::{Encode, EncodeResult};
use crate::WriteCursor;
impl Encode for alloc::vec::Vec<u8> {
fn encode(&self, dst: &mut WriteCursor<'_>) -> EncodeResult<()> {
ensure_size!(in: dst, size: self.len());
dst.write_slice(self);
Ok(())
}
fn name(&self) -> &'static str {
"legacy-pdu-encode"
}
fn size(&self) -> usize {
self.len()
}
}
}