Skip to main content

ethrex_p2p/discv4/
messages.rs

1use crate::{
2    types::{Endpoint, Node, NodeRecord},
3    utils::{current_unix_time, node_id},
4};
5use bytes::BufMut;
6use ethrex_common::{H256, H512, H520, utils::keccak};
7use ethrex_crypto::keccak::keccak_hash;
8use ethrex_rlp::{
9    decode::RLPDecode,
10    encode::RLPEncode,
11    error::RLPDecodeError,
12    structs::{self, Decoder, Encoder},
13};
14use secp256k1::{
15    SecretKey,
16    ecdsa::{RecoverableSignature, RecoveryId},
17};
18use std::{convert::Into, io::ErrorKind};
19
20#[derive(Debug, thiserror::Error)]
21pub enum PacketDecodeErr {
22    #[error("RLP decoding error")]
23    RLPDecodeError(#[from] RLPDecodeError),
24    #[error("Invalid packet size")]
25    InvalidSize,
26    #[error("Hash mismatch")]
27    HashMismatch,
28    #[error("Invalid signature")]
29    InvalidSignature,
30    #[error("Discv4 decoding error: {0}")]
31    Discv4DecodingError(String),
32    #[error("Io Error: {0}")]
33    IoError(#[from] std::io::Error),
34}
35
36impl From<PacketDecodeErr> for std::io::Error {
37    fn from(error: PacketDecodeErr) -> Self {
38        std::io::Error::new(ErrorKind::InvalidData, error.to_string())
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct Packet {
44    hash: H256,
45    signature: H520,
46    message: Message,
47    public_key: H512,
48}
49
50impl Packet {
51    pub fn decode(encoded_packet: &[u8]) -> Result<Packet, PacketDecodeErr> {
52        // the packet structure is
53        // hash || signature || packet-type || packet-data
54        let hash_len = 32;
55        let signature_len = 65;
56        let header_size = hash_len + signature_len; // 97
57
58        if encoded_packet.len() < header_size + 1 {
59            return Err(PacketDecodeErr::InvalidSize);
60        };
61
62        let hash = H256::from_slice(&encoded_packet[..hash_len]);
63        let signature_bytes = &encoded_packet[hash_len..header_size];
64        let packet_type = encoded_packet[header_size];
65        let encoded_msg = &encoded_packet[header_size..];
66
67        let header_hash = keccak(&encoded_packet[hash_len..]);
68
69        if hash != header_hash {
70            return Err(PacketDecodeErr::HashMismatch);
71        }
72
73        let digest: [u8; 32] = keccak_hash(encoded_msg);
74
75        let rid = RecoveryId::try_from(Into::<i32>::into(signature_bytes[64]))
76            .map_err(|_| PacketDecodeErr::InvalidSignature)?;
77
78        let peer_pk = secp256k1::SECP256K1
79            .recover_ecdsa(
80                &secp256k1::Message::from_digest(digest),
81                &RecoverableSignature::from_compact(&signature_bytes[0..64], rid)
82                    .map_err(|_| PacketDecodeErr::InvalidSignature)?,
83            )
84            .map_err(|_| PacketDecodeErr::InvalidSignature)?;
85
86        let encoded = peer_pk.serialize_uncompressed();
87
88        let public_key = H512::from_slice(&encoded[1..]);
89        let signature = H520::from_slice(signature_bytes);
90        let message = Message::decode_with_type(packet_type, &encoded_msg[1..])
91            .map_err(PacketDecodeErr::RLPDecodeError)?;
92
93        Ok(Self {
94            hash,
95            signature,
96            message,
97            public_key,
98        })
99    }
100
101    pub fn get_hash(&self) -> H256 {
102        self.hash
103    }
104
105    pub fn get_message(&self) -> &Message {
106        &self.message
107    }
108
109    #[allow(unused)]
110    pub fn get_signature(&self) -> H520 {
111        self.signature
112    }
113
114    pub fn get_public_key(&self) -> H512 {
115        self.public_key
116    }
117
118    pub fn get_node_id(&self) -> H256 {
119        node_id(&self.public_key)
120    }
121}
122
123#[derive(Debug, Eq, PartialEq, Clone)]
124pub enum Message {
125    Ping(PingMessage),
126    Pong(PongMessage),
127    FindNode(FindNodeMessage),
128    Neighbors(NeighborsMessage),
129    ENRRequest(ENRRequestMessage),
130    ENRResponse(ENRResponseMessage),
131}
132
133impl std::fmt::Display for Message {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        let variant = match self {
136            Message::Ping(_) => "Ping",
137            Message::Pong(_) => "Pong",
138            Message::FindNode(_) => "FindNode",
139            Message::Neighbors(_) => "Neighbors",
140            Message::ENRRequest(_) => "ENRRequest",
141            Message::ENRResponse(_) => "ENRResponse",
142        };
143        write!(f, "{variant}")
144    }
145}
146
147impl Message {
148    /// Returns a short, stable label suitable for use as a Prometheus metric label value.
149    pub fn metric_label(&self) -> &'static str {
150        match self {
151            Message::Ping(_) => "Ping",
152            Message::Pong(_) => "Pong",
153            Message::FindNode(_) => "FindNode",
154            Message::Neighbors(_) => "Neighbors",
155            Message::ENRRequest(_) => "ENRRequest",
156            Message::ENRResponse(_) => "ENRResponse",
157        }
158    }
159
160    pub fn encode_with_header(&self, buf: &mut dyn BufMut, node_signer: &SecretKey) {
161        let signature_size = 65_usize;
162        let mut data: Vec<u8> = Vec::with_capacity(signature_size.next_power_of_two());
163        data.resize(signature_size, 0);
164
165        self.encode_with_type(&mut data);
166
167        let digest: [u8; 32] = keccak_hash(&data[signature_size..]);
168
169        let (recovery_id, signature) = secp256k1::SECP256K1
170            .sign_ecdsa_recoverable(&secp256k1::Message::from_digest(digest), node_signer)
171            .serialize_compact();
172
173        data[..signature_size - 1].copy_from_slice(&signature);
174        data[signature_size - 1] = Into::<i32>::into(recovery_id) as u8;
175
176        let hash = keccak_hash(&data[..]);
177        buf.put_slice(&hash);
178        buf.put_slice(&data[..]);
179    }
180
181    fn encode_with_type(&self, buf: &mut dyn BufMut) {
182        buf.put_u8(self.packet_type());
183        match self {
184            Message::Ping(msg) => msg.encode(buf),
185            Message::Pong(msg) => msg.encode(buf),
186            Message::FindNode(msg) => msg.encode(buf),
187            Message::ENRRequest(msg) => msg.encode(buf),
188            Message::ENRResponse(msg) => msg.encode(buf),
189            Message::Neighbors(msg) => msg.encode(buf),
190        }
191    }
192
193    pub fn decode_with_type(packet_type: u8, msg: &[u8]) -> Result<Message, RLPDecodeError> {
194        // NOTE: extra elements inside the message should be ignored, along with extra data
195        // after the message.
196        match packet_type {
197            0x01 => {
198                let (ping, _rest) = PingMessage::decode_unfinished(msg)?;
199                Ok(Message::Ping(ping))
200            }
201            0x02 => {
202                let (pong, _rest) = PongMessage::decode_unfinished(msg)?;
203                Ok(Message::Pong(pong))
204            }
205            0x03 => {
206                let (find_node_msg, _rest) = FindNodeMessage::decode_unfinished(msg)?;
207                Ok(Message::FindNode(find_node_msg))
208            }
209            0x04 => {
210                let (neighbors_msg, _rest) = NeighborsMessage::decode_unfinished(msg)?;
211                Ok(Message::Neighbors(neighbors_msg))
212            }
213            0x05 => {
214                let (enr_request_msg, _rest) = ENRRequestMessage::decode_unfinished(msg)?;
215                Ok(Message::ENRRequest(enr_request_msg))
216            }
217            0x06 => {
218                let (enr_response_msg, _rest) = ENRResponseMessage::decode_unfinished(msg)?;
219                Ok(Message::ENRResponse(enr_response_msg))
220            }
221            _ => Err(RLPDecodeError::MalformedData),
222        }
223    }
224
225    fn packet_type(&self) -> u8 {
226        match self {
227            Message::Ping(_) => 0x01,
228            Message::Pong(_) => 0x02,
229            Message::FindNode(_) => 0x03,
230            Message::Neighbors(_) => 0x04,
231            Message::ENRRequest(_) => 0x05,
232            Message::ENRResponse(_) => 0x06,
233        }
234    }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq)]
238pub struct PingMessage {
239    /// The Ping message version. Should be set to 4, but mustn't be enforced.
240    pub version: u8,
241    /// The endpoint of the sender.
242    pub from: Endpoint,
243    /// The endpoint of the receiver.
244    pub to: Endpoint,
245    /// The expiration time of the message. If the message is older than this time,
246    /// it shouldn't be responded to.
247    pub expiration: u64,
248    /// The ENR sequence number of the sender. This field is optional.
249    pub enr_seq: Option<u64>,
250}
251
252impl PingMessage {
253    pub fn new(from: Endpoint, to: Endpoint, expiration: u64) -> Self {
254        Self {
255            version: 4,
256            from,
257            to,
258            expiration,
259            enr_seq: None,
260        }
261    }
262
263    // TODO: remove when used
264    #[allow(unused)]
265    pub fn with_enr_seq(self, enr_seq: u64) -> Self {
266        Self {
267            enr_seq: Some(enr_seq),
268            ..self
269        }
270    }
271}
272
273impl RLPEncode for PingMessage {
274    fn encode(&self, buf: &mut dyn BufMut) {
275        structs::Encoder::new(buf)
276            .encode_field(&self.version)
277            .encode_field(&self.from)
278            .encode_field(&self.to)
279            .encode_field(&self.expiration)
280            .encode_optional_field(&self.enr_seq)
281            .finish();
282    }
283}
284
285#[derive(Debug, PartialEq, Eq, Clone)]
286pub struct FindNodeMessage {
287    /// The target is a 64-byte secp256k1 public key.
288    pub target: H512,
289    /// The expiration time of the message. If the message is older than this time,
290    /// it shouldn't be responded to.
291    pub expiration: u64,
292}
293
294impl FindNodeMessage {
295    #[allow(unused)]
296    pub fn new(target: H512, expiration: u64) -> Self {
297        Self { target, expiration }
298    }
299}
300
301impl RLPEncode for FindNodeMessage {
302    fn encode(&self, buf: &mut dyn BufMut) {
303        structs::Encoder::new(buf)
304            .encode_field(&self.target)
305            .encode_field(&self.expiration)
306            .finish();
307    }
308}
309
310impl RLPDecode for FindNodeMessage {
311    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
312        let decoder = Decoder::new(rlp)?;
313        let (target, decoder) = decoder.decode_field("target")?;
314        let (expiration, decoder) = decoder.decode_field("expiration")?;
315        let remaining = decoder.finish_unchecked();
316        let msg = FindNodeMessage { target, expiration };
317        Ok((msg, remaining))
318    }
319}
320
321#[derive(Debug, Clone)]
322pub struct FindNodeRequest {
323    /// the number of nodes sent
324    /// we keep track of this number since we will accept neighbor messages until the max_per_bucket
325    pub nodes_sent: u64,
326    /// unix timestamp tracking when we have sent the request
327    pub sent_at: u64,
328    /// if present, server will send the nodes through this channel when receiving neighbors
329    /// useful to wait for the response in lookups
330    pub tx: Option<tokio::sync::mpsc::UnboundedSender<Vec<Node>>>,
331}
332
333impl Default for FindNodeRequest {
334    fn default() -> Self {
335        Self {
336            nodes_sent: 0,
337            sent_at: current_unix_time(),
338            tx: None,
339        }
340    }
341}
342
343impl FindNodeRequest {
344    pub fn new_with_sender(sender: tokio::sync::mpsc::UnboundedSender<Vec<Node>>) -> Self {
345        Self {
346            tx: Some(sender),
347            ..Self::default()
348        }
349    }
350}
351
352impl RLPDecode for PingMessage {
353    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
354        let decoder = Decoder::new(rlp)?;
355        let (version, decoder): (u8, Decoder) = decoder.decode_field("version")?;
356        let (from, decoder) = decoder.decode_field("from")?;
357        let (to, decoder) = decoder.decode_field("to")?;
358        let (expiration, decoder) = decoder.decode_field("expiration")?;
359        let (enr_seq, decoder) = decoder.decode_optional_field();
360
361        let ping = PingMessage {
362            version,
363            from,
364            to,
365            expiration,
366            enr_seq,
367        };
368        // NOTE: as per the spec, any additional elements should be ignored.
369        let remaining = decoder.finish_unchecked();
370        Ok((ping, remaining))
371    }
372}
373
374#[derive(Debug, Clone, Copy, PartialEq, Eq)]
375pub struct PongMessage {
376    /// The endpoint of the receiver.
377    pub to: Endpoint,
378    /// The hash of the corresponding ping packet.
379    pub ping_hash: H256,
380    /// The expiration time of the message. If the message is older than this time,
381    /// it shouldn't be responded to.
382    pub expiration: u64,
383    /// The ENR sequence number of the sender. This field is optional.
384    pub enr_seq: Option<u64>,
385}
386
387impl PongMessage {
388    #[allow(unused)]
389    pub fn new(to: Endpoint, ping_hash: H256, expiration: u64) -> Self {
390        Self {
391            to,
392            ping_hash,
393            expiration,
394            enr_seq: None,
395        }
396    }
397
398    pub fn with_enr_seq(self, enr_seq: u64) -> Self {
399        Self {
400            enr_seq: Some(enr_seq),
401            ..self
402        }
403    }
404}
405
406impl RLPEncode for PongMessage {
407    fn encode(&self, buf: &mut dyn BufMut) {
408        Encoder::new(buf)
409            .encode_field(&self.to)
410            .encode_field(&self.ping_hash)
411            .encode_field(&self.expiration)
412            .encode_optional_field(&self.enr_seq)
413            .finish();
414    }
415}
416
417impl RLPDecode for PongMessage {
418    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
419        let decoder = Decoder::new(rlp)?;
420        let (to, decoder) = decoder.decode_field("to")?;
421        let (ping_hash, decoder) = decoder.decode_field("ping_hash")?;
422        let (expiration, decoder) = decoder.decode_field("expiration")?;
423        let (enr_seq, decoder) = decoder.decode_optional_field();
424
425        let pong = PongMessage {
426            to,
427            ping_hash,
428            expiration,
429            enr_seq,
430        };
431        // NOTE: as per the spec, any additional elements should be ignored.
432        let remaining = decoder.finish_unchecked();
433        Ok((pong, remaining))
434    }
435}
436
437#[derive(Debug, Clone, PartialEq, Eq)]
438pub struct NeighborsMessage {
439    // nodes is the list of neighbors
440    pub nodes: Vec<Node>,
441    pub expiration: u64,
442}
443
444impl NeighborsMessage {
445    pub fn new(nodes: Vec<Node>, expiration: u64) -> Self {
446        Self { nodes, expiration }
447    }
448}
449
450impl RLPDecode for NeighborsMessage {
451    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
452        let decoder = Decoder::new(rlp)?;
453        let (nodes, decoder) = decoder.decode_field("nodes")?;
454        let (expiration, decoder) = decoder.decode_field("expiration")?;
455        let remaining = decoder.finish_unchecked();
456
457        let neighbors = NeighborsMessage::new(nodes, expiration);
458        Ok((neighbors, remaining))
459    }
460}
461
462impl RLPEncode for NeighborsMessage {
463    fn encode(&self, buf: &mut dyn BufMut) {
464        structs::Encoder::new(buf)
465            .encode_field(&self.nodes)
466            .encode_field(&self.expiration)
467            .finish();
468    }
469}
470
471#[derive(Debug, PartialEq, Eq, Clone)]
472pub struct ENRResponseMessage {
473    pub request_hash: H256,
474    pub node_record: NodeRecord,
475}
476
477impl ENRResponseMessage {
478    pub fn new(request_hash: H256, node_record: NodeRecord) -> Self {
479        Self {
480            request_hash,
481            node_record,
482        }
483    }
484}
485
486impl RLPDecode for ENRResponseMessage {
487    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
488        let decoder = Decoder::new(rlp)?;
489        let (request_hash, decoder) = decoder.decode_field("request_hash")?;
490        let (node_record, decoder) = decoder.decode_field("node_record")?;
491        let remaining = decoder.finish_unchecked();
492        let response = ENRResponseMessage {
493            request_hash,
494            node_record,
495        };
496        Ok((response, remaining))
497    }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq)]
501pub struct ENRRequestMessage {
502    pub expiration: u64,
503}
504
505impl ENRRequestMessage {
506    pub fn new(expiration: u64) -> Self {
507        Self { expiration }
508    }
509}
510
511impl RLPDecode for ENRRequestMessage {
512    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
513        let decoder = Decoder::new(rlp)?;
514        let (expiration, decoder) = decoder.decode_field("expiration")?;
515        let remaining = decoder.finish_unchecked();
516        let enr_request = ENRRequestMessage { expiration };
517        Ok((enr_request, remaining))
518    }
519}
520
521impl RLPEncode for ENRRequestMessage {
522    fn encode(&self, buf: &mut dyn BufMut) {
523        structs::Encoder::new(buf)
524            .encode_field(&self.expiration)
525            .finish();
526    }
527}
528
529impl RLPEncode for ENRResponseMessage {
530    fn encode(&self, buf: &mut dyn BufMut) {
531        structs::Encoder::new(buf)
532            .encode_field(&self.request_hash)
533            .encode_field(&self.node_record)
534            .finish();
535    }
536}