1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{Future, StreamExt, stream::FuturesUnordered};
use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan};
use tokio::{sync::broadcast, task::JoinSet};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use super::{
PubError, PubMessage, PubOptions, SocketState, session::SubscriberSession, trie::PrefixTrie,
};
use crate::{ConnectionHookErased, hooks};
use msg_transport::{Address, PeerAddress, Transport};
use msg_wire::pubsub;
/// The driver for the publisher socket. This is responsible for accepting incoming connections,
/// running connection hooks, and spawning new [`SubscriberSession`]s for each connection.
#[allow(clippy::type_complexity)]
pub(crate) struct PubDriver<T: Transport<A>, A: Address> {
/// Session ID counter.
pub(super) id_counter: u32,
/// The server transport used to accept incoming connections.
pub(super) transport: T,
/// The publisher options (shared with the socket)
pub(super) options: Arc<PubOptions>,
/// The publisher socket state, shared with the socket front-end.
pub(crate) state: Arc<SocketState>,
/// Optional connection hook.
pub(super) hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
/// A set of pending incoming connections, represented by [`Transport::Accept`].
pub(super) conn_tasks: FuturesUnordered<T::Accept>,
/// A joinset of connection hook tasks.
pub(super) hook_tasks: JoinSet<WithSpan<hooks::ErasedHookResult<(T::Io, A)>>>,
/// The receiver end of the message broadcast channel. The sender half is stored by
/// [`PubSocket`](super::PubSocket).
pub(super) from_socket_bcast: broadcast::Receiver<PubMessage>,
}
impl<T, A> Future for PubDriver<T, A>
where
T: Transport<A>,
A: Address,
{
type Output = Result<(), PubError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
// First, poll the joinset of hook tasks. If a new connection has been handled
// we spawn a new session for it.
if let Poll::Ready(Some(Ok(hook_result))) = this.hook_tasks.poll_join_next(cx).enter() {
match hook_result.inner {
Ok((stream, _addr)) => {
info!("connection hook passed");
let framed = Framed::new(stream, pubsub::Codec::new());
let session = SubscriberSession {
seq: 0,
session_id: this.id_counter,
from_socket_bcast: this.from_socket_bcast.resubscribe().into(),
state: Arc::clone(&this.state),
pending_egress: None,
conn: framed,
topic_filter: PrefixTrie::new(),
linger_timer: this
.options
.write_buffer_linger
.map(tokio::time::interval),
write_buffer_size: this.options.write_buffer_size,
};
tokio::spawn(session);
this.id_counter = this.id_counter.wrapping_add(1);
}
Err(e) => {
error!(err = ?e, "Error in connection hook");
this.state.stats.specific.decrement_active_clients();
}
}
continue;
}
// Then poll the incoming connection tasks. If a new connection has been accepted, spawn
// a new hook task for it (or directly spawn a session if no hook is configured).
if let Poll::Ready(Some(incoming)) = this.conn_tasks.poll_next_unpin(cx) {
match incoming {
Ok(io) => {
if let Err(e) = this.on_incoming(io) {
error!(err = ?e, "Error accepting incoming connection");
this.state.stats.specific.decrement_active_clients();
}
}
Err(e) => {
error!(err = ?e, "Error accepting incoming connection");
// Active clients have already been incremented in the initial call to
// `poll_accept`, so we need to decrement them here.
this.state.stats.specific.decrement_active_clients();
}
}
continue;
}
// Finally, poll the transport for new incoming connection futures and push them to the
// incoming connection tasks.
if let Poll::Ready(accept) = Pin::new(&mut this.transport).poll_accept(cx) {
if let Some(max) = this.options.max_clients &&
this.state.stats.specific.active_clients() >= max
{
warn!("Max connections reached ({}), rejecting incoming connection", max);
continue;
}
// Increment the active clients counter. If the hook fails,
// this counter will be decremented.
this.state.stats.specific.increment_active_clients();
this.conn_tasks.push(accept);
continue;
}
return Poll::Pending;
}
}
}
impl<T, A> PubDriver<T, A>
where
T: Transport<A>,
A: Address,
{
/// Handles an incoming connection. If this returns an error, the active connections counter
/// should be decremented.
fn on_incoming(&mut self, io: T::Io) -> Result<(), std::io::Error> {
let addr = io.peer_addr()?;
info!("New connection from {:?}", addr);
// If a connection hook is configured, run it
if let Some(ref hook) = self.hook {
let hook = Arc::clone(hook);
let span = tracing::info_span!("connection_hook", ?addr);
let fut = async move {
let stream = hook.on_connection(io).await?;
Ok((stream, addr))
};
self.hook_tasks.spawn(fut.with_span(span));
} else {
let framed = Framed::new(io, pubsub::Codec::new());
let session = SubscriberSession {
seq: 0,
session_id: self.id_counter,
from_socket_bcast: self.from_socket_bcast.resubscribe().into(),
state: Arc::clone(&self.state),
pending_egress: None,
conn: framed,
topic_filter: PrefixTrie::new(),
linger_timer: self.options.write_buffer_linger.map(tokio::time::interval),
write_buffer_size: self.options.write_buffer_size,
};
tokio::spawn(session);
self.id_counter = self.id_counter.wrapping_add(1);
debug!("New connection from {:?}, session ID {}", addr, self.id_counter);
}
Ok(())
}
}