#![allow(unknown_lints, clippy::non_send_fields_in_send_ty)]
use super::{ConnectionApi, ConnectionApiProvider};
use crate::{
connection::{self, Connection, ConnectionInterests, InternalConnectionId},
endpoint::{
self,
connect::{self, ConnectionSender},
handle::{AcceptorSender, ConnectorReceiver},
},
stream,
};
use alloc::{collections::BTreeMap, sync::Arc};
use bytes::Bytes;
use core::{
cell::Cell,
marker::PhantomData,
ops::Deref,
pin::Pin,
sync::atomic::AtomicUsize,
task::{Context, Poll},
};
use intrusive_collections::{
intrusive_adapter, KeyAdapter, LinkedList, LinkedListLink, RBTree, RBTreeLink,
};
use s2n_quic_core::{
application,
application::ServerName,
event::supervisor,
inet::SocketAddress,
query::{Query, QueryMut},
recovery::K_GRANULARITY,
time::Timestamp,
transport,
};
intrusive_adapter!(DoneConnectionsAdapter<C, L> = Arc<ConnectionNode<C, L>>: ConnectionNode<C, L> {
done_connections_link: LinkedListLink
} where C: connection::Trait, L: connection::Lock<C>);
intrusive_adapter!(WaitingForTransmissionAdapter<C, L> = Arc<ConnectionNode<C, L>>: ConnectionNode<C, L> {
waiting_for_transmission_link: LinkedListLink
} where C: connection::Trait, L: connection::Lock<C>);
intrusive_adapter!(WaitingForConnectionIdAdapter<C, L> = Arc<ConnectionNode<C, L>>: ConnectionNode<C, L> {
waiting_for_connection_id_link: LinkedListLink
} where C: connection::Trait, L: connection::Lock<C>);
intrusive_adapter!(WaitingForTimeoutAdapter<C, L> = Arc<ConnectionNode<C, L>>: ConnectionNode<C, L> {
waiting_for_timeout_link: RBTreeLink
} where C: connection::Trait, L: connection::Lock<C>);
intrusive_adapter!(ConnectionTreeAdapter<C, L> = Arc<ConnectionNode<C, L>>: ConnectionNode<C, L> {
tree_link: RBTreeLink
} where C: connection::Trait, L: connection::Lock<C>);
struct ConnectionNode<C: connection::Trait, L: connection::Lock<C>> {
inner: L,
internal_connection_id: InternalConnectionId,
tree_link: RBTreeLink,
done_connections_link: LinkedListLink,
waiting_for_transmission_link: LinkedListLink,
waiting_for_connection_id_link: LinkedListLink,
waiting_for_timeout_link: RBTreeLink,
timeout: Cell<Option<Timestamp>>,
application_handle_count: AtomicUsize,
_connection: PhantomData<C>,
}
impl<C: connection::Trait, L: connection::Lock<C>> ConnectionNode<C, L> {
pub fn new(
connection_impl: L,
internal_connection_id: InternalConnectionId,
) -> ConnectionNode<C, L> {
ConnectionNode {
inner: connection_impl,
internal_connection_id,
tree_link: RBTreeLink::new(),
done_connections_link: LinkedListLink::new(),
waiting_for_transmission_link: LinkedListLink::new(),
waiting_for_connection_id_link: LinkedListLink::new(),
waiting_for_timeout_link: RBTreeLink::new(),
timeout: Cell::new(None),
application_handle_count: AtomicUsize::new(0),
_connection: PhantomData,
}
}
unsafe fn arc_from_ref(&self) -> Arc<Self> {
let temp_node_ptr: core::mem::ManuallyDrop<Arc<ConnectionNode<C, L>>> =
core::mem::ManuallyDrop::new(Arc::<ConnectionNode<C, L>>::from_raw(
self as *const ConnectionNode<C, L>,
));
temp_node_ptr.deref().clone()
}
fn api_write_call<F: FnOnce(&mut C) -> Result<R, E>, R, E: From<connection::Error>>(
&self,
f: F,
) -> Result<R, E> {
match self.inner.write(|conn| f(conn)) {
Ok(res) => res,
Err(_) => Err(connection::Error::unspecified().into()),
}
}
fn api_read_call<F: FnOnce(&C) -> Result<R, E>, R, E: From<connection::Error>>(
&self,
f: F,
) -> Result<R, E> {
match self.inner.read(|conn| f(conn)) {
Ok(res) => res,
Err(_) => Err(connection::Error::unspecified().into()),
}
}
fn api_poll_call<F: FnOnce(&mut C) -> Poll<Result<R, E>>, R, E: From<connection::Error>>(
&self,
f: F,
) -> Poll<Result<R, E>> {
match self.inner.write(|conn| f(conn)) {
Ok(res) => res,
Err(_) => Poll::Ready(Err(connection::Error::unspecified().into())),
}
}
#[inline]
fn ensure_consistency(&self) {
if !cfg!(debug_assertions) {
return;
}
if self.done_connections_link.is_linked() {
assert!(
!self.waiting_for_connection_id_link.is_linked(),
"A done connection should not be waiting for connection IDs"
);
assert!(
!self.waiting_for_timeout_link.is_linked(),
"A done connection should not be waiting for timeout"
);
assert!(
!self.waiting_for_transmission_link.is_linked(),
"A done connection should not be waiting for transmission"
);
return;
}
assert!(
self.waiting_for_connection_id_link.is_linked()
|| self.waiting_for_timeout_link.is_linked()
|| self.waiting_for_transmission_link.is_linked(),
"Active connections should express interest in at least one action"
);
}
}
impl<'a, C: connection::Trait, L: connection::Lock<C>> KeyAdapter<'a>
for WaitingForTimeoutAdapter<C, L>
{
type Key = Timestamp;
fn get_key(&self, node: &'a ConnectionNode<C, L>) -> Timestamp {
if let Some(timeout) = node.timeout.get() {
timeout
} else if cfg!(debug_assertions) {
panic!("node was queried for timeout but none was set")
} else {
unsafe {
Timestamp::from_duration(core::time::Duration::from_secs(0))
}
}
}
}
impl<'a, C: connection::Trait, L: connection::Lock<C>> KeyAdapter<'a>
for ConnectionTreeAdapter<C, L>
{
type Key = InternalConnectionId;
fn get_key(&self, node: &'a ConnectionNode<C, L>) -> InternalConnectionId {
node.internal_connection_id
}
}
unsafe impl<C: connection::Trait, L: connection::Lock<C>> Sync for ConnectionNode<C, L> {}
impl<C: connection::Trait, L: connection::Lock<C>> ConnectionApiProvider for ConnectionNode<C, L> {
fn application_handle_count(&self) -> &AtomicUsize {
&self.application_handle_count
}
fn poll_request(
&self,
stream_id: stream::StreamId,
request: &mut stream::ops::Request,
context: Option<&Context>,
) -> Result<stream::ops::Response, stream::StreamError> {
self.api_write_call(|conn| conn.poll_stream_request(stream_id, request, context))
}
fn poll_accept(
&self,
arc_self: &ConnectionApi,
stream_type: Option<stream::StreamType>,
context: &Context,
) -> Poll<Result<Option<stream::Stream>, connection::Error>> {
let response = self.api_poll_call(|conn| conn.poll_accept_stream(stream_type, context));
match response {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Err(e).into(),
Poll::Ready(Ok(None)) => Ok(None).into(),
Poll::Ready(Ok(Some(stream_id))) => {
let connection = arc_self.clone();
let connection = Connection::new(connection);
let stream = stream::Stream::new(connection, stream_id);
Ok(Some(stream)).into()
}
}
}
fn poll_open_stream(
&self,
arc_self: &ConnectionApi,
stream_type: stream::StreamType,
open_token: &mut connection::OpenToken,
context: &Context,
) -> Poll<Result<stream::Stream, connection::Error>> {
let response =
self.api_poll_call(|conn| conn.poll_open_stream(stream_type, open_token, context));
match response {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => Err(e).into(),
Poll::Ready(Ok(stream_id)) => {
let connection = arc_self.clone();
let connection = Connection::new(connection);
let stream = stream::Stream::new(connection, stream_id);
Ok(stream).into()
}
}
}
fn close_connection(&self, error: Option<application::Error>) {
let _: Result<(), connection::Error> = self.api_write_call(|conn| {
conn.application_close(error);
Ok(())
});
}
fn server_name(&self) -> Result<Option<ServerName>, connection::Error> {
self.api_read_call(|conn| Ok(conn.server_name()))
}
fn application_protocol(&self) -> Result<Bytes, connection::Error> {
self.api_read_call(|conn| Ok(conn.application_protocol()))
}
fn id(&self) -> u64 {
self.internal_connection_id.into()
}
fn ping(&self) -> Result<(), connection::Error> {
self.api_write_call(|conn| conn.ping())
}
fn keep_alive(&self, enabled: bool) -> Result<(), connection::Error> {
self.api_write_call(|conn| conn.keep_alive(enabled))
}
fn local_address(&self) -> Result<SocketAddress, connection::Error> {
self.api_read_call(|conn| conn.local_address())
}
fn remote_address(&self) -> Result<SocketAddress, connection::Error> {
self.api_read_call(|conn| conn.remote_address())
}
#[inline]
fn query_event_context(&self, query: &mut dyn Query) -> Result<(), connection::Error> {
self.api_read_call(|conn| {
conn.query_event_context(query);
Ok(())
})
}
#[inline]
fn query_event_context_mut(&self, query: &mut dyn QueryMut) -> Result<(), connection::Error> {
self.api_write_call(|conn| {
conn.query_event_context_mut(query);
Ok(())
})
}
#[inline]
fn datagram_mut(&self, query: &mut dyn QueryMut) -> Result<(), connection::Error> {
self.api_write_call(|conn| {
conn.datagram_mut(query);
Ok(())
})
}
}
struct InterestLists<C: connection::Trait, L: connection::Lock<C>> {
done_connections: LinkedList<DoneConnectionsAdapter<C, L>>,
waiting_for_transmission: LinkedList<WaitingForTransmissionAdapter<C, L>>,
waiting_for_connection_id: LinkedList<WaitingForConnectionIdAdapter<C, L>>,
waiting_for_timeout: RBTree<WaitingForTimeoutAdapter<C, L>>,
waiting_for_open: BTreeMap<InternalConnectionId, ConnectionSender>,
handshake_connections: usize,
connection_count: usize,
}
impl<C: connection::Trait, L: connection::Lock<C>> InterestLists<C, L> {
fn new() -> Self {
Self {
done_connections: LinkedList::new(DoneConnectionsAdapter::new()),
waiting_for_transmission: LinkedList::new(WaitingForTransmissionAdapter::new()),
waiting_for_connection_id: LinkedList::new(WaitingForConnectionIdAdapter::new()),
waiting_for_timeout: RBTree::new(WaitingForTimeoutAdapter::new()),
waiting_for_open: BTreeMap::new(),
handshake_connections: 0,
connection_count: 0,
}
}
fn update_interests(
&mut self,
accept_queue: &mut AcceptorSender,
node: &ConnectionNode<C, L>,
interests: ConnectionInterests,
result: ConnectionContainerIterationResult,
) -> Result<(), L::Error> {
let id = node.internal_connection_id;
macro_rules! insert_interest {
($list_name:ident, $call:ident) => {
let node = unsafe {
node.arc_from_ref()
};
self.$list_name.$call(node);
};
}
macro_rules! remove_interest {
($list_name:ident) => {
let mut cursor = unsafe {
self.$list_name
.cursor_mut_from_ptr(node.deref() as *const ConnectionNode<C, L>)
};
cursor.remove();
};
}
macro_rules! sync_interests_list {
($interest:expr, $link_name:ident, $list_name:ident) => {
if $interest != node.$link_name.is_linked() {
if $interest {
if matches!(result, ConnectionContainerIterationResult::Continue) {
insert_interest!($list_name, push_back);
} else {
insert_interest!($list_name, push_front);
}
} else {
remove_interest!($list_name);
}
}
debug_assert_eq!($interest, node.$link_name.is_linked());
};
}
sync_interests_list!(
interests.transmission,
waiting_for_transmission_link,
waiting_for_transmission
);
sync_interests_list!(
interests.new_connection_id,
waiting_for_connection_id_link,
waiting_for_connection_id
);
if node.timeout.get() != interests.timeout {
if node.waiting_for_timeout_link.is_linked() {
remove_interest!(waiting_for_timeout);
}
node.timeout.set(interests.timeout);
if interests.timeout.is_some() {
insert_interest!(waiting_for_timeout, insert);
}
} else {
debug_assert_eq!(
interests.timeout.is_some(),
node.waiting_for_timeout_link.is_linked()
);
}
if interests.accept {
node.inner.write(|conn| {
debug_assert!(!conn.is_handshaking());
conn.mark_as_accepted();
})?;
self.handshake_connections -= 1;
let handle = unsafe {
node.arc_from_ref()
};
let handle = crate::connection::api::Connection::new(handle);
match <C::Config as endpoint::Config>::ENDPOINT_TYPE {
endpoint::Type::Server => {
if let Err(error) = accept_queue.unbounded_send(handle) {
error.into_inner().api.close_connection(None);
}
}
endpoint::Type::Client => {
if let Some(sender) = self.waiting_for_open.remove(&id) {
if let Err(Ok(handle)) = sender.send(Ok(handle)) {
handle.api.close_connection(None);
}
} else {
debug_assert!(false, "client connection tried to open more than once");
}
}
}
}
if interests.finalization != node.done_connections_link.is_linked() {
if interests.finalization {
if <C::Config as endpoint::Config>::ENDPOINT_TYPE.is_client() {
if let Some(sender) = self.waiting_for_open.remove(&id) {
let err = node.inner.read(|conn| conn.error());
let err = match err {
Ok(Some(err)) => {
err
}
Ok(None) => {
transport::Error::NO_ERROR.into()
}
Err(_err) => {
transport::Error::INTERNAL_ERROR
.with_reason("failed to acquire connection lock")
.into()
}
};
let _ = sender.send(Err(err));
}
}
insert_interest!(done_connections, push_back);
} else {
unreachable!("Done connections should never report not done later");
}
}
node.ensure_consistency();
Ok(())
}
fn remove_node(&mut self, connection: &ConnectionNode<C, L>) {
let connection_ptr = connection as *const ConnectionNode<C, L>;
macro_rules! remove_connection_from_list {
($list_name:ident, $link_name:ident) => {
if connection.$link_name.is_linked() {
let mut cursor = unsafe {
self.$list_name.cursor_mut_from_ptr(connection_ptr)
};
let remove_result = cursor.remove();
debug_assert!(remove_result.is_some());
}
};
}
remove_connection_from_list!(waiting_for_transmission, waiting_for_transmission_link);
remove_connection_from_list!(waiting_for_connection_id, waiting_for_connection_id_link);
remove_connection_from_list!(waiting_for_timeout, waiting_for_timeout_link);
self.connection_count -= 1;
}
}
pub struct ConnectionContainer<C: connection::Trait, L: connection::Lock<C>> {
connection_map: RBTree<ConnectionTreeAdapter<C, L>>,
interest_lists: InterestLists<C, L>,
accept_queue: AcceptorSender,
connector_receiver: ConnectorReceiver,
}
macro_rules! iterate_interruptible {
($sel:ident, $list_name:ident, $link_name:ident, $func:expr) => {
let mut extracted_list = $sel.interest_lists.$list_name.take();
let mut cursor = extracted_list.front_mut();
while let Some(connection) = cursor.remove() {
debug_assert!(!connection.$link_name.is_linked());
let (result, interests) = match connection.inner.write(|conn| {
let result = $func(conn);
let interests = conn.interests();
(result, interests)
}) {
Ok(result) => result,
Err(_) => {
$sel.remove_poisoned_node(&connection);
continue;
}
};
if $sel
.interest_lists
.update_interests(&mut $sel.accept_queue, &connection, interests, result)
.is_err()
{
$sel.remove_poisoned_node(&connection);
}
match result {
ConnectionContainerIterationResult::BreakAndInsertAtBack => {
$sel.interest_lists
.$list_name
.front_mut()
.splice_after(extracted_list);
break;
}
ConnectionContainerIterationResult::Continue => {}
}
}
$sel.finalize_done_connections();
};
}
impl<C: connection::Trait, L: connection::Lock<C>> ConnectionContainer<C, L> {
pub(crate) fn new(accept_queue: AcceptorSender, connector_receiver: ConnectorReceiver) -> Self {
Self {
connection_map: RBTree::new(ConnectionTreeAdapter::new()),
interest_lists: InterestLists::new(),
accept_queue,
connector_receiver,
}
}
pub fn can_accept(&self) -> bool {
debug_assert!(<C::Config as endpoint::Config>::ENDPOINT_TYPE.is_server());
!self.accept_queue.is_closed()
}
fn can_connect(&self) -> bool {
debug_assert!(<C::Config as endpoint::Config>::ENDPOINT_TYPE.is_client());
use futures_core::FusedStream;
!self.connector_receiver.is_terminated()
}
pub fn is_empty(&self) -> bool {
self.connection_map.is_empty()
}
pub fn close(&mut self) {
debug_assert!(
self.is_empty(),
"close should only be called once all accepted connections have finished"
);
self.accept_queue.close_channel();
self.connector_receiver.close();
while let Ok(Some(request)) = self.connector_receiver.try_next() {
if request
.sender
.send(Err(connection::Error::endpoint_closing()))
.is_err()
{
}
}
}
pub fn is_open(&self) -> bool {
!self.connection_map.is_empty()
|| match <C::Config as endpoint::Config>::ENDPOINT_TYPE {
endpoint::Type::Server => self.can_accept(),
endpoint::Type::Client => self.can_connect(),
}
}
pub fn next_expiration(&self) -> Option<Timestamp> {
let cursor = self.interest_lists.waiting_for_timeout.front();
let node = cursor.get()?;
let timeout = node.timeout.get();
debug_assert!(
timeout.is_some(),
"a connection should only be in the timeout list when the timeout field is set"
);
timeout
}
pub fn insert_server_connection(
&mut self,
connection: C,
internal_connection_id: InternalConnectionId,
) {
debug_assert!(<C::Config as endpoint::Config>::ENDPOINT_TYPE.is_server());
self.insert_connection(connection, internal_connection_id)
}
#[allow(dead_code)]
pub fn insert_client_connection(
&mut self,
connection: C,
internal_connection_id: InternalConnectionId,
connection_sender: ConnectionSender,
) {
debug_assert!(<C::Config as endpoint::Config>::ENDPOINT_TYPE.is_client());
self.interest_lists
.waiting_for_open
.insert(internal_connection_id, connection_sender);
self.insert_connection(connection, internal_connection_id)
}
pub(crate) fn poll_connection_request(
&mut self,
cx: &mut Context,
) -> Poll<Option<connect::Request>> {
debug_assert!(
<C::Config as endpoint::Config>::ENDPOINT_TYPE.is_client(),
"only clients can open connections"
);
futures_core::Stream::poll_next(Pin::new(&mut self.connector_receiver), cx)
}
fn insert_connection(&mut self, connection: C, internal_connection_id: InternalConnectionId) {
let interests = connection.interests();
let connection = L::new(connection);
let connection = Arc::new(ConnectionNode::new(connection, internal_connection_id));
if self
.interest_lists
.update_interests(
&mut self.accept_queue,
&connection,
interests,
ConnectionContainerIterationResult::Continue,
)
.is_ok()
{
self.connection_map.insert(connection);
self.interest_lists.handshake_connections += 1;
self.interest_lists.connection_count += 1;
self.ensure_counter_consistency();
}
}
pub fn handshake_connections(&self) -> usize {
self.interest_lists.handshake_connections
}
pub fn len(&self) -> usize {
self.interest_lists.connection_count
}
pub fn with_connection<F, R>(
&mut self,
connection_id: InternalConnectionId,
func: F,
) -> Option<(R, ConnectionInterests)>
where
F: FnOnce(&mut C) -> R,
{
let cursor = self.connection_map.find(&connection_id);
let node = cursor.get()?;
let (result, interests) = match node.inner.write(|conn| {
let result = func(conn);
let interests = conn.interests();
(result, interests)
}) {
Ok(result) => result,
Err(_) => {
let id = node.internal_connection_id;
self.remove_node_by_id(id);
self.interest_lists.handshake_connections = self.count_handshaking_connections();
return None;
}
};
if self
.interest_lists
.update_interests(
&mut self.accept_queue,
node,
interests,
ConnectionContainerIterationResult::Continue,
)
.is_err()
{
let id = node.internal_connection_id;
self.remove_node_by_id(id);
self.interest_lists.handshake_connections = self.count_handshaking_connections();
}
self.ensure_counter_consistency();
self.finalize_done_connections();
Some((result, interests))
}
pub fn finalize_done_connections(&mut self) {
for connection in self.interest_lists.done_connections.take() {
self.remove_node(&connection);
let result = connection.inner.read(|conn| conn.is_handshaking());
match result {
Ok(true) => {
self.interest_lists.handshake_connections -= 1;
self.ensure_counter_consistency();
}
Ok(false) => {
}
Err(_) => {
self.interest_lists.handshake_connections =
self.count_handshaking_connections();
}
}
}
}
fn count_handshaking_connections(&self) -> usize {
self.connection_map
.iter()
.filter(|conn| {
conn.inner
.read(|conn| conn.is_handshaking())
.ok()
.unwrap_or(false)
})
.count()
}
fn ensure_counter_consistency(&self) {
if cfg!(debug_assertions) {
let expected = self.count_handshaking_connections();
assert_eq!(expected, self.interest_lists.handshake_connections);
assert_eq!(self.len(), self.connection_map.iter().count());
}
}
pub fn iterate_transmission_list<F>(&mut self, mut func: F)
where
F: FnMut(&mut C) -> ConnectionContainerIterationResult,
{
iterate_interruptible!(
self,
waiting_for_transmission,
waiting_for_transmission_link,
func
);
}
pub fn iterate_new_connection_id_list<F>(&mut self, mut func: F)
where
F: FnMut(&mut C) -> ConnectionContainerIterationResult,
{
iterate_interruptible!(
self,
waiting_for_connection_id,
waiting_for_connection_id_link,
func
);
}
pub fn iterate_timeout_list<F>(&mut self, now: Timestamp, mut func: F)
where
F: FnMut(&mut C, &supervisor::Context),
{
loop {
let mut cursor = self.interest_lists.waiting_for_timeout.front_mut();
let connection = if let Some(connection) = cursor.get() {
connection
} else {
break;
};
match connection.timeout.get() {
Some(v) if !v.has_elapsed(now) => break,
Some(_) => {}
None => {
debug_assert!(false, "connection was inserted without a timeout specified");
let conn = cursor.remove().unwrap();
conn.timeout.set(None);
continue;
}
}
let connection = cursor
.remove()
.expect("list capacity was already checked in the `while` condition");
debug_assert!(!connection.waiting_for_timeout_link.is_linked());
connection.timeout.set(None);
let mut interests = match connection.inner.write(|conn| {
let remote_address = conn
.remote_address()
.expect("Remote address should be available");
let context = supervisor::Context::new(
self.handshake_connections(),
self.len(),
&remote_address,
conn.is_handshaking(),
);
func(conn, &context);
conn.interests()
}) {
Ok(result) => result,
Err(_) => {
self.remove_poisoned_node(&connection);
continue;
}
};
if let Some(timeout) = interests.timeout.as_mut() {
if timeout.has_elapsed(now) {
*timeout = now + K_GRANULARITY;
debug_assert!(!timeout.has_elapsed(now));
}
}
if self
.interest_lists
.update_interests(
&mut self.accept_queue,
&connection,
interests,
ConnectionContainerIterationResult::Continue,
)
.is_err()
{
self.remove_poisoned_node(&connection);
}
}
self.finalize_done_connections();
}
fn remove_node_by_id(&mut self, connection_id: InternalConnectionId) {
let mut cursor = self.connection_map.find_mut(&connection_id);
let remove_result = cursor.remove();
debug_assert!(remove_result.is_some());
if let Some(connection) = remove_result {
self.interest_lists.remove_node(&connection);
}
}
fn remove_poisoned_node(&mut self, connection: &ConnectionNode<C, L>) {
self.remove_node(connection);
self.interest_lists.handshake_connections = self.count_handshaking_connections();
}
fn remove_node(&mut self, connection: &ConnectionNode<C, L>) {
let mut cursor = self
.connection_map
.find_mut(&connection.internal_connection_id);
let remove_result = cursor.remove();
debug_assert!(remove_result.is_some());
self.interest_lists.remove_node(connection);
}
}
#[derive(Clone, Copy, Debug)]
pub enum ConnectionContainerIterationResult {
Continue,
BreakAndInsertAtBack,
}
#[cfg(test)]
mod tests;