1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
use std::{
collections::HashSet,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{Future, StreamExt};
use msg_common::span::{EnterSpan as _, SpanExt as _, WithSpan};
use rustc_hash::FxHashMap;
use tokio::{
sync::mpsc::{self, error::TrySendError},
task::JoinSet,
};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, warn};
use super::{
Command, PubMessage, SocketState, SubOptions,
session::{PublisherSession, SessionCommand},
stream::{PublisherStream, TopicMessage},
};
use crate::{ConnectionHookErased, ConnectionState, ExponentialBackoff, hooks};
use msg_common::{Channel, JoinMap, channel};
use msg_transport::{Address, Transport};
use msg_wire::{compression::try_decompress_payload, pubsub};
/// Publisher channel type, used to send messages to the publisher session
/// and receive messages to forward to the socket frontend.
type PubChannel = Channel<SessionCommand, TopicMessage>;
#[allow(clippy::type_complexity)]
pub(crate) struct SubDriver<T: Transport<A>, A: Address> {
/// Options shared with the socket.
pub(super) options: Arc<SubOptions>,
/// The transport for this socket.
pub(super) transport: T,
/// Commands from the socket.
pub(super) from_socket: mpsc::Receiver<Command<A>>,
/// Messages to the socket.
pub(super) to_socket: mpsc::Sender<PubMessage<A>>,
/// A joinmap of transport connection tasks.
pub(super) conn_tasks: JoinMap<A, Result<T::Io, T::Error>>,
/// A joinset of connection hook tasks.
pub(super) hook_tasks: JoinSet<WithSpan<hooks::ErasedHookResult<(T::Io, A)>>>,
/// The set of subscribed topics.
pub(super) subscribed_topics: HashSet<String>,
/// All publisher sessions for this subscriber socket, keyed by address.
pub(super) publishers: FxHashMap<A, ConnectionState<PubChannel, ExponentialBackoff, A>>,
/// Socket state. This is shared with the backend task. Contains the unified stats struct.
pub(super) state: Arc<SocketState<A>>,
/// Optional connection hook.
pub(super) hook: Option<Arc<dyn ConnectionHookErased<T::Io>>>,
}
impl<T, A> Future for SubDriver<T, A>
where
T: Transport<A>,
A: Address,
{
type Output = ();
/// This poll implementation prioritizes incoming messages over commands.
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
// First, poll all the publishers to handle incoming messages.
if this.poll_publishers(cx).is_ready() {
continue;
}
// Then, poll the socket for new commands.
if let Poll::Ready(Some(cmd)) = this.from_socket.poll_recv(cx) {
// MODIFIED: Access stats via state.stats.specific
this.state.stats.specific.increment_commands_received();
// Process the command
this.on_command(cmd);
continue;
}
// Poll connection hook tasks
if let Poll::Ready(Some(Ok(hook_result))) = this.hook_tasks.poll_join_next(cx).enter() {
match hook_result.inner {
Ok((io, addr)) => {
info!("connection hook passed");
this.on_connection(addr, io);
}
Err(e) => {
error!(err = ?e, "connection hook failed");
}
}
continue;
}
// Poll the transport connection tasks
if let Poll::Ready(Some(Ok((addr, result)))) = this.conn_tasks.poll_join_next(cx) {
match result {
Ok(io) => {
this.on_transport_connected(addr, io);
}
Err(e) => {
error!(err = ?e, ?addr, "transport connection failed");
}
}
continue;
}
return Poll::Pending;
}
}
}
impl<T, A> SubDriver<T, A>
where
T: Transport<A>,
A: Address,
{
/// De-activates a publisher by setting it to [`ConnectionState::Inactive`].
/// This will initialize the backoff stream.
fn reset_publisher(&mut self, addr: A) {
debug!("Resetting publisher at {addr:?}");
self.publishers.insert(
addr.clone(),
ConnectionState::Inactive {
addr,
backoff: ExponentialBackoff::new(
self.options.initial_backoff,
self.options.retry_attempts,
),
},
);
}
/// Returns true if we're already connected to the given publisher address.
fn is_connected(&self, addr: &A) -> bool {
if self.publishers.get(addr).is_some_and(|s| s.is_active()) {
return true;
}
false
}
fn is_known(&self, addr: &A) -> bool {
self.publishers.contains_key(addr)
}
/// Subscribes to a topic on all publishers.
fn subscribe(&mut self, topic: String) {
let mut inactive = Vec::new();
if self.subscribed_topics.insert(topic.clone()) {
// Subscribe to the topic on all publishers
for (addr, publisher_state) in self.publishers.iter_mut() {
if let ConnectionState::Active { channel } = publisher_state {
// If the channel is closed on the other side, deactivate the publisher
if let Err(TrySendError::Closed(_)) =
channel.try_send(SessionCommand::Subscribe(topic.clone()))
{
warn!(publisher = ?addr, "Error trying to subscribe to topic {topic}: publisher channel closed");
inactive.push(addr.clone());
}
}
}
// Remove all inactive publishers
for addr in inactive {
// Move publisher to inactive state
self.reset_publisher(addr);
}
info!(
topic = topic.as_str(),
n_publishers = self.publishers.len(),
"Subscribed to topic"
);
} else {
debug!(topic = topic.as_str(), "Already subscribed to topic");
}
}
/// Unsubscribes from a topic on all publishers.
fn unsubscribe(&mut self, topic: String) {
let mut inactive = Vec::new();
if self.subscribed_topics.remove(&topic) {
// Unsubscribe from the topic on all publishers
for (addr, publisher_state) in self.publishers.iter_mut() {
if let ConnectionState::Active { channel } = publisher_state {
// If the channel is closed on the other side, deactivate the publisher
if let Err(TrySendError::Closed(_)) =
channel.try_send(SessionCommand::Unsubscribe(topic.clone()))
{
warn!(publisher = ?addr, "Error trying to unsubscribe from topic {topic}: publisher channel closed");
inactive.push(addr.clone());
}
}
}
// Remove all inactive publishers
for addr in inactive {
// Move publisher to inactive state
self.reset_publisher(addr);
}
info!(
topic = topic.as_str(),
n_publishers = self.publishers.len(),
"Unsubscribed from topic"
);
} else {
debug!(topic = topic.as_str(), "Not subscribed to topic");
}
}
fn on_command(&mut self, cmd: Command<A>) {
debug!("Received command: {:?}", cmd);
match cmd {
Command::Subscribe { topic } => {
self.subscribe(topic);
}
Command::Unsubscribe { topic } => {
self.unsubscribe(topic);
}
Command::Connect { endpoint } => {
if self.is_known(&endpoint) {
debug!(?endpoint, "Publisher already known, ignoring connect command");
return;
}
self.connect(endpoint.clone());
// Also set the publisher to the disconnected state. This will make sure that if the
// initial connection attempt fails, it will be retried in `poll_publishers`.
self.reset_publisher(endpoint);
}
Command::Disconnect { endpoint } => {
if self.publishers.remove(&endpoint).is_some() {
debug!(?endpoint, "Disconnected from publisher");
self.state.stats.specific.remove_session(&endpoint);
} else {
debug!(?endpoint, "Not connected to publisher");
};
}
Command::Shutdown => {
// TODO: graceful shutdown?
debug!("shutting down");
}
}
}
fn connect(&mut self, addr: A) {
let connect = self.transport.connect(addr.clone());
self.conn_tasks.spawn(addr.clone(), async move {
let io = connect.await;
(addr, io)
});
}
/// Called when a transport connection is established (before hook).
///
/// Spawns a task with the configured hook, or a task that returns immediately if no hook
/// is configured.
fn on_transport_connected(&mut self, addr: A, io: T::Io) {
if let Some(ref hook) = self.hook {
let hook = Arc::clone(hook);
let span = tracing::info_span!("connection_hook", ?addr);
let fut = async move {
let stream = hook.on_connection(io).await?;
Ok((stream, addr))
};
self.hook_tasks.spawn(fut.with_span(span));
} else {
let span = tracing::trace_span!("connection_hook", ?addr, noop_hook = true);
self.hook_tasks.spawn(async move { Ok((io, addr)) }.with_span(span));
}
}
fn on_connection(&mut self, addr: A, io: T::Io) {
if self.is_connected(&addr) {
// We're already connected to this publisher
warn!(?addr, "Already connected to publisher");
return;
}
debug!("Connection to {:?} established, spawning session", addr);
let framed = Framed::with_capacity(io, pubsub::Codec::new(), self.options.read_buffer_size);
let (driver_channel, mut publisher_channel) = channel(1024, 64);
let publisher_session =
PublisherSession::new(addr.clone(), PublisherStream::from(framed), driver_channel);
// Get the shared session stats.
let session_stats = publisher_session.stats();
// Spawn the publisher session
tokio::spawn(publisher_session);
for topic in self.subscribed_topics.iter() {
if publisher_channel.try_send(SessionCommand::Subscribe(topic.clone())).is_err() {
error!(publisher = ?addr, "Error trying to subscribe to topic {topic} on startup: publisher channel closed / full");
}
}
self.publishers
.insert(addr.clone(), ConnectionState::Active { channel: publisher_channel });
self.state.stats.specific.insert_session(addr, session_stats);
}
/// Polls all the publisher channels for new messages. On new messages, forwards them to the
/// socket. If a publisher channel is closed, it will be removed from the list of
/// publishers.
///
/// Returns `Poll::Ready` if any progress was made and this method should be called again.
/// Returns `Poll::Pending` if no progress was made.
fn poll_publishers(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let mut progress = false;
// These should be fine as Vec::new() does not allocate
let mut inactive = Vec::new();
let mut to_retry = Vec::new();
let mut to_terminate = Vec::new();
for (addr, state) in self.publishers.iter_mut() {
match state {
ConnectionState::Active { channel } => {
match channel.poll_recv(cx) {
Poll::Ready(Some(mut msg)) => {
match try_decompress_payload(msg.compression_type, msg.payload) {
Ok(decompressed) => msg.payload = decompressed,
Err(e) => {
error!(err = ?e, "Failed to decompress message");
continue;
}
};
let msg_to_send = PubMessage::new(addr.clone(), msg.topic, msg.payload);
debug!(source = ?msg_to_send.source, ?msg_to_send, "New message");
match self.to_socket.try_send(msg_to_send) {
Ok(_) => {
// Successfully sent to socket frontend
self.state.stats.specific.increment_messages_received();
}
Err(TrySendError::Full(msg_back)) => {
// Failed due to full buffer
self.state.stats.specific.increment_dropped_messages();
error!(
topic = msg_back.topic,
"Slow subscriber socket, dropping message"
);
}
Err(TrySendError::Closed(_)) => {
error!(
"SubSocket frontend channel closed unexpectedly while driver is active."
);
// Consider shutting down or marking as inactive?
// For now, just log. No counter increment.
}
}
progress = true;
}
Poll::Ready(None) => {
error!(source = ?addr, "Publisher stream closed, removing channel");
inactive.push(addr.clone());
progress = true;
}
Poll::Pending => {}
}
}
ConnectionState::Inactive { addr, backoff } => {
// Poll the backoff stream
if let Poll::Ready(item) = backoff.poll_next_unpin(cx) {
if let Some(duration) = item {
progress = true;
// Only retry if there are no active connection tasks
if !self.conn_tasks.contains_key(addr) {
debug!(backoff = ?duration, "Retrying connection to {:?}", addr);
to_retry.push(addr.clone());
} else {
debug!(backoff = ?duration, "Not retrying connection to {:?} as there is already a connection task", addr);
}
} else {
error!(
"Exceeded maximum number of retries for {:?}, terminating connection",
addr
);
to_terminate.push(addr.clone());
}
}
}
}
}
// Activate retries
for addr in to_retry {
self.connect(addr);
}
// Queue retries for all the inactive publishers.
for addr in inactive {
self.reset_publisher(addr);
}
// Terminate publishers that are unreachable.
for addr in to_terminate {
self.publishers.remove(&addr);
}
if progress { Poll::Ready(()) } else { Poll::Pending }
}
}