1use crate::{
132 errors::{CatBridgeError, NetworkError},
133 net::{
134 DEFAULT_SLOWLORIS_TIMEOUT, SERVER_ID, STREAM_ID, TCP_READ_BUFFER_SIZE,
135 errors::{CommonNetAPIError, CommonNetNetworkError},
136 handlers::{
137 OnResponseStreamBeginHandler, OnResponseStreamEndHandler,
138 OnStreamBeginHandlerAsService, OnStreamEndHandlerAsService,
139 },
140 models::{NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
141 now,
142 server::models::{
143 DisconnectAsyncDropServer, ResponseStreamEvent, ResponseStreamMessage,
144 UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
145 },
146 },
147};
148use bytes::{Bytes, BytesMut};
149use fnv::FnvHashSet;
150use futures::future::join_all;
151use scc::HashMap as ConcurrentMap;
152use std::{
153 convert::Infallible,
154 fmt::{Debug, Formatter, Result as FmtResult},
155 net::{IpAddr, SocketAddr},
156 sync::{Arc, LazyLock, atomic::Ordering},
157 time::{Duration, SystemTime},
158};
159use tokio::{
160 io::{AsyncReadExt, AsyncWriteExt},
161 net::{TcpListener, TcpStream, ToSocketAddrs, lookup_host},
162 sync::{
163 Mutex,
164 mpsc::{Sender as BoundedSender, channel as bounded_channel},
165 },
166 task::{Builder as TaskBuilder, block_in_place},
167 time::sleep,
168};
169use tower::{Layer, Service, util::BoxCloneService};
170use tracing::{Instrument, debug, error_span, trace, warn};
171use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
172
173#[cfg(debug_assertions)]
174use crate::net::SPRIG_TRACE_IO;
175
176static OUT_OF_BAND_SENDERS: LazyLock<
179 ConcurrentMap<(u64, u64), BoundedSender<ResponseStreamMessage>>,
180> = LazyLock::new(ConcurrentMap::new);
181
182pub struct TCPServer<State: Clone + Send + Sync + 'static = ()> {
191 address_to_bind_or_connect_to: SocketAddr,
199 cat_dev_slowdown: Option<Duration>,
205 chunk_output_at_size: Option<usize>,
211 id: u64,
215 initial_service: BoxCloneService<Request<State>, Response, Infallible>,
220 nagle_guard: NagleGuard,
222 on_stream_begin: Option<UnderlyingOnStreamBeginService<State>>,
227 on_stream_end: Option<UnderlyingOnStreamEndService<State>>,
232 pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
238 post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
245 service_name: &'static str,
247 slowloris_timeout: Duration,
249 state: State,
255 #[cfg(debug_assertions)]
258 trace_during_debug: bool,
259}
260
261impl TCPServer<()> {
262 pub async fn new<AddrTy, ServiceTy>(
268 service_name: &'static str,
269 bind_addr: AddrTy,
270 initial_service: ServiceTy,
271 nagle_hooks: (
272 Option<&'static dyn PreNagleFnTy>,
273 Option<&'static dyn PostNagleFnTy>,
274 ),
275 guard: impl Into<NagleGuard>,
276 trace_io_during_debug: bool,
277 ) -> Result<Self, CommonNetAPIError>
278 where
279 AddrTy: ToSocketAddrs,
280 ServiceTy:
281 Clone + Send + Service<Request<()>, Response = Response, Error = Infallible> + 'static,
282 ServiceTy::Future: Send + 'static,
283 {
284 Self::new_with_state(
285 service_name,
286 bind_addr,
287 initial_service,
288 nagle_hooks,
289 guard,
290 (),
291 trace_io_during_debug,
292 )
293 .await
294 }
295}
296
297impl<State: Clone + Send + Sync + 'static> TCPServer<State> {
298 pub async fn out_of_bound_send(
305 server_id: u64,
306 stream_id: u64,
307 message: ResponseStreamMessage,
308 ) -> Result<(), CatBridgeError> {
309 if let Some(stream) = OUT_OF_BAND_SENDERS.get_async(&(server_id, stream_id)).await {
310 stream
311 .send(message)
312 .await
313 .map_err(NetworkError::SendQueueMessageFailure)?;
314 Ok(())
315 } else {
316 Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
317 }
318 }
319
320 pub async fn out_of_bound_broadcast(
322 server_id: u64,
323 message: ResponseStreamMessage,
324 ) -> Vec<Result<(), CatBridgeError>> {
325 let mut ids = FnvHashSet::default();
326 OUT_OF_BAND_SENDERS
328 .iter_async(|key, _value| {
329 if key.0 == server_id {
330 ids.insert(key.1);
331 }
332 true
333 })
334 .await;
335
336 let mut tasks = Vec::with_capacity(ids.len());
338 for id in ids {
339 tasks.push(Self::out_of_bound_send(server_id, id, message.clone()));
340 }
341
342 join_all(tasks).await
343 }
344
345 #[allow(unused)]
351 pub async fn new_with_state<AddrTy, ServiceTy>(
352 service_name: &'static str,
353 bind_addr: AddrTy,
354 initial_service: ServiceTy,
355 nagle_hooks: (
356 Option<&'static dyn PreNagleFnTy>,
357 Option<&'static dyn PostNagleFnTy>,
358 ),
359 guard: impl Into<NagleGuard>,
360 state: State,
361 trace_io_during_debug: bool,
362 ) -> Result<Self, CommonNetAPIError>
363 where
364 AddrTy: ToSocketAddrs,
365 ServiceTy: Clone
366 + Send
367 + Service<Request<State>, Response = Response, Error = Infallible>
368 + 'static,
369 ServiceTy::Future: Send + 'static,
370 {
371 let hosts = lookup_host(bind_addr)
372 .await
373 .map_err(CommonNetAPIError::AddressLookupError)?
374 .collect::<Vec<_>>();
375 if hosts.len() != 1 {
376 return Err(CommonNetAPIError::WrongAmountOfAddressesToBindToo(hosts));
377 }
378
379 #[cfg(not(debug_assertions))]
380 {
381 if trace_io_during_debug {
382 warn!(
383 "Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
384 );
385 }
386 }
387
388 Ok(Self {
389 address_to_bind_or_connect_to: hosts[0],
390 cat_dev_slowdown: None,
391 chunk_output_at_size: None,
392 id: SERVER_ID.fetch_add(1, Ordering::SeqCst),
393 initial_service: BoxCloneService::new(initial_service),
394 nagle_guard: guard.into(),
395 on_stream_begin: None,
396 on_stream_end: None,
397 pre_nagle_hook: nagle_hooks.0,
398 post_nagle_hook: nagle_hooks.1,
399 service_name,
400 slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
401 state,
402 #[cfg(debug_assertions)]
403 trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
404 })
405 }
406
407 #[must_use]
409 pub const fn id(&self) -> u64 {
410 self.id
411 }
412
413 #[must_use]
415 pub const fn ip(&self) -> IpAddr {
416 self.address_to_bind_or_connect_to.ip()
417 }
418
419 #[must_use]
421 pub const fn port(&self) -> u16 {
422 self.address_to_bind_or_connect_to.port()
423 }
424
425 pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
427 self.cat_dev_slowdown = slowdown;
428 }
429
430 #[must_use]
431 pub const fn chunk_output_at_size(&self) -> Option<usize> {
432 self.chunk_output_at_size
433 }
434
435 pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
436 self.chunk_output_at_size = new_size;
437 }
438
439 #[must_use]
440 pub const fn slowloris_timeout(&self) -> Duration {
441 self.slowloris_timeout
442 }
443 pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
444 self.slowloris_timeout = slowloris_timeout;
445 }
446
447 #[must_use]
448 pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<State>> {
449 self.on_stream_begin.as_ref()
450 }
451
452 pub fn set_raw_on_stream_begin(
465 &mut self,
466 on_start: Option<UnderlyingOnStreamBeginService<State>>,
467 ) -> Result<(), CommonNetAPIError> {
468 if self.on_stream_begin.is_some() {
469 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
470 }
471
472 self.on_stream_begin = on_start;
473 Ok(())
474 }
475
476 pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
488 &mut self,
489 handler: HandlerTy,
490 ) -> Result<(), CommonNetAPIError>
491 where
492 HandlerParamsTy: Send + 'static,
493 HandlerTy: OnResponseStreamBeginHandler<HandlerParamsTy, State> + Clone + Send + 'static,
494 {
495 if self.on_stream_begin.is_some() {
496 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
497 }
498
499 let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
500 self.on_stream_begin = Some(boxed);
501 Ok(())
502 }
503
504 pub fn set_on_stream_begin_service<ServiceTy>(
516 &mut self,
517 service_ty: ServiceTy,
518 ) -> Result<(), CommonNetAPIError>
519 where
520 ServiceTy: Clone
521 + Send
522 + Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
523 + 'static,
524 ServiceTy::Future: Send + 'static,
525 {
526 if self.on_stream_begin.is_some() {
527 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
528 }
529
530 self.on_stream_begin = Some(BoxCloneService::new(service_ty));
531 Ok(())
532 }
533
534 pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
541 &mut self,
542 layer: LayerTy,
543 ) -> Result<(), CommonNetAPIError>
544 where
545 LayerTy: Layer<UnderlyingOnStreamBeginService<State>, Service = ServiceTy>,
546 ServiceTy: Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
547 + Clone
548 + Send
549 + 'static,
550 <LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
551 {
552 let Some(srvc) = self.on_stream_begin.take() else {
553 return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
554 };
555
556 self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
557 Ok(())
558 }
559
560 #[must_use]
561 pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<State>> {
562 self.on_stream_end.as_ref()
563 }
564
565 pub fn set_raw_on_stream_end(
578 &mut self,
579 on_end: Option<UnderlyingOnStreamEndService<State>>,
580 ) -> Result<(), CommonNetAPIError> {
581 if self.on_stream_end.is_some() {
582 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
583 }
584
585 self.on_stream_end = on_end;
586 Ok(())
587 }
588
589 pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
601 &mut self,
602 handler: HandlerTy,
603 ) -> Result<(), CommonNetAPIError>
604 where
605 HandlerParamsTy: Send + 'static,
606 HandlerTy: OnResponseStreamEndHandler<HandlerParamsTy, State> + Clone + Send + 'static,
607 {
608 if self.on_stream_end.is_some() {
609 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
610 }
611
612 let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
613 self.on_stream_end = Some(boxed);
614 Ok(())
615 }
616
617 pub fn set_on_stream_end_service<ServiceTy>(
629 &mut self,
630 service_ty: ServiceTy,
631 ) -> Result<(), CommonNetAPIError>
632 where
633 ServiceTy: Clone
634 + Send
635 + Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
636 + 'static,
637 ServiceTy::Future: Send + 'static,
638 {
639 if self.on_stream_end.is_some() {
640 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
641 }
642
643 self.on_stream_end = Some(BoxCloneService::new(service_ty));
644 Ok(())
645 }
646
647 pub fn layer_on_stream_end<LayerTy, ServiceTy>(
654 &mut self,
655 layer: LayerTy,
656 ) -> Result<(), CommonNetAPIError>
657 where
658 LayerTy: Layer<UnderlyingOnStreamEndService<State>, Service = ServiceTy>,
659 ServiceTy: Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
660 + Clone
661 + Send
662 + 'static,
663 <LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
664 {
665 let Some(srvc) = self.on_stream_end.take() else {
666 return Err(CommonNetAPIError::OnStreamEndNotRegistered);
667 };
668
669 self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
670 Ok(())
671 }
672
673 #[must_use]
674 pub const fn initial_service(&self) -> &BoxCloneService<Request<State>, Response, Infallible> {
675 &self.initial_service
676 }
677
678 pub fn layer_initial_service<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
679 where
680 LayerTy: Layer<BoxCloneService<Request<State>, Response, Infallible>, Service = ServiceTy>,
681 ServiceTy: Service<Request<State>, Response = Response, Error = Infallible>
682 + Clone
683 + Send
684 + 'static,
685 <LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
686 {
687 self.initial_service = BoxCloneService::new(layer.layer(self.initial_service.clone()));
688 }
689
690 #[must_use]
692 pub const fn state(&self) -> &State {
693 &self.state
694 }
695
696 pub async fn connect(self) -> Result<(), CatBridgeError> {
707 loop {
708 let client_address = self.address_to_bind_or_connect_to;
710 let stream = TcpStream::connect(self.address_to_bind_or_connect_to)
711 .await
712 .map_err(NetworkError::IO)?;
713 let loggable_address = stream.local_addr().map_err(NetworkError::IO)?;
714 let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
715 trace!(
716 server.address = %loggable_address,
717 client.address = %client_address,
718 stream.id = stream_id,
719 stream.stream_type = "server",
720 "cat_dev::net::tcp_server::connect(): started connection (TcpStream::connect())",
721 );
722
723 if let Err(cause) = Self::handle_tcp_connection(
724 self.on_stream_begin.clone(),
725 self.on_stream_end.clone(),
726 self.nagle_guard.clone(),
727 self.slowloris_timeout,
728 self.initial_service.clone(),
729 stream,
730 client_address,
731 self.pre_nagle_hook,
732 self.post_nagle_hook,
733 self.chunk_output_at_size,
734 self.state.clone(),
735 self.id,
736 stream_id,
737 self.cat_dev_slowdown,
738 #[cfg(debug_assertions)]
739 self.trace_during_debug,
740 )
741 .instrument(error_span!(
742 "CatDevTCPServerConnect",
743 client.address = %client_address,
744 server.address = %loggable_address,
745 server.service = self.service_name,
746 stream.id = stream_id,
747 stream.stream_type = "server",
748 ))
749 .await
750 {
751 warn!(
752 ?cause,
753 client.address = %client_address,
754 server.address = %loggable_address,
755 server.service = self.service_name,
756 "Error escaped while handling TCP connection.",
757 );
758 }
759 }
760 }
761
762 pub async fn bind(self) -> Result<(), CatBridgeError> {
772 let loggable_address = self.address_to_bind_or_connect_to;
774 let listener = TcpListener::bind(self.address_to_bind_or_connect_to)
775 .await
776 .map_err(NetworkError::IO)?;
777
778 loop {
779 let (stream, client_address) = listener.accept().await.map_err(NetworkError::IO)?;
780 trace!(
781 server.address = %loggable_address,
782 client.address = %client_address,
783 "cat_dev::net::tcp_server::bind(): received connection (listener.accept())",
784 );
785
786 let cloned_begin_handler = self.on_stream_begin.clone();
796 let cloned_end_handler = self.on_stream_end.clone();
797 let cloned_nagle_guard = self.nagle_guard.clone();
798 let cloned_handler = self.initial_service.clone();
799 let cloned_state = self.state.clone();
800 let copied_pre_nagle_hook = self.pre_nagle_hook;
801 let copied_post_nagle_hook = self.post_nagle_hook;
802 let copied_chunk_on_size = self.chunk_output_at_size;
803 let copied_service_name = self.service_name;
804 let copied_slowloris_timeout = self.slowloris_timeout;
805 let copied_server_id = self.id;
806 let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
807 let copied_slowdown = self.cat_dev_slowdown;
808 #[cfg(debug_assertions)]
809 let trace_io = self.trace_during_debug;
810
811 TaskBuilder::new()
812 .name("cat_dev::net::tcp_server::bind().connection.handle")
813 .spawn(async move {
814 if let Err(cause) = Self::handle_tcp_connection(
815 cloned_begin_handler,
816 cloned_end_handler,
817 cloned_nagle_guard,
818 copied_slowloris_timeout,
819 cloned_handler,
820 stream,
821 client_address,
822 copied_pre_nagle_hook,
823 copied_post_nagle_hook,
824 copied_chunk_on_size,
825 cloned_state,
826 copied_server_id,
827 stream_id,
828 copied_slowdown,
829 #[cfg(debug_assertions)]
830 trace_io,
831 )
832 .instrument(error_span!(
833 "CatDevTCPServerAccept",
834 client.address = %client_address,
835 server.address = %loggable_address,
836 server.service = copied_service_name,
837 server.stream_id = stream_id,
838 ))
839 .await
840 {
841 warn!(
842 ?cause,
843 client.address = %client_address,
844 server.address = %loggable_address,
845 server.service = %copied_service_name,
846 "Error escaped while handling TCP connection.",
847 );
848 }
849 })
850 .map_err(CatBridgeError::SpawnFailure)?;
851 }
852 }
853
854 #[allow(
866 clippy::too_many_arguments,
876 )]
877 async fn handle_tcp_connection(
878 on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
879 on_stream_end_handler: Option<UnderlyingOnStreamEndService<State>>,
880 nagle_guard: NagleGuard,
881 slowloris_timeout: Duration,
882 handler: BoxCloneService<Request<State>, Response, Infallible>,
883 mut tcp_stream: TcpStream,
884 client_address: SocketAddr,
885 pre_hook_cloned: Option<&'static dyn PreNagleFnTy>,
886 post_hook_cloned: Option<&'static dyn PostNagleFnTy>,
887 chunk_output_at_size: Option<usize>,
888 state: State,
889 server_id: u64,
890 stream_id: u64,
891 cat_dev_slowdown: Option<Duration>,
892 #[cfg(debug_assertions)] trace_io: bool,
893 ) -> Result<(), CatBridgeError> {
894 let (mut send_responses, mut packets_left_to_send) =
895 bounded_channel::<ResponseStreamMessage>(128);
896
897 if Self::initialize_stream(
900 on_stream_begin_handler,
901 &mut send_responses,
902 &client_address,
903 &state,
904 &mut tcp_stream,
905 server_id,
906 stream_id,
907 )
908 .await?
909 {
910 return Ok(());
911 }
912
913 let _guard = on_stream_end_handler.map(|service| {
915 DisconnectAsyncDropServer::new(service, state.clone(), client_address, stream_id)
916 });
917
918 let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
921
922 loop {
923 let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
924 tokio::select! {
925 received = packets_left_to_send.recv() => {
926 if Self::handle_server_write_to_connection(
927 &mut tcp_stream,
928 chunk_output_at_size,
929 received,
930 post_hook_cloned,
931 stream_id,
932 cat_dev_slowdown,
933 #[cfg(debug_assertions)] trace_io,
934 ).await? {
935 break;
936 }
937 }
938 res_size = tcp_stream.read_buf(&mut buff) => {
939 let size = res_size.map_err(NetworkError::IO)?;
940 buff.truncate(size);
941 if buff.is_empty() {
942 continue;
943 }
944
945 let (should_break, returned_stream) = Self::handle_server_read_from_connection(
946 tcp_stream,
947 buff,
948 send_responses.clone(),
949 &nagle_guard,
950 slowloris_timeout,
951 handler.clone(),
952 &mut nagle_cache,
953 client_address,
954 pre_hook_cloned,
955 state.clone(),
956 stream_id,
957 #[cfg(debug_assertions)] trace_io,
958 ).await?;
959 tcp_stream = returned_stream;
960 if should_break {
961 break;
962 }
963 }
964 }
965 }
966
967 OUT_OF_BAND_SENDERS
968 .remove_async(&(server_id, stream_id))
969 .await;
970 packets_left_to_send.close();
971 std::mem::drop(tcp_stream.shutdown().await);
972
973 Ok(())
974 }
975
976 async fn initialize_stream(
977 on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
978 send_channel: &mut BoundedSender<ResponseStreamMessage>,
979 source_address: &SocketAddr,
980 state: &State,
981 tcp_stream: &mut TcpStream,
982 server_id: u64,
983 stream_id: u64,
984 ) -> Result<bool, CatBridgeError> {
985 tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
986 OUT_OF_BAND_SENDERS
987 .upsert_async((server_id, stream_id), send_channel.clone())
988 .await;
989
990 if let Some(mut handle) = on_stream_begin_handler
991 && !handle
992 .call(ResponseStreamEvent::new_with_state(
993 send_channel.clone(),
994 *source_address,
995 Some(stream_id),
996 state.clone(),
997 ))
998 .await?
999 {
1000 trace!("handler failed on stream begin hook");
1001 return Ok(true);
1002 }
1003
1004 Ok(false)
1005 }
1006
1007 async fn handle_server_write_to_connection(
1008 tcp_stream: &mut TcpStream,
1009 chunk_output_on_size: Option<usize>,
1010 to_send_to_client_opt: Option<ResponseStreamMessage>,
1011 post_hook: Option<&'static dyn PostNagleFnTy>,
1012 stream_id: u64,
1013 cat_dev_slowdown: Option<Duration>,
1014 #[cfg(debug_assertions)] trace_io: bool,
1015 ) -> Result<bool, CatBridgeError> {
1016 let Some(to_send_to_client) = to_send_to_client_opt else {
1017 return Ok(false);
1018 };
1019
1020 match to_send_to_client {
1021 ResponseStreamMessage::Disconnect => {
1022 debug!("stream-disconnect-message");
1023 Ok(true)
1024 }
1025 ResponseStreamMessage::Response(resp) => {
1026 if let Some(body) = resp.body()
1027 && !body.is_empty()
1028 {
1029 let messages = if let Some(size) = chunk_output_on_size {
1030 body.chunks(size)
1031 .map(Bytes::copy_from_slice)
1032 .collect::<Vec<_>>()
1033 } else {
1034 vec![body.clone()]
1035 };
1036
1037 for message in messages {
1038 #[cfg(debug_assertions)]
1039 if trace_io {
1040 debug!(
1041 body.hex = format!("{message:02x?}"),
1042 body.str = String::from_utf8_lossy(&message).to_string(),
1043 "cat-dev-trace-output-tcp-server",
1044 );
1045 }
1046
1047 let mut full_response = message.clone();
1048 if let Some(post) = post_hook {
1049 full_response = block_in_place(|| post(stream_id, full_response));
1050 }
1051 if let Some(slowdown_ms) = cat_dev_slowdown {
1052 sleep(slowdown_ms).await;
1053 }
1054
1055 tcp_stream.writable().await.map_err(NetworkError::IO)?;
1056 tcp_stream
1057 .write_all(&full_response)
1058 .await
1059 .map_err(NetworkError::IO)?;
1060 }
1061 }
1062
1063 if resp.request_connection_close() {
1064 trace!("response-requested-connection-close");
1065 Ok(true)
1066 } else {
1067 Ok(false)
1068 }
1069 }
1070 }
1071 }
1072
1073 #[allow(
1074 clippy::too_many_arguments,
1079 )]
1080 async fn handle_server_read_from_connection<'data>(
1081 mut stream: TcpStream,
1082 mut buff: BytesMut,
1083 channel: BoundedSender<ResponseStreamMessage>,
1084 nagle_guard: &'data NagleGuard,
1085 slowloris_timeout: Duration,
1086 mut handler: BoxCloneService<Request<State>, Response, Infallible>,
1087 nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
1088 client_address: SocketAddr,
1089 cloned_pre_nagle: Option<&'static dyn PreNagleFnTy>,
1090 state: State,
1091 stream_id: u64,
1092 #[cfg(debug_assertions)] trace_io: bool,
1093 ) -> Result<(bool, TcpStream), CatBridgeError> {
1094 if let Some(convert_fn) = cloned_pre_nagle {
1095 block_in_place(|| {
1096 (*convert_fn)(stream_id, &mut buff);
1097 });
1098 }
1099
1100 #[cfg(debug_assertions)]
1101 {
1102 if trace_io {
1103 debug!(
1104 body.hex = format!("{:02x?}", buff),
1105 body.str = String::from_utf8_lossy(&buff).to_string(),
1106 "cat-dev-trace-input-tcp-server",
1107 );
1108 }
1109 }
1110
1111 let start_time = now();
1114 if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
1115 let total_duration = start_time
1120 .duration_since(old_start_time)
1121 .unwrap_or(Duration::from_secs(0));
1122 if total_duration > slowloris_timeout {
1123 debug!(
1124 cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
1125 "slowloris-detected",
1126 );
1127 return Ok((true, stream));
1128 }
1129
1130 existing_buff.extend(buff.freeze());
1131 buff = existing_buff;
1132 }
1133
1134 while let Some((start_of_packet, end_of_packet)) = nagle_guard.split(&buff)? {
1135 let remaining_buff = buff.split_off(end_of_packet);
1136 let _start_of_buff = buff.split_to(start_of_packet);
1137 let req_body = buff.freeze();
1138 buff = remaining_buff;
1139
1140 let lockable_stream = Arc::new(Mutex::new(Some((Some(buff), stream))));
1141 let mut request_object = Request::new_with_state_and_stream(
1142 req_body,
1143 client_address,
1144 state.clone(),
1145 Some(stream_id),
1146 lockable_stream.clone(),
1147 );
1148 request_object.extensions_mut().insert(channel.clone());
1149 if let Err(cause) = match handler.call(request_object).await {
1150 Ok(ref resp) => {
1151 channel
1152 .send(ResponseStreamMessage::Response(resp.clone()))
1153 .await
1154 }
1155 Err(cause) => {
1156 warn!(?cause, "request handler failed, will close connection.");
1157 channel.send(ResponseStreamMessage::Disconnect).await
1158 }
1159 } {
1160 warn!(
1161 ?cause,
1162 "internal queue failure will not send disconnect/response."
1163 );
1164 }
1165
1166 {
1167 let mut done_lock = lockable_stream.lock().await;
1168 if let Some((newer_buff, strm)) = done_lock.take() {
1169 if let Some(newest_buff) = newer_buff {
1170 buff = newest_buff;
1171 } else {
1172 return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1173 }
1174 stream = strm;
1175 } else {
1176 return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1177 }
1178 }
1179 }
1180
1181 if !buff.is_empty() {
1182 _ = nagle_cache.insert((buff, start_time));
1183 }
1184
1185 Ok((false, stream))
1186 }
1187}
1188
1189impl<State: Clone + Debug + Send + Sync + 'static> Debug for TCPServer<State> {
1190 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
1191 let mut dbg_struct = fmt.debug_struct("TCPServer");
1192 dbg_struct
1193 .field(
1194 "address_to_bind_or_connect_to",
1195 &self.address_to_bind_or_connect_to,
1196 )
1197 .field("cat_dev_slowdown", &self.cat_dev_slowdown)
1198 .field("chunk_output_at_size", &self.chunk_output_at_size)
1199 .field("id", &self.id)
1200 .field("initial_service", &self.initial_service)
1201 .field("nagle_guard", &self.nagle_guard)
1202 .field("on_stream_begin", &self.on_stream_begin)
1203 .field("on_stream_end", &self.on_stream_end)
1204 .field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
1205 .field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
1206 .field("service_name", &self.service_name)
1207 .field("slowloris_timeout", &self.slowloris_timeout)
1208 .field("state", &self.state);
1209
1210 #[cfg(debug_assertions)]
1211 {
1212 dbg_struct.field("trace_during_debug", &self.trace_during_debug);
1213 }
1214
1215 dbg_struct.finish()
1216 }
1217}
1218
1219const TCP_SERVER_FIELDS: &[NamedField<'static>] = &[
1220 NamedField::new("address_to_bind_or_connect_to"),
1221 NamedField::new("cat_dev_slowdown"),
1222 NamedField::new("chunk_output_at_size"),
1223 NamedField::new("initial_service"),
1224 NamedField::new("nagle_guard"),
1225 NamedField::new("on_stream_begin"),
1226 NamedField::new("on_stream_end"),
1227 NamedField::new("has_pre_nagle_hook"),
1228 NamedField::new("has_post_nagle_hook"),
1229 NamedField::new("service_name"),
1230 NamedField::new("slowloris_timeout"),
1231 NamedField::new("state"),
1232 #[cfg(debug_assertions)]
1233 NamedField::new("trace_during_debug"),
1234];
1235
1236impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Structable for TCPServer<State> {
1237 fn definition(&self) -> StructDef<'_> {
1238 StructDef::new_static("TcpServer", Fields::Named(TCP_SERVER_FIELDS))
1239 }
1240}
1241
1242impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Valuable for TCPServer<State> {
1243 fn as_value(&self) -> Value<'_> {
1244 Value::Structable(self)
1245 }
1246
1247 fn visit(&self, visitor: &mut dyn Visit) {
1248 visitor.visit_named_fields(&NamedValues::new(
1249 TCP_SERVER_FIELDS,
1250 &[
1251 Valuable::as_value(&format!("{}", self.address_to_bind_or_connect_to)),
1252 Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
1253 format!("{}ms", slowdown.as_millis())
1254 } else {
1255 "<none>".to_string()
1256 }),
1257 Valuable::as_value(&self.chunk_output_at_size),
1258 Valuable::as_value(&format!("{:?}", self.initial_service)),
1259 Valuable::as_value(&self.nagle_guard),
1260 Valuable::as_value(&format!("{:?}", self.on_stream_begin)),
1261 Valuable::as_value(&format!("{:?}", self.on_stream_end)),
1262 Valuable::as_value(&self.pre_nagle_hook.is_some()),
1263 Valuable::as_value(&self.post_nagle_hook.is_some()),
1264 Valuable::as_value(&self.service_name),
1265 Valuable::as_value(&format!("{:?}", self.slowloris_timeout)),
1266 Valuable::as_value(&self.state),
1267 #[cfg(debug_assertions)]
1268 Valuable::as_value(&self.trace_during_debug),
1269 ],
1270 ));
1271 }
1272}
1273
1274#[cfg(test)]
1275pub mod test_helpers {
1276 use super::*;
1277 use std::net::{Ipv4Addr, SocketAddrV4};
1278
1279 pub async fn get_free_tcp_v4_port() -> Option<u16> {
1285 let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
1286 if let Ok(bound) = TcpListener::bind(addr).await {
1287 if let Ok(local) = bound.local_addr() {
1288 return Some(local.port());
1289 }
1290 }
1291 None
1292 }
1293}
1294
1295#[cfg(test)]
1296mod unit_tests {
1297 use super::*;
1298 use crate::net::{
1299 CURRENT_TIME,
1300 server::{Router, requestable::Extension, test_helpers::*},
1301 };
1302 use bytes::Bytes;
1303 use std::{
1304 net::{Ipv4Addr, SocketAddrV4},
1305 sync::{
1306 Arc, Mutex,
1307 atomic::{AtomicU8, Ordering},
1308 },
1309 time::Duration,
1310 };
1311 use tokio::time::timeout;
1312
1313 fn set_now(new_time: SystemTime) {
1314 CURRENT_TIME.with(|time_lazy| {
1315 *time_lazy.write().expect("RwLock is poisioned?") = new_time;
1316 })
1317 }
1318
1319 #[tokio::test]
1320 pub async fn test_full_server() {
1321 let connected_fired = Arc::new(Mutex::new(false));
1322 let on_disconnect_fired = Arc::new(Mutex::new(false));
1323 let request_fired = Arc::new(Mutex::new(false));
1324
1325 async fn on_connection(
1326 Extension(connected): Extension<Arc<Mutex<bool>>>,
1327 ) -> Result<bool, CatBridgeError> {
1328 let mut locked = connected
1329 .lock()
1330 .expect("Failed to lock connected fired extension");
1331 *locked = true;
1332 Ok(true)
1333 }
1334 async fn on_disconnect(
1335 Extension(disconnected): Extension<Arc<Mutex<bool>>>,
1336 ) -> Result<(), CatBridgeError> {
1337 let mut locked = disconnected
1338 .lock()
1339 .expect("Failed to lock connected fired extension");
1340 *locked = true;
1341 Ok(())
1342 }
1343 async fn on_request(
1344 Extension(request): Extension<Arc<Mutex<bool>>>,
1345 ) -> Result<Response, CatBridgeError> {
1346 let mut locked = request
1347 .lock()
1348 .expect("Failed to lock connected fired extension");
1349 *locked = true;
1350
1351 let mut resp = Response::new_with_body(Bytes::from(vec![0x1]));
1352 resp.should_close_connection();
1353 Ok(resp)
1354 }
1355
1356 let mut router = Router::new();
1357 router
1358 .add_route(&[0x1, 0x2, 0x3], on_request)
1359 .expect("Failed to add a route!");
1360 router.layer(Extension(request_fired.clone()));
1361
1362 let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1363 .await
1364 .expect("Timed out trying to find free port!")
1365 .expect("Failed to find free TCP port on system.");
1366
1367 let mut srv = timeout(
1368 Duration::from_secs(5),
1369 TCPServer::new_with_state(
1370 "test",
1371 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1372 router,
1373 (None, None),
1374 NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1375 (),
1376 #[cfg(debug_assertions)]
1377 true,
1378 ),
1379 )
1380 .await
1381 .expect("Timed out starting server")
1382 .expect("Failed to create TCP Server.");
1383
1384 srv.set_on_stream_begin(on_connection)
1385 .expect("Failed to register stream begin handler!");
1386 srv.layer_on_stream_begin(Extension(connected_fired.clone()))
1387 .expect("Failed to add layer to on stream begin!");
1388 srv.set_on_stream_end(on_disconnect)
1389 .expect("Failed to register stream end handler!");
1390 srv.layer_on_stream_end(Extension(on_disconnect_fired.clone()))
1391 .expect("Failed to add layer to on_disconnect!");
1392
1393 let spawned =
1394 tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1395 {
1396 loop {
1397 let client_stream_res = timeout(
1398 Duration::from_secs(10),
1399 TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1400 )
1401 .await
1402 .expect("Service timed out waiting for connection!");
1403 if client_stream_res.is_err() {
1405 continue;
1406 }
1407 let mut client_stream = client_stream_res.unwrap();
1408 client_stream
1409 .write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1410 .await
1411 .expect("Failed to write to client stream");
1412 let mut buff = [0_u8; 1];
1414 timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1415 .await
1416 .expect("Timed out reading from client stream")
1417 .expect("Failed to read data from client stream");
1418 timeout(Duration::from_secs(5), client_stream.shutdown())
1419 .await
1420 .expect("Timed out shutting down client stream")
1421 .expect("Failed to shutdown client stream.");
1422 break;
1423 }
1424 }
1425 std::mem::drop(spawned);
1427
1428 let locked_connect = connected_fired
1429 .lock()
1430 .expect("Failed to lock second connect");
1431 let locked_disconnect = on_disconnect_fired
1432 .lock()
1433 .expect("Failed to lock second on_disconnect");
1434 let locked_request = request_fired.lock().expect("Failed to lock second request");
1435
1436 assert!(*locked_connect, "on connection handler never fired!");
1437 assert!(*locked_disconnect, "on disconnect handler never fired!");
1438 assert!(*locked_request, "on request handler never fired!");
1439 }
1440
1441 #[tokio::test]
1442 pub async fn test_nagled() {
1443 let requests_fired = Arc::new(AtomicU8::new(0));
1444
1445 async fn on_request(
1446 Extension(request): Extension<Arc<AtomicU8>>,
1447 ) -> Result<Response, CatBridgeError> {
1448 request.fetch_add(1, Ordering::SeqCst);
1449 let resp = Response::new_with_body(Bytes::from(vec![0x1]));
1450 Ok(resp)
1451 }
1452
1453 let mut router = Router::new();
1454 router
1455 .add_route(&[0x1, 0x2, 0x3], on_request)
1456 .expect("Failed to add a route!");
1457 router.layer(Extension(requests_fired.clone()));
1458
1459 let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1460 .await
1461 .expect("Timed out finding port to bind too!")
1462 .expect("Failed to find any free tcp v4 port on system!");
1463 let srv = timeout(
1464 Duration::from_secs(5),
1465 TCPServer::new_with_state(
1466 "test",
1467 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1468 router,
1469 (None, None),
1470 NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1471 (),
1472 #[cfg(debug_assertions)]
1473 true,
1474 ),
1475 )
1476 .await
1477 .expect("timed out starting TCP Server for test")
1478 .expect("falied to create local tcp server!");
1479
1480 let spawned =
1481 tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1482 {
1483 loop {
1484 let client_stream_res = timeout(
1485 Duration::from_secs(10),
1486 TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1487 )
1488 .await
1489 .expect("Service timed out waiting for connection!");
1490 if client_stream_res.is_err() {
1492 continue;
1493 }
1494 let mut client_stream = client_stream_res.unwrap();
1495
1496 client_stream
1497 .write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF])
1498 .await
1499 .expect("Failed to write to client_stream");
1500 client_stream
1501 .flush()
1502 .await
1503 .expect("Failed to flush client_stream");
1504 client_stream
1505 .write_all(&[0xFF, 0xFF, 0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1506 .await
1507 .expect("Failed to issue second write call to client_stream");
1508 let mut buff = [0_u8; 2];
1509 let read_bytes = timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1510 .await
1511 .expect("Timed out reading from client_stream")
1512 .expect("Failed to read from client_stream!");
1513 if read_bytes == 1 {
1514 timeout(Duration::from_secs(5), client_stream.read(&mut buff[1..]))
1515 .await
1516 .expect("Timed out reading from client_stream")
1517 .expect("Failed to read from client_stream!");
1518 }
1519
1520 timeout(Duration::from_secs(5), client_stream.shutdown())
1521 .await
1522 .expect("Timed out shutting down client stream")
1523 .expect("Failed to shutdown client stream.");
1524 break;
1525 }
1526 }
1527 std::mem::drop(spawned);
1529
1530 assert_eq!(
1531 requests_fired.load(Ordering::SeqCst),
1532 2,
1533 "on request did not fire the correct amount of times!",
1534 );
1535 }
1536
1537 #[tokio::test]
1538 pub async fn test_slowloris_blocking() {
1539 let mut router = Router::new();
1540 router
1541 .add_route(&[0x1, 0x2, 0x3], || async {
1542 Ok(Response::new_with_body(Bytes::from(vec![0x1])))
1543 })
1544 .expect("Failed to add a route!");
1545
1546 let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1547 .await
1548 .expect("Timed out finding port to bind too!")
1549 .expect("Failed to find any free tcp v4 port on system!");
1550 let srv = timeout(
1551 Duration::from_secs(5),
1552 TCPServer::new_with_state(
1553 "test",
1554 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1555 router,
1556 (None, None),
1557 NagleGuard::EndSigilSearch(&[0x10, 0x11, 0x12]),
1558 (),
1559 #[cfg(debug_assertions)]
1560 true,
1561 ),
1562 )
1563 .await
1564 .expect("timed out starting TCP Server for test")
1565 .expect("falied to create local tcp server!");
1566
1567 let spawned =
1568 tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1569 let read_bytes;
1570 {
1571 loop {
1572 let client_stream_res = timeout(
1573 Duration::from_secs(10),
1574 TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1575 )
1576 .await
1577 .expect("Service timed out waiting for connection!");
1578 if client_stream_res.is_err() {
1580 continue;
1581 }
1582 let mut client_stream = client_stream_res.unwrap();
1583
1584 client_stream
1585 .write_all(&[0x1, 0x2, 0x3, 0x10])
1586 .await
1587 .expect("Failed to write to client_stream");
1588 client_stream
1589 .flush()
1590 .await
1591 .expect("Failed to flush client_stream");
1592 tokio::time::sleep(Duration::from_secs(5)).await;
1595 set_now(
1596 SystemTime::now()
1597 .checked_add(Duration::from_secs(900_00_000))
1598 .expect("Failed to add time to systemtime"),
1599 );
1600 client_stream
1602 .write_all(&[0x11, 0x12])
1603 .await
1604 .expect("Failed to write to client_stream");
1605
1606 let mut buff = [0_u8; 1];
1607 read_bytes = timeout(Duration::from_secs(10), client_stream.read(&mut buff))
1610 .await
1611 .expect("timed out trying to wait for disconnect")
1612 .expect("failure reading from stream");
1613 break;
1614 }
1615 }
1616 std::mem::drop(spawned);
1617
1618 assert_eq!(read_bytes, 0, "Client didn't error on slowloris'd packet");
1619 }
1620}