use std::{
cell::RefCell,
fmt::Debug,
rc::{Rc, Weak},
time::Instant,
};
use mio::{Token, net::TcpStream};
use rusty_ulid::Ulid;
use sozu_command::{logging::ansi_palette, ready::Ready};
use super::{
BackendStatus, ConnectionH1, ConnectionH2, Context, Endpoint, GlobalStreamId, MuxResult,
Position, Router,
h2::{self, H2StreamId},
};
use crate::metrics::names;
use crate::{
L7ListenerHandler, ListenerHandler, Readiness, backends::Backend, pool::Pool,
socket::SocketHandler, timer::TimeoutContainer,
};
macro_rules! log_module_context {
() => {{
let (open, reset, _, _, _) = ansi_palette();
format!("{open}MUX-CONN{reset}\t >>>", open = open, reset = reset)
}};
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub enum Connection<Front: SocketHandler> {
H1(ConnectionH1<Front>),
H2(ConnectionH2<Front>),
}
macro_rules! forward {
($self:expr, $method:ident ( $($args:tt)* )) => {
match $self {
Connection::H1(c) => c.$method($($args)*),
Connection::H2(c) => c.$method($($args)*),
}
};
(&$self:expr, $field:ident) => {
match $self {
Connection::H1(c) => &c.$field,
Connection::H2(c) => &c.$field,
}
};
(&mut $self:expr, $field:ident) => {
match $self {
Connection::H1(c) => &mut c.$field,
Connection::H2(c) => &mut c.$field,
}
};
}
impl<Front: SocketHandler> Connection<Front> {
pub fn new_h1_server(
session_ulid: Ulid,
front_stream: Front,
timeout_container: TimeoutContainer,
) -> Connection<Front> {
Connection::H1(ConnectionH1 {
socket: front_stream,
position: Position::Server,
readiness: Readiness {
interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
event: Ready::EMPTY,
},
requests: 0,
stream: Some(0),
timeout_container,
parked_on_buffer_pressure: false,
close_notify_sent: false,
session_ulid,
})
}
pub fn new_h1_client(
session_ulid: Ulid,
front_stream: Front,
cluster_id: String,
backend: Rc<RefCell<Backend>>,
timeout_container: TimeoutContainer,
) -> Connection<Front> {
Connection::H1(ConnectionH1 {
socket: front_stream,
position: Position::Client(
cluster_id,
backend,
BackendStatus::Connecting(Instant::now()),
),
readiness: Readiness {
interest: Ready::WRITABLE | Ready::READABLE | Ready::HUP | Ready::ERROR,
event: Ready::EMPTY,
},
stream: None,
requests: 0,
timeout_container,
parked_on_buffer_pressure: false,
close_notify_sent: false,
session_ulid,
})
}
#[allow(clippy::too_many_arguments)]
pub fn new_h2_server(
session_ulid: Ulid,
front_stream: Front,
pool: Weak<RefCell<Pool>>,
timeout_container: TimeoutContainer,
flood_config: h2::H2FloodConfig,
connection_config: h2::H2ConnectionConfig,
stream_idle_timeout: std::time::Duration,
graceful_shutdown_deadline: Option<std::time::Duration>,
) -> Option<Connection<Front>> {
Some(Connection::H2(ConnectionH2::new(
session_ulid,
front_stream,
Position::Server,
pool,
flood_config,
connection_config,
stream_idle_timeout,
graceful_shutdown_deadline,
timeout_container,
Some((H2StreamId::Zero, h2::CLIENT_PREFACE_SIZE)),
Ready::READABLE | Ready::HUP | Ready::ERROR,
)?))
}
#[allow(clippy::too_many_arguments)]
pub fn new_h2_client(
session_ulid: Ulid,
front_stream: Front,
cluster_id: String,
backend: Rc<RefCell<Backend>>,
pool: Weak<RefCell<Pool>>,
timeout_container: TimeoutContainer,
flood_config: h2::H2FloodConfig,
connection_config: h2::H2ConnectionConfig,
stream_idle_timeout: std::time::Duration,
graceful_shutdown_deadline: Option<std::time::Duration>,
) -> Option<Connection<Front>> {
#[cfg(any(test, feature = "e2e-hooks"))]
if test_hooks::FORCE_NEW_H2_CLIENT_FAILURE.swap(false, std::sync::atomic::Ordering::SeqCst)
{
return None;
}
Some(Connection::H2(ConnectionH2::new(
session_ulid,
front_stream,
Position::Client(
cluster_id,
backend,
BackendStatus::Connecting(Instant::now()),
),
pool,
flood_config,
connection_config,
stream_idle_timeout,
graceful_shutdown_deadline,
timeout_container,
None,
Ready::WRITABLE | Ready::HUP | Ready::ERROR,
)?))
}
pub fn readiness(&self) -> &Readiness {
forward!(&self, readiness)
}
pub fn readiness_mut(&mut self) -> &mut Readiness {
forward!(&mut self, readiness)
}
pub fn position(&self) -> &Position {
forward!(&self, position)
}
pub fn position_mut(&mut self) -> &mut Position {
forward!(&mut self, position)
}
pub fn socket(&self) -> &TcpStream {
match self {
Connection::H1(c) => c.socket.socket_ref(),
Connection::H2(c) => c.socket.socket_ref(),
}
}
pub fn socket_mut(&mut self) -> &mut TcpStream {
match self {
Connection::H1(c) => c.socket.socket_mut(),
Connection::H2(c) => c.socket.socket_mut(),
}
}
pub fn timeout_container(&mut self) -> &mut TimeoutContainer {
forward!(&mut self, timeout_container)
}
pub fn overhead_bytes(&self) -> (usize, usize) {
match self {
Connection::H1(_) => (0, 0),
Connection::H2(c) => (c.bytes.overhead_bin, c.bytes.overhead_bout),
}
}
pub(super) fn readable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
where
E: Endpoint,
L: ListenerHandler + L7ListenerHandler,
{
forward!(self, readable(context, endpoint))
}
pub(super) fn writable<E, L>(&mut self, context: &mut Context<L>, endpoint: E) -> MuxResult
where
E: Endpoint,
L: ListenerHandler + L7ListenerHandler,
{
forward!(self, writable(context, endpoint))
}
pub(super) fn has_buffer_pressure<L>(&self, context: &Context<L>) -> bool
where
L: ListenerHandler + L7ListenerHandler,
{
match self {
Connection::H1(c) => {
let Some(stream_id) = c.stream else {
return false;
};
let kawa = match c.position {
Position::Client(..) => &context.streams[stream_id].back,
Position::Server => &context.streams[stream_id].front,
};
kawa.storage.available_space() == 0
}
Connection::H2(_) => false,
}
}
pub(super) fn try_resume_reading<L>(&mut self, context: &Context<L>) -> bool
where
L: ListenerHandler + L7ListenerHandler,
{
match self {
Connection::H1(c) => {
if !c.parked_on_buffer_pressure {
return false;
}
let Some(stream_id) = c.stream else {
return false;
};
let kawa = match c.position {
Position::Client(..) => &context.streams[stream_id].back,
Position::Server => &context.streams[stream_id].front,
};
if kawa.storage.available_space() > 0 {
trace!(
"{} H1 try_resume_reading: re-arming READABLE",
log_module_context!()
);
c.readiness.signal_pending_read();
true
} else {
false
}
}
Connection::H2(c) => c.try_resume_reading(context),
}
}
pub(super) fn graceful_goaway(&mut self) -> MuxResult {
match self {
Connection::H1(_) => MuxResult::Continue,
Connection::H2(c) => c.graceful_goaway(),
}
}
pub(super) fn is_draining(&self) -> bool {
match self {
Connection::H1(_) => false,
Connection::H2(c) => c.drain.draining,
}
}
pub(super) fn graceful_shutdown_deadline_elapsed(&self) -> bool {
match self {
Connection::H1(_) => false,
Connection::H2(c) => c.graceful_shutdown_deadline_elapsed(),
}
}
pub(super) fn has_pending_write(&self) -> bool {
forward!(self, has_pending_write())
}
pub(super) fn has_pending_write_including_streams<L>(&self, context: &super::Context<L>) -> bool
where
L: ListenerHandler + L7ListenerHandler,
{
match self {
Connection::H1(c) => c.has_pending_write(),
Connection::H2(c) => c.has_pending_write_full(context),
}
}
pub(super) fn initiate_close_notify(&mut self) -> bool {
forward!(self, initiate_close_notify())
}
pub(super) fn flush_zero_buffer(&mut self) {
if let Connection::H2(c) = self {
c.flush_zero_buffer();
}
}
fn pre_close_client_bookkeeping(&self) {
if let Position::Client(cluster_id, backend, _) = self.position() {
let mut backend_borrow = backend.borrow_mut();
backend_borrow.dec_connections();
gauge_add!(names::backend::CONNECTIONS, -1);
gauge_add!(names::backend::POOL_SIZE, -1);
gauge_add!(
names::backend::CONNECTIONS_PER_BACKEND,
-1,
Some(cluster_id),
Some(&backend_borrow.backend_id)
);
trace!(
"{} connection close: {:#?}",
log_module_context!(),
backend_borrow
);
}
}
fn pre_end_stream_client_bookkeeping(&self) {
if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
let mut backend_borrow = backend.borrow_mut();
backend_borrow.active_requests = backend_borrow.active_requests.saturating_sub(1);
trace!(
"{} connection end stream: {:#?}",
log_module_context!(),
backend_borrow
);
}
}
fn pre_start_stream_client_bookkeeping(&self) {
if let Position::Client(_, backend, BackendStatus::Connected) = self.position() {
let mut backend_borrow = backend.borrow_mut();
backend_borrow.active_requests += 1;
trace!(
"{} connection start stream: {:#?}",
log_module_context!(),
backend_borrow
);
}
}
pub(super) fn close<E, L>(&mut self, context: &mut Context<L>, endpoint: E)
where
E: Endpoint,
L: ListenerHandler + L7ListenerHandler,
{
self.pre_close_client_bookkeeping();
forward!(self, close(context, endpoint))
}
pub(super) fn end_stream<L>(&mut self, stream: GlobalStreamId, context: &mut Context<L>)
where
L: ListenerHandler + L7ListenerHandler,
{
self.pre_end_stream_client_bookkeeping();
forward!(self, end_stream(stream, context))
}
pub(super) fn start_stream<L>(
&mut self,
stream: GlobalStreamId,
context: &mut Context<L>,
) -> bool
where
L: ListenerHandler + L7ListenerHandler,
{
self.pre_start_stream_client_bookkeeping();
let started = forward!(self, start_stream(stream, context));
if !started {
self.pre_end_stream_client_bookkeeping();
}
started
}
}
#[derive(Debug)]
pub(super) struct EndpointServer<'a, Front: SocketHandler>(pub &'a mut Connection<Front>);
#[derive(Debug)]
pub(super) struct EndpointClient<'a>(pub &'a mut Router);
impl<Front: SocketHandler + Debug> Endpoint for EndpointServer<'_, Front> {
fn readiness(&self, _token: Token) -> &Readiness {
self.0.readiness()
}
fn readiness_mut(&mut self, _token: Token) -> &mut Readiness {
self.0.readiness_mut()
}
fn socket(&self, _token: Token) -> Option<&TcpStream> {
Some(self.0.socket())
}
fn end_stream<L>(&mut self, _token: Token, stream: GlobalStreamId, context: &mut Context<L>)
where
L: ListenerHandler + L7ListenerHandler,
{
self.0.end_stream(stream, context);
}
fn start_stream<L>(
&mut self,
_token: Token,
stream: GlobalStreamId,
context: &mut Context<L>,
) -> bool
where
L: ListenerHandler + L7ListenerHandler,
{
self.0.start_stream(stream, context)
}
}
impl Endpoint for EndpointClient<'_> {
fn readiness(&self, token: Token) -> &Readiness {
match self.0.backends.get(&token) {
Some(backend) => backend.readiness(),
None => {
error!(
"{} backend token {:?} missing from backends map (readiness)",
log_module_context!(),
token
);
&self.0.fallback_readiness
}
}
}
fn readiness_mut(&mut self, token: Token) -> &mut Readiness {
match self.0.backends.get_mut(&token) {
Some(backend) => backend.readiness_mut(),
None => {
error!(
"{} backend token {:?} missing from backends map (readiness_mut)",
log_module_context!(),
token
);
&mut self.0.fallback_readiness
}
}
}
fn socket(&self, token: Token) -> Option<&TcpStream> {
self.0.backends.get(&token).map(|c| c.socket())
}
fn end_stream<L>(&mut self, token: Token, stream: GlobalStreamId, context: &mut Context<L>)
where
L: ListenerHandler + L7ListenerHandler,
{
match self.0.backends.get_mut(&token) {
Some(backend) => backend.end_stream(stream, context),
None => {
error!(
"{} backend token {:?} missing from backends map (end_stream)",
log_module_context!(),
token
);
}
}
}
fn start_stream<L>(
&mut self,
token: Token,
stream: GlobalStreamId,
context: &mut Context<L>,
) -> bool
where
L: ListenerHandler + L7ListenerHandler,
{
match self.0.backends.get_mut(&token) {
Some(backend) => backend.start_stream(stream, context),
None => {
error!(
"{} backend token {:?} missing from backends map (start_stream)",
log_module_context!(),
token
);
false
}
}
}
}
#[cfg(any(test, feature = "e2e-hooks"))]
pub mod test_hooks {
use std::sync::atomic::AtomicBool;
pub static FORCE_NEW_H2_CLIENT_FAILURE: AtomicBool = AtomicBool::new(false);
pub fn __test_force_h2_client_failure(on: bool) -> bool {
FORCE_NEW_H2_CLIENT_FAILURE.swap(on, std::sync::atomic::Ordering::SeqCst)
}
}