1use std::net::SocketAddr;
2
3use either::Either;
4use nodecraft::{Domain, DomainRef, HostAddr, HostAddrRef, Node, NodeId, NodeIdRef};
5
6use super::{
7 super::{WireType, merge, skip},
8 Data, DataRef, DecodeError, EncodeError,
9};
10
11impl<'a> DataRef<'a, Domain> for DomainRef<'a> {
12 fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> {
13 DomainRef::try_from(buf)
14 .map(|domain| (buf.len(), domain))
15 .map_err(|e| DecodeError::custom(e.as_str()))
16 }
17}
18
19impl Data for Domain {
20 type Ref<'a> = DomainRef<'a>;
21
22 fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
23 where
24 Self: Sized,
25 {
26 Ok(val.to_owned())
27 }
28
29 fn encoded_len(&self) -> usize {
30 self.fqdn_str().len()
31 }
32
33 fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
34 let val = self.fqdn_str();
35 let len = val.len();
36 if buf.len() < len {
37 return Err(EncodeError::insufficient_buffer(len, buf.len()));
38 }
39 buf[..len].copy_from_slice(val.as_bytes());
40 Ok(len)
41 }
42}
43
44impl<'a, const N: usize> DataRef<'a, NodeId<N>> for NodeIdRef<'a, N> {
45 fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> {
46 NodeIdRef::try_from(buf)
47 .map(|node_id| (buf.len(), node_id))
48 .map_err(|e| DecodeError::custom(e.to_string()))
49 }
50}
51
52impl<const N: usize> Data for NodeId<N> {
53 type Ref<'a> = NodeIdRef<'a, N>;
54
55 fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError> {
56 Ok(Self::new(val).expect("reference must be a valid node id"))
57 }
58
59 fn encoded_len(&self) -> usize {
60 self.as_bytes().len()
61 }
62
63 fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
64 let val = self.as_bytes();
65 let len = val.len();
66 if buf.len() < len {
67 return Err(EncodeError::insufficient_buffer(len, buf.len()));
68 }
69 buf[..len].copy_from_slice(val);
70 Ok(len)
71 }
72}
73
74const _: () = {
75 const HOST_ADDR_SOCKET_TAG: u8 = 1;
76 const HOST_ADDR_SOCKET_BYTE: u8 = merge(WireType::LengthDelimited, HOST_ADDR_SOCKET_TAG);
77 const HOST_ADDR_DOMAIN_TAG: u8 = 2;
78 const HOST_ADDR_DOMAIN_BYTE: u8 = merge(WireType::LengthDelimited, HOST_ADDR_DOMAIN_TAG);
79
80 impl<'a> DataRef<'a, HostAddr> for HostAddrRef<'a> {
81 fn decode(buf: &'a [u8]) -> Result<(usize, Self), DecodeError> {
82 let mut offset = 0;
83 let buf_len = buf.len();
84 let mut value = None;
85
86 while offset < buf_len {
87 match buf[offset] {
88 HOST_ADDR_SOCKET_BYTE => {
89 if value.is_some() {
90 return Err(DecodeError::duplicate_field(
91 "HostAddr",
92 "addr",
93 HOST_ADDR_SOCKET_TAG,
94 ));
95 }
96 offset += 1;
97 let (bytes_read, addr) =
98 <SocketAddr as DataRef<SocketAddr>>::decode_length_delimited(&buf[offset..])?;
99 offset += bytes_read;
100 value = Some(Self::from(addr));
101 }
102 HOST_ADDR_DOMAIN_BYTE => {
103 if value.is_some() {
104 return Err(DecodeError::duplicate_field(
105 "HostAddr",
106 "domain",
107 HOST_ADDR_DOMAIN_TAG,
108 ));
109 }
110
111 offset += 1;
112 let (bytes_read, domain) =
113 <DomainRef<'_> as DataRef<Domain>>::decode_length_delimited(&buf[offset..])?;
114 let required = offset + bytes_read + 2;
115 if required > buf_len {
116 return Err(DecodeError::buffer_underflow());
117 }
118 let port = u16::from_be_bytes(buf[offset + bytes_read..required].try_into().unwrap());
119 offset += bytes_read + 2;
120 value = Some(Self::from((domain, port)));
121 }
122 _ => offset += skip("HostAddr", &buf[offset..])?,
123 }
124 }
125
126 let value = value.ok_or_else(|| DecodeError::missing_field("HostAddr", "value"))?;
127 Ok((offset, value))
128 }
129 }
130
131 impl Data for HostAddr {
132 type Ref<'a> = HostAddrRef<'a>;
133
134 fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError>
135 where
136 Self: Sized,
137 {
138 Ok(val.to_owned())
139 }
140
141 fn encoded_len(&self) -> usize {
142 match self.as_inner() {
143 Either::Left(addr) => 1 + addr.encoded_len_with_length_delimited(),
144 Either::Right((_, domain)) => 1 + 2 + domain.encoded_len_with_length_delimited(),
145 }
146 }
147
148 fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
149 let src_len = buf.len();
150
151 match self.as_inner() {
152 Either::Left(addr) => {
153 if src_len < 1 {
154 return Err(EncodeError::insufficient_buffer(1, src_len));
155 }
156 buf[0] = HOST_ADDR_SOCKET_BYTE;
157 let offset = addr.encode_length_delimited(&mut buf[1..])?;
158 Ok(1 + offset)
159 }
160 Either::Right((port, domain)) => {
161 buf[0] = HOST_ADDR_DOMAIN_BYTE;
162 let offset = domain.encode_length_delimited(&mut buf[1..])?;
163 buf[1 + offset..1 + offset + 2].copy_from_slice(&port.to_be_bytes());
164 Ok(1 + offset + 2)
165 }
166 }
167 }
168 }
169};
170
171const _: () = {
172 const NODE_ID_TAG: u8 = 1;
173 const NODE_ADDR_TAG: u8 = 2;
174
175 #[inline]
176 const fn node_id_byte<I: Data>() -> u8 {
177 merge(I::WIRE_TYPE, NODE_ID_TAG)
178 }
179
180 #[inline]
181 const fn node_addr_byte<A: Data>() -> u8 {
182 merge(A::WIRE_TYPE, NODE_ADDR_TAG)
183 }
184
185 impl<'a, I, A> DataRef<'a, Node<I, A>> for Node<I::Ref<'a>, A::Ref<'a>>
186 where
187 I: Data,
188 A: Data,
189 {
190 fn decode(src: &'a [u8]) -> Result<(usize, Self), DecodeError>
191 where
192 Self: Sized,
193 {
194 let mut offset = 0;
195 let mut id = None;
196 let mut address = None;
197
198 while offset < src.len() {
199 match src[offset] {
200 b if b == node_id_byte::<I>() => {
201 if id.is_some() {
202 return Err(DecodeError::duplicate_field("Node", "id", NODE_ID_TAG));
203 }
204
205 offset += 1;
206 let (bytes_read, value) = I::Ref::decode_length_delimited(&src[offset..])?;
207 offset += bytes_read;
208 id = Some(value);
209 }
210 b if b == node_addr_byte::<A>() => {
211 if address.is_some() {
212 return Err(DecodeError::duplicate_field(
213 "Node",
214 "address",
215 NODE_ADDR_TAG,
216 ));
217 }
218
219 offset += 1;
220 let (bytes_read, value) = A::Ref::decode_length_delimited(&src[offset..])?;
221 offset += bytes_read;
222 address = Some(value);
223 }
224 _ => offset += skip("Node", &src[offset..])?,
225 }
226 }
227
228 let id = id.ok_or_else(|| DecodeError::missing_field("Node", "id"))?;
229 let address = address.ok_or_else(|| DecodeError::missing_field("Node", "address"))?;
230 Ok((offset, Self::new(id, address)))
231 }
232 }
233
234 impl<I, A> Data for Node<I, A>
235 where
236 I: Data,
237 A: Data,
238 {
239 type Ref<'a> = Node<I::Ref<'a>, A::Ref<'a>>;
240
241 fn from_ref(val: Self::Ref<'_>) -> Result<Self, DecodeError> {
242 let (id, address) = val.into_components();
243 I::from_ref(id).and_then(|id| A::from_ref(address).map(|address| Self::new(id, address)))
244 }
245
246 fn encoded_len(&self) -> usize {
247 1 + self.id().encoded_len_with_length_delimited()
248 + 1
249 + self.address().encoded_len_with_length_delimited()
250 }
251
252 fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
253 let src_len = buf.len();
254 if src_len == 0 {
255 return Err(EncodeError::insufficient_buffer(
256 self.encoded_len(),
257 src_len,
258 ));
259 }
260
261 let mut offset = 0;
262 buf[offset] = node_id_byte::<I>();
263 offset += 1;
264 offset += self
265 .id()
266 .encode_length_delimited(&mut buf[offset..])
267 .map_err(|e| e.update(self.encoded_len(), src_len))?;
268
269 if offset >= src_len {
270 return Err(EncodeError::insufficient_buffer(
271 self.encoded_len(),
272 src_len,
273 ));
274 }
275 buf[offset] = node_addr_byte::<A>();
276 offset += 1;
277 offset += self
278 .address()
279 .encode_length_delimited(&mut buf[offset..])
280 .map_err(|e| e.update(self.encoded_len(), src_len))?;
281
282 #[cfg(debug_assertions)]
283 super::super::debug_assert_write_eq::<Self>(offset, self.encoded_len());
284 Ok(offset)
285 }
286 }
287};