1use crate::{
74 errors::{CatBridgeError, NetworkError},
75 net::{
76 DEFAULT_SLOWLORIS_TIMEOUT, STREAM_ID, TCP_READ_BUFFER_SIZE,
77 additions::RequestID,
78 client::{
79 errors::CommonNetClientNetworkError,
80 models::{
81 DisconnectAsyncDropClient, RequestStreamEvent, RequestStreamMessage,
82 UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
83 },
84 },
85 errors::{CommonNetAPIError, CommonNetNetworkError},
86 handlers::{
87 OnRequestStreamBeginHandler, OnRequestStreamEndHandler, OnStreamBeginHandlerAsService,
88 OnStreamEndHandlerAsService,
89 },
90 models::{FromRequestParts, NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
91 now,
92 },
93};
94use bytes::{Bytes, BytesMut};
95use fnv::{FnvHashMap, FnvHashSet};
96use futures::future::join_all;
97use miette::miette;
98use scc::HashMap as ConcurrentHashMap;
99use std::{
100 collections::VecDeque,
101 fmt::{Debug, Formatter, Result as FmtResult},
102 hash::BuildHasherDefault,
103 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
104 sync::{
105 Arc,
106 atomic::{AtomicU64, Ordering},
107 },
108 time::{Duration, Instant, SystemTime},
109};
110use tokio::{
111 io::{AsyncReadExt, AsyncWriteExt},
112 net::{TcpListener, TcpStream, ToSocketAddrs},
113 sync::mpsc::{
114 Receiver as BoundedReceiver, Sender as BoundedSender, channel as bounded_channel,
115 error::SendTimeoutError,
116 },
117 task::{Builder as TaskBuilder, block_in_place},
118 time::{sleep, timeout},
119};
120use tower::{Layer, Service, util::BoxCloneService};
121use tracing::{Instrument, debug, error_span, trace, warn};
122use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
123
124#[cfg(debug_assertions)]
125use crate::net::SPRIG_TRACE_IO;
126
127const EMPTY_TIMEOUT: Duration = Duration::from_secs(0);
128
129pub struct TCPClient {
137 cat_dev_slowdown: Option<Duration>,
143 chunk_output_at_size: Option<usize>,
149 keep_all_responses: bool,
151 nagle_guard: NagleGuard,
153 on_stream_begin: Option<UnderlyingOnStreamBeginService<()>>,
158 on_stream_end: Option<UnderlyingOnStreamEndService<()>>,
163 pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
169 post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
176 primary_stream_id: Arc<AtomicU64>,
178 streams: Arc<ConcurrentHashMap<u64, TCPClientStream>>,
180 service_name: &'static str,
182 slowloris_timeout: Duration,
184 #[cfg(debug_assertions)]
187 trace_during_debug: bool,
188}
189
190impl TCPClient {
191 #[must_use]
202 pub fn new(
203 service_name: &'static str,
204 guard: impl Into<NagleGuard>,
205 nagle_hooks: (
206 Option<&'static dyn PreNagleFnTy>,
207 Option<&'static dyn PostNagleFnTy>,
208 ),
209 trace_io_during_debug: bool,
210 ) -> Self {
211 #[cfg(not(debug_assertions))]
212 {
213 if trace_io_during_debug {
214 warn!(
215 "Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
216 );
217 }
218 }
219
220 Self {
221 cat_dev_slowdown: None,
222 chunk_output_at_size: None,
223 keep_all_responses: false,
224 nagle_guard: guard.into(),
225 on_stream_begin: None,
226 on_stream_end: None,
227 pre_nagle_hook: nagle_hooks.0,
228 post_nagle_hook: nagle_hooks.1,
229 primary_stream_id: Arc::new(AtomicU64::new(0)),
230 service_name,
231 slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
232 streams: Arc::new(ConcurrentHashMap::default()),
233 #[cfg(debug_assertions)]
234 trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
235 }
236 }
237
238 pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
240 self.cat_dev_slowdown = slowdown;
241 }
242
243 pub const fn should_keep_all_responses(&mut self) {
246 self.keep_all_responses = true;
247 }
248
249 pub const fn set_keep_all_responses(&mut self, keep: bool) {
251 self.keep_all_responses = keep;
252 }
253
254 pub fn set_primary_stream(&mut self, stream_id: u64) {
256 self.primary_stream_id.store(stream_id, Ordering::Release);
257 }
258
259 #[must_use]
260 pub const fn chunk_output_at_size(&self) -> Option<usize> {
261 self.chunk_output_at_size
262 }
263
264 pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
265 self.chunk_output_at_size = new_size;
266 }
267
268 #[must_use]
269 pub const fn slowloris_timeout(&self) -> Duration {
270 self.slowloris_timeout
271 }
272 pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
273 self.slowloris_timeout = slowloris_timeout;
274 }
275
276 #[must_use]
277 pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<()>> {
278 self.on_stream_begin.as_ref()
279 }
280
281 pub fn set_raw_on_stream_begin(
294 &mut self,
295 on_start: Option<UnderlyingOnStreamBeginService<()>>,
296 ) -> Result<(), CommonNetAPIError> {
297 if self.on_stream_begin.is_some() {
298 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
299 }
300
301 self.on_stream_begin = on_start;
302 Ok(())
303 }
304
305 pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
317 &mut self,
318 handler: HandlerTy,
319 ) -> Result<(), CommonNetAPIError>
320 where
321 HandlerParamsTy: Send + 'static,
322 HandlerTy: OnRequestStreamBeginHandler<HandlerParamsTy, ()> + Clone + Send + 'static,
323 {
324 if self.on_stream_begin.is_some() {
325 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
326 }
327
328 let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
329 self.on_stream_begin = Some(boxed);
330 Ok(())
331 }
332
333 pub fn set_on_stream_begin_service<ServiceTy>(
345 &mut self,
346 service_ty: ServiceTy,
347 ) -> Result<(), CommonNetAPIError>
348 where
349 ServiceTy: Clone
350 + Send
351 + Service<RequestStreamEvent<()>, Response = bool, Error = CatBridgeError>
352 + 'static,
353 ServiceTy::Future: Send + 'static,
354 {
355 if self.on_stream_begin.is_some() {
356 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
357 }
358
359 self.on_stream_begin = Some(BoxCloneService::new(service_ty));
360 Ok(())
361 }
362
363 pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
370 &mut self,
371 layer: LayerTy,
372 ) -> Result<(), CommonNetAPIError>
373 where
374 LayerTy: Layer<UnderlyingOnStreamBeginService<()>, Service = ServiceTy>,
375 ServiceTy: Service<RequestStreamEvent<()>, Response = bool, Error = CatBridgeError>
376 + Clone
377 + Send
378 + 'static,
379 <LayerTy::Service as Service<RequestStreamEvent<()>>>::Future: Send + 'static,
380 {
381 let Some(srvc) = self.on_stream_begin.take() else {
382 return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
383 };
384
385 self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
386 Ok(())
387 }
388
389 #[must_use]
390 pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<()>> {
391 self.on_stream_end.as_ref()
392 }
393
394 pub fn set_raw_on_stream_end(
407 &mut self,
408 on_end: Option<UnderlyingOnStreamEndService<()>>,
409 ) -> Result<(), CommonNetAPIError> {
410 if self.on_stream_end.is_some() {
411 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
412 }
413
414 self.on_stream_end = on_end;
415 Ok(())
416 }
417
418 pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
430 &mut self,
431 handler: HandlerTy,
432 ) -> Result<(), CommonNetAPIError>
433 where
434 HandlerParamsTy: Send + 'static,
435 HandlerTy: OnRequestStreamEndHandler<HandlerParamsTy, ()> + Clone + Send + 'static,
436 {
437 if self.on_stream_end.is_some() {
438 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
439 }
440
441 let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
442 self.on_stream_end = Some(boxed);
443 Ok(())
444 }
445
446 pub fn set_on_stream_end_service<ServiceTy>(
458 &mut self,
459 service_ty: ServiceTy,
460 ) -> Result<(), CommonNetAPIError>
461 where
462 ServiceTy: Clone
463 + Send
464 + Service<RequestStreamEvent<()>, Response = (), Error = CatBridgeError>
465 + 'static,
466 ServiceTy::Future: Send + 'static,
467 {
468 if self.on_stream_end.is_some() {
469 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
470 }
471
472 self.on_stream_end = Some(BoxCloneService::new(service_ty));
473 Ok(())
474 }
475
476 pub fn layer_on_stream_end<LayerTy, ServiceTy>(
483 &mut self,
484 layer: LayerTy,
485 ) -> Result<(), CommonNetAPIError>
486 where
487 LayerTy: Layer<UnderlyingOnStreamEndService<()>, Service = ServiceTy>,
488 ServiceTy: Service<RequestStreamEvent<()>, Response = (), Error = CatBridgeError>
489 + Clone
490 + Send
491 + 'static,
492 <LayerTy::Service as Service<RequestStreamEvent<()>>>::Future: Send + 'static,
493 {
494 let Some(srvc) = self.on_stream_end.take() else {
495 return Err(CommonNetAPIError::OnStreamEndNotRegistered);
496 };
497
498 self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
499 Ok(())
500 }
501
502 pub async fn bind<AddrTy: ToSocketAddrs>(&self, address: AddrTy) -> Result<(), CatBridgeError> {
517 let listener = TcpListener::bind(address).await.map_err(NetworkError::IO)?;
518
519 let client_address = listener.local_addr().map_err(NetworkError::IO)?;
520 let cloned_stream_begin = self.on_stream_begin.clone();
521 let cloned_stream_end = self.on_stream_end.clone();
522 let cloned_nagle_guard = self.nagle_guard.clone();
523 let cloned_slowerloris_timeout = self.slowloris_timeout;
524 let streams_ref = self.streams.clone();
525 let primary_stream_id_ref = self.primary_stream_id.clone();
526 let cloned_chunk_output_at_size = self.chunk_output_at_size;
527 let cloned_pre_nagle_hook = self.pre_nagle_hook;
528 let cloned_post_nagle_hook = self.post_nagle_hook;
529 #[cfg(debug_assertions)]
530 let cloned_trace = self.trace_during_debug;
531 let cloned_service_name = self.service_name;
532 let cloned_cat_dev_slowdown = self.cat_dev_slowdown;
533
534 TaskBuilder::new()
535 .name("cat_dev::net::tcp_client::bind().loop")
536 .spawn(async move {
537 loop {
538 let (stream, server_address) = match listener.accept().await {
539 Ok(tuple) => tuple,
540 Err(cause) => {
541 warn!(
542 ?cause,
543 client.address = %client_address,
544 "cat_dev::net::tcp_client::bind(): Failed to accept connection!",
545 );
546 continue;
547 }
548 };
549 let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
550
551 let cloned_cloned_stream_begin = cloned_stream_begin.clone();
552 let cloned_cloned_stream_end = cloned_stream_end.clone();
553 let cloned_cloned_nagle_guard = cloned_nagle_guard.clone();
554 let cloned_streams_ref = streams_ref.clone();
555 let cloned_primary_stream_id_ref = primary_stream_id_ref.clone();
556
557 if let Err(cause) = TaskBuilder::new()
558 .name("cat_dev::net::tcp_client::bind().connection.handle")
559 .spawn(async move {
560 if let Err(cause) = Self::handle_tcp_stream(
561 stream,
562 stream_id,
563 server_address,
564 cloned_cloned_stream_begin,
565 cloned_cloned_stream_end,
566 cloned_cloned_nagle_guard,
567 cloned_slowerloris_timeout,
568 cloned_streams_ref,
569 cloned_primary_stream_id_ref,
570 cloned_chunk_output_at_size,
571 cloned_pre_nagle_hook,
572 cloned_post_nagle_hook,
573 cloned_cat_dev_slowdown,
574 #[cfg(debug_assertions)]
575 cloned_trace,
576 )
577 .instrument(error_span!(
578 "CatDevTCPClientConnect",
579 client.address = %client_address,
580 server.address = %server_address,
581 client.service = cloned_service_name,
582 stream.id = stream_id,
583 stream.stream_type = "client",
584 ))
585 .await
586 {
587 warn!(
588 ?cause,
589 client.address = %client_address,
590 server.address = %server_address,
591 client.service = cloned_service_name,
592 "Error escaped while handling TCP Connection.",
593 );
594 }
595 }) {
596 warn!(
597 ?cause,
598 client.address = %client_address,
599 server.address = %server_address,
600 client.service = cloned_service_name,
601 "Error handling client connection, no task could be allocated.",
602 );
603 }
604
605 trace!(
606 server.address = %server_address,
607 client.address = %client_address,
608 "cat_dev::net::tcp_client::bind(): received connection (listener.accept())",
609 );
610 }
611 })
612 .map_err(CatBridgeError::SpawnFailure)?;
613
614 Ok(())
615 }
616
617 pub async fn wait_for_connection(&self) {
620 while self.get_active_sid().await.is_err() {
622 sleep(Duration::from_secs(1)).await;
623 }
624 }
625
626 pub async fn connect<AddrTy: ToSocketAddrs>(
637 &self,
638 address: AddrTy,
639 ) -> Result<u64, CatBridgeError> {
640 let raw_stream = TcpStream::connect(address)
641 .await
642 .map_err(NetworkError::IO)?;
643 let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
644 let remote_address = raw_stream.peer_addr().map_err(NetworkError::IO)?;
645 let local_address = raw_stream.local_addr().map_err(NetworkError::IO)?;
646 trace!(
647 server.address = %remote_address,
648 client.address = %local_address,
649 stream.id = stream_id,
650 stream.stream_type = "client",
651 "cat_dev::net::tcp_client::connect(): started connection (TcpStream::connect())",
652 );
653
654 let cloned_stream_begin = self.on_stream_begin.clone();
655 let cloned_stream_end = self.on_stream_end.clone();
656 let cloned_nagle_guard = self.nagle_guard.clone();
657 let cloned_slowerloris_timeout = self.slowloris_timeout;
658 let streams_ref = self.streams.clone();
659 let primary_stream_id_ref = self.primary_stream_id.clone();
660 let cloned_chunk_output_at_size = self.chunk_output_at_size;
661 let cloned_pre_nagle_hook = self.pre_nagle_hook;
662 let cloned_post_nagle_hook = self.post_nagle_hook;
663 #[cfg(debug_assertions)]
664 let cloned_trace = self.trace_during_debug;
665 let cloned_service_name = self.service_name;
666 let cloned_cat_dev_slowdown = self.cat_dev_slowdown;
667
668 TaskBuilder::new()
669 .name("cat_dev::net::tcp_client::connect().connection.handle")
670 .spawn(async move {
671 if let Err(cause) = Self::handle_tcp_stream(
672 raw_stream,
673 stream_id,
674 remote_address,
675 cloned_stream_begin,
676 cloned_stream_end,
677 cloned_nagle_guard,
678 cloned_slowerloris_timeout,
679 streams_ref,
680 primary_stream_id_ref,
681 cloned_chunk_output_at_size,
682 cloned_pre_nagle_hook,
683 cloned_post_nagle_hook,
684 cloned_cat_dev_slowdown,
685 #[cfg(debug_assertions)]
686 cloned_trace,
687 )
688 .instrument(error_span!(
689 "CatDevTCPClientConnect",
690 client.address = %local_address,
691 server.address = %remote_address,
692 client.service = cloned_service_name,
693 stream.id = stream_id,
694 stream.stream_type = "client",
695 ))
696 .await
697 {
698 warn!(
699 ?cause,
700 client.address = %local_address,
701 server.address = %remote_address,
702 client.service = cloned_service_name,
703 "Error escaped while handling TCP Connection.",
704 );
705 }
706 })
707 .map_err(CatBridgeError::SpawnFailure)?;
708
709 Ok(stream_id)
710 }
711
712 pub async fn send<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
731 &self,
732 body: BodyTy,
733 wait_for_response_timeout: Option<Duration>,
734 ) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
735 let mut request = Request::new_with_state(
737 body.try_into().map_err(|cause| {
738 CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
739 })?,
740 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
741 (),
742 None,
743 );
744 let req_id = RequestID::generate();
745 request.extensions_mut().insert(req_id.clone());
746
747 self.common_send(request, req_id, wait_for_response_timeout)
748 .await
749 }
750
751 pub async fn send_with_read_amount<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
770 &self,
771 body: BodyTy,
772 wait_for_response_timeout: Option<Duration>,
773 explicit_read_amount: usize,
774 ) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
775 let mut request = Request::new_with_state_and_read_amount(
777 body.try_into().map_err(|cause| {
778 CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
779 })?,
780 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
781 (),
782 None,
783 explicit_read_amount,
784 );
785 let req_id = RequestID::generate();
786 request.extensions_mut().insert(req_id.clone());
787
788 self.common_send(request, req_id, wait_for_response_timeout)
789 .await
790 }
791
792 pub async fn broadcast_send<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
800 &self,
801 body: BodyTy,
802 wait_for_response_timeout: Duration,
803 ) -> Result<FnvHashMap<u64, Option<Response>>, CatBridgeError> {
804 let mut request = Request::new_with_state(
806 body.try_into().map_err(|cause| {
807 CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
808 })?,
809 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
810 (),
811 None,
812 );
813 let req_id = RequestID::generate();
814 request.extensions_mut().insert(req_id.clone());
815
816 let mut ids = FnvHashSet::default();
817 self.streams
818 .iter_async(|stream_id, _stream| {
819 ids.insert(*stream_id);
820 true
821 })
822 .await;
823
824 let mut tasks = Vec::with_capacity(ids.len());
825 for id in &ids {
826 tasks.push(self.send_to_stream(*id, request.clone(), wait_for_response_timeout));
827 }
828 join_all(tasks)
831 .await
832 .into_iter()
833 .collect::<Result<(), NetworkError>>()?;
834
835 let mut response_tasks = Vec::with_capacity(ids.len());
836 for id in &ids {
837 response_tasks.push(self.get_response_from_stream(*id, req_id.clone()));
838 }
839 let responses = timeout(wait_for_response_timeout, join_all(response_tasks))
840 .await
841 .map_err(|_| NetworkError::Timeout(wait_for_response_timeout))?;
842
843 let mut map =
844 FnvHashMap::with_capacity_and_hasher(ids.len(), BuildHasherDefault::default());
845 for (got_stream_id, response) in responses {
846 map.insert(got_stream_id, response);
847 }
848 Ok(map)
849 }
850
851 pub async fn receive(&self, wait_until: Duration) -> Result<Option<Response>, NetworkError> {
861 let active_sid = self.get_active_sid().await?;
862
863 let mut tasks;
864 if self.keep_all_responses {
865 tasks = vec![self.get_any_response_from_stream(active_sid)];
866 } else {
867 let mut ids = FnvHashSet::default();
868 self.streams
869 .iter_async(|stream_id, _stream| {
870 ids.insert(*stream_id);
871 true
872 })
873 .await;
874
875 tasks = Vec::with_capacity(ids.len());
876 for id in ids {
877 tasks.push(self.get_any_response_from_stream(id));
878 }
879 }
880 let responses = timeout(wait_until, join_all(tasks))
881 .await
882 .map_err(|_| NetworkError::Timeout(wait_until))?;
883
884 for (got_stream_id, _, response) in responses {
885 if got_stream_id == active_sid {
886 return Ok(response);
887 }
888 }
889
890 Ok(None)
891 }
892
893 pub async fn take_all_response_for_request_id(
899 &self,
900 request_id: &RequestID,
901 wait_for: Duration,
902 ) -> FnvHashMap<u64, Option<Response>> {
903 let mut ids = FnvHashSet::default();
904 self.streams
905 .iter_async(|stream_id, _stream| {
906 ids.insert(*stream_id);
907 true
908 })
909 .await;
910
911 let mut tasks = Vec::with_capacity(ids.len());
912 for id in &ids {
913 tasks.push(timeout(
914 wait_for,
915 self.get_response_from_stream(*id, request_id.clone()),
916 ));
917 }
918
919 let mut results: FnvHashMap<u64, Option<Response>> =
920 join_all(tasks).await.into_iter().flatten().collect();
921 for id in ids {
922 results.entry(id).or_insert(None);
923 }
924 results
925 }
926
927 async fn common_send(
928 &self,
929 mock_req: Request<()>,
930 req_id: RequestID,
931 wait_for_response_timeout: Option<Duration>,
932 ) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
933 let active_sid = self.get_active_sid().await?;
934
935 let mut ids = FnvHashSet::default();
936 self.streams
937 .iter_async(|stream_id, _stream| {
938 ids.insert(*stream_id);
939 true
940 })
941 .await;
942
943 let mut tasks = Vec::with_capacity(ids.len());
944 for id in &ids {
945 tasks.push(self.send_to_stream(
946 *id,
947 mock_req.clone(),
948 wait_for_response_timeout.unwrap_or(DEFAULT_SLOWLORIS_TIMEOUT),
949 ));
950 }
951 join_all(tasks)
954 .await
955 .into_iter()
956 .collect::<Result<(), NetworkError>>()?;
957
958 match wait_for_response_timeout {
959 None | Some(EMPTY_TIMEOUT) => Ok((active_sid, req_id, None)),
961 Some(duration) => {
962 let mut tasks;
963 if self.keep_all_responses {
965 tasks = vec![self.get_response_from_stream(active_sid, req_id.clone())];
966 } else {
967 tasks = Vec::with_capacity(ids.len());
968 for id in ids {
969 tasks.push(self.get_response_from_stream(id, req_id.clone()));
970 }
971 }
972 let responses = timeout(duration, join_all(tasks))
973 .await
974 .map_err(|_| NetworkError::Timeout(duration))?;
975
976 for (got_stream_id, response) in responses {
977 if got_stream_id == active_sid {
978 return Ok((active_sid, req_id, response));
979 }
980 }
981
982 Ok((active_sid, req_id, None))
983 }
984 }
985 }
986
987 #[allow(
988 clippy::too_many_arguments,
998 )]
999 async fn handle_tcp_stream(
1000 mut stream: TcpStream,
1001 stream_id: u64,
1002 remote_address: SocketAddr,
1003 on_stream_begin: Option<UnderlyingOnStreamBeginService<()>>,
1004 on_stream_end: Option<UnderlyingOnStreamEndService<()>>,
1005 nagle_guard: NagleGuard,
1006 slowloris_timeout: Duration,
1007 stream_lists: Arc<ConcurrentHashMap<u64, TCPClientStream>>,
1008 active_stream_ptr: Arc<AtomicU64>,
1009 chunk_output_on_size: Option<usize>,
1010 pre_hook: Option<&'static dyn PreNagleFnTy>,
1011 post_hook: Option<&'static dyn PostNagleFnTy>,
1012 cat_dev_slowdown: Option<Duration>,
1013 #[cfg(debug_assertions)] trace_io: bool,
1014 ) -> Result<(), CatBridgeError> {
1015 let mut receive_packets_to_send: BoundedReceiver<RequestStreamMessage>;
1019 let (response_sink_send, response_sink_recv) = bounded_channel(128);
1020 {
1021 let (mut sender, receiver) = bounded_channel(128);
1022
1023 if Self::initialize_stream(
1027 on_stream_begin,
1028 &mut sender,
1029 &remote_address,
1030 &stream,
1031 stream_id,
1032 )
1033 .await?
1034 {
1035 return Ok(());
1036 }
1037
1038 let mut active_stream =
1039 TCPClientStream::new(remote_address, sender, receiver, response_sink_recv);
1040 receive_packets_to_send = active_stream
1041 .steal_send_requests_receiver()
1042 .ok_or_else(|| CatBridgeError::ClosedChannel)?;
1043
1044 std::mem::drop(stream_lists.insert_async(stream_id, active_stream).await);
1045 _ = active_stream_ptr.compare_exchange(
1047 0,
1048 stream_id,
1049 Ordering::AcqRel,
1050 Ordering::Acquire,
1051 );
1052 }
1053
1054 let _guard = on_stream_end
1056 .map(|service| DisconnectAsyncDropClient::new(service, (), remote_address, stream_id));
1057
1058 let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
1059 let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
1062 let mut cached_request_id: Option<RequestID> = None;
1063 let mut nagle_overrides: VecDeque<Option<NagleGuard>> = VecDeque::with_capacity(128);
1064
1065 loop {
1066 tokio::select! {
1067 opt = receive_packets_to_send.recv() => {
1068 if Self::handle_client_write_to_connection(
1070 chunk_output_on_size,
1071 opt,
1072 pre_hook,
1073 &mut cached_request_id,
1074 stream_id,
1075 &mut stream,
1076 &mut nagle_overrides,
1077 cat_dev_slowdown,
1078 #[cfg(debug_assertions)]
1079 trace_io,
1080 ).await? {
1081 break;
1082 }
1083 }
1084 read_res = stream.read_buf(&mut buff) => {
1085 let read_bytes = read_res.map_err(NetworkError::IO)?;
1086 buff.truncate(read_bytes);
1087
1088 if buff.is_empty() {
1089 buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
1090 continue;
1091 }
1092
1093 if Self::handle_client_read_from_connection(
1094 buff,
1095 &nagle_guard,
1096 &mut nagle_overrides,
1097 slowloris_timeout,
1098 &mut nagle_cache,
1099 response_sink_send.clone(),
1100 post_hook,
1101 &mut cached_request_id,
1102 stream_id,
1103 #[cfg(debug_assertions)]
1104 trace_io,
1105 ).await? {
1106 break;
1107 }
1108 buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
1109 }
1110 }
1111 }
1112
1113 Ok(())
1114 }
1115
1116 async fn initialize_stream(
1117 on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<()>>,
1118 send_channel: &mut BoundedSender<RequestStreamMessage>,
1119 remote_address: &SocketAddr,
1120 tcp_stream: &TcpStream,
1121 stream_id: u64,
1122 ) -> Result<bool, CatBridgeError> {
1123 tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
1124
1125 if let Some(mut handle) = on_stream_begin_handler
1126 && !handle
1127 .call(RequestStreamEvent::new_with_state(
1128 send_channel.clone(),
1129 *remote_address,
1130 Some(stream_id),
1131 (),
1132 ))
1133 .await?
1134 {
1135 trace!("handler failed on stream begin hook");
1136 return Ok(true);
1137 }
1138
1139 Ok(false)
1140 }
1141
1142 #[allow(
1143 clippy::too_many_arguments,
1148 )]
1149 async fn handle_client_read_from_connection<'data>(
1150 mut buff: BytesMut,
1151 nagle_guard: &'data NagleGuard,
1152 nagle_overrides: &mut VecDeque<Option<NagleGuard>>,
1153 slowloris_timeout: Duration,
1154 nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
1155 response_output: BoundedSender<(Option<RequestID>, Response)>,
1156 cloned_post_nagle: Option<&'static dyn PostNagleFnTy>,
1157 cached_request_id: &mut Option<RequestID>,
1158 stream_id: u64,
1159 #[cfg(debug_assertions)] trace_io: bool,
1160 ) -> Result<bool, CatBridgeError> {
1161 if let Some(convert_fn) = cloned_post_nagle {
1162 buff = BytesMut::from(block_in_place(|| (*convert_fn)(stream_id, buff.freeze())));
1163 }
1164
1165 #[cfg(debug_assertions)]
1166 {
1167 if trace_io {
1168 debug!(
1169 body.hex = format!("{:02x?}", buff),
1170 body.str = String::from_utf8_lossy(&buff).to_string(),
1171 "cat-dev-trace-input-tcp-client",
1172 );
1173 }
1174 }
1175
1176 let start_time = now();
1179 if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
1180 let total_duration = start_time
1185 .duration_since(old_start_time)
1186 .unwrap_or(Duration::from_secs(0));
1187 if total_duration > slowloris_timeout {
1188 debug!(
1189 cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
1190 "slowloris-detected",
1191 );
1192 return Ok(true);
1193 }
1194
1195 existing_buff.extend(buff);
1196 buff = existing_buff;
1197 }
1198
1199 let mut current_nagle_guard = if let Some(Some(guard)) = nagle_overrides.front() {
1200 guard
1201 } else {
1202 nagle_guard
1203 };
1204
1205 while let Some((start_of_packet, end_of_packet)) = current_nagle_guard.split(&buff)? {
1206 let remaining_buff = buff.split_off(end_of_packet);
1207 let _start_of_buff = buff.split_to(start_of_packet);
1208 let req_body = buff.freeze();
1209 buff = remaining_buff;
1210
1211 if let Err(cause) = response_output
1212 .send((cached_request_id.take(), Response::new_with_body(req_body)))
1213 .await
1214 {
1215 warn!(
1216 ?cause,
1217 "internal queue failure will not send disconnect/response."
1218 );
1219 }
1220
1221 if !nagle_overrides.is_empty() {
1222 nagle_overrides.pop_front();
1223 current_nagle_guard = if let Some(Some(guard)) = nagle_overrides.front() {
1224 guard
1225 } else {
1226 nagle_guard
1227 };
1228 }
1229 }
1230
1231 if !buff.is_empty() {
1232 _ = nagle_cache.insert((buff, start_time));
1233 }
1234
1235 Ok(false)
1236 }
1237
1238 #[allow(
1239 clippy::too_many_arguments,
1241 )]
1242 async fn handle_client_write_to_connection(
1243 chunk_output_on_size: Option<usize>,
1244 to_send_to_client_opt: Option<RequestStreamMessage>,
1245 pre_hook: Option<&'static dyn PreNagleFnTy>,
1246 cached_request_id: &mut Option<RequestID>,
1247 stream_id: u64,
1248 raw_stream: &mut TcpStream,
1249 nagle_overrides: &mut VecDeque<Option<NagleGuard>>,
1250 cat_dev_slowdown: Option<Duration>,
1251 #[cfg(debug_assertions)] trace_io: bool,
1252 ) -> Result<bool, CatBridgeError> {
1253 let Some(to_send_to_client) = to_send_to_client_opt else {
1254 return Ok(true);
1255 };
1256
1257 match to_send_to_client {
1258 RequestStreamMessage::Disconnect => {
1259 _ = cached_request_id.take();
1261 trace!("stream-disconnect-message");
1262 Ok(true)
1263 }
1264 RequestStreamMessage::Request(mut req) => {
1265 if let Some(explicit_read) = req.explicit_read_amount() {
1266 nagle_overrides.push_back(Some(NagleGuard::StaticSize(explicit_read)));
1267 } else {
1268 nagle_overrides.push_back(None);
1269 }
1270 if !req.body().is_empty() {
1271 if let Ok(req_id) = RequestID::from_request_parts(&mut req).await {
1272 _ = cached_request_id.insert(req_id);
1273 }
1274 let messages = if let Some(size) = chunk_output_on_size {
1275 req.body_owned()
1276 .chunks(size)
1277 .map(BytesMut::from)
1278 .collect::<Vec<_>>()
1279 } else {
1280 vec![BytesMut::from(req.body_owned())]
1281 };
1282
1283 for message in messages {
1284 #[cfg(debug_assertions)]
1285 if trace_io {
1286 debug!(
1287 body.hex = format!("{message:02x?}"),
1288 body.str = String::from_utf8_lossy(&message).to_string(),
1289 "cat-dev-trace-output-tcp-client",
1290 );
1291 }
1292
1293 let mut full_response = message.clone();
1294 if let Some(pre) = pre_hook {
1295 block_in_place(|| pre(stream_id, &mut full_response));
1296 }
1297 if let Some(slowdown) = cat_dev_slowdown {
1298 sleep(slowdown).await;
1299 }
1300
1301 raw_stream.writable().await.map_err(NetworkError::IO)?;
1302 raw_stream
1303 .write_all(&full_response)
1304 .await
1305 .map_err(NetworkError::IO)?;
1306 }
1307 }
1308
1309 Ok(false)
1310 }
1311 }
1312 }
1313
1314 async fn send_to_stream(
1320 &self,
1321 stream_id: u64,
1322 mut base_request: Request<()>,
1323 timeout: Duration,
1324 ) -> Result<(), NetworkError> {
1325 if let Some(stream) = self.streams.get_async(&stream_id).await {
1326 base_request.update_request_source(stream.server_address(), Some(stream_id));
1327 stream
1328 .send_timeout(RequestStreamMessage::Request(base_request), timeout)
1329 .await
1330 .map_err(|cause| {
1331 CommonNetClientNetworkError::CannotQueueSend(format!("{cause:?}")).into()
1332 })
1333 } else {
1334 Ok(())
1336 }
1337 }
1338
1339 async fn get_any_response_from_stream(
1342 &self,
1343 stream_id: u64,
1344 ) -> (u64, Option<RequestID>, Option<Response>) {
1345 if let Some(mut stream) = self.streams.get_async(&stream_id).await {
1346 let Some((opt_req_id, response)) = stream.response_channel_mut().recv().await else {
1347 return (stream_id, None, None);
1348 };
1349
1350 (stream_id, opt_req_id, Some(response))
1351 } else {
1352 (stream_id, None, None)
1354 }
1355 }
1356
1357 async fn get_response_from_stream(
1360 &self,
1361 stream_id: u64,
1362 request_id: RequestID,
1363 ) -> (u64, Option<Response>) {
1364 if let Some(mut stream) = self.streams.get_async(&stream_id).await {
1365 while let Some((opt_req_id, response)) = stream.response_channel_mut().recv().await {
1366 if let Some(got_req_id) = opt_req_id
1367 && got_req_id == request_id
1368 {
1369 return (stream_id, Some(response));
1370 }
1371 }
1372
1373 (stream_id, None)
1374 } else {
1375 (stream_id, None)
1377 }
1378 }
1379
1380 async fn get_active_sid(&self) -> Result<u64, CommonNetClientNetworkError> {
1382 let active_sid = self.primary_stream_id.load(Ordering::Acquire);
1383 if active_sid == 0 {
1384 return Err(CommonNetClientNetworkError::NotConnectedToServer);
1385 }
1386
1387 if !self.streams.contains_async(&active_sid).await {
1388 let mut oldest_stream = None;
1389
1390 self.streams
1391 .iter_async(|stream_id, stream| {
1392 if let Some((_strm_id, strm_created_at)) = oldest_stream {
1393 if stream.opened_at() < strm_created_at {
1394 _ = oldest_stream.insert((*stream_id, stream.opened_at()));
1395 }
1396 } else {
1397 _ = oldest_stream.insert((*stream_id, stream.opened_at()));
1398 }
1399 true
1400 })
1401 .await;
1402 }
1403
1404 Ok(active_sid)
1405 }
1406}
1407
1408impl Debug for TCPClient {
1409 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
1410 let mut tcp_dbg_struct = fmt.debug_struct("TCPClient");
1411
1412 tcp_dbg_struct
1413 .field("cat_dev_slowdown", &self.cat_dev_slowdown)
1414 .field("chunk_output_at_size", &self.chunk_output_at_size)
1415 .field("keep_all_responses", &self.keep_all_responses)
1416 .field("nagle_guard", &self.nagle_guard)
1417 .field("has_on_stream_begin", &self.on_stream_begin.is_some())
1418 .field("has_on_stream_end", &self.on_stream_end.is_some())
1419 .field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
1420 .field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
1421 .field(
1422 "primary_stream_id",
1423 &self.primary_stream_id.load(Ordering::Relaxed),
1424 )
1425 .field("streams", &self.streams)
1426 .field("service_name", &self.service_name)
1427 .field("slowloris_timeout", &self.slowloris_timeout);
1428
1429 #[cfg(debug_assertions)]
1430 {
1431 tcp_dbg_struct.field("trace_during_debug", &self.trace_during_debug);
1432 }
1433
1434 tcp_dbg_struct.finish()
1435 }
1436}
1437
1438const TCP_CLIENT_FIELDS: &[NamedField<'static>] = &[
1439 NamedField::new("cat_dev_slowdown"),
1440 NamedField::new("chunk_output_at_size"),
1441 NamedField::new("keep_all_responses"),
1442 NamedField::new("nagle_guard"),
1443 NamedField::new("has_on_stream_begin"),
1444 NamedField::new("has_on_stream_end"),
1445 NamedField::new("has_pre_nagle_hook"),
1446 NamedField::new("has_post_nagle_hook"),
1447 NamedField::new("primary_stream_id"),
1448 NamedField::new("streams"),
1449 NamedField::new("service_name"),
1450 NamedField::new("slowloris_timeout"),
1451 #[cfg(debug_assertions)]
1452 NamedField::new("trace_during_debug"),
1453];
1454
1455impl Structable for TCPClient {
1456 fn definition(&self) -> StructDef<'_> {
1457 StructDef::new_static("TCPClient", Fields::Named(TCP_CLIENT_FIELDS))
1458 }
1459}
1460
1461impl Valuable for TCPClient {
1462 fn as_value(&self) -> Value<'_> {
1463 Value::Structable(self)
1464 }
1465
1466 fn visit(&self, visitor: &mut dyn Visit) {
1467 let mut valuable_map = FnvHashMap::default();
1468 self.streams.iter_sync(|stream_id, stream| {
1469 valuable_map.insert(*stream_id, stream.to_valuable());
1470 true
1471 });
1472
1473 visitor.visit_named_fields(&NamedValues::new(
1474 TCP_CLIENT_FIELDS,
1475 &[
1476 Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
1477 format!("{}ms", slowdown.as_millis())
1478 } else {
1479 "<none>".to_string()
1480 }),
1481 Valuable::as_value(&self.chunk_output_at_size),
1482 Valuable::as_value(&self.keep_all_responses),
1483 Valuable::as_value(&self.nagle_guard),
1484 Valuable::as_value(&self.on_stream_begin.is_some()),
1485 Valuable::as_value(&self.on_stream_end.is_some()),
1486 Valuable::as_value(&self.pre_nagle_hook.is_some()),
1487 Valuable::as_value(&self.post_nagle_hook.is_some()),
1488 Valuable::as_value(&self.primary_stream_id.load(Ordering::Relaxed)),
1489 Valuable::as_value(&valuable_map),
1490 Valuable::as_value(&self.service_name),
1491 Valuable::as_value(&self.slowloris_timeout.as_secs()),
1492 #[cfg(debug_assertions)]
1493 Valuable::as_value(&self.trace_during_debug),
1494 ],
1495 ));
1496 }
1497}
1498
1499struct TCPClientStream {
1505 remote_address: SocketAddr,
1507 response_channel: BoundedReceiver<(Option<RequestID>, Response)>,
1509 send_requests_receiver: Option<BoundedReceiver<RequestStreamMessage>>,
1512 send_requests: BoundedSender<RequestStreamMessage>,
1514 time_opened: Instant,
1516}
1517
1518impl TCPClientStream {
1519 #[must_use]
1521 pub fn new(
1522 remote_address: SocketAddr,
1523 sender: BoundedSender<RequestStreamMessage>,
1524 receiver: BoundedReceiver<RequestStreamMessage>,
1525 response_channel: BoundedReceiver<(Option<RequestID>, Response)>,
1526 ) -> Self {
1527 Self {
1528 remote_address,
1529 response_channel,
1530 send_requests_receiver: Some(receiver),
1531 send_requests: sender,
1532 time_opened: Instant::now(),
1533 }
1534 }
1535
1536 #[must_use]
1537 pub const fn to_valuable(&self) -> TCPClientStreamValuable {
1538 TCPClientStreamValuable {
1539 receiver_stolen: self.send_requests_receiver.is_none(),
1540 time_opened: self.time_opened,
1541 }
1542 }
1543
1544 pub const fn server_address(&self) -> SocketAddr {
1546 self.remote_address
1547 }
1548
1549 #[must_use]
1550 pub const fn response_channel_mut(
1551 &mut self,
1552 ) -> &mut BoundedReceiver<(Option<RequestID>, Response)> {
1553 &mut self.response_channel
1554 }
1555
1556 #[must_use]
1558 pub fn steal_send_requests_receiver(
1559 &mut self,
1560 ) -> Option<BoundedReceiver<RequestStreamMessage>> {
1561 self.send_requests_receiver.take()
1562 }
1563
1564 pub async fn send_timeout(
1566 &self,
1567 message: RequestStreamMessage,
1568 timeout: Duration,
1569 ) -> Result<(), SendTimeoutError<RequestStreamMessage>> {
1570 self.send_requests.send_timeout(message, timeout).await
1571 }
1572
1573 pub const fn opened_at(&self) -> Instant {
1575 self.time_opened
1576 }
1577}
1578
1579impl Debug for TCPClientStream {
1580 fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
1581 fmt.debug_struct("TCPClientStream")
1582 .field("receiver_stolen", &self.send_requests_receiver.is_none())
1583 .field("time_opened", &self.time_opened)
1584 .finish_non_exhaustive()
1585 }
1586}
1587
1588impl PartialEq for TCPClientStream {
1589 fn eq(&self, other: &Self) -> bool {
1590 self.time_opened == other.time_opened
1591 }
1592}
1593
1594impl PartialOrd for TCPClientStream {
1595 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1596 Some(self.time_opened.cmp(&other.time_opened))
1597 }
1598}
1599
1600const TCP_CLIENT_STREAM_FIELDS: &[NamedField<'static>] = &[
1601 NamedField::new("receiver_stolen"),
1602 NamedField::new("time_opened"),
1603];
1604
1605impl Structable for TCPClientStream {
1606 fn definition(&self) -> StructDef<'_> {
1607 StructDef::new_static("TCPClientStream", Fields::Named(TCP_CLIENT_STREAM_FIELDS))
1608 }
1609}
1610
1611impl Valuable for TCPClientStream {
1612 fn as_value(&self) -> Value<'_> {
1613 Value::Structable(self)
1614 }
1615
1616 fn visit(&self, visitor: &mut dyn Visit) {
1617 visitor.visit_named_fields(&NamedValues::new(
1618 TCP_CLIENT_STREAM_FIELDS,
1619 &[
1620 Valuable::as_value(&self.send_requests_receiver.is_none()),
1621 Valuable::as_value(
1622 &SystemTime::now()
1623 .checked_add(self.time_opened.elapsed())
1624 .unwrap_or_else(SystemTime::now)
1625 .duration_since(SystemTime::UNIX_EPOCH)
1626 .unwrap_or_default()
1627 .as_secs(),
1628 ),
1629 ],
1630 ));
1631 }
1632}
1633
1634struct TCPClientStreamValuable {
1635 receiver_stolen: bool,
1636 time_opened: Instant,
1637}
1638
1639impl Structable for TCPClientStreamValuable {
1640 fn definition(&self) -> StructDef<'_> {
1641 StructDef::new_static("TCPClientStream", Fields::Named(TCP_CLIENT_STREAM_FIELDS))
1642 }
1643}
1644
1645impl Valuable for TCPClientStreamValuable {
1646 fn as_value(&self) -> Value<'_> {
1647 Value::Structable(self)
1648 }
1649
1650 fn visit(&self, visitor: &mut dyn Visit) {
1651 visitor.visit_named_fields(&NamedValues::new(
1652 TCP_CLIENT_STREAM_FIELDS,
1653 &[
1654 Valuable::as_value(&self.receiver_stolen),
1655 Valuable::as_value(
1656 &SystemTime::now()
1657 .checked_add(self.time_opened.elapsed())
1658 .unwrap_or_else(SystemTime::now)
1659 .duration_since(SystemTime::UNIX_EPOCH)
1660 .unwrap_or_default()
1661 .as_secs(),
1662 ),
1663 ],
1664 ));
1665 }
1666}