use {
crate::{bind_to_unspecified, HEADER_LENGTH, IP_ECHO_SERVER_RESPONSE_LENGTH},
log::*,
serde::{Deserialize, Serialize},
solana_serde::default_on_eof,
std::{
collections::HashSet,
io,
net::{IpAddr, SocketAddr},
num::NonZeroUsize,
sync::{Arc, Mutex},
time::Duration,
},
tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
runtime::{self, Runtime},
time::{timeout_at, Instant},
},
};
pub type IpEchoServer = Runtime;
pub const MINIMUM_IP_ECHO_SERVER_THREADS: NonZeroUsize = NonZeroUsize::new(2).unwrap();
pub const DEFAULT_IP_ECHO_SERVER_THREADS: NonZeroUsize = MINIMUM_IP_ECHO_SERVER_THREADS;
pub const MAX_PORT_COUNT_PER_MESSAGE: usize = 4;
const IO_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_CONCURRENT_CONNECTIONS: usize = 2048;
struct ConnectionCleanup {
active_ips: Arc<Mutex<HashSet<IpAddr>>>,
ip: IpAddr,
}
impl ConnectionCleanup {
fn new(active_ips: Arc<Mutex<HashSet<IpAddr>>>, ip: IpAddr) -> Self {
Self { active_ips, ip }
}
}
impl Drop for ConnectionCleanup {
fn drop(&mut self) {
let mut active_ips = self.active_ips.lock().expect("active_ips lock poisoned");
release_active_ip(&mut active_ips, self.ip);
}
}
#[derive(Serialize, Deserialize, Default, Debug)]
pub(crate) struct IpEchoServerMessage {
tcp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], udp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], }
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct IpEchoServerResponse {
pub(crate) address: IpAddr,
#[serde(deserialize_with = "default_on_eof")]
pub(crate) shred_version: Option<u16>,
}
impl IpEchoServerMessage {
pub fn new(tcp_ports: &[u16], udp_ports: &[u16]) -> Self {
let mut msg = Self::default();
assert!(tcp_ports.len() <= msg.tcp_ports.len());
assert!(udp_ports.len() <= msg.udp_ports.len());
msg.tcp_ports[..tcp_ports.len()].copy_from_slice(tcp_ports);
msg.udp_ports[..udp_ports.len()].copy_from_slice(udp_ports);
msg
}
}
pub(crate) fn ip_echo_server_request_length() -> usize {
const REQUEST_TERMINUS_LENGTH: usize = 1;
(HEADER_LENGTH + REQUEST_TERMINUS_LENGTH)
.wrapping_add(bincode::serialized_size(&IpEchoServerMessage::default()).unwrap() as usize)
}
async fn process_connection(
mut socket: TcpStream,
peer_addr: SocketAddr,
shred_version: Option<u16>,
) -> io::Result<()> {
info!("connection from {peer_addr:?}");
let deadline = Instant::now()
.checked_add(IO_TIMEOUT)
.ok_or_else(|| io::Error::other("failed to compute request deadline"))?;
let mut data = vec![0u8; ip_echo_server_request_length()];
let mut writer = {
let (mut reader, writer) = socket.split();
let _ = timeout_at(deadline, reader.read_exact(&mut data)).await??;
writer
};
let request_header: String = data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect();
if request_header != "\0\0\0\0" {
if request_header == "GET " || request_header == "POST" {
timeout_at(
deadline,
writer.write_all(b"HTTP/1.1 400 Bad Request\nContent-length: 0\n\n"),
)
.await??;
return Ok(());
}
return Err(io::Error::other(format!(
"Bad request header: {request_header}"
)));
}
let msg =
bincode::deserialize::<IpEchoServerMessage>(&data[HEADER_LENGTH..]).map_err(|err| {
io::Error::other(format!(
"Failed to deserialize IpEchoServerMessage: {err:?}"
))
})?;
trace!("request: {msg:?}");
match bind_to_unspecified() {
Ok(udp_socket) => {
for udp_port in &msg.udp_ports {
if *udp_port != 0 {
let result =
udp_socket.send_to(&[0], SocketAddr::from((peer_addr.ip(), *udp_port)));
match result {
Ok(_) => debug!("Successful send_to udp/{udp_port}"),
Err(err) => info!("Failed to send_to udp/{udp_port}: {err}"),
}
}
}
}
Err(err) => {
warn!("Failed to bind local udp socket: {err}");
}
}
for tcp_port in &msg.tcp_ports {
if *tcp_port != 0 {
debug!("Connecting to tcp/{tcp_port}");
let mut tcp_stream = timeout_at(
deadline,
TcpStream::connect(&SocketAddr::new(peer_addr.ip(), *tcp_port)),
)
.await??;
debug!("Connection established to tcp/{}", *tcp_port);
tcp_stream.shutdown().await?;
}
}
let response = IpEchoServerResponse {
address: peer_addr.ip(),
shred_version,
};
let mut bytes = vec![0u8; IP_ECHO_SERVER_RESPONSE_LENGTH];
bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &response).unwrap();
trace!("response: {bytes:?}");
timeout_at(deadline, writer.write_all(&bytes)).await?
}
fn release_active_ip(active_ips: &mut HashSet<IpAddr>, ip: IpAddr) {
let removed = active_ips.remove(&ip);
debug_assert!(removed, "cleanup for unknown IP {ip}");
}
async fn run_echo_server(tcp_listener: std::net::TcpListener, shred_version: Option<u16>) {
info!("bound to {:?}", tcp_listener.local_addr().unwrap());
let tcp_listener =
TcpListener::from_std(tcp_listener).expect("Failed to convert std::TcpListener");
let active_ips = Arc::new(Mutex::new(HashSet::new()));
loop {
let connection = tcp_listener.accept().await;
match connection {
Ok((socket, peer_addr)) => {
let tracked_ip = (!peer_addr.ip().is_loopback()).then_some(peer_addr.ip());
if let Some(ip) = tracked_ip {
let mut active_ip_set = active_ips
.lock()
.expect("active_ips lock poisoned while admitting");
if active_ip_set.len() >= MAX_CONCURRENT_CONNECTIONS {
debug!(
"dropping connection from {peer_addr:?}: max concurrent connections \
({MAX_CONCURRENT_CONNECTIONS}) reached",
);
continue;
}
if !active_ip_set.insert(ip) {
debug!(
"dropping connection from {peer_addr:?}: max concurrent connections \
per IP (1) reached"
);
continue;
}
}
let cleanup =
tracked_ip.map(|ip| ConnectionCleanup::new(Arc::clone(&active_ips), ip));
runtime::Handle::current().spawn(async move {
let cleanup = cleanup;
if let Err(err) = process_connection(socket, peer_addr, shred_version).await {
info!("session failed: {err:?}");
}
drop(cleanup);
});
}
Err(err) => warn!("listener accept failed: {err:?}"),
}
}
}
pub fn ip_echo_server(
tcp_listener: std::net::TcpListener,
num_server_threads: NonZeroUsize,
shred_version: Option<u16>,
) -> IpEchoServer {
tcp_listener.set_nonblocking(true).unwrap();
let runtime = tokio::runtime::Builder::new_multi_thread()
.thread_name("solIpEchoSrvrRt")
.worker_threads(num_server_threads.get())
.enable_all()
.build()
.expect("new tokio runtime");
runtime.spawn(run_echo_server(tcp_listener, shred_version));
runtime
}