pb_jelly/
lib.rs

1#![warn(rust_2018_idioms)]
2#![allow(clippy::cast_sign_loss)]
3#![allow(clippy::cast_possible_truncation)]
4#![allow(clippy::cast_possible_wrap)]
5
6use std::any::Any;
7use std::collections::BTreeMap;
8use std::default::Default;
9use std::fmt::{
10    self,
11    Debug,
12};
13use std::io::{
14    Cursor,
15    Error,
16    ErrorKind,
17    Result,
18    Write,
19};
20
21use bytes::buf::{
22    Buf,
23    BufMut,
24};
25
26pub mod erased;
27pub mod extensions;
28pub mod helpers;
29pub mod varint;
30pub mod wire_format;
31
32mod buffer;
33pub use crate::buffer::{
34    type_is,
35    CopyWriter,
36    Lazy,
37    PbBuffer,
38    PbBufferReader,
39    PbBufferWriter,
40};
41
42mod base_types;
43pub use crate::base_types::{
44    ClosedProtoEnum,
45    Fixed32,
46    Fixed64,
47    OpenProtoEnum,
48    ProtoEnum,
49    Sfixed32,
50    Sfixed64,
51    Signed32,
52    Signed64,
53};
54
55mod descriptor;
56pub use crate::descriptor::{
57    FieldDescriptor,
58    Label,
59    MessageDescriptor,
60    OneofDescriptor,
61};
62
63pub mod reflection;
64pub use crate::reflection::Reflection;
65
66#[cfg(test)]
67mod tests;
68
69/// Trait implemented by all the messages defined in proto files and base datatypes
70/// like string, bytes, etc. The exact details of this trait is implemented for messages
71/// and base types can be found at - <https://developers.google.com/protocol-buffers/docs/encoding>
72pub trait Message: PartialEq + Default + Debug + Any {
73    /// Returns the `MessageDescriptor` for this message, if this is not a primitive type.
74    fn descriptor(&self) -> Option<MessageDescriptor> {
75        None
76    }
77
78    /// Computes the number of bytes a message will take when serialized. This does not
79    /// include number of bytes required for tag+wire_format or the bytes used to represent
80    /// length of the message in case of LengthDelimited messages/types.
81    fn compute_size(&self) -> usize;
82
83    /// Computes the number of bytes in all grpc slices.
84    /// This information is used to optimize memory allocations in zero-copy encoding.
85    fn compute_grpc_slices_size(&self) -> usize {
86        0
87    }
88
89    /// Serializes the message to the writer.
90    fn serialize<W: PbBufferWriter>(&self, w: &mut W) -> Result<()>;
91
92    /// Reads the message from the blob reader, copying as necessary.
93    fn deserialize<B: PbBufferReader>(&mut self, r: &mut B) -> Result<()>;
94
95    /// Helper method for serializing a message to a [Vec<u8>].
96    #[inline]
97    fn serialize_to_vec(&self) -> Vec<u8> {
98        let size = self.compute_size() as usize;
99        let mut out = Vec::with_capacity(size);
100        // We know that a Cursor<Vec<u8>> only fails on u32 overflow
101        // https://doc.rust-lang.org/src/std/io/cursor.rs.html#295
102        self.serialize(&mut Cursor::new(&mut out)).expect("Vec u32 overflow");
103        debug_assert_eq!(out.len(), size);
104        out
105    }
106
107    /// Helper method for serializing a message to an arbitrary [Write].
108    ///
109    /// If there are [Lazy] fields in the message, their contents will be copied out.
110    #[inline]
111    fn serialize_to_writer<W: Write>(&self, writer: &mut W) -> Result<()> {
112        let mut copy_writer = CopyWriter { writer };
113        self.serialize(&mut copy_writer)?;
114        Ok(())
115    }
116
117    /// Helper method for deserializing a message from a u8 slice.
118    ///
119    /// This will error if there are any [Lazy] fields in the message.
120    #[inline]
121    fn deserialize_from_slice(slice: &[u8]) -> Result<Self> {
122        let mut buf = Cursor::new(slice);
123        let mut m = Self::default();
124        m.deserialize(&mut buf)?;
125        Ok(m)
126    }
127}
128
129pub fn ensure_wire_format(
130    format: wire_format::Type,
131    expected: wire_format::Type,
132    msg_name: &str,
133    field_number: u32,
134) -> Result<()> {
135    if format != expected {
136        return Err(Error::new(
137            ErrorKind::Other,
138            format!(
139                "expected wire_format {:?}, found {:?}, at {:?}:{:?}",
140                expected, format, msg_name, field_number
141            ),
142        ));
143    }
144
145    Ok(())
146}
147
148pub fn unexpected_eof() -> Error {
149    Error::new(ErrorKind::UnexpectedEof, "unexpected EOF")
150}
151
152// XXX: arguably this should not impl PartialEq since we cannot canonicalize the unparsed field contents
153#[derive(Clone, Default, PartialEq)]
154pub struct Unrecognized {
155    by_field_number: BTreeMap<u32, Vec<u8>>,
156}
157
158impl fmt::Debug for Unrecognized {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        f.debug_map()
161            .entries(self.by_field_number.keys().map(|k| (k, ..)))
162            .finish()
163    }
164}
165
166impl Unrecognized {
167    pub fn new() -> Self {
168        Self::default()
169    }
170
171    pub fn serialize(&self, unrecognized_buf: &mut impl PbBufferWriter) -> Result<()> {
172        // Write out sorted by field number
173        for serialized_field in self.by_field_number.values() {
174            unrecognized_buf.write_all(&serialized_field)?;
175        }
176        Ok(())
177    }
178
179    pub fn compute_size(&self) -> usize {
180        self.by_field_number.values().map(|v| v.len()).sum()
181    }
182
183    pub fn gather<B: Buf>(&mut self, field_number: u32, typ: wire_format::Type, buf: &mut B) -> Result<()> {
184        let unrecognized_buf = self.by_field_number.entry(field_number).or_default();
185
186        wire_format::write(field_number, typ, unrecognized_buf)?;
187        let advance = match typ {
188            wire_format::Type::Varint => {
189                if let Some(num) = varint::read(buf)? {
190                    varint::write(num, unrecognized_buf)?;
191                } else {
192                    return Err(unexpected_eof());
193                };
194
195                0
196            },
197            wire_format::Type::Fixed64 => 8,
198            wire_format::Type::Fixed32 => 4,
199            wire_format::Type::LengthDelimited => match varint::read(buf)? {
200                Some(n) => {
201                    varint::write(n, unrecognized_buf)?;
202                    n as usize
203                },
204                None => return Err(unexpected_eof()),
205            },
206        };
207
208        if buf.remaining() < advance {
209            return Err(unexpected_eof());
210        }
211
212        unrecognized_buf.put(buf.take(advance));
213
214        Ok(())
215    }
216
217    pub(crate) fn get_singular_field(&self, field_number: u32) -> Option<(&[u8], wire_format::Type)> {
218        let mut buf = Cursor::new(&self.by_field_number.get(&field_number)?[..]);
219        let mut result = None;
220        // It's technically legal for a singular field to occur multiple times on the wire,
221        // so skip over all but the last instance.
222        while let Some((_field_number, wire_format)) =
223            wire_format::read(&mut buf).expect("self.by_field_number malformed")
224        {
225            result = Some((&buf.get_ref()[buf.position() as usize..], wire_format));
226
227            skip(wire_format, &mut buf).expect("self.by_field_number malformed");
228        }
229        result
230    }
231
232    pub(crate) fn get_fields(&self, field_number: u32) -> &[u8] {
233        self.by_field_number.get(&field_number).map_or(&[], Vec::as_ref)
234    }
235}
236
237pub fn skip<B: Buf>(typ: wire_format::Type, buf: &mut B) -> Result<()> {
238    let advance = match typ {
239        wire_format::Type::Varint => {
240            if varint::read(buf)?.is_none() {
241                return Err(unexpected_eof());
242            };
243
244            0
245        },
246        wire_format::Type::Fixed64 => 8,
247        wire_format::Type::Fixed32 => 4,
248        wire_format::Type::LengthDelimited => match varint::read(buf)? {
249            Some(n) => n as usize,
250            None => return Err(unexpected_eof()),
251        },
252    };
253
254    if buf.remaining() < advance {
255        return Err(unexpected_eof());
256    }
257
258    buf.advance(advance);
259    Ok(())
260}
261
262pub fn ensure_split<B: PbBufferReader>(buf: &mut B, len: usize) -> Result<B> {
263    if buf.remaining() < len {
264        Err(unexpected_eof())
265    } else {
266        Ok(buf.split(len))
267    }
268}