1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use crate::{
    errors::{Error, Result},
    types::BoltWireFormat,
    version::Version,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::mem;

pub const SMALL: u8 = 0xCC;
pub const MEDIUM: u8 = 0xCD;
pub const LARGE: u8 = 0xCE;

#[derive(Debug, PartialEq, Eq, Clone)]
pub struct BoltBytes {
    pub value: Bytes,
}

impl BoltBytes {
    pub fn new(value: Bytes) -> Self {
        BoltBytes { value }
    }

    pub fn len(&self) -> usize {
        self.value.len()
    }

    #[must_use]
    pub fn is_empty(&self) -> bool {
        self.value.is_empty()
    }
}

impl BoltWireFormat for BoltBytes {
    fn can_parse(_version: Version, input: &[u8]) -> bool {
        let marker = input[0];
        [SMALL, MEDIUM, LARGE].contains(&marker)
    }

    fn parse(_version: Version, input: &mut Bytes) -> Result<Self> {
        let marker = input.get_u8();
        let size = match marker {
            SMALL => input.get_u8() as usize,
            MEDIUM => input.get_u16() as usize,
            LARGE => input.get_u32() as usize,
            _ => {
                return Err(Error::InvalidTypeMarker(format!(
                    "invalid bytes marker {}",
                    marker
                )))
            }
        };

        let bytes = input.split_to(size);
        Ok(BoltBytes::new(bytes))
    }

    fn write_into(&self, _version: Version, bytes: &mut BytesMut) -> Result<()> {
        match &self.value {
            value if value.len() <= 255 => {
                bytes.reserve(2 + self.value.len());
                bytes.put_u8(SMALL);
                bytes.put_u8(value.len() as u8);
            }
            value if value.len() > 255 && value.len() <= 65_535 => {
                bytes.reserve(1 + mem::size_of::<u16>() + self.value.len());
                bytes.put_u8(MEDIUM);
                bytes.put_u16(value.len() as u16);
            }
            value if value.len() > 65_535 && value.len() <= 2_147_483_648 => {
                bytes.reserve(1 + mem::size_of::<u32>() + self.value.len());
                bytes.put_u8(LARGE);
                bytes.put_u32(value.len() as u32);
            }
            _ => return Err(Error::BytesTooBig),
        }
        bytes.put(&*self.value);
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn should_serialize_small_bytes() {
        let bolt_bytes = BoltBytes::new(Bytes::from_static("hello".as_bytes()));

        let mut serialized: Bytes = bolt_bytes.into_bytes(Version::V4_1).unwrap();

        assert_eq!(
            &serialized[..],
            Bytes::from_static(&[SMALL, 0x05, b'h', b'e', b'l', b'l', b'o'])
        );

        let deserialized: BoltBytes = BoltBytes::parse(Version::V4_1, &mut serialized).unwrap();

        assert_eq!(
            String::from_utf8(deserialized.value.to_vec()).unwrap(),
            "hello".to_owned()
        );
    }

    #[test]
    fn should_serialize_medium_bytes() {
        let raw_bytes = Bytes::copy_from_slice(&vec![0; 256]);
        let bolt_bytes = BoltBytes::new(raw_bytes.clone());
        let mut serialized: Bytes = bolt_bytes.into_bytes(Version::V4_1).unwrap();

        assert_eq!(serialized[0], MEDIUM);
        assert_eq!(u16::from_be_bytes([serialized[1], serialized[2]]), 256);

        let deserialized: BoltBytes = BoltBytes::parse(Version::V4_1, &mut serialized).unwrap();
        assert_eq!(deserialized.value, raw_bytes);
    }

    #[test]
    fn should_serialize_large_bytes() {
        let raw_bytes = Bytes::copy_from_slice(&vec![0; 65_537]);
        let bolt_bytes = BoltBytes::new(raw_bytes.clone());
        let mut serialized: Bytes = bolt_bytes.into_bytes(Version::V4_1).unwrap();

        assert_eq!(serialized[0], LARGE);
        assert_eq!(
            u32::from_be_bytes([serialized[1], serialized[2], serialized[3], serialized[4]]),
            65_537
        );

        let deserialized: BoltBytes = BoltBytes::parse(Version::V4_1, &mut serialized).unwrap();
        assert_eq!(deserialized.value, raw_bytes);
    }
}