use crate::{
errors::{CatBridgeError, NetworkError},
net::{
DEFAULT_SLOWLORIS_TIMEOUT, STREAM_ID, TCP_READ_BUFFER_SIZE,
additions::RequestID,
client::{
errors::CommonNetClientNetworkError,
models::{
DisconnectAsyncDropClient, RequestStreamEvent, RequestStreamMessage,
UnderlyingOnStreamBeginService, UnderlyingOnStreamEndService,
},
},
errors::{CommonNetAPIError, CommonNetNetworkError},
handlers::{
OnRequestStreamBeginHandler, OnRequestStreamEndHandler, OnStreamBeginHandlerAsService,
OnStreamEndHandlerAsService,
},
models::{FromRequestParts, NagleGuard, PostNagleFnTy, PreNagleFnTy, Request, Response},
now,
},
};
use bytes::{Bytes, BytesMut};
use fnv::{FnvHashMap, FnvHashSet};
use futures::future::join_all;
use miette::miette;
use scc::HashMap as ConcurrentHashMap;
use std::{
collections::VecDeque,
fmt::{Debug, Formatter, Result as FmtResult},
hash::BuildHasherDefault,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::{Duration, Instant, SystemTime},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream, ToSocketAddrs},
sync::mpsc::{
Receiver as BoundedReceiver, Sender as BoundedSender, channel as bounded_channel,
error::SendTimeoutError,
},
task::{Builder as TaskBuilder, block_in_place},
time::{sleep, timeout},
};
use tower::{Layer, Service, util::BoxCloneService};
use tracing::{Instrument, debug, error_span, trace, warn};
use valuable::{Fields, NamedField, NamedValues, StructDef, Structable, Valuable, Value, Visit};
#[cfg(debug_assertions)]
use crate::net::SPRIG_TRACE_IO;
const EMPTY_TIMEOUT: Duration = Duration::from_secs(0);
pub struct TCPClient {
cat_dev_slowdown: Option<Duration>,
chunk_output_at_size: Option<usize>,
keep_all_responses: bool,
nagle_guard: NagleGuard,
on_stream_begin: Option<UnderlyingOnStreamBeginService<()>>,
on_stream_end: Option<UnderlyingOnStreamEndService<()>>,
pre_nagle_hook: Option<&'static dyn PreNagleFnTy>,
post_nagle_hook: Option<&'static dyn PostNagleFnTy>,
primary_stream_id: Arc<AtomicU64>,
streams: Arc<ConcurrentHashMap<u64, TCPClientStream>>,
service_name: &'static str,
slowloris_timeout: Duration,
#[cfg(debug_assertions)]
trace_during_debug: bool,
}
impl TCPClient {
#[must_use]
pub fn new(
service_name: &'static str,
guard: impl Into<NagleGuard>,
nagle_hooks: (
Option<&'static dyn PreNagleFnTy>,
Option<&'static dyn PostNagleFnTy>,
),
trace_io_during_debug: bool,
) -> Self {
#[cfg(not(debug_assertions))]
{
if trace_io_during_debug {
warn!(
"Trace IO was turned on, but debug assertsions were not compiled in. Tracing of I/O will not happen. Please recompile cat-dev with debug assertions to properly trace I/O.",
);
}
}
Self {
cat_dev_slowdown: None,
chunk_output_at_size: None,
keep_all_responses: false,
nagle_guard: guard.into(),
on_stream_begin: None,
on_stream_end: None,
pre_nagle_hook: nagle_hooks.0,
post_nagle_hook: nagle_hooks.1,
primary_stream_id: Arc::new(AtomicU64::new(0)),
service_name,
slowloris_timeout: DEFAULT_SLOWLORIS_TIMEOUT,
streams: Arc::new(ConcurrentHashMap::default()),
#[cfg(debug_assertions)]
trace_during_debug: trace_io_during_debug || *SPRIG_TRACE_IO,
}
}
pub const fn set_cat_dev_slowdown(&mut self, slowdown: Option<Duration>) {
self.cat_dev_slowdown = slowdown;
}
pub const fn should_keep_all_responses(&mut self) {
self.keep_all_responses = true;
}
pub const fn set_keep_all_responses(&mut self, keep: bool) {
self.keep_all_responses = keep;
}
pub fn set_primary_stream(&mut self, stream_id: u64) {
self.primary_stream_id.store(stream_id, Ordering::Release);
}
#[must_use]
pub const fn chunk_output_at_size(&self) -> Option<usize> {
self.chunk_output_at_size
}
pub const fn set_chunk_output_at_size(&mut self, new_size: Option<usize>) {
self.chunk_output_at_size = new_size;
}
#[must_use]
pub const fn slowloris_timeout(&self) -> Duration {
self.slowloris_timeout
}
pub const fn set_slowloris_timeout(&mut self, slowloris_timeout: Duration) {
self.slowloris_timeout = slowloris_timeout;
}
#[must_use]
pub const fn on_stream_begin(&self) -> Option<&UnderlyingOnStreamBeginService<()>> {
self.on_stream_begin.as_ref()
}
pub fn set_raw_on_stream_begin(
&mut self,
on_start: Option<UnderlyingOnStreamBeginService<()>>,
) -> Result<(), CommonNetAPIError> {
if self.on_stream_begin.is_some() {
return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
}
self.on_stream_begin = on_start;
Ok(())
}
pub fn set_on_stream_begin<HandlerTy, HandlerParamsTy>(
&mut self,
handler: HandlerTy,
) -> Result<(), CommonNetAPIError>
where
HandlerParamsTy: Send + 'static,
HandlerTy: OnRequestStreamBeginHandler<HandlerParamsTy, ()> + Clone + Send + 'static,
{
if self.on_stream_begin.is_some() {
return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
}
let boxed = BoxCloneService::new(OnStreamBeginHandlerAsService::new(handler));
self.on_stream_begin = Some(boxed);
Ok(())
}
pub fn set_on_stream_begin_service<ServiceTy>(
&mut self,
service_ty: ServiceTy,
) -> Result<(), CommonNetAPIError>
where
ServiceTy: Clone
+ Send
+ Service<RequestStreamEvent<()>, Response = bool, Error = CatBridgeError>
+ 'static,
ServiceTy::Future: Send + 'static,
{
if self.on_stream_begin.is_some() {
return Err(CommonNetAPIError::OnStreamBeginAlreadyRegistered);
}
self.on_stream_begin = Some(BoxCloneService::new(service_ty));
Ok(())
}
pub fn layer_on_stream_begin<LayerTy, ServiceTy>(
&mut self,
layer: LayerTy,
) -> Result<(), CommonNetAPIError>
where
LayerTy: Layer<UnderlyingOnStreamBeginService<()>, Service = ServiceTy>,
ServiceTy: Service<RequestStreamEvent<()>, Response = bool, Error = CatBridgeError>
+ Clone
+ Send
+ 'static,
<LayerTy::Service as Service<RequestStreamEvent<()>>>::Future: Send + 'static,
{
let Some(srvc) = self.on_stream_begin.take() else {
return Err(CommonNetAPIError::OnStreamBeginNotRegistered);
};
self.on_stream_begin = Some(BoxCloneService::new(layer.layer(srvc)));
Ok(())
}
#[must_use]
pub const fn on_stream_end(&self) -> Option<&UnderlyingOnStreamEndService<()>> {
self.on_stream_end.as_ref()
}
pub fn set_raw_on_stream_end(
&mut self,
on_end: Option<UnderlyingOnStreamEndService<()>>,
) -> Result<(), CommonNetAPIError> {
if self.on_stream_end.is_some() {
return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
}
self.on_stream_end = on_end;
Ok(())
}
pub fn set_on_stream_end<HandlerTy, HandlerParamsTy>(
&mut self,
handler: HandlerTy,
) -> Result<(), CommonNetAPIError>
where
HandlerParamsTy: Send + 'static,
HandlerTy: OnRequestStreamEndHandler<HandlerParamsTy, ()> + Clone + Send + 'static,
{
if self.on_stream_end.is_some() {
return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
}
let boxed = BoxCloneService::new(OnStreamEndHandlerAsService::new(handler));
self.on_stream_end = Some(boxed);
Ok(())
}
pub fn set_on_stream_end_service<ServiceTy>(
&mut self,
service_ty: ServiceTy,
) -> Result<(), CommonNetAPIError>
where
ServiceTy: Clone
+ Send
+ Service<RequestStreamEvent<()>, Response = (), Error = CatBridgeError>
+ 'static,
ServiceTy::Future: Send + 'static,
{
if self.on_stream_end.is_some() {
return Err(CommonNetAPIError::OnStreamEndAlreadyRegistered);
}
self.on_stream_end = Some(BoxCloneService::new(service_ty));
Ok(())
}
pub fn layer_on_stream_end<LayerTy, ServiceTy>(
&mut self,
layer: LayerTy,
) -> Result<(), CommonNetAPIError>
where
LayerTy: Layer<UnderlyingOnStreamEndService<()>, Service = ServiceTy>,
ServiceTy: Service<RequestStreamEvent<()>, Response = (), Error = CatBridgeError>
+ Clone
+ Send
+ 'static,
<LayerTy::Service as Service<RequestStreamEvent<()>>>::Future: Send + 'static,
{
let Some(srvc) = self.on_stream_end.take() else {
return Err(CommonNetAPIError::OnStreamEndNotRegistered);
};
self.on_stream_end = Some(BoxCloneService::new(layer.layer(srvc)));
Ok(())
}
pub async fn bind<AddrTy: ToSocketAddrs>(&self, address: AddrTy) -> Result<(), CatBridgeError> {
let listener = TcpListener::bind(address).await.map_err(NetworkError::IO)?;
let client_address = listener.local_addr().map_err(NetworkError::IO)?;
let cloned_stream_begin = self.on_stream_begin.clone();
let cloned_stream_end = self.on_stream_end.clone();
let cloned_nagle_guard = self.nagle_guard.clone();
let cloned_slowerloris_timeout = self.slowloris_timeout;
let streams_ref = self.streams.clone();
let primary_stream_id_ref = self.primary_stream_id.clone();
let cloned_chunk_output_at_size = self.chunk_output_at_size;
let cloned_pre_nagle_hook = self.pre_nagle_hook;
let cloned_post_nagle_hook = self.post_nagle_hook;
#[cfg(debug_assertions)]
let cloned_trace = self.trace_during_debug;
let cloned_service_name = self.service_name;
let cloned_cat_dev_slowdown = self.cat_dev_slowdown;
TaskBuilder::new()
.name("cat_dev::net::tcp_client::bind().loop")
.spawn(async move {
loop {
let (stream, server_address) = match listener.accept().await {
Ok(tuple) => tuple,
Err(cause) => {
warn!(
?cause,
client.address = %client_address,
"cat_dev::net::tcp_client::bind(): Failed to accept connection!",
);
continue;
}
};
let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
let cloned_cloned_stream_begin = cloned_stream_begin.clone();
let cloned_cloned_stream_end = cloned_stream_end.clone();
let cloned_cloned_nagle_guard = cloned_nagle_guard.clone();
let cloned_streams_ref = streams_ref.clone();
let cloned_primary_stream_id_ref = primary_stream_id_ref.clone();
if let Err(cause) = TaskBuilder::new()
.name("cat_dev::net::tcp_client::bind().connection.handle")
.spawn(async move {
if let Err(cause) = Self::handle_tcp_stream(
stream,
stream_id,
server_address,
cloned_cloned_stream_begin,
cloned_cloned_stream_end,
cloned_cloned_nagle_guard,
cloned_slowerloris_timeout,
cloned_streams_ref,
cloned_primary_stream_id_ref,
cloned_chunk_output_at_size,
cloned_pre_nagle_hook,
cloned_post_nagle_hook,
cloned_cat_dev_slowdown,
#[cfg(debug_assertions)]
cloned_trace,
)
.instrument(error_span!(
"CatDevTCPClientConnect",
client.address = %client_address,
server.address = %server_address,
client.service = cloned_service_name,
stream.id = stream_id,
stream.stream_type = "client",
))
.await
{
warn!(
?cause,
client.address = %client_address,
server.address = %server_address,
client.service = cloned_service_name,
"Error escaped while handling TCP Connection.",
);
}
}) {
warn!(
?cause,
client.address = %client_address,
server.address = %server_address,
client.service = cloned_service_name,
"Error handling client connection, no task could be allocated.",
);
}
trace!(
server.address = %server_address,
client.address = %client_address,
"cat_dev::net::tcp_client::bind(): received connection (listener.accept())",
);
}
})
.map_err(CatBridgeError::SpawnFailure)?;
Ok(())
}
pub async fn wait_for_connection(&self) {
while self.get_active_sid().await.is_err() {
sleep(Duration::from_secs(1)).await;
}
}
pub async fn connect<AddrTy: ToSocketAddrs>(
&self,
address: AddrTy,
) -> Result<u64, CatBridgeError> {
let raw_stream = TcpStream::connect(address)
.await
.map_err(NetworkError::IO)?;
let stream_id = STREAM_ID.fetch_add(1, Ordering::SeqCst);
let remote_address = raw_stream.peer_addr().map_err(NetworkError::IO)?;
let local_address = raw_stream.local_addr().map_err(NetworkError::IO)?;
trace!(
server.address = %remote_address,
client.address = %local_address,
stream.id = stream_id,
stream.stream_type = "client",
"cat_dev::net::tcp_client::connect(): started connection (TcpStream::connect())",
);
let cloned_stream_begin = self.on_stream_begin.clone();
let cloned_stream_end = self.on_stream_end.clone();
let cloned_nagle_guard = self.nagle_guard.clone();
let cloned_slowerloris_timeout = self.slowloris_timeout;
let streams_ref = self.streams.clone();
let primary_stream_id_ref = self.primary_stream_id.clone();
let cloned_chunk_output_at_size = self.chunk_output_at_size;
let cloned_pre_nagle_hook = self.pre_nagle_hook;
let cloned_post_nagle_hook = self.post_nagle_hook;
#[cfg(debug_assertions)]
let cloned_trace = self.trace_during_debug;
let cloned_service_name = self.service_name;
let cloned_cat_dev_slowdown = self.cat_dev_slowdown;
TaskBuilder::new()
.name("cat_dev::net::tcp_client::connect().connection.handle")
.spawn(async move {
if let Err(cause) = Self::handle_tcp_stream(
raw_stream,
stream_id,
remote_address,
cloned_stream_begin,
cloned_stream_end,
cloned_nagle_guard,
cloned_slowerloris_timeout,
streams_ref,
primary_stream_id_ref,
cloned_chunk_output_at_size,
cloned_pre_nagle_hook,
cloned_post_nagle_hook,
cloned_cat_dev_slowdown,
#[cfg(debug_assertions)]
cloned_trace,
)
.instrument(error_span!(
"CatDevTCPClientConnect",
client.address = %local_address,
server.address = %remote_address,
client.service = cloned_service_name,
stream.id = stream_id,
stream.stream_type = "client",
))
.await
{
warn!(
?cause,
client.address = %local_address,
server.address = %remote_address,
client.service = cloned_service_name,
"Error escaped while handling TCP Connection.",
);
}
})
.map_err(CatBridgeError::SpawnFailure)?;
Ok(stream_id)
}
pub async fn send<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
&self,
body: BodyTy,
wait_for_response_timeout: Option<Duration>,
) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
let mut request = Request::new_with_state(
body.try_into().map_err(|cause| {
CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
})?,
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
(),
None,
);
let req_id = RequestID::generate();
request.extensions_mut().insert(req_id.clone());
self.common_send(request, req_id, wait_for_response_timeout)
.await
}
pub async fn send_with_read_amount<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
&self,
body: BodyTy,
wait_for_response_timeout: Option<Duration>,
explicit_read_amount: usize,
) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
let mut request = Request::new_with_state_and_read_amount(
body.try_into().map_err(|cause| {
CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
})?,
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
(),
None,
explicit_read_amount,
);
let req_id = RequestID::generate();
request.extensions_mut().insert(req_id.clone());
self.common_send(request, req_id, wait_for_response_timeout)
.await
}
pub async fn broadcast_send<ErrorTy: Debug, BodyTy: TryInto<Bytes, Error = ErrorTy>>(
&self,
body: BodyTy,
wait_for_response_timeout: Duration,
) -> Result<FnvHashMap<u64, Option<Response>>, CatBridgeError> {
let mut request = Request::new_with_state(
body.try_into().map_err(|cause| {
CommonNetClientNetworkError::SerializationError(miette!("{cause:?}"))
})?,
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
(),
None,
);
let req_id = RequestID::generate();
request.extensions_mut().insert(req_id.clone());
let mut ids = FnvHashSet::default();
self.streams
.iter_async(|stream_id, _stream| {
ids.insert(*stream_id);
true
})
.await;
let mut tasks = Vec::with_capacity(ids.len());
for id in &ids {
tasks.push(self.send_to_stream(*id, request.clone(), wait_for_response_timeout));
}
join_all(tasks)
.await
.into_iter()
.collect::<Result<(), NetworkError>>()?;
let mut response_tasks = Vec::with_capacity(ids.len());
for id in &ids {
response_tasks.push(self.get_response_from_stream(*id, req_id.clone()));
}
let responses = timeout(wait_for_response_timeout, join_all(response_tasks))
.await
.map_err(|_| NetworkError::Timeout(wait_for_response_timeout))?;
let mut map =
FnvHashMap::with_capacity_and_hasher(ids.len(), BuildHasherDefault::default());
for (got_stream_id, response) in responses {
map.insert(got_stream_id, response);
}
Ok(map)
}
pub async fn receive(&self, wait_until: Duration) -> Result<Option<Response>, NetworkError> {
let active_sid = self.get_active_sid().await?;
let mut tasks;
if self.keep_all_responses {
tasks = vec![self.get_any_response_from_stream(active_sid)];
} else {
let mut ids = FnvHashSet::default();
self.streams
.iter_async(|stream_id, _stream| {
ids.insert(*stream_id);
true
})
.await;
tasks = Vec::with_capacity(ids.len());
for id in ids {
tasks.push(self.get_any_response_from_stream(id));
}
}
let responses = timeout(wait_until, join_all(tasks))
.await
.map_err(|_| NetworkError::Timeout(wait_until))?;
for (got_stream_id, _, response) in responses {
if got_stream_id == active_sid {
return Ok(response);
}
}
Ok(None)
}
pub async fn take_all_response_for_request_id(
&self,
request_id: &RequestID,
wait_for: Duration,
) -> FnvHashMap<u64, Option<Response>> {
let mut ids = FnvHashSet::default();
self.streams
.iter_async(|stream_id, _stream| {
ids.insert(*stream_id);
true
})
.await;
let mut tasks = Vec::with_capacity(ids.len());
for id in &ids {
tasks.push(timeout(
wait_for,
self.get_response_from_stream(*id, request_id.clone()),
));
}
let mut results: FnvHashMap<u64, Option<Response>> =
join_all(tasks).await.into_iter().flatten().collect();
for id in ids {
results.entry(id).or_insert(None);
}
results
}
async fn common_send(
&self,
mock_req: Request<()>,
req_id: RequestID,
wait_for_response_timeout: Option<Duration>,
) -> Result<(u64, RequestID, Option<Response>), NetworkError> {
let active_sid = self.get_active_sid().await?;
let mut ids = FnvHashSet::default();
self.streams
.iter_async(|stream_id, _stream| {
ids.insert(*stream_id);
true
})
.await;
let mut tasks = Vec::with_capacity(ids.len());
for id in &ids {
tasks.push(self.send_to_stream(
*id,
mock_req.clone(),
wait_for_response_timeout.unwrap_or(DEFAULT_SLOWLORIS_TIMEOUT),
));
}
join_all(tasks)
.await
.into_iter()
.collect::<Result<(), NetworkError>>()?;
match wait_for_response_timeout {
None | Some(EMPTY_TIMEOUT) => Ok((active_sid, req_id, None)),
Some(duration) => {
let mut tasks;
if self.keep_all_responses {
tasks = vec![self.get_response_from_stream(active_sid, req_id.clone())];
} else {
tasks = Vec::with_capacity(ids.len());
for id in ids {
tasks.push(self.get_response_from_stream(id, req_id.clone()));
}
}
let responses = timeout(duration, join_all(tasks))
.await
.map_err(|_| NetworkError::Timeout(duration))?;
for (got_stream_id, response) in responses {
if got_stream_id == active_sid {
return Ok((active_sid, req_id, response));
}
}
Ok((active_sid, req_id, None))
}
}
}
#[allow(
// all of our parameters are very well named, and types are not close to
// overlapping with each other.
//
// we also just fundamenetally have a lot of state thanks to the complexity
// of all the things we have to handle for a TCP connection, e.g. NAGLE,
// delimiters, caches, etc.
//
// it is also only ever called from one internal function, so it's not like
// part of our public facing api.
clippy::too_many_arguments,
)]
async fn handle_tcp_stream(
mut stream: TcpStream,
stream_id: u64,
remote_address: SocketAddr,
on_stream_begin: Option<UnderlyingOnStreamBeginService<()>>,
on_stream_end: Option<UnderlyingOnStreamEndService<()>>,
nagle_guard: NagleGuard,
slowloris_timeout: Duration,
stream_lists: Arc<ConcurrentHashMap<u64, TCPClientStream>>,
active_stream_ptr: Arc<AtomicU64>,
chunk_output_on_size: Option<usize>,
pre_hook: Option<&'static dyn PreNagleFnTy>,
post_hook: Option<&'static dyn PostNagleFnTy>,
cat_dev_slowdown: Option<Duration>,
#[cfg(debug_assertions)] trace_io: bool,
) -> Result<(), CatBridgeError> {
let mut receive_packets_to_send: BoundedReceiver<RequestStreamMessage>;
let (response_sink_send, response_sink_recv) = bounded_channel(128);
{
let (mut sender, receiver) = bounded_channel(128);
if Self::initialize_stream(
on_stream_begin,
&mut sender,
&remote_address,
&stream,
stream_id,
)
.await?
{
return Ok(());
}
let mut active_stream =
TCPClientStream::new(remote_address, sender, receiver, response_sink_recv);
receive_packets_to_send = active_stream
.steal_send_requests_receiver()
.ok_or_else(|| CatBridgeError::ClosedChannel)?;
std::mem::drop(stream_lists.insert_async(stream_id, active_stream).await);
_ = active_stream_ptr.compare_exchange(
0,
stream_id,
Ordering::AcqRel,
Ordering::Acquire,
);
}
let _guard = on_stream_end
.map(|service| DisconnectAsyncDropClient::new(service, (), remote_address, stream_id));
let mut buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
let mut nagle_cache: Option<(BytesMut, SystemTime)> = None;
let mut cached_request_id: Option<RequestID> = None;
let mut nagle_overrides: VecDeque<Option<NagleGuard>> = VecDeque::with_capacity(128);
loop {
tokio::select! {
opt = receive_packets_to_send.recv() => {
if Self::handle_client_write_to_connection(
chunk_output_on_size,
opt,
pre_hook,
&mut cached_request_id,
stream_id,
&mut stream,
&mut nagle_overrides,
cat_dev_slowdown,
#[cfg(debug_assertions)]
trace_io,
).await? {
break;
}
}
read_res = stream.read_buf(&mut buff) => {
let read_bytes = read_res.map_err(NetworkError::IO)?;
buff.truncate(read_bytes);
if buff.is_empty() {
buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
continue;
}
if Self::handle_client_read_from_connection(
buff,
&nagle_guard,
&mut nagle_overrides,
slowloris_timeout,
&mut nagle_cache,
response_sink_send.clone(),
post_hook,
&mut cached_request_id,
stream_id,
#[cfg(debug_assertions)]
trace_io,
).await? {
break;
}
buff = BytesMut::with_capacity(TCP_READ_BUFFER_SIZE);
}
}
}
Ok(())
}
async fn initialize_stream(
on_stream_begin_handler: Option<UnderlyingOnStreamBeginService<()>>,
send_channel: &mut BoundedSender<RequestStreamMessage>,
remote_address: &SocketAddr,
tcp_stream: &TcpStream,
stream_id: u64,
) -> Result<bool, CatBridgeError> {
tcp_stream.set_nodelay(true).map_err(NetworkError::IO)?;
if let Some(mut handle) = on_stream_begin_handler
&& !handle
.call(RequestStreamEvent::new_with_state(
send_channel.clone(),
*remote_address,
Some(stream_id),
(),
))
.await?
{
trace!("handler failed on stream begin hook");
return Ok(true);
}
Ok(false)
}
#[allow(
// All of our types are very differently typed, and well named, so chance
// of confusion is low.
//
// Not to mention this is an internal only method.
clippy::too_many_arguments,
)]
async fn handle_client_read_from_connection<'data>(
mut buff: BytesMut,
nagle_guard: &'data NagleGuard,
nagle_overrides: &mut VecDeque<Option<NagleGuard>>,
slowloris_timeout: Duration,
nagle_cache: &'data mut Option<(BytesMut, SystemTime)>,
response_output: BoundedSender<(Option<RequestID>, Response)>,
cloned_post_nagle: Option<&'static dyn PostNagleFnTy>,
cached_request_id: &mut Option<RequestID>,
stream_id: u64,
#[cfg(debug_assertions)] trace_io: bool,
) -> Result<bool, CatBridgeError> {
if let Some(convert_fn) = cloned_post_nagle {
buff = BytesMut::from(block_in_place(|| (*convert_fn)(stream_id, buff.freeze())));
}
#[cfg(debug_assertions)]
{
if trace_io {
debug!(
body.hex = format!("{:02x?}", buff),
body.str = String::from_utf8_lossy(&buff).to_string(),
"cat-dev-trace-input-tcp-client",
);
}
}
let start_time = now();
if let Some((mut existing_buff, old_start_time)) = nagle_cache.take() {
let total_duration = start_time
.duration_since(old_start_time)
.unwrap_or(Duration::from_secs(0));
if total_duration > slowloris_timeout {
debug!(
cause = ?CommonNetNetworkError::SlowlorisTimeout(total_duration),
"slowloris-detected",
);
return Ok(true);
}
existing_buff.extend(buff);
buff = existing_buff;
}
let mut current_nagle_guard = if let Some(Some(guard)) = nagle_overrides.front() {
guard
} else {
nagle_guard
};
while let Some((start_of_packet, end_of_packet)) = current_nagle_guard.split(&buff)? {
let remaining_buff = buff.split_off(end_of_packet);
let _start_of_buff = buff.split_to(start_of_packet);
let req_body = buff.freeze();
buff = remaining_buff;
if let Err(cause) = response_output
.send((cached_request_id.take(), Response::new_with_body(req_body)))
.await
{
warn!(
?cause,
"internal queue failure will not send disconnect/response."
);
}
if !nagle_overrides.is_empty() {
nagle_overrides.pop_front();
current_nagle_guard = if let Some(Some(guard)) = nagle_overrides.front() {
guard
} else {
nagle_guard
};
}
}
if !buff.is_empty() {
_ = nagle_cache.insert((buff, start_time));
}
Ok(false)
}
#[allow(
// Well typed arguments, lots to do and all that.
clippy::too_many_arguments,
)]
async fn handle_client_write_to_connection(
chunk_output_on_size: Option<usize>,
to_send_to_client_opt: Option<RequestStreamMessage>,
pre_hook: Option<&'static dyn PreNagleFnTy>,
cached_request_id: &mut Option<RequestID>,
stream_id: u64,
raw_stream: &mut TcpStream,
nagle_overrides: &mut VecDeque<Option<NagleGuard>>,
cat_dev_slowdown: Option<Duration>,
#[cfg(debug_assertions)] trace_io: bool,
) -> Result<bool, CatBridgeError> {
let Some(to_send_to_client) = to_send_to_client_opt else {
return Ok(true);
};
match to_send_to_client {
RequestStreamMessage::Disconnect => {
_ = cached_request_id.take();
trace!("stream-disconnect-message");
Ok(true)
}
RequestStreamMessage::Request(mut req) => {
if let Some(explicit_read) = req.explicit_read_amount() {
nagle_overrides.push_back(Some(NagleGuard::StaticSize(explicit_read)));
} else {
nagle_overrides.push_back(None);
}
if !req.body().is_empty() {
if let Ok(req_id) = RequestID::from_request_parts(&mut req).await {
_ = cached_request_id.insert(req_id);
}
let messages = if let Some(size) = chunk_output_on_size {
req.body_owned()
.chunks(size)
.map(BytesMut::from)
.collect::<Vec<_>>()
} else {
vec![BytesMut::from(req.body_owned())]
};
for message in messages {
#[cfg(debug_assertions)]
if trace_io {
debug!(
body.hex = format!("{message:02x?}"),
body.str = String::from_utf8_lossy(&message).to_string(),
"cat-dev-trace-output-tcp-client",
);
}
let mut full_response = message.clone();
if let Some(pre) = pre_hook {
block_in_place(|| pre(stream_id, &mut full_response));
}
if let Some(slowdown) = cat_dev_slowdown {
sleep(slowdown).await;
}
raw_stream.writable().await.map_err(NetworkError::IO)?;
raw_stream
.write_all(&full_response)
.await
.map_err(NetworkError::IO)?;
}
}
Ok(false)
}
}
}
async fn send_to_stream(
&self,
stream_id: u64,
mut base_request: Request<()>,
timeout: Duration,
) -> Result<(), NetworkError> {
if let Some(stream) = self.streams.get_async(&stream_id).await {
base_request.update_request_source(stream.server_address(), Some(stream_id));
stream
.send_timeout(RequestStreamMessage::Request(base_request), timeout)
.await
.map_err(|cause| {
CommonNetClientNetworkError::CannotQueueSend(format!("{cause:?}")).into()
})
} else {
Ok(())
}
}
async fn get_any_response_from_stream(
&self,
stream_id: u64,
) -> (u64, Option<RequestID>, Option<Response>) {
if let Some(mut stream) = self.streams.get_async(&stream_id).await {
let Some((opt_req_id, response)) = stream.response_channel_mut().recv().await else {
return (stream_id, None, None);
};
(stream_id, opt_req_id, Some(response))
} else {
(stream_id, None, None)
}
}
async fn get_response_from_stream(
&self,
stream_id: u64,
request_id: RequestID,
) -> (u64, Option<Response>) {
if let Some(mut stream) = self.streams.get_async(&stream_id).await {
while let Some((opt_req_id, response)) = stream.response_channel_mut().recv().await {
if let Some(got_req_id) = opt_req_id
&& got_req_id == request_id
{
return (stream_id, Some(response));
}
}
(stream_id, None)
} else {
(stream_id, None)
}
}
async fn get_active_sid(&self) -> Result<u64, CommonNetClientNetworkError> {
let active_sid = self.primary_stream_id.load(Ordering::Acquire);
if active_sid == 0 {
return Err(CommonNetClientNetworkError::NotConnectedToServer);
}
if !self.streams.contains_async(&active_sid).await {
let mut oldest_stream = None;
self.streams
.iter_async(|stream_id, stream| {
if let Some((_strm_id, strm_created_at)) = oldest_stream {
if stream.opened_at() < strm_created_at {
_ = oldest_stream.insert((*stream_id, stream.opened_at()));
}
} else {
_ = oldest_stream.insert((*stream_id, stream.opened_at()));
}
true
})
.await;
}
Ok(active_sid)
}
}
impl Debug for TCPClient {
fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
let mut tcp_dbg_struct = fmt.debug_struct("TCPClient");
tcp_dbg_struct
.field("cat_dev_slowdown", &self.cat_dev_slowdown)
.field("chunk_output_at_size", &self.chunk_output_at_size)
.field("keep_all_responses", &self.keep_all_responses)
.field("nagle_guard", &self.nagle_guard)
.field("has_on_stream_begin", &self.on_stream_begin.is_some())
.field("has_on_stream_end", &self.on_stream_end.is_some())
.field("has_pre_nagle_hook", &self.pre_nagle_hook.is_some())
.field("has_post_nagle_hook", &self.post_nagle_hook.is_some())
.field(
"primary_stream_id",
&self.primary_stream_id.load(Ordering::Relaxed),
)
.field("streams", &self.streams)
.field("service_name", &self.service_name)
.field("slowloris_timeout", &self.slowloris_timeout);
#[cfg(debug_assertions)]
{
tcp_dbg_struct.field("trace_during_debug", &self.trace_during_debug);
}
tcp_dbg_struct.finish()
}
}
const TCP_CLIENT_FIELDS: &[NamedField<'static>] = &[
NamedField::new("cat_dev_slowdown"),
NamedField::new("chunk_output_at_size"),
NamedField::new("keep_all_responses"),
NamedField::new("nagle_guard"),
NamedField::new("has_on_stream_begin"),
NamedField::new("has_on_stream_end"),
NamedField::new("has_pre_nagle_hook"),
NamedField::new("has_post_nagle_hook"),
NamedField::new("primary_stream_id"),
NamedField::new("streams"),
NamedField::new("service_name"),
NamedField::new("slowloris_timeout"),
#[cfg(debug_assertions)]
NamedField::new("trace_during_debug"),
];
impl Structable for TCPClient {
fn definition(&self) -> StructDef<'_> {
StructDef::new_static("TCPClient", Fields::Named(TCP_CLIENT_FIELDS))
}
}
impl Valuable for TCPClient {
fn as_value(&self) -> Value<'_> {
Value::Structable(self)
}
fn visit(&self, visitor: &mut dyn Visit) {
let mut valuable_map = FnvHashMap::default();
self.streams.iter_sync(|stream_id, stream| {
valuable_map.insert(*stream_id, stream.to_valuable());
true
});
visitor.visit_named_fields(&NamedValues::new(
TCP_CLIENT_FIELDS,
&[
Valuable::as_value(&if let Some(slowdown) = self.cat_dev_slowdown {
format!("{}ms", slowdown.as_millis())
} else {
"<none>".to_string()
}),
Valuable::as_value(&self.chunk_output_at_size),
Valuable::as_value(&self.keep_all_responses),
Valuable::as_value(&self.nagle_guard),
Valuable::as_value(&self.on_stream_begin.is_some()),
Valuable::as_value(&self.on_stream_end.is_some()),
Valuable::as_value(&self.pre_nagle_hook.is_some()),
Valuable::as_value(&self.post_nagle_hook.is_some()),
Valuable::as_value(&self.primary_stream_id.load(Ordering::Relaxed)),
Valuable::as_value(&valuable_map),
Valuable::as_value(&self.service_name),
Valuable::as_value(&self.slowloris_timeout.as_secs()),
#[cfg(debug_assertions)]
Valuable::as_value(&self.trace_during_debug),
],
));
}
}
struct TCPClientStream {
remote_address: SocketAddr,
response_channel: BoundedReceiver<(Option<RequestID>, Response)>,
send_requests_receiver: Option<BoundedReceiver<RequestStreamMessage>>,
send_requests: BoundedSender<RequestStreamMessage>,
time_opened: Instant,
}
impl TCPClientStream {
#[must_use]
pub fn new(
remote_address: SocketAddr,
sender: BoundedSender<RequestStreamMessage>,
receiver: BoundedReceiver<RequestStreamMessage>,
response_channel: BoundedReceiver<(Option<RequestID>, Response)>,
) -> Self {
Self {
remote_address,
response_channel,
send_requests_receiver: Some(receiver),
send_requests: sender,
time_opened: Instant::now(),
}
}
#[must_use]
pub const fn to_valuable(&self) -> TCPClientStreamValuable {
TCPClientStreamValuable {
receiver_stolen: self.send_requests_receiver.is_none(),
time_opened: self.time_opened,
}
}
pub const fn server_address(&self) -> SocketAddr {
self.remote_address
}
#[must_use]
pub const fn response_channel_mut(
&mut self,
) -> &mut BoundedReceiver<(Option<RequestID>, Response)> {
&mut self.response_channel
}
#[must_use]
pub fn steal_send_requests_receiver(
&mut self,
) -> Option<BoundedReceiver<RequestStreamMessage>> {
self.send_requests_receiver.take()
}
pub async fn send_timeout(
&self,
message: RequestStreamMessage,
timeout: Duration,
) -> Result<(), SendTimeoutError<RequestStreamMessage>> {
self.send_requests.send_timeout(message, timeout).await
}
pub const fn opened_at(&self) -> Instant {
self.time_opened
}
}
impl Debug for TCPClientStream {
fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
fmt.debug_struct("TCPClientStream")
.field("receiver_stolen", &self.send_requests_receiver.is_none())
.field("time_opened", &self.time_opened)
.finish_non_exhaustive()
}
}
impl PartialEq for TCPClientStream {
fn eq(&self, other: &Self) -> bool {
self.time_opened == other.time_opened
}
}
impl PartialOrd for TCPClientStream {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.time_opened.cmp(&other.time_opened))
}
}
const TCP_CLIENT_STREAM_FIELDS: &[NamedField<'static>] = &[
NamedField::new("receiver_stolen"),
NamedField::new("time_opened"),
];
impl Structable for TCPClientStream {
fn definition(&self) -> StructDef<'_> {
StructDef::new_static("TCPClientStream", Fields::Named(TCP_CLIENT_STREAM_FIELDS))
}
}
impl Valuable for TCPClientStream {
fn as_value(&self) -> Value<'_> {
Value::Structable(self)
}
fn visit(&self, visitor: &mut dyn Visit) {
visitor.visit_named_fields(&NamedValues::new(
TCP_CLIENT_STREAM_FIELDS,
&[
Valuable::as_value(&self.send_requests_receiver.is_none()),
Valuable::as_value(
&SystemTime::now()
.checked_add(self.time_opened.elapsed())
.unwrap_or_else(SystemTime::now)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
],
));
}
}
struct TCPClientStreamValuable {
receiver_stolen: bool,
time_opened: Instant,
}
impl Structable for TCPClientStreamValuable {
fn definition(&self) -> StructDef<'_> {
StructDef::new_static("TCPClientStream", Fields::Named(TCP_CLIENT_STREAM_FIELDS))
}
}
impl Valuable for TCPClientStreamValuable {
fn as_value(&self) -> Value<'_> {
Value::Structable(self)
}
fn visit(&self, visitor: &mut dyn Visit) {
visitor.visit_named_fields(&NamedValues::new(
TCP_CLIENT_STREAM_FIELDS,
&[
Valuable::as_value(&self.receiver_stolen),
Valuable::as_value(
&SystemTime::now()
.checked_add(self.time_opened.elapsed())
.unwrap_or_else(SystemTime::now)
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
],
));
}
}