sha3_utils/
enc.rs

1#![allow(
2    clippy::indexing_slicing,
3    reason = "The compiler can prove that the indices are in bounds"
4)]
5#![allow(
6    clippy::arithmetic_side_effects,
7    reason = "All arithmetic is in bounds"
8)]
9
10use core::{
11    array, fmt,
12    hash::Hash,
13    hint,
14    iter::{ExactSizeIterator, FusedIterator},
15    slice,
16};
17
18#[cfg(feature = "no-panic")]
19use no_panic::no_panic;
20use zerocopy::{ByteEq, ByteHash, Immutable, IntoBytes, KnownLayout, Unaligned};
21
22/// The size in bytes of [`usize`].
23pub(crate) const USIZE_BYTES: usize = size_of::<usize>();
24
25// This is silly, but it ensures that we're always in
26//    [0, ((2^2040)-1)/8]
27// which is required by SP 800-185, which requires that
28// `left_encode`, `right_encode`, etc. accept integers up to
29// (2^2040)-1.
30//
31// Divide by 8 because of the `*_bytes` routines.
32const _: () = assert!(USIZE_BYTES <= 255);
33
34/// Encodes `x` as a byte string in a way that can be
35/// unambiguously parsed from the beginning.
36#[inline]
37pub const fn left_encode(mut x: usize) -> LeftEncode {
38    // `x|1` ensures that `n < USIZE_BYTES`. It's cheaper than
39    // using a conditional.
40    let n = (x | 1).leading_zeros() / 8;
41    // Shift into the leading zeros so that we write everything
42    // at the start of the buffer. This lets us use constants for
43    // writing, as well as lets us use fixed-size writes (see
44    // `bytepad_blocks`, etc.).
45    x <<= n * 8;
46
47    LeftEncode(LeftEncodeRepr {
48        n: (USIZE_BYTES - n as usize) as u8,
49        w: x.to_be(),
50    })
51}
52
53/// The result of [`left_encode`].
54#[derive(Copy, Clone, Hash, Eq, PartialEq)]
55pub struct LeftEncode(LeftEncodeRepr);
56
57impl LeftEncode {
58    /// Returns the number of encoded bytes.
59    ///
60    /// The result is always non-zero.
61    #[inline]
62    #[allow(clippy::len_without_is_empty, reason = "Meaningless for this type")]
63    pub const fn len(&self) -> usize {
64        // SAFETY: See the invariant for `n` in `LeftEncodeRepr`.
65        unsafe { hint::assert_unchecked(self.0.n <= USIZE_BYTES as u8) }
66
67        (self.0.n + 1) as usize
68    }
69
70    /// Returns the encoded bytes.
71    ///
72    /// The result always has a non-zero length.
73    #[inline]
74    pub const fn as_bytes(&self) -> &[u8] {
75        // SAFETY: `self.len()` is in [1, USIZE_BYTES + 1].
76        unsafe { slice::from_raw_parts(self.as_fixed_bytes().as_ptr(), self.len()) }
77    }
78
79    pub(crate) const fn as_fixed_bytes(&self) -> &[u8; size_of::<Self>()] {
80        zerocopy::transmute_ref!(&self.0)
81    }
82}
83
84impl AsRef<[u8]> for LeftEncode {
85    #[inline]
86    fn as_ref(&self) -> &[u8] {
87        self.as_bytes()
88    }
89}
90
91impl fmt::Debug for LeftEncode {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        f.debug_tuple("LeftEncode").field(&self.as_bytes()).finish()
94    }
95}
96
97#[repr(C, packed)]
98#[derive(Copy, Clone, KnownLayout, Immutable, Unaligned, IntoBytes, ByteEq, ByteHash)]
99struct LeftEncodeRepr {
100    /// Invariant: `n` is in [0, USIZE_BYTES]
101    n: u8,
102    w: usize,
103}
104const _: () = {
105    assert!(size_of::<LeftEncodeRepr>() == 1 + USIZE_BYTES);
106};
107
108/// Encodes `x*8` as a byte string in a way that can be
109/// unambiguously parsed from the beginning.
110///
111/// # Rationale
112///
113/// [`left_encode`] is typically used to encode a length in
114/// *bits*. In practice, we usually have a length in *bytes*. The
115/// conversion from bytes to bits might overflow if the number of
116/// bytes is large. This method avoids overflowing.
117///
118/// # Example
119///
120/// ```rust
121/// use sha3_utils::{left_encode, left_encode_bytes};
122///
123/// assert_eq!(
124///     left_encode(8192 * 8).as_bytes(),
125///     left_encode_bytes(8192).as_bytes(),
126/// );
127///
128/// // usize::MAX*8 overflows, causing an incorrect result.
129/// assert_ne!(
130///     left_encode(usize::MAX.wrapping_mul(8)).as_bytes(),
131///     left_encode_bytes(usize::MAX).as_bytes(),
132/// );
133/// ```
134#[inline]
135pub const fn left_encode_bytes(x: usize) -> LeftEncodeBytes {
136    // Break `x*8` into double word arithmetic.
137    let mut hi = (x >> (usize::BITS - 3)) as u8;
138    let mut lo = x << 3;
139
140    let n = if hi == 0 {
141        // `lo|1` ensures that `n < USIZE_BYTES`. It's cheaper
142        // than using a conditional.
143        let n = (lo | 1).leading_zeros() / 8;
144        lo <<= n * 8;
145        // `hi == 0`, so we have one more leading byte to shift
146        // off.
147        hi = (lo >> (usize::BITS - 8)) as u8;
148        lo <<= 8;
149        (n + 1) as usize
150    } else {
151        0
152    };
153
154    LeftEncodeBytes(LeftEncodeBytesRepr {
155        n: (1 + USIZE_BYTES - n) as u8,
156        hi,
157        lo: lo.to_be(),
158    })
159}
160
161/// The result of [`left_encode_bytes`].
162#[derive(Copy, Clone, Hash, Eq, PartialEq)]
163pub struct LeftEncodeBytes(LeftEncodeBytesRepr);
164
165impl LeftEncodeBytes {
166    /// Returns the number of encoded bytes.
167    ///
168    /// The result is always non-zero.
169    #[inline]
170    #[allow(clippy::len_without_is_empty, reason = "Meaningless for this type")]
171    pub const fn len(&self) -> usize {
172        // SAFETY: See the invariant for `n` in
173        // `LeftEncodeBytesRepr`.
174        unsafe { hint::assert_unchecked(self.0.n <= (USIZE_BYTES + 1) as u8) }
175
176        (self.0.n + 1) as usize
177    }
178
179    /// Returns the encoded bytes.
180    ///
181    /// The result always has a non-zero length.
182    #[inline]
183    pub const fn as_bytes(&self) -> &[u8] {
184        // SAFETY: `self.len()` is in [1, USIZE_BYTES + 2].
185        unsafe { slice::from_raw_parts(self.as_fixed_bytes().as_ptr(), self.len()) }
186    }
187
188    pub(crate) const fn as_fixed_bytes(&self) -> &[u8; size_of::<Self>()] {
189        zerocopy::transmute_ref!(&self.0)
190    }
191}
192
193impl AsRef<[u8]> for LeftEncodeBytes {
194    #[inline]
195    fn as_ref(&self) -> &[u8] {
196        self.as_bytes()
197    }
198}
199
200impl fmt::Debug for LeftEncodeBytes {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        f.debug_tuple("LeftEncodeBytes")
203            .field(&self.as_bytes())
204            .finish()
205    }
206}
207
208#[repr(C, packed)]
209#[derive(Copy, Clone, KnownLayout, Immutable, Unaligned, IntoBytes, ByteEq, ByteHash)]
210struct LeftEncodeBytesRepr {
211    /// Invariant: `n` is in [0, USIZE_BYTES+1]
212    n: u8,
213    hi: u8,
214    lo: usize,
215}
216const _: () = {
217    assert!(size_of::<LeftEncodeBytesRepr>() == 2 + USIZE_BYTES);
218};
219
220/// Encodes `x` as a byte string in a way that can be
221/// unambiguously parsed from the end.
222#[inline]
223pub const fn right_encode(x: usize) -> RightEncode {
224    // `x|1` ensures that `n < USIZE_BYTES`. It's cheaper than
225    // using a conditional.
226    let n = (x | 1).leading_zeros() / 8;
227
228    RightEncode(RightEncodeRepr {
229        w: x.to_be(),
230        n: (USIZE_BYTES - n as usize) as u8,
231    })
232}
233
234/// The result of [`right_encode`].
235#[derive(Copy, Clone, Hash, Eq, PartialEq)]
236pub struct RightEncode(RightEncodeRepr);
237
238impl RightEncode {
239    /// Returns the number of encoded bytes.
240    ///
241    /// The result is always non-zero.
242    #[inline]
243    #[allow(clippy::len_without_is_empty, reason = "Meaningless for this type")]
244    pub const fn len(&self) -> usize {
245        let n = self.0.n as usize;
246        self.as_fixed_bytes().len() - 1 - n
247    }
248
249    /// Returns the encoded bytes.
250    ///
251    /// The result always has a non-zero length.
252    #[inline]
253    pub const fn as_bytes(&self) -> &[u8] {
254        let buf = self.as_fixed_bytes();
255        let off = self.len();
256        let len = buf.len() - off;
257
258        // SAFETY: `self.len()` is in [1, self.buf.len()).
259        unsafe { slice::from_raw_parts(buf.as_ptr().add(off), len) }
260    }
261
262    const fn as_fixed_bytes(&self) -> &[u8; size_of::<Self>()] {
263        zerocopy::transmute_ref!(&self.0)
264    }
265}
266
267impl AsRef<[u8]> for RightEncode {
268    #[inline]
269    fn as_ref(&self) -> &[u8] {
270        self.as_bytes()
271    }
272}
273
274impl fmt::Debug for RightEncode {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        f.debug_tuple("RightEncode")
277            .field(&self.as_bytes())
278            .finish()
279    }
280}
281
282#[repr(C, packed)]
283#[derive(Copy, Clone, KnownLayout, Immutable, Unaligned, IntoBytes, ByteEq, ByteHash)]
284struct RightEncodeRepr {
285    w: usize,
286    /// Invariant: `n` is in [0, USIZE_BYTES]
287    n: u8,
288}
289const _: () = {
290    assert!(size_of::<RightEncodeRepr>() == USIZE_BYTES + 1);
291};
292
293/// Encodes `x*8` as a byte string in a way that can be
294/// unambiguously parsed from the beginning.
295///
296/// # Rationale
297///
298/// [`right_encode`] is typically used to encode a length in
299/// *bits*. In practice, we usually have a length in *bytes*. The
300/// conversion from bytes to bits might overflow if the number of
301/// bytes is large. This method avoids overflowing.
302///
303/// # Example
304///
305/// ```rust
306/// use sha3_utils::{right_encode, right_encode_bytes};
307///
308/// assert_eq!(
309///     right_encode(8192 * 8).as_bytes(),
310///     right_encode_bytes(8192).as_bytes(),
311/// );
312///
313/// // usize::MAX*8 overflows, causing an incorrect result.
314/// assert_ne!(
315///     right_encode(usize::MAX.wrapping_mul(8)).as_bytes(),
316///     right_encode_bytes(usize::MAX).as_bytes(),
317/// );
318/// ```
319#[inline]
320pub const fn right_encode_bytes(mut x: usize) -> RightEncodeBytes {
321    // Break `x*8` into double word arithmetic.
322    let hi = (x >> (usize::BITS - 3)) & 0x7;
323    x <<= 3;
324
325    // `x|1` ensures that `n < USIZE_BYTES`. It's cheaper than
326    // using a conditional.
327    let n = if hi == 0 {
328        1 + ((x | 1).leading_zeros() / 8)
329    } else {
330        0
331    };
332
333    RightEncodeBytes(RightEncodeBytesRepr {
334        hi: hi as u8,
335        lo: x.to_be(),
336        n: (1 + USIZE_BYTES - n as usize) as u8,
337    })
338}
339
340/// The result of [`right_encode_bytes`].
341#[derive(Copy, Clone, Hash, Eq, PartialEq)]
342pub struct RightEncodeBytes(RightEncodeBytesRepr);
343
344impl RightEncodeBytes {
345    /// Returns the number of encoded bytes.
346    ///
347    /// The result is always non-zero.
348    #[inline]
349    #[allow(clippy::len_without_is_empty, reason = "Meaningless for this type")]
350    pub const fn len(&self) -> usize {
351        let n = self.0.n as usize;
352        self.as_fixed_bytes().len() - 1 - n
353    }
354
355    /// Returns the encoded bytes.
356    ///
357    /// The result always has a non-zero length.
358    #[inline]
359    pub const fn as_bytes(&self) -> &[u8] {
360        let buf = self.as_fixed_bytes();
361        let off = self.len();
362        let len = buf.len() - off;
363
364        // SAFETY: `self.len()` is in [1, self.buf.len()).
365        unsafe { slice::from_raw_parts(buf.as_ptr().add(off), len) }
366    }
367
368    const fn as_fixed_bytes(&self) -> &[u8; size_of::<Self>()] {
369        zerocopy::transmute_ref!(&self.0)
370    }
371}
372
373impl AsRef<[u8]> for RightEncodeBytes {
374    #[inline]
375    fn as_ref(&self) -> &[u8] {
376        self.as_bytes()
377    }
378}
379
380impl fmt::Debug for RightEncodeBytes {
381    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382        f.debug_tuple("RightEncodeBytes")
383            .field(&self.as_bytes())
384            .finish()
385    }
386}
387
388#[repr(C, packed)]
389#[derive(Copy, Clone, KnownLayout, Immutable, Unaligned, IntoBytes, ByteEq, ByteHash)]
390struct RightEncodeBytesRepr {
391    hi: u8,
392    lo: usize,
393    /// Invariant: `n` is in [0, USIZE_BYTES+1]
394    n: u8,
395}
396const _: () = {
397    assert!(size_of::<RightEncodeBytesRepr>() == 1 + USIZE_BYTES + 1);
398};
399
400/// Encodes `s` such that it can be unambiguously encoded from
401/// the beginning.
402///
403/// This is the same thing as [`encode_string`], but evaluates to
404/// a constant `&[u8]`.
405///
406/// # Example
407///
408/// ```rust
409/// use sha3_utils::encode_string;
410///
411/// let s = encode_string!(b"hello, world!");
412/// assert_eq!(
413///     s,
414///     &[
415///         1, 104,
416///         104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33,
417///     ],
418/// );
419/// ```
420#[macro_export]
421macro_rules! encode_string {
422    ($s:expr) => {{
423        const S: &[u8] = $s;
424        const PREFIX: &[u8] = $crate::left_encode_bytes(S.len()).as_bytes();
425        const LENGTH: usize = PREFIX.len() + S.len();
426        const OUTPUT: [u8; LENGTH] = {
427            let mut buf = [0u8; PREFIX.len() + S.len()];
428            let mut i = 0;
429            let mut j = 0;
430            while j < PREFIX.len() {
431                buf[i] = PREFIX[j];
432                i += 1;
433                j += 1;
434            }
435            let mut j = 0;
436            while j < S.len() {
437                buf[i] = S[j];
438                i += 1;
439                j += 1;
440            }
441            buf
442        };
443        OUTPUT.as_slice()
444    }};
445}
446
447/// Encodes `s` such that it can be unambiguously encoded from
448/// the beginning.
449///
450/// # Example
451///
452/// ```rust
453/// use sha3_utils::encode_string;
454///
455/// let s = encode_string(b"hello, world!");
456/// assert_eq!(
457///     s.iter().flatten().copied().collect::<Vec<_>>(),
458///     &[
459///         1, 104,
460///         104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33,
461///     ],
462/// );
463/// ```
464#[inline]
465pub const fn encode_string(s: &[u8]) -> EncodedString<'_> {
466    let prefix = left_encode_bytes(s.len());
467    EncodedString { prefix, s }
468}
469
470/// The result of [`encode_string`].
471#[derive(Copy, Clone, Debug)]
472pub struct EncodedString<'a> {
473    prefix: LeftEncodeBytes,
474    s: &'a [u8],
475}
476
477impl EncodedString<'_> {
478    /// Returns the length of the encoded string.
479    ///
480    /// The result is always non-zero.
481    #[inline]
482    #[allow(clippy::len_without_is_empty, reason = "Meaningless for this type")]
483    pub const fn len(&self) -> usize {
484        self.prefix.len() + self.s.len()
485    }
486
487    /// Returns an iterator over the encoded string.
488    #[inline]
489    #[cfg_attr(feature = "no-panic", no_panic)]
490    pub fn iter(&self) -> EncodedStringIter<'_> {
491        EncodedStringIter {
492            iter: [self.prefix.as_bytes(), self.s].into_iter(),
493        }
494    }
495
496    /// Returns the two parts of the encoded string.
497    #[inline]
498    pub const fn as_parts(&self) -> (&LeftEncodeBytes, &[u8]) {
499        (&self.prefix, self.s)
500    }
501}
502
503impl<'a> EncodedString<'a> {
504    /// Returns the two parts of the encoded string.
505    #[inline]
506    pub const fn to_parts(self) -> (LeftEncodeBytes, &'a [u8]) {
507        (self.prefix, self.s)
508    }
509}
510
511impl<'a> IntoIterator for &'a EncodedString<'a> {
512    type Item = &'a [u8];
513    type IntoIter = EncodedStringIter<'a>;
514
515    #[inline]
516    #[cfg_attr(feature = "no-panic", no_panic)]
517    fn into_iter(self) -> Self::IntoIter {
518        self.iter()
519    }
520}
521
522/// An iterator over [`EncodedString`].
523#[derive(Clone, Debug)]
524pub struct EncodedStringIter<'a> {
525    iter: array::IntoIter<&'a [u8], 2>,
526}
527
528impl<'a> Iterator for EncodedStringIter<'a> {
529    type Item = &'a [u8];
530
531    #[inline]
532    fn next(&mut self) -> Option<Self::Item> {
533        self.iter.next()
534    }
535
536    #[inline]
537    fn count(self) -> usize {
538        self.iter.count()
539    }
540
541    #[inline]
542    fn fold<Acc, F>(self, acc: Acc, f: F) -> Acc
543    where
544        F: FnMut(Acc, Self::Item) -> Acc,
545    {
546        self.iter.fold(acc, f)
547    }
548
549    #[inline]
550    fn last(self) -> Option<Self::Item> {
551        self.iter.last()
552    }
553
554    #[inline]
555    fn nth(&mut self, n: usize) -> Option<Self::Item> {
556        self.iter.nth(n)
557    }
558
559    #[inline]
560    fn size_hint(&self) -> (usize, Option<usize>) {
561        self.iter.size_hint()
562    }
563}
564
565impl ExactSizeIterator for EncodedStringIter<'_> {
566    #[inline]
567    fn len(&self) -> usize {
568        self.iter.len()
569    }
570}
571
572impl FusedIterator for EncodedStringIter<'_> {}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    #[test]
579    fn test_left_encode() {
580        assert_eq!(left_encode(0).as_bytes(), &[1, 0], "#0");
581        for i in 0..usize::BITS {
582            let x: usize = 1 << i;
583            let mut want = vec![0; 1];
584            want.extend(x.to_be_bytes().iter().skip_while(|&&v| v == 0));
585            want[0] = (want.len() - 1) as u8;
586            assert_eq!(left_encode(x).as_bytes(), want, "#{x}");
587        }
588    }
589
590    #[test]
591    fn test_left_encode_bytes() {
592        for i in 0..usize::BITS {
593            let x: usize = 1 << i;
594            let mut want = vec![0; 1];
595            want.extend(
596                (8 * x as u128)
597                    .to_be_bytes()
598                    .iter()
599                    .skip_while(|&&v| v == 0),
600            );
601            want[0] = (want.len() - 1) as u8;
602            assert_eq!(left_encode_bytes(x).as_bytes(), want, "#{x}");
603        }
604    }
605
606    #[test]
607    fn test_right_encode() {
608        for i in 0..usize::BITS {
609            let x: usize = 1 << i;
610            let mut want = Vec::from_iter(x.to_be_bytes().iter().copied().skip_while(|&v| v == 0));
611            want.push(want.len() as u8);
612            assert_eq!(right_encode(x).as_bytes(), want, "#{x}");
613        }
614    }
615
616    #[test]
617    fn test_right_encode_bytes() {
618        for i in 0..usize::BITS {
619            let x: usize = 1 << i;
620            let mut want = Vec::from_iter(
621                (8 * x as u128)
622                    .to_be_bytes()
623                    .iter()
624                    .copied()
625                    .skip_while(|&v| v == 0),
626            );
627            want.push(want.len() as u8);
628            assert_eq!(right_encode_bytes(x).as_bytes(), want, "#{x}");
629        }
630    }
631
632    #[test]
633    fn test_encode_string() {
634        let want = encode_string(b"hello, world!")
635            .into_iter()
636            .flatten()
637            .copied()
638            .collect::<Vec<_>>();
639        let got = encode_string!(b"hello, world!");
640        assert_eq!(got, want);
641    }
642}