djin_protocol/logic/
aligned.rs

1use crate::{hint, Error, Parcel, Settings};
2use std::io::prelude::*;
3use std::{marker, mem};
4
5/// A value that is aligned to a specified number of bytes.
6///
7/// When bytes are written, they are zero-padding at the end
8/// until the total size is the smallest multiple of the
9/// size of `ToSizeOfType`.
10///
11/// When an `Aligned` type is read, a value of the inner `T`
12/// is first read, and then the minimum number of zero bytes in
13/// order to maintain alignment are read and ignored.
14///
15/// Type parameters:
16///
17///   * `T` - The `Parcel` type that is to be transmitted
18///   * `ToSizeOfType` The transmitted bytes will be aligned to a multiple
19///     of `size_of::<ToSizeOfType>()`. For example, if `ToSizeOfType = u32`,
20///     then the written bytes will be aligned to a multiple of 4 bytes.
21///
22/// Examples:
23///
24/// ```
25/// extern crate djin_protocol;
26/// #[macro_use] extern crate djin_protocol_derive;
27/// use djin_protocol::Parcel;
28///
29/// /// An example packet with a length prefix disjoint
30/// /// from its data, with the data also
31/// #[derive(Protocol, Clone, Debug, PartialEq)]
32/// struct Packet {
33///     /// The length of the 'reason' string.
34///     pub reason_length: u8,
35///     /// The version number of the protocol.
36///     pub version_number: (u32, u32),
37///     #[protocol(length_prefix(bytes(reason_length)))]
38///     pub reason: djin_protocol::logic::Aligned<String, u64>,
39///
40/// }
41///
42/// let raw_bytes = Packet {
43///     reason_length: 12,
44///     version_number: (11, 0xdeadbeef),
45///     reason: "hello world!".to_owned().into(),
46/// }.raw_bytes(&djin_protocol::Settings::default()).unwrap();
47///
48/// assert_eq!(&[
49///     12, // reason length
50///     0, 0, 0, 11, 0xde, 0xad, 0xbe, 0xef, // version number
51///     // the string "hello world".
52///     b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', b'!',
53///     0x00, 0x00, 0x00, 0x00, // padding bytes to align to string to 16 bytes.
54///     ], &raw_bytes[..]);
55/// ```
56
57#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
58pub struct Aligned<T, ToSizeOfType>
59    where T: Parcel,
60          ToSizeOfType: Sized {
61    /// The inner value.
62    pub value: T,
63    _phantom: marker::PhantomData<ToSizeOfType>,
64}
65
66impl<T, ToSizeOfType> Aligned<T, ToSizeOfType>
67    where T: Parcel,
68          ToSizeOfType: Sized {
69    /// Creates a new aligned value.
70    pub fn new(value: T) -> Self {
71        Aligned { value, _phantom: marker::PhantomData }
72    }
73
74    /// Gets the number of bytes of the alignment.
75    pub fn align_to_bytes() -> usize {
76        mem::size_of::<ToSizeOfType>()
77    }
78}
79
80impl<T, ToSizeOfType> Parcel for Aligned<T, ToSizeOfType>
81    where T: Parcel,
82          ToSizeOfType: Sized {
83    const TYPE_NAME: &'static str = "Aligned";
84
85    fn read_field(read: &mut dyn Read,
86                  settings: &Settings,
87                  hints: &mut hint::Hints) -> Result<Self, Error> {
88        let inner_value = T::read_field(read, settings, hints)?;
89        let value_size = inner_value.raw_bytes_field(settings, hints).unwrap().len();
90        let padding_size = calculate_padding(Self::align_to_bytes(), value_size);
91
92        for _ in 0..padding_size {
93            let padding_byte = u8::read(read, settings)?;
94
95            // FIXME: promote to error.
96            assert_eq!(0x00, padding_byte, "padding bytes should be zero");
97        }
98
99        Ok(Aligned { value: inner_value, _phantom: marker::PhantomData })
100    }
101
102    fn write_field(&self,
103                   write: &mut dyn Write,
104                   settings: &Settings,
105                   hints: &mut hint::Hints) -> Result<(), Error> {
106        let unaligned_bytes = self.value.raw_bytes_field(settings, hints)?;
107        let aligned_bytes = align_to(Self::align_to_bytes(), 0x00, unaligned_bytes);
108        write.write(&aligned_bytes)?;
109        Ok(())
110    }
111}
112
113impl<T, ToSizeOfType> From<T> for Aligned<T, ToSizeOfType>
114    where T: Parcel,
115          ToSizeOfType: Sized {
116    fn from(value: T) -> Self {
117        Aligned { value, _phantom: marker::PhantomData }
118    }
119}
120
121/// Aligns a set of bytes to a multiple of the specified alignment.
122fn align_to(align_to: usize,
123            padding_byte: u8,
124            bytes: Vec<u8>) -> Vec<u8> {
125    // Thanks for the formula Ned!
126    // https://stackoverflow.com/a/11642218
127    let extra_padding_needed = calculate_padding(align_to, bytes.len());
128
129    let extra_padding = (0..).into_iter().take(extra_padding_needed).map(|_| padding_byte);
130
131    let bytes: Vec<_> = bytes.into_iter().chain(extra_padding).collect();
132    assert_eq!(0, bytes.len() % align_to,
133            "failed to align");
134    bytes
135}
136
137fn calculate_padding(align_to: usize,
138                     unaligned_size: usize) -> usize {
139    // Thanks for the formula Ned!
140    // https://stackoverflow.com/a/11642218
141    (align_to - (unaligned_size % align_to)) % align_to
142}
143
144#[cfg(test)]
145mod test {
146    use super::*;
147
148    mod alignment_calculations {
149        use super::*;
150
151        #[test]
152        fn test_aligning_when_none_needed() {
153            assert_eq!(vec![1, 2], align_to(1, 0x00, vec![1, 2]));
154            assert_eq!(vec![1, 2], align_to(2, 0x00, vec![1, 2]));
155        }
156
157        #[test]
158        fn test_align_to_3_with_size_2() {
159            assert_eq!(vec![1, 2, 0], align_to(3, 0x00, vec![1, 2]));
160        }
161
162        #[test]
163        fn test_align_to_4_with_size_2() {
164            assert_eq!(vec![1, 2, 0xff, 0xff], align_to(4, 0xff, vec![1, 2]));
165        }
166
167        #[test]
168        fn test_align_to_3_with_size_5() {
169            assert_eq!(vec![1, 2, 3, 4, 5, 0], align_to(3, 0x00, vec![1, 2, 3, 4, 5]));
170        }
171
172        #[test]
173        fn test_align_to_4_with_size_97() {
174            let original = [1; 97];
175            let aligned = align_to(4, 0x00, original.to_vec());
176
177            let count_ones = aligned.iter().filter(|&&i| i == 1).count();
178            let count_zeros = aligned.iter().filter(|&&i| i == 0).count();
179
180            assert_eq!(97, count_ones);
181            assert_eq!(3, count_zeros);
182        }
183    }
184}
185