use std::cmp::Ordering;
use std::net::{SocketAddr, SocketAddrV4, UdpSocket};
use std::time::{Duration, Instant};
use tracing::{debug, error, trace};
use crate::common::{ErrorSpecific, Message, MessageType, RequestSpecific, ResponseSpecific};
use super::config::Config;
const VERSION: [u8; 4] = [82, 83, 0, 5]; const MTU: usize = 2048;
pub const DEFAULT_PORT: u16 = 6881;
pub const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_millis(2000); pub const READ_TIMEOUT: Duration = Duration::from_millis(10);
#[derive(Debug)]
pub struct KrpcSocket {
next_tid: u16,
socket: UdpSocket,
pub(crate) server_mode: bool,
request_timeout: Duration,
inflight_requests: Vec<InflightRequest>,
local_addr: SocketAddrV4,
}
#[derive(Debug)]
pub struct InflightRequest {
tid: u16,
to: SocketAddrV4,
sent_at: Instant,
}
impl KrpcSocket {
pub(crate) fn new(config: &Config) -> Result<Self, std::io::Error> {
let request_timeout = config.request_timeout;
let port = config.port;
let socket = if let Some(port) = port {
UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], port)))?
} else {
match UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], DEFAULT_PORT))) {
Ok(socket) => Ok(socket),
Err(_) => UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0))),
}?
};
let local_addr = match socket.local_addr()? {
SocketAddr::V4(addr) => addr,
SocketAddr::V6(_) => unimplemented!("KrpcSocket does not support Ipv6"),
};
socket.set_read_timeout(Some(READ_TIMEOUT))?;
Ok(Self {
socket,
next_tid: 0,
server_mode: config.server_mode,
request_timeout,
inflight_requests: Vec::with_capacity(u16::MAX as usize),
local_addr,
})
}
#[cfg(test)]
pub(crate) fn server() -> Result<Self, std::io::Error> {
Self::new(&Config {
server_mode: true,
..Default::default()
})
}
#[cfg(test)]
pub(crate) fn client() -> Result<Self, std::io::Error> {
Self::new(&Config::default())
}
#[inline]
pub fn local_addr(&self) -> SocketAddrV4 {
self.local_addr
}
pub fn inflight(&self, transaction_id: &u16) -> bool {
self.inflight_requests
.binary_search_by(|request| request.tid.cmp(transaction_id))
.is_ok()
}
pub fn request(&mut self, address: SocketAddrV4, request: RequestSpecific) -> u16 {
let message = self.request_message(request);
trace!(context = "socket_message_sending", message = ?message);
self.inflight_requests.push(InflightRequest {
tid: message.transaction_id,
to: address,
sent_at: Instant::now(),
});
let tid = message.transaction_id;
let _ = self.send(address, message).map_err(|e| {
debug!(?e, "Error sending request message");
});
tid
}
pub fn response(
&mut self,
address: SocketAddrV4,
transaction_id: u16,
response: ResponseSpecific,
) {
let message =
self.response_message(MessageType::Response(response), address, transaction_id);
trace!(context = "socket_message_sending", message = ?message);
let _ = self.send(address, message).map_err(|e| {
debug!(?e, "Error sending response message");
});
}
pub fn error(&mut self, address: SocketAddrV4, transaction_id: u16, error: ErrorSpecific) {
let message = self.response_message(MessageType::Error(error), address, transaction_id);
let _ = self.send(address, message).map_err(|e| {
debug!(?e, "Error sending error message");
});
}
pub fn recv_from(&mut self) -> Option<(Message, SocketAddrV4)> {
let mut buf = [0u8; MTU];
match self.inflight_requests.binary_search_by(|request| {
if request.sent_at.elapsed() > self.request_timeout {
Ordering::Less
} else {
Ordering::Greater
}
}) {
Ok(index) => {
self.inflight_requests.drain(..index);
}
Err(index) => {
self.inflight_requests.drain(..index);
}
};
if let Ok((amt, SocketAddr::V4(from))) = self.socket.recv_from(&mut buf) {
let bytes = &buf[..amt];
if from.port() == 0 {
trace!(
context = "socket_validation",
message = "Response from port 0"
);
return None;
}
match Message::from_bytes(bytes) {
Ok(message) => {
let should_return = match message.message_type {
MessageType::Request(_) => {
trace!(
context = "socket_message_receiving",
?message,
?from,
"Received request message"
);
true
}
MessageType::Response(_) => {
trace!(
context = "socket_message_receiving",
?message,
?from,
"Received response message"
);
self.is_expected_response(&message, &from)
}
MessageType::Error(_) => {
trace!(
context = "socket_message_receiving",
?message,
?from,
"Received error message"
);
self.is_expected_response(&message, &from)
}
};
if should_return {
return Some((message, from));
}
}
Err(error) => {
trace!(context = "socket_error", ?error, ?from, message = ?String::from_utf8_lossy(bytes), "Received invalid Bencode message.");
}
};
};
None
}
fn is_expected_response(&mut self, message: &Message, from: &SocketAddrV4) -> bool {
match self
.inflight_requests
.binary_search_by(|request| request.tid.cmp(&message.transaction_id))
{
Ok(index) => {
let inflight_request = self
.inflight_requests
.get(index)
.expect("should be infallible");
if compare_socket_addr(&inflight_request.to, from) {
self.inflight_requests.remove(index);
return true;
} else {
trace!(
context = "socket_validation",
message = "Response from wrong address"
);
}
}
Err(_) => {
trace!(
context = "socket_validation",
message = "Unexpected response id"
);
}
}
false
}
fn tid(&mut self) -> u16 {
let tid = self.next_tid;
self.next_tid = self.next_tid.wrapping_add(1);
tid
}
fn request_message(&mut self, message: RequestSpecific) -> Message {
let transaction_id = self.tid();
Message {
transaction_id,
message_type: MessageType::Request(message),
version: Some(VERSION),
read_only: !self.server_mode,
requester_ip: None,
}
}
fn response_message(
&mut self,
message: MessageType,
requester_ip: SocketAddrV4,
request_tid: u16,
) -> Message {
Message {
transaction_id: request_tid,
message_type: message,
version: Some(VERSION),
read_only: !self.server_mode,
requester_ip: Some(requester_ip),
}
}
fn send(&mut self, address: SocketAddrV4, message: Message) -> Result<(), SendMessageError> {
self.socket.send_to(&message.to_bytes()?, address)?;
trace!(context = "socket_message_sending", message = ?message);
Ok(())
}
}
#[derive(thiserror::Error, Debug)]
pub enum SendMessageError {
#[error("Failed to parse packet bytes: {0}")]
BencodeError(#[from] serde_bencode::Error),
#[error(transparent)]
IO(#[from] std::io::Error),
}
fn compare_socket_addr(a: &SocketAddrV4, b: &SocketAddrV4) -> bool {
if a.port() != b.port() {
return false;
}
if a.ip().is_unspecified() {
return true;
}
a.ip() == b.ip()
}
#[cfg(test)]
mod test {
use std::thread;
use crate::common::{Id, PingResponseArguments, RequestTypeSpecific};
use super::*;
#[test]
fn tid() {
let mut socket = KrpcSocket::server().unwrap();
assert_eq!(socket.tid(), 0);
assert_eq!(socket.tid(), 1);
assert_eq!(socket.tid(), 2);
socket.next_tid = u16::MAX;
assert_eq!(socket.tid(), 65535);
assert_eq!(socket.tid(), 0);
}
#[test]
fn recv_request() {
let mut server = KrpcSocket::server().unwrap();
let server_address = server.local_addr();
let mut client = KrpcSocket::client().unwrap();
client.next_tid = 120;
let client_address = client.local_addr();
let request = RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Ping,
};
let expected_request = request.clone();
let server_thread = thread::spawn(move || loop {
if let Some((message, from)) = server.recv_from() {
assert_eq!(from.port(), client_address.port());
assert_eq!(message.transaction_id, 120);
assert!(message.read_only, "Read-only should be true");
assert_eq!(message.version, Some(VERSION), "Version should be 'RS'");
assert_eq!(message.message_type, MessageType::Request(expected_request));
break;
}
});
client.request(server_address, request);
server_thread.join().unwrap();
}
#[test]
fn recv_response() {
let (tx, rx) = flume::bounded(1);
let mut client = KrpcSocket::client().unwrap();
let client_address = client.local_addr();
let responder_id = Id::random();
let response = ResponseSpecific::Ping(PingResponseArguments { responder_id });
let server_thread = thread::spawn(move || {
let mut server = KrpcSocket::client().unwrap();
let server_address = server.local_addr();
tx.send(server_address).unwrap();
loop {
server.inflight_requests.push(InflightRequest {
tid: 8,
to: client_address,
sent_at: Instant::now(),
});
if let Some((message, from)) = server.recv_from() {
assert_eq!(from.port(), client_address.port());
assert_eq!(message.transaction_id, 8);
assert!(message.read_only, "Read-only should be true");
assert_eq!(message.version, Some(VERSION), "Version should be 'RS'");
assert_eq!(
message.message_type,
MessageType::Response(ResponseSpecific::Ping(PingResponseArguments {
responder_id,
}))
);
break;
}
}
});
let server_address = rx.recv().unwrap();
client.response(server_address, 8, response);
server_thread.join().unwrap();
}
#[test]
fn ignore_response_from_wrong_address() {
let mut server = KrpcSocket::client().unwrap();
let server_address = server.local_addr();
let mut client = KrpcSocket::client().unwrap();
let client_address = client.local_addr();
server.inflight_requests.push(InflightRequest {
tid: 8,
to: SocketAddrV4::new([127, 0, 0, 1].into(), client_address.port() + 1),
sent_at: Instant::now(),
});
let response = ResponseSpecific::Ping(PingResponseArguments {
responder_id: Id::random(),
});
let _ = response.clone();
let server_thread = thread::spawn(move || {
thread::sleep(Duration::from_millis(5));
assert!(
server.recv_from().is_none(),
"Should not receive a response from wrong address"
);
});
client.response(server_address, 8, response);
server_thread.join().unwrap();
}
}