1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
use crate::{check_len, Error, Result, TryRead, TryWrite};

/// Context for &[u8] to determine where the slice ends.
///
/// Pattern will be included in the result
///
/// # Example
///
/// ```
/// use byte::*;
/// use byte::ctx::*;
///
/// let bytes: &[u8] = &[0xde, 0xad, 0xbe, 0xef, 0x00, 0xff];
///
/// let sub: &[u8] = bytes.read_with(&mut 0, Bytes::Len(2)).unwrap();
/// assert_eq!(sub, &[0xde, 0xad]);
///
/// static PATTERN: &'static [u8; 2] = &[0x00, 0xff];
///
/// let sub: &[u8] = bytes.read_with(&mut 0, Bytes::Pattern(PATTERN)).unwrap();
/// assert_eq!(sub, &[0xde, 0xad, 0xbe, 0xef, 0x00, 0xff]);
///
/// let sub: &[u8] = bytes.read_with(&mut 0, Bytes::PatternUntil(PATTERN, 4)).unwrap();
/// assert_eq!(sub, &[0xde, 0xad, 0xbe, 0xef]);
/// ```
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum Bytes {
    /// Take fix-length bytes
    Len(usize),
    /// Take bytes until reaching a byte pattern
    Pattern(&'static [u8]),
    /// Take bytes until either byte pattern or length reached
    PatternUntil(&'static [u8], usize),
}

impl<'a> TryRead<'a, Bytes> for &'a [u8] {
    #[inline]
    fn try_read(bytes: &'a [u8], ctx: Bytes) -> Result<(Self, usize)> {
        let len = match ctx {
            Bytes::Len(len) => check_len(bytes, len)?,
            Bytes::Pattern(pattern) => {
                if pattern.is_empty() {
                    return Err(Error::BadInput {
                        err: "Pattern is empty",
                    });
                }
                check_len(bytes, pattern.len())?;
                (0..bytes.len() - pattern.len() + 1)
                    .map(|n| bytes[n..].starts_with(pattern))
                    .position(|p| p)
                    .map(|len| len + pattern.len())
                    .ok_or(Error::Incomplete)?
            }
            Bytes::PatternUntil(pattern, len) => {
                if pattern.is_empty() {
                    return Err(Error::BadInput {
                        err: "Pattern is empty",
                    });
                }
                if pattern.len() > len {
                    return Err(Error::BadInput {
                        err: "Pattern is longer than restricted length",
                    });
                }
                check_len(bytes, pattern.len())?;
                (0..bytes.len() - pattern.len() + 1)
                    .map(|n| bytes[n..].starts_with(pattern))
                    .take(len - pattern.len())
                    .position(|p| p)
                    .map(|position| position + pattern.len())
                    .unwrap_or(check_len(bytes, len)?)
            }
        };

        Ok((&bytes[..len], len))
    }
}

impl<'a> TryWrite for &'a [u8] {
    #[inline]
    fn try_write(self, bytes: &mut [u8], _ctx: ()) -> Result<usize> {
        check_len(bytes, self.len())?;

        bytes[..self.len()].clone_from_slice(self);

        Ok(self.len())
    }
}