use std::{
collections::VecDeque,
future::{ready, Future, IntoFuture, Ready},
net::SocketAddr,
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
},
};
#[cfg(feature = "tls")]
use crate::tls::TlsInfo;
use crate::{
channel::Channel,
context::{
info::{ConnInfo, DatagramInfo},
ConnectionStats,
},
Result,
};
pub struct InboundContext {
info: DatagramInfo,
}
impl InboundContext {
pub(crate) fn new(info: ConnInfo) -> Self {
Self {
info: DatagramInfo::from_conn(info),
}
}
pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
Self { info }
}
pub fn id(&self) -> u64 {
self.info.id()
}
pub fn peer_addr(&self) -> SocketAddr {
self.info.peer_addr()
}
pub fn local_addr(&self) -> SocketAddr {
self.info.local_addr()
}
#[cfg(feature = "tls")]
pub fn tls(&self) -> Option<&TlsInfo> {
self.info.tls()
}
}
pub struct BusinessContext {
info: DatagramInfo,
}
impl BusinessContext {
pub(crate) fn new(info: ConnInfo) -> Self {
Self {
info: DatagramInfo::from_conn(info),
}
}
pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
Self { info }
}
pub fn id(&self) -> u64 {
self.info.id()
}
pub fn peer_addr(&self) -> SocketAddr {
self.info.peer_addr()
}
pub fn local_addr(&self) -> SocketAddr {
self.info.local_addr()
}
#[cfg(feature = "tls")]
pub fn tls(&self) -> Option<&TlsInfo> {
self.info.tls()
}
}
pub struct OutboundContext {
info: DatagramInfo,
}
impl OutboundContext {
pub(crate) fn new(info: ConnInfo) -> Self {
Self {
info: DatagramInfo::from_conn(info),
}
}
pub(crate) fn new_datagram(info: DatagramInfo) -> Self {
Self { info }
}
pub fn id(&self) -> u64 {
self.info.id()
}
pub fn peer_addr(&self) -> SocketAddr {
self.info.peer_addr()
}
pub fn local_addr(&self) -> SocketAddr {
self.info.local_addr()
}
#[cfg(feature = "tls")]
pub fn tls(&self) -> Option<&TlsInfo> {
self.info.tls()
}
}
pub struct Context<W> {
info: ConnInfo,
channel: Channel<W>,
outbox: StreamOutboxHandle<W>,
close_requested: bool,
}
impl<W: Send + 'static> Context<W> {
pub(crate) fn new(info: ConnInfo, channel: Channel<W>) -> Self {
Self {
info,
channel,
outbox: StreamOutboxHandle::new(),
close_requested: false,
}
}
pub fn id(&self) -> u64 {
self.info.id()
}
pub fn peer_addr(&self) -> SocketAddr {
self.info.peer_addr()
}
pub fn local_addr(&self) -> SocketAddr {
self.info.local_addr()
}
#[cfg(feature = "tls")]
pub fn tls(&self) -> Option<&TlsInfo> {
self.info.tls()
}
pub fn channel(&self) -> Channel<W> {
self.channel.clone()
}
pub fn stats(&self) -> Option<ConnectionStats> {
self.channel.stats()
}
#[inline]
pub fn write(&mut self, msg: W) -> WriteHandle {
self.outbox.push_write(msg);
WriteHandle { _private: () }
}
#[inline]
pub fn flush(&mut self) -> FlushHandle<'_, W> {
self.outbox.push_flush()
}
#[inline]
pub fn write_and_flush(&mut self, msg: W) -> FlushHandle<'_, W> {
self.outbox.push_write_and_flush(msg)
}
pub async fn close(&mut self) -> Result<()> {
self.close_requested = true;
Ok(())
}
pub(crate) fn outbox(&self) -> StreamOutboxHandle<W> {
self.outbox.clone()
}
pub(crate) fn close_requested(&self) -> bool {
self.close_requested
}
#[inline]
pub(crate) fn has_external_channel(&self) -> bool {
self.channel.strong_count() > 1
}
}
pub struct WriteHandle {
_private: (),
}
impl IntoFuture for WriteHandle {
type Output = Result<()>;
type IntoFuture = Ready<Result<()>>;
#[inline]
fn into_future(self) -> Self::IntoFuture {
ready(Ok(()))
}
}
pub struct FlushHandle<'a, W> {
outbox: &'a StreamOutboxHandle<W>,
}
impl<'a, W> IntoFuture for FlushHandle<'a, W> {
type Output = Result<()>;
type IntoFuture = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
#[inline]
fn into_future(self) -> Self::IntoFuture {
let id = self.outbox.push_flush_completion();
let state = &self.outbox.core.flush_state;
Box::pin(async move {
state.mark_awaited(id);
loop {
let notified = state.notify.notified();
tokio::pin!(notified);
notified.as_mut().enable();
if state.completed_flush_id.load(Ordering::Acquire) >= id {
return Ok(());
}
notified.await;
}
})
}
}
pub(crate) enum StreamOutboxCommand<W> {
Write(W),
Flush { completion: Option<u64> },
WriteAndFlush { msg: W },
}
struct StreamOutboxState<W> {
head: Option<StreamOutboxCommand<W>>,
tail: VecDeque<StreamOutboxCommand<W>>,
}
impl<W> StreamOutboxState<W> {
fn new() -> Self {
Self {
head: None,
tail: VecDeque::new(),
}
}
#[inline]
fn push(&mut self, command: StreamOutboxCommand<W>) {
if self.head.is_none() {
self.head = Some(command);
} else {
self.tail.push_back(command);
}
}
#[inline]
fn take_batch(&mut self) -> StreamOutboxBatch<W> {
StreamOutboxBatch {
head: self.head.take(),
tail: std::mem::take(&mut self.tail),
}
}
}
pub(crate) struct StreamOutboxBatch<W> {
head: Option<StreamOutboxCommand<W>>,
tail: VecDeque<StreamOutboxCommand<W>>,
}
impl<W> Iterator for StreamOutboxBatch<W> {
type Item = StreamOutboxCommand<W>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.head.take().or_else(|| self.tail.pop_front())
}
}
struct StreamFlushState {
next_flush_id: AtomicU64,
completed_flush_id: AtomicU64,
awaited_flush_id: AtomicU64,
notify: tokio::sync::Notify,
}
impl StreamFlushState {
fn new() -> Self {
Self {
next_flush_id: AtomicU64::new(0),
completed_flush_id: AtomicU64::new(0),
awaited_flush_id: AtomicU64::new(0),
notify: tokio::sync::Notify::new(),
}
}
#[inline]
fn next_id(&self) -> u64 {
self.next_flush_id.fetch_add(1, Ordering::Relaxed) + 1
}
#[inline]
fn mark_awaited(&self, id: u64) {
self.awaited_flush_id.fetch_max(id, Ordering::Release);
}
#[inline]
fn complete(&self, id: u64) {
self.completed_flush_id.store(id, Ordering::Release);
if self.awaited_flush_id.load(Ordering::Acquire) >= id {
self.notify.notify_waiters();
}
}
}
struct StreamOutboxCore<W> {
commands: Mutex<StreamOutboxState<W>>,
flush_requested: AtomicBool,
flush_state: StreamFlushState,
}
pub(crate) struct StreamOutboxHandle<W> {
core: Arc<StreamOutboxCore<W>>,
}
impl<W> Clone for StreamOutboxHandle<W> {
fn clone(&self) -> Self {
Self {
core: self.core.clone(),
}
}
}
impl<W> StreamOutboxHandle<W> {
fn new() -> Self {
Self {
core: Arc::new(StreamOutboxCore {
commands: Mutex::new(StreamOutboxState::new()),
flush_requested: AtomicBool::new(false),
flush_state: StreamFlushState::new(),
}),
}
}
#[inline]
fn push_write(&self, msg: W) {
self.core
.commands
.lock()
.expect("stream outbox lock poisoned")
.push(StreamOutboxCommand::Write(msg));
}
#[inline]
fn push_flush(&self) -> FlushHandle<'_, W> {
self.core
.commands
.lock()
.expect("stream outbox lock poisoned")
.push(StreamOutboxCommand::Flush { completion: None });
self.core.flush_requested.store(true, Ordering::Release);
FlushHandle { outbox: self }
}
#[inline]
fn push_write_and_flush(&self, msg: W) -> FlushHandle<'_, W> {
self.core
.commands
.lock()
.expect("stream outbox lock poisoned")
.push(StreamOutboxCommand::WriteAndFlush { msg });
self.core.flush_requested.store(true, Ordering::Release);
FlushHandle { outbox: self }
}
#[inline]
fn push_flush_completion(&self) -> u64 {
let id = self.core.flush_state.next_id();
self.core
.commands
.lock()
.expect("stream outbox lock poisoned")
.push(StreamOutboxCommand::Flush {
completion: Some(id),
});
self.core.flush_requested.store(true, Ordering::Release);
id
}
#[inline]
pub(crate) fn has_flush_command(&self) -> bool {
self.core.flush_requested.load(Ordering::Acquire)
}
#[inline]
pub(crate) fn take_commands(&self) -> StreamOutboxBatch<W> {
self.core.flush_requested.store(false, Ordering::Release);
self.core
.commands
.lock()
.expect("stream outbox lock poisoned")
.take_batch()
}
#[inline]
pub(crate) fn complete_flush(&self, id: u64) {
self.core.flush_state.complete(id);
}
}