memberlist_proto/data/
nodecraft.rs

1use std::net::SocketAddr;
2
3use either::Either;
4use nodecraft::{Domain, DomainRef, HostAddr, HostAddrRef, Node, NodeId, NodeIdRef};
5
6use super::{
7  super::{WireType, merge, skip},
8  Data, DataRef, DecodeError, EncodeError,
9};
10
11impl<'a> DataRef<'a, Domain> for DomainRef<'a> {
12  fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> {
13    DomainRef::try_from(buf)
14      .map(|domain| (buf.len(), domain))
15      .map_err(|e| DecodeError::custom(e.as_str()))
16  }
17}
18
19impl Data for Domain {
20  type Ref<'a> = DomainRef<'a>;
21
22  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
23  where
24    Self: Sized,
25  {
26    Ok(val.to_owned())
27  }
28
29  fn encoded_len(&self) -> usize {
30    self.fqdn_str().len()
31  }
32
33  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
34    let val = self.fqdn_str();
35    let len = val.len();
36    if buf.len() < len {
37      return Err(EncodeError::insufficient_buffer(len, buf.len()));
38    }
39    buf[..len].copy_from_slice(val.as_bytes());
40    Ok(len)
41  }
42}
43
44impl<'a, const N: usize> DataRef<'a, NodeId<N>> for NodeIdRef<'a, N> {
45  fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> {
46    NodeIdRef::try_from(buf)
47      .map(|node_id| (buf.len(), node_id))
48      .map_err(|e| DecodeError::custom(e.to_string()))
49  }
50}
51
52impl<const N: usize> Data for NodeId<N> {
53  type Ref<'a> = NodeIdRef<'a, N>;
54
55  fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError> {
56    Ok(Self::new(val).expect("reference must be a valid node id"))
57  }
58
59  fn encoded_len(&self) -> usize {
60    self.as_bytes().len()
61  }
62
63  fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
64    let val = self.as_bytes();
65    let len = val.len();
66    if buf.len() < len {
67      return Err(EncodeError::insufficient_buffer(len, buf.len()));
68    }
69    buf[..len].copy_from_slice(val);
70    Ok(len)
71  }
72}
73
74const _: () = {
75  const HOST_ADDR_SOCKET_TAG: u8 = 1;
76  const HOST_ADDR_SOCKET_BYTE: u8 = merge(WireType::LengthDelimited, HOST_ADDR_SOCKET_TAG);
77  const HOST_ADDR_DOMAIN_TAG: u8 = 2;
78  const HOST_ADDR_DOMAIN_BYTE: u8 = merge(WireType::LengthDelimited, HOST_ADDR_DOMAIN_TAG);
79
80  impl<'a> DataRef<'a, HostAddr> for HostAddrRef<'a> {
81    fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> {
82      let mut offset = 0;
83      let buf_len = buf.len();
84      let mut value = None;
85
86      while offset < buf_len {
87        match buf[offset] {
88          HOST_ADDR_SOCKET_BYTE => {
89            if value.is_some() {
90              return Err(DecodeError::duplicate_field(
91                "HostAddr",
92                "addr",
93                HOST_ADDR_SOCKET_TAG,
94              ));
95            }
96            offset += 1;
97            let (bytes_read, addr) =
98              <SocketAddr as DataRef<SocketAddr>>::decode_length_delimited(&buf[offset..])?;
99            offset += bytes_read;
100            value = Some(Self::from(addr));
101          }
102          HOST_ADDR_DOMAIN_BYTE => {
103            if value.is_some() {
104              return Err(DecodeError::duplicate_field(
105                "HostAddr",
106                "domain",
107                HOST_ADDR_DOMAIN_TAG,
108              ));
109            }
110
111            offset += 1;
112            let (bytes_read, domain) =
113              <DomainRef<'_> as DataRef<Domain>>::decode_length_delimited(&buf[offset..])?;
114            let required = offset + bytes_read + 2;
115            if required > buf_len {
116              return Err(DecodeError::buffer_underflow());
117            }
118            let port = u16::from_be_bytes(buf[offset + bytes_read..required].try_into().unwrap());
119            offset += bytes_read + 2;
120            value = Some(Self::from((domain, port)));
121          }
122          _ => offset += skip("HostAddr", &buf[offset..])?,
123        }
124      }
125
126      let value = value.ok_or_else(|| DecodeError::missing_field("HostAddr", "value"))?;
127      Ok((offset, value))
128    }
129  }
130
131  impl Data for HostAddr {
132    type Ref<'a> = HostAddrRef<'a>;
133
134    fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
135    where
136      Self: Sized,
137    {
138      Ok(val.to_owned())
139    }
140
141    fn encoded_len(&self) -> usize {
142      match self.as_inner() {
143        Either::Left(addr) => 1 + addr.encoded_len_with_length_delimited(),
144        Either::Right((_, domain)) => 1 + 2 + domain.encoded_len_with_length_delimited(),
145      }
146    }
147
148    fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
149      let src_len = buf.len();
150
151      match self.as_inner() {
152        Either::Left(addr) => {
153          if src_len < 1 {
154            return Err(EncodeError::insufficient_buffer(1, src_len));
155          }
156          buf[0] = HOST_ADDR_SOCKET_BYTE;
157          let offset = addr.encode_length_delimited(&mut buf[1..])?;
158          Ok(1 + offset)
159        }
160        Either::Right((port, domain)) => {
161          buf[0] = HOST_ADDR_DOMAIN_BYTE;
162          let offset = domain.encode_length_delimited(&mut buf[1..])?;
163          buf[1 + offset..1 + offset + 2].copy_from_slice(&port.to_be_bytes());
164          Ok(1 + offset + 2)
165        }
166      }
167    }
168  }
169};
170
171const _: () = {
172  const NODE_ID_TAG: u8 = 1;
173  const NODE_ADDR_TAG: u8 = 2;
174
175  #[inline]
176  const fn node_id_byte<I: Data>() -> u8 {
177    merge(I::WIRE_TYPE, NODE_ID_TAG)
178  }
179
180  #[inline]
181  const fn node_addr_byte<A: Data>() -> u8 {
182    merge(A::WIRE_TYPE, NODE_ADDR_TAG)
183  }
184
185  impl<'a, I, A> DataRef<'a, Node<I, A>> for Node<I::Ref<'a>, A::Ref<'a>>
186  where
187    I: Data,
188    A: Data,
189  {
190    fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError>
191    where
192      Self: Sized,
193    {
194      let mut offset = 0;
195      let mut id = None;
196      let mut address = None;
197
198      while offset < src.len() {
199        match src[offset] {
200          b if b == node_id_byte::<I>() => {
201            if id.is_some() {
202              return Err(DecodeError::duplicate_field("Node", "id", NODE_ID_TAG));
203            }
204
205            offset += 1;
206            let (bytes_read, value) = I::Ref::decode_length_delimited(&src[offset..])?;
207            offset += bytes_read;
208            id = Some(value);
209          }
210          b if b == node_addr_byte::<A>() => {
211            if address.is_some() {
212              return Err(DecodeError::duplicate_field(
213                "Node",
214                "address",
215                NODE_ADDR_TAG,
216              ));
217            }
218
219            offset += 1;
220            let (bytes_read, value) = A::Ref::decode_length_delimited(&src[offset..])?;
221            offset += bytes_read;
222            address = Some(value);
223          }
224          _ => offset += skip("Node", &src[offset..])?,
225        }
226      }
227
228      let id = id.ok_or_else(|| DecodeError::missing_field("Node", "id"))?;
229      let address = address.ok_or_else(|| DecodeError::missing_field("Node", "address"))?;
230      Ok((offset, Self::new(id, address)))
231    }
232  }
233
234  impl<I, A> Data for Node<I, A>
235  where
236    I: Data,
237    A: Data,
238  {
239    type Ref<'a> = Node<I::Ref<'a>, A::Ref<'a>>;
240
241    fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError> {
242      let (id, address) = val.into_components();
243      I::from_ref(id).and_then(|id| A::from_ref(address).map(|address| Self::new(id, address)))
244    }
245
246    fn encoded_len(&self) -> usize {
247      1 + self.id().encoded_len_with_length_delimited()
248        + 1
249        + self.address().encoded_len_with_length_delimited()
250    }
251
252    fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
253      let src_len = buf.len();
254      if src_len == 0 {
255        return Err(EncodeError::insufficient_buffer(
256          self.encoded_len(),
257          src_len,
258        ));
259      }
260
261      let mut offset = 0;
262      buf[offset] = node_id_byte::<I>();
263      offset += 1;
264      offset += self
265        .id()
266        .encode_length_delimited(&mut buf[offset..])
267        .map_err(|e| e.update(self.encoded_len(), src_len))?;
268
269      if offset >= src_len {
270        return Err(EncodeError::insufficient_buffer(
271          self.encoded_len(),
272          src_len,
273        ));
274      }
275      buf[offset] = node_addr_byte::<A>();
276      offset += 1;
277      offset += self
278        .address()
279        .encode_length_delimited(&mut buf[offset..])
280        .map_err(|e| e.update(self.encoded_len(), src_len))?;
281
282      #[cfg(debug_assertions)]
283      super::super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
284      Ok(offset)
285    }
286  }
287};