use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use crate::bencode::{self, Bencode, Bytes};
use crate::error::{Error, ErrorKind};
pub type TransactionId = [u8; 2];
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum KrpcMessage {
Query {
transaction_id: TransactionId,
method: String,
args: Bencode,
},
Response {
transaction_id: TransactionId,
result: Bencode,
},
Error {
transaction_id: TransactionId,
code: i64,
message: String,
},
}
impl KrpcMessage {
pub fn to_bytes(&self) -> Vec<u8> {
tracing::trace!("encoding KRPC message: {:?}", self);
let dict = match self {
KrpcMessage::Query {
transaction_id,
method,
args,
} => Bencode::Dict(vec![
(
t_key(),
Bencode::Bytes(Bytes::copy_from_slice(transaction_id)),
),
(y_key(), Bencode::Bytes(Bytes::copy_from_slice(b"q"))),
(
q_key(),
Bencode::Bytes(Bytes::copy_from_slice(method.as_bytes())),
),
(a_key(), args.clone()),
]),
KrpcMessage::Response {
transaction_id,
result,
} => Bencode::Dict(vec![
(
t_key(),
Bencode::Bytes(Bytes::copy_from_slice(transaction_id)),
),
(y_key(), Bencode::Bytes(Bytes::copy_from_slice(b"r"))),
(r_key(), result.clone()),
]),
KrpcMessage::Error {
transaction_id,
code,
message,
} => Bencode::Dict(vec![
(
t_key(),
Bencode::Bytes(Bytes::copy_from_slice(transaction_id)),
),
(y_key(), Bencode::Bytes(Bytes::copy_from_slice(b"e"))),
(
e_key(),
Bencode::List(vec![
Bencode::Integer(*code),
Bencode::Bytes(Bytes::copy_from_slice(message.as_bytes())),
]),
),
]),
};
bencode::encode(&dict)
}
pub fn from_bytes(data: &[u8]) -> Result<Self, Error> {
tracing::trace!("decoding KRPC message ({} bytes)", data.len());
let (val, _rest) = bencode::decode(data)?;
Self::from_bencode(&val)
}
pub fn from_bencode(val: &Bencode) -> Result<Self, Error> {
let t = dict_get_bytes(val, b"t").ok_or(Error::new(ErrorKind::Protocol))?;
let mut transaction_id = [0u8; 2];
let len = std::cmp::min(t.len(), 2);
transaction_id[..len].copy_from_slice(&t[..len]);
let y = dict_get_bytes(val, b"y").ok_or(Error::new(ErrorKind::Protocol))?;
let y_byte = if !y.is_empty() { y[0] } else { 0 };
match y_byte {
b'q' => {
let method = dict_get_bytes(val, b"q")
.and_then(|b| String::from_utf8(b.to_vec()).ok())
.ok_or(Error::new(ErrorKind::Protocol))?;
let args = dict_get(val, b"a")
.cloned()
.unwrap_or(Bencode::Dict(vec![]));
Ok(KrpcMessage::Query {
transaction_id,
method,
args,
})
}
b'r' => {
let result = dict_get(val, b"r")
.cloned()
.unwrap_or(Bencode::Dict(vec![]));
Ok(KrpcMessage::Response {
transaction_id,
result,
})
}
b'e' => {
let err_val = dict_get(val, b"e").ok_or(Error::new(ErrorKind::Protocol))?;
match err_val {
Bencode::List(items) if items.len() >= 2 => {
let code = match &items[0] {
Bencode::Integer(c) => *c,
_ => return Err(Error::new(ErrorKind::Protocol)),
};
let message = match &items[1] {
Bencode::Bytes(b) => String::from_utf8(b.to_vec()).unwrap_or_default(),
_ => return Err(Error::new(ErrorKind::Protocol)),
};
Ok(KrpcMessage::Error {
transaction_id,
code,
message,
})
}
_ => Err(Error::new(ErrorKind::Protocol)),
}
}
_ => Err(Error::new(ErrorKind::Protocol)),
}
}
}
pub fn build_ping(tid: TransactionId, node_id: &[u8; 20]) -> Vec<u8> {
KrpcMessage::Query {
transaction_id: tid,
method: "ping".into(),
args: Bencode::Dict(vec![(
id_key(),
Bencode::Bytes(Bytes::copy_from_slice(node_id)),
)]),
}
.to_bytes()
}
pub fn build_find_node(tid: TransactionId, node_id: &[u8; 20], target: &[u8; 20]) -> Vec<u8> {
KrpcMessage::Query {
transaction_id: tid,
method: "find_node".into(),
args: Bencode::Dict(vec![
(id_key(), Bencode::Bytes(Bytes::copy_from_slice(node_id))),
(target_key(), Bencode::Bytes(Bytes::copy_from_slice(target))),
]),
}
.to_bytes()
}
pub fn build_get_peers(tid: TransactionId, node_id: &[u8; 20], info_hash: &[u8; 20]) -> Vec<u8> {
KrpcMessage::Query {
transaction_id: tid,
method: "get_peers".into(),
args: Bencode::Dict(vec![
(id_key(), Bencode::Bytes(Bytes::copy_from_slice(node_id))),
(
info_hash_key(),
Bencode::Bytes(Bytes::copy_from_slice(info_hash)),
),
]),
}
.to_bytes()
}
pub fn build_announce_peer(
tid: TransactionId, node_id: &[u8; 20], info_hash: &[u8; 20], port: u16, token: &[u8],
) -> Vec<u8> {
KrpcMessage::Query {
transaction_id: tid,
method: "announce_peer".into(),
args: Bencode::Dict(vec![
(id_key(), Bencode::Bytes(Bytes::copy_from_slice(node_id))),
(
info_hash_key(),
Bencode::Bytes(Bytes::copy_from_slice(info_hash)),
),
(Bytes::from("port"), Bencode::Integer(port as i64)),
(token_key(), Bencode::Bytes(Bytes::copy_from_slice(token))),
]),
}
.to_bytes()
}
pub fn parse_ping_response(msg: &KrpcMessage) -> Result<[u8; 20], Error> {
match msg {
KrpcMessage::Response { result, .. } => {
let node_id = dict_get_bytes(result, b"id").ok_or(Error::new(ErrorKind::Protocol))?;
let mut id = [0u8; 20];
let len = std::cmp::min(node_id.len(), 20);
id[..len].copy_from_slice(&node_id[..len]);
Ok(id)
}
_ => Err(Error::new(ErrorKind::Protocol)),
}
}
#[derive(Debug, Clone)]
pub enum GetPeersResult {
Values {
token: Vec<u8>,
peers: Vec<SocketAddr>,
},
Nodes(Vec<super::Node>),
}
pub fn parse_get_peers_response(msg: &KrpcMessage) -> Result<GetPeersResult, Error> {
match msg {
KrpcMessage::Response { result, .. } => {
let token = dict_get_bytes(result, b"token")
.map(|b| b.to_vec())
.ok_or(Error::new(ErrorKind::Protocol))?;
if let Some(Bencode::List(values)) = dict_get(result, b"values") {
let mut peers = Vec::new();
for v in values {
if let Bencode::Bytes(b) = v
&& b.len() == 6
{
let ip = Ipv4Addr::new(b[0], b[1], b[2], b[3]);
let port = u16::from_be_bytes([b[4], b[5]]);
peers.push(SocketAddr::new(IpAddr::V4(ip), port));
}
}
return Ok(GetPeersResult::Values { token, peers });
}
if let Some(nodes_bytes) = dict_get_bytes(result, b"nodes") {
let nodes = parse_compact_nodes(nodes_bytes);
return Ok(GetPeersResult::Nodes(nodes));
}
Err(Error::new(ErrorKind::Protocol))
}
_ => Err(Error::new(ErrorKind::Protocol)),
}
}
pub fn parse_compact_nodes(data: &[u8]) -> Vec<super::Node> {
data.chunks_exact(26)
.map(|chunk| {
let mut id = [0u8; 20];
id.copy_from_slice(&chunk[..20]);
let ip = Ipv4Addr::new(chunk[20], chunk[21], chunk[22], chunk[23]);
let port = u16::from_be_bytes([chunk[24], chunk[25]]);
super::Node {
id,
addr: SocketAddr::new(IpAddr::V4(ip), port),
}
})
.collect()
}
pub fn encode_compact_nodes(nodes: &[super::Node]) -> Vec<u8> {
let mut data = Vec::with_capacity(nodes.len() * 26);
for node in nodes {
data.extend_from_slice(&node.id);
let ip = match node.addr.ip() {
IpAddr::V4(v4) => v4.octets(),
_ => continue, };
data.extend_from_slice(&ip);
data.extend_from_slice(&node.addr.port().to_be_bytes());
}
data
}
pub fn build_ping_response(tid: TransactionId, node_id: &[u8; 20]) -> Vec<u8> {
KrpcMessage::Response {
transaction_id: tid,
result: Bencode::Dict(vec![(
id_key(),
Bencode::Bytes(Bytes::copy_from_slice(node_id)),
)]),
}
.to_bytes()
}
pub fn build_find_node_response(
tid: TransactionId, node_id: &[u8; 20], nodes: &[super::Node],
) -> Vec<u8> {
let compact = encode_compact_nodes(nodes);
KrpcMessage::Response {
transaction_id: tid,
result: Bencode::Dict(vec![
(id_key(), Bencode::Bytes(Bytes::copy_from_slice(node_id))),
(Bytes::from("nodes"), Bencode::Bytes(Bytes::from(compact))),
]),
}
.to_bytes()
}
pub fn build_get_peers_response_values(
tid: TransactionId, node_id: &[u8; 20], token: &[u8], peers: &[SocketAddr],
) -> Vec<u8> {
let peer_list: Vec<Bencode> = peers
.iter()
.filter_map(|addr| match addr.ip() {
IpAddr::V4(v4) => {
let mut data = Vec::new();
data.extend_from_slice(&v4.octets());
data.extend_from_slice(&addr.port().to_be_bytes());
Some(Bencode::Bytes(Bytes::from(data)))
}
_ => None,
})
.collect();
KrpcMessage::Response {
transaction_id: tid,
result: Bencode::Dict(vec![
(id_key(), Bencode::Bytes(Bytes::copy_from_slice(node_id))),
(
Bytes::from("token"),
Bencode::Bytes(Bytes::copy_from_slice(token)),
),
(Bytes::from("values"), Bencode::List(peer_list)),
]),
}
.to_bytes()
}
pub fn build_get_peers_response_nodes(
tid: TransactionId, node_id: &[u8; 20], token: &[u8], nodes: &[super::Node],
) -> Vec<u8> {
let compact = encode_compact_nodes(nodes);
KrpcMessage::Response {
transaction_id: tid,
result: Bencode::Dict(vec![
(id_key(), Bencode::Bytes(Bytes::copy_from_slice(node_id))),
(
Bytes::from("token"),
Bencode::Bytes(Bytes::copy_from_slice(token)),
),
(Bytes::from("nodes"), Bencode::Bytes(Bytes::from(compact))),
]),
}
.to_bytes()
}
pub fn build_error_response(tid: TransactionId, code: i64, message: &str) -> Vec<u8> {
KrpcMessage::Error {
transaction_id: tid,
code,
message: message.into(),
}
.to_bytes()
}
fn t_key() -> Bytes {
Bytes::from("t")
}
fn y_key() -> Bytes {
Bytes::from("y")
}
fn q_key() -> Bytes {
Bytes::from("q")
}
fn a_key() -> Bytes {
Bytes::from("a")
}
fn r_key() -> Bytes {
Bytes::from("r")
}
fn e_key() -> Bytes {
Bytes::from("e")
}
fn id_key() -> Bytes {
Bytes::from("id")
}
fn target_key() -> Bytes {
Bytes::from("target")
}
fn info_hash_key() -> Bytes {
Bytes::from("info_hash")
}
fn token_key() -> Bytes {
Bytes::from("token")
}
fn dict_get<'a>(val: &'a Bencode, key: &[u8]) -> Option<&'a Bencode> {
match val {
Bencode::Dict(entries) => entries
.iter()
.find(|(k, _)| k.as_ref() == key)
.map(|(_, v)| v),
_ => None,
}
}
pub fn dict_get_bytes<'a>(val: &'a Bencode, key: &[u8]) -> Option<&'a [u8]> {
match dict_get(val, key)? {
Bencode::Bytes(b) => Some(b),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn krpc_ping_roundtrip() {
let tid = [0xAB, 0xCD];
let node_id = [0x42u8; 20];
let bytes = build_ping(tid, &node_id);
let msg = KrpcMessage::from_bytes(&bytes).unwrap();
match &msg {
KrpcMessage::Query {
transaction_id,
method,
..
} => {
assert_eq!(*transaction_id, tid);
assert_eq!(method, "ping");
}
_ => panic!("expected query"),
}
}
#[test]
fn krpc_find_node_roundtrip() {
let tid = [0x12, 0x34];
let node_id = [0x11u8; 20];
let target = [0x22u8; 20];
let bytes = build_find_node(tid, &node_id, &target);
let msg = KrpcMessage::from_bytes(&bytes).unwrap();
match &msg {
KrpcMessage::Query { method, .. } => {
assert_eq!(method, "find_node");
}
_ => panic!("expected query"),
}
}
#[test]
fn krpc_response_roundtrip() {
let tid = [0xFF, 0xEE];
let msg = KrpcMessage::Response {
transaction_id: tid,
result: Bencode::Dict(vec![(
Bytes::from("id"),
Bencode::Bytes(Bytes::copy_from_slice(&[0x55u8; 20])),
)]),
};
let bytes = msg.to_bytes();
let decoded = KrpcMessage::from_bytes(&bytes).unwrap();
match decoded {
KrpcMessage::Response {
transaction_id,
result,
} => {
assert_eq!(transaction_id, tid);
let id = dict_get_bytes(&result, b"id").unwrap();
assert_eq!(id, &[0x55u8; 20]);
}
_ => panic!("expected response"),
}
}
#[test]
fn krpc_error_roundtrip() {
let msg = KrpcMessage::Error {
transaction_id: [0x01, 0x02],
code: 203,
message: "Server Error".into(),
};
let bytes = msg.to_bytes();
let decoded = KrpcMessage::from_bytes(&bytes).unwrap();
match decoded {
KrpcMessage::Error { code, message, .. } => {
assert_eq!(code, 203);
assert_eq!(message, "Server Error");
}
_ => panic!("expected error"),
}
}
#[test]
fn test_parse_compact_nodes() {
let mut data = Vec::new();
data.extend_from_slice(&[0x01u8; 20]);
data.extend_from_slice(&[127, 0, 0, 1]);
data.extend_from_slice(&6881u16.to_be_bytes());
data.extend_from_slice(&[0x02u8; 20]);
data.extend_from_slice(&[192, 168, 1, 1]);
data.extend_from_slice(&51413u16.to_be_bytes());
let nodes = parse_compact_nodes(&data);
assert_eq!(nodes.len(), 2);
assert_eq!(nodes[0].id, [0x01u8; 20]);
assert_eq!(nodes[0].addr.to_string(), "127.0.0.1:6881");
assert_eq!(nodes[1].addr.to_string(), "192.168.1.1:51413");
}
#[test]
fn parse_ping_response_valid() {
let msg = KrpcMessage::Response {
transaction_id: [0xAB, 0xCD],
result: Bencode::Dict(vec![(
Bytes::from("id"),
Bencode::Bytes(Bytes::copy_from_slice(&[0x42u8; 20])),
)]),
};
let id = parse_ping_response(&msg).unwrap();
assert_eq!(id, [0x42u8; 20]);
}
#[test]
fn parse_ping_response_not_a_response() {
let msg = KrpcMessage::Query {
transaction_id: [0; 2],
method: "ping".into(),
args: Bencode::Dict(vec![]),
};
assert!(parse_ping_response(&msg).is_err());
}
#[test]
fn parse_get_peers_values() {
let msg = KrpcMessage::Response {
transaction_id: [0; 2],
result: Bencode::Dict(vec![
(Bytes::from("token"), Bencode::Bytes(Bytes::from("tok"))),
(
Bytes::from("values"),
Bencode::List(vec![
Bencode::Bytes(Bytes::from(vec![127, 0, 0, 1, 0x1A, 0xE1])),
]),
),
]),
};
match parse_get_peers_response(&msg).unwrap() {
GetPeersResult::Values { token, peers } => {
assert_eq!(token, b"tok");
assert_eq!(peers.len(), 1);
}
_ => panic!("expected Values"),
}
}
#[test]
fn parse_get_peers_nodes() {
let mut compact = Vec::new();
compact.extend_from_slice(&[0x01u8; 20]); compact.extend_from_slice(&[10, 0, 0, 1]); compact.extend_from_slice(&6881u16.to_be_bytes());
let msg = KrpcMessage::Response {
transaction_id: [0; 2],
result: Bencode::Dict(vec![
(Bytes::from("token"), Bencode::Bytes(Bytes::from("tok"))),
(Bytes::from("nodes"), Bencode::Bytes(Bytes::from(compact))),
]),
};
match parse_get_peers_response(&msg).unwrap() {
GetPeersResult::Nodes(nodes) => {
assert_eq!(nodes.len(), 1);
assert_eq!(nodes[0].addr.to_string(), "10.0.0.1:6881");
}
_ => panic!("expected Nodes"),
}
}
#[test]
fn parse_get_peers_neither_values_nor_nodes() {
let msg = KrpcMessage::Response {
transaction_id: [0; 2],
result: Bencode::Dict(vec![]),
};
assert!(parse_get_peers_response(&msg).is_err());
}
#[test]
fn parse_get_peers_non_response() {
let msg = KrpcMessage::Query {
transaction_id: [0; 2],
method: "get_peers".into(),
args: Bencode::Dict(vec![]),
};
assert!(parse_get_peers_response(&msg).is_err());
}
#[test]
fn decode_truncated_krpc() {
let data = b"d1:t2:ab1:y1:q"; assert!(KrpcMessage::from_bytes(data).is_err());
}
#[test]
fn decode_unknown_y_type() {
let _ = KrpcMessage::Error {
transaction_id: [0; 2],
code: 0,
message: String::new(),
};
let dict = Bencode::Dict(vec![
(Bytes::from("t"), Bencode::Bytes(Bytes::from(vec![0, 0]))),
(Bytes::from("y"), Bencode::Bytes(Bytes::from(&b"x"[..]))),
(
Bytes::from("e"),
Bencode::List(vec![Bencode::Integer(0), Bencode::Bytes(Bytes::from(""))]),
),
]);
assert!(KrpcMessage::from_bencode(&dict).is_err());
}
#[test]
fn decode_missing_t_field() {
let dict = Bencode::Dict(vec![
(Bytes::from("y"), Bencode::Bytes(Bytes::from(&b"q"[..]))),
(Bytes::from("q"), Bencode::Bytes(Bytes::from(&b"ping"[..]))),
(Bytes::from("a"), Bencode::Dict(vec![])),
]);
assert!(KrpcMessage::from_bencode(&dict).is_err());
}
#[test]
fn decode_error_missing_list() {
let dict = Bencode::Dict(vec![
(Bytes::from("t"), Bencode::Bytes(Bytes::from(vec![0, 0]))),
(Bytes::from("y"), Bencode::Bytes(Bytes::from(&b"e"[..]))),
(Bytes::from("e"), Bencode::Integer(203)), ]);
assert!(KrpcMessage::from_bencode(&dict).is_err());
}
}