cs_mwc_bch/messages/
message_header.rs

1use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
2use ring::digest;
3use std::fmt;
4use std::io;
5use std::io::{Cursor, Read, Write};
6use std::str;
7use util::{Error, Result, Serializable};
8
9/// Header that begins all messages
10#[derive(Default, PartialEq, Eq, Hash, Clone)]
11pub struct MessageHeader {
12    /// Magic bytes indicating the network type
13    pub magic: [u8; 4],
14    /// Command name
15    pub command: [u8; 12],
16    /// Payload size
17    pub payload_size: u32,
18    /// First 4 bytes of SHA256(SHA256(payload))
19    pub checksum: [u8; 4],
20}
21
22impl MessageHeader {
23    /// Size of the message header in bytes
24    pub const SIZE: usize = 24;
25
26    /// Returns the size of the header in bytes
27    pub fn size(&self) -> usize {
28        MessageHeader::SIZE
29    }
30
31    /// Checks if the header is valid
32    ///
33    /// `magic` - Expected magic bytes for the network
34    /// `max_size` - Max size in bytes for the payload
35    pub fn validate(&self, magic: [u8; 4], max_size: u32) -> Result<()> {
36        if self.magic != magic {
37            let msg = format!("Bad magic: {:?}", self.magic);
38            return Err(Error::BadData(msg));
39        }
40        if self.payload_size > max_size {
41            let msg = format!("Bad size: {:?}", self.payload_size);
42            return Err(Error::BadData(msg));
43        }
44        Ok(())
45    }
46
47    /// Reads the payload and verifies its checksum
48    pub fn payload(&self, reader: &mut dyn Read) -> Result<Vec<u8>> {
49        let mut p = vec![0; self.payload_size as usize];
50        reader.read_exact(p.as_mut())?;
51        let hash = digest::digest(&digest::SHA256, p.as_ref());
52        let hash = digest::digest(&digest::SHA256, &hash.as_ref());
53        let h = &hash.as_ref();
54        let j = &self.checksum;
55        if h[0] != j[0] || h[1] != j[1] || h[2] != j[2] || h[3] != j[3] {
56            let msg = format!("Bad checksum: {:?} != {:?}", &h[..4], j);
57            return Err(Error::BadData(msg));
58        }
59        Ok(p)
60    }
61}
62
63impl Serializable<MessageHeader> for MessageHeader {
64    fn read(reader: &mut dyn Read) -> Result<MessageHeader> {
65        // Read all the bytes at once so that the stream doesn't get in a partially-read state
66        let mut p = vec![0; MessageHeader::SIZE];
67        reader.read_exact(p.as_mut())?;
68        let mut c = Cursor::new(p);
69
70        // Now parse the results from the stream
71        let mut ret = MessageHeader {
72            ..Default::default()
73        };
74        c.read(&mut ret.magic)?;
75        c.read(&mut ret.command)?;
76        ret.payload_size = c.read_u32::<LittleEndian>()?;
77        c.read(&mut ret.checksum)?;
78
79        Ok(ret)
80    }
81
82    fn write(&self, writer: &mut dyn Write) -> io::Result<()> {
83        writer.write(&self.magic)?;
84        writer.write(&self.command)?;
85        writer.write_u32::<LittleEndian>(self.payload_size)?;
86        writer.write(&self.checksum)?;
87        Ok(())
88    }
89}
90
91// Prints so the command is easier to read
92impl fmt::Debug for MessageHeader {
93    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
94        let command = match str::from_utf8(&self.command) {
95            Ok(s) => s.to_string(),
96            Err(_) => format!("Not Ascii ({:?})", self.command),
97        };
98        write!(
99            f,
100            "Header {{ magic: {:?}, command: {:?}, payload_size: {}, checksum: {:?} }}",
101            self.magic, command, self.payload_size, self.checksum
102        )
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use hex;
110    use std::io::Cursor;
111
112    #[test]
113    fn read_bytes() {
114        let b = hex::decode("f9beb4d976657273696f6e00000000007a0000002a1957bb".as_bytes()).unwrap();
115        let h = MessageHeader::read(&mut Cursor::new(&b)).unwrap();
116        assert!(h.magic == [0xf9, 0xbe, 0xb4, 0xd9]);
117        assert!(h.command == *b"version\0\0\0\0\0");
118        assert!(h.payload_size == 122);
119        assert!(h.checksum == [0x2a, 0x19, 0x57, 0xbb]);
120    }
121
122    #[test]
123    fn write_read() {
124        let mut v = Vec::new();
125        let h = MessageHeader {
126            magic: [0x00, 0x01, 0x02, 0x03],
127            command: *b"command\0\0\0\0\0",
128            payload_size: 42,
129            checksum: [0xa0, 0xa1, 0xa2, 0xa3],
130        };
131        h.write(&mut v).unwrap();
132        assert!(v.len() == h.size());
133        assert!(MessageHeader::read(&mut Cursor::new(&v)).unwrap() == h);
134    }
135
136    #[test]
137    fn validate() {
138        let magic = [0xa0, 0xa1, 0xa2, 0xa3];
139        let h = MessageHeader {
140            magic,
141            command: *b"verack\0\0\0\0\0\0",
142            payload_size: 88,
143            checksum: [0x12, 0x34, 0x56, 0x78],
144        };
145        // Valid
146        assert!(h.validate(magic, 100).is_ok());
147        // Bad magic
148        let bad_magic = [0xb0, 0xb1, 0xb2, 0xb3];
149        assert!(h.validate(bad_magic, 100).is_err());
150        // Bad size
151        assert!(h.validate(magic, 50).is_err());
152    }
153
154    #[test]
155    fn payload() {
156        let p = [0x22, 0x33, 0x44, 0x00, 0x11, 0x22, 0x45, 0x67, 0x89];
157        let hash = digest::digest(&digest::SHA256, &p);
158        let hash = digest::digest(&digest::SHA256, hash.as_ref());
159        let hash = hash.as_ref();
160        let checksum = [hash[0], hash[1], hash[2], hash[3]];
161        let header = MessageHeader {
162            magic: [0x00, 0x00, 0x00, 0x00],
163            command: *b"version\0\0\0\0\0",
164            payload_size: p.len() as u32,
165            checksum,
166        };
167        // Valid
168        let v = header.payload(&mut Cursor::new(&p)).unwrap();
169        assert!(v.as_ref() == p);
170        // Bad checksum
171        let p2 = [0xf2, 0xf3, 0xf4, 0xf0, 0xf1, 0xf2, 0xf5, 0xf7, 0xf9];
172        assert!(header.payload(&mut Cursor::new(&p2)).is_err());
173    }
174}