gel_protocol/
encoding.rs

1use std::collections::HashMap;
2use std::convert::TryFrom;
3use std::mem::MaybeUninit;
4use std::ops::{Deref, DerefMut, RangeBounds};
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use snafu::{ensure, OptionExt, ResultExt};
8use uuid::Uuid;
9
10use crate::errors::{self, DecodeError, EncodeError};
11use crate::features::ProtocolVersion;
12
13pub type KeyValues = HashMap<u16, Bytes>;
14pub type Annotations = HashMap<String, String>;
15
16pub struct Input {
17    #[allow(dead_code)]
18    proto: ProtocolVersion,
19    bytes: Bytes,
20}
21
22pub struct Output<'a> {
23    #[allow(dead_code)]
24    proto: &'a ProtocolVersion,
25    bytes: &'a mut BytesMut,
26}
27
28pub(crate) trait Encode {
29    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError>;
30}
31
32pub(crate) trait Decode: Sized {
33    fn decode(buf: &mut Input) -> Result<Self, DecodeError>;
34}
35
36impl Input {
37    pub fn new(proto: ProtocolVersion, bytes: Bytes) -> Input {
38        Input { proto, bytes }
39    }
40    pub fn proto(&self) -> &ProtocolVersion {
41        &self.proto
42    }
43    pub fn slice(&self, range: impl RangeBounds<usize>) -> Input {
44        Input {
45            proto: self.proto.clone(),
46            bytes: self.bytes.slice(range),
47        }
48    }
49}
50
51impl Buf for Input {
52    fn remaining(&self) -> usize {
53        self.bytes.remaining()
54    }
55
56    fn chunk(&self) -> &[u8] {
57        self.bytes.chunk()
58    }
59
60    fn advance(&mut self, cnt: usize) {
61        self.bytes.advance(cnt)
62    }
63
64    fn copy_to_bytes(&mut self, len: usize) -> Bytes {
65        self.bytes.copy_to_bytes(len)
66    }
67}
68
69impl Deref for Input {
70    type Target = [u8];
71    fn deref(&self) -> &[u8] {
72        &self.bytes[..]
73    }
74}
75
76impl Deref for Output<'_> {
77    type Target = [u8];
78    fn deref(&self) -> &[u8] {
79        &self.bytes[..]
80    }
81}
82
83impl DerefMut for Output<'_> {
84    fn deref_mut(&mut self) -> &mut [u8] {
85        &mut self.bytes[..]
86    }
87}
88
89impl Output<'_> {
90    pub fn new<'x>(proto: &'x ProtocolVersion, bytes: &'x mut BytesMut) -> Output<'x> {
91        Output { proto, bytes }
92    }
93    pub fn proto(&self) -> &ProtocolVersion {
94        self.proto
95    }
96    pub fn reserve(&mut self, size: usize) {
97        self.bytes.reserve(size)
98    }
99    pub fn extend(&mut self, slice: &[u8]) {
100        self.bytes.extend(slice)
101    }
102    pub fn uninit(&mut self) -> &mut [MaybeUninit<u8>] {
103        unsafe { self.bytes.chunk_mut().as_uninit_slice_mut() }
104    }
105}
106
107unsafe impl BufMut for Output<'_> {
108    fn remaining_mut(&self) -> usize {
109        self.bytes.remaining_mut()
110    }
111    unsafe fn advance_mut(&mut self, cnt: usize) {
112        self.bytes.advance_mut(cnt)
113    }
114    fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
115        self.bytes.chunk_mut()
116    }
117}
118
119pub(crate) fn encode<T: Encode>(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> {
120    buf.reserve(5);
121    buf.put_u8(code);
122    let base = buf.len();
123    buf.put_slice(&[0; 4]);
124
125    msg.encode(buf)?;
126
127    let size = u32::try_from(buf.len() - base)
128        .ok()
129        .context(errors::MessageTooLong)?;
130    buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]);
131    Ok(())
132}
133
134impl Encode for String {
135    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
136        buf.reserve(2 + self.len());
137        buf.put_u32(
138            u32::try_from(self.len())
139                .ok()
140                .context(errors::StringTooLong)?,
141        );
142        buf.extend(self.as_bytes());
143        Ok(())
144    }
145}
146
147impl Encode for Bytes {
148    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
149        buf.reserve(2 + self.len());
150        buf.put_u32(
151            u32::try_from(self.len())
152                .ok()
153                .context(errors::StringTooLong)?,
154        );
155        buf.extend(&self[..]);
156        Ok(())
157    }
158}
159
160impl Decode for String {
161    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
162        ensure!(buf.remaining() >= 4, errors::Underflow);
163        let len = buf.get_u32() as usize;
164        // TODO(tailhook) ensure size < i32::MAX
165        ensure!(buf.remaining() >= len, errors::Underflow);
166        let mut data = vec![0u8; len];
167        buf.copy_to_slice(&mut data[..]);
168
169        String::from_utf8(data)
170            .map_err(|e| e.utf8_error())
171            .context(errors::InvalidUtf8)
172    }
173}
174
175impl Decode for Bytes {
176    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
177        ensure!(buf.remaining() >= 4, errors::Underflow);
178        let len = buf.get_u32() as usize;
179        // TODO(tailhook) ensure size < i32::MAX
180        ensure!(buf.remaining() >= len, errors::Underflow);
181        Ok(buf.copy_to_bytes(len))
182    }
183}
184
185impl Decode for Uuid {
186    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
187        ensure!(buf.remaining() >= 16, errors::Underflow);
188        let mut bytes = [0u8; 16];
189        buf.copy_to_slice(&mut bytes[..]);
190        let result = Uuid::from_slice(&bytes).context(errors::InvalidUuid)?;
191        Ok(result)
192    }
193}
194
195impl Encode for Uuid {
196    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
197        buf.extend(self.as_bytes());
198        Ok(())
199    }
200}
201
202impl Decode for bool {
203    fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
204        ensure!(buf.remaining() >= 1, errors::Underflow);
205        let res = match buf.get_u8() {
206            0x00 => false,
207            0x01 => true,
208            v => errors::InvalidBool { val: v }.fail()?,
209        };
210        Ok(res)
211    }
212}
213
214impl Encode for bool {
215    fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
216        buf.extend(match self {
217            true => &[0x01],
218            false => &[0x00],
219        });
220        Ok(())
221    }
222}