Skip to main content

ethrex_p2p/rlpx/
p2p.rs

1use super::{
2    message::RLPxMessage,
3    utils::{decompress_pubkey, snappy_compress},
4};
5use crate::rlpx::utils::{compress_pubkey, snappy_decompress};
6use bytes::BufMut;
7use ethrex_common::H512;
8use ethrex_rlp::structs::{Decoder, Encoder};
9use ethrex_rlp::{
10    decode::{RLPDecode, decode_rlp_item},
11    encode::RLPEncode,
12    error::{RLPDecodeError, RLPEncodeError},
13};
14use secp256k1::PublicKey;
15use serde::Serialize;
16
17pub const SUPPORTED_ETH_CAPABILITIES: [Capability; 4] = [
18    Capability::eth(68),
19    Capability::eth(69),
20    Capability::eth(70),
21    Capability::eth(71),
22];
23pub const SUPPORTED_SNAP_CAPABILITIES: [Capability; 1] = [Capability::snap(1)];
24
25/// The version of the base P2P protocol we support.
26/// This is sent at the start of the Hello message instead of the capabilities list.
27pub const SUPPORTED_P2P_CAPABILITY_VERSION: u8 = 5;
28
29const CAPABILITY_NAME_MAX_LENGTH: usize = 8;
30
31// Pads the input array to the right with zeros to ensure it is 8 bytes long.
32// Panics if the input is longer than 8 bytes.
33const fn pad_right<const N: usize>(input: &[u8; N]) -> [u8; 8] {
34    assert!(
35        N <= CAPABILITY_NAME_MAX_LENGTH,
36        "Input array must be 8 bytes or less"
37    );
38
39    let mut padded = [0_u8; CAPABILITY_NAME_MAX_LENGTH];
40    let mut i = 0;
41    while i < input.len() {
42        padded[i] = input[i];
43        i += 1;
44    }
45    padded
46}
47
48#[derive(Debug, Clone, PartialEq)]
49/// A capability is identified by a short ASCII name (max eight characters) and version number
50pub struct Capability {
51    protocol: [u8; CAPABILITY_NAME_MAX_LENGTH],
52    pub version: u8,
53}
54
55impl Capability {
56    pub const fn eth(version: u8) -> Self {
57        Capability {
58            protocol: pad_right(b"eth"),
59            version,
60        }
61    }
62
63    pub const fn snap(version: u8) -> Self {
64        Capability {
65            protocol: pad_right(b"snap"),
66            version,
67        }
68    }
69
70    pub const fn based(version: u8) -> Self {
71        Capability {
72            protocol: pad_right(b"based"),
73            version,
74        }
75    }
76
77    pub fn protocol(&self) -> &str {
78        let len = self
79            .protocol
80            .iter()
81            .position(|c| c == &b'\0')
82            .unwrap_or(CAPABILITY_NAME_MAX_LENGTH);
83        str::from_utf8(&self.protocol[..len]).expect("value parsed as utf8 in RLPDecode")
84    }
85}
86
87impl RLPEncode for Capability {
88    fn encode(&self, buf: &mut dyn BufMut) {
89        Encoder::new(buf)
90            .encode_field(&self.protocol())
91            .encode_field(&self.version)
92            .finish();
93    }
94}
95
96impl RLPDecode for Capability {
97    fn decode_unfinished(rlp: &[u8]) -> Result<(Self, &[u8]), RLPDecodeError> {
98        let (protocol_name, rest) = String::decode_unfinished(&rlp[1..])?;
99        if protocol_name.len() > CAPABILITY_NAME_MAX_LENGTH {
100            return Err(RLPDecodeError::InvalidLength);
101        }
102        let (version, rest) = u8::decode_unfinished(rest)?;
103        let mut protocol = [0; CAPABILITY_NAME_MAX_LENGTH];
104        protocol[..protocol_name.len()].copy_from_slice(protocol_name.as_bytes());
105        Ok((Capability { protocol, version }, rest))
106    }
107}
108
109impl Serialize for Capability {
110    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
111    where
112        S: serde::Serializer,
113    {
114        serializer.serialize_str(&format!("{}/{}", self.protocol(), self.version))
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct HelloMessage {
120    pub capabilities: Vec<Capability>,
121    pub node_id: PublicKey,
122    pub client_id: String,
123}
124
125impl HelloMessage {
126    pub fn new(capabilities: Vec<Capability>, node_id: PublicKey, client_id: String) -> Self {
127        Self {
128            capabilities,
129            node_id,
130            client_id,
131        }
132    }
133}
134
135impl RLPxMessage for HelloMessage {
136    const CODE: u8 = 0x00;
137    fn encode(&self, mut buf: &mut dyn BufMut) -> Result<(), RLPEncodeError> {
138        Encoder::new(&mut buf)
139            .encode_field(&SUPPORTED_P2P_CAPABILITY_VERSION) // protocolVersion
140            .encode_field(&self.client_id) // clientId
141            .encode_field(&self.capabilities) // capabilities
142            .encode_field(&0u8) // listenPort (ignored)
143            .encode_field(&decompress_pubkey(&self.node_id)) // nodeKey
144            .finish();
145        Ok(())
146    }
147
148    fn decode(msg_data: &[u8]) -> Result<Self, RLPDecodeError> {
149        // decode hello message: [protocolVersion: P, clientId: B, capabilities, listenPort: P, nodeId: B_64, ...]
150        let decoder = Decoder::new(msg_data)?;
151        let (protocol_version, decoder): (u64, _) = decoder.decode_field("protocolVersion")?;
152
153        if protocol_version != SUPPORTED_P2P_CAPABILITY_VERSION as u64 {
154            return Err(RLPDecodeError::IncompatibleProtocol(format!(
155                "Received message is encoded in p2p version {} when negotiated p2p version was {} ",
156                protocol_version, SUPPORTED_P2P_CAPABILITY_VERSION
157            )));
158        }
159
160        let (client_id, decoder): (String, _) = decoder.decode_field("clientId")?;
161
162        // [[cap1, capVersion1], [cap2, capVersion2], ...]
163        let (capabilities, decoder): (Vec<Capability>, _) = decoder.decode_field("capabilities")?;
164
165        // This field should be ignored
166        let (_listen_port, decoder): (u16, _) = decoder.decode_field("listenPort")?;
167
168        let (node_id, decoder): (H512, _) = decoder.decode_field("nodeId")?;
169
170        // Implementations must ignore any additional list elements
171        let _padding = decoder.finish_unchecked();
172
173        Ok(Self::new(
174            capabilities,
175            compress_pubkey(node_id).ok_or(RLPDecodeError::MalformedData)?,
176            client_id,
177        ))
178    }
179}
180
181// Create disconnectreason enum
182#[derive(Debug, Clone, Copy, PartialEq)]
183pub enum DisconnectReason {
184    DisconnectRequested = 0x00,
185    NetworkError = 0x01,
186    ProtocolError = 0x02,
187    UselessPeer = 0x03,
188    TooManyPeers = 0x04,
189    AlreadyConnected = 0x05,
190    IncompatibleVersion = 0x06,
191    InvalidIdentity = 0x07,
192    ClientQuitting = 0x08,
193    UnexpectedIdentity = 0x09,
194    SelfIdentity = 0x0a,
195    PingTimeout = 0x0b,
196    SubprotocolError = 0x10,
197    InvalidReason = 0xff,
198}
199
200impl DisconnectReason {
201    // Returns a vector of all DisconnectReason variants, we need to update this method when we add,
202    // change or remove any DisconnectReason variants which are used in metrics.
203    // A test ensures this method is up to date.
204    pub fn all() -> Vec<DisconnectReason> {
205        vec![
206            DisconnectReason::DisconnectRequested,
207            DisconnectReason::NetworkError,
208            DisconnectReason::ProtocolError,
209            DisconnectReason::UselessPeer,
210            DisconnectReason::TooManyPeers,
211            DisconnectReason::AlreadyConnected,
212            DisconnectReason::IncompatibleVersion,
213            DisconnectReason::InvalidIdentity,
214            DisconnectReason::ClientQuitting,
215            DisconnectReason::UnexpectedIdentity,
216            DisconnectReason::SelfIdentity,
217            DisconnectReason::PingTimeout,
218            DisconnectReason::SubprotocolError,
219            DisconnectReason::InvalidReason,
220        ]
221    }
222}
223
224// impl display for disconnectreason
225impl std::fmt::Display for DisconnectReason {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        match self {
228            DisconnectReason::DisconnectRequested => write!(f, "Disconnect Requested"),
229            DisconnectReason::NetworkError => write!(f, "TCP Subsystem Error"),
230            DisconnectReason::ProtocolError => write!(f, "Breach of Protocol"),
231            DisconnectReason::UselessPeer => write!(f, "Useless Peer"),
232            DisconnectReason::TooManyPeers => write!(f, "Too Many Peers"),
233            DisconnectReason::AlreadyConnected => write!(f, "Already Connected"),
234            DisconnectReason::IncompatibleVersion => {
235                write!(f, "Incompatible P2P Protocol Version")
236            }
237            DisconnectReason::InvalidIdentity => write!(f, "Null Node Identity Received"),
238            DisconnectReason::ClientQuitting => write!(f, "Client Quitting"),
239            DisconnectReason::UnexpectedIdentity => {
240                write!(f, "Unexpected Identity in Handshake")
241            }
242            DisconnectReason::SelfIdentity => {
243                write!(f, "Identity is the Same as This Node")
244            }
245            DisconnectReason::PingTimeout => write!(f, "Ping Timeout"),
246            DisconnectReason::SubprotocolError => {
247                write!(f, "Some Other Reason Specific to a Subprotocol")
248            }
249            DisconnectReason::InvalidReason => write!(f, "Invalid Disconnect Reason"),
250        }
251    }
252}
253
254impl From<u8> for DisconnectReason {
255    fn from(value: u8) -> Self {
256        match value {
257            0x00 => DisconnectReason::DisconnectRequested,
258            0x01 => DisconnectReason::NetworkError,
259            0x02 => DisconnectReason::ProtocolError,
260            0x03 => DisconnectReason::UselessPeer,
261            0x04 => DisconnectReason::TooManyPeers,
262            0x05 => DisconnectReason::AlreadyConnected,
263            0x06 => DisconnectReason::IncompatibleVersion,
264            0x07 => DisconnectReason::InvalidIdentity,
265            0x08 => DisconnectReason::ClientQuitting,
266            0x09 => DisconnectReason::UnexpectedIdentity,
267            0x0a => DisconnectReason::SelfIdentity,
268            0x0b => DisconnectReason::PingTimeout,
269            0x10 => DisconnectReason::SubprotocolError,
270            _ => DisconnectReason::InvalidReason,
271        }
272    }
273}
274
275impl From<DisconnectReason> for u8 {
276    fn from(val: DisconnectReason) -> Self {
277        val as u8
278    }
279}
280#[derive(Debug, Clone)]
281pub struct DisconnectMessage {
282    pub reason: Option<DisconnectReason>,
283}
284
285impl DisconnectMessage {
286    pub fn new(reason: Option<DisconnectReason>) -> Self {
287        Self { reason }
288    }
289
290    /// Returns the meaning of the disconnect reason's error code
291    /// The meaning of each error code is defined by the spec: https://github.com/ethereum/devp2p/blob/master/rlpx.md#disconnect-0x01
292    pub fn reason(&self) -> DisconnectReason {
293        self.reason.unwrap_or(DisconnectReason::InvalidReason)
294    }
295}
296
297impl RLPxMessage for DisconnectMessage {
298    const CODE: u8 = 0x01;
299    fn encode(&self, buf: &mut dyn BufMut) -> Result<(), RLPEncodeError> {
300        let mut encoded_data = vec![];
301        // Disconnect msg_data is reason or none
302        match self.reason.map(Into::<u8>::into) {
303            Some(value) => Encoder::new(&mut encoded_data)
304                .encode_field(&value)
305                .finish(),
306            None => Vec::<u8>::new().encode(&mut encoded_data),
307        }
308        let msg_data = snappy_compress(encoded_data)?;
309        buf.put_slice(&msg_data);
310        Ok(())
311    }
312
313    fn decode(msg_data: &[u8]) -> Result<Self, RLPDecodeError> {
314        // decode disconnect message: [reason (optional)]
315        // The msg data may be compressed or not
316        let msg_data = if let Ok(decompressed) = snappy_decompress(msg_data) {
317            decompressed
318        } else {
319            msg_data.to_vec()
320        };
321        // It seems that disconnect reason can be encoded in different ways:
322        let reason = match msg_data.len() {
323            0 => None,
324            // As a single u8
325            1 => Some(msg_data[0]),
326            // As an RLP encoded Vec<u8>
327            _ => {
328                let decoder = Decoder::new(&msg_data)?;
329                let (reason, _): (Option<u8>, _) = decoder.decode_optional_field();
330                reason
331            }
332        };
333
334        Ok(Self::new(reason.map(|r| r.into())))
335    }
336}
337
338#[derive(Debug, Clone, Copy)]
339pub struct PingMessage {}
340
341impl RLPxMessage for PingMessage {
342    const CODE: u8 = 0x02;
343    fn encode(&self, buf: &mut dyn BufMut) -> Result<(), RLPEncodeError> {
344        let mut encoded_data = vec![];
345        // Ping msg_data is only []
346        Vec::<u8>::new().encode(&mut encoded_data);
347        let msg_data = snappy_compress(encoded_data)?;
348        buf.put_slice(&msg_data);
349        Ok(())
350    }
351
352    fn decode(msg_data: &[u8]) -> Result<Self, RLPDecodeError> {
353        // decode ping message: data is empty list [] or string but it is snappy compressed
354        let decompressed_data = snappy_decompress(msg_data)?;
355        let (_, payload, remaining) = decode_rlp_item(&decompressed_data)?;
356
357        let empty: &[u8] = &[];
358        assert_eq!(payload, empty, "Ping payload should be &[]");
359        assert_eq!(remaining, empty, "Ping remaining should be &[]");
360        Ok(Self {})
361    }
362}
363
364#[derive(Debug, Clone, Copy)]
365pub struct PongMessage {}
366
367impl RLPxMessage for PongMessage {
368    const CODE: u8 = 0x03;
369    fn encode(&self, buf: &mut dyn BufMut) -> Result<(), RLPEncodeError> {
370        let mut encoded_data = vec![];
371        // Pong msg_data is only []
372        Vec::<u8>::new().encode(&mut encoded_data);
373        let msg_data = snappy_compress(encoded_data)?;
374        buf.put_slice(&msg_data);
375        Ok(())
376    }
377
378    fn decode(msg_data: &[u8]) -> Result<Self, RLPDecodeError> {
379        // decode pong message: data is empty list [] or string but it is snappy compressed
380        let decompressed_data = snappy_decompress(msg_data)?;
381        let (_, payload, remaining) = decode_rlp_item(&decompressed_data)?;
382
383        let empty: &[u8] = &[];
384        assert_eq!(payload, empty, "Pong payload should be &[]");
385        assert_eq!(remaining, empty, "Pong remaining should be &[]");
386        Ok(Self {})
387    }
388}