use std::collections::HashMap;
use std::hash::Hash;
use std::net::SocketAddr;
use std::time::Instant;
use bytes::Bytes;
use crate::packet_utils::read_op_code;
use crate::packets::RemapConnection;
use crate::protocol::{DisconnectReason, OpCode};
use crate::session::{
ApplicationParameters, SessionEvent, SessionMode, SessionParameters, SessionState, SoeSession,
};
pub trait SoeSocket {
fn local_addr(&self) -> std::io::Result<SocketAddr>;
fn session_count(&self) -> usize;
fn connect(&mut self, remote: SocketAddr);
fn enqueue_data(&mut self, remote: &SocketAddr, data: &[u8]) -> bool;
fn terminate(&mut self, remote: &SocketAddr, reason: DisconnectReason);
}
pub trait RemoteAddr: Clone + Eq + Hash {
fn same_host(&self, other: &Self) -> bool;
}
impl RemoteAddr for SocketAddr {
fn same_host(&self, other: &Self) -> bool {
self.ip() == other.ip()
}
}
#[derive(Debug, Clone, Default)]
pub struct SocketConfig {
pub default_session_params: SessionParameters,
pub app_params: ApplicationParameters,
pub allow_port_remaps: bool,
pub base_rng_seed: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SocketEvent<A> {
SessionOpened {
remote: A,
},
DataReceived {
remote: A,
data: Bytes,
},
SessionClosed {
remote: A,
reason: DisconnectReason,
},
}
#[derive(Debug)]
pub struct SoeMultiplexer<A: RemoteAddr> {
config: SocketConfig,
sessions: HashMap<A, SoeSession>,
outgoing: Vec<(A, Bytes)>,
events: Vec<SocketEvent<A>>,
next_seed: u64,
}
impl<A: RemoteAddr> SoeMultiplexer<A> {
pub fn new(config: SocketConfig) -> Self {
let next_seed = config.base_rng_seed;
Self {
config,
sessions: HashMap::new(),
outgoing: Vec::new(),
events: Vec::new(),
next_seed,
}
}
pub fn session_count(&self) -> usize {
self.sessions.len()
}
pub fn session(&self, remote: &A) -> Option<&SoeSession> {
self.sessions.get(remote)
}
pub fn take_outgoing(&mut self) -> Vec<(A, Bytes)> {
std::mem::take(&mut self.outgoing)
}
pub fn take_events(&mut self) -> Vec<SocketEvent<A>> {
std::mem::take(&mut self.events)
}
pub fn connect(&mut self, remote: A, now: Instant) {
self.create_session(remote.clone(), SessionMode::Client, now);
if let Some(session) = self.sessions.get_mut(&remote) {
session.send_session_request();
}
self.drain_session(&remote);
}
#[must_use = "a false return means the data was dropped because no running session exists for the address"]
pub fn enqueue_data(&mut self, remote: &A, data: &[u8]) -> bool {
let queued = match self.sessions.get_mut(remote) {
Some(session) => session.enqueue_data(data),
None => false,
};
self.drain_session(remote);
queued
}
pub fn terminate(&mut self, remote: &A, reason: DisconnectReason, now: Instant) {
if let Some(session) = self.sessions.get_mut(remote) {
session.terminate(reason, true, now);
}
self.drain_session(remote);
self.remove_if_terminated(remote);
}
pub fn process_incoming(&mut self, remote: A, datagram: Bytes, now: Instant) {
if !self.sessions.contains_key(&remote) {
match read_op_code(&datagram) {
Some(OpCode::SessionRequest) => {
self.create_session(remote.clone(), SessionMode::Server, now);
}
Some(OpCode::RemapConnection) => {
self.handle_remap(&remote, &datagram);
return;
}
_ => return,
}
}
if let Some(session) = self.sessions.get_mut(&remote) {
session.process_incoming(datagram, now);
}
self.drain_session(&remote);
self.remove_if_terminated(&remote);
}
pub fn run_tick(&mut self, now: Instant) {
let mut outgoing = std::mem::take(&mut self.outgoing);
let mut events = std::mem::take(&mut self.events);
self.sessions.retain(|remote, session| {
session.run_tick(now);
Self::drain_into(remote, session, &mut outgoing, &mut events);
session.state() != SessionState::Terminated
});
self.outgoing = outgoing;
self.events = events;
}
pub fn drive<T>(&mut self, transport: &mut T, now: Instant) -> std::io::Result<()>
where
T: UdpTransport<Addr = A>,
{
let mut buf = [0u8; 2048];
while let Some((len, from)) = transport.try_recv(&mut buf)? {
self.process_incoming(from, Bytes::copy_from_slice(&buf[..len]), now);
}
self.run_tick(now);
for (addr, datagram) in self.take_outgoing() {
transport.send_to(&datagram, &addr)?;
}
Ok(())
}
fn create_session(&mut self, remote: A, mode: SessionMode, now: Instant) {
let seed = self.next_seed;
self.next_seed = self.next_seed.wrapping_add(1);
let session = SoeSession::new(
mode,
self.config.default_session_params.clone(),
self.config.app_params.clone(),
seed,
now,
);
self.sessions.insert(remote, session);
}
fn handle_remap(&mut self, from: &A, datagram: &[u8]) {
if !self.config.allow_port_remaps {
return;
}
let remap = match RemapConnection::deserialize(datagram, true) {
Ok(remap) => remap,
Err(_) => return,
};
let old_key = self.sessions.iter().find_map(|(key, session)| {
(session.session_id() == remap.session_id && session.crc_seed() == remap.crc_seed)
.then(|| key.clone())
});
let Some(old_key) = old_key else { return };
if &old_key == from || !old_key.same_host(from) {
return;
}
if let Some(session) = self.sessions.remove(&old_key) {
self.sessions.insert(from.clone(), session);
}
}
fn drain_session(&mut self, remote: &A) {
if let Some(session) = self.sessions.get_mut(remote) {
Self::drain_into(remote, session, &mut self.outgoing, &mut self.events);
}
}
fn drain_into(
remote: &A,
session: &mut SoeSession,
outgoing: &mut Vec<(A, Bytes)>,
events: &mut Vec<SocketEvent<A>>,
) {
for datagram in session.take_outgoing() {
outgoing.push((remote.clone(), datagram));
}
let session_events = session.take_events();
for event in &session_events {
if matches!(event, SessionEvent::Opened) {
events.push(SocketEvent::SessionOpened {
remote: remote.clone(),
});
}
}
for data in session.take_received() {
events.push(SocketEvent::DataReceived {
remote: remote.clone(),
data,
});
}
for event in session_events {
if let SessionEvent::Closed(reason) = event {
events.push(SocketEvent::SessionClosed {
remote: remote.clone(),
reason,
});
}
}
}
fn remove_if_terminated(&mut self, remote: &A) {
if let Some(session) = self.sessions.get(remote)
&& session.state() == SessionState::Terminated
{
self.sessions.remove(remote);
}
}
}
pub trait UdpTransport {
type Addr: RemoteAddr;
fn try_recv(&mut self, buf: &mut [u8]) -> std::io::Result<Option<(usize, Self::Addr)>>;
fn send_to(&mut self, buf: &[u8], addr: &Self::Addr) -> std::io::Result<usize>;
}
impl UdpTransport for std::net::UdpSocket {
type Addr = std::net::SocketAddr;
fn try_recv(&mut self, buf: &mut [u8]) -> std::io::Result<Option<(usize, Self::Addr)>> {
match self.recv_from(buf) {
Ok((len, from)) => Ok(Some((len, from))),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(None),
Err(e) => Err(e),
}
}
fn send_to(&mut self, buf: &[u8], addr: &Self::Addr) -> std::io::Result<usize> {
std::net::UdpSocket::send_to(self, buf, addr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rc4::Rc4KeyState;
use std::net::SocketAddr;
const CLIENT: &str = "127.0.0.1:4001";
const SERVER: &str = "127.0.0.1:4002";
fn addr(s: &str) -> SocketAddr {
s.parse().unwrap()
}
fn config(protocol: &str, seed: u64) -> SocketConfig {
let mut params = SessionParameters {
application_protocol: protocol.to_owned(),
..SessionParameters::default()
};
params.heartbeat_after = std::time::Duration::ZERO;
params.inactivity_timeout = std::time::Duration::ZERO;
SocketConfig {
default_session_params: params,
app_params: ApplicationParameters::default(),
allow_port_remaps: false,
base_rng_seed: seed,
}
}
fn pump(client: &mut SoeMultiplexer<SocketAddr>, server: &mut SoeMultiplexer<SocketAddr>) {
let now = Instant::now();
for _ in 0..64 {
client.run_tick(now);
server.run_tick(now);
let from_client = client.take_outgoing();
let from_server = server.take_outgoing();
if from_client.is_empty() && from_server.is_empty() {
break;
}
for (_dest, dg) in from_client {
server.process_incoming(addr(CLIENT), dg, now);
}
for (_dest, dg) in from_server {
client.process_incoming(addr(SERVER), dg, now);
}
}
}
#[test]
fn establishes_session_and_emits_opened() {
let now = Instant::now();
let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
let mut server = SoeMultiplexer::new(config("TestProtocol", 2));
client.connect(addr(SERVER), now);
pump(&mut client, &mut server);
assert_eq!(client.session_count(), 1);
assert_eq!(server.session_count(), 1);
assert!(client.take_events().iter().any(|e| matches!(
e,
SocketEvent::SessionOpened { remote } if *remote == addr(SERVER)
)));
assert!(client.enqueue_data(&addr(SERVER), b"hi"));
pump(&mut client, &mut server);
assert!(server.take_events().iter().any(|e| matches!(
e,
SocketEvent::SessionOpened { remote } if *remote == addr(CLIENT)
)));
}
#[test]
fn routes_data_between_peers() {
let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
let mut server = SoeMultiplexer::new(config("TestProtocol", 2));
client.connect(addr(SERVER), Instant::now());
pump(&mut client, &mut server);
assert!(client.enqueue_data(&addr(SERVER), b"ping"));
pump(&mut client, &mut server);
assert!(server.take_events().iter().any(|e| matches!(
e,
SocketEvent::DataReceived { remote, data } if *remote == addr(CLIENT) && data == "ping"
)));
assert!(server.enqueue_data(&addr(CLIENT), b"pong"));
pump(&mut client, &mut server);
assert!(client.take_events().iter().any(|e| matches!(
e,
SocketEvent::DataReceived { remote, data } if *remote == addr(SERVER) && data == "pong"
)));
}
#[test]
fn encrypted_data_routes_between_peers() {
let key = Rc4KeyState::new(&[1, 2, 3, 4, 5]);
let mut client_cfg = config("TestProtocol", 1);
let mut server_cfg = config("TestProtocol", 2);
client_cfg.app_params.encryption_key_state = Some(key.clone());
server_cfg.app_params.encryption_key_state = Some(key);
let mut client = SoeMultiplexer::new(client_cfg);
let mut server = SoeMultiplexer::new(server_cfg);
client.connect(addr(SERVER), Instant::now());
pump(&mut client, &mut server);
let payload = vec![0u8; 200];
assert!(client.enqueue_data(&addr(SERVER), &payload));
pump(&mut client, &mut server);
assert!(server.take_events().iter().any(|e| matches!(
e,
SocketEvent::DataReceived { remote, data }
if *remote == addr(CLIENT) && data.as_ref() == payload.as_slice()
)));
}
#[test]
fn terminate_notifies_remote_and_removes_session() {
let now = Instant::now();
let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
let mut server = SoeMultiplexer::new(config("TestProtocol", 2));
client.connect(addr(SERVER), now);
pump(&mut client, &mut server);
client.take_events();
server.take_events();
client.terminate(&addr(SERVER), DisconnectReason::Application, now);
pump(&mut client, &mut server);
assert_eq!(client.session_count(), 0);
assert_eq!(server.session_count(), 0);
assert!(server.take_events().iter().any(|e| matches!(
e,
SocketEvent::SessionClosed { remote, reason }
if *remote == addr(CLIENT) && *reason == DisconnectReason::Application
)));
}
#[test]
fn ignores_stray_datagram_without_session() {
let now = Instant::now();
let mut server = SoeMultiplexer::<SocketAddr>::new(config("TestProtocol", 1));
server.process_incoming(addr(CLIENT), Bytes::from_static(&[0x00, 0x09, 0x00]), now);
assert_eq!(server.session_count(), 0);
assert!(server.take_outgoing().is_empty());
assert!(server.take_events().is_empty());
}
#[test]
fn remaps_port_for_matching_session() {
let now = Instant::now();
let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
let mut server_cfg = config("TestProtocol", 2);
server_cfg.allow_port_remaps = true;
let mut server = SoeMultiplexer::new(server_cfg);
client.connect(addr(SERVER), now);
pump(&mut client, &mut server);
let session = server.session(&addr(CLIENT)).expect("server session");
let remap = RemapConnection {
session_id: session.session_id(),
crc_seed: session.crc_seed(),
};
let mut buf = [0u8; RemapConnection::SIZE];
let n = remap.serialize(&mut buf).unwrap();
let new_client = addr("127.0.0.1:4099");
server.process_incoming(new_client, Bytes::copy_from_slice(&buf[..n]), now);
assert!(server.session(&addr(CLIENT)).is_none());
assert!(server.session(&new_client).is_some());
}
#[test]
fn rejects_remap_from_different_host() {
let now = Instant::now();
let mut client = SoeMultiplexer::new(config("TestProtocol", 1));
let mut server_cfg = config("TestProtocol", 2);
server_cfg.allow_port_remaps = true;
let mut server = SoeMultiplexer::new(server_cfg);
client.connect(addr(SERVER), now);
pump(&mut client, &mut server);
let session = server.session(&addr(CLIENT)).expect("server session");
let remap = RemapConnection {
session_id: session.session_id(),
crc_seed: session.crc_seed(),
};
let mut buf = [0u8; RemapConnection::SIZE];
let n = remap.serialize(&mut buf).unwrap();
let attacker = addr("10.0.0.1:5000");
server.process_incoming(attacker, Bytes::copy_from_slice(&buf[..n]), now);
assert!(server.session(&addr(CLIENT)).is_some());
assert!(server.session(&attacker).is_none());
}
}