1use std::io::Cursor;
18
19use anyhow::anyhow;
20use anyhow::Result;
21use bufsize::SizeCounter;
22use bytes::Bytes;
23use bytes::BytesMut;
24use ghost::phantom;
25
26use crate::binary_type::CopyFromBuf;
27use crate::bufext::BufExt;
28use crate::bufext::BufMutExt;
29use crate::bufext::DeserializeSource;
30use crate::deserialize::Deserialize;
31use crate::errors::ProtocolError;
32use crate::framing::Framing;
33use crate::protocol::Field;
34use crate::protocol::Protocol;
35use crate::protocol::ProtocolReader;
36use crate::protocol::ProtocolWriter;
37use crate::serialize::Serialize;
38use crate::thrift_protocol::MessageType;
39use crate::thrift_protocol::ProtocolID;
40use crate::ttype::TType;
41
42pub const BINARY_VERSION_MASK: u32 = 0xffff_0000;
43pub const BINARY_VERSION_1: u32 = 0x8001_0000;
44
45#[phantom]
65#[derive(Copy, Clone)]
66pub struct BinaryProtocol<F = Bytes>;
67
68pub struct BinaryProtocolSerializer<B> {
69 buffer: B,
70}
71
72pub struct BinaryProtocolDeserializer<B> {
73 buffer: B,
74}
75
76impl<F> Protocol for BinaryProtocol<F>
77where
78 F: Framing + 'static,
79{
80 type Frame = F;
81 type Sizer = BinaryProtocolSerializer<SizeCounter>;
82 type Serializer = BinaryProtocolSerializer<F::EncBuf>;
83 type Deserializer = BinaryProtocolDeserializer<F::DecBuf>;
84
85 const PROTOCOL_ID: ProtocolID = ProtocolID::BinaryProtocol;
86
87 fn serializer<SZ, SER>(size: SZ, ser: SER) -> <Self::Serializer as ProtocolWriter>::Final
88 where
89 SZ: FnOnce(&mut Self::Sizer),
90 SER: FnOnce(&mut Self::Serializer),
91 {
92 let mut sizer = BinaryProtocolSerializer {
93 buffer: SizeCounter::new(),
94 };
95 size(&mut sizer);
96 let sz = sizer.finish();
97 let mut buf = BinaryProtocolSerializer {
98 buffer: F::enc_with_capacity(sz),
99 };
100 ser(&mut buf);
101 buf.finish()
102 }
103
104 fn deserializer(buf: F::DecBuf) -> Self::Deserializer {
105 BinaryProtocolDeserializer::new(buf)
106 }
107
108 fn into_buffer(deser: Self::Deserializer) -> F::DecBuf {
109 deser.into_inner()
110 }
111}
112
113impl<B> BinaryProtocolSerializer<B> {
114 pub fn with_buffer(buffer: B) -> Self {
115 Self { buffer }
116 }
117}
118
119impl<B: BufMutExt> BinaryProtocolSerializer<B> {
120 fn write_u32(&mut self, value: u32) {
121 self.buffer.put_u32(value)
122 }
123}
124
125impl<B: BufExt> BinaryProtocolDeserializer<B> {
126 pub fn new(buffer: B) -> Self {
127 BinaryProtocolDeserializer { buffer }
128 }
129
130 pub fn into_inner(self) -> B {
131 self.buffer
132 }
133
134 fn peek_bytes(&self, len: usize) -> Option<&[u8]> {
135 if self.buffer.chunk().len() >= len {
136 Some(&self.buffer.chunk()[..len])
137 } else {
138 None
139 }
140 }
141
142 fn read_u32(&mut self) -> Result<u32> {
143 ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
144
145 Ok(self.buffer.get_u32())
146 }
147}
148
149impl<B: BufMutExt> ProtocolWriter for BinaryProtocolSerializer<B> {
150 type Final = B::Final;
151
152 fn write_message_begin(&mut self, name: &str, type_id: MessageType, seqid: u32) {
153 let version = BINARY_VERSION_1 | (type_id as u32);
154 self.write_i32(version as i32);
155 self.write_string(name);
156 self.write_u32(seqid);
157 }
158
159 #[inline]
160 fn write_message_end(&mut self) {}
161
162 #[inline]
163 fn write_struct_begin(&mut self, _name: &str) {}
164
165 #[inline]
166 fn write_struct_end(&mut self) {}
167
168 fn write_field_begin(&mut self, _name: &str, type_id: TType, id: i16) {
169 self.write_byte(type_id as i8);
170 self.write_i16(id);
171 }
172
173 #[inline]
174 fn write_field_end(&mut self) {}
175
176 #[inline]
177 fn write_field_stop(&mut self) {
178 self.write_byte(TType::Stop as i8)
179 }
180
181 fn write_map_begin(&mut self, key_type: TType, value_type: TType, size: usize) {
182 self.write_byte(key_type as i8);
183 self.write_byte(value_type as i8);
184 self.write_i32(i32::try_from(size as u64).expect("map size overflow"));
185 }
186
187 #[inline]
188 fn write_map_key_begin(&mut self) {}
189
190 #[inline]
191 fn write_map_value_begin(&mut self) {}
192
193 #[inline]
194 fn write_map_end(&mut self) {}
195
196 fn write_list_begin(&mut self, elem_type: TType, size: usize) {
197 self.write_byte(elem_type as i8);
198 self.write_i32(i32::try_from(size as u64).expect("list size overflow"));
199 }
200
201 #[inline]
202 fn write_list_value_begin(&mut self) {}
203
204 #[inline]
205 fn write_list_end(&mut self) {}
206
207 fn write_set_begin(&mut self, elem_type: TType, size: usize) {
208 self.write_byte(elem_type as i8);
209 self.write_i32(i32::try_from(size as u64).expect("set size overflow"));
210 }
211
212 #[inline]
213 fn write_set_value_begin(&mut self) {}
214
215 fn write_set_end(&mut self) {}
216
217 fn write_bool(&mut self, value: bool) {
218 if value {
219 self.write_byte(1)
220 } else {
221 self.write_byte(0)
222 }
223 }
224
225 fn write_byte(&mut self, value: i8) {
226 self.buffer.put_i8(value)
227 }
228
229 fn write_i16(&mut self, value: i16) {
230 self.buffer.put_i16(value)
231 }
232
233 fn write_i32(&mut self, value: i32) {
234 self.buffer.put_i32(value)
235 }
236
237 fn write_i64(&mut self, value: i64) {
238 self.buffer.put_i64(value)
239 }
240
241 fn write_double(&mut self, value: f64) {
242 self.buffer.put_f64(value)
243 }
244
245 fn write_float(&mut self, value: f32) {
246 self.buffer.put_f32(value)
247 }
248
249 fn write_string(&mut self, value: &str) {
250 self.write_i32(value.len() as i32);
251 self.buffer.put_slice(value.as_bytes())
252 }
253
254 fn write_binary(&mut self, value: &[u8]) {
255 self.write_i32(value.len() as i32);
256 self.buffer.put_slice(value)
257 }
258
259 fn finish(self) -> B::Final {
260 self.buffer.finalize()
261 }
262}
263
264impl<B: BufExt> ProtocolReader for BinaryProtocolDeserializer<B> {
265 fn read_message_begin<F, T>(&mut self, msgfn: F) -> Result<(T, MessageType, u32)>
266 where
267 F: FnOnce(&[u8]) -> T,
268 {
269 let versionty = self.read_i32()? as u32;
270
271 let msgtype = MessageType::try_from(versionty & !BINARY_VERSION_MASK)?; let version = versionty & BINARY_VERSION_MASK;
273 ensure_err!(version == BINARY_VERSION_1, ProtocolError::BadVersion);
274
275 let name = {
276 let len = self.read_i32()? as usize;
277 let (len, name) = {
278 if self.peek_bytes(len).is_some() {
279 let namebuf = self.peek_bytes(len).unwrap();
280 (namebuf.len(), msgfn(namebuf))
281 } else {
282 ensure_err!(
283 self.buffer.remaining() >= len,
284 ProtocolError::InvalidDataLength
285 );
286 let namebuf: Vec<u8> = Vec::copy_from_buf(&mut self.buffer, len);
287 (0, msgfn(namebuf.as_slice()))
288 }
289 };
290 self.buffer.advance(len);
291 name
292 };
293 let seq_id = self.read_u32()?;
294
295 Ok((name, msgtype, seq_id))
296 }
297
298 fn read_message_end(&mut self) -> Result<()> {
299 Ok(())
300 }
301
302 fn read_struct_begin<F, T>(&mut self, namefn: F) -> Result<T>
303 where
304 F: FnOnce(&[u8]) -> T,
305 {
306 Ok(namefn(&[]))
307 }
308
309 fn read_struct_end(&mut self) -> Result<()> {
310 Ok(())
311 }
312
313 fn read_field_begin<F, T>(&mut self, fieldfn: F, _fields: &[Field]) -> Result<(T, TType, i16)>
314 where
315 F: FnOnce(&[u8]) -> T,
316 {
317 let type_id = TType::try_from(self.read_byte()?)?;
318 let seq_id = match type_id {
319 TType::Stop => 0,
320 _ => self.read_i16()?,
321 };
322 Ok((fieldfn(&[]), type_id, seq_id))
323 }
324
325 fn read_field_end(&mut self) -> Result<()> {
326 Ok(())
327 }
328
329 fn read_map_begin(&mut self) -> Result<(TType, TType, Option<usize>)> {
330 let k_type = TType::try_from(self.read_byte()?)?;
331 let v_type = TType::try_from(self.read_byte()?)?;
332
333 let size = self.read_i32()?;
334 ensure_err!(size >= 0, ProtocolError::InvalidDataLength);
335 Ok((k_type, v_type, Some(size as usize)))
336 }
337
338 #[inline]
339 fn read_map_key_begin(&mut self) -> Result<bool> {
340 Ok(true)
341 }
342
343 #[inline]
344 fn read_map_value_begin(&mut self) -> Result<()> {
345 Ok(())
346 }
347
348 #[inline]
349 fn read_map_value_end(&mut self) -> Result<()> {
350 Ok(())
351 }
352
353 fn read_map_end(&mut self) -> Result<()> {
354 Ok(())
355 }
356
357 fn read_list_begin(&mut self) -> Result<(TType, Option<usize>)> {
358 let elem_type = TType::try_from(self.read_byte()?)?;
359 let size = self.read_i32()?;
360 ensure_err!(size >= 0, ProtocolError::InvalidDataLength);
361 Ok((elem_type, Some(size as usize)))
362 }
363
364 #[inline]
365 fn read_list_value_begin(&mut self) -> Result<bool> {
366 Ok(true)
367 }
368
369 #[inline]
370 fn read_list_value_end(&mut self) -> Result<()> {
371 Ok(())
372 }
373
374 fn read_list_end(&mut self) -> Result<()> {
375 Ok(())
376 }
377
378 fn read_set_begin(&mut self) -> Result<(TType, Option<usize>)> {
379 let elem_type = TType::try_from(self.read_byte()?)?;
380 let size = self.read_i32()?;
381 ensure_err!(size >= 0, ProtocolError::InvalidDataLength);
382 Ok((elem_type, Some(size as usize)))
383 }
384
385 #[inline]
386 fn read_set_value_begin(&mut self) -> Result<bool> {
387 Ok(true)
388 }
389
390 #[inline]
391 fn read_set_value_end(&mut self) -> Result<()> {
392 Ok(())
393 }
394
395 fn read_set_end(&mut self) -> Result<()> {
396 Ok(())
397 }
398
399 fn read_bool(&mut self) -> Result<bool> {
400 match self.read_byte()? {
401 0 => Ok(false),
402 _ => Ok(true),
403 }
404 }
405
406 fn read_byte(&mut self) -> Result<i8> {
407 ensure_err!(self.buffer.remaining() >= 1, ProtocolError::EOF);
408
409 Ok(self.buffer.get_i8())
410 }
411
412 fn read_i16(&mut self) -> Result<i16> {
413 ensure_err!(self.buffer.remaining() >= 2, ProtocolError::EOF);
414
415 Ok(self.buffer.get_i16())
416 }
417
418 fn read_i32(&mut self) -> Result<i32> {
419 ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
420
421 Ok(self.buffer.get_i32())
422 }
423
424 fn read_i64(&mut self) -> Result<i64> {
425 ensure_err!(self.buffer.remaining() >= 8, ProtocolError::EOF);
426
427 Ok(self.buffer.get_i64())
428 }
429
430 fn read_double(&mut self) -> Result<f64> {
431 ensure_err!(self.buffer.remaining() >= 8, ProtocolError::EOF);
432
433 Ok(self.buffer.get_f64())
434 }
435
436 fn read_float(&mut self) -> Result<f32> {
437 ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
438
439 Ok(self.buffer.get_f32())
440 }
441
442 fn read_string(&mut self) -> Result<String> {
443 let vec = self.read_binary::<Vec<u8>>()?;
444
445 String::from_utf8(vec)
446 .map_err(|utf8_error| anyhow!("deserializing `string` from Thrift binary protocol got invalid utf-8, you need to use `binary` instead: {utf8_error}"))
447 }
448
449 fn read_binary<V: CopyFromBuf>(&mut self) -> Result<V> {
450 let received_len = self.read_i32()?;
451 ensure_err!(received_len >= 0, ProtocolError::InvalidDataLength);
452
453 let received_len = received_len as usize;
454
455 ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF);
456 Ok(V::copy_from_buf(&mut self.buffer, received_len))
457 }
458}
459
460pub fn serialize_size<T>(v: &T) -> usize
462where
463 T: Serialize<BinaryProtocolSerializer<SizeCounter>>,
464{
465 let mut sizer = BinaryProtocolSerializer::with_buffer(SizeCounter::new());
466 v.write(&mut sizer);
467 sizer.finish()
468}
469
470pub fn serialize_to_buffer<T>(v: T, buffer: BytesMut) -> BinaryProtocolSerializer<BytesMut>
474where
475 T: Serialize<BinaryProtocolSerializer<BytesMut>>,
476{
477 let mut buf = BinaryProtocolSerializer::with_buffer(buffer);
479 v.write(&mut buf);
480 buf
481}
482
483pub trait SerializeRef:
484 Serialize<BinaryProtocolSerializer<SizeCounter>> + Serialize<BinaryProtocolSerializer<BytesMut>>
485where
486 for<'a> &'a Self: Serialize<BinaryProtocolSerializer<SizeCounter>>,
487 for<'a> &'a Self: Serialize<BinaryProtocolSerializer<BytesMut>>,
488{
489}
490
491impl<T> SerializeRef for T
492where
493 T: Serialize<BinaryProtocolSerializer<BytesMut>>,
494 T: Serialize<BinaryProtocolSerializer<SizeCounter>>,
495 for<'a> &'a T: Serialize<BinaryProtocolSerializer<BytesMut>>,
496 for<'a> &'a T: Serialize<BinaryProtocolSerializer<SizeCounter>>,
497{
498}
499
500pub fn serialize<T>(v: T) -> Bytes
502where
503 T: Serialize<BinaryProtocolSerializer<SizeCounter>>
504 + Serialize<BinaryProtocolSerializer<BytesMut>>,
505{
506 let sz = serialize_size(&v);
507 let buf = serialize_to_buffer(v, BytesMut::with_capacity(sz));
508 buf.finish()
510}
511
512pub trait DeserializeSlice:
513 for<'a> Deserialize<BinaryProtocolDeserializer<Cursor<&'a [u8]>>>
514{
515}
516
517impl<T> DeserializeSlice for T where
518 T: for<'a> Deserialize<BinaryProtocolDeserializer<Cursor<&'a [u8]>>>
519{
520}
521
522pub fn deserialize<T, B, C>(b: B) -> Result<T>
524where
525 B: Into<DeserializeSource<C>>,
526 C: BufExt,
527 T: Deserialize<BinaryProtocolDeserializer<C>>,
528{
529 let source: DeserializeSource<C> = b.into();
530 let mut deser = BinaryProtocolDeserializer::new(source.0);
531 T::read(&mut deser)
532}