crypto_permutation/
buffer.rs

1//! Potentially uninitialised buffers that guarantee that they are not
2//! deinitialised again after init.
3
4use super::io::{check_write_size, WriteTooLargeError, Writer};
5use core::mem::MaybeUninit;
6use core::slice::SliceIndex;
7
8/// Potentially uninitialised buffer that can never be deinitialised after it
9/// has been initialised.
10pub struct BufMut<'a> {
11    buf: &'a mut [MaybeUninit<u8>],
12}
13
14impl<'a> From<&'a mut [MaybeUninit<u8>]> for BufMut<'a> {
15    fn from(buf: &'a mut [MaybeUninit<u8>]) -> Self {
16        Self { buf }
17    }
18}
19
20// SAFETY: this conversion is safe since a [`BufMut`] cannot be used to
21// deinitialise the underlying memory.
22impl<'a> From<&'a mut [u8]> for BufMut<'a> {
23    fn from(slice: &'a mut [u8]) -> Self {
24        let ptr: *mut MaybeUninit<u8> = slice.as_mut_ptr().cast();
25        let len: usize = slice.len();
26        // SAFETY: `ptr` and `len` formed a slice to `u8`s, so definitely valid
27        // `MaybeUninit<u8>`s. SAFETY: we just have to make sure that the
28        // pointed-to bytes remain initialised but that is what the `Self`
29        // wrapper struct is for.
30        let buf = unsafe { core::slice::from_raw_parts_mut(ptr, len) };
31        Self { buf }
32    }
33}
34
35impl<'a> BufMut<'a> {
36    /// Length of the buffer.
37    pub fn len(&self) -> usize {
38        self.buf.len()
39    }
40
41    /// Copy non-overlapping memory from `buf` to `self`.
42    ///
43    /// Requires that `self.len() >= buf.len()`. Doesn't change where the buffer
44    /// `self` points to.
45    ///
46    /// # Errors
47    /// Errors when `buf.len() > self.buf.len()`, without doing any copying.
48    pub fn copy(&mut self, buf: &[u8]) -> Result<(), WriteTooLargeError> {
49        // SAFETY: `self` has unique mutable access to the buffer referenced by
50        // `self.buf`, so this cannot overlap with `buf`.
51        let _: &mut [MaybeUninit<u8>] = self.buf;
52
53        let len = buf.len();
54        check_write_size(len, self.len())?;
55
56        let src: *const u8 = buf.as_ptr();
57        let dst: *mut u8 = self.buf.as_mut_ptr().cast();
58        // SAFETY: `src` and `dst` don't overlap by the comment above; both slices have
59        // length at least `len`
60        unsafe {
61            core::ptr::copy_nonoverlapping(src, dst, len);
62        }
63
64        Ok(())
65    }
66
67    #[must_use]
68    pub fn reborrow<'b>(&'b mut self) -> BufMut<'b>
69    where
70        'a: 'b,
71    {
72        let buf = &mut self.buf;
73        BufMut { buf }
74    }
75
76    #[must_use = "for inplace mutation use `restrict_inplace` instead"]
77    pub fn restrict<'b, I>(&'b mut self, range: I) -> BufMut<'b>
78    where
79        'a: 'b,
80        I: SliceIndex<[MaybeUninit<u8>], Output = [MaybeUninit<u8>]>,
81    {
82        let reborrowed: &'b mut [MaybeUninit<u8>] = self.buf;
83        let buf = &mut reborrowed[range];
84        BufMut { buf }
85    }
86
87    pub fn restrict_inplace<'b, I>(&'b mut self, range: I)
88    where
89        'a: 'b,
90        I: SliceIndex<[MaybeUninit<u8>], Output = [MaybeUninit<u8>]>,
91    {
92        let mut buf = core::mem::take(&mut self.buf);
93        buf = &mut buf[range];
94        let _ = core::mem::replace(&mut self.buf, buf);
95    }
96}
97
98impl<'a> Writer for BufMut<'a> {
99    type Return = ();
100
101    fn capacity(&self) -> usize {
102        self.len()
103    }
104
105    fn skip(&mut self, n: usize) -> Result<(), WriteTooLargeError> {
106        check_write_size(n, self.capacity())?;
107        self.restrict_inplace(n..);
108        Ok(())
109    }
110
111    fn write_bytes(&mut self, data: &[u8]) -> Result<(), WriteTooLargeError> {
112        self.copy(data)?;
113        self.restrict_inplace(data.len()..);
114        Ok(())
115    }
116
117    /// No-op.
118    fn finish(self) -> Self::Return {}
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn writer_write() {
127        let mut buf = [0; 3];
128        let mut bufmut = BufMut::from(&mut buf[..]);
129        bufmut.write_bytes(&[1, 2, 3]).unwrap();
130        assert_eq!(buf, [1, 2, 3]);
131    }
132
133    #[test]
134    fn writer_write_out_of_bounds() {
135        let mut buf = [0; 3];
136        let mut bufmut = BufMut::from(&mut buf[..]);
137        let res = bufmut.write_bytes(&[1, 2, 3, 4]);
138        assert!(res.is_err());
139        assert_eq!(buf, [0; 3]);
140    }
141
142    #[test]
143    fn writer_write_write() {
144        let mut buf = [0; 5];
145        let mut bufmut = BufMut::from(&mut buf[..]);
146        bufmut.write_bytes(&[1, 2]).unwrap();
147        bufmut.write_bytes(&[3]).unwrap();
148        assert_eq!(buf, [1, 2, 3, 0, 0]);
149    }
150
151    #[test]
152    fn writer_skip() {
153        let mut buf = [0; 3];
154        let mut bufmut = BufMut::from(&mut buf[..]);
155        bufmut.skip(3).unwrap();
156        assert_eq!(buf, [0; 3]);
157    }
158
159    #[test]
160    fn writer_skip_out_of_bounds() {
161        let mut buf = [0; 3];
162        let mut bufmut = BufMut::from(&mut buf[..]);
163        let res = bufmut.skip(4);
164        assert!(res.is_err());
165    }
166
167    #[test]
168    fn writer_skip_capacity() {
169        let mut buf = [0; 5];
170        let mut bufmut = BufMut::from(&mut buf[..]);
171        assert_eq!(bufmut.capacity(), 5);
172        bufmut.skip(2).unwrap();
173        assert_eq!(bufmut.capacity(), 3);
174    }
175
176    #[test]
177    fn writer_skip_write() {
178        let mut buf = [0; 5];
179        let mut bufmut = BufMut::from(&mut buf[..]);
180        bufmut.skip(2).unwrap();
181        bufmut.write_bytes(&[1, 1]).unwrap();
182        assert_eq!(buf, [0, 0, 1, 1, 0]);
183    }
184}