use crate::error::{Error, ErrorKind};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PeerMessage {
KeepAlive,
Choke,
Unchoke,
Interested,
NotInterested,
Have(u32),
Bitfield(Vec<u8>),
Request {
index: u32,
begin: u32,
length: u32,
},
Piece {
index: u32,
begin: u32,
data: Vec<u8>,
},
Cancel {
index: u32,
begin: u32,
length: u32,
},
Port(u16),
}
pub fn encode(msg: &PeerMessage) -> Vec<u8> {
tracing::trace!("encoding peer message: {:?}", msg);
match msg {
PeerMessage::KeepAlive => vec![0, 0, 0, 0],
PeerMessage::Choke => vec![0, 0, 0, 1, 0],
PeerMessage::Unchoke => vec![0, 0, 0, 1, 1],
PeerMessage::Interested => vec![0, 0, 0, 1, 2],
PeerMessage::NotInterested => vec![0, 0, 0, 1, 3],
PeerMessage::Have(index) => {
let mut buf = vec![0, 0, 0, 5, 4];
buf.extend_from_slice(&index.to_be_bytes());
buf
}
PeerMessage::Bitfield(bitfield) => {
let len = 1 + bitfield.len() as u32;
let mut buf = Vec::with_capacity(4 + len as usize);
buf.extend_from_slice(&len.to_be_bytes());
buf.push(5);
buf.extend_from_slice(bitfield);
buf
}
PeerMessage::Request {
index,
begin,
length,
}
| PeerMessage::Cancel {
index,
begin,
length,
} => {
let msg_id = if matches!(msg, PeerMessage::Request { .. }) {
6
} else {
8
};
let mut buf = vec![0, 0, 0, 13, msg_id];
buf.extend_from_slice(&index.to_be_bytes());
buf.extend_from_slice(&begin.to_be_bytes());
buf.extend_from_slice(&length.to_be_bytes());
buf
}
PeerMessage::Piece { index, begin, data } => {
let len = 9 + data.len() as u32;
let mut buf = Vec::with_capacity(4 + len as usize);
buf.extend_from_slice(&len.to_be_bytes());
buf.push(7);
buf.extend_from_slice(&index.to_be_bytes());
buf.extend_from_slice(&begin.to_be_bytes());
buf.extend_from_slice(data);
buf
}
PeerMessage::Port(port) => {
let mut buf = vec![0, 0, 0, 3, 9];
buf.extend_from_slice(&port.to_be_bytes());
buf
}
}
}
pub fn decode(data: &[u8]) -> Result<PeerMessage, Error> {
if data.len() < 4 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let len = u32::from_be_bytes([data[0], data[1], data[2], data[3]]);
if len == 0 {
tracing::trace!("decoded peer message: KeepAlive");
return Ok(PeerMessage::KeepAlive);
}
if data.len() < 5 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
if data.len() < 4 + len as usize {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let payload = &data[5..4 + len as usize];
let msg_id = data[4];
match msg_id {
0 => {
if len != 1 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
Ok(PeerMessage::Choke)
}
1 => {
if len != 1 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
Ok(PeerMessage::Unchoke)
}
2 => {
if len != 1 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
Ok(PeerMessage::Interested)
}
3 => {
if len != 1 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
Ok(PeerMessage::NotInterested)
}
4 => {
if len != 5 || payload.len() != 4 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let index = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
Ok(PeerMessage::Have(index))
}
5 => Ok(PeerMessage::Bitfield(payload.to_vec())),
6 => {
if len != 13 || payload.len() != 12 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let index = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let begin = u32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
let length = u32::from_be_bytes([payload[8], payload[9], payload[10], payload[11]]);
Ok(PeerMessage::Request {
index,
begin,
length,
})
}
7 => {
if payload.len() < 8 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let index = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let begin = u32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
let data = payload[8..].to_vec();
Ok(PeerMessage::Piece { index, begin, data })
}
8 => {
if len != 13 || payload.len() != 12 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let index = u32::from_be_bytes([payload[0], payload[1], payload[2], payload[3]]);
let begin = u32::from_be_bytes([payload[4], payload[5], payload[6], payload[7]]);
let length = u32::from_be_bytes([payload[8], payload[9], payload[10], payload[11]]);
Ok(PeerMessage::Cancel {
index,
begin,
length,
})
}
9 => {
if len != 3 || payload.len() != 2 {
return Err(Error::new(ErrorKind::PeerInvalidMessage));
}
let port = u16::from_be_bytes([payload[0], payload[1]]);
Ok(PeerMessage::Port(port))
}
_ => Err(Error::new(ErrorKind::PeerInvalidMessage)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_keepalive() {
assert_eq!(encode(&PeerMessage::KeepAlive), vec![0, 0, 0, 0]);
}
#[test]
fn encode_choke() {
assert_eq!(encode(&PeerMessage::Choke), vec![0, 0, 0, 1, 0]);
}
#[test]
fn encode_unchoke() {
assert_eq!(encode(&PeerMessage::Unchoke), vec![0, 0, 0, 1, 1]);
}
#[test]
fn encode_interested() {
assert_eq!(encode(&PeerMessage::Interested), vec![0, 0, 0, 1, 2]);
}
#[test]
fn encode_not_interested() {
assert_eq!(encode(&PeerMessage::NotInterested), vec![0, 0, 0, 1, 3]);
}
#[test]
fn encode_have() {
let msg = PeerMessage::Have(42);
let encoded = encode(&msg);
assert_eq!(encoded.len(), 9); assert_eq!(&encoded[0..5], &[0, 0, 0, 5, 4]);
assert_eq!(&encoded[5..], &42u32.to_be_bytes());
}
#[test]
fn encode_request() {
let msg = PeerMessage::Request {
index: 1,
begin: 1024,
length: 16384,
};
let encoded = encode(&msg);
assert_eq!(encoded.len(), 17);
}
#[test]
fn encode_piece() {
let data = vec![0xAB; 16384];
let msg = PeerMessage::Piece {
index: 0,
begin: 0,
data: data.clone(),
};
let encoded = encode(&msg);
assert_eq!(encoded.len(), 4 + 1 + 8 + 16384);
let expected_len = (9u32 + 16384u32).to_be_bytes();
assert_eq!(&encoded[0..4], &expected_len);
}
#[test]
fn encode_cancel() {
let msg = PeerMessage::Cancel {
index: 5,
begin: 2048,
length: 8192,
};
let encoded = encode(&msg);
assert_eq!(encoded.len(), 17);
assert_eq!(encoded[4], 8); }
#[test]
fn encode_port() {
let msg = PeerMessage::Port(6881);
let encoded = encode(&msg);
assert_eq!(encoded.len(), 7); assert_eq!(&encoded[0..5], &[0, 0, 0, 3, 9]);
}
#[test]
fn encode_bitfield() {
let bits = vec![0xAA, 0x55, 0xFF];
let msg = PeerMessage::Bitfield(bits.clone());
let encoded = encode(&msg);
assert_eq!(encoded.len(), 5 + 3);
assert_eq!(encoded[4], 5); assert_eq!(&encoded[5..], bits.as_slice());
}
#[test]
fn roundtrip_all_messages() {
let messages = vec![
PeerMessage::KeepAlive,
PeerMessage::Choke,
PeerMessage::Unchoke,
PeerMessage::Interested,
PeerMessage::NotInterested,
PeerMessage::Have(7),
PeerMessage::Bitfield(vec![0xFF, 0x00]),
PeerMessage::Request {
index: 0,
begin: 0,
length: 16384,
},
PeerMessage::Cancel {
index: 1,
begin: 1024,
length: 8192,
},
PeerMessage::Port(6881),
];
for msg in &messages {
let encoded = encode(msg);
let decoded = decode(&encoded).unwrap();
assert_eq!(msg, &decoded, "roundtrip failed for {:?}", msg);
}
}
#[test]
fn roundtrip_piece() {
let msg = PeerMessage::Piece {
index: 3,
begin: 4096,
data: vec![0xCC; 512],
};
let encoded = encode(&msg);
let decoded = decode(&encoded).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn decode_empty_data() {
assert!(decode(b"").is_err());
}
#[test]
fn decode_truncated() {
assert!(decode(&[0, 0, 0]).is_err());
}
#[test]
fn decode_unknown_message_id() {
let data = [0, 0, 0, 1, 255];
assert!(decode(&data).is_err());
}
}