1use crate::{
132 errors::{CatBridgeError, NetworkError},
133 net::{
134 DEFAULT_SLOWLORIS_TIMEOUT, SERVER_ID, STREAM_ID, TCP_READ_BUFFER_SIZE,
135 errors::{CommonNetAPIError, CommonNetNetworkError},
136 handlers::{
137 OnResponseStreamBeginHandler, OnResponseStreamEndHandler,
138 OnStreamBeginHandlerAsService, OnStreamEndHandlerAsService,
139 },
140 models::{NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
141 now,
142 server::models::{
143 DisconnectAsyncDropServer, ResponseStreamEvent, ResponseStreamMessage,
144 UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
145 },
146 },
147};
148use bytes::{Bytes, BytesMut};
149use fnv::FnvHashSet;
150use futures::future::join_all;
151use scc::HashMap as ConcurrentMap;
152use std::{
153 convert::Infallible,
154 fmt::{Debug, Formatter, Result as FmtResult},
155 net::{IpAddr, SocketAddr},
156 sync::{Arc, LazyLock, atomic::Ordering},
157 time::{Duration, SystemTime},
158};
159use tokio::{
160 io::{AsyncReadExt, AsyncWriteExt},
161 net::{TcpListener, TcpStream, ToSocketAddrs, lookup_host},
162 sync::{
163 Mutex,
164 mpsc::{Sender as BoundedSender, channel as bounded_channel},
165 },
166 task::{Builder as TaskBuilder, block_in_place},
167 time::sleep,
168};
169use tower::{Layer, Service, util::BoxCloneService};
170use tracing::{Instrument, debug, error_span, trace, warn};
171use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
172
173#[cfg(debug_assertions)]
174use crate::net::SPRIG_TRACE_IO;
175
176static OUT_OF_BAND_SENDERS: LazyLock<
179 ConcurrentMap<(u64, u64), BoundedSender<ResponseStreamMessage>>,
180> = LazyLock::new(ConcurrentMap::new);
181
182pub struct TCPServer<State: Clone + Send + Sync + 'static = ()> {
191 address_to_bind_or_connect_to: SocketAddr,
199 cat_dev_slowdown: Option<Duration>,
205 chunk_output_at_size: Option<usize>,
211 id: u64,
215 initial_service: BoxCloneService<Request<State>, Response, Infallible>,
220 nagle_guard: NagleGuard,
222 on_stream_begin: Option<UnderlyingOnStreamBeginService<State>>,
227 on_stream_end: Option<UnderlyingOnStreamEndService<State>>,
232 pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
238 post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
245 service_name: &'static str,
247 slowloris_timeout: Duration,
249 state: State,
255 #[cfg(debug_assertions)]
258 trace_during_debug: bool,
259}
260
261impl TCPServer<()> {
262 pub async fn new<AddrTy, ServiceTy>(
268 service_name: &'static str,
269 bind_addr: AddrTy,
270 initial_service: ServiceTy,
271 nagle_hooks: (
272 Option<&'static dyn PreNagleFnTy>,
273 Option<&'static dyn PostNagleFnTy>,
274 ),
275 guard: impl Into<NagleGuard>,
276 trace_io_during_debug: bool,
277 ) -> Result<Self, CommonNetAPIError>
278 where
279 AddrTy: ToSocketAddrs,
280 ServiceTy:
281 Clone + Send + Service<Request<()>, Response = Response, Error = Infallible> + 'static,
282 ServiceTy::Future: Send + 'static,
283 {
284 Self::new_with_state(
285 service_name,
286 bind_addr,
287 initial_service,
288 nagle_hooks,
289 guard,
290 (),
291 trace_io_during_debug,
292 )
293 .await
294 }
295}
296
297impl<State: Clone + Send + Sync + 'static> TCPServer<State> {
298 pub async fn out_of_bound_send(
305 server_id: u64,
306 stream_id: u64,
307 message: ResponseStreamMessage,
308 ) -> Result<(), CatBridgeError> {
309 if let Some(stream) = OUT_OF_BAND_SENDERS.get_async(&(server_id, stream_id)).await {
310 stream
311 .send(message)
312 .await
313 .map_err(NetworkError::SendQueueMessageFailure)?;
314 Ok(())
315 } else {
316 Err(CommonNetNetworkError::StreamNoLongerProcessing.into())
317 }
318 }
319
320 pub async fn out_of_bound_broadcast(
322 server_id: u64,
323 message: ResponseStreamMessage,
324 ) -> Vec<Result<(), CatBridgeError>> {
325 let mut ids = FnvHashSet::default();
326 OUT_OF_BAND_SENDERS
328 .iter_async(|key, _value| {
329 if key.0 == server_id {
330 ids.insert(key.1);
331 }
332 true
333 })
334 .await;
335
336 let mut tasks = Vec::with_capacity(ids.len());
338 for id in ids {
339 tasks.push(Self::out_of_bound_send(server_id, id, message.clone()));
340 }
341
342 join_all(tasks).await
343 }
344
345 #[allow(unused)]
351 pub async fn new_with_state<AddrTy, ServiceTy>(
352 service_name: &'static str,
353 bind_addr: AddrTy,
354 initial_service: ServiceTy,
355 nagle_hooks: (
356 Option<&'static dyn PreNagleFnTy>,
357 Option<&'static dyn PostNagleFnTy>,
358 ),
359 guard: impl Into<NagleGuard>,
360 state: State,
361 trace_io_during_debug: bool,
362 ) -> Result<Self, CommonNetAPIError>
363 where
364 AddrTy: ToSocketAddrs,
365 ServiceTy: Clone
366 + Send
367 + Service<Request<State>, Response = Response, Error = Infallible>
368 + 'static,
369 ServiceTy::Future: Send + 'static,
370 {
371 let hosts = lookup_host(bind_addr)
372 .await
373 .map_err(CommonNetAPIError::AddressLookupError)?
374 .collect::<Vec<_>>();
375 if hosts.len() != 1 {
376 return Err(CommonNetAPIError::WrongAmountOfAddressesToBindToo(hosts));
377 }
378
379 #[cfg(not(debug_assertions))]
380 {
381 if trace_io_during_debug {
382 warn!(
383 "Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
384 );
385 }
386 }
387
388 Ok(Self {
389 address_to_bind_or_connect_to: hosts[0],
390 cat_dev_slowdown: None,
391 chunk_output_at_size: None,
392 id: SERVER_ID.fetch_add(1, Ordering::SeqCst),
393 initial_service: BoxCloneService::new(initial_service),
394 nagle_guard: guard.into(),
395 on_stream_begin: None,
396 on_stream_end: None,
397 pre_nagle_hook: nagle_hooks.0,
398 post_nagle_hook: nagle_hooks.1,
399 service_name,
400 slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
401 state,
402 #[cfg(debug_assertions)]
403 trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
404 })
405 }
406
407 #[must_use]
409 pub const fn id(&self) -> u64 {
410 self.id
411 }
412
413 #[must_use]
415 pub const fn ip(&self) -> IpAddr {
416 self.address_to_bind_or_connect_to.ip()
417 }
418
419 #[must_use]
421 pub const fn port(&self) -> u16 {
422 self.address_to_bind_or_connect_to.port()
423 }
424
425 pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
427 self.cat_dev_slowdown = slowdown;
428 }
429
430 #[must_use]
431 pub const fn chunk_output_at_size(&self) -> Option<usize> {
432 self.chunk_output_at_size
433 }
434
435 pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
436 self.chunk_output_at_size = new_size;
437 }
438
439 #[must_use]
440 pub const fn slowloris_timeout(&self) -> Duration {
441 self.slowloris_timeout
442 }
443 pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
444 self.slowloris_timeout = slowloris_timeout;
445 }
446
447 #[must_use]
448 pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<State>> {
449 self.on_stream_begin.as_ref()
450 }
451
452 pub fn set_raw_on_stream_begin(
465 &mut self,
466 on_start: Option<UnderlyingOnStreamBeginService<State>>,
467 ) -> Result<(), CommonNetAPIError> {
468 if self.on_stream_begin.is_some() {
469 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
470 }
471
472 self.on_stream_begin = on_start;
473 Ok(())
474 }
475
476 pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
488 &mut self,
489 handler: HandlerTy,
490 ) -> Result<(), CommonNetAPIError>
491 where
492 HandlerParamsTy: Send + 'static,
493 HandlerTy: OnResponseStreamBeginHandler<HandlerParamsTy, State> + Clone + Send + 'static,
494 {
495 if self.on_stream_begin.is_some() {
496 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
497 }
498
499 let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
500 self.on_stream_begin = Some(boxed);
501 Ok(())
502 }
503
504 pub fn set_on_stream_begin_service<ServiceTy>(
516 &mut self,
517 service_ty: ServiceTy,
518 ) -> Result<(), CommonNetAPIError>
519 where
520 ServiceTy: Clone
521 + Send
522 + Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
523 + 'static,
524 ServiceTy::Future: Send + 'static,
525 {
526 if self.on_stream_begin.is_some() {
527 return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
528 }
529
530 self.on_stream_begin = Some(BoxCloneService::new(service_ty));
531 Ok(())
532 }
533
534 pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
541 &mut self,
542 layer: LayerTy,
543 ) -> Result<(), CommonNetAPIError>
544 where
545 LayerTy: Layer<UnderlyingOnStreamBeginService<State>, Service = ServiceTy>,
546 ServiceTy: Service<ResponseStreamEvent<State>, Response = bool, Error = CatBridgeError>
547 + Clone
548 + Send
549 + 'static,
550 <LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
551 {
552 let Some(srvc) = self.on_stream_begin.take() else {
553 return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
554 };
555
556 self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
557 Ok(())
558 }
559
560 #[must_use]
561 pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<State>> {
562 self.on_stream_end.as_ref()
563 }
564
565 pub fn set_raw_on_stream_end(
578 &mut self,
579 on_end: Option<UnderlyingOnStreamEndService<State>>,
580 ) -> Result<(), CommonNetAPIError> {
581 if self.on_stream_end.is_some() {
582 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
583 }
584
585 self.on_stream_end = on_end;
586 Ok(())
587 }
588
589 pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
601 &mut self,
602 handler: HandlerTy,
603 ) -> Result<(), CommonNetAPIError>
604 where
605 HandlerParamsTy: Send + 'static,
606 HandlerTy: OnResponseStreamEndHandler<HandlerParamsTy, State> + Clone + Send + 'static,
607 {
608 if self.on_stream_end.is_some() {
609 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
610 }
611
612 let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
613 self.on_stream_end = Some(boxed);
614 Ok(())
615 }
616
617 pub fn set_on_stream_end_service<ServiceTy>(
629 &mut self,
630 service_ty: ServiceTy,
631 ) -> Result<(), CommonNetAPIError>
632 where
633 ServiceTy: Clone
634 + Send
635 + Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
636 + 'static,
637 ServiceTy::Future: Send + 'static,
638 {
639 if self.on_stream_end.is_some() {
640 return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
641 }
642
643 self.on_stream_end = Some(BoxCloneService::new(service_ty));
644 Ok(())
645 }
646
647 pub fn layer_on_stream_end<LayerTy, ServiceTy>(
654 &mut self,
655 layer: LayerTy,
656 ) -> Result<(), CommonNetAPIError>
657 where
658 LayerTy: Layer<UnderlyingOnStreamEndService<State>, Service = ServiceTy>,
659 ServiceTy: Service<ResponseStreamEvent<State>, Response = (), Error = CatBridgeError>
660 + Clone
661 + Send
662 + 'static,
663 <LayerTy::Service as Service<ResponseStreamEvent<State>>>::Future: Send + 'static,
664 {
665 let Some(srvc) = self.on_stream_end.take() else {
666 return Err(CommonNetAPIError::OnStreamEndNotRegistered);
667 };
668
669 self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
670 Ok(())
671 }
672
673 #[must_use]
674 pub const fn initial_service(&self) -> &BoxCloneService<Request<State>, Response, Infallible> {
675 &self.initial_service
676 }
677
678 pub fn layer_initial_service<LayerTy, ServiceTy>(&mut self, layer: LayerTy)
679 where
680 LayerTy: Layer<BoxCloneService<Request<State>, Response, Infallible>, Service = ServiceTy>,
681 ServiceTy: Service<Request<State>, Response = Response, Error = Infallible>
682 + Clone
683 + Send
684 + 'static,
685 <LayerTy::Service as Service<Request<State>>>::Future: Send + 'static,
686 {
687 self.initial_service = BoxCloneService::new(layer.layer(self.initial_service.clone()));
688 }
689
690 #[must_use]
692 pub const fn state(&self) -> &State {
693 &self.state
694 }
695
696 pub async fn connect(self) -> Result<(), CatBridgeError> {
707 loop {
708 let client_address = self.address_to_bind_or_connect_to;
710 let stream = TcpStream::connect(self.address_to_bind_or_connect_to)
711 .await
712 .map_err(NetworkError::IO)?;
713 let loggable_address = stream.local_addr().map_err(NetworkError::IO)?;
714 let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
715 trace!(
716 server.address = %loggable_address,
717 client.address = %client_address,
718 stream.id = stream_id,
719 stream.stream_type = "server",
720 "cat_dev::net::tcp_server::connect(): started connection (TcpStream::connect())",
721 );
722
723 if let Err(cause) = Self::handle_tcp_connection(
724 self.on_stream_begin.clone(),
725 self.on_stream_end.clone(),
726 self.nagle_guard.clone(),
727 self.slowloris_timeout,
728 self.initial_service.clone(),
729 stream,
730 client_address,
731 self.pre_nagle_hook,
732 self.post_nagle_hook,
733 self.chunk_output_at_size,
734 self.state.clone(),
735 self.id,
736 stream_id,
737 self.cat_dev_slowdown,
738 #[cfg(debug_assertions)]
739 self.trace_during_debug,
740 )
741 .instrument(error_span!(
742 "CatDevTCPServerConnect",
743 client.address = %client_address,
744 server.address = %loggable_address,
745 server.service = self.service_name,
746 stream.id = stream_id,
747 stream.stream_type = "server",
748 ))
749 .await
750 {
751 warn!(
752 ?cause,
753 client.address = %client_address,
754 server.address = %loggable_address,
755 server.service = self.service_name,
756 "Error escaped while handling TCP connection.",
757 );
758 }
759 }
760 }
761
762 pub async fn bind(self) -> Result<(), CatBridgeError> {
772 let loggable_address = self.address_to_bind_or_connect_to;
774 let listener = TcpListener::bind(self.address_to_bind_or_connect_to)
775 .await
776 .map_err(NetworkError::IO)?;
777
778 loop {
779 let (stream, client_address) = listener.accept().await.map_err(NetworkError::IO)?;
780 trace!(
781 server.address = %loggable_address,
782 client.address = %client_address,
783 "cat_dev::net::tcp_server::bind(): received connection (listener.accept())",
784 );
785
786 let cloned_begin_handler = self.on_stream_begin.clone();
796 let cloned_end_handler = self.on_stream_end.clone();
797 let cloned_nagle_guard = self.nagle_guard.clone();
798 let cloned_handler = self.initial_service.clone();
799 let cloned_state = self.state.clone();
800 let copied_pre_nagle_hook = self.pre_nagle_hook;
801 let copied_post_nagle_hook = self.post_nagle_hook;
802 let copied_chunk_on_size = self.chunk_output_at_size;
803 let copied_service_name = self.service_name;
804 let copied_slowloris_timeout = self.slowloris_timeout;
805 let copied_server_id = self.id;
806 let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
807 let copied_slowdown = self.cat_dev_slowdown;
808 #[cfg(debug_assertions)]
809 let trace_io = self.trace_during_debug;
810
811 TaskBuilder::new()
812 .name("cat_dev::net::tcp_server::bind().connection.handle")
813 .spawn(async move {
814 if let Err(cause) = Self::handle_tcp_connection(
815 cloned_begin_handler,
816 cloned_end_handler,
817 cloned_nagle_guard,
818 copied_slowloris_timeout,
819 cloned_handler,
820 stream,
821 client_address,
822 copied_pre_nagle_hook,
823 copied_post_nagle_hook,
824 copied_chunk_on_size,
825 cloned_state,
826 copied_server_id,
827 stream_id,
828 copied_slowdown,
829 #[cfg(debug_assertions)]
830 trace_io,
831 )
832 .instrument(error_span!(
833 "CatDevTCPServerAccept",
834 client.address = %client_address,
835 server.address = %loggable_address,
836 server.service = copied_service_name,
837 server.stream_id = stream_id,
838 ))
839 .await
840 {
841 warn!(
842 ?cause,
843 client.address = %client_address,
844 server.address = %loggable_address,
845 server.service = %copied_service_name,
846 "Error escaped while handling TCP connection.",
847 );
848 }
849 })
850 .map_err(CatBridgeError::SpawnFailure)?;
851 }
852 }
853
854 #[allow(
866 clippy::too_many_arguments,
876 )]
877 async fn handle_tcp_connection(
878 on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
879 on_stream_end_handler: Option<UnderlyingOnStreamEndService<State>>,
880 nagle_guard: NagleGuard,
881 slowloris_timeout: Duration,
882 handler: BoxCloneService<Request<State>, Response, Infallible>,
883 mut tcp_stream: TcpStream,
884 client_address: SocketAddr,
885 pre_hook_cloned: Option<&'static dyn PreNagleFnTy>,
886 post_hook_cloned: Option<&'static dyn PostNagleFnTy>,
887 chunk_output_at_size: Option<usize>,
888 state: State,
889 server_id: u64,
890 stream_id: u64,
891 cat_dev_slowdown: Option<Duration>,
892 #[cfg(debug_assertions)] trace_io: bool,
893 ) -> Result<(), CatBridgeError> {
894 let (mut send_responses, mut packets_left_to_send) =
895 bounded_channel::<ResponseStreamMessage>(128);
896
897 if Self::initialize_stream(
900 on_stream_begin_handler,
901 &mut send_responses,
902 &client_address,
903 &state,
904 &mut tcp_stream,
905 server_id,
906 stream_id,
907 )
908 .await?
909 {
910 return Ok(());
911 }
912
913 let _guard = on_stream_end_handler.map(|service| {
915 DisconnectAsyncDropServer::new(service, state.clone(), client_address, stream_id)
916 });
917
918 let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
921
922 loop {
923 let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
924 tokio::select! {
925 received = packets_left_to_send.recv() => {
926 if Self::handle_server_write_to_connection(
927 &mut tcp_stream,
928 chunk_output_at_size,
929 received,
930 post_hook_cloned,
931 stream_id,
932 cat_dev_slowdown,
933 #[cfg(debug_assertions)] trace_io,
934 ).await? {
935 break;
936 }
937 }
938 res_size = tcp_stream.read_buf(&mut buff) => {
939 let size = res_size.map_err(NetworkError::IO)?;
940 buff.truncate(size);
941 if buff.is_empty() {
942 continue;
943 }
944
945 let (should_break, returned_stream) = Self::handle_server_read_from_connection(
946 tcp_stream,
947 buff,
948 send_responses.clone(),
949 &nagle_guard,
950 slowloris_timeout,
951 handler.clone(),
952 &mut nagle_cache,
953 client_address,
954 pre_hook_cloned,
955 state.clone(),
956 stream_id,
957 #[cfg(debug_assertions)] trace_io,
958 ).await?;
959 tcp_stream = returned_stream;
960 if should_break {
961 break;
962 }
963 }
964 }
965 }
966
967 OUT_OF_BAND_SENDERS
968 .remove_async(&(server_id, stream_id))
969 .await;
970 packets_left_to_send.close();
971 std::mem::drop(tcp_stream.shutdown().await);
972
973 Ok(())
974 }
975
976 async fn initialize_stream(
977 on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<State>>,
978 send_channel: &mut BoundedSender<ResponseStreamMessage>,
979 source_address: &SocketAddr,
980 state: &State,
981 tcp_stream: &mut TcpStream,
982 server_id: u64,
983 stream_id: u64,
984 ) -> Result<bool, CatBridgeError> {
985 tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
986 OUT_OF_BAND_SENDERS
987 .upsert_async((server_id, stream_id), send_channel.clone())
988 .await;
989
990 if let Some(mut handle) = on_stream_begin_handler
991 && !handle
992 .call(ResponseStreamEvent::new_with_state(
993 send_channel.clone(),
994 *source_address,
995 Some(stream_id),
996 state.clone(),
997 ))
998 .await?
999 {
1000 trace!("handler failed on stream begin hook");
1001 return Ok(true);
1002 }
1003
1004 Ok(false)
1005 }
1006
1007 async fn handle_server_write_to_connection(
1008 tcp_stream: &mut TcpStream,
1009 chunk_output_on_size: Option<usize>,
1010 to_send_to_client_opt: Option<ResponseStreamMessage>,
1011 post_hook: Option<&'static dyn PostNagleFnTy>,
1012 stream_id: u64,
1013 cat_dev_slowdown: Option<Duration>,
1014 #[cfg(debug_assertions)] trace_io: bool,
1015 ) -> Result<bool, CatBridgeError> {
1016 let Some(to_send_to_client) = to_send_to_client_opt else {
1017 return Ok(false);
1018 };
1019
1020 match to_send_to_client {
1021 ResponseStreamMessage::Disconnect => {
1022 debug!("stream-disconnect-message");
1023 Ok(true)
1024 }
1025 ResponseStreamMessage::Response(resp) => {
1026 if let Some(body) = resp.body()
1027 && !body.is_empty()
1028 {
1029 let messages = if let Some(size) = chunk_output_on_size {
1030 body.chunks(size)
1031 .map(Bytes::copy_from_slice)
1032 .collect::<Vec<_>>()
1033 } else {
1034 vec![body.clone()]
1035 };
1036
1037 for message in messages {
1038 #[cfg(debug_assertions)]
1039 if trace_io {
1040 debug!(
1041 body.hex = format!("{message:02x?}"),
1042 body.str = String::from_utf8_lossy(&message).to_string(),
1043 "cat-dev-trace-output-tcp-server",
1044 );
1045 }
1046
1047 let mut full_response = message.clone();
1048 if let Some(post) = post_hook {
1049 full_response = block_in_place(|| post(stream_id, full_response));
1050 }
1051 if let Some(slowdown_ms) = cat_dev_slowdown {
1052 sleep(slowdown_ms).await;
1053 }
1054
1055 tcp_stream.writable().await.map_err(NetworkError::IO)?;
1056 tcp_stream
1057 .write_all(&full_response)
1058 .await
1059 .map_err(NetworkError::IO)?;
1060 }
1061 }
1062
1063 if resp.request_connection_close() {
1064 trace!("response-requested-connection-close");
1065 Ok(true)
1066 } else {
1067 Ok(false)
1068 }
1069 }
1070 }
1071 }
1072
1073 #[allow(
1074 clippy::too_many_arguments,
1079 )]
1080 async fn handle_server_read_from_connection<'data>(
1081 mut stream: TcpStream,
1082 mut buff: BytesMut,
1083 channel: BoundedSender<ResponseStreamMessage>,
1084 nagle_guard: &'data NagleGuard,
1085 slowloris_timeout: Duration,
1086 mut handler: BoxCloneService<Request<State>, Response, Infallible>,
1087 nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
1088 client_address: SocketAddr,
1089 cloned_pre_nagle: Option<&'static dyn PreNagleFnTy>,
1090 state: State,
1091 stream_id: u64,
1092 #[cfg(debug_assertions)] trace_io: bool,
1093 ) -> Result<(bool, TcpStream), CatBridgeError> {
1094 if let Some(convert_fn) = cloned_pre_nagle {
1095 block_in_place(|| {
1096 (*convert_fn)(stream_id, &mut buff);
1097 });
1098 }
1099
1100 #[cfg(debug_assertions)]
1101 {
1102 if trace_io {
1103 debug!(
1104 body.hex = format!("{:02x?}", buff),
1105 body.str = String::from_utf8_lossy(&buff).to_string(),
1106 "cat-dev-trace-input-tcp-server",
1107 );
1108 }
1109 }
1110
1111 let start_time = now();
1114 if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
1115 let total_duration = start_time
1120 .duration_since(old_start_time)
1121 .unwrap_or(Duration::from_secs(0));
1122 if total_duration > slowloris_timeout {
1123 debug!(
1124 cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
1125 "slowloris-detected",
1126 );
1127 return Ok((true, stream));
1128 }
1129
1130 existing_buff.extend(buff.freeze());
1131 buff = existing_buff;
1132 }
1133
1134 while let Some((start_of_packet, end_of_packet)) = nagle_guard.split(&buff)? {
1135 let remaining_buff = buff.split_off(end_of_packet);
1136 let _start_of_buff = buff.split_to(start_of_packet);
1137 let req_body = buff.freeze();
1138 buff = remaining_buff;
1139
1140 let lockable_stream = Arc::new(Mutex::new(Some((Some(buff), stream))));
1141 let mut request_object = Request::new_with_state_and_stream(
1142 req_body,
1143 client_address,
1144 state.clone(),
1145 Some(stream_id),
1146 lockable_stream.clone(),
1147 );
1148 request_object.extensions_mut().insert(channel.clone());
1149 if let Err(cause) = match handler.call(request_object).await {
1150 Ok(ref resp) => {
1151 channel
1152 .send(ResponseStreamMessage::Response(resp.clone()))
1153 .await
1154 }
1155 Err(cause) => {
1156 warn!(
1157 ?cause,
1158 lisa.force_combine_fields = true,
1159 "request handler failed, will close connection.",
1160 );
1161 channel.send(ResponseStreamMessage::Disconnect).await
1162 }
1163 } {
1164 warn!(
1165 ?cause,
1166 lisa.force_combine_fields = true,
1167 "internal queue failure will not send disconnect/response."
1168 );
1169 }
1170
1171 {
1172 let mut done_lock = lockable_stream.lock().await;
1173 if let Some((newer_buff, strm)) = done_lock.take() {
1174 if let Some(newest_buff) = newer_buff {
1175 buff = newest_buff;
1176 } else {
1177 return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1178 }
1179 stream = strm;
1180 } else {
1181 return Err(CommonNetNetworkError::StreamNoLongerProcessing.into());
1182 }
1183 }
1184 }
1185
1186 if !buff.is_empty() {
1187 _ = nagle_cache.insert((buff, start_time));
1188 }
1189
1190 Ok((false, stream))
1191 }
1192}
1193
1194impl<State: Clone + Debug + Send + Sync + 'static> Debug for TCPServer<State> {
1195 fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
1196 let mut dbg_struct = fmt.debug_struct("TCPServer");
1197 dbg_struct
1198 .field(
1199 "address_to_bind_or_connect_to",
1200 &self.address_to_bind_or_connect_to,
1201 )
1202 .field("cat_dev_slowdown", &self.cat_dev_slowdown)
1203 .field("chunk_output_at_size", &self.chunk_output_at_size)
1204 .field("id", &self.id)
1205 .field("initial_service", &self.initial_service)
1206 .field("nagle_guard", &self.nagle_guard)
1207 .field("on_stream_begin", &self.on_stream_begin)
1208 .field("on_stream_end", &self.on_stream_end)
1209 .field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
1210 .field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
1211 .field("service_name", &self.service_name)
1212 .field("slowloris_timeout", &self.slowloris_timeout)
1213 .field("state", &self.state);
1214
1215 #[cfg(debug_assertions)]
1216 {
1217 dbg_struct.field("trace_during_debug", &self.trace_during_debug);
1218 }
1219
1220 dbg_struct.finish()
1221 }
1222}
1223
1224const TCP_SERVER_FIELDS: &[NamedField<'static>] = &[
1225 NamedField::new("address_to_bind_or_connect_to"),
1226 NamedField::new("cat_dev_slowdown"),
1227 NamedField::new("chunk_output_at_size"),
1228 NamedField::new("initial_service"),
1229 NamedField::new("nagle_guard"),
1230 NamedField::new("on_stream_begin"),
1231 NamedField::new("on_stream_end"),
1232 NamedField::new("has_pre_nagle_hook"),
1233 NamedField::new("has_post_nagle_hook"),
1234 NamedField::new("service_name"),
1235 NamedField::new("slowloris_timeout"),
1236 NamedField::new("state"),
1237 #[cfg(debug_assertions)]
1238 NamedField::new("trace_during_debug"),
1239];
1240
1241impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Structable for TCPServer<State> {
1242 fn definition(&self) -> StructDef<'_> {
1243 StructDef::new_static("TcpServer", Fields::Named(TCP_SERVER_FIELDS))
1244 }
1245}
1246
1247impl<State: Clone + Debug + Send + Sync + Valuable + 'static> Valuable for TCPServer<State> {
1248 fn as_value(&self) -> Value<'_> {
1249 Value::Structable(self)
1250 }
1251
1252 fn visit(&self, visitor: &mut dyn Visit) {
1253 visitor.visit_named_fields(&NamedValues::new(
1254 TCP_SERVER_FIELDS,
1255 &[
1256 Valuable::as_value(&format!("{}", self.address_to_bind_or_connect_to)),
1257 Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
1258 format!("{}ms", slowdown.as_millis())
1259 } else {
1260 "<none>".to_string()
1261 }),
1262 Valuable::as_value(&self.chunk_output_at_size),
1263 Valuable::as_value(&format!("{:?}", self.initial_service)),
1264 Valuable::as_value(&self.nagle_guard),
1265 Valuable::as_value(&format!("{:?}", self.on_stream_begin)),
1266 Valuable::as_value(&format!("{:?}", self.on_stream_end)),
1267 Valuable::as_value(&self.pre_nagle_hook.is_some()),
1268 Valuable::as_value(&self.post_nagle_hook.is_some()),
1269 Valuable::as_value(&self.service_name),
1270 Valuable::as_value(&format!("{:?}", self.slowloris_timeout)),
1271 Valuable::as_value(&self.state),
1272 #[cfg(debug_assertions)]
1273 Valuable::as_value(&self.trace_during_debug),
1274 ],
1275 ));
1276 }
1277}
1278
1279#[cfg(test)]
1280pub mod test_helpers {
1281 use super::*;
1282 use std::net::{Ipv4Addr, SocketAddrV4};
1283
1284 pub async fn get_free_tcp_v4_port() -> Option<u16> {
1290 let addr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0);
1291 if let Ok(bound) = TcpListener::bind(addr).await {
1292 if let Ok(local) = bound.local_addr() {
1293 return Some(local.port());
1294 }
1295 }
1296 None
1297 }
1298}
1299
1300#[cfg(test)]
1301mod unit_tests {
1302 use super::*;
1303 use crate::net::{
1304 CURRENT_TIME,
1305 server::{Router, requestable::Extension, test_helpers::*},
1306 };
1307 use bytes::Bytes;
1308 use std::{
1309 net::{Ipv4Addr, SocketAddrV4},
1310 sync::{
1311 Arc, Mutex,
1312 atomic::{AtomicU8, Ordering},
1313 },
1314 time::Duration,
1315 };
1316 use tokio::time::timeout;
1317
1318 fn set_now(new_time: SystemTime) {
1319 CURRENT_TIME.with(|time_lazy| {
1320 *time_lazy.write().expect("RwLock is poisioned?") = new_time;
1321 })
1322 }
1323
1324 #[tokio::test]
1325 pub async fn full_server() {
1326 let connected_fired = Arc::new(Mutex::new(false));
1327 let on_disconnect_fired = Arc::new(Mutex::new(false));
1328 let request_fired = Arc::new(Mutex::new(false));
1329
1330 async fn on_connection(
1331 Extension(connected): Extension<Arc<Mutex<bool>>>,
1332 ) -> Result<bool, CatBridgeError> {
1333 let mut locked = connected
1334 .lock()
1335 .expect("Failed to lock connected fired extension");
1336 *locked = true;
1337 Ok(true)
1338 }
1339 async fn on_disconnect(
1340 Extension(disconnected): Extension<Arc<Mutex<bool>>>,
1341 ) -> Result<(), CatBridgeError> {
1342 let mut locked = disconnected
1343 .lock()
1344 .expect("Failed to lock connected fired extension");
1345 *locked = true;
1346 Ok(())
1347 }
1348 async fn on_request(
1349 Extension(request): Extension<Arc<Mutex<bool>>>,
1350 ) -> Result<Response, CatBridgeError> {
1351 let mut locked = request
1352 .lock()
1353 .expect("Failed to lock connected fired extension");
1354 *locked = true;
1355
1356 let mut resp = Response::new_with_body(Bytes::from(vec![0x1]));
1357 resp.should_close_connection();
1358 Ok(resp)
1359 }
1360
1361 let mut router = Router::new();
1362 router
1363 .add_route(&[0x1, 0x2, 0x3], on_request)
1364 .expect("Failed to add a route!");
1365 router.layer(Extension(request_fired.clone()));
1366
1367 let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1368 .await
1369 .expect("Timed out trying to find free port!")
1370 .expect("Failed to find free TCP port on system.");
1371
1372 let mut srv = timeout(
1373 Duration::from_secs(5),
1374 TCPServer::new_with_state(
1375 "test",
1376 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1377 router,
1378 (None, None),
1379 NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1380 (),
1381 #[cfg(debug_assertions)]
1382 true,
1383 ),
1384 )
1385 .await
1386 .expect("Timed out starting server")
1387 .expect("Failed to create TCP Server.");
1388
1389 srv.set_on_stream_begin(on_connection)
1390 .expect("Failed to register stream begin handler!");
1391 srv.layer_on_stream_begin(Extension(connected_fired.clone()))
1392 .expect("Failed to add layer to on stream begin!");
1393 srv.set_on_stream_end(on_disconnect)
1394 .expect("Failed to register stream end handler!");
1395 srv.layer_on_stream_end(Extension(on_disconnect_fired.clone()))
1396 .expect("Failed to add layer to on_disconnect!");
1397
1398 let spawned =
1399 tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1400 {
1401 loop {
1402 let client_stream_res = timeout(
1403 Duration::from_secs(10),
1404 TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1405 )
1406 .await
1407 .expect("Service timed out waiting for connection!");
1408 if client_stream_res.is_err() {
1410 continue;
1411 }
1412 let mut client_stream = client_stream_res.unwrap();
1413 client_stream
1414 .write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1415 .await
1416 .expect("Failed to write to client stream");
1417 let mut buff = [0_u8; 1];
1419 timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1420 .await
1421 .expect("Timed out reading from client stream")
1422 .expect("Failed to read data from client stream");
1423 timeout(Duration::from_secs(5), client_stream.shutdown())
1424 .await
1425 .expect("Timed out shutting down client stream")
1426 .expect("Failed to shutdown client stream.");
1427 break;
1428 }
1429 }
1430 std::mem::drop(spawned);
1432
1433 let locked_connect = connected_fired
1434 .lock()
1435 .expect("Failed to lock second connect");
1436 let locked_disconnect = on_disconnect_fired
1437 .lock()
1438 .expect("Failed to lock second on_disconnect");
1439 let locked_request = request_fired.lock().expect("Failed to lock second request");
1440
1441 assert!(*locked_connect, "on connection handler never fired!");
1442 assert!(*locked_disconnect, "on disconnect handler never fired!");
1443 assert!(*locked_request, "on request handler never fired!");
1444 }
1445
1446 #[tokio::test]
1447 pub async fn nagled_logic_works() {
1448 let requests_fired = Arc::new(AtomicU8::new(0));
1449
1450 async fn on_request(
1451 Extension(request): Extension<Arc<AtomicU8>>,
1452 ) -> Result<Response, CatBridgeError> {
1453 request.fetch_add(1, Ordering::SeqCst);
1454 let resp = Response::new_with_body(Bytes::from(vec![0x1]));
1455 Ok(resp)
1456 }
1457
1458 let mut router = Router::new();
1459 router
1460 .add_route(&[0x1, 0x2, 0x3], on_request)
1461 .expect("Failed to add a route!");
1462 router.layer(Extension(requests_fired.clone()));
1463
1464 let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1465 .await
1466 .expect("Timed out finding port to bind too!")
1467 .expect("Failed to find any free tcp v4 port on system!");
1468 let srv = timeout(
1469 Duration::from_secs(5),
1470 TCPServer::new_with_state(
1471 "test",
1472 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1473 router,
1474 (None, None),
1475 NagleGuard::EndSigilSearch(&[0xFF, 0xFF, 0xFF, 0xFF]),
1476 (),
1477 #[cfg(debug_assertions)]
1478 true,
1479 ),
1480 )
1481 .await
1482 .expect("timed out starting TCP Server for test")
1483 .expect("falied to create local tcp server!");
1484
1485 let spawned =
1486 tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1487 {
1488 loop {
1489 let client_stream_res = timeout(
1490 Duration::from_secs(10),
1491 TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1492 )
1493 .await
1494 .expect("Service timed out waiting for connection!");
1495 if client_stream_res.is_err() {
1497 continue;
1498 }
1499 let mut client_stream = client_stream_res.unwrap();
1500
1501 client_stream
1502 .write_all(&[0x1, 0x2, 0x3, 0xFF, 0xFF])
1503 .await
1504 .expect("Failed to write to client_stream");
1505 client_stream
1506 .flush()
1507 .await
1508 .expect("Failed to flush client_stream");
1509 client_stream
1510 .write_all(&[0xFF, 0xFF, 0x1, 0x2, 0x3, 0xFF, 0xFF, 0xFF, 0xFF])
1511 .await
1512 .expect("Failed to issue second write call to client_stream");
1513 let mut buff = [0_u8; 2];
1514 let read_bytes = timeout(Duration::from_secs(5), client_stream.read(&mut buff))
1515 .await
1516 .expect("Timed out reading from client_stream")
1517 .expect("Failed to read from client_stream!");
1518 if read_bytes == 1 {
1519 timeout(Duration::from_secs(5), client_stream.read(&mut buff[1..]))
1520 .await
1521 .expect("Timed out reading from client_stream")
1522 .expect("Failed to read from client_stream!");
1523 }
1524
1525 timeout(Duration::from_secs(5), client_stream.shutdown())
1526 .await
1527 .expect("Timed out shutting down client stream")
1528 .expect("Failed to shutdown client stream.");
1529 break;
1530 }
1531 }
1532 std::mem::drop(spawned);
1534
1535 assert_eq!(
1536 requests_fired.load(Ordering::SeqCst),
1537 2,
1538 "on request did not fire the correct amount of times!",
1539 );
1540 }
1541
1542 #[tokio::test]
1543 pub async fn slowloris_is_blocked() {
1544 let mut router = Router::new();
1545 router
1546 .add_route(&[0x1, 0x2, 0x3], || async {
1547 Ok(Response::new_with_body(Bytes::from(vec![0x1])))
1548 })
1549 .expect("Failed to add a route!");
1550
1551 let found_port = timeout(Duration::from_secs(5), get_free_tcp_v4_port())
1552 .await
1553 .expect("Timed out finding port to bind too!")
1554 .expect("Failed to find any free tcp v4 port on system!");
1555 let srv = timeout(
1556 Duration::from_secs(5),
1557 TCPServer::new_with_state(
1558 "test",
1559 SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1560 router,
1561 (None, None),
1562 NagleGuard::EndSigilSearch(&[0x10, 0x11, 0x12]),
1563 (),
1564 #[cfg(debug_assertions)]
1565 true,
1566 ),
1567 )
1568 .await
1569 .expect("timed out starting TCP Server for test")
1570 .expect("falied to create local tcp server!");
1571
1572 let spawned =
1573 tokio::task::spawn(async move { srv.bind().await.expect("Failed to bind server!") });
1574 let read_bytes;
1575 {
1576 loop {
1577 let client_stream_res = timeout(
1578 Duration::from_secs(10),
1579 TcpStream::connect(SocketAddrV4::new(Ipv4Addr::LOCALHOST, found_port)),
1580 )
1581 .await
1582 .expect("Service timed out waiting for connection!");
1583 if client_stream_res.is_err() {
1585 continue;
1586 }
1587 let mut client_stream = client_stream_res.unwrap();
1588
1589 client_stream
1590 .write_all(&[0x1, 0x2, 0x3, 0x10])
1591 .await
1592 .expect("Failed to write to client_stream");
1593 client_stream
1594 .flush()
1595 .await
1596 .expect("Failed to flush client_stream");
1597 tokio::time::sleep(Duration::from_secs(5)).await;
1600 set_now(
1601 SystemTime::now()
1602 .checked_add(Duration::from_secs(900_00_000))
1603 .expect("Failed to add time to systemtime"),
1604 );
1605 client_stream
1607 .write_all(&[0x11, 0x12])
1608 .await
1609 .expect("Failed to write to client_stream");
1610
1611 let mut buff = [0_u8; 1];
1612 read_bytes = timeout(Duration::from_secs(10), client_stream.read(&mut buff))
1615 .await
1616 .expect("timed out trying to wait for disconnect")
1617 .expect("failure reading from stream");
1618 break;
1619 }
1620 }
1621 std::mem::drop(spawned);
1622
1623 assert_eq!(read_bytes, 0, "Client didn't error on slowloris'd packet");
1624 }
1625}