Skip to main content

cat_dev/net/server/
tcp.rs

1//! Utilities for Serving the TCP layer 4 protocol.
2//!
3//! Note this "TCP Server" class allows both binding to a particular port, as
4//! well as "connect"'ing to another host. This is mostly due to the fact we
5//! have to support SDIO where the person who "connect"'s acts as the server.
6//!
7//! ## Notes About Layering
8//!
9//! When layer-ing things on-top of service, it's important to note about
10//! performance characteristics. For every single individual `layer` call
11//! there ***may be significant overhead***. As such, it is best to use
12//! [`tower::ServiceBuilder`] in order to lower the cost of overhead related
13//! to significant layering.
14//!
15//! ## TCP and NAGLE'ing
16//!
17//! TCP the protocol is fundamentally built on-top of a 'stream', and is
18//! not packet oriented. Even though we try to guard against this, by
19//! setting no-delay on multiple parts of our infrastructure. There is
20//! a whole world out there who may not do what we want.
21//!
22//! As a result, we need to guard against weird splits of data across read
23//! calls. Because a single read call isn't neccissarily going to read a full
24//! "packet", and may even read multiple packets.
25//!
26//! I generally will refer to this in the source code as "NAGLE" protections,
27//! as NAGLE's algorithim is usually to blame for weird behaviors here. Where
28//! a device will combine multiple small packets together.
29//!
30//! ## Notes about Concurrency
31//!
32//! This TCP Server unfortunately has to make the sacrifice and process one
33//! packet per stream at a time. While you can have as many TCP streams as you
34//! want, and we should be able to handle many at the same time! Unfortunately
35//! the ordered nature of TCP, along with some protocol designs implemented by
36//! nintendo means this server must also force that we process one packet per
37//! tcp stream at a time.
38//!
39//! Most notably this comes from the fact that our file servers will
40//! consistently break their normal "NAGLE" protection, and we have to do just
41//! raw reads of N bytes from the stream (in both ways), _BEFORE_ processing
42//! another request.
43//!
44//! ## SLOW-loris
45//!
46//! As mentioned TCP is built ontop of a stream, there is a chance that someone
47//! can just "slow drip" us a packet, bit by bit, eating up an open connection
48//! slot for a long period of time. Eating up a connection slot forever. This
49//! used to be a lot more damaging when you'd allocate a thread for every
50//! single connection.
51//!
52//! For us, because we _don't_ allocate a thread for every single connection
53//! the cost of any single connection is low. Not to mention we don't really
54//! have to deal with untrusted DOS connections in a lot of cases.
55//!
56//! That being said, there's no reason to allow someone to slow-drip us a
57//! packet forever, so our servers will inherently time out a packet if it's
58//! taking too long closing a connection.
59//!
60//! ## Lifecycle of a Packet
61//!
62//! ```ascii
63//!     ┌─────────┐                        ┌─────────┐
64//!     │TCPClient│                        │TCPServer│
65//!     └─────────┘                        └─────────┘
66//!          │Sends Arbitrary Sequence of Bytes │
67//!          │─────────────────────────────────>│
68//!          │                                  │
69//!          │                                  │────┐
70//!          │                                  │    │ Call PreNagleFN (if one exists).
71//!          │                                  │    │ This is a custom function that handles things like:
72//!          │                                  │    │
73//!          │                                  │    │   1. Decrypting Packets.
74//!          │                                  │    │   2. Logging info not through trace.
75//!          │                                  │    │
76//!          │                                  │    │ Generally should not be used for anything as it
77//!          │                                  │    │ happens before ALL processing.
78//!          │                                  │<───┘
79//!          │                                  │
80//!          │                                  │────┐
81//!          │                                  │    │ Check how long this client has been attempting to
82//!          │                                  │    │ send us a single packet.
83//!          │                                  │    │
84//!          │                                  │    │ Compare this with `SLOWLORIS_TIMEOUT`,
85//!          │                                  │    │ and error/close the stream if it's taken too long to
86//!          │                                  │    │ send us a packet).
87//!          │                                  │<───┘
88//!          │                                  │
89//!          │                                  │────┐
90//!          │                                  │    │ Look at "nagle cache", this is any data that was previously sent
91//!          │                                  │    │ on this TCP Stream, that has not yet been processed.
92//!          │                                  │    │
93//!          │                                  │    │ Create final `buff`.
94//!          │                                  │<───┘
95//!          │                                  │
96//!          │                                  │────┐
97//!          │                                  │    │ Look at "nagle guard", this determines when one packet begins/ends.
98//!          │                                  │    │ Process any # of packets that may have come in.
99//!          │                                  │    │
100//!          │                                  │    │ If any data is left over put it in "nagle cache".
101//!          │                                  │<───┘
102//!          │                                  │
103//!          │                                  │────┐
104//!          │                                  │    │ Process Packet.
105//!          │                                  │<───┘
106//!          │                                  │
107//!          │                                  │────┐
108//!          │                                  │    │ If reply was sent, and wasn't empty...
109//!          │                                  │<───┘
110//!          │                                  │
111//!          │                                  │────┐
112//!          │                                  │    │ Look at "chunked output size", and chunk packet if necessary.
113//!          │                                  │<───┘
114//!          │                                  │
115//!          │                                  │────┐
116//!          │                                  │    │ For each packet apply PostNagleFN (if one exists).
117//!          │                                  │    │ This is a custom function that handles things like:
118//!          │                                  │    │
119//!          │                                  │    │   1. Encrypting Packets.
120//!          │                                  │    │   2. Logging info not through trace.
121//!          │                                  │    │   3. Sleeping to not overwhelm a CAT-DEV.
122//!          │                                  │<───┘
123//!          │                                  │
124//!          │      Data for the Client.        │
125//!          │<─────────────────────────────────│
126//!     ┌─────────┐                        ┌─────────┐
127//!     │TCPClient│                        │TCPServer│
128//!     └─────────┘                        └─────────┘
129//! ```
130
131use crate::{
132	errors::{CatBridgeError, NetworkError},
133	net::{
134		DEFAULT_SLOWLORIS_TIMEOUT, SERVER_ID, STREAM_ID, TCP_READ_BUFFER_SIZE,
135		errors::{CommonNetAPIError, CommonNetNetworkError},
136		handlers::{
137			OnResponseStreamBeginHandler, OnResponseStreamEndHandler,
138			OnStreamBeginHandlerAsService, OnStreamEndHandlerAsService,
139		},
140		models::{NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
141		now,
142		server::models::{
143			DisconnectAsyncDropServer, ResponseStreamEvent, ResponseStreamMessage,
144			UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
145		},
146	},
147};
148use bytes::{Bytes, BytesMut};
149use fnv::FnvHashSet;
150use futures::future::join_all;
151use scc::HashMap as ConcurrentMap;
152use std::{
153	convert::Infallible,
154	fmt::{Debug, Formatter, Result as FmtResult},
155	net::{IpAddr, SocketAddr},
156	sync::{Arc, LazyLock, atomic::Ordering},
157	time::{Duration, SystemTime},
158};
159use tokio::{
160	io::{AsyncReadExt, AsyncWriteExt},
161	net::{TcpListener, TcpStream, ToSocketAddrs, lookup_host},
162	sync::{
163		Mutex,
164		mpsc::{Sender as BoundedSender, channel as bounded_channel},
165	},
166	task::{Builder as TaskBuilder, block_in_place},
167	time::sleep,
168};
169use tower::{Layer, Service, util::BoxCloneService};
170use tracing::{Instrument, debug, error_span, trace, warn};
171use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
172
173#[cfg(debug_assertions)]
174use crate::net::SPRIG_TRACE_IO;
175
176/// A map of streams and the channels to queue a repsonse packet out on, in a
177/// complete out of band way.
178static OUT_OF_BAND_SENDERS: LazyLock<
179	ConcurrentMap<(u64, u64), BoundedSender<ResponseStreamMessage>>,
180> = LazyLock::new(ConcurrentMap::new);
181
182/// An implementation of a TCP server.
183///
184/// This is the "base" TCP Server class that any of the servers that cat-dev
185/// will be built on. It has extensions for handling any potential type of
186/// server that we will ever need to support.
187///
188/// This server accepts lots of parameters so look at the modules documentation
189/// for general info about how processing works.
190pub struct TCPServer<State: Clone + Send + Sync + 'static = ()> {
191	/// The address that the server will bound too, when [`bind`]
192	/// is called. It may also be a server address to connect too when
193	/// we are a server who connects to our client.
194	///
195	/// Currently our server can only bind to a very specific address, and does
196	/// not supporting binding to multiple interfaces besides the special
197	/// `0.0.0.0` addresses that get handled by the OS.
198	address_to_bind_or_connect_to: SocketAddr,
199	/// Cat-dev's need load-bearing sleeps as they can "ACK" a ppacket, but
200	/// throw away the bytes and pretend it never got implied.
201	///
202	/// This is used for sleeping before then. By default this isn't set,
203	/// cat-dev services will call: [`TCPServer::set_cat_dev_slowdown`].
204	cat_dev_slowdown: Option<Duration>,
205	/// For devices that can't receive too much data at once.
206	///
207	/// When devices *cough* MION *cough* can't receive too much data at once,
208	/// we need to chunk it, and rest between those chunks. This will chunk those
209	/// packets for us.
210	chunk_output_at_size: Option<usize>,
211	/// The ID of this particular server.
212	///
213	/// Used to do out of band sends/broadcasts to this particular server.
214	id: u64,
215	/// The [`Service`] that gets called when a packet is received.
216	///
217	/// This will almost always be some sort of router that knows how to route
218	/// packets to handlers. However, it could just be a single tower service.
219	initial_service: BoxCloneService<Request<State>, Response, Infallible>,
220	/// Determines when a packet starts, and ends.
221	nagle_guard: NagleGuard,
222	/// A tower service to call when a particular stream starts.
223	///
224	/// This is effectively like an "on connect" hook that you can use to
225	/// call functions.
226	on_stream_begin: Option<UnderlyingOnStreamBeginService<State>>,
227	/// A tower service to call when a particular stream ends.
228	///
229	/// This is effectively like an "on disconnect" hook that you can use
230	/// to call functions.
231	on_stream_end: Option<UnderlyingOnStreamEndService<State>>,
232	/// A function to apply some sort of processing before doing any sort
233	/// of NAGLE/SLOWLORIS logic.
234	///
235	/// This is best used for encryption/decryption that needs to be handled
236	/// before anything else.
237	pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
238	/// A function to apply some sort of processing before sending a packet
239	/// out.
240	///
241	/// This is the very last thing before `send` is actually called. This is
242	/// best used for encryption/decryption that needs to be only processed
243	/// at the end.
244	post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
245	/// The name of the service being provided by this server, to attach to logs.
246	service_name: &'static str,
247	/// The "Slowloris Detection" timeout.
248	slowloris_timeout: Duration,
249	/// Type-safe server wide state.
250	///
251	/// This is mostly used by our servers to attach state this is meant to be
252	/// server wide. Such as a `HostFilesystem` type to abstract out all the
253	/// filesystem stuff that many of our servers have to do.
254	state: State,
255	/// If we should log all packet requests/responses when compiled with debug
256	/// assertions.
257	#[cfg(debug_assertions)]
258	trace_during_debug: bool,
259}
260
261impl TCPServer<()> {
262	/// Create a new TCP Server without having to specify any particular state.
263	///
264	/// ## Errors
265	///
266	/// If we cannot look up one specific address to bind too.
267	pub async fn new<AddrTy, ServiceTy>(
268		service_name: &'static str,
269		bind_addr: AddrTy,
270		initial_service: ServiceTy,
271		nagle_hooks: (
272			Option<&'static dyn PreNagleFnTy>,
273			Option<&'static dyn PostNagleFnTy>,
274		),
275		guard: impl Into<NagleGuard>,
276		trace_io_during_debug: bool,
277	) -> Result<Self, CommonNetAPIError>
278	where
279		AddrTy: ToSocketAddrs,
280		ServiceTy:
281			Clone + Send + Service<Request<()>, Response = Response, Error = Infallible> + 'static,
282		ServiceTy::Future: Send + 'static,
283	{
284		Self::new_with_state(
285			service_name,
286			bind_addr,
287			initial_service,
288			nagle_hooks,
289			guard,
290			(),
291			trace_io_during_debug,
292		)
293		.await
294	}
295}
296
297impl<State: Clone + Send + Sync + 'static> TCPServer<State> {
298	/// Send a message completely out of band to a particular open stream.
299	///
300	/// ## Errors
301	///
302	/// - If the stream is not actively open.
303	/// - If we could not actively queue a packet to go out.
304	pub async fn out_of_bound_send(
305		server_id: u64,
306		stream_id: u64,
307		message: ResponseStreamMessage,
308	) -> Result<(), CatBridgeError> {
309		if let Some(stream) = OUT_OF_BAND_SENDERS.get_async(&(server_id, stream_id)).await {
310			stream
311				.send(message)
312				.await
313				.map_err(NetworkError::SendQueueMessageFailure)?;
314			Ok(())
315		} else {
316			Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
317		}
318	}
319
320	/// Send a message completely out of band to all open stream.
321	pub async fn out_of_bound_broadcast(
322		server_id: u64,
323		message: ResponseStreamMessage,
324	) -> Vec<Result<(), CatBridgeError>> {
325		let mut ids = FnvHashSet::default();
326		// Scan all senders at a Point-in-Time..
327		OUT_OF_BAND_SENDERS
328			.iter_async(|key, _value| {
329				if key.0 == server_id {
330					ids.insert(key.1);
331				}
332				true
333			})
334			.await;
335
336		// Now we broadcast to those streams....
337		let mut tasks = Vec::with_capacity(ids.len());
338		for id in ids {
339			tasks.push(Self::out_of_bound_send(server_id, id, message.clone()));
340		}
341
342		join_all(tasks).await
343	}
344
345	/// Create a new TCP Server along with the state for this server.
346	///
347	/// ## Errors
348	///
349	/// If we cannot look up one specific address to bind too.
350	#[allow(unused)]
351	pub async fn new_with_state<AddrTy, ServiceTy>(
352		service_name: &'static str,
353		bind_addr: AddrTy,
354		initial_service: ServiceTy,
355		nagle_hooks: (
356			Option<&'static dyn PreNagleFnTy>,
357			Option<&'static dyn PostNagleFnTy>,
358		),
359		guard: impl Into<NagleGuard>,
360		state: State,
361		trace_io_during_debug: bool,
362	) -> Result<Self, CommonNetAPIError>
363	where
364		AddrTy: ToSocketAddrs,
365		ServiceTy: Clone
366			+ Send
367			+ Service<Request<State>, Response = Response, Error = Infallible>
368			+ 'static,
369		ServiceTy::Future: Send + 'static,
370	{
371		let hosts = lookup_host(bind_addr)
372			.await
373			.map_err(CommonNetAPIError::AddressLookupError)?
374			.collect::<Vec<_>>();
375		if hosts.len() != 1 {
376			return Err(CommonNetAPIError::WrongAmountOfAddressesToBindToo(hosts));
377		}
378
379		#[cfg(not(debug_assertions))]
380		{
381			if trace_io_during_debug {
382				warn!(
383					"Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
384				);
385			}
386		}
387
388		Ok(Self {
389			address_to_bind_or_connect_to: hosts[0],
390			cat_dev_slowdown: None,
391			chunk_output_at_size: None,
392			id: SERVER_ID.fetch_add(1, Ordering::SeqCst),
393			initial_service: BoxCloneService::new(initial_service),
394			nagle_guard: guard.into(),
395			on_stream_begin: None,
396			on_stream_end: None,
397			pre_nagle_hook: nagle_hooks.0,
398			post_nagle_hook: nagle_hooks.1,
399			service_name,
400			slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
401			state,
402			#[cfg(debug_assertions)]
403			trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
404		})
405	}
406
407	/// Get the particular ID for this server.
408	#[must_use]
409	pub const fn id(&self) -> u64 {
410		self.id
411	}
412
413	/// The IP address we should connect to bind too.
414	#[must_use]
415	pub const fn ip(&self) -> IpAddr {
416		self.address_to_bind_or_connect_to.ip()
417	}
418
419	/// Get the port that we're either binding too, or connecting too.
420	#[must_use]
421	pub const fn port(&self) -> u16 {
422		self.address_to_bind_or_connect_to.port()
423	}
424
425	/// Set the slowdown to before sending bytes from this server.
426	pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
427		self.cat_dev_slowdown = slowdown;
428	}
429
430	#[must_use]
431	pub const fn chunk_output_at_size(&self) -> Option<usize> {
432		self.chunk_output_at_size
433	}
434
435	pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
436		self.chunk_output_at_size = new_size;
437	}
438
439	#[must_use]
440	pub const fn slowloris_timeout(&self) -> Duration {
441		self.slowloris_timeout
442	}
443	pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
444		self.slowloris_timeout = slowloris_timeout;
445	}
446
447	#[must_use]
448	pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<State>> {
449		self.on_stream_begin.as_ref()
450	}
451
452	/// Set a hook to run when a stream has been created.
453	///
454	/// This is what happens when a new machine connects. So it may also be
455	/// refered to as "on connect". This assumes you're assigning a service
456	/// that already exists and is in the raw storage type, you may want to
457	/// look into [`Self::set_on_stream_begin`], or
458	/// [`Self::set_on_stream_begin_service`].
459	///
460	/// ## Errors
461	///
462	/// If the stream beginning hook has already been registered. If you're
463	/// looking to perform multiple actions at once, look into layer-ing.
464	pub fn set_raw_on_stream_begin(
465		&mut self,
466		on_start: Option<UnderlyingOnStreamBeginService<State>>,
467	) -> Result<(), CommonNetAPIError> {
468		if self.on_stream_begin.is_some() {
469			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
470		}
471
472		self.on_stream_begin = on_start;
473		Ok(())
474	}
475
476	/// Set a function hook to run when a stream has been created.
477	///
478	/// This is what happens when a new machine connects. So it may also be
479	/// refered to as "on connect". This assumes you're assigning a function
480	/// to on stream begin otherwise use [`Self::set_raw_on_stream_begin`],
481	/// or [`Self::set_on_stream_begin_service`].
482	///
483	/// ## Errors
484	///
485	/// If the stream beginning hook has already been registered. If you're
486	/// looking to perform multiple actions at once, look into layer-ing.
487	pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
488		&mut self,
489		handler: HandlerTy,
490	) -> Result<(), CommonNetAPIError>
491	where
492		HandlerParamsTy: Send + 'static,
493		HandlerTy: OnResponseStreamBeginHandler<HandlerParamsTy, State> + Clone + Send + 'static,
494	{
495		if self.on_stream_begin.is_some() {
496			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
497		}
498
499		let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
500		self.on_stream_begin = Some(boxed);
501		Ok(())
502	}
503
504	/// Set a function hook to run when a stream has been created.
505	///
506	/// This is what happens when a new machine connects. So it may also be
507	/// refered to as "on connect". This assumes you're assigning a [`Service`]
508	/// to on stream begin otherwise use [`Self::set_raw_on_stream_begin`],
509	/// or [`Self::set_on_stream_begin`].
510	///
511	/// ## Errors
512	///
513	/// If the stream beginning hook has already been registered. If you're
514	/// looking to perform multiple actions at once, look into layer-ing.
515	pub fn set_on_stream_begin_service<ServiceTy>(
516		&mut self,
517		service_ty: ServiceTy,
518	) -> Result<(), CommonNetAPIError>
519	where
520		ServiceTy: Clone
521			+ Send
522			+ Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
523			+ 'static,
524		ServiceTy::Future: Send + 'static,
525	{
526		if self.on_stream_begin.is_some() {
527			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
528		}
529
530		self.on_stream_begin = Some(BoxCloneService::new(service_ty));
531		Ok(())
532	}
533
534	/// Add a layer to the service to process when a stream begins, or a new
535	/// connection is created.
536	///
537	/// ## Errors
538	///
539	/// If there is no on stream begin handler that is currently active.
540	pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
541		&mut self,
542		layer: LayerTy,
543	) -> Result<(), CommonNetAPIError>
544	where
545		LayerTy: Layer<UnderlyingOnStreamBeginService<State>, Service = ServiceTy>,
546		ServiceTy: Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
547			+ Clone
548			+ Send
549			+ 'static,
550		<LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
551	{
552		let Some(srvc) = self.on_stream_begin.take() else {
553			return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
554		};
555
556		self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
557		Ok(())
558	}
559
560	#[must_use]
561	pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<State>> {
562		self.on_stream_end.as_ref()
563	}
564
565	/// Set a hook to run when a stream is being destroyed.
566	///
567	/// This is what happens when a machine disconnects. So it may also be
568	/// refered to as "on disconnect". This assumes you're assigning a service
569	/// that already exists and is in the raw storage type, you may want to
570	/// look into [`Self::set_on_stream_end`], or
571	/// [`Self::set_on_stream_end_service`].
572	///
573	/// ## Errors
574	///
575	/// If the stream ending hook has already been registered. If you're
576	/// looking to perform multiple actions at once, look into layer-ing.
577	pub fn set_raw_on_stream_end(
578		&mut self,
579		on_end: Option<UnderlyingOnStreamEndService<State>>,
580	) -> Result<(), CommonNetAPIError> {
581		if self.on_stream_end.is_some() {
582			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
583		}
584
585		self.on_stream_end = on_end;
586		Ok(())
587	}
588
589	/// Set a function hook to run when a stream is being destroyed.
590	///
591	/// This is what happens when a machine disconnects. So it may also be
592	/// refered to as "on disconnect". This assumes you're assigning a function
593	/// to on stream end otherwise use [`Self::set_raw_on_stream_end`],
594	/// or [`Self::set_on_stream_end_service`].
595	///
596	/// ## Errors
597	///
598	/// If the stream ending hook has already been registered. If you're
599	/// looking to perform multiple actions at once, look into layer-ing.
600	pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
601		&mut self,
602		handler: HandlerTy,
603	) -> Result<(), CommonNetAPIError>
604	where
605		HandlerParamsTy: Send + 'static,
606		HandlerTy: OnResponseStreamEndHandler<HandlerParamsTy, State> + Clone + Send + 'static,
607	{
608		if self.on_stream_end.is_some() {
609			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
610		}
611
612		let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
613		self.on_stream_end = Some(boxed);
614		Ok(())
615	}
616
617	/// Set a function hook to run when a stream is being destroyed.
618	///
619	/// This is what happens when a machine disconnects. So it may also be
620	/// refered to as "on disconnect". This assumes you're assigning a [`Service`]
621	/// to on stream end otherwise use [`Self::set_raw_on_stream_end`],
622	/// or [`Self::set_on_stream_end`].
623	///
624	/// ## Errors
625	///
626	/// If the stream beginning hook has already been registered. If you're
627	/// looking to perform multiple actions at once, look into layer-ing.
628	pub fn set_on_stream_end_service<ServiceTy>(
629		&mut self,
630		service_ty: ServiceTy,
631	) -> Result<(), CommonNetAPIError>
632	where
633		ServiceTy: Clone
634			+ Send
635			+ Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
636			+ 'static,
637		ServiceTy::Future: Send + 'static,
638	{
639		if self.on_stream_end.is_some() {
640			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
641		}
642
643		self.on_stream_end = Some(BoxCloneService::new(service_ty));
644		Ok(())
645	}
646
647	/// Add a layer to the service to process when a stream ends, or a new
648	/// connection is destroyed.
649	///
650	/// ## Errors
651	///
652	/// If there is no on stream end handler that is currently active.
653	pub fn layer_on_stream_end<LayerTy, ServiceTy>(
654		&mut self,
655		layer: LayerTy,
656	) -> Result<(), CommonNetAPIError>
657	where
658		LayerTy: Layer<UnderlyingOnStreamEndService<State>, Service = ServiceTy>,
659		ServiceTy: Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
660			+ Clone
661			+ Send
662			+ 'static,
663		<LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
664	{
665		let Some(srvc) = self.on_stream_end.take() else {
666			return Err(CommonNetAPIError::OnStreamEndNotRegistered);
667		};
668
669		self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
670		Ok(())
671	}
672
673	#[must_use]
674	pub const fn initial_service(&self) -> &BoxCloneService<Request<State>, Response, Infallible> {
675		&self.initial_service
676	}
677
678	pub fn layer_initial_service<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
679	where
680		LayerTy: Layer<BoxCloneService<Request<State>, Response, Infallible>, Service = ServiceTy>,
681		ServiceTy: Service<Request<State>, Response = Response, Error = Infallible>
682			+ Clone
683			+ Send
684			+ 'static,
685		<LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
686	{
687		self.initial_service = BoxCloneService::new(layer.layer(self.initial_service.clone()));
688	}
689
690	/// Get a reference to the current state of the server.
691	#[must_use]
692	pub const fn state(&self) -> &State {
693		&self.state
694	}
695
696	/// "Connect" to our remote client address, and start serving ourselves to
697	/// whoever we connect too.
698	///
699	/// NOTE: this will ***NOT*** return until the service either runs into an
700	/// error, or the task is cancelled.
701	///
702	/// ## Errors
703	///
704	/// If we cannot connect to the remote source, fetch the local address, or
705	/// handle the connection in some other way.
706	pub async fn connect(self) -> Result<(), CatBridgeError> {
707		loop {
708			// Copy the address we're bound to, to use in log statements later on.
709			let client_address = self.address_to_bind_or_connect_to;
710			let stream = TcpStream::connect(self.address_to_bind_or_connect_to)
711				.await
712				.map_err(NetworkError::IO)?;
713			let loggable_address = stream.local_addr().map_err(NetworkError::IO)?;
714			let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
715			trace!(
716				server.address = %loggable_address,
717				client.address = %client_address,
718				stream.id = stream_id,
719				stream.stream_type = "server",
720				"cat_dev::net::tcp_server::connect(): started connection (TcpStream::connect())",
721			);
722
723			if let Err(cause) = Self::handle_tcp_connection(
724				self.on_stream_begin.clone(),
725				self.on_stream_end.clone(),
726				self.nagle_guard.clone(),
727				self.slowloris_timeout,
728				self.initial_service.clone(),
729				stream,
730				client_address,
731				self.pre_nagle_hook,
732				self.post_nagle_hook,
733				self.chunk_output_at_size,
734				self.state.clone(),
735				self.id,
736				stream_id,
737				self.cat_dev_slowdown,
738				#[cfg(debug_assertions)]
739				self.trace_during_debug,
740			)
741			.instrument(error_span!(
742				"CatDevTCPServerConnect",
743				client.address = %client_address,
744				server.address = %loggable_address,
745				server.service = self.service_name,
746				stream.id = stream_id,
747				stream.stream_type = "server",
748			))
749			.await
750			{
751				warn!(
752					?cause,
753					client.address = %client_address,
754					server.address = %loggable_address,
755					server.service = self.service_name,
756					"Error escaped while handling TCP connection.",
757				);
758			}
759		}
760	}
761
762	/// Bind this server, and start listening on the specified address.
763	///
764	/// NOTE: this will ***NOT*** return until the service either runs into an
765	/// error, or gets cancelled.
766	///
767	/// ## Errors
768	///
769	/// - If we cannot bind to the requested address, and port.
770	/// - If we run into an error accept'ing a TCP connection.
771	pub async fn bind(self) -> Result<(), CatBridgeError> {
772		// Copy the address we're bound to, to use in log statements later on.
773		let loggable_address = self.address_to_bind_or_connect_to;
774		let listener = TcpListener::bind(self.address_to_bind_or_connect_to)
775			.await
776			.map_err(NetworkError::IO)?;
777
778		loop {
779			let (stream, client_address) = listener.accept().await.map_err(NetworkError::IO)?;
780			trace!(
781				server.address = %loggable_address,
782				client.address = %client_address,
783				"cat_dev::net::tcp_server::bind(): received connection (listener.accept())",
784			);
785
786			// We need to clone all of our stuff, so they can be owned by the
787			// underlying connection task.
788			//
789			// We spawn a task per connection that way one connection doesn't block
790			// the accept'ing of any new connections. If we did, we'd only be able
791			// to handle one connection at a time which is lame.
792			//
793			// Tasks are also incredibly cheap, so spawning one per connection even
794			// long lived tasks is nbd.
795			let cloned_begin_handler = self.on_stream_begin.clone();
796			let cloned_end_handler = self.on_stream_end.clone();
797			let cloned_nagle_guard = self.nagle_guard.clone();
798			let cloned_handler = self.initial_service.clone();
799			let cloned_state = self.state.clone();
800			let copied_pre_nagle_hook = self.pre_nagle_hook;
801			let copied_post_nagle_hook = self.post_nagle_hook;
802			let copied_chunk_on_size = self.chunk_output_at_size;
803			let copied_service_name = self.service_name;
804			let copied_slowloris_timeout = self.slowloris_timeout;
805			let copied_server_id = self.id;
806			let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
807			let copied_slowdown = self.cat_dev_slowdown;
808			#[cfg(debug_assertions)]
809			let trace_io = self.trace_during_debug;
810
811			TaskBuilder::new()
812				.name("cat_dev::net::tcp_server::bind().connection.handle")
813				.spawn(async move {
814					if let Err(cause) = Self::handle_tcp_connection(
815						cloned_begin_handler,
816						cloned_end_handler,
817						cloned_nagle_guard,
818						copied_slowloris_timeout,
819						cloned_handler,
820						stream,
821						client_address,
822						copied_pre_nagle_hook,
823						copied_post_nagle_hook,
824						copied_chunk_on_size,
825						cloned_state,
826						copied_server_id,
827						stream_id,
828						copied_slowdown,
829						#[cfg(debug_assertions)]
830						trace_io,
831					)
832					.instrument(error_span!(
833						"CatDevTCPServerAccept",
834						client.address = %client_address,
835						server.address = %loggable_address,
836						server.service = copied_service_name,
837						server.stream_id = stream_id,
838					))
839					.await
840					{
841						warn!(
842							?cause,
843							client.address = %client_address,
844							server.address = %loggable_address,
845							server.service = %copied_service_name,
846							"Error escaped while handling TCP connection.",
847						);
848					}
849				})
850				.map_err(CatBridgeError::SpawnFailure)?;
851		}
852	}
853
854	/// Attempt to handle a TCP connection, or "stream".
855	///
856	/// This will be called immediately from the task that spawns on
857	/// connections. It will automatically spin down "shortly" after
858	/// the connection is gone. We basically wait for an error back from the OS.
859	///
860	/// ## Errors
861	///
862	/// If reading, or writing to the underlying TCP Stream runs into any errors.
863	/// Individual packet errors do not necissarily correlate to a closed
864	/// connection. However, most services will link the two together.
865	#[allow(
866		// all of our parameters are very well named, and types are not close to
867		// overlapping with each other.
868		//
869		// we also just fundamenetally have a lot of state thanks to the complexity
870		// of all the things we have to handle for a TCP connection, e.g. NAGLE,
871		// delimiters, caches, etc.
872		//
873		// it is also only ever called from one internal function, so it's not like
874		// part of our public facing api.
875		clippy::too_many_arguments,
876	)]
877	async fn handle_tcp_connection(
878		on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
879		on_stream_end_handler: Option<UnderlyingOnStreamEndService<State>>,
880		nagle_guard: NagleGuard,
881		slowloris_timeout: Duration,
882		handler: BoxCloneService<Request<State>, Response, Infallible>,
883		mut tcp_stream: TcpStream,
884		client_address: SocketAddr,
885		pre_hook_cloned: Option<&'static dyn PreNagleFnTy>,
886		post_hook_cloned: Option<&'static dyn PostNagleFnTy>,
887		chunk_output_at_size: Option<usize>,
888		state: State,
889		server_id: u64,
890		stream_id: u64,
891		cat_dev_slowdown: Option<Duration>,
892		#[cfg(debug_assertions)] trace_io: bool,
893	) -> Result<(), CatBridgeError> {
894		let (mut send_responses, mut packets_left_to_send) =
895			bounded_channel::<ResponseStreamMessage>(128);
896
897		// Run connection handlers and let them tell us if they want to allow this
898		// stream setup to continue.
899		if Self::initialize_stream(
900			on_stream_begin_handler,
901			&mut send_responses,
902			&client_address,
903			&state,
904			&mut tcp_stream,
905			server_id,
906			stream_id,
907		)
908		.await?
909		{
910			return Ok(());
911		}
912
913		// Connection has been "approved", setup the on disconnect handler.
914		let _guard = on_stream_end_handler.map(|service| {
915			DisconnectAsyncDropServer::new(service, state.clone(), client_address, stream_id)
916		});
917
918		// Any previously saved data that was a victim of NAGLE's algorithim, or
919		// similar.
920		let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
921
922		loop {
923			let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
924			tokio::select! {
925				received = packets_left_to_send.recv() => {
926					if Self::handle_server_write_to_connection(
927						&mut tcp_stream,
928						chunk_output_at_size,
929						received,
930						post_hook_cloned,
931						stream_id,
932						cat_dev_slowdown,
933						#[cfg(debug_assertions)] trace_io,
934					).await? {
935						break;
936					}
937				}
938				res_size = tcp_stream.read_buf(&mut buff) => {
939					let size = res_size.map_err(NetworkError::IO)?;
940					buff.truncate(size);
941					if buff.is_empty() {
942						continue;
943					}
944
945					let (should_break, returned_stream) = Self::handle_server_read_from_connection(
946						tcp_stream,
947						buff,
948						send_responses.clone(),
949						&nagle_guard,
950						slowloris_timeout,
951						handler.clone(),
952						&mut nagle_cache,
953						client_address,
954						pre_hook_cloned,
955						state.clone(),
956						stream_id,
957						#[cfg(debug_assertions)] trace_io,
958					).await?;
959					tcp_stream = returned_stream;
960					if should_break {
961						break;
962					}
963				}
964			}
965		}
966
967		OUT_OF_BAND_SENDERS
968			.remove_async(&(server_id, stream_id))
969			.await;
970		packets_left_to_send.close();
971		std::mem::drop(tcp_stream.shutdown().await);
972
973		Ok(())
974	}
975
976	async fn initialize_stream(
977		on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
978		send_channel: &mut BoundedSender<ResponseStreamMessage>,
979		source_address: &SocketAddr,
980		state: &State,
981		tcp_stream: &mut TcpStream,
982		server_id: u64,
983		stream_id: u64,
984	) -> Result<bool, CatBridgeError> {
985		tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
986		OUT_OF_BAND_SENDERS
987			.upsert_async((server_id, stream_id), send_channel.clone())
988			.await;
989
990		if let Some(mut handle) = on_stream_begin_handler
991			&& !handle
992				.call(ResponseStreamEvent::new_with_state(
993					send_channel.clone(),
994					*source_address,
995					Some(stream_id),
996					state.clone(),
997				))
998				.await?
999		{
1000			trace!("handler failed on stream begin hook");
1001			return Ok(true);
1002		}
1003
1004		Ok(false)
1005	}
1006
1007	async fn handle_server_write_to_connection(
1008		tcp_stream: &mut TcpStream,
1009		chunk_output_on_size: Option<usize>,
1010		to_send_to_client_opt: Option<ResponseStreamMessage>,
1011		post_hook: Option<&'static dyn PostNagleFnTy>,
1012		stream_id: u64,
1013		cat_dev_slowdown: Option<Duration>,
1014		#[cfg(debug_assertions)] trace_io: bool,
1015	) -> Result<bool, CatBridgeError> {
1016		let Some(to_send_to_client) = to_send_to_client_opt else {
1017			return Ok(false);
1018		};
1019
1020		match to_send_to_client {
1021			ResponseStreamMessage::Disconnect => {
1022				debug!("stream-disconnect-message");
1023				Ok(true)
1024			}
1025			ResponseStreamMessage::Response(resp) => {
1026				if let Some(body) = resp.body()
1027					&& !body.is_empty()
1028				{
1029					let messages = if let Some(size) = chunk_output_on_size {
1030						body.chunks(size)
1031							.map(Bytes::copy_from_slice)
1032							.collect::<Vec<_>>()
1033					} else {
1034						vec![body.clone()]
1035					};
1036
1037					for message in messages {
1038						#[cfg(debug_assertions)]
1039						if trace_io {
1040							debug!(
1041								body.hex = format!("{message:02x?}"),
1042								body.str = String::from_utf8_lossy(&message).to_string(),
1043								"cat-dev-trace-output-tcp-server",
1044							);
1045						}
1046
1047						let mut full_response = message.clone();
1048						if let Some(post) = post_hook {
1049							full_response = block_in_place(|| post(stream_id, full_response));
1050						}
1051						if let Some(slowdown_ms) = cat_dev_slowdown {
1052							sleep(slowdown_ms).await;
1053						}
1054
1055						tcp_stream.writable().await.map_err(NetworkError::IO)?;
1056						tcp_stream
1057							.write_all(&full_response)
1058							.await
1059							.map_err(NetworkError::IO)?;
1060					}
1061				}
1062
1063				if resp.request_connection_close() {
1064					trace!("response-requested-connection-close");
1065					Ok(true)
1066				} else {
1067					Ok(false)
1068				}
1069			}
1070		}
1071	}
1072
1073	#[allow(
1074		// All of our types are very differently typed, and well named, so chance
1075		// of confusion is low.
1076		//
1077		// Not to mention this is an internal only method.
1078		clippy::too_many_arguments,
1079	)]
1080	async fn handle_server_read_from_connection<'data>(
1081		mut stream: TcpStream,
1082		mut buff: BytesMut,
1083		channel: BoundedSender<ResponseStreamMessage>,
1084		nagle_guard: &'data NagleGuard,
1085		slowloris_timeout: Duration,
1086		mut handler: BoxCloneService<Request<State>, Response, Infallible>,
1087		nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
1088		client_address: SocketAddr,
1089		cloned_pre_nagle: Option<&'static dyn PreNagleFnTy>,
1090		state: State,
1091		stream_id: u64,
1092		#[cfg(debug_assertions)] trace_io: bool,
1093	) -> Result<(bool, TcpStream), CatBridgeError> {
1094		if let Some(convert_fn) = cloned_pre_nagle {
1095			block_in_place(|| {
1096				(*convert_fn)(stream_id, &mut buff);
1097			});
1098		}
1099
1100		#[cfg(debug_assertions)]
1101		{
1102			if trace_io {
1103				debug!(
1104					body.hex = format!("{:02x?}", buff),
1105					body.str = String::from_utf8_lossy(&buff).to_string(),
1106					"cat-dev-trace-input-tcp-server",
1107				);
1108			}
1109		}
1110
1111		// We may be NAGEL'd, so we need to recover/split, and potentially buffer
1112		// any packets. Also watch out for slowloris-esque attacks.
1113		let start_time = now();
1114		if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
1115			// If we can't calculat duration seconds it's negative, or no duration
1116			// has passed yet.
1117			//
1118			// Just treat it as 0.
1119			let total_duration = start_time
1120				.duration_since(old_start_time)
1121				.unwrap_or(Duration::from_secs(0));
1122			if total_duration > slowloris_timeout {
1123				debug!(
1124					cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
1125					"slowloris-detected",
1126				);
1127				return Ok((true, stream));
1128			}
1129
1130			existing_buff.extend(buff.freeze());
1131			buff = existing_buff;
1132		}
1133
1134		while let Some((start_of_packet, end_of_packet)) = nagle_guard.split(&buff)? {
1135			let remaining_buff = buff.split_off(end_of_packet);
1136			let _start_of_buff = buff.split_to(start_of_packet);
1137			let req_body = buff.freeze();
1138			buff = remaining_buff;
1139
1140			let lockable_stream = Arc::new(Mutex::new(Some((Some(buff), stream))));
1141			let mut request_object = Request::new_with_state_and_stream(
1142				req_body,
1143				client_address,
1144				state.clone(),
1145				Some(stream_id),
1146				lockable_stream.clone(),
1147			);
1148			request_object.extensions_mut().insert(channel.clone());
1149			if let Err(cause) = match handler.call(request_object).await {
1150				Ok(ref resp) => {
1151					channel
1152						.send(ResponseStreamMessage::Response(resp.clone()))
1153						.await
1154				}
1155				Err(cause) => {
1156					warn!(
1157						?cause,
1158						lisa.force_combine_fields = true,
1159						"request handler failed, will close connection.",
1160					);
1161					channel.send(ResponseStreamMessage::Disconnect).await
1162				}
1163			} {
1164				warn!(
1165					?cause,
1166					lisa.force_combine_fields = true,
1167					"internal queue failure will not send disconnect/response."
1168				);
1169			}
1170
1171			{
1172				let mut done_lock = lockable_stream.lock().await;
1173				if let Some((newer_buff, strm)) = done_lock.take() {
1174					if let Some(newest_buff) = newer_buff {
1175						buff = newest_buff;
1176					} else {
1177						return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1178					}
1179					stream = strm;
1180				} else {
1181					return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1182				}
1183			}
1184		}
1185
1186		if !buff.is_empty() {
1187			_ = nagle_cache.insert((buff, start_time));
1188		}
1189
1190		Ok((false, stream))
1191	}
1192}
1193
1194impl<State: Clone + Debug + Send + Sync + 'static> Debug for TCPServer<State> {
1195	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
1196		let mut dbg_struct = fmt.debug_struct("TCPServer");
1197		dbg_struct
1198			.field(
1199				"address_to_bind_or_connect_to",
1200				&self.address_to_bind_or_connect_to,
1201			)
1202			.field("cat_dev_slowdown", &self.cat_dev_slowdown)
1203			.field("chunk_output_at_size", &self.chunk_output_at_size)
1204			.field("id", &self.id)
1205			.field("initial_service", &self.initial_service)
1206			.field("nagle_guard", &self.nagle_guard)
1207			.field("on_stream_begin", &self.on_stream_begin)
1208			.field("on_stream_end", &self.on_stream_end)
1209			.field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
1210			.field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
1211			.field("service_name", &self.service_name)
1212			.field("slowloris_timeout", &self.slowloris_timeout)
1213			.field("state", &self.state);
1214
1215		#[cfg(debug_assertions)]
1216		{
1217			dbg_struct.field("trace_during_debug", &self.trace_during_debug);
1218		}
1219
1220		dbg_struct.finish()
1221	}
1222}
1223
1224const TCP_SERVER_FIELDS: &[NamedField<'static>] = &[
1225	NamedField::new("address_to_bind_or_connect_to"),
1226	NamedField::new("cat_dev_slowdown"),
1227	NamedField::new("chunk_output_at_size"),
1228	NamedField::new("initial_service"),
1229	NamedField::new("nagle_guard"),
1230	NamedField::new("on_stream_begin"),
1231	NamedField::new("on_stream_end"),
1232	NamedField::new("has_pre_nagle_hook"),
1233	NamedField::new("has_post_nagle_hook"),
1234	NamedField::new("service_name"),
1235	NamedField::new("slowloris_timeout"),
1236	NamedField::new("state"),
1237	#[cfg(debug_assertions)]
1238	NamedField::new("trace_during_debug"),
1239];
1240
1241impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Structable for TCPServer<State> {
1242	fn definition(&self) -> StructDef<'_> {
1243		StructDef::new_static("TcpServer", Fields::Named(TCP_SERVER_FIELDS))
1244	}
1245}
1246
1247impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Valuable for TCPServer<State> {
1248	fn as_value(&self) -> Value<'_> {
1249		Value::Structable(self)
1250	}
1251
1252	fn visit(&self, visitor: &mut dyn Visit) {
1253		visitor.visit_named_fields(&NamedValues::new(
1254			TCP_SERVER_FIELDS,
1255			&[
1256				Valuable::as_value(&format!("{}", self.address_to_bind_or_connect_to)),
1257				Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
1258					format!("{}ms", slowdown.as_millis())
1259				} else {
1260					"<none>".to_string()
1261				}),
1262				Valuable::as_value(&self.chunk_output_at_size),
1263				Valuable::as_value(&format!("{:?}", self.initial_service)),
1264				Valuable::as_value(&self.nagle_guard),
1265				Valuable::as_value(&format!("{:?}", self.on_stream_begin)),
1266				Valuable::as_value(&format!("{:?}", self.on_stream_end)),
1267				Valuable::as_value(&self.pre_nagle_hook.is_some()),
1268				Valuable::as_value(&self.post_nagle_hook.is_some()),
1269				Valuable::as_value(&self.service_name),
1270				Valuable::as_value(&format!("{:?}", self.slowloris_timeout)),
1271				Valuable::as_value(&self.state),
1272				#[cfg(debug_assertions)]
1273				Valuable::as_value(&self.trace_during_debug),
1274			],
1275		));
1276	}
1277}
1278
1279#[cfg(test)]
1280pub mod test_helpers {
1281	use super::*;
1282	use std::net::{Ipv4Addr, SocketAddrV4};
1283
1284	/// Get a free TCP Port for IPv4.
1285	///
1286	/// This will attempt to get a free tcp port on IPv4 by actually binding to
1287	/// port `0` which should make the OS automatically assign a free unused
1288	/// port.
1289	pub async fn get_free_tcp_v4_port() -> Option<u16> {
1290		let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
1291		if let Ok(bound) = TcpListener::bind(addr).await {
1292			if let Ok(local) = bound.local_addr() {
1293				return Some(local.port());
1294			}
1295		}
1296		None
1297	}
1298}
1299
1300#[cfg(test)]
1301mod unit_tests {
1302	use super::*;
1303	use crate::net::{
1304		CURRENT_TIME,
1305		server::{Router, requestable::Extension, test_helpers::*},
1306	};
1307	use bytes::Bytes;
1308	use std::{
1309		net::{Ipv4Addr, SocketAddrV4},
1310		sync::{
1311			Arc, Mutex,
1312			atomic::{AtomicU8, Ordering},
1313		},
1314		time::Duration,
1315	};
1316	use tokio::time::timeout;
1317
1318	fn set_now(new_time: SystemTime) {
1319		CURRENT_TIME.with(|time_lazy| {
1320			*time_lazy.write().expect("RwLock is poisioned?") = new_time;
1321		})
1322	}
1323
1324	#[tokio::test]
1325	pub async fn full_server() {
1326		let connected_fired = Arc::new(Mutex::new(false));
1327		let on_disconnect_fired = Arc::new(Mutex::new(false));
1328		let request_fired = Arc::new(Mutex::new(false));
1329
1330		async fn on_connection(
1331			Extension(connected): Extension<Arc<Mutex<bool>>>,
1332		) -> Result<bool, CatBridgeError> {
1333			let mut locked = connected
1334				.lock()
1335				.expect("Failed to lock connected fired extension");
1336			*locked = true;
1337			Ok(true)
1338		}
1339		async fn on_disconnect(
1340			Extension(disconnected): Extension<Arc<Mutex<bool>>>,
1341		) -> Result<(), CatBridgeError> {
1342			let mut locked = disconnected
1343				.lock()
1344				.expect("Failed to lock connected fired extension");
1345			*locked = true;
1346			Ok(())
1347		}
1348		async fn on_request(
1349			Extension(request): Extension<Arc<Mutex<bool>>>,
1350		) -> Result<Response, CatBridgeError> {
1351			let mut locked = request
1352				.lock()
1353				.expect("Failed to lock connected fired extension");
1354			*locked = true;
1355
1356			let mut resp = Response::new_with_body(Bytes::from(vec![0x1]));
1357			resp.should_close_connection();
1358			Ok(resp)
1359		}
1360
1361		let mut router = Router::new();
1362		router
1363			.add_route(&[0x1, 0x2, 0x3], on_request)
1364			.expect("Failed to add a route!");
1365		router.layer(Extension(request_fired.clone()));
1366
1367		let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1368			.await
1369			.expect("Timed out trying to find free port!")
1370			.expect("Failed to find free TCP port on system.");
1371
1372		let mut srv = timeout(
1373			Duration::from_secs(5),
1374			TCPServer::new_with_state(
1375				"test",
1376				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1377				router,
1378				(None, None),
1379				NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1380				(),
1381				#[cfg(debug_assertions)]
1382				true,
1383			),
1384		)
1385		.await
1386		.expect("Timed out starting server")
1387		.expect("Failed to create TCP Server.");
1388
1389		srv.set_on_stream_begin(on_connection)
1390			.expect("Failed to register stream begin handler!");
1391		srv.layer_on_stream_begin(Extension(connected_fired.clone()))
1392			.expect("Failed to add layer to on stream begin!");
1393		srv.set_on_stream_end(on_disconnect)
1394			.expect("Failed to register stream end handler!");
1395		srv.layer_on_stream_end(Extension(on_disconnect_fired.clone()))
1396			.expect("Failed to add layer to on_disconnect!");
1397
1398		let spawned =
1399			tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1400		{
1401			loop {
1402				let client_stream_res = timeout(
1403					Duration::from_secs(10),
1404					TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1405				)
1406				.await
1407				.expect("Service timed out waiting for connection!");
1408				// Service hasn't binded yet!
1409				if client_stream_res.is_err() {
1410					continue;
1411				}
1412				let mut client_stream = client_stream_res.unwrap();
1413				client_stream
1414					.write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1415					.await
1416					.expect("Failed to write to client stream");
1417				// Wait til we get a response.
1418				let mut buff = [0_u8; 1];
1419				timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1420					.await
1421					.expect("Timed out reading from client stream")
1422					.expect("Failed to read data from client stream");
1423				timeout(Duration::from_secs(5), client_stream.shutdown())
1424					.await
1425					.expect("Timed out shutting down client stream")
1426					.expect("Failed to shutdown client stream.");
1427				break;
1428			}
1429		}
1430		// Destruct the server.
1431		std::mem::drop(spawned);
1432
1433		let locked_connect = connected_fired
1434			.lock()
1435			.expect("Failed to lock second connect");
1436		let locked_disconnect = on_disconnect_fired
1437			.lock()
1438			.expect("Failed to lock second on_disconnect");
1439		let locked_request = request_fired.lock().expect("Failed to lock second request");
1440
1441		assert!(*locked_connect, "on connection handler never fired!");
1442		assert!(*locked_disconnect, "on disconnect handler never fired!");
1443		assert!(*locked_request, "on request handler never fired!");
1444	}
1445
1446	#[tokio::test]
1447	pub async fn nagled_logic_works() {
1448		let requests_fired = Arc::new(AtomicU8::new(0));
1449
1450		async fn on_request(
1451			Extension(request): Extension<Arc<AtomicU8>>,
1452		) -> Result<Response, CatBridgeError> {
1453			request.fetch_add(1, Ordering::SeqCst);
1454			let resp = Response::new_with_body(Bytes::from(vec![0x1]));
1455			Ok(resp)
1456		}
1457
1458		let mut router = Router::new();
1459		router
1460			.add_route(&[0x1, 0x2, 0x3], on_request)
1461			.expect("Failed to add a route!");
1462		router.layer(Extension(requests_fired.clone()));
1463
1464		let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1465			.await
1466			.expect("Timed out finding port to bind too!")
1467			.expect("Failed to find any free tcp v4 port on system!");
1468		let srv = timeout(
1469			Duration::from_secs(5),
1470			TCPServer::new_with_state(
1471				"test",
1472				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1473				router,
1474				(None, None),
1475				NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1476				(),
1477				#[cfg(debug_assertions)]
1478				true,
1479			),
1480		)
1481		.await
1482		.expect("timed out starting TCP Server for test")
1483		.expect("falied to create local tcp server!");
1484
1485		let spawned =
1486			tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1487		{
1488			loop {
1489				let client_stream_res = timeout(
1490					Duration::from_secs(10),
1491					TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1492				)
1493				.await
1494				.expect("Service timed out waiting for connection!");
1495				// Service hasn't binded yet!
1496				if client_stream_res.is_err() {
1497					continue;
1498				}
1499				let mut client_stream = client_stream_res.unwrap();
1500
1501				client_stream
1502					.write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF])
1503					.await
1504					.expect("Failed to write to client_stream");
1505				client_stream
1506					.flush()
1507					.await
1508					.expect("Failed to flush client_stream");
1509				client_stream
1510					.write_all(&[0xFF, 0xFF, 0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1511					.await
1512					.expect("Failed to issue second write call to client_stream");
1513				let mut buff = [0_u8; 2];
1514				let read_bytes = timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1515					.await
1516					.expect("Timed out reading from client_stream")
1517					.expect("Failed to read from client_stream!");
1518				if read_bytes == 1 {
1519					timeout(Duration::from_secs(5), client_stream.read(&mut buff[1..]))
1520						.await
1521						.expect("Timed out reading from client_stream")
1522						.expect("Failed to read from client_stream!");
1523				}
1524
1525				timeout(Duration::from_secs(5), client_stream.shutdown())
1526					.await
1527					.expect("Timed out shutting down client stream")
1528					.expect("Failed to shutdown client stream.");
1529				break;
1530			}
1531		}
1532		// Shuts the server down.
1533		std::mem::drop(spawned);
1534
1535		assert_eq!(
1536			requests_fired.load(Ordering::SeqCst),
1537			2,
1538			"on request did not fire the correct amount of times!",
1539		);
1540	}
1541
1542	#[tokio::test]
1543	pub async fn slowloris_is_blocked() {
1544		let mut router = Router::new();
1545		router
1546			.add_route(&[0x1, 0x2, 0x3], || async {
1547				Ok(Response::new_with_body(Bytes::from(vec![0x1])))
1548			})
1549			.expect("Failed to add a route!");
1550
1551		let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1552			.await
1553			.expect("Timed out finding port to bind too!")
1554			.expect("Failed to find any free tcp v4 port on system!");
1555		let srv = timeout(
1556			Duration::from_secs(5),
1557			TCPServer::new_with_state(
1558				"test",
1559				SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1560				router,
1561				(None, None),
1562				NagleGuard::EndSigilSearch(&[0x10, 0x11, 0x12]),
1563				(),
1564				#[cfg(debug_assertions)]
1565				true,
1566			),
1567		)
1568		.await
1569		.expect("timed out starting TCP Server for test")
1570		.expect("falied to create local tcp server!");
1571
1572		let spawned =
1573			tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1574		let read_bytes;
1575		{
1576			loop {
1577				let client_stream_res = timeout(
1578					Duration::from_secs(10),
1579					TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1580				)
1581				.await
1582				.expect("Service timed out waiting for connection!");
1583				// Service hasn't binded yet!
1584				if client_stream_res.is_err() {
1585					continue;
1586				}
1587				let mut client_stream = client_stream_res.unwrap();
1588
1589				client_stream
1590					.write_all(&[0x1, 0x2, 0x3, 0x10])
1591					.await
1592					.expect("Failed to write to client_stream");
1593				client_stream
1594					.flush()
1595					.await
1596					.expect("Failed to flush client_stream");
1597				// This ensures the server running in the background has for sure
1598				// cached the time we started processing.
1599				tokio::time::sleep(Duration::from_secs(5)).await;
1600				set_now(
1601					SystemTime::now()
1602						.checked_add(Duration::from_secs(900_00_000))
1603						.expect("Failed to add time to systemtime"),
1604				);
1605				// We do properly finish off the packet but it's too late...
1606				client_stream
1607					.write_all(&[0x11, 0x12])
1608					.await
1609					.expect("Failed to write to client_stream");
1610
1611				let mut buff = [0_u8; 1];
1612				// this read should timeout as the server will close our connection, so we shouldn't hit
1613				// the timeout.
1614				read_bytes = timeout(Duration::from_secs(10), client_stream.read(&mut buff))
1615					.await
1616					.expect("timed out trying to wait for disconnect")
1617					.expect("failure reading from stream");
1618				break;
1619			}
1620		}
1621		std::mem::drop(spawned);
1622
1623		assert_eq!(read_bytes, 0, "Client didn't error on slowloris'd packet");
1624	}
1625}