1use std::collections::HashMap;
2use std::convert::TryFrom;
3use std::mem::MaybeUninit;
4use std::ops::{Deref, DerefMut, RangeBounds};
5
6use bytes::{Buf, BufMut, Bytes, BytesMut};
7use snafu::{ensure, OptionExt, ResultExt};
8use uuid::Uuid;
9
10use crate::errors::{self, DecodeError, EncodeError};
11use crate::features::ProtocolVersion;
12
13pub type KeyValues = HashMap<u16, Bytes>;
14pub type Annotations = HashMap<String, String>;
15
16pub struct Input {
17 #[allow(dead_code)]
18 proto: ProtocolVersion,
19 bytes: Bytes,
20}
21
22pub struct Output<'a> {
23 #[allow(dead_code)]
24 proto: &'a ProtocolVersion,
25 bytes: &'a mut BytesMut,
26}
27
28pub(crate) trait Encode {
29 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError>;
30}
31
32pub(crate) trait Decode: Sized {
33 fn decode(buf: &mut Input) -> Result<Self, DecodeError>;
34}
35
36impl Input {
37 pub fn new(proto: ProtocolVersion, bytes: Bytes) -> Input {
38 Input { proto, bytes }
39 }
40 pub fn proto(&self) -> &ProtocolVersion {
41 &self.proto
42 }
43 pub fn slice(&self, range: impl RangeBounds<usize>) -> Input {
44 Input {
45 proto: self.proto.clone(),
46 bytes: self.bytes.slice(range),
47 }
48 }
49}
50
51impl Buf for Input {
52 fn remaining(&self) -> usize {
53 self.bytes.remaining()
54 }
55
56 fn chunk(&self) -> &[u8] {
57 self.bytes.chunk()
58 }
59
60 fn advance(&mut self, cnt: usize) {
61 self.bytes.advance(cnt)
62 }
63
64 fn copy_to_bytes(&mut self, len: usize) -> Bytes {
65 self.bytes.copy_to_bytes(len)
66 }
67}
68
69impl Deref for Input {
70 type Target = [u8];
71 fn deref(&self) -> &[u8] {
72 &self.bytes[..]
73 }
74}
75
76impl Deref for Output<'_> {
77 type Target = [u8];
78 fn deref(&self) -> &[u8] {
79 &self.bytes[..]
80 }
81}
82
83impl DerefMut for Output<'_> {
84 fn deref_mut(&mut self) -> &mut [u8] {
85 &mut self.bytes[..]
86 }
87}
88
89impl Output<'_> {
90 pub fn new<'x>(proto: &'x ProtocolVersion, bytes: &'x mut BytesMut) -> Output<'x> {
91 Output { proto, bytes }
92 }
93 pub fn proto(&self) -> &ProtocolVersion {
94 self.proto
95 }
96 pub fn reserve(&mut self, size: usize) {
97 self.bytes.reserve(size)
98 }
99 pub fn extend(&mut self, slice: &[u8]) {
100 self.bytes.extend(slice)
101 }
102 pub fn uninit(&mut self) -> &mut [MaybeUninit<u8>] {
103 unsafe { self.bytes.chunk_mut().as_uninit_slice_mut() }
104 }
105}
106
107unsafe impl BufMut for Output<'_> {
108 fn remaining_mut(&self) -> usize {
109 self.bytes.remaining_mut()
110 }
111 unsafe fn advance_mut(&mut self, cnt: usize) {
112 self.bytes.advance_mut(cnt)
113 }
114 fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
115 self.bytes.chunk_mut()
116 }
117}
118
119pub(crate) fn encode<T: Encode>(buf: &mut Output, code: u8, msg: &T) -> Result<(), EncodeError> {
120 buf.reserve(5);
121 buf.put_u8(code);
122 let base = buf.len();
123 buf.put_slice(&[0; 4]);
124
125 msg.encode(buf)?;
126
127 let size = u32::try_from(buf.len() - base)
128 .ok()
129 .context(errors::MessageTooLong)?;
130 buf[base..base + 4].copy_from_slice(&size.to_be_bytes()[..]);
131 Ok(())
132}
133
134impl Encode for String {
135 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
136 buf.reserve(2 + self.len());
137 buf.put_u32(
138 u32::try_from(self.len())
139 .ok()
140 .context(errors::StringTooLong)?,
141 );
142 buf.extend(self.as_bytes());
143 Ok(())
144 }
145}
146
147impl Encode for Bytes {
148 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
149 buf.reserve(2 + self.len());
150 buf.put_u32(
151 u32::try_from(self.len())
152 .ok()
153 .context(errors::StringTooLong)?,
154 );
155 buf.extend(&self[..]);
156 Ok(())
157 }
158}
159
160impl Decode for String {
161 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
162 ensure!(buf.remaining() >= 4, errors::Underflow);
163 let len = buf.get_u32() as usize;
164 ensure!(buf.remaining() >= len, errors::Underflow);
166 let mut data = vec![0u8; len];
167 buf.copy_to_slice(&mut data[..]);
168
169 String::from_utf8(data)
170 .map_err(|e| e.utf8_error())
171 .context(errors::InvalidUtf8)
172 }
173}
174
175impl Decode for Bytes {
176 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
177 ensure!(buf.remaining() >= 4, errors::Underflow);
178 let len = buf.get_u32() as usize;
179 ensure!(buf.remaining() >= len, errors::Underflow);
181 Ok(buf.copy_to_bytes(len))
182 }
183}
184
185impl Decode for Uuid {
186 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
187 ensure!(buf.remaining() >= 16, errors::Underflow);
188 let mut bytes = [0u8; 16];
189 buf.copy_to_slice(&mut bytes[..]);
190 let result = Uuid::from_slice(&bytes).context(errors::InvalidUuid)?;
191 Ok(result)
192 }
193}
194
195impl Encode for Uuid {
196 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
197 buf.extend(self.as_bytes());
198 Ok(())
199 }
200}
201
202impl Decode for bool {
203 fn decode(buf: &mut Input) -> Result<Self, DecodeError> {
204 ensure!(buf.remaining() >= 1, errors::Underflow);
205 let res = match buf.get_u8() {
206 0x00 => false,
207 0x01 => true,
208 v => errors::InvalidBool { val: v }.fail()?,
209 };
210 Ok(res)
211 }
212}
213
214impl Encode for bool {
215 fn encode(&self, buf: &mut Output) -> Result<(), EncodeError> {
216 buf.extend(match self {
217 true => &[0x01],
218 false => &[0x00],
219 });
220 Ok(())
221 }
222}