1#![no_std]
2
3use num_traits::FromPrimitive;
4pub use prost;
5use prost::{bytes::Buf, bytes::BufMut};
6use thiserror::Error;
7
8pub mod api {
9 use prost::Message as prostMessage;
10 include!(concat!(env!("OUT_DIR"), "/api.rs"));
11}
12
13pub use api::MessageType;
14
15#[derive(Error, Debug)]
16pub enum Error {
17 #[error("First byte of message was not zero")]
18 InvalidStartByte,
19
20 #[error("Too short buffer")]
21 ShortBuffer,
22
23 #[error("Unknown message type {0}")]
24 UnknownMessageType(u64),
25
26 #[error("Could not decode protobuf")]
27 Decode(prost::DecodeError),
28
29 #[error("Could not encode protobuf")]
30 Encode(prost::EncodeError),
31}
32
33impl From<prost::DecodeError> for Error {
34 fn from(value: prost::DecodeError) -> Self {
35 Self::Decode(value)
36 }
37}
38
39impl From<prost::EncodeError> for Error {
40 fn from(value: prost::EncodeError) -> Self {
41 Self::Encode(value)
42 }
43}
44
45type Result<T> = core::result::Result<T, Error>;
46
47struct Header {
48 _type: api::MessageType,
49 size: u64,
50}
51
52impl Header {
53 pub fn decode(buffer: &mut &[u8]) -> Result<Self> {
54 if buffer.is_empty() {
55 return Err(Error::ShortBuffer);
56 }
57
58 if buffer.get_u8() != 0 {
59 return Err(Error::InvalidStartByte);
60 }
61
62 let size = prost::encoding::decode_varint(buffer)?;
63 let kind = prost::encoding::decode_varint(buffer)?;
64
65 Ok(Header {
66 _type: api::MessageType::from_u64(kind).ok_or(Error::UnknownMessageType(kind))?,
67 size,
68 })
69 }
70
71 pub fn encode_header(&self, buffer: &mut impl BufMut) -> Result<()> {
72 buffer.put_u8(0);
73 try_encode_variant(self.size, buffer)?;
74 try_encode_variant(self._type as u64, buffer)?;
75
76 Ok(())
77 }
78}
79
80pub fn decode_request(buffer: &mut &[u8]) -> Result<api::Message> {
81 let header = Header::decode(buffer)?;
82
83 if header.size as usize > buffer.len() {
84 return Err(Error::ShortBuffer);
85 }
86
87 let body = api::Message::decode(header._type, &mut &buffer[..header.size as usize])?;
88
89 buffer.advance(header.size as usize);
90
91 Ok(body)
92}
93
94pub fn encode_response<T>(buffer: &mut impl BufMut, body: &T) -> Result<usize>
95where
96 T: prost::Message,
97 for<'a> &'a T: Into<MessageType>,
98{
99 let remaining = buffer.remaining_mut();
100
101 let header = Header {
102 _type: body.into(),
103 size: body.encoded_len() as u64,
104 };
105
106 header.encode_header(buffer)?;
107
108 if buffer.remaining_mut() < header.size as usize {
109 return Err(Error::ShortBuffer);
110 }
111
112 body.encode(buffer)?;
113
114 Ok(remaining - buffer.remaining_mut())
115}
116
117fn try_encode_variant(value: u64, buf: &mut impl BufMut) -> Result<()> {
118 let len = prost::encoding::encoded_len_varint(value);
119
120 if buf.remaining_mut() < len {
121 return Err(Error::ShortBuffer);
122 }
123
124 prost::encoding::encode_varint(value, buf);
125
126 Ok(())
127}