1use std::collections::HashMap;
2use std::convert::TryFrom;
3use std::ops::{Deref, DerefMut, RangeBounds};
4
5use bytes::{Buf, BufMut, Bytes, BytesMut};
6use snafu::{ensure, OptionExt, ResultExt};
7use uuid::Uuid;
8
9use crate::errors::{self, DecodeError, EncodeError};
10use crate::features::ProtocolVersion;
11
12pub type KeyValues = HashMap<u16, Bytes>;
13pub type Annotations = HashMap<String, String>;
14
15pub struct Input {
16 #[allow(dead_code)]
17 proto: ProtocolVersion,
18 bytes: Bytes,
19}
20
21pub struct Output<'a> {
22 #[allow(dead_code)]
23 proto: &'a ProtocolVersion,
24 bytes: &'a mut BytesMut,
25}
26
27pub(crate) trait Encode {
28 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError>;
29}
30
31pub(crate) trait Decode: Sized {
32 fn decode(buf: &mut Input) -> Result<Self, DecodeError>;
33}
34
35impl Input {
36 pub fn new(proto: ProtocolVersion, bytes: Bytes) -> Input {
37 Input { proto, bytes }
38 }
39 pub fn proto(&self) -> &ProtocolVersion {
40 &self.proto
41 }
42 pub fn slice(&self, range: impl RangeBounds<usize>) -> Input {
43 Input {
44 proto: self.proto.clone(),
45 bytes: self.bytes.slice(range),
46 }
47 }
48}
49
50impl Buf for Input {
51 fn remaining(&self) -> usize {
52 self.bytes.remaining()
53 }
54
55 fn chunk(&self) -> &[u8] {
56 self.bytes.chunk()
57 }
58
59 fn advance(&mut self, cnt: usize) {
60 self.bytes.advance(cnt)
61 }
62
63 fn copy_to_bytes(&mut self, len: usize) -> Bytes {
64 self.bytes.copy_to_bytes(len)
65 }
66}
67
68impl Deref for Input {
69 type Target = [u8];
70 fn deref(&self) -> &[u8] {
71 &self.bytes[..]
72 }
73}
74
75impl Deref for Output<'_> {
76 type Target = [u8];
77 fn deref(&self) -> &[u8] {
78 &self.bytes[..]
79 }
80}
81
82impl DerefMut for Output<'_> {
83 fn deref_mut(&mut self) -> &mut [u8] {
84 &mut self.bytes[..]
85 }
86}
87
88impl Output<'_> {
89 pub fn new<'x>(proto: &'x ProtocolVersion, bytes: &'x mut BytesMut) -> Output<'x> {
90 Output { proto, bytes }
91 }
92 pub fn proto(&self) -> &ProtocolVersion {
93 self.proto
94 }
95 pub fn reserve(&mut self, size: usize) {
96 self.bytes.reserve(size)
97 }
98 pub fn extend(&mut self, slice: &[u8]) {
99 self.bytes.extend(slice)
100 }
101}
102
103unsafe impl BufMut for Output<'_> {
104 fn remaining_mut(&self) -> usize {
105 self.bytes.remaining_mut()
106 }
107 unsafe fn advance_mut(&mut self, cnt: usize) {
108 self.bytes.advance_mut(cnt)
109 }
110 fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
111 self.bytes.chunk_mut()
112 }
113}
114
115pub(crate) fn encode<T: Encode>(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> {
116 buf.reserve(5);
117 buf.put_u8(code);
118 let base = buf.len();
119 buf.put_slice(&[0; 4]);
120
121 msg.encode(buf)?;
122
123 let size = u32::try_from(buf.len() - base)
124 .ok()
125 .context(errors::MessageTooLong)?;
126 buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]);
127 Ok(())
128}
129
130impl Encode for String {
131 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
132 buf.reserve(2 + self.len());
133 buf.put_u32(
134 u32::try_from(self.len())
135 .ok()
136 .context(errors::StringTooLong)?,
137 );
138 buf.extend(self.as_bytes());
139 Ok(())
140 }
141}
142
143impl Encode for Bytes {
144 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
145 buf.reserve(2 + self.len());
146 buf.put_u32(
147 u32::try_from(self.len())
148 .ok()
149 .context(errors::StringTooLong)?,
150 );
151 buf.extend(&self[..]);
152 Ok(())
153 }
154}
155
156impl Decode for String {
157 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
158 ensure!(buf.remaining() >= 4, errors::Underflow);
159 let len = buf.get_u32() as usize;
160 ensure!(buf.remaining() >= len, errors::Underflow);
162 let mut data = vec![0u8; len];
163 buf.copy_to_slice(&mut data[..]);
164
165 String::from_utf8(data)
166 .map_err(|e| e.utf8_error())
167 .context(errors::InvalidUtf8)
168 }
169}
170
171impl Decode for Bytes {
172 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
173 ensure!(buf.remaining() >= 4, errors::Underflow);
174 let len = buf.get_u32() as usize;
175 ensure!(buf.remaining() >= len, errors::Underflow);
177 Ok(buf.copy_to_bytes(len))
178 }
179}
180
181impl Decode for Uuid {
182 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
183 ensure!(buf.remaining() >= 16, errors::Underflow);
184 let mut bytes = [0u8; 16];
185 buf.copy_to_slice(&mut bytes[..]);
186 let result = Uuid::from_slice(&bytes).context(errors::InvalidUuid)?;
187 Ok(result)
188 }
189}
190
191impl Encode for Uuid {
192 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
193 buf.extend(self.as_bytes());
194 Ok(())
195 }
196}
197
198impl Decode for bool {
199 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
200 ensure!(buf.remaining() >= 1, errors::Underflow);
201 let res = match buf.get_u8() {
202 0x00 => false,
203 0x01 => true,
204 v => errors::InvalidBool { val: v }.fail()?,
205 };
206 Ok(res)
207 }
208}
209
210impl Encode for bool {
211 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
212 buf.extend(match self {
213 true => &[0x01],
214 false => &[0x00],
215 });
216 Ok(())
217 }
218}