cat_dev/net/client/
tcp.rs

1//! Utilities for building a client to communicate over the TCP layer 4
2//! protocol.
3//!
4//! Note this "TCP Client" class allows connecting to multiple 'servers' at
5//! once, as well as "binding" to listen as a 'host'. This is mostly due to the
6//! fact we have to support SDIO where the person who "connect"'s acts as
7//! the server.
8//!
9//! ## TCP and NAGLE'ing
10//!
11//! TCP the protocol is fundamentally built on-top of a 'stream', and is
12//! not packet oriented. For a client this affects us in two places:
13//!
14//! 1. When reading responses back. Similar to a server we have to know when
15//!    a packet starts/ends in the response. We use the same nagle guard method
16//!    that our servers use.
17//! 2. When writing our requests, just because our servers are well behaved,
18//!    this doesn't mean that all server implementations are well behaved. So,
19//!    we will automatically turn on `no_delay` to prevent nagle's algorithm
20//!    from getting in our way. So our write calls will always lead to
21//!    hopefully one write call on the other side.
22//!
23//! ## Notes about Concurrency
24//!
25//! This TCP Client unfortunately has to make the sacrifice and process one
26//! packet per stream at a time. While you can have as many TCP streams as you
27//! want, and we should be able to handle many at the same time! Unfortunately
28//! the ordered nature of TCP, along with some protocol designs implemented by
29//! nintendo means this server must also force that we process one packet per
30//! tcp stream at a time.
31//!
32//! Most notably this comes from the fact that our file servers will
33//! consistently break their normal "NAGLE" protection, and we have to do just
34//! raw reads of N bytes from the stream (in both ways), _BEFORE_ processing
35//! another request.
36//!
37//! ## API Notes
38//!
39//! Most, TCP Clients only expect to connect to a single server, and have just
40//! a single TCP Stream. This means just one client to `send`, and `recv` from.
41//! HOWEVER, this simple model doesn't work for two reasons:
42//!
43//! 1. SDIO's TCP "Client" is actually a TCP Server that allows multiple folks
44//!    to connect to it. Although in reality folks are only supposed to connect
45//!    once, it techincally supports multiple.
46//! 2. Scientists. One goal of all of our clients/servers is they allow
47//!    debugging with scientists. Scientists allow connecting to many upstreams
48//!    at once, and diff'ing between them. So our TCP Client needs to fit in
49//!    there as well.
50//!
51//! In order to keep the APIs as simple as possible (keeping one "send", and
52//! one "recv" interface) while also supporting multiple clients that need to
53//! broadcast to all we introduce the concept of a "primary" connection to
54//! treat responses as "canonical" from. By default this is whoever we connect
55//! to, or whoever connects to us first. You can always re-assign this
56//! manually, but the first is always "primary". If the primary connection is
57//! dropped, we will switch to the next oldest connection automatically.
58//!
59//! The way this works is when you call `send` the response you get will always
60//! be from this "primary" connection, all the rest will be silently dropped,
61//! but will still have sent the packet (and read the response) from all other
62//! connections. This way when dealing with stateful protocols, it all will
63//! continue to work.
64//!
65//! ### Scientists
66//!
67//! Scientists on the other hand can call the [`TCPClient::should_keep_all_responses`]
68//! API, which will allow clients to act the same, but instead of dropping the
69//! responses from other upstreams will keep them until they can be removed
70//! later with
71//! ([`TCPClient::take_all_response_for_request_id`]).
72
73use crate::{
74	errors::{CatBridgeError, NetworkError},
75	net::{
76		DEFAULT_SLOWLORIS_TIMEOUT, STREAM_ID, TCP_READ_BUFFER_SIZE,
77		additions::RequestID,
78		client::{
79			errors::CommonNetClientNetworkError,
80			models::{
81				DisconnectAsyncDropClient, RequestStreamEvent, RequestStreamMessage,
82				UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
83			},
84		},
85		errors::{CommonNetAPIError, CommonNetNetworkError},
86		handlers::{
87			OnRequestStreamBeginHandler, OnRequestStreamEndHandler, OnStreamBeginHandlerAsService,
88			OnStreamEndHandlerAsService,
89		},
90		models::{FromRequestParts, NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
91		now,
92	},
93};
94use bytes::{Bytes, BytesMut};
95use fnv::{FnvHashMap, FnvHashSet};
96use futures::future::join_all;
97use miette::miette;
98use scc::HashMap as ConcurrentHashMap;
99use std::{
100	collections::VecDeque,
101	fmt::{Debug, Formatter, Result as FmtResult},
102	hash::BuildHasherDefault,
103	net::{Ipv4Addr, SocketAddr, SocketAddrV4},
104	sync::{
105		Arc,
106		atomic::{AtomicU64, Ordering},
107	},
108	time::{Duration, Instant, SystemTime},
109};
110use tokio::{
111	io::{AsyncReadExt, AsyncWriteExt},
112	net::{TcpListener, TcpStream, ToSocketAddrs},
113	sync::mpsc::{
114		Receiver as BoundedReceiver, Sender as BoundedSender, channel as bounded_channel,
115		error::SendTimeoutError,
116	},
117	task::{Builder as TaskBuilder, block_in_place},
118	time::{sleep, timeout},
119};
120use tower::{Layer, Service, util::BoxCloneService};
121use tracing::{Instrument, debug, error_span, trace, warn};
122use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
123
124#[cfg(debug_assertions)]
125use crate::net::SPRIG_TRACE_IO;
126
127const EMPTY_TIMEOUT: Duration = Duration::from_secs(0);
128
129/// A generic TCP client that can handle connections to multiple TCP Servers.
130///
131/// This client allows for connecting to many servers all at once, or creating
132/// a real "TCP Server" that allows people to connect to it (for SDIO). Because
133/// of the weirdness of our SDIO client, and the hooks for scientists this
134/// client ends up accepting a lot of parameters, see the module documentation
135/// for more information.
136pub struct TCPClient {
137	/// Cat-dev's need load-bearing sleeps as they can "ACK" a ppacket, but
138	/// throw away the bytes and pretend it never got implied.
139	///
140	/// This is used for sleeping before then. By default this isn't set,
141	/// cat-dev services will call: [`TCPClient::set_cat_dev_slowdown`].
142	cat_dev_slowdown: Option<Duration>,
143	/// For devices that can't receive too much data at once.
144	///
145	/// When devices *cough* MION *cough* can't receive too much data at once,
146	/// we need to chunk it, and rest between those chunks. This will chunk those
147	/// packets for us.
148	chunk_output_at_size: Option<usize>,
149	/// Keep responses for all streams, mostly used for scientist like code.
150	keep_all_responses: bool,
151	/// Determines when a packet starts, and ends.
152	nagle_guard: NagleGuard,
153	/// A tower service to call when a particular stream starts.
154	///
155	/// This is effectively like an "on connect" hook that you can use to
156	/// call functions.
157	on_stream_begin: Option<UnderlyingOnStreamBeginService<()>>,
158	/// A tower service to call when a particular stream ends.
159	///
160	/// This is effectively like an "on disconnect" hook that you can use
161	/// to call functions.
162	on_stream_end: Option<UnderlyingOnStreamEndService<()>>,
163	/// A function to apply some sort of processing before doing any sort
164	/// of NAGLE/SLOWLORIS logic.
165	///
166	/// This is best used for encryption/decryption that needs to be handled
167	/// before anything else.
168	pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
169	/// A function to apply some sort of processing before sending a packet
170	/// out.
171	///
172	/// This is the very last thing before `send` is actually called. This is
173	/// best used for encryption/decryption that needs to be only processed
174	/// at the end.
175	post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
176	/// The stream id of the current "active" stream.
177	primary_stream_id: Arc<AtomicU64>,
178	/// The list of active client streams.
179	streams: Arc<ConcurrentHashMap<u64, TCPClientStream>>,
180	/// The name of the service being provided by this server, to attach to logs.
181	service_name: &'static str,
182	/// The "Slowloris Detection" timeout.
183	slowloris_timeout: Duration,
184	/// If we should log all packet requests/responses when compiled with debug
185	/// assertions.
186	#[cfg(debug_assertions)]
187	trace_during_debug: bool,
188}
189
190impl TCPClient {
191	/// Construct a new TCP Client.
192	///
193	/// This TCP client will by default not connect to anything/listen to any
194	/// ports. Instead you will have to call [`TCPClient::connect`], and
195	/// [`TCPClient::bind`] manually to setup connections.
196	///
197	/// Remember the default [`TCPClient::send`], and [`TCPClient::receive`] will
198	/// only return responses to the "primary" connection (whichever connection
199	/// has the oldest connection). That is unless you call
200	/// [`TCPClient::set_primary_stream`].
201	#[must_use]
202	pub fn new(
203		service_name: &'static str,
204		guard: impl Into<NagleGuard>,
205		nagle_hooks: (
206			Option<&'static dyn PreNagleFnTy>,
207			Option<&'static dyn PostNagleFnTy>,
208		),
209		trace_io_during_debug: bool,
210	) -> Self {
211		#[cfg(not(debug_assertions))]
212		{
213			if trace_io_during_debug {
214				warn!(
215					"Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
216				);
217			}
218		}
219
220		Self {
221			cat_dev_slowdown: None,
222			chunk_output_at_size: None,
223			keep_all_responses: false,
224			nagle_guard: guard.into(),
225			on_stream_begin: None,
226			on_stream_end: None,
227			pre_nagle_hook: nagle_hooks.0,
228			post_nagle_hook: nagle_hooks.1,
229			primary_stream_id: Arc::new(AtomicU64::new(0)),
230			service_name,
231			slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
232			streams: Arc::new(ConcurrentHashMap::default()),
233			#[cfg(debug_assertions)]
234			trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
235		}
236	}
237
238	/// Set the slowdown to before sending bytes from this server.
239	pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
240		self.cat_dev_slowdown = slowdown;
241	}
242
243	/// Mark that this client should keep all responses (e.g. those from the non
244	/// active upstreams).
245	pub const fn should_keep_all_responses(&mut self) {
246		self.keep_all_responses = true;
247	}
248
249	/// Set directly whether or not this client should keep all responses.
250	pub const fn set_keep_all_responses(&mut self, keep: bool) {
251		self.keep_all_responses = keep;
252	}
253
254	/// Set the primary stream to receive responses from.
255	pub fn set_primary_stream(&mut self, stream_id: u64) {
256		self.primary_stream_id.store(stream_id, Ordering::Release);
257	}
258
259	#[must_use]
260	pub const fn chunk_output_at_size(&self) -> Option<usize> {
261		self.chunk_output_at_size
262	}
263
264	pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
265		self.chunk_output_at_size = new_size;
266	}
267
268	#[must_use]
269	pub const fn slowloris_timeout(&self) -> Duration {
270		self.slowloris_timeout
271	}
272	pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
273		self.slowloris_timeout = slowloris_timeout;
274	}
275
276	#[must_use]
277	pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<()>> {
278		self.on_stream_begin.as_ref()
279	}
280
281	/// Set a hook to run when a stream has been created.
282	///
283	/// This is what happens when a new machine connects. So it may also be
284	/// refered to as "on connect". This assumes you're assigning a service
285	/// that already exists and is in the raw storage type, you may want to
286	/// look into [`Self::set_on_stream_begin`], or
287	/// [`Self::set_on_stream_begin_service`].
288	///
289	/// ## Errors
290	///
291	/// If the stream beginning hook has already been registered. If you're
292	/// looking to perform multiple actions at once, look into layer-ing.
293	pub fn set_raw_on_stream_begin(
294		&mut self,
295		on_start: Option<UnderlyingOnStreamBeginService<()>>,
296	) -> Result<(), CommonNetAPIError> {
297		if self.on_stream_begin.is_some() {
298			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
299		}
300
301		self.on_stream_begin = on_start;
302		Ok(())
303	}
304
305	/// Set a function hook to run when a stream has been created.
306	///
307	/// This is what happens when a new machine connects. So it may also be
308	/// refered to as "on connect". This assumes you're assigning a function
309	/// to on stream begin otherwise use [`Self::set_raw_on_stream_begin`],
310	/// or [`Self::set_on_stream_begin_service`].
311	///
312	/// ## Errors
313	///
314	/// If the stream beginning hook has already been registered. If you're
315	/// looking to perform multiple actions at once, look into layer-ing.
316	pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
317		&mut self,
318		handler: HandlerTy,
319	) -> Result<(), CommonNetAPIError>
320	where
321		HandlerParamsTy: Send + 'static,
322		HandlerTy: OnRequestStreamBeginHandler<HandlerParamsTy, ()> + Clone + Send + 'static,
323	{
324		if self.on_stream_begin.is_some() {
325			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
326		}
327
328		let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
329		self.on_stream_begin = Some(boxed);
330		Ok(())
331	}
332
333	/// Set a function hook to run when a stream has been created.
334	///
335	/// This is what happens when a new machine connects. So it may also be
336	/// refered to as "on connect". This assumes you're assigning a [`Service`]
337	/// to on stream begin otherwise use [`Self::set_raw_on_stream_begin`],
338	/// or [`Self::set_on_stream_begin`].
339	///
340	/// ## Errors
341	///
342	/// If the stream beginning hook has already been registered. If you're
343	/// looking to perform multiple actions at once, look into layer-ing.
344	pub fn set_on_stream_begin_service<ServiceTy>(
345		&mut self,
346		service_ty: ServiceTy,
347	) -> Result<(), CommonNetAPIError>
348	where
349		ServiceTy: Clone
350			+ Send
351			+ Service<RequestStreamEvent<()>, Response = bool, Error = CatBridgeError>
352			+ 'static,
353		ServiceTy::Future: Send + 'static,
354	{
355		if self.on_stream_begin.is_some() {
356			return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
357		}
358
359		self.on_stream_begin = Some(BoxCloneService::new(service_ty));
360		Ok(())
361	}
362
363	/// Add a layer to the service to process when a stream begins, or a new
364	/// connection is created.
365	///
366	/// ## Errors
367	///
368	/// If there is no on stream begin handler that is currently active.
369	pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
370		&mut self,
371		layer: LayerTy,
372	) -> Result<(), CommonNetAPIError>
373	where
374		LayerTy: Layer<UnderlyingOnStreamBeginService<()>, Service = ServiceTy>,
375		ServiceTy: Service<RequestStreamEvent<()>, Response = bool, Error = CatBridgeError>
376			+ Clone
377			+ Send
378			+ 'static,
379		<LayerTy::Service as Service<RequestStreamEvent<()>>>::Future: Send + 'static,
380	{
381		let Some(srvc) = self.on_stream_begin.take() else {
382			return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
383		};
384
385		self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
386		Ok(())
387	}
388
389	#[must_use]
390	pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<()>> {
391		self.on_stream_end.as_ref()
392	}
393
394	/// Set a hook to run when a stream is being destroyed.
395	///
396	/// This is what happens when a machine disconnects. So it may also be
397	/// refered to as "on disconnect". This assumes you're assigning a service
398	/// that already exists and is in the raw storage type, you may want to
399	/// look into [`Self::set_on_stream_end`], or
400	/// [`Self::set_on_stream_end_service`].
401	///
402	/// ## Errors
403	///
404	/// If the stream ending hook has already been registered. If you're
405	/// looking to perform multiple actions at once, look into layer-ing.
406	pub fn set_raw_on_stream_end(
407		&mut self,
408		on_end: Option<UnderlyingOnStreamEndService<()>>,
409	) -> Result<(), CommonNetAPIError> {
410		if self.on_stream_end.is_some() {
411			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
412		}
413
414		self.on_stream_end = on_end;
415		Ok(())
416	}
417
418	/// Set a function hook to run when a stream is being destroyed.
419	///
420	/// This is what happens when a machine disconnects. So it may also be
421	/// refered to as "on disconnect". This assumes you're assigning a function
422	/// to on stream end otherwise use [`Self::set_raw_on_stream_end`],
423	/// or [`Self::set_on_stream_end_service`].
424	///
425	/// ## Errors
426	///
427	/// If the stream ending hook has already been registered. If you're
428	/// looking to perform multiple actions at once, look into layer-ing.
429	pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
430		&mut self,
431		handler: HandlerTy,
432	) -> Result<(), CommonNetAPIError>
433	where
434		HandlerParamsTy: Send + 'static,
435		HandlerTy: OnRequestStreamEndHandler<HandlerParamsTy, ()> + Clone + Send + 'static,
436	{
437		if self.on_stream_end.is_some() {
438			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
439		}
440
441		let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
442		self.on_stream_end = Some(boxed);
443		Ok(())
444	}
445
446	/// Set a function hook to run when a stream is being destroyed.
447	///
448	/// This is what happens when a machine disconnects. So it may also be
449	/// refered to as "on disconnect". This assumes you're assigning a [`Service`]
450	/// to on stream end otherwise use [`Self::set_raw_on_stream_end`],
451	/// or [`Self::set_on_stream_end`].
452	///
453	/// ## Errors
454	///
455	/// If the stream beginning hook has already been registered. If you're
456	/// looking to perform multiple actions at once, look into layer-ing.
457	pub fn set_on_stream_end_service<ServiceTy>(
458		&mut self,
459		service_ty: ServiceTy,
460	) -> Result<(), CommonNetAPIError>
461	where
462		ServiceTy: Clone
463			+ Send
464			+ Service<RequestStreamEvent<()>, Response = (), Error = CatBridgeError>
465			+ 'static,
466		ServiceTy::Future: Send + 'static,
467	{
468		if self.on_stream_end.is_some() {
469			return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
470		}
471
472		self.on_stream_end = Some(BoxCloneService::new(service_ty));
473		Ok(())
474	}
475
476	/// Add a layer to the service to process when a stream ends, or a new
477	/// connection is destroyed.
478	///
479	/// ## Errors
480	///
481	/// If there is no on stream end handler that is currently active.
482	pub fn layer_on_stream_end<LayerTy, ServiceTy>(
483		&mut self,
484		layer: LayerTy,
485	) -> Result<(), CommonNetAPIError>
486	where
487		LayerTy: Layer<UnderlyingOnStreamEndService<()>, Service = ServiceTy>,
488		ServiceTy: Service<RequestStreamEvent<()>, Response = (), Error = CatBridgeError>
489			+ Clone
490			+ Send
491			+ 'static,
492		<LayerTy::Service as Service<RequestStreamEvent<()>>>::Future: Send + 'static,
493	{
494		let Some(srvc) = self.on_stream_end.take() else {
495			return Err(CommonNetAPIError::OnStreamEndNotRegistered);
496		};
497
498		self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
499		Ok(())
500	}
501
502	/// "Bind" this client to listen on a specific address.
503	///
504	/// I know this is unusual for a "TCP Client", binding is for servers? I mean
505	/// who exactly would be binding to an address as a client? However, certain
506	/// protocols used for PCFS & the like are "reverse". E.g. the pc the server
507	/// connects to the "client" the MION listening on a server.
508	///
509	/// TCP Servers also have a "connect" method for this very reason.
510	///
511	/// ## Errors
512	///
513	/// If we cannot spin up a server to listen on this host. This doesn't mean
514	/// someone is connected, just that we're listening. You may want to use:
515	/// [`TCPClient::wait_for_connection`].
516	pub async fn bind<AddrTy: ToSocketAddrs>(&self, address: AddrTy) -> Result<(), CatBridgeError> {
517		let listener = TcpListener::bind(address).await.map_err(NetworkError::IO)?;
518
519		let client_address = listener.local_addr().map_err(NetworkError::IO)?;
520		let cloned_stream_begin = self.on_stream_begin.clone();
521		let cloned_stream_end = self.on_stream_end.clone();
522		let cloned_nagle_guard = self.nagle_guard.clone();
523		let cloned_slowerloris_timeout = self.slowloris_timeout;
524		let streams_ref = self.streams.clone();
525		let primary_stream_id_ref = self.primary_stream_id.clone();
526		let cloned_chunk_output_at_size = self.chunk_output_at_size;
527		let cloned_pre_nagle_hook = self.pre_nagle_hook;
528		let cloned_post_nagle_hook = self.post_nagle_hook;
529		#[cfg(debug_assertions)]
530		let cloned_trace = self.trace_during_debug;
531		let cloned_service_name = self.service_name;
532		let cloned_cat_dev_slowdown = self.cat_dev_slowdown;
533
534		TaskBuilder::new()
535			.name("cat_dev::net::tcp_client::bind().loop")
536			.spawn(async move {
537				loop {
538					let (stream, server_address) = match listener.accept().await {
539						Ok(tuple) => tuple,
540						Err(cause) => {
541							warn!(
542								?cause,
543								client.address = %client_address,
544								"cat_dev::net::tcp_client::bind(): Failed to accept connection!",
545							);
546							continue;
547						}
548					};
549					let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
550
551					let cloned_cloned_stream_begin = cloned_stream_begin.clone();
552					let cloned_cloned_stream_end = cloned_stream_end.clone();
553					let cloned_cloned_nagle_guard = cloned_nagle_guard.clone();
554					let cloned_streams_ref = streams_ref.clone();
555					let cloned_primary_stream_id_ref = primary_stream_id_ref.clone();
556
557					if let Err(cause) = TaskBuilder::new()
558						.name("cat_dev::net::tcp_client::bind().connection.handle")
559						.spawn(async move {
560							if let Err(cause) = Self::handle_tcp_stream(
561								stream,
562								stream_id,
563								server_address,
564								cloned_cloned_stream_begin,
565								cloned_cloned_stream_end,
566								cloned_cloned_nagle_guard,
567								cloned_slowerloris_timeout,
568								cloned_streams_ref,
569								cloned_primary_stream_id_ref,
570								cloned_chunk_output_at_size,
571								cloned_pre_nagle_hook,
572								cloned_post_nagle_hook,
573								cloned_cat_dev_slowdown,
574								#[cfg(debug_assertions)]
575								cloned_trace,
576							)
577							.instrument(error_span!(
578								"CatDevTCPClientConnect",
579								client.address = %client_address,
580								server.address = %server_address,
581								client.service = cloned_service_name,
582								stream.id = stream_id,
583								stream.stream_type = "client",
584							))
585							.await
586							{
587								warn!(
588									?cause,
589									client.address = %client_address,
590									server.address = %server_address,
591									client.service = cloned_service_name,
592									"Error escaped while handling TCP Connection.",
593								);
594							}
595						}) {
596						warn!(
597							?cause,
598							client.address = %client_address,
599							server.address = %server_address,
600							client.service = cloned_service_name,
601							"Error handling client connection, no task could be allocated.",
602						);
603					}
604
605					trace!(
606						server.address = %server_address,
607						client.address = %client_address,
608						"cat_dev::net::tcp_client::bind(): received connection (listener.accept())",
609					);
610				}
611			})
612			.map_err(CatBridgeError::SpawnFailure)?;
613
614		Ok(())
615	}
616
617	/// This will be an async function that will not return until at least one
618	/// server has connected to our client.
619	pub async fn wait_for_connection(&self) {
620		// Loop until we have an active....
621		while self.get_active_sid().await.is_err() {
622			sleep(Duration::from_secs(1)).await;
623		}
624	}
625
626	/// Connect to a new server as a raw TCP client.
627	///
628	/// This method will return the stream id to use for later requests to
629	/// specific streams.
630	///
631	/// ## Errors
632	///
633	/// If we can't connect to the remote TCP Server completing the three way
634	/// handshake, or if we the stream begin handler returns an error/failure of
635	/// some kind.
636	pub async fn connect<AddrTy: ToSocketAddrs>(
637		&self,
638		address: AddrTy,
639	) -> Result<u64, CatBridgeError> {
640		let raw_stream = TcpStream::connect(address)
641			.await
642			.map_err(NetworkError::IO)?;
643		let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
644		let remote_address = raw_stream.peer_addr().map_err(NetworkError::IO)?;
645		let local_address = raw_stream.local_addr().map_err(NetworkError::IO)?;
646		trace!(
647			server.address = %remote_address,
648			client.address = %local_address,
649			stream.id = stream_id,
650			stream.stream_type = "client",
651			"cat_dev::net::tcp_client::connect(): started connection (TcpStream::connect())",
652		);
653
654		let cloned_stream_begin = self.on_stream_begin.clone();
655		let cloned_stream_end = self.on_stream_end.clone();
656		let cloned_nagle_guard = self.nagle_guard.clone();
657		let cloned_slowerloris_timeout = self.slowloris_timeout;
658		let streams_ref = self.streams.clone();
659		let primary_stream_id_ref = self.primary_stream_id.clone();
660		let cloned_chunk_output_at_size = self.chunk_output_at_size;
661		let cloned_pre_nagle_hook = self.pre_nagle_hook;
662		let cloned_post_nagle_hook = self.post_nagle_hook;
663		#[cfg(debug_assertions)]
664		let cloned_trace = self.trace_during_debug;
665		let cloned_service_name = self.service_name;
666		let cloned_cat_dev_slowdown = self.cat_dev_slowdown;
667
668		TaskBuilder::new()
669			.name("cat_dev::net::tcp_client::connect().connection.handle")
670			.spawn(async move {
671				if let Err(cause) = Self::handle_tcp_stream(
672					raw_stream,
673					stream_id,
674					remote_address,
675					cloned_stream_begin,
676					cloned_stream_end,
677					cloned_nagle_guard,
678					cloned_slowerloris_timeout,
679					streams_ref,
680					primary_stream_id_ref,
681					cloned_chunk_output_at_size,
682					cloned_pre_nagle_hook,
683					cloned_post_nagle_hook,
684					cloned_cat_dev_slowdown,
685					#[cfg(debug_assertions)]
686					cloned_trace,
687				)
688				.instrument(error_span!(
689					"CatDevTCPClientConnect",
690					client.address = %local_address,
691					server.address = %remote_address,
692					client.service = cloned_service_name,
693					stream.id = stream_id,
694					stream.stream_type = "client",
695				))
696				.await
697				{
698					warn!(
699						?cause,
700						client.address = %local_address,
701						server.address = %remote_address,
702						client.service = cloned_service_name,
703						"Error escaped while handling TCP Connection.",
704					);
705				}
706			})
707			.map_err(CatBridgeError::SpawnFailure)?;
708
709		Ok(stream_id)
710	}
711
712	/// Send a series of bytes over the wire potentially receiving responses back.
713	///
714	/// This will always only take the response from the 'primary server'. Even if
715	/// other servers respond first or at all. It will always be the primary
716	/// response.
717	///
718	/// If you want to truly access all the client streams at once use
719	/// [`Self::broadcast_send`] to send to all, and receive all their responses
720	/// back.
721	///
722	/// This _will_ return the stream that was used as the 'primary', the
723	/// request-id to wait for a response later or otherwise, and the optional
724	/// response if we waited for one.
725	///
726	/// ## Errors
727	///
728	/// This function will error if we run into any issues writing or reading the
729	/// bytes from the stream in a timely manner.
730	pub async fn send<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
731		&self,
732		body: BodyTy,
733		wait_for_response_timeout: Option<Duration>,
734	) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
735		// This will be cloned, and modified for each stream we send out too.
736		let mut request = Request::new_with_state(
737			body.try_into().map_err(|cause| {
738				CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
739			})?,
740			SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
741			(),
742			None,
743		);
744		let req_id = RequestID::generate();
745		request.extensions_mut().insert(req_id.clone());
746
747		self.common_send(request, req_id, wait_for_response_timeout)
748			.await
749	}
750
751	/// Send a series of bytes over the wire potentially receiving responses back.
752	///
753	/// This will always only take the response from the 'primary server'. Even if
754	/// other servers respond first or at all. It will always be the primary
755	/// response.
756	///
757	/// If you want to truly access all the client streams at once use
758	/// [`Self::broadcast_send`] to send to all, and receive all their responses
759	/// back.
760	///
761	/// This _will_ return the stream that was used as the 'primary', the
762	/// request-id to wait for a response later or otherwise, and the optional
763	/// response if we waited for one.
764	///
765	/// ## Errors
766	///
767	/// This function will error if we run into any issues writing or reading the
768	/// bytes from the stream in a timely manner.
769	pub async fn send_with_read_amount<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
770		&self,
771		body: BodyTy,
772		wait_for_response_timeout: Option<Duration>,
773		explicit_read_amount: usize,
774	) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
775		// This will be cloned, and modified for each stream we send out too.
776		let mut request = Request::new_with_state_and_read_amount(
777			body.try_into().map_err(|cause| {
778				CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
779			})?,
780			SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
781			(),
782			None,
783			explicit_read_amount,
784		);
785		let req_id = RequestID::generate();
786		request.extensions_mut().insert(req_id.clone());
787
788		self.common_send(request, req_id, wait_for_response_timeout)
789			.await
790	}
791
792	/// The equivalent of [`send`], but get all the responses back out of this
793	/// client.
794	///
795	/// ## Errors
796	///
797	/// If we timeout, or run into any sort of error sending or receiving content
798	/// from a stream.
799	pub async fn broadcast_send<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
800		&self,
801		body: BodyTy,
802		wait_for_response_timeout: Duration,
803	) -> Result<FnvHashMap<u64, Option<Response>>, CatBridgeError> {
804		// This will be cloned, and modified for each stream we send out too.
805		let mut request = Request::new_with_state(
806			body.try_into().map_err(|cause| {
807				CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
808			})?,
809			SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
810			(),
811			None,
812		);
813		let req_id = RequestID::generate();
814		request.extensions_mut().insert(req_id.clone());
815
816		let mut ids = FnvHashSet::default();
817		self.streams
818			.iter_async(|stream_id, _stream| {
819				ids.insert(*stream_id);
820				true
821			})
822			.await;
823
824		let mut tasks = Vec::with_capacity(ids.len());
825		for id in &ids {
826			tasks.push(self.send_to_stream(*id, request.clone(), wait_for_response_timeout));
827		}
828		// join all gives us a Vec<Result>, we can collect that into one final
829		// result, to use.
830		join_all(tasks)
831			.await
832			.into_iter()
833			.collect::<Result<(), NetworkError>>()?;
834
835		let mut response_tasks = Vec::with_capacity(ids.len());
836		for id in &ids {
837			response_tasks.push(self.get_response_from_stream(*id, req_id.clone()));
838		}
839		let responses = timeout(wait_for_response_timeout, join_all(response_tasks))
840			.await
841			.map_err(|_| NetworkError::Timeout(wait_for_response_timeout))?;
842
843		let mut map =
844			FnvHashMap::with_capacity_and_hasher(ids.len(), BuildHasherDefault::default());
845		for (got_stream_id, response) in responses {
846			map.insert(got_stream_id, response);
847		}
848		Ok(map)
849	}
850
851	/// Receive a packet from the primary stream.
852	///
853	/// Used for out-of-band receiving of packets that don't have an associated
854	/// [`Self::send`] call.
855	///
856	/// ## Errors
857	///
858	/// If we timeout, or have another series of failures reading bytes from our
859	/// stream.
860	pub async fn receive(&self, wait_until: Duration) -> Result<Option<Response>, NetworkError> {
861		let active_sid = self.get_active_sid().await?;
862
863		let mut tasks;
864		if self.keep_all_responses {
865			tasks = vec![self.get_any_response_from_stream(active_sid)];
866		} else {
867			let mut ids = FnvHashSet::default();
868			self.streams
869				.iter_async(|stream_id, _stream| {
870					ids.insert(*stream_id);
871					true
872				})
873				.await;
874
875			tasks = Vec::with_capacity(ids.len());
876			for id in ids {
877				tasks.push(self.get_any_response_from_stream(id));
878			}
879		}
880		let responses = timeout(wait_until, join_all(tasks))
881			.await
882			.map_err(|_| NetworkError::Timeout(wait_until))?;
883
884		for (got_stream_id, _, response) in responses {
885			if got_stream_id == active_sid {
886				return Ok(response);
887			}
888		}
889
890		Ok(None)
891	}
892
893	/// Get all the responses for a particular request id.
894	///
895	/// In order to keep memory from ballooning up, and the fact that TCP is
896	/// ordered. This API also expects to be called in an 'ordered' way. Previous
897	/// requests that do not match the request id will be dropped.
898	pub async fn take_all_response_for_request_id(
899		&self,
900		request_id: &RequestID,
901		wait_for: Duration,
902	) -> FnvHashMap<u64, Option<Response>> {
903		let mut ids = FnvHashSet::default();
904		self.streams
905			.iter_async(|stream_id, _stream| {
906				ids.insert(*stream_id);
907				true
908			})
909			.await;
910
911		let mut tasks = Vec::with_capacity(ids.len());
912		for id in &ids {
913			tasks.push(timeout(
914				wait_for,
915				self.get_response_from_stream(*id, request_id.clone()),
916			));
917		}
918
919		let mut results: FnvHashMap<u64, Option<Response>> =
920			join_all(tasks).await.into_iter().flatten().collect();
921		for id in ids {
922			results.entry(id).or_insert(None);
923		}
924		results
925	}
926
927	async fn common_send(
928		&self,
929		mock_req: Request<()>,
930		req_id: RequestID,
931		wait_for_response_timeout: Option<Duration>,
932	) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
933		let active_sid = self.get_active_sid().await?;
934
935		let mut ids = FnvHashSet::default();
936		self.streams
937			.iter_async(|stream_id, _stream| {
938				ids.insert(*stream_id);
939				true
940			})
941			.await;
942
943		let mut tasks = Vec::with_capacity(ids.len());
944		for id in &ids {
945			tasks.push(self.send_to_stream(
946				*id,
947				mock_req.clone(),
948				wait_for_response_timeout.unwrap_or(DEFAULT_SLOWLORIS_TIMEOUT),
949			));
950		}
951		// join all gives us a Vec<Result>, we can collect that into one final
952		// result, to use.
953		join_all(tasks)
954			.await
955			.into_iter()
956			.collect::<Result<(), NetworkError>>()?;
957
958		match wait_for_response_timeout {
959			// Don't drain/wait for responses when there are none.
960			None | Some(EMPTY_TIMEOUT) => Ok((active_sid, req_id, None)),
961			Some(duration) => {
962				let mut tasks;
963				// If we keep all responsese
964				if self.keep_all_responses {
965					tasks = vec![self.get_response_from_stream(active_sid, req_id.clone())];
966				} else {
967					tasks = Vec::with_capacity(ids.len());
968					for id in ids {
969						tasks.push(self.get_response_from_stream(id, req_id.clone()));
970					}
971				}
972				let responses = timeout(duration, join_all(tasks))
973					.await
974					.map_err(|_| NetworkError::Timeout(duration))?;
975
976				for (got_stream_id, response) in responses {
977					if got_stream_id == active_sid {
978						return Ok((active_sid, req_id, response));
979					}
980				}
981
982				Ok((active_sid, req_id, None))
983			}
984		}
985	}
986
987	#[allow(
988		// all of our parameters are very well named, and types are not close to
989		// overlapping with each other.
990		//
991		// we also just fundamenetally have a lot of state thanks to the complexity
992		// of all the things we have to handle for a TCP connection, e.g. NAGLE,
993		// delimiters, caches, etc.
994		//
995		// it is also only ever called from one internal function, so it's not like
996		// part of our public facing api.
997		clippy::too_many_arguments,
998	)]
999	async fn handle_tcp_stream(
1000		mut stream: TcpStream,
1001		stream_id: u64,
1002		remote_address: SocketAddr,
1003		on_stream_begin: Option<UnderlyingOnStreamBeginService<()>>,
1004		on_stream_end: Option<UnderlyingOnStreamEndService<()>>,
1005		nagle_guard: NagleGuard,
1006		slowloris_timeout: Duration,
1007		stream_lists: Arc<ConcurrentHashMap<u64, TCPClientStream>>,
1008		active_stream_ptr: Arc<AtomicU64>,
1009		chunk_output_on_size: Option<usize>,
1010		pre_hook: Option<&'static dyn PreNagleFnTy>,
1011		post_hook: Option<&'static dyn PostNagleFnTy>,
1012		cat_dev_slowdown: Option<Duration>,
1013		#[cfg(debug_assertions)] trace_io: bool,
1014	) -> Result<(), CatBridgeError> {
1015		// We drop the 'sender' to cancel the stream. This means we can't hold a
1016		// copy of any 'Sender' for a long life. So we use this small little block
1017		// to do that for us.
1018		let mut receive_packets_to_send: BoundedReceiver<RequestStreamMessage>;
1019		let (response_sink_send, response_sink_recv) = bounded_channel(128);
1020		{
1021			let (mut sender, receiver) = bounded_channel(128);
1022
1023			// First perform any initialization necessary....
1024			//
1025			// And make sure they tell us to continue before doing stuff.
1026			if Self::initialize_stream(
1027				on_stream_begin,
1028				&mut sender,
1029				&remote_address,
1030				&stream,
1031				stream_id,
1032			)
1033			.await?
1034			{
1035				return Ok(());
1036			}
1037
1038			let mut active_stream =
1039				TCPClientStream::new(remote_address, sender, receiver, response_sink_recv);
1040			receive_packets_to_send = active_stream
1041				.steal_send_requests_receiver()
1042				.ok_or_else(|| CatBridgeError::ClosedChannel)?;
1043
1044			std::mem::drop(stream_lists.insert_async(stream_id, active_stream).await);
1045			// Update the active stream pointer if need be.
1046			_ = active_stream_ptr.compare_exchange(
1047				0,
1048				stream_id,
1049				Ordering::AcqRel,
1050				Ordering::Acquire,
1051			);
1052		}
1053
1054		// Connection has been "approved", setup the on disconnect handler.
1055		let _guard = on_stream_end
1056			.map(|service| DisconnectAsyncDropClient::new(service, (), remote_address, stream_id));
1057
1058		let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
1059		// Any previously saved data that was a victim of NAGLE's algorithim, or
1060		// similar.
1061		let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
1062		let mut cached_request_id: Option<RequestID> = None;
1063		let mut nagle_overrides: VecDeque<Option<NagleGuard>> = VecDeque::with_capacity(128);
1064
1065		loop {
1066			tokio::select! {
1067				opt = receive_packets_to_send.recv() => {
1068					// Sender is closed, shutdown our channel cleanly.
1069					if Self::handle_client_write_to_connection(
1070						chunk_output_on_size,
1071						opt,
1072						pre_hook,
1073						&mut cached_request_id,
1074						stream_id,
1075						&mut stream,
1076						&mut nagle_overrides,
1077						cat_dev_slowdown,
1078						#[cfg(debug_assertions)]
1079						trace_io,
1080					).await? {
1081						break;
1082					}
1083				}
1084				read_res = stream.read_buf(&mut buff) => {
1085					let read_bytes = read_res.map_err(NetworkError::IO)?;
1086					buff.truncate(read_bytes);
1087
1088					if buff.is_empty() {
1089						buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
1090						continue;
1091					}
1092
1093					if Self::handle_client_read_from_connection(
1094						buff,
1095						&nagle_guard,
1096						&mut nagle_overrides,
1097						slowloris_timeout,
1098						&mut nagle_cache,
1099						response_sink_send.clone(),
1100						post_hook,
1101						&mut cached_request_id,
1102						stream_id,
1103						#[cfg(debug_assertions)]
1104						trace_io,
1105					).await? {
1106						break;
1107					}
1108					buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
1109				}
1110			}
1111		}
1112
1113		Ok(())
1114	}
1115
1116	async fn initialize_stream(
1117		on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<()>>,
1118		send_channel: &mut BoundedSender<RequestStreamMessage>,
1119		remote_address: &SocketAddr,
1120		tcp_stream: &TcpStream,
1121		stream_id: u64,
1122	) -> Result<bool, CatBridgeError> {
1123		tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
1124
1125		if let Some(mut handle) = on_stream_begin_handler
1126			&& !handle
1127				.call(RequestStreamEvent::new_with_state(
1128					send_channel.clone(),
1129					*remote_address,
1130					Some(stream_id),
1131					(),
1132				))
1133				.await?
1134		{
1135			trace!("handler failed on stream begin hook");
1136			return Ok(true);
1137		}
1138
1139		Ok(false)
1140	}
1141
1142	#[allow(
1143		// All of our types are very differently typed, and well named, so chance
1144		// of confusion is low.
1145		//
1146		// Not to mention this is an internal only method.
1147		clippy::too_many_arguments,
1148	)]
1149	async fn handle_client_read_from_connection<'data>(
1150		mut buff: BytesMut,
1151		nagle_guard: &'data NagleGuard,
1152		nagle_overrides: &mut VecDeque<Option<NagleGuard>>,
1153		slowloris_timeout: Duration,
1154		nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
1155		response_output: BoundedSender<(Option<RequestID>, Response)>,
1156		cloned_post_nagle: Option<&'static dyn PostNagleFnTy>,
1157		cached_request_id: &mut Option<RequestID>,
1158		stream_id: u64,
1159		#[cfg(debug_assertions)] trace_io: bool,
1160	) -> Result<bool, CatBridgeError> {
1161		if let Some(convert_fn) = cloned_post_nagle {
1162			buff = BytesMut::from(block_in_place(|| (*convert_fn)(stream_id, buff.freeze())));
1163		}
1164
1165		#[cfg(debug_assertions)]
1166		{
1167			if trace_io {
1168				debug!(
1169					body.hex = format!("{:02x?}", buff),
1170					body.str = String::from_utf8_lossy(&buff).to_string(),
1171					"cat-dev-trace-input-tcp-client",
1172				);
1173			}
1174		}
1175
1176		// We may be NAGEL'd, so we need to recover/split, and potentially buffer
1177		// any packets. Also watch out for slowloris-esque attacks.
1178		let start_time = now();
1179		if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
1180			// If we can't calculat duration seconds it's negative, or no duration
1181			// has passed yet.
1182			//
1183			// Just treat it as 0.
1184			let total_duration = start_time
1185				.duration_since(old_start_time)
1186				.unwrap_or(Duration::from_secs(0));
1187			if total_duration > slowloris_timeout {
1188				debug!(
1189					cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
1190					"slowloris-detected",
1191				);
1192				return Ok(true);
1193			}
1194
1195			existing_buff.extend(buff);
1196			buff = existing_buff;
1197		}
1198
1199		let mut current_nagle_guard = if let Some(Some(guard)) = nagle_overrides.front() {
1200			guard
1201		} else {
1202			nagle_guard
1203		};
1204
1205		while let Some((start_of_packet, end_of_packet)) = current_nagle_guard.split(&buff)? {
1206			let remaining_buff = buff.split_off(end_of_packet);
1207			let _start_of_buff = buff.split_to(start_of_packet);
1208			let req_body = buff.freeze();
1209			buff = remaining_buff;
1210
1211			if let Err(cause) = response_output
1212				.send((cached_request_id.take(), Response::new_with_body(req_body)))
1213				.await
1214			{
1215				warn!(
1216					?cause,
1217					"internal queue failure will not send disconnect/response."
1218				);
1219			}
1220
1221			if !nagle_overrides.is_empty() {
1222				nagle_overrides.pop_front();
1223				current_nagle_guard = if let Some(Some(guard)) = nagle_overrides.front() {
1224					guard
1225				} else {
1226					nagle_guard
1227				};
1228			}
1229		}
1230
1231		if !buff.is_empty() {
1232			_ = nagle_cache.insert((buff, start_time));
1233		}
1234
1235		Ok(false)
1236	}
1237
1238	#[allow(
1239		// Well typed arguments, lots to do and all that.
1240		clippy::too_many_arguments,
1241	)]
1242	async fn handle_client_write_to_connection(
1243		chunk_output_on_size: Option<usize>,
1244		to_send_to_client_opt: Option<RequestStreamMessage>,
1245		pre_hook: Option<&'static dyn PreNagleFnTy>,
1246		cached_request_id: &mut Option<RequestID>,
1247		stream_id: u64,
1248		raw_stream: &mut TcpStream,
1249		nagle_overrides: &mut VecDeque<Option<NagleGuard>>,
1250		cat_dev_slowdown: Option<Duration>,
1251		#[cfg(debug_assertions)] trace_io: bool,
1252	) -> Result<bool, CatBridgeError> {
1253		let Some(to_send_to_client) = to_send_to_client_opt else {
1254			return Ok(true);
1255		};
1256
1257		match to_send_to_client {
1258			RequestStreamMessage::Disconnect => {
1259				// Clear a value, it's a disconnect.
1260				_ = cached_request_id.take();
1261				trace!("stream-disconnect-message");
1262				Ok(true)
1263			}
1264			RequestStreamMessage::Request(mut req) => {
1265				if let Some(explicit_read) = req.explicit_read_amount() {
1266					nagle_overrides.push_back(Some(NagleGuard::StaticSize(explicit_read)));
1267				} else {
1268					nagle_overrides.push_back(None);
1269				}
1270				if !req.body().is_empty() {
1271					if let Ok(req_id) = RequestID::from_request_parts(&mut req).await {
1272						_ = cached_request_id.insert(req_id);
1273					}
1274					let messages = if let Some(size) = chunk_output_on_size {
1275						req.body_owned()
1276							.chunks(size)
1277							.map(BytesMut::from)
1278							.collect::<Vec<_>>()
1279					} else {
1280						vec![BytesMut::from(req.body_owned())]
1281					};
1282
1283					for message in messages {
1284						#[cfg(debug_assertions)]
1285						if trace_io {
1286							debug!(
1287								body.hex = format!("{message:02x?}"),
1288								body.str = String::from_utf8_lossy(&message).to_string(),
1289								"cat-dev-trace-output-tcp-client",
1290							);
1291						}
1292
1293						let mut full_response = message.clone();
1294						if let Some(pre) = pre_hook {
1295							block_in_place(|| pre(stream_id, &mut full_response));
1296						}
1297						if let Some(slowdown) = cat_dev_slowdown {
1298							sleep(slowdown).await;
1299						}
1300
1301						raw_stream.writable().await.map_err(NetworkError::IO)?;
1302						raw_stream
1303							.write_all(&full_response)
1304							.await
1305							.map_err(NetworkError::IO)?;
1306					}
1307				}
1308
1309				Ok(false)
1310			}
1311		}
1312	}
1313
1314	/// Send a request to a stream if it exists. It may not exist.
1315	///
1316	/// ## Errors
1317	///
1318	/// If we fail to queue up a packet to be sent out over a stream.
1319	async fn send_to_stream(
1320		&self,
1321		stream_id: u64,
1322		mut base_request: Request<()>,
1323		timeout: Duration,
1324	) -> Result<(), NetworkError> {
1325		if let Some(stream) = self.streams.get_async(&stream_id).await {
1326			base_request.update_request_source(stream.server_address(), Some(stream_id));
1327			stream
1328				.send_timeout(RequestStreamMessage::Request(base_request), timeout)
1329				.await
1330				.map_err(|cause| {
1331					CommonNetClientNetworkError::CannotQueueSend(format!("{cause:?}")).into()
1332				})
1333		} else {
1334			// Stream must've gotten removed since we got our list.
1335			Ok(())
1336		}
1337	}
1338
1339	/// Receive a response from a particular stream if it exists, it may not
1340	/// exist.
1341	async fn get_any_response_from_stream(
1342		&self,
1343		stream_id: u64,
1344	) -> (u64, Option<RequestID>, Option<Response>) {
1345		if let Some(mut stream) = self.streams.get_async(&stream_id).await {
1346			let Some((opt_req_id, response)) = stream.response_channel_mut().recv().await else {
1347				return (stream_id, None, None);
1348			};
1349
1350			(stream_id, opt_req_id, Some(response))
1351		} else {
1352			// Stream must've gotten removed since we got our list.
1353			(stream_id, None, None)
1354		}
1355	}
1356
1357	/// Receive a response from a particular stream if it exists, it may not
1358	/// exist.
1359	async fn get_response_from_stream(
1360		&self,
1361		stream_id: u64,
1362		request_id: RequestID,
1363	) -> (u64, Option<Response>) {
1364		if let Some(mut stream) = self.streams.get_async(&stream_id).await {
1365			while let Some((opt_req_id, response)) = stream.response_channel_mut().recv().await {
1366				if let Some(got_req_id) = opt_req_id
1367					&& got_req_id == request_id
1368				{
1369					return (stream_id, Some(response));
1370				}
1371			}
1372
1373			(stream_id, None)
1374		} else {
1375			// Stream must've gotten removed since we got our list.
1376			(stream_id, None)
1377		}
1378	}
1379
1380	/// Get the active stream id to use for connections.
1381	async fn get_active_sid(&self) -> Result<u64, CommonNetClientNetworkError> {
1382		let active_sid = self.primary_stream_id.load(Ordering::Acquire);
1383		if active_sid == 0 {
1384			return Err(CommonNetClientNetworkError::NotConnectedToServer);
1385		}
1386
1387		if !self.streams.contains_async(&active_sid).await {
1388			let mut oldest_stream = None;
1389
1390			self.streams
1391				.iter_async(|stream_id, stream| {
1392					if let Some((_strm_id, strm_created_at)) = oldest_stream {
1393						if stream.opened_at() < strm_created_at {
1394							_ = oldest_stream.insert((*stream_id, stream.opened_at()));
1395						}
1396					} else {
1397						_ = oldest_stream.insert((*stream_id, stream.opened_at()));
1398					}
1399					true
1400				})
1401				.await;
1402		}
1403
1404		Ok(active_sid)
1405	}
1406}
1407
1408impl Debug for TCPClient {
1409	fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
1410		let mut tcp_dbg_struct = fmt.debug_struct("TCPClient");
1411
1412		tcp_dbg_struct
1413			.field("cat_dev_slowdown", &self.cat_dev_slowdown)
1414			.field("chunk_output_at_size", &self.chunk_output_at_size)
1415			.field("keep_all_responses", &self.keep_all_responses)
1416			.field("nagle_guard", &self.nagle_guard)
1417			.field("has_on_stream_begin", &self.on_stream_begin.is_some())
1418			.field("has_on_stream_end", &self.on_stream_end.is_some())
1419			.field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
1420			.field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
1421			.field(
1422				"primary_stream_id",
1423				&self.primary_stream_id.load(Ordering::Relaxed),
1424			)
1425			.field("streams", &self.streams)
1426			.field("service_name", &self.service_name)
1427			.field("slowloris_timeout", &self.slowloris_timeout);
1428
1429		#[cfg(debug_assertions)]
1430		{
1431			tcp_dbg_struct.field("trace_during_debug", &self.trace_during_debug);
1432		}
1433
1434		tcp_dbg_struct.finish()
1435	}
1436}
1437
1438const TCP_CLIENT_FIELDS: &[NamedField<'static>] = &[
1439	NamedField::new("cat_dev_slowdown"),
1440	NamedField::new("chunk_output_at_size"),
1441	NamedField::new("keep_all_responses"),
1442	NamedField::new("nagle_guard"),
1443	NamedField::new("has_on_stream_begin"),
1444	NamedField::new("has_on_stream_end"),
1445	NamedField::new("has_pre_nagle_hook"),
1446	NamedField::new("has_post_nagle_hook"),
1447	NamedField::new("primary_stream_id"),
1448	NamedField::new("streams"),
1449	NamedField::new("service_name"),
1450	NamedField::new("slowloris_timeout"),
1451	#[cfg(debug_assertions)]
1452	NamedField::new("trace_during_debug"),
1453];
1454
1455impl Structable for TCPClient {
1456	fn definition(&self) -> StructDef<'_> {
1457		StructDef::new_static("TCPClient", Fields::Named(TCP_CLIENT_FIELDS))
1458	}
1459}
1460
1461impl Valuable for TCPClient {
1462	fn as_value(&self) -> Value<'_> {
1463		Value::Structable(self)
1464	}
1465
1466	fn visit(&self, visitor: &mut dyn Visit) {
1467		let mut valuable_map = FnvHashMap::default();
1468		self.streams.iter_sync(|stream_id, stream| {
1469			valuable_map.insert(*stream_id, stream.to_valuable());
1470			true
1471		});
1472
1473		visitor.visit_named_fields(&NamedValues::new(
1474			TCP_CLIENT_FIELDS,
1475			&[
1476				Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
1477					format!("{}ms", slowdown.as_millis())
1478				} else {
1479					"<none>".to_string()
1480				}),
1481				Valuable::as_value(&self.chunk_output_at_size),
1482				Valuable::as_value(&self.keep_all_responses),
1483				Valuable::as_value(&self.nagle_guard),
1484				Valuable::as_value(&self.on_stream_begin.is_some()),
1485				Valuable::as_value(&self.on_stream_end.is_some()),
1486				Valuable::as_value(&self.pre_nagle_hook.is_some()),
1487				Valuable::as_value(&self.post_nagle_hook.is_some()),
1488				Valuable::as_value(&self.primary_stream_id.load(Ordering::Relaxed)),
1489				Valuable::as_value(&valuable_map),
1490				Valuable::as_value(&self.service_name),
1491				Valuable::as_value(&self.slowloris_timeout.as_secs()),
1492				#[cfg(debug_assertions)]
1493				Valuable::as_value(&self.trace_during_debug),
1494			],
1495		));
1496	}
1497}
1498
1499/// An active TCP Client stream.
1500///
1501/// This represents a stream that is actively processing. When it drops (and by
1502/// proxy drops it's Sender) the task will notice it's receiver has closed, and
1503/// will automatically close tiself.
1504struct TCPClientStream {
1505	/// The remote address of this stream.
1506	remote_address: SocketAddr,
1507	/// The actual responses coming from this TCP Stream.
1508	response_channel: BoundedReceiver<(Option<RequestID>, Response)>,
1509	/// The receiver for this particular stream. Gets stolen by the worker
1510	/// actually handling tasks.
1511	send_requests_receiver: Option<BoundedReceiver<RequestStreamMessage>>,
1512	/// The sender to send requests to this stream.
1513	send_requests: BoundedSender<RequestStreamMessage>,
1514	/// The instant this stream was opened, used for sorting streams.
1515	time_opened: Instant,
1516}
1517
1518impl TCPClientStream {
1519	/// Create a new TCP Client stream given a sender/receiver pair.
1520	#[must_use]
1521	pub fn new(
1522		remote_address: SocketAddr,
1523		sender: BoundedSender<RequestStreamMessage>,
1524		receiver: BoundedReceiver<RequestStreamMessage>,
1525		response_channel: BoundedReceiver<(Option<RequestID>, Response)>,
1526	) -> Self {
1527		Self {
1528			remote_address,
1529			response_channel,
1530			send_requests_receiver: Some(receiver),
1531			send_requests: sender,
1532			time_opened: Instant::now(),
1533		}
1534	}
1535
1536	#[must_use]
1537	pub const fn to_valuable(&self) -> TCPClientStreamValuable {
1538		TCPClientStreamValuable {
1539			receiver_stolen: self.send_requests_receiver.is_none(),
1540			time_opened: self.time_opened,
1541		}
1542	}
1543
1544	/// The address of the server on the other side.
1545	pub const fn server_address(&self) -> SocketAddr {
1546		self.remote_address
1547	}
1548
1549	#[must_use]
1550	pub const fn response_channel_mut(
1551		&mut self,
1552	) -> &mut BoundedReceiver<(Option<RequestID>, Response)> {
1553		&mut self.response_channel
1554	}
1555
1556	/// Steal the receiver if it already hasn't been stolen already.
1557	#[must_use]
1558	pub fn steal_send_requests_receiver(
1559		&mut self,
1560	) -> Option<BoundedReceiver<RequestStreamMessage>> {
1561		self.send_requests_receiver.take()
1562	}
1563
1564	/// Send a message to the client stream with a given timeout.
1565	pub async fn send_timeout(
1566		&self,
1567		message: RequestStreamMessage,
1568		timeout: Duration,
1569	) -> Result<(), SendTimeoutError<RequestStreamMessage>> {
1570		self.send_requests.send_timeout(message, timeout).await
1571	}
1572
1573	/// When this stream was opened.
1574	pub const fn opened_at(&self) -> Instant {
1575		self.time_opened
1576	}
1577}
1578
1579impl Debug for TCPClientStream {
1580	fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
1581		fmt.debug_struct("TCPClientStream")
1582			.field("receiver_stolen", &self.send_requests_receiver.is_none())
1583			.field("time_opened", &self.time_opened)
1584			.finish_non_exhaustive()
1585	}
1586}
1587
1588impl PartialEq for TCPClientStream {
1589	fn eq(&self, other: &Self) -> bool {
1590		self.time_opened == other.time_opened
1591	}
1592}
1593
1594impl PartialOrd for TCPClientStream {
1595	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1596		Some(self.time_opened.cmp(&other.time_opened))
1597	}
1598}
1599
1600const TCP_CLIENT_STREAM_FIELDS: &[NamedField<'static>] = &[
1601	NamedField::new("receiver_stolen"),
1602	NamedField::new("time_opened"),
1603];
1604
1605impl Structable for TCPClientStream {
1606	fn definition(&self) -> StructDef<'_> {
1607		StructDef::new_static("TCPClientStream", Fields::Named(TCP_CLIENT_STREAM_FIELDS))
1608	}
1609}
1610
1611impl Valuable for TCPClientStream {
1612	fn as_value(&self) -> Value<'_> {
1613		Value::Structable(self)
1614	}
1615
1616	fn visit(&self, visitor: &mut dyn Visit) {
1617		visitor.visit_named_fields(&NamedValues::new(
1618			TCP_CLIENT_STREAM_FIELDS,
1619			&[
1620				Valuable::as_value(&self.send_requests_receiver.is_none()),
1621				Valuable::as_value(
1622					&SystemTime::now()
1623						.checked_add(self.time_opened.elapsed())
1624						.unwrap_or_else(SystemTime::now)
1625						.duration_since(SystemTime::UNIX_EPOCH)
1626						.unwrap_or_default()
1627						.as_secs(),
1628				),
1629			],
1630		));
1631	}
1632}
1633
1634struct TCPClientStreamValuable {
1635	receiver_stolen: bool,
1636	time_opened: Instant,
1637}
1638
1639impl Structable for TCPClientStreamValuable {
1640	fn definition(&self) -> StructDef<'_> {
1641		StructDef::new_static("TCPClientStream", Fields::Named(TCP_CLIENT_STREAM_FIELDS))
1642	}
1643}
1644
1645impl Valuable for TCPClientStreamValuable {
1646	fn as_value(&self) -> Value<'_> {
1647		Value::Structable(self)
1648	}
1649
1650	fn visit(&self, visitor: &mut dyn Visit) {
1651		visitor.visit_named_fields(&NamedValues::new(
1652			TCP_CLIENT_STREAM_FIELDS,
1653			&[
1654				Valuable::as_value(&self.receiver_stolen),
1655				Valuable::as_value(
1656					&SystemTime::now()
1657						.checked_add(self.time_opened.elapsed())
1658						.unwrap_or_else(SystemTime::now)
1659						.duration_since(SystemTime::UNIX_EPOCH)
1660						.unwrap_or_default()
1661						.as_secs(),
1662				),
1663			],
1664		));
1665	}
1666}