cat_dev/net/server/
tcp.rs

1//! Utilities for Serving the TCP layer 4 protocol.
2//!
3//! Note this "TCP Server" class allows both binding to a particular port, as
4//! well as "connect"'ing to another host. This is mostly due to the fact we
5//! have to support SDIO where the person who "connect"'s acts as the server.
6//!
7//! ## Notes About Layering
8//!
9//! When layer-ing things on-top of service, it's important to note about
10//! performance characteristics. For every single individual `layer` call
11//! there ***may be significant overhead***. As such, it is best to use
12//! [`tower::ServiceBuilder`] in order to lower the cost of overhead related
13//! to significant layering.
14//!
15//! ## TCP and NAGLE'ing
16//!
17//! TCP the protocol is fundamentally built on-top of a 'stream', and is
18//! not packet oriented. Even though we try to guard against this, by
19//! setting no-delay on multiple parts of our infrastructure. There is
20//! a whole world out there who may not do what we want.
21//!
22//! As a result, we need to guard against weird splits of data across read
23//! calls. Because a single read call isn't neccissarily going to read a full
24//! "packet", and may even read multiple packets.
25//!
26//! I generally will refer to this in the source code as "NAGLE" protections,
27//! as NAGLE's algorithim is usually to blame for weird behaviors here. Where
28//! a device will combine multiple small packets together.
29//!
30//! ## Notes about Concurrency
31//!
32//! This TCP Server unfortunately has to make the sacrifice and process one
33//! packet per stream at a time. While you can have as many TCP streams as you
34//! want, and we should be able to handle many at the same time! Unfortunately
35//! the ordered nature of TCP, along with some protocol designs implemented by
36//! nintendo means this server must also force that we process one packet per
37//! tcp stream at a time.
38//!
39//! Most notably this comes from the fact that our file servers will
40//! consistently break their normal "NAGLE" protection, and we have to do just
41//! raw reads of N bytes from the stream (in both ways), _BEFORE_ processing
42//! another request.
43//!
44//! ## SLOW-loris
45//!
46//! As mentioned TCP is built ontop of a stream, there is a chance that someone
47//! can just "slow drip" us a packet, bit by bit, eating up an open connection
48//! slot for a long period of time. Eating up a connection slot forever. This
49//! used to be a lot more damaging when you'd allocate a thread for every
50//! single connection.
51//!
52//! For us, because we _don't_ allocate a thread for every single connection
53//! the cost of any single connection is low. Not to mention we don't really
54//! have to deal with untrusted DOS connections in a lot of cases.
55//!
56//! That being said, there's no reason to allow someone to slow-drip us a
57//! packet forever, so our servers will inherently time out a packet if it's
58//! taking too long closing a connection.
59//!
60//! ## Lifecycle of a Packet
61//!
62//! ```ascii
63//!     ┌─────────┐                        ┌─────────┐
64//!     │TCPClient│                        │TCPServer│
65//!     └─────────┘                        └─────────┘
66//!          │Sends Arbitrary Sequence of Bytes │
67//!          │─────────────────────────────────>│
68//!          │                                  │
69//!          │                                  │────┐
70//!          │                                  │    │ Call PreNagleFN (if one exists).
71//!          │                                  │    │ This is a custom function that handles things like:
72//!          │                                  │    │
73//!          │                                  │    │   1. Decrypting Packets.
74//!          │                                  │    │   2. Logging info not through trace.
75//!          │                                  │    │
76//!          │                                  │    │ Generally should not be used for anything as it
77//!          │                                  │    │ happens before ALL processing.
78//!          │                                  │<───┘
79//!          │                                  │
80//!          │                                  │────┐
81//!          │                                  │    │ Check how long this client has been attempting to
82//!          │                                  │    │ send us a single packet.
83//!          │                                  │    │
84//!          │                                  │    │ Compare this with `SLOWLORIS_TIMEOUT`,
85//!          │                                  │    │ and error/close the stream if it's taken too long to
86//!          │                                  │    │ send us a packet).
87//!          │                                  │<───┘
88//!          │                                  │
89//!          │                                  │────┐
90//!          │                                  │    │ Look at "nagle cache", this is any data that was previously sent
91//!          │                                  │    │ on this TCP Stream, that has not yet been processed.
92//!          │                                  │    │
93//!          │                                  │    │ Create final `buff`.
94//!          │                                  │<───┘
95//!          │                                  │
96//!          │                                  │────┐
97//!          │                                  │    │ Look at "nagle guard", this determines when one packet begins/ends.
98//!          │                                  │    │ Process any # of packets that may have come in.
99//!          │                                  │    │
100//!          │                                  │    │ If any data is left over put it in "nagle cache".
101//!          │                                  │<───┘
102//!          │                                  │
103//!          │                                  │────┐
104//!          │                                  │    │ Process Packet.
105//!          │                                  │<───┘
106//!          │                                  │
107//!          │                                  │────┐
108//!          │                                  │    │ If reply was sent, and wasn't empty...
109//!          │                                  │<───┘
110//!          │                                  │
111//!          │                                  │────┐
112//!          │                                  │    │ Look at "chunked output size", and chunk packet if necessary.
113//!          │                                  │<───┘
114//!          │                                  │
115//!          │                                  │────┐
116//!          │                                  │    │ For each packet apply PostNagleFN (if one exists).
117//!          │                                  │    │ This is a custom function that handles things like:
118//!          │                                  │    │
119//!          │                                  │    │   1. Encrypting Packets.
120//!          │                                  │    │   2. Logging info not through trace.
121//!          │                                  │    │   3. Sleeping to not overwhelm a CAT-DEV.
122//!          │                                  │<───┘
123//!          │                                  │
124//!          │      Data for the Client.        │
125//!          │<─────────────────────────────────│
126//!     ┌─────────┐                        ┌─────────┐
127//!     │TCPClient│                        │TCPServer│
128//!     └─────────┘                        └─────────┘
129//! ```
130
131use crate::{
132	errors::{CatBridgeError, NetworkError},
133	net::{
134		DEFAULT_SLOWLORIS_TIMEOUT, SERVER_ID, STREAM_ID, TCP_READ_BUFFER_SIZE,
135		errors::{CommonNetAPIError, CommonNetNetworkError},
136		handlers::{
137			OnResponseStreamBeginHandler, OnResponseStreamEndHandler,
138			OnStreamBeginHandlerAsService, OnStreamEndHandlerAsService,
139		},
140		models::{NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
141		now,
142		server::models::{
143			DisconnectAsyncDropServer, ResponseStreamEvent, ResponseStreamMessage,
144			UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
145		},
146	},
147};
148use bytes::{Bytes, BytesMut};
149use fnv::FnvHashSet;
150use futures::future::join_all;
151use scc::HashMap as ConcurrentMap;
152use std::{
153	convert::Infallible,
154	fmt::{Debug, Formatter, Result as FmtResult},
155	net::{IpAddr, SocketAddr},
156	sync::{Arc, LazyLock, atomic::Ordering},
157	time::{Duration, SystemTime},
158};
159use tokio::{
160	io::{AsyncReadExt, AsyncWriteExt},
161	net::{TcpListener, TcpStream, ToSocketAddrs, lookup_host},
162	sync::{
163		Mutex,
164		mpsc::{Sender as BoundedSender, channel as bounded_channel},
165	},
166	task::{Builder as TaskBuilder, block_in_place},
167	time::sleep,
168};
169use tower::{Layer, Service, util::BoxCloneService};
170use tracing::{Instrument, debug, error_span, trace, warn};
171use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
172
173#[cfg(debug_assertions)]
174use crate::net::SPRIG_TRACE_IO;
175
176/// A map of streams and the channels to queue a repsonse packet out on, in a
177/// complete out of band way.
178static OUT_OF_BAND_SENDERS: LazyLock<
179	ConcurrentMap<(u64, u64), BoundedSender<ResponseStreamMessage>>,
180> = LazyLock::new(ConcurrentMap::new);
181
182/// An implementation of a TCP server.
183///
184/// This is the "base" TCP Server class that any of the servers that cat-dev
185/// will be built on. It has extensions for handling any potential type of
186/// server that we will ever need to support.
187///
188/// This server accepts lots of parameters so look at the modules documentation
189/// for general info about how processing works.
190pub struct TCPServer<State: Clone + Send + Sync + 'static = ()> {
191	/// The address that the server will bound too, when [`bind`]
192	/// is called. It may also be a server address to connect too when
193	/// we are a server who connects to our client.
194	///
195	/// Currently our server can only bind to a very specific address, and does
196	/// not supporting binding to multiple interfaces besides the special
197	/// `0.0.0.0` addresses that get handled by the OS.
198	address_to_bind_or_connect_to: SocketAddr,
199	/// Cat-dev's need load-bearing sleeps as they can "ACK" a ppacket, but
200	/// throw away the bytes and pretend it never got implied.
201	///
202	/// This is used for sleeping before then. By default this isn't set,
203	/// cat-dev services will call: [`TCPServer::set_cat_dev_slowdown`].
204	cat_dev_slowdown: Option<Duration>,
205	/// For devices that can't receive too much data at once.
206	///
207	/// When devices *cough* MION *cough* can't receive too much data at once,
208	/// we need to chunk it, and rest between those chunks. This will chunk those
209	/// packets for us.
210	chunk_output_at_size: Option<usize>,
211	/// The ID of this particular server.
212	///
213	/// Used to do out of band sends/broadcasts to this particular server.
214	id: u64,
215	/// The [`Service`] that gets called when a packet is received.
216	///
217	/// This will almost always be some sort of router that knows how to route
218	/// packets to handlers. However, it could just be a single tower service.
219	initial_service: BoxCloneService<Request<State>, Response, Infallible>,
220	/// Determines when a packet starts, and ends.
221	nagle_guard: NagleGuard,
222	/// A tower service to call when a particular stream starts.
223	///
224	/// This is effectively like an "on connect" hook that you can use to
225	/// call functions.
226	on_stream_begin: Option<UnderlyingOnStreamBeginService<State>>,
227	/// A tower service to call when a particular stream ends.
228	///
229	/// This is effectively like an "on disconnect" hook that you can use
230	/// to call functions.
231	on_stream_end: Option<UnderlyingOnStreamEndService<State>>,
232	/// A function to apply some sort of processing before doing any sort
233	/// of NAGLE/SLOWLORIS logic.
234	///
235	/// This is best used for encryption/decryption that needs to be handled
236	/// before anything else.
237	pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
238	/// A function to apply some sort of processing before sending a packet
239	/// out.
240	///
241	/// This is the very last thing before `send` is actually called. This is
242	/// best used for encryption/decryption that needs to be only processed
243	/// at the end.
244	post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
245	/// The name of the service being provided by this server, to attach to logs.
246	service_name: &'static str,
247	/// The "Slowloris Detection" timeout.
248	slowloris_timeout: Duration,
249	/// Type-safe server wide state.
250	///
251	/// This is mostly used by our servers to attach state this is meant to be
252	/// server wide. Such as a `HostFilesystem` type to abstract out all the
253	/// filesystem stuff that many of our servers have to do.
254	state: State,
255	/// If we should log all packet requests/responses when compiled with debug
256	/// assertions.
257	#[cfg(debug_assertions)]
258	trace_during_debug: bool,
259}
260
261impl TCPServer<()> {
262	/// Create a new TCP Server without having to specify any particular state.
263	///
264	/// ## Errors
265	///
266	/// If we cannot look up one specific address to bind too.
267	pub async fn new<AddrTy, ServiceTy>(
268		service_name: &'static str,
269		bind_addr: AddrTy,
270		initial_service: ServiceTy,
271		nagle_hooks: (
272			Option<&'static dyn PreNagleFnTy>,
273			Option<&'static dyn PostNagleFnTy>,
274		),
275		guard: impl Into<NagleGuard>,
276		trace_io_during_debug: bool,
277	) -> Result<Self, CommonNetAPIError>
278	where
279		AddrTy: ToSocketAddrs,
280		ServiceTy:
281			Clone + Send + Service<Request<()>, Response = Response, Error = Infallible> + 'static,
282		ServiceTy::Future: Send + 'static,
283	{
284		Self::new_with_state(
285			service_name,
286			bind_addr,
287			initial_service,
288			nagle_hooks,
289			guard,
290			(),
291			trace_io_during_debug,
292		)
293		.await
294	}
295}
296
297impl<State: Clone + Send + Sync + 'static> TCPServer<State> {
298	/// Send a message completely out of band to a particular open stream.
299	///
300	/// ## Errors
301	///
302	/// - If the stream is not actively open.
303	/// - If we could not actively queue a packet to go out.
304	pub async fn out_of_bound_send(
305		server_id: u64,
306		stream_id: u64,
307		message: ResponseStreamMessage,
308	) -> Result<(), CatBridgeError> {
309		if let Some(stream) = OUT_OF_BAND_SENDERS.get_async(&(server_id, stream_id)).await {
310			stream
311				.send(message)
312				.await
313				.map_err(NetworkError::SendQueueMessageFailure)?;
314			Ok(())
315		} else {
316			Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
317		}
318	}
319
320	/// Send a message completely out of band to all open stream.
321	pub async fn out_of_bound_broadcast(
322		server_id: u64,
323		message: ResponseStreamMessage,
324	) -> Vec<Result<(), CatBridgeError>> {
325		let mut ids = FnvHashSet::default();
326		// Scan all senders at a Point-in-Time..
327		OUT_OF_BAND_SENDERS
328			.iter_async(|key, _value| {
329				if key.0 == server_id {
330					ids.insert(key.1);
331				}
332				true
333			})
334			.await;
335
336		// Now we broadcast to those streams....
337		let mut tasks = Vec::with_capacity(ids.len());
338		for id in ids {
339			tasks.push(Self::out_of_bound_send(server_id, id, message.clone()));
340		}
341
342		join_all(tasks).await
343	}
344
345	/// Create a new TCP Server along with the state for this server.
346	///
347	/// ## Errors
348	///
349	/// If we cannot look up one specific address to bind too.
350	#[allow(unused)]
351	pub async fn new_with_state<AddrTy, ServiceTy>(
352		service_name: &'static str,
353		bind_addr: AddrTy,
354		initial_service: ServiceTy,
355		nagle_hooks: (
356			Option<&'static dyn PreNagleFnTy>,
357			Option<&'static dyn PostNagleFnTy>,
358		),
359		guard: impl Into<NagleGuard>,
360		state: State,
361		trace_io_during_debug: bool,
362	) -> Result<Self, CommonNetAPIError>
363	where
364		AddrTy: ToSocketAddrs,
365		ServiceTy: Clone
366			+ Send
367			+ Service<Request<State>, Response = Response, Error = Infallible>
368			+ 'static,
369		ServiceTy::Future: Send + 'static,
370	{
371		let hosts = lookup_host(bind_addr)
372			.await
373			.map_err(CommonNetAPIError::AddressLookupError)?
374			.collect::<Vec<_>>();
375		if hosts.len() != 1 {
376			return Err(CommonNetAPIError::WrongAmountOfAddressesToBindToo(hosts));
377		}
378
379		#[cfg(not(debug_assertions))]
380		{
381			if trace_io_during_debug {
382				warn!(
383					"Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
384				);
385			}
386		}
387
388		Ok(Self {
389			address_to_bind_or_connect_to: hosts[0],
390			cat_dev_slowdown: None,
391			chunk_output_at_size: None,
392			id: SERVER_ID.fetch_add(1, Ordering::SeqCst),
393			initial_service: BoxCloneService::new(initial_service),
394			nagle_guard: guard.into(),
395			on_stream_begin: None,
396			on_stream_end: None,
397			pre_nagle_hook: nagle_hooks.0,
398			post_nagle_hook: nagle_hooks.1,
399			service_name,
400			slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
401			state,
402			#[cfg(debug_assertions)]
403			trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
404		})
405	}
406
407	/// Get the particular ID for this server.
408	#[must_use]
409	pub const fn id(&self) -> u64 {
410		self.id
411	}
412
413	/// The IP address we should connect to bind too.
414	#[must_use]
415	pub const fn ip(&self) -> IpAddr {
416		self.address_to_bind_or_connect_to.ip()
417	}
418
419	/// Get the port that we're either binding too, or connecting too.
420	#[must_use]
421	pub const fn port(&self) -> u16 {
422		self.address_to_bind_or_connect_to.port()
423	}
424
425	/// Set the slowdown to before sending bytes from this server.
426	pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
427		self.cat_dev_slowdown = slowdown;
428	}
429
430	#[must_use]
431	pub const fn chunk_output_at_size(&self) -> Option<usize> {
432		self.chunk_output_at_size
433	}
434
435	pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
436		self.chunk_output_at_size = new_size;
437	}
438
439	#[must_use]
440	pub const fn slowloris_timeout(&self) -> Duration {
441		self.slowloris_timeout
442	}
443	pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
444		self.slowloris_timeout = slowloris_timeout;
445	}
446
447	#[must_use]
448	pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<State>> {
449		self.on_stream_begin.as_ref()
450	}
451
452	/// Set a hook to run when a stream has been created.
453	///
454	/// This is what happens when a new machine connects. So it may also be
455	/// refered to as "on connect". This assumes you're assigning a service
456	/// that already exists and is in the raw storage type, you may want to
457	/// look into [`Self::set_on_stream_begin`], or
458	/// [`Self::set_on_stream_begin_service`].
459	///
460	/// ## Errors
461	///
462	/// If the stream beginning hook has already been registered. If you're
463	/// looking to perform multiple actions at once, look into layer-ing.
464	pub fn set_raw_on_stream_begin(
465		&mut self,
466		on_start: Option<UnderlyingOnStreamBeginService<State>>,
467	) -> Result<(), CommonNetAPIError> {
468		if self.on_stream_begin.is_some() {
469			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
470		}
471
472		self.on_stream_begin = on_start;
473		Ok(())
474	}
475
476	/// Set a function hook to run when a stream has been created.
477	///
478	/// This is what happens when a new machine connects. So it may also be
479	/// refered to as "on connect". This assumes you're assigning a function
480	/// to on stream begin otherwise use [`Self::set_raw_on_stream_begin`],
481	/// or [`Self::set_on_stream_begin_service`].
482	///
483	/// ## Errors
484	///
485	/// If the stream beginning hook has already been registered. If you're
486	/// looking to perform multiple actions at once, look into layer-ing.
487	pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
488		&mut self,
489		handler: HandlerTy,
490	) -> Result<(), CommonNetAPIError>
491	where
492		HandlerParamsTy: Send + 'static,
493		HandlerTy: OnResponseStreamBeginHandler<HandlerParamsTy, State> + Clone + Send + 'static,
494	{
495		if self.on_stream_begin.is_some() {
496			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
497		}
498
499		let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
500		self.on_stream_begin = Some(boxed);
501		Ok(())
502	}
503
504	/// Set a function hook to run when a stream has been created.
505	///
506	/// This is what happens when a new machine connects. So it may also be
507	/// refered to as "on connect". This assumes you're assigning a [`Service`]
508	/// to on stream begin otherwise use [`Self::set_raw_on_stream_begin`],
509	/// or [`Self::set_on_stream_begin`].
510	///
511	/// ## Errors
512	///
513	/// If the stream beginning hook has already been registered. If you're
514	/// looking to perform multiple actions at once, look into layer-ing.
515	pub fn set_on_stream_begin_service<ServiceTy>(
516		&mut self,
517		service_ty: ServiceTy,
518	) -> Result<(), CommonNetAPIError>
519	where
520		ServiceTy: Clone
521			+ Send
522			+ Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
523			+ 'static,
524		ServiceTy::Future: Send + 'static,
525	{
526		if self.on_stream_begin.is_some() {
527			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
528		}
529
530		self.on_stream_begin = Some(BoxCloneService::new(service_ty));
531		Ok(())
532	}
533
534	/// Add a layer to the service to process when a stream begins, or a new
535	/// connection is created.
536	///
537	/// ## Errors
538	///
539	/// If there is no on stream begin handler that is currently active.
540	pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
541		&mut self,
542		layer: LayerTy,
543	) -> Result<(), CommonNetAPIError>
544	where
545		LayerTy: Layer<UnderlyingOnStreamBeginService<State>, Service = ServiceTy>,
546		ServiceTy: Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
547			+ Clone
548			+ Send
549			+ 'static,
550		<LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
551	{
552		let Some(srvc) = self.on_stream_begin.take() else {
553			return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
554		};
555
556		self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
557		Ok(())
558	}
559
560	#[must_use]
561	pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<State>> {
562		self.on_stream_end.as_ref()
563	}
564
565	/// Set a hook to run when a stream is being destroyed.
566	///
567	/// This is what happens when a machine disconnects. So it may also be
568	/// refered to as "on disconnect". This assumes you're assigning a service
569	/// that already exists and is in the raw storage type, you may want to
570	/// look into [`Self::set_on_stream_end`], or
571	/// [`Self::set_on_stream_end_service`].
572	///
573	/// ## Errors
574	///
575	/// If the stream ending hook has already been registered. If you're
576	/// looking to perform multiple actions at once, look into layer-ing.
577	pub fn set_raw_on_stream_end(
578		&mut self,
579		on_end: Option<UnderlyingOnStreamEndService<State>>,
580	) -> Result<(), CommonNetAPIError> {
581		if self.on_stream_end.is_some() {
582			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
583		}
584
585		self.on_stream_end = on_end;
586		Ok(())
587	}
588
589	/// Set a function hook to run when a stream is being destroyed.
590	///
591	/// This is what happens when a machine disconnects. So it may also be
592	/// refered to as "on disconnect". This assumes you're assigning a function
593	/// to on stream end otherwise use [`Self::set_raw_on_stream_end`],
594	/// or [`Self::set_on_stream_end_service`].
595	///
596	/// ## Errors
597	///
598	/// If the stream ending hook has already been registered. If you're
599	/// looking to perform multiple actions at once, look into layer-ing.
600	pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
601		&mut self,
602		handler: HandlerTy,
603	) -> Result<(), CommonNetAPIError>
604	where
605		HandlerParamsTy: Send + 'static,
606		HandlerTy: OnResponseStreamEndHandler<HandlerParamsTy, State> + Clone + Send + 'static,
607	{
608		if self.on_stream_end.is_some() {
609			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
610		}
611
612		let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
613		self.on_stream_end = Some(boxed);
614		Ok(())
615	}
616
617	/// Set a function hook to run when a stream is being destroyed.
618	///
619	/// This is what happens when a machine disconnects. So it may also be
620	/// refered to as "on disconnect". This assumes you're assigning a [`Service`]
621	/// to on stream end otherwise use [`Self::set_raw_on_stream_end`],
622	/// or [`Self::set_on_stream_end`].
623	///
624	/// ## Errors
625	///
626	/// If the stream beginning hook has already been registered. If you're
627	/// looking to perform multiple actions at once, look into layer-ing.
628	pub fn set_on_stream_end_service<ServiceTy>(
629		&mut self,
630		service_ty: ServiceTy,
631	) -> Result<(), CommonNetAPIError>
632	where
633		ServiceTy: Clone
634			+ Send
635			+ Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
636			+ 'static,
637		ServiceTy::Future: Send + 'static,
638	{
639		if self.on_stream_end.is_some() {
640			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
641		}
642
643		self.on_stream_end = Some(BoxCloneService::new(service_ty));
644		Ok(())
645	}
646
647	/// Add a layer to the service to process when a stream ends, or a new
648	/// connection is destroyed.
649	///
650	/// ## Errors
651	///
652	/// If there is no on stream end handler that is currently active.
653	pub fn layer_on_stream_end<LayerTy, ServiceTy>(
654		&mut self,
655		layer: LayerTy,
656	) -> Result<(), CommonNetAPIError>
657	where
658		LayerTy: Layer<UnderlyingOnStreamEndService<State>, Service = ServiceTy>,
659		ServiceTy: Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
660			+ Clone
661			+ Send
662			+ 'static,
663		<LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
664	{
665		let Some(srvc) = self.on_stream_end.take() else {
666			return Err(CommonNetAPIError::OnStreamEndNotRegistered);
667		};
668
669		self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
670		Ok(())
671	}
672
673	#[must_use]
674	pub const fn initial_service(&self) -> &BoxCloneService<Request<State>, Response, Infallible> {
675		&self.initial_service
676	}
677
678	pub fn layer_initial_service<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
679	where
680		LayerTy: Layer<BoxCloneService<Request<State>, Response, Infallible>, Service = ServiceTy>,
681		ServiceTy: Service<Request<State>, Response = Response, Error = Infallible>
682			+ Clone
683			+ Send
684			+ 'static,
685		<LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
686	{
687		self.initial_service = BoxCloneService::new(layer.layer(self.initial_service.clone()));
688	}
689
690	/// Get a reference to the current state of the server.
691	#[must_use]
692	pub const fn state(&self) -> &State {
693		&self.state
694	}
695
696	/// "Connect" to our remote client address, and start serving ourselves to
697	/// whoever we connect too.
698	///
699	/// NOTE: this will ***NOT*** return until the service either runs into an
700	/// error, or the task is cancelled.
701	///
702	/// ## Errors
703	///
704	/// If we cannot connect to the remote source, fetch the local address, or
705	/// handle the connection in some other way.
706	pub async fn connect(self) -> Result<(), CatBridgeError> {
707		loop {
708			// Copy the address we're bound to, to use in log statements later on.
709			let client_address = self.address_to_bind_or_connect_to;
710			let stream = TcpStream::connect(self.address_to_bind_or_connect_to)
711				.await
712				.map_err(NetworkError::IO)?;
713			let loggable_address = stream.local_addr().map_err(NetworkError::IO)?;
714			let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
715			trace!(
716				server.address = %loggable_address,
717				client.address = %client_address,
718				stream.id = stream_id,
719				stream.stream_type = "server",
720				"cat_dev::net::tcp_server::connect(): started connection (TcpStream::connect())",
721			);
722
723			if let Err(cause) = Self::handle_tcp_connection(
724				self.on_stream_begin.clone(),
725				self.on_stream_end.clone(),
726				self.nagle_guard.clone(),
727				self.slowloris_timeout,
728				self.initial_service.clone(),
729				stream,
730				client_address,
731				self.pre_nagle_hook,
732				self.post_nagle_hook,
733				self.chunk_output_at_size,
734				self.state.clone(),
735				self.id,
736				stream_id,
737				self.cat_dev_slowdown,
738				#[cfg(debug_assertions)]
739				self.trace_during_debug,
740			)
741			.instrument(error_span!(
742				"CatDevTCPServerConnect",
743				client.address = %client_address,
744				server.address = %loggable_address,
745				server.service = self.service_name,
746				stream.id = stream_id,
747				stream.stream_type = "server",
748			))
749			.await
750			{
751				warn!(
752					?cause,
753					client.address = %client_address,
754					server.address = %loggable_address,
755					server.service = self.service_name,
756					"Error escaped while handling TCP connection.",
757				);
758			}
759		}
760	}
761
762	/// Bind this server, and start listening on the specified address.
763	///
764	/// NOTE: this will ***NOT*** return until the service either runs into an
765	/// error, or gets cancelled.
766	///
767	/// ## Errors
768	///
769	/// - If we cannot bind to the requested address, and port.
770	/// - If we run into an error accept'ing a TCP connection.
771	pub async fn bind(self) -> Result<(), CatBridgeError> {
772		// Copy the address we're bound to, to use in log statements later on.
773		let loggable_address = self.address_to_bind_or_connect_to;
774		let listener = TcpListener::bind(self.address_to_bind_or_connect_to)
775			.await
776			.map_err(NetworkError::IO)?;
777
778		loop {
779			let (stream, client_address) = listener.accept().await.map_err(NetworkError::IO)?;
780			trace!(
781				server.address = %loggable_address,
782				client.address = %client_address,
783				"cat_dev::net::tcp_server::bind(): received connection (listener.accept())",
784			);
785
786			// We need to clone all of our stuff, so they can be owned by the
787			// underlying connection task.
788			//
789			// We spawn a task per connection that way one connection doesn't block
790			// the accept'ing of any new connections. If we did, we'd only be able
791			// to handle one connection at a time which is lame.
792			//
793			// Tasks are also incredibly cheap, so spawning one per connection even
794			// long lived tasks is nbd.
795			let cloned_begin_handler = self.on_stream_begin.clone();
796			let cloned_end_handler = self.on_stream_end.clone();
797			let cloned_nagle_guard = self.nagle_guard.clone();
798			let cloned_handler = self.initial_service.clone();
799			let cloned_state = self.state.clone();
800			let copied_pre_nagle_hook = self.pre_nagle_hook;
801			let copied_post_nagle_hook = self.post_nagle_hook;
802			let copied_chunk_on_size = self.chunk_output_at_size;
803			let copied_service_name = self.service_name;
804			let copied_slowloris_timeout = self.slowloris_timeout;
805			let copied_server_id = self.id;
806			let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
807			let copied_slowdown = self.cat_dev_slowdown;
808			#[cfg(debug_assertions)]
809			let trace_io = self.trace_during_debug;
810
811			TaskBuilder::new()
812				.name("cat_dev::net::tcp_server::bind().connection.handle")
813				.spawn(async move {
814					if let Err(cause) = Self::handle_tcp_connection(
815						cloned_begin_handler,
816						cloned_end_handler,
817						cloned_nagle_guard,
818						copied_slowloris_timeout,
819						cloned_handler,
820						stream,
821						client_address,
822						copied_pre_nagle_hook,
823						copied_post_nagle_hook,
824						copied_chunk_on_size,
825						cloned_state,
826						copied_server_id,
827						stream_id,
828						copied_slowdown,
829						#[cfg(debug_assertions)]
830						trace_io,
831					)
832					.instrument(error_span!(
833						"CatDevTCPServerAccept",
834						client.address = %client_address,
835						server.address = %loggable_address,
836						server.service = copied_service_name,
837						server.stream_id = stream_id,
838					))
839					.await
840					{
841						warn!(
842							?cause,
843							client.address = %client_address,
844							server.address = %loggable_address,
845							server.service = %copied_service_name,
846							"Error escaped while handling TCP connection.",
847						);
848					}
849				})
850				.map_err(CatBridgeError::SpawnFailure)?;
851		}
852	}
853
854	/// Attempt to handle a TCP connection, or "stream".
855	///
856	/// This will be called immediately from the task that spawns on
857	/// connections. It will automatically spin down "shortly" after
858	/// the connection is gone. We basically wait for an error back from the OS.
859	///
860	/// ## Errors
861	///
862	/// If reading, or writing to the underlying TCP Stream runs into any errors.
863	/// Individual packet errors do not necissarily correlate to a closed
864	/// connection. However, most services will link the two together.
865	#[allow(
866		// all of our parameters are very well named, and types are not close to
867		// overlapping with each other.
868		//
869		// we also just fundamenetally have a lot of state thanks to the complexity
870		// of all the things we have to handle for a TCP connection, e.g. NAGLE,
871		// delimiters, caches, etc.
872		//
873		// it is also only ever called from one internal function, so it's not like
874		// part of our public facing api.
875		clippy::too_many_arguments,
876	)]
877	async fn handle_tcp_connection(
878		on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
879		on_stream_end_handler: Option<UnderlyingOnStreamEndService<State>>,
880		nagle_guard: NagleGuard,
881		slowloris_timeout: Duration,
882		handler: BoxCloneService<Request<State>, Response, Infallible>,
883		mut tcp_stream: TcpStream,
884		client_address: SocketAddr,
885		pre_hook_cloned: Option<&'static dyn PreNagleFnTy>,
886		post_hook_cloned: Option<&'static dyn PostNagleFnTy>,
887		chunk_output_at_size: Option<usize>,
888		state: State,
889		server_id: u64,
890		stream_id: u64,
891		cat_dev_slowdown: Option<Duration>,
892		#[cfg(debug_assertions)] trace_io: bool,
893	) -> Result<(), CatBridgeError> {
894		let (mut send_responses, mut packets_left_to_send) =
895			bounded_channel::<ResponseStreamMessage>(128);
896
897		// Run connection handlers and let them tell us if they want to allow this
898		// stream setup to continue.
899		if Self::initialize_stream(
900			on_stream_begin_handler,
901			&mut send_responses,
902			&client_address,
903			&state,
904			&mut tcp_stream,
905			server_id,
906			stream_id,
907		)
908		.await?
909		{
910			return Ok(());
911		}
912
913		// Connection has been "approved", setup the on disconnect handler.
914		let _guard = on_stream_end_handler.map(|service| {
915			DisconnectAsyncDropServer::new(service, state.clone(), client_address, stream_id)
916		});
917
918		// Any previously saved data that was a victim of NAGLE's algorithim, or
919		// similar.
920		let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
921
922		loop {
923			let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
924			tokio::select! {
925				received = packets_left_to_send.recv() => {
926					if Self::handle_server_write_to_connection(
927						&mut tcp_stream,
928						chunk_output_at_size,
929						received,
930						post_hook_cloned,
931						stream_id,
932						cat_dev_slowdown,
933						#[cfg(debug_assertions)] trace_io,
934					).await? {
935						break;
936					}
937				}
938				res_size = tcp_stream.read_buf(&mut buff) => {
939					let size = res_size.map_err(NetworkError::IO)?;
940					buff.truncate(size);
941					if buff.is_empty() {
942						continue;
943					}
944
945					let (should_break, returned_stream) = Self::handle_server_read_from_connection(
946						tcp_stream,
947						buff,
948						send_responses.clone(),
949						&nagle_guard,
950						slowloris_timeout,
951						handler.clone(),
952						&mut nagle_cache,
953						client_address,
954						pre_hook_cloned,
955						state.clone(),
956						stream_id,
957						#[cfg(debug_assertions)] trace_io,
958					).await?;
959					tcp_stream = returned_stream;
960					if should_break {
961						break;
962					}
963				}
964			}
965		}
966
967		OUT_OF_BAND_SENDERS
968			.remove_async(&(server_id, stream_id))
969			.await;
970		packets_left_to_send.close();
971		std::mem::drop(tcp_stream.shutdown().await);
972
973		Ok(())
974	}
975
976	async fn initialize_stream(
977		on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
978		send_channel: &mut BoundedSender<ResponseStreamMessage>,
979		source_address: &SocketAddr,
980		state: &State,
981		tcp_stream: &mut TcpStream,
982		server_id: u64,
983		stream_id: u64,
984	) -> Result<bool, CatBridgeError> {
985		tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
986		OUT_OF_BAND_SENDERS
987			.upsert_async((server_id, stream_id), send_channel.clone())
988			.await;
989
990		if let Some(mut handle) = on_stream_begin_handler
991			&& !handle
992				.call(ResponseStreamEvent::new_with_state(
993					send_channel.clone(),
994					*source_address,
995					Some(stream_id),
996					state.clone(),
997				))
998				.await?
999		{
1000			trace!("handler failed on stream begin hook");
1001			return Ok(true);
1002		}
1003
1004		Ok(false)
1005	}
1006
1007	async fn handle_server_write_to_connection(
1008		tcp_stream: &mut TcpStream,
1009		chunk_output_on_size: Option<usize>,
1010		to_send_to_client_opt: Option<ResponseStreamMessage>,
1011		post_hook: Option<&'static dyn PostNagleFnTy>,
1012		stream_id: u64,
1013		cat_dev_slowdown: Option<Duration>,
1014		#[cfg(debug_assertions)] trace_io: bool,
1015	) -> Result<bool, CatBridgeError> {
1016		let Some(to_send_to_client) = to_send_to_client_opt else {
1017			return Ok(false);
1018		};
1019
1020		match to_send_to_client {
1021			ResponseStreamMessage::Disconnect => {
1022				debug!("stream-disconnect-message");
1023				Ok(true)
1024			}
1025			ResponseStreamMessage::Response(resp) => {
1026				if let Some(body) = resp.body()
1027					&& !body.is_empty()
1028				{
1029					let messages = if let Some(size) = chunk_output_on_size {
1030						body.chunks(size)
1031							.map(Bytes::copy_from_slice)
1032							.collect::<Vec<_>>()
1033					} else {
1034						vec![body.clone()]
1035					};
1036
1037					for message in messages {
1038						#[cfg(debug_assertions)]
1039						if trace_io {
1040							debug!(
1041								body.hex = format!("{message:02x?}"),
1042								body.str = String::from_utf8_lossy(&message).to_string(),
1043								"cat-dev-trace-output-tcp-server",
1044							);
1045						}
1046
1047						let mut full_response = message.clone();
1048						if let Some(post) = post_hook {
1049							full_response = block_in_place(|| post(stream_id, full_response));
1050						}
1051						if let Some(slowdown_ms) = cat_dev_slowdown {
1052							sleep(slowdown_ms).await;
1053						}
1054
1055						tcp_stream.writable().await.map_err(NetworkError::IO)?;
1056						tcp_stream
1057							.write_all(&full_response)
1058							.await
1059							.map_err(NetworkError::IO)?;
1060					}
1061				}
1062
1063				if resp.request_connection_close() {
1064					trace!("response-requested-connection-close");
1065					Ok(true)
1066				} else {
1067					Ok(false)
1068				}
1069			}
1070		}
1071	}
1072
1073	#[allow(
1074		// All of our types are very differently typed, and well named, so chance
1075		// of confusion is low.
1076		//
1077		// Not to mention this is an internal only method.
1078		clippy::too_many_arguments,
1079	)]
1080	async fn handle_server_read_from_connection<'data>(
1081		mut stream: TcpStream,
1082		mut buff: BytesMut,
1083		channel: BoundedSender<ResponseStreamMessage>,
1084		nagle_guard: &'data NagleGuard,
1085		slowloris_timeout: Duration,
1086		mut handler: BoxCloneService<Request<State>, Response, Infallible>,
1087		nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
1088		client_address: SocketAddr,
1089		cloned_pre_nagle: Option<&'static dyn PreNagleFnTy>,
1090		state: State,
1091		stream_id: u64,
1092		#[cfg(debug_assertions)] trace_io: bool,
1093	) -> Result<(bool, TcpStream), CatBridgeError> {
1094		if let Some(convert_fn) = cloned_pre_nagle {
1095			block_in_place(|| {
1096				(*convert_fn)(stream_id, &mut buff);
1097			});
1098		}
1099
1100		#[cfg(debug_assertions)]
1101		{
1102			if trace_io {
1103				debug!(
1104					body.hex = format!("{:02x?}", buff),
1105					body.str = String::from_utf8_lossy(&buff).to_string(),
1106					"cat-dev-trace-input-tcp-server",
1107				);
1108			}
1109		}
1110
1111		// We may be NAGEL'd, so we need to recover/split, and potentially buffer
1112		// any packets. Also watch out for slowloris-esque attacks.
1113		let start_time = now();
1114		if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
1115			// If we can't calculat duration seconds it's negative, or no duration
1116			// has passed yet.
1117			//
1118			// Just treat it as 0.
1119			let total_duration = start_time
1120				.duration_since(old_start_time)
1121				.unwrap_or(Duration::from_secs(0));
1122			if total_duration > slowloris_timeout {
1123				debug!(
1124					cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
1125					"slowloris-detected",
1126				);
1127				return Ok((true, stream));
1128			}
1129
1130			existing_buff.extend(buff.freeze());
1131			buff = existing_buff;
1132		}
1133
1134		while let Some((start_of_packet, end_of_packet)) = nagle_guard.split(&buff)? {
1135			let remaining_buff = buff.split_off(end_of_packet);
1136			let _start_of_buff = buff.split_to(start_of_packet);
1137			let req_body = buff.freeze();
1138			buff = remaining_buff;
1139
1140			let lockable_stream = Arc::new(Mutex::new(Some((Some(buff), stream))));
1141			let mut request_object = Request::new_with_state_and_stream(
1142				req_body,
1143				client_address,
1144				state.clone(),
1145				Some(stream_id),
1146				lockable_stream.clone(),
1147			);
1148			request_object.extensions_mut().insert(channel.clone());
1149			if let Err(cause) = match handler.call(request_object).await {
1150				Ok(ref resp) => {
1151					channel
1152						.send(ResponseStreamMessage::Response(resp.clone()))
1153						.await
1154				}
1155				Err(cause) => {
1156					warn!(?cause, "request handler failed, will close connection.");
1157					channel.send(ResponseStreamMessage::Disconnect).await
1158				}
1159			} {
1160				warn!(
1161					?cause,
1162					"internal queue failure will not send disconnect/response."
1163				);
1164			}
1165
1166			{
1167				let mut done_lock = lockable_stream.lock().await;
1168				if let Some((newer_buff, strm)) = done_lock.take() {
1169					if let Some(newest_buff) = newer_buff {
1170						buff = newest_buff;
1171					} else {
1172						return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1173					}
1174					stream = strm;
1175				} else {
1176					return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1177				}
1178			}
1179		}
1180
1181		if !buff.is_empty() {
1182			_ = nagle_cache.insert((buff, start_time));
1183		}
1184
1185		Ok((false, stream))
1186	}
1187}
1188
1189impl<State: Clone + Debug + Send + Sync + 'static> Debug for TCPServer<State> {
1190	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
1191		let mut dbg_struct = fmt.debug_struct("TCPServer");
1192		dbg_struct
1193			.field(
1194				"address_to_bind_or_connect_to",
1195				&self.address_to_bind_or_connect_to,
1196			)
1197			.field("cat_dev_slowdown", &self.cat_dev_slowdown)
1198			.field("chunk_output_at_size", &self.chunk_output_at_size)
1199			.field("id", &self.id)
1200			.field("initial_service", &self.initial_service)
1201			.field("nagle_guard", &self.nagle_guard)
1202			.field("on_stream_begin", &self.on_stream_begin)
1203			.field("on_stream_end", &self.on_stream_end)
1204			.field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
1205			.field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
1206			.field("service_name", &self.service_name)
1207			.field("slowloris_timeout", &self.slowloris_timeout)
1208			.field("state", &self.state);
1209
1210		#[cfg(debug_assertions)]
1211		{
1212			dbg_struct.field("trace_during_debug", &self.trace_during_debug);
1213		}
1214
1215		dbg_struct.finish()
1216	}
1217}
1218
1219const TCP_SERVER_FIELDS: &[NamedField<'static>] = &[
1220	NamedField::new("address_to_bind_or_connect_to"),
1221	NamedField::new("cat_dev_slowdown"),
1222	NamedField::new("chunk_output_at_size"),
1223	NamedField::new("initial_service"),
1224	NamedField::new("nagle_guard"),
1225	NamedField::new("on_stream_begin"),
1226	NamedField::new("on_stream_end"),
1227	NamedField::new("has_pre_nagle_hook"),
1228	NamedField::new("has_post_nagle_hook"),
1229	NamedField::new("service_name"),
1230	NamedField::new("slowloris_timeout"),
1231	NamedField::new("state"),
1232	#[cfg(debug_assertions)]
1233	NamedField::new("trace_during_debug"),
1234];
1235
1236impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Structable for TCPServer<State> {
1237	fn definition(&self) -> StructDef<'_> {
1238		StructDef::new_static("TcpServer", Fields::Named(TCP_SERVER_FIELDS))
1239	}
1240}
1241
1242impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Valuable for TCPServer<State> {
1243	fn as_value(&self) -> Value<'_> {
1244		Value::Structable(self)
1245	}
1246
1247	fn visit(&self, visitor: &mut dyn Visit) {
1248		visitor.visit_named_fields(&NamedValues::new(
1249			TCP_SERVER_FIELDS,
1250			&[
1251				Valuable::as_value(&format!("{}", self.address_to_bind_or_connect_to)),
1252				Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
1253					format!("{}ms", slowdown.as_millis())
1254				} else {
1255					"<none>".to_string()
1256				}),
1257				Valuable::as_value(&self.chunk_output_at_size),
1258				Valuable::as_value(&format!("{:?}", self.initial_service)),
1259				Valuable::as_value(&self.nagle_guard),
1260				Valuable::as_value(&format!("{:?}", self.on_stream_begin)),
1261				Valuable::as_value(&format!("{:?}", self.on_stream_end)),
1262				Valuable::as_value(&self.pre_nagle_hook.is_some()),
1263				Valuable::as_value(&self.post_nagle_hook.is_some()),
1264				Valuable::as_value(&self.service_name),
1265				Valuable::as_value(&format!("{:?}", self.slowloris_timeout)),
1266				Valuable::as_value(&self.state),
1267				#[cfg(debug_assertions)]
1268				Valuable::as_value(&self.trace_during_debug),
1269			],
1270		));
1271	}
1272}
1273
1274#[cfg(test)]
1275pub mod test_helpers {
1276	use super::*;
1277	use std::net::{Ipv4Addr, SocketAddrV4};
1278
1279	/// Get a free TCP Port for IPv4.
1280	///
1281	/// This will attempt to get a free tcp port on IPv4 by actually binding to
1282	/// port `0` which should make the OS automatically assign a free unused
1283	/// port.
1284	pub async fn get_free_tcp_v4_port() -> Option<u16> {
1285		let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
1286		if let Ok(bound) = TcpListener::bind(addr).await {
1287			if let Ok(local) = bound.local_addr() {
1288				return Some(local.port());
1289			}
1290		}
1291		None
1292	}
1293}
1294
1295#[cfg(test)]
1296mod unit_tests {
1297	use super::*;
1298	use crate::net::{
1299		CURRENT_TIME,
1300		server::{Router, requestable::Extension, test_helpers::*},
1301	};
1302	use bytes::Bytes;
1303	use std::{
1304		net::{Ipv4Addr, SocketAddrV4},
1305		sync::{
1306			Arc, Mutex,
1307			atomic::{AtomicU8, Ordering},
1308		},
1309		time::Duration,
1310	};
1311	use tokio::time::timeout;
1312
1313	fn set_now(new_time: SystemTime) {
1314		CURRENT_TIME.with(|time_lazy| {
1315			*time_lazy.write().expect("RwLock is poisioned?") = new_time;
1316		})
1317	}
1318
1319	#[tokio::test]
1320	pub async fn test_full_server() {
1321		let connected_fired = Arc::new(Mutex::new(false));
1322		let on_disconnect_fired = Arc::new(Mutex::new(false));
1323		let request_fired = Arc::new(Mutex::new(false));
1324
1325		async fn on_connection(
1326			Extension(connected): Extension<Arc<Mutex<bool>>>,
1327		) -> Result<bool, CatBridgeError> {
1328			let mut locked = connected
1329				.lock()
1330				.expect("Failed to lock connected fired extension");
1331			*locked = true;
1332			Ok(true)
1333		}
1334		async fn on_disconnect(
1335			Extension(disconnected): Extension<Arc<Mutex<bool>>>,
1336		) -> Result<(), CatBridgeError> {
1337			let mut locked = disconnected
1338				.lock()
1339				.expect("Failed to lock connected fired extension");
1340			*locked = true;
1341			Ok(())
1342		}
1343		async fn on_request(
1344			Extension(request): Extension<Arc<Mutex<bool>>>,
1345		) -> Result<Response, CatBridgeError> {
1346			let mut locked = request
1347				.lock()
1348				.expect("Failed to lock connected fired extension");
1349			*locked = true;
1350
1351			let mut resp = Response::new_with_body(Bytes::from(vec![0x1]));
1352			resp.should_close_connection();
1353			Ok(resp)
1354		}
1355
1356		let mut router = Router::new();
1357		router
1358			.add_route(&[0x1, 0x2, 0x3], on_request)
1359			.expect("Failed to add a route!");
1360		router.layer(Extension(request_fired.clone()));
1361
1362		let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1363			.await
1364			.expect("Timed out trying to find free port!")
1365			.expect("Failed to find free TCP port on system.");
1366
1367		let mut srv = timeout(
1368			Duration::from_secs(5),
1369			TCPServer::new_with_state(
1370				"test",
1371				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1372				router,
1373				(None, None),
1374				NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1375				(),
1376				#[cfg(debug_assertions)]
1377				true,
1378			),
1379		)
1380		.await
1381		.expect("Timed out starting server")
1382		.expect("Failed to create TCP Server.");
1383
1384		srv.set_on_stream_begin(on_connection)
1385			.expect("Failed to register stream begin handler!");
1386		srv.layer_on_stream_begin(Extension(connected_fired.clone()))
1387			.expect("Failed to add layer to on stream begin!");
1388		srv.set_on_stream_end(on_disconnect)
1389			.expect("Failed to register stream end handler!");
1390		srv.layer_on_stream_end(Extension(on_disconnect_fired.clone()))
1391			.expect("Failed to add layer to on_disconnect!");
1392
1393		let spawned =
1394			tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1395		{
1396			loop {
1397				let client_stream_res = timeout(
1398					Duration::from_secs(10),
1399					TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1400				)
1401				.await
1402				.expect("Service timed out waiting for connection!");
1403				// Service hasn't binded yet!
1404				if client_stream_res.is_err() {
1405					continue;
1406				}
1407				let mut client_stream = client_stream_res.unwrap();
1408				client_stream
1409					.write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1410					.await
1411					.expect("Failed to write to client stream");
1412				// Wait til we get a response.
1413				let mut buff = [0_u8; 1];
1414				timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1415					.await
1416					.expect("Timed out reading from client stream")
1417					.expect("Failed to read data from client stream");
1418				timeout(Duration::from_secs(5), client_stream.shutdown())
1419					.await
1420					.expect("Timed out shutting down client stream")
1421					.expect("Failed to shutdown client stream.");
1422				break;
1423			}
1424		}
1425		// Destruct the server.
1426		std::mem::drop(spawned);
1427
1428		let locked_connect = connected_fired
1429			.lock()
1430			.expect("Failed to lock second connect");
1431		let locked_disconnect = on_disconnect_fired
1432			.lock()
1433			.expect("Failed to lock second on_disconnect");
1434		let locked_request = request_fired.lock().expect("Failed to lock second request");
1435
1436		assert!(*locked_connect, "on connection handler never fired!");
1437		assert!(*locked_disconnect, "on disconnect handler never fired!");
1438		assert!(*locked_request, "on request handler never fired!");
1439	}
1440
1441	#[tokio::test]
1442	pub async fn test_nagled() {
1443		let requests_fired = Arc::new(AtomicU8::new(0));
1444
1445		async fn on_request(
1446			Extension(request): Extension<Arc<AtomicU8>>,
1447		) -> Result<Response, CatBridgeError> {
1448			request.fetch_add(1, Ordering::SeqCst);
1449			let resp = Response::new_with_body(Bytes::from(vec![0x1]));
1450			Ok(resp)
1451		}
1452
1453		let mut router = Router::new();
1454		router
1455			.add_route(&[0x1, 0x2, 0x3], on_request)
1456			.expect("Failed to add a route!");
1457		router.layer(Extension(requests_fired.clone()));
1458
1459		let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1460			.await
1461			.expect("Timed out finding port to bind too!")
1462			.expect("Failed to find any free tcp v4 port on system!");
1463		let srv = timeout(
1464			Duration::from_secs(5),
1465			TCPServer::new_with_state(
1466				"test",
1467				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1468				router,
1469				(None, None),
1470				NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1471				(),
1472				#[cfg(debug_assertions)]
1473				true,
1474			),
1475		)
1476		.await
1477		.expect("timed out starting TCP Server for test")
1478		.expect("falied to create local tcp server!");
1479
1480		let spawned =
1481			tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1482		{
1483			loop {
1484				let client_stream_res = timeout(
1485					Duration::from_secs(10),
1486					TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1487				)
1488				.await
1489				.expect("Service timed out waiting for connection!");
1490				// Service hasn't binded yet!
1491				if client_stream_res.is_err() {
1492					continue;
1493				}
1494				let mut client_stream = client_stream_res.unwrap();
1495
1496				client_stream
1497					.write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF])
1498					.await
1499					.expect("Failed to write to client_stream");
1500				client_stream
1501					.flush()
1502					.await
1503					.expect("Failed to flush client_stream");
1504				client_stream
1505					.write_all(&[0xFF, 0xFF, 0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1506					.await
1507					.expect("Failed to issue second write call to client_stream");
1508				let mut buff = [0_u8; 2];
1509				let read_bytes = timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1510					.await
1511					.expect("Timed out reading from client_stream")
1512					.expect("Failed to read from client_stream!");
1513				if read_bytes == 1 {
1514					timeout(Duration::from_secs(5), client_stream.read(&mut buff[1..]))
1515						.await
1516						.expect("Timed out reading from client_stream")
1517						.expect("Failed to read from client_stream!");
1518				}
1519
1520				timeout(Duration::from_secs(5), client_stream.shutdown())
1521					.await
1522					.expect("Timed out shutting down client stream")
1523					.expect("Failed to shutdown client stream.");
1524				break;
1525			}
1526		}
1527		// Shuts the server down.
1528		std::mem::drop(spawned);
1529
1530		assert_eq!(
1531			requests_fired.load(Ordering::SeqCst),
1532			2,
1533			"on request did not fire the correct amount of times!",
1534		);
1535	}
1536
1537	#[tokio::test]
1538	pub async fn test_slowloris_blocking() {
1539		let mut router = Router::new();
1540		router
1541			.add_route(&[0x1, 0x2, 0x3], || async {
1542				Ok(Response::new_with_body(Bytes::from(vec![0x1])))
1543			})
1544			.expect("Failed to add a route!");
1545
1546		let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1547			.await
1548			.expect("Timed out finding port to bind too!")
1549			.expect("Failed to find any free tcp v4 port on system!");
1550		let srv = timeout(
1551			Duration::from_secs(5),
1552			TCPServer::new_with_state(
1553				"test",
1554				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1555				router,
1556				(None, None),
1557				NagleGuard::EndSigilSearch(&[0x10, 0x11, 0x12]),
1558				(),
1559				#[cfg(debug_assertions)]
1560				true,
1561			),
1562		)
1563		.await
1564		.expect("timed out starting TCP Server for test")
1565		.expect("falied to create local tcp server!");
1566
1567		let spawned =
1568			tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1569		let read_bytes;
1570		{
1571			loop {
1572				let client_stream_res = timeout(
1573					Duration::from_secs(10),
1574					TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1575				)
1576				.await
1577				.expect("Service timed out waiting for connection!");
1578				// Service hasn't binded yet!
1579				if client_stream_res.is_err() {
1580					continue;
1581				}
1582				let mut client_stream = client_stream_res.unwrap();
1583
1584				client_stream
1585					.write_all(&[0x1, 0x2, 0x3, 0x10])
1586					.await
1587					.expect("Failed to write to client_stream");
1588				client_stream
1589					.flush()
1590					.await
1591					.expect("Failed to flush client_stream");
1592				// This ensures the server running in the background has for sure
1593				// cached the time we started processing.
1594				tokio::time::sleep(Duration::from_secs(5)).await;
1595				set_now(
1596					SystemTime::now()
1597						.checked_add(Duration::from_secs(900_00_000))
1598						.expect("Failed to add time to systemtime"),
1599				);
1600				// We do properly finish off the packet but it's too late...
1601				client_stream
1602					.write_all(&[0x11, 0x12])
1603					.await
1604					.expect("Failed to write to client_stream");
1605
1606				let mut buff = [0_u8; 1];
1607				// this read should timeout as the server will close our connection, so we shouldn't hit
1608				// the timeout.
1609				read_bytes = timeout(Duration::from_secs(10), client_stream.read(&mut buff))
1610					.await
1611					.expect("timed out trying to wait for disconnect")
1612					.expect("failure reading from stream");
1613				break;
1614			}
1615		}
1616		std::mem::drop(spawned);
1617
1618		assert_eq!(read_bytes, 0, "Client didn't error on slowloris'd packet");
1619	}
1620}