gel_protocol/
encoding.rs

1use std::collections::HashMap;
2use std::convert::TryFrom;
3use std::ops::{Deref, DerefMut, RangeBounds};
4
5use bytes::{Buf, BufMut, Bytes, BytesMut};
6use snafu::{ensure, OptionExt, ResultExt};
7use uuid::Uuid;
8
9use crate::errors::{self, DecodeError, EncodeError};
10use crate::features::ProtocolVersion;
11
12pub type KeyValues = HashMap<u16, Bytes>;
13pub type Annotations = HashMap<String, String>;
14
15pub struct Input {
16    #[allow(dead_code)]
17    proto: ProtocolVersion,
18    bytes: Bytes,
19}
20
21pub struct Output<'a> {
22    #[allow(dead_code)]
23    proto: &'a ProtocolVersion,
24    bytes: &'a mut BytesMut,
25}
26
27pub(crate) trait Encode {
28    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError>;
29}
30
31pub(crate) trait Decode: Sized {
32    fn decode(buf: &mut Input) -> Result<Self, DecodeError>;
33}
34
35impl Input {
36    pub fn new(proto: ProtocolVersion, bytes: Bytes) -> Input {
37        Input { proto, bytes }
38    }
39    pub fn proto(&self) -> &ProtocolVersion {
40        &self.proto
41    }
42    pub fn slice(&self, range: impl RangeBounds<usize>) -> Input {
43        Input {
44            proto: self.proto.clone(),
45            bytes: self.bytes.slice(range),
46        }
47    }
48}
49
50impl Buf for Input {
51    fn remaining(&self) -> usize {
52        self.bytes.remaining()
53    }
54
55    fn chunk(&self) -> &[u8] {
56        self.bytes.chunk()
57    }
58
59    fn advance(&mut self, cnt: usize) {
60        self.bytes.advance(cnt)
61    }
62
63    fn copy_to_bytes(&mut self, len: usize) -> Bytes {
64        self.bytes.copy_to_bytes(len)
65    }
66}
67
68impl Deref for Input {
69    type Target = [u8];
70    fn deref(&self) -> &[u8] {
71        &self.bytes[..]
72    }
73}
74
75impl Deref for Output<'_> {
76    type Target = [u8];
77    fn deref(&self) -> &[u8] {
78        &self.bytes[..]
79    }
80}
81
82impl DerefMut for Output<'_> {
83    fn deref_mut(&mut self) -> &mut [u8] {
84        &mut self.bytes[..]
85    }
86}
87
88impl Output<'_> {
89    pub fn new<'x>(proto: &'x ProtocolVersion, bytes: &'x mut BytesMut) -> Output<'x> {
90        Output { proto, bytes }
91    }
92    pub fn proto(&self) -> &ProtocolVersion {
93        self.proto
94    }
95    pub fn reserve(&mut self, size: usize) {
96        self.bytes.reserve(size)
97    }
98    pub fn extend(&mut self, slice: &[u8]) {
99        self.bytes.extend(slice)
100    }
101}
102
103unsafe impl BufMut for Output<'_> {
104    fn remaining_mut(&self) -> usize {
105        self.bytes.remaining_mut()
106    }
107    unsafe fn advance_mut(&mut self, cnt: usize) {
108        self.bytes.advance_mut(cnt)
109    }
110    fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
111        self.bytes.chunk_mut()
112    }
113}
114
115pub(crate) fn encode<T: Encode>(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> {
116    buf.reserve(5);
117    buf.put_u8(code);
118    let base = buf.len();
119    buf.put_slice(&[0; 4]);
120
121    msg.encode(buf)?;
122
123    let size = u32::try_from(buf.len() - base)
124        .ok()
125        .context(errors::MessageTooLong)?;
126    buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]);
127    Ok(())
128}
129
130impl Encode for String {
131    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
132        buf.reserve(2 + self.len());
133        buf.put_u32(
134            u32::try_from(self.len())
135                .ok()
136                .context(errors::StringTooLong)?,
137        );
138        buf.extend(self.as_bytes());
139        Ok(())
140    }
141}
142
143impl Encode for Bytes {
144    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
145        buf.reserve(2 + self.len());
146        buf.put_u32(
147            u32::try_from(self.len())
148                .ok()
149                .context(errors::StringTooLong)?,
150        );
151        buf.extend(&self[..]);
152        Ok(())
153    }
154}
155
156impl Decode for String {
157    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
158        ensure!(buf.remaining() >= 4, errors::Underflow);
159        let len = buf.get_u32() as usize;
160        // TODO(tailhook) ensure size < i32::MAX
161        ensure!(buf.remaining() >= len, errors::Underflow);
162        let mut data = vec![0u8; len];
163        buf.copy_to_slice(&mut data[..]);
164
165        String::from_utf8(data)
166            .map_err(|e| e.utf8_error())
167            .context(errors::InvalidUtf8)
168    }
169}
170
171impl Decode for Bytes {
172    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
173        ensure!(buf.remaining() >= 4, errors::Underflow);
174        let len = buf.get_u32() as usize;
175        // TODO(tailhook) ensure size < i32::MAX
176        ensure!(buf.remaining() >= len, errors::Underflow);
177        Ok(buf.copy_to_bytes(len))
178    }
179}
180
181impl Decode for Uuid {
182    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
183        ensure!(buf.remaining() >= 16, errors::Underflow);
184        let mut bytes = [0u8; 16];
185        buf.copy_to_slice(&mut bytes[..]);
186        let result = Uuid::from_slice(&bytes).context(errors::InvalidUuid)?;
187        Ok(result)
188    }
189}
190
191impl Encode for Uuid {
192    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
193        buf.extend(self.as_bytes());
194        Ok(())
195    }
196}
197
198impl Decode for bool {
199    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
200        ensure!(buf.remaining() >= 1, errors::Underflow);
201        let res = match buf.get_u8() {
202            0x00 => false,
203            0x01 => true,
204            v => errors::InvalidBool { val: v }.fail()?,
205        };
206        Ok(res)
207    }
208}
209
210impl Encode for bool {
211    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
212        buf.extend(match self {
213            true => &[0x01],
214            false => &[0x00],
215        });
216        Ok(())
217    }
218}