fbthrift_git/
binary_protocol.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use std::io::Cursor;
18
19use anyhow::anyhow;
20use anyhow::Result;
21use bufsize::SizeCounter;
22use bytes::Bytes;
23use bytes::BytesMut;
24use ghost::phantom;
25
26use crate::binary_type::CopyFromBuf;
27use crate::bufext::BufExt;
28use crate::bufext::BufMutExt;
29use crate::bufext::DeserializeSource;
30use crate::deserialize::Deserialize;
31use crate::errors::ProtocolError;
32use crate::framing::Framing;
33use crate::protocol::Field;
34use crate::protocol::Protocol;
35use crate::protocol::ProtocolReader;
36use crate::protocol::ProtocolWriter;
37use crate::serialize::Serialize;
38use crate::thrift_protocol::MessageType;
39use crate::thrift_protocol::ProtocolID;
40use crate::ttype::TType;
41
42pub const BINARY_VERSION_MASK: u32 = 0xffff_0000;
43pub const BINARY_VERSION_1: u32 = 0x8001_0000;
44
45/// A straight-forward binary format that encodes numeric values in fixed width.
46///
47/// ```ignore
48/// let protocol = BinaryProtocol;
49/// let transport = HttpClient::new(ENDPOINT)?;
50/// let client = <dyn BuckGraphService>::new(protocol, transport);
51/// ```
52///
53/// The type parameter is the Framing expected by the transport on which this
54/// protocol is operating. Usually by convention the transport itself serves as
55/// the Framing impl, so for example in the case of HttpClient above, the
56/// compiler has inferred `F = HttpClient`.
57///
58/// Where the compiler reports that a Framing can't be inferred, one can be
59/// specified explicitly:
60///
61/// ```ignore
62/// let protocol = BinaryProtocol::<SRHeaderTransport>;
63/// ```
64#[phantom]
65#[derive(Copy, Clone)]
66pub struct BinaryProtocol<F = Bytes>;
67
68pub struct BinaryProtocolSerializer<B> {
69    buffer: B,
70}
71
72pub struct BinaryProtocolDeserializer<B> {
73    buffer: B,
74}
75
76impl<F> Protocol for BinaryProtocol<F>
77where
78    F: Framing + 'static,
79{
80    type Frame = F;
81    type Sizer = BinaryProtocolSerializer<SizeCounter>;
82    type Serializer = BinaryProtocolSerializer<F::EncBuf>;
83    type Deserializer = BinaryProtocolDeserializer<F::DecBuf>;
84
85    const PROTOCOL_ID: ProtocolID = ProtocolID::BinaryProtocol;
86
87    fn serializer<SZ, SER>(size: SZ, ser: SER) -> <Self::Serializer as ProtocolWriter>::Final
88    where
89        SZ: FnOnce(&mut Self::Sizer),
90        SER: FnOnce(&mut Self::Serializer),
91    {
92        let mut sizer = BinaryProtocolSerializer {
93            buffer: SizeCounter::new(),
94        };
95        size(&mut sizer);
96        let sz = sizer.finish();
97        let mut buf = BinaryProtocolSerializer {
98            buffer: F::enc_with_capacity(sz),
99        };
100        ser(&mut buf);
101        buf.finish()
102    }
103
104    fn deserializer(buf: F::DecBuf) -> Self::Deserializer {
105        BinaryProtocolDeserializer::new(buf)
106    }
107
108    fn into_buffer(deser: Self::Deserializer) -> F::DecBuf {
109        deser.into_inner()
110    }
111}
112
113impl<B> BinaryProtocolSerializer<B> {
114    pub fn with_buffer(buffer: B) -> Self {
115        Self { buffer }
116    }
117}
118
119impl<B: BufMutExt> BinaryProtocolSerializer<B> {
120    fn write_u32(&mut self, value: u32) {
121        self.buffer.put_u32(value)
122    }
123}
124
125impl<B: BufExt> BinaryProtocolDeserializer<B> {
126    pub fn new(buffer: B) -> Self {
127        BinaryProtocolDeserializer { buffer }
128    }
129
130    pub fn into_inner(self) -> B {
131        self.buffer
132    }
133
134    fn peek_bytes(&self, len: usize) -> Option<&[u8]> {
135        if self.buffer.chunk().len() >= len {
136            Some(&self.buffer.chunk()[..len])
137        } else {
138            None
139        }
140    }
141
142    fn read_u32(&mut self) -> Result<u32> {
143        ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
144
145        Ok(self.buffer.get_u32())
146    }
147}
148
149impl<B: BufMutExt> ProtocolWriter for BinaryProtocolSerializer<B> {
150    type Final = B::Final;
151
152    fn write_message_begin(&mut self, name: &str, type_id: MessageType, seqid: u32) {
153        let version = BINARY_VERSION_1 | (type_id as u32);
154        self.write_i32(version as i32);
155        self.write_string(name);
156        self.write_u32(seqid);
157    }
158
159    #[inline]
160    fn write_message_end(&mut self) {}
161
162    #[inline]
163    fn write_struct_begin(&mut self, _name: &str) {}
164
165    #[inline]
166    fn write_struct_end(&mut self) {}
167
168    fn write_field_begin(&mut self, _name: &str, type_id: TType, id: i16) {
169        self.write_byte(type_id as i8);
170        self.write_i16(id);
171    }
172
173    #[inline]
174    fn write_field_end(&mut self) {}
175
176    #[inline]
177    fn write_field_stop(&mut self) {
178        self.write_byte(TType::Stop as i8)
179    }
180
181    fn write_map_begin(&mut self, key_type: TType, value_type: TType, size: usize) {
182        self.write_byte(key_type as i8);
183        self.write_byte(value_type as i8);
184        self.write_i32(i32::try_from(size as u64).expect("map size overflow"));
185    }
186
187    #[inline]
188    fn write_map_key_begin(&mut self) {}
189
190    #[inline]
191    fn write_map_value_begin(&mut self) {}
192
193    #[inline]
194    fn write_map_end(&mut self) {}
195
196    fn write_list_begin(&mut self, elem_type: TType, size: usize) {
197        self.write_byte(elem_type as i8);
198        self.write_i32(i32::try_from(size as u64).expect("list size overflow"));
199    }
200
201    #[inline]
202    fn write_list_value_begin(&mut self) {}
203
204    #[inline]
205    fn write_list_end(&mut self) {}
206
207    fn write_set_begin(&mut self, elem_type: TType, size: usize) {
208        self.write_byte(elem_type as i8);
209        self.write_i32(i32::try_from(size as u64).expect("set size overflow"));
210    }
211
212    #[inline]
213    fn write_set_value_begin(&mut self) {}
214
215    fn write_set_end(&mut self) {}
216
217    fn write_bool(&mut self, value: bool) {
218        if value {
219            self.write_byte(1)
220        } else {
221            self.write_byte(0)
222        }
223    }
224
225    fn write_byte(&mut self, value: i8) {
226        self.buffer.put_i8(value)
227    }
228
229    fn write_i16(&mut self, value: i16) {
230        self.buffer.put_i16(value)
231    }
232
233    fn write_i32(&mut self, value: i32) {
234        self.buffer.put_i32(value)
235    }
236
237    fn write_i64(&mut self, value: i64) {
238        self.buffer.put_i64(value)
239    }
240
241    fn write_double(&mut self, value: f64) {
242        self.buffer.put_f64(value)
243    }
244
245    fn write_float(&mut self, value: f32) {
246        self.buffer.put_f32(value)
247    }
248
249    fn write_string(&mut self, value: &str) {
250        self.write_i32(value.len() as i32);
251        self.buffer.put_slice(value.as_bytes())
252    }
253
254    fn write_binary(&mut self, value: &[u8]) {
255        self.write_i32(value.len() as i32);
256        self.buffer.put_slice(value)
257    }
258
259    fn finish(self) -> B::Final {
260        self.buffer.finalize()
261    }
262}
263
264impl<B: BufExt> ProtocolReader for BinaryProtocolDeserializer<B> {
265    fn read_message_begin<F, T>(&mut self, msgfn: F) -> Result<(T, MessageType, u32)>
266    where
267        F: FnOnce(&[u8]) -> T,
268    {
269        let versionty = self.read_i32()? as u32;
270
271        let msgtype = MessageType::try_from(versionty & !BINARY_VERSION_MASK)?; // !u32 -> ~u32
272        let version = versionty & BINARY_VERSION_MASK;
273        ensure_err!(version == BINARY_VERSION_1, ProtocolError::BadVersion);
274
275        let name = {
276            let len = self.read_i32()? as usize;
277            let (len, name) = {
278                if self.peek_bytes(len).is_some() {
279                    let namebuf = self.peek_bytes(len).unwrap();
280                    (namebuf.len(), msgfn(namebuf))
281                } else {
282                    ensure_err!(
283                        self.buffer.remaining() >= len,
284                        ProtocolError::InvalidDataLength
285                    );
286                    let namebuf: Vec<u8> = Vec::copy_from_buf(&mut self.buffer, len);
287                    (0, msgfn(namebuf.as_slice()))
288                }
289            };
290            self.buffer.advance(len);
291            name
292        };
293        let seq_id = self.read_u32()?;
294
295        Ok((name, msgtype, seq_id))
296    }
297
298    fn read_message_end(&mut self) -> Result<()> {
299        Ok(())
300    }
301
302    fn read_struct_begin<F, T>(&mut self, namefn: F) -> Result<T>
303    where
304        F: FnOnce(&[u8]) -> T,
305    {
306        Ok(namefn(&[]))
307    }
308
309    fn read_struct_end(&mut self) -> Result<()> {
310        Ok(())
311    }
312
313    fn read_field_begin<F, T>(&mut self, fieldfn: F, _fields: &[Field]) -> Result<(T, TType, i16)>
314    where
315        F: FnOnce(&[u8]) -> T,
316    {
317        let type_id = TType::try_from(self.read_byte()?)?;
318        let seq_id = match type_id {
319            TType::Stop => 0,
320            _ => self.read_i16()?,
321        };
322        Ok((fieldfn(&[]), type_id, seq_id))
323    }
324
325    fn read_field_end(&mut self) -> Result<()> {
326        Ok(())
327    }
328
329    fn read_map_begin(&mut self) -> Result<(TType, TType, Option<usize>)> {
330        let k_type = TType::try_from(self.read_byte()?)?;
331        let v_type = TType::try_from(self.read_byte()?)?;
332
333        let size = self.read_i32()?;
334        ensure_err!(size >= 0, ProtocolError::InvalidDataLength);
335        Ok((k_type, v_type, Some(size as usize)))
336    }
337
338    #[inline]
339    fn read_map_key_begin(&mut self) -> Result<bool> {
340        Ok(true)
341    }
342
343    #[inline]
344    fn read_map_value_begin(&mut self) -> Result<()> {
345        Ok(())
346    }
347
348    #[inline]
349    fn read_map_value_end(&mut self) -> Result<()> {
350        Ok(())
351    }
352
353    fn read_map_end(&mut self) -> Result<()> {
354        Ok(())
355    }
356
357    fn read_list_begin(&mut self) -> Result<(TType, Option<usize>)> {
358        let elem_type = TType::try_from(self.read_byte()?)?;
359        let size = self.read_i32()?;
360        ensure_err!(size >= 0, ProtocolError::InvalidDataLength);
361        Ok((elem_type, Some(size as usize)))
362    }
363
364    #[inline]
365    fn read_list_value_begin(&mut self) -> Result<bool> {
366        Ok(true)
367    }
368
369    #[inline]
370    fn read_list_value_end(&mut self) -> Result<()> {
371        Ok(())
372    }
373
374    fn read_list_end(&mut self) -> Result<()> {
375        Ok(())
376    }
377
378    fn read_set_begin(&mut self) -> Result<(TType, Option<usize>)> {
379        let elem_type = TType::try_from(self.read_byte()?)?;
380        let size = self.read_i32()?;
381        ensure_err!(size >= 0, ProtocolError::InvalidDataLength);
382        Ok((elem_type, Some(size as usize)))
383    }
384
385    #[inline]
386    fn read_set_value_begin(&mut self) -> Result<bool> {
387        Ok(true)
388    }
389
390    #[inline]
391    fn read_set_value_end(&mut self) -> Result<()> {
392        Ok(())
393    }
394
395    fn read_set_end(&mut self) -> Result<()> {
396        Ok(())
397    }
398
399    fn read_bool(&mut self) -> Result<bool> {
400        match self.read_byte()? {
401            0 => Ok(false),
402            _ => Ok(true),
403        }
404    }
405
406    fn read_byte(&mut self) -> Result<i8> {
407        ensure_err!(self.buffer.remaining() >= 1, ProtocolError::EOF);
408
409        Ok(self.buffer.get_i8())
410    }
411
412    fn read_i16(&mut self) -> Result<i16> {
413        ensure_err!(self.buffer.remaining() >= 2, ProtocolError::EOF);
414
415        Ok(self.buffer.get_i16())
416    }
417
418    fn read_i32(&mut self) -> Result<i32> {
419        ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
420
421        Ok(self.buffer.get_i32())
422    }
423
424    fn read_i64(&mut self) -> Result<i64> {
425        ensure_err!(self.buffer.remaining() >= 8, ProtocolError::EOF);
426
427        Ok(self.buffer.get_i64())
428    }
429
430    fn read_double(&mut self) -> Result<f64> {
431        ensure_err!(self.buffer.remaining() >= 8, ProtocolError::EOF);
432
433        Ok(self.buffer.get_f64())
434    }
435
436    fn read_float(&mut self) -> Result<f32> {
437        ensure_err!(self.buffer.remaining() >= 4, ProtocolError::EOF);
438
439        Ok(self.buffer.get_f32())
440    }
441
442    fn read_string(&mut self) -> Result<String> {
443        let vec = self.read_binary::<Vec<u8>>()?;
444
445        String::from_utf8(vec)
446            .map_err(|utf8_error| anyhow!("deserializing `string` from Thrift binary protocol got invalid utf-8, you need to use `binary` instead: {utf8_error}"))
447    }
448
449    fn read_binary<V: CopyFromBuf>(&mut self) -> Result<V> {
450        let received_len = self.read_i32()?;
451        ensure_err!(received_len >= 0, ProtocolError::InvalidDataLength);
452
453        let received_len = received_len as usize;
454
455        ensure_err!(self.buffer.remaining() >= received_len, ProtocolError::EOF);
456        Ok(V::copy_from_buf(&mut self.buffer, received_len))
457    }
458}
459
460/// How large an item will be when `serialize()` is called
461pub fn serialize_size<T>(v: &T) -> usize
462where
463    T: Serialize<BinaryProtocolSerializer<SizeCounter>>,
464{
465    let mut sizer = BinaryProtocolSerializer::with_buffer(SizeCounter::new());
466    v.write(&mut sizer);
467    sizer.finish()
468}
469
470/// Serialize a Thrift value using the binary protocol to a pre-allocated buffer.
471/// This will panic if the buffer is not large enough. A buffer at least as
472/// large as the return value of `serialize_size` will not panic.
473pub fn serialize_to_buffer<T>(v: T, buffer: BytesMut) -> BinaryProtocolSerializer<BytesMut>
474where
475    T: Serialize<BinaryProtocolSerializer<BytesMut>>,
476{
477    // Now that we have the size, allocate an output buffer and serialize into it
478    let mut buf = BinaryProtocolSerializer::with_buffer(buffer);
479    v.write(&mut buf);
480    buf
481}
482
483pub trait SerializeRef:
484    Serialize<BinaryProtocolSerializer<SizeCounter>> + Serialize<BinaryProtocolSerializer<BytesMut>>
485where
486    for<'a> &'a Self: Serialize<BinaryProtocolSerializer<SizeCounter>>,
487    for<'a> &'a Self: Serialize<BinaryProtocolSerializer<BytesMut>>,
488{
489}
490
491impl<T> SerializeRef for T
492where
493    T: Serialize<BinaryProtocolSerializer<BytesMut>>,
494    T: Serialize<BinaryProtocolSerializer<SizeCounter>>,
495    for<'a> &'a T: Serialize<BinaryProtocolSerializer<BytesMut>>,
496    for<'a> &'a T: Serialize<BinaryProtocolSerializer<SizeCounter>>,
497{
498}
499
500/// Serialize a Thrift value using the binary protocol.
501pub fn serialize<T>(v: T) -> Bytes
502where
503    T: Serialize<BinaryProtocolSerializer<SizeCounter>>
504        + Serialize<BinaryProtocolSerializer<BytesMut>>,
505{
506    let sz = serialize_size(&v);
507    let buf = serialize_to_buffer(v, BytesMut::with_capacity(sz));
508    // Done
509    buf.finish()
510}
511
512pub trait DeserializeSlice:
513    for<'a> Deserialize<BinaryProtocolDeserializer<Cursor<&'a [u8]>>>
514{
515}
516
517impl<T> DeserializeSlice for T where
518    T: for<'a> Deserialize<BinaryProtocolDeserializer<Cursor<&'a [u8]>>>
519{
520}
521
522/// Deserialize a Thrift blob using the binary protocol.
523pub fn deserialize<T, B, C>(b: B) -> Result<T>
524where
525    B: Into<DeserializeSource<C>>,
526    C: BufExt,
527    T: Deserialize<BinaryProtocolDeserializer<C>>,
528{
529    let source: DeserializeSource<C> = b.into();
530    let mut deser = BinaryProtocolDeserializer::new(source.0);
531    T::read(&mut deser)
532}