solana_net_utils/
ip_echo_server.rs1use {
2 crate::{bind_to_unspecified, HEADER_LENGTH, IP_ECHO_SERVER_RESPONSE_LENGTH},
3 log::*,
4 serde::{Deserialize, Serialize},
5 solana_serde::default_on_eof,
6 std::{
7 collections::HashSet,
8 io,
9 net::{IpAddr, SocketAddr},
10 num::NonZeroUsize,
11 sync::{Arc, Mutex},
12 time::Duration,
13 },
14 tokio::{
15 io::{AsyncReadExt, AsyncWriteExt},
16 net::{TcpListener, TcpStream},
17 runtime::{self, Runtime},
18 time::{timeout_at, Instant},
19 },
20};
21
22pub type IpEchoServer = Runtime;
23
24pub const MINIMUM_IP_ECHO_SERVER_THREADS: NonZeroUsize = NonZeroUsize::new(2).unwrap();
28pub const DEFAULT_IP_ECHO_SERVER_THREADS: NonZeroUsize = MINIMUM_IP_ECHO_SERVER_THREADS;
31pub const MAX_PORT_COUNT_PER_MESSAGE: usize = 4;
32
33const IO_TIMEOUT: Duration = Duration::from_secs(5);
34const MAX_CONCURRENT_CONNECTIONS: usize = 2048;
36
37struct ConnectionCleanup {
38 active_ips: Arc<Mutex<HashSet<IpAddr>>>,
39 ip: IpAddr,
40}
41
42impl ConnectionCleanup {
43 fn new(active_ips: Arc<Mutex<HashSet<IpAddr>>>, ip: IpAddr) -> Self {
44 Self { active_ips, ip }
45 }
46}
47
48impl Drop for ConnectionCleanup {
49 fn drop(&mut self) {
50 let mut active_ips = self.active_ips.lock().expect("active_ips lock poisoned");
51 release_active_ip(&mut active_ips, self.ip);
52 }
53}
54
55#[derive(Serialize, Deserialize, Default, Debug)]
56pub(crate) struct IpEchoServerMessage {
57 tcp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], udp_ports: [u16; MAX_PORT_COUNT_PER_MESSAGE], }
60
61#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
62pub struct IpEchoServerResponse {
63 pub(crate) address: IpAddr,
65 #[serde(deserialize_with = "default_on_eof")]
67 pub(crate) shred_version: Option<u16>,
68}
69
70impl IpEchoServerMessage {
71 pub fn new(tcp_ports: &[u16], udp_ports: &[u16]) -> Self {
72 let mut msg = Self::default();
73 assert!(tcp_ports.len() <= msg.tcp_ports.len());
74 assert!(udp_ports.len() <= msg.udp_ports.len());
75
76 msg.tcp_ports[..tcp_ports.len()].copy_from_slice(tcp_ports);
77 msg.udp_ports[..udp_ports.len()].copy_from_slice(udp_ports);
78 msg
79 }
80}
81
82pub(crate) fn ip_echo_server_request_length() -> usize {
83 const REQUEST_TERMINUS_LENGTH: usize = 1;
84 (HEADER_LENGTH + REQUEST_TERMINUS_LENGTH)
85 .wrapping_add(bincode::serialized_size(&IpEchoServerMessage::default()).unwrap() as usize)
86}
87
88async fn process_connection(
89 mut socket: TcpStream,
90 peer_addr: SocketAddr,
91 shred_version: Option<u16>,
92) -> io::Result<()> {
93 info!("connection from {peer_addr:?}");
94 let deadline = Instant::now()
95 .checked_add(IO_TIMEOUT)
96 .ok_or_else(|| io::Error::other("failed to compute request deadline"))?;
97
98 let mut data = vec![0u8; ip_echo_server_request_length()];
99
100 let mut writer = {
101 let (mut reader, writer) = socket.split();
102 let _ = timeout_at(deadline, reader.read_exact(&mut data)).await??;
103 writer
104 };
105
106 let request_header: String = data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect();
107 if request_header != "\0\0\0\0" {
108 if request_header == "GET " || request_header == "POST" {
112 timeout_at(
114 deadline,
115 writer.write_all(b"HTTP/1.1 400 Bad Request\nContent-length: 0\n\n"),
116 )
117 .await??;
118 return Ok(());
119 }
120 return Err(io::Error::other(format!(
121 "Bad request header: {request_header}"
122 )));
123 }
124
125 let msg =
126 bincode::deserialize::<IpEchoServerMessage>(&data[HEADER_LENGTH..]).map_err(|err| {
127 io::Error::other(format!(
128 "Failed to deserialize IpEchoServerMessage: {err:?}"
129 ))
130 })?;
131
132 trace!("request: {msg:?}");
133
134 match bind_to_unspecified() {
136 Ok(udp_socket) => {
137 for udp_port in &msg.udp_ports {
138 if *udp_port != 0 {
139 let result =
140 udp_socket.send_to(&[0], SocketAddr::from((peer_addr.ip(), *udp_port)));
141 match result {
142 Ok(_) => debug!("Successful send_to udp/{udp_port}"),
143 Err(err) => info!("Failed to send_to udp/{udp_port}: {err}"),
144 }
145 }
146 }
147 }
148 Err(err) => {
149 warn!("Failed to bind local udp socket: {err}");
150 }
151 }
152
153 for tcp_port in &msg.tcp_ports {
155 if *tcp_port != 0 {
156 debug!("Connecting to tcp/{tcp_port}");
157
158 let mut tcp_stream = timeout_at(
159 deadline,
160 TcpStream::connect(&SocketAddr::new(peer_addr.ip(), *tcp_port)),
161 )
162 .await??;
163
164 debug!("Connection established to tcp/{}", *tcp_port);
165 tcp_stream.shutdown().await?;
166 }
167 }
168 let response = IpEchoServerResponse {
169 address: peer_addr.ip(),
170 shred_version,
171 };
172 let mut bytes = vec![0u8; IP_ECHO_SERVER_RESPONSE_LENGTH];
175 bincode::serialize_into(&mut bytes[HEADER_LENGTH..], &response).unwrap();
176 trace!("response: {bytes:?}");
177 timeout_at(deadline, writer.write_all(&bytes)).await?
178}
179
180fn release_active_ip(active_ips: &mut HashSet<IpAddr>, ip: IpAddr) {
181 let removed = active_ips.remove(&ip);
182 debug_assert!(removed, "cleanup for unknown IP {ip}");
183}
184
185async fn run_echo_server(tcp_listener: std::net::TcpListener, shred_version: Option<u16>) {
186 info!("bound to {:?}", tcp_listener.local_addr().unwrap());
187 let tcp_listener =
188 TcpListener::from_std(tcp_listener).expect("Failed to convert std::TcpListener");
189 let active_ips = Arc::new(Mutex::new(HashSet::new()));
190
191 loop {
192 let connection = tcp_listener.accept().await;
193 match connection {
194 Ok((socket, peer_addr)) => {
195 let tracked_ip = (!peer_addr.ip().is_loopback()).then_some(peer_addr.ip());
196 if let Some(ip) = tracked_ip {
197 let mut active_ip_set = active_ips
198 .lock()
199 .expect("active_ips lock poisoned while admitting");
200 if active_ip_set.len() >= MAX_CONCURRENT_CONNECTIONS {
201 debug!(
202 "dropping connection from {peer_addr:?}: max concurrent connections \
203 ({MAX_CONCURRENT_CONNECTIONS}) reached",
204 );
205 continue;
206 }
207 if !active_ip_set.insert(ip) {
208 debug!(
209 "dropping connection from {peer_addr:?}: max concurrent connections \
210 per IP (1) reached"
211 );
212 continue;
213 }
214 }
215 let cleanup =
216 tracked_ip.map(|ip| ConnectionCleanup::new(Arc::clone(&active_ips), ip));
217 runtime::Handle::current().spawn(async move {
218 let cleanup = cleanup;
219 if let Err(err) = process_connection(socket, peer_addr, shred_version).await {
220 info!("session failed: {err:?}");
221 }
222 drop(cleanup);
223 });
224 }
225 Err(err) => warn!("listener accept failed: {err:?}"),
226 }
227 }
228}
229
230pub fn ip_echo_server(
233 tcp_listener: std::net::TcpListener,
234 num_server_threads: NonZeroUsize,
235 shred_version: Option<u16>,
237) -> IpEchoServer {
238 tcp_listener.set_nonblocking(true).unwrap();
239
240 let runtime = tokio::runtime::Builder::new_multi_thread()
241 .thread_name("solIpEchoSrvrRt")
242 .worker_threads(num_server_threads.get())
243 .enable_all()
244 .build()
245 .expect("new tokio runtime");
246 runtime.spawn(run_echo_server(tcp_listener, shred_version));
247 runtime
248}