Skip to main content

esphome_proto/
lib.rs

1#![no_std]
2
3use num_traits::FromPrimitive;
4pub use prost;
5use prost::{bytes::Buf, bytes::BufMut};
6use thiserror::Error;
7
8pub mod api {
9    use prost::Message as prostMessage;
10    include!(concat!(env!("OUT_DIR"), "/api.rs"));
11}
12
13pub use api::MessageType;
14
15#[derive(Error, Debug)]
16pub enum Error {
17    #[error("First byte of message was not zero")]
18    InvalidStartByte,
19
20    #[error("Too short buffer")]
21    ShortBuffer,
22
23    #[error("Unknown message type {0}")]
24    UnknownMessageType(u64),
25
26    #[error("Could not decode protobuf")]
27    Decode(prost::DecodeError),
28
29    #[error("Could not encode protobuf")]
30    Encode(prost::EncodeError),
31}
32
33impl From<prost::DecodeError> for Error {
34    fn from(value: prost::DecodeError) -> Self {
35        Self::Decode(value)
36    }
37}
38
39impl From<prost::EncodeError> for Error {
40    fn from(value: prost::EncodeError) -> Self {
41        Self::Encode(value)
42    }
43}
44
45type Result<T> = core::result::Result<T, Error>;
46
47struct Header {
48    _type: api::MessageType,
49    size: u64,
50}
51
52impl Header {
53    pub fn decode(buffer: &mut &[u8]) -> Result<Self> {
54        if buffer.is_empty() {
55            return Err(Error::ShortBuffer);
56        }
57
58        if buffer.get_u8() != 0 {
59            return Err(Error::InvalidStartByte);
60        }
61
62        let size = prost::encoding::decode_varint(buffer)?;
63        let kind = prost::encoding::decode_varint(buffer)?;
64
65        Ok(Header {
66            _type: api::MessageType::from_u64(kind).ok_or(Error::UnknownMessageType(kind))?,
67            size,
68        })
69    }
70
71    pub fn encode_header(&self, buffer: &mut impl BufMut) -> Result<()> {
72        buffer.put_u8(0);
73        try_encode_variant(self.size, buffer)?;
74        try_encode_variant(self._type as u64, buffer)?;
75
76        Ok(())
77    }
78}
79
80pub fn decode_request(buffer: &mut &[u8]) -> Result<api::Message> {
81    let header = Header::decode(buffer)?;
82
83    if header.size as usize > buffer.len() {
84        return Err(Error::ShortBuffer);
85    }
86
87    let body = api::Message::decode(header._type, &mut &buffer[..header.size as usize])?;
88
89    buffer.advance(header.size as usize);
90
91    Ok(body)
92}
93
94pub fn encode_response<T>(buffer: &mut impl BufMut, body: &T) -> Result<usize>
95where
96    T: prost::Message,
97    for<'a> &'a T: Into<MessageType>,
98{
99    let remaining = buffer.remaining_mut();
100
101    let header = Header {
102        _type: body.into(),
103        size: body.encoded_len() as u64,
104    };
105
106    header.encode_header(buffer)?;
107
108    if buffer.remaining_mut() < header.size as usize {
109        return Err(Error::ShortBuffer);
110    }
111
112    body.encode(buffer)?;
113
114    Ok(remaining - buffer.remaining_mut())
115}
116
117fn try_encode_variant(value: u64, buf: &mut impl BufMut) -> Result<()> {
118    let len = prost::encoding::encoded_len_varint(value);
119
120    if buf.remaining_mut() < len {
121        return Err(Error::ShortBuffer);
122    }
123
124    prost::encoding::encode_varint(value, buf);
125
126    Ok(())
127}