Skip to main content

pea2pea/protocols/
writing.rs

1use std::{
2    any::Any, collections::HashMap, future::Future, io, net::SocketAddr, sync::Arc, time::Duration,
3};
4
5#[cfg(doc)]
6use bytes::Bytes;
7use futures_util::sink::SinkExt;
8use parking_lot::RwLock;
9use tokio::{
10    io::AsyncWrite,
11    sync::{mpsc, oneshot},
12    task::JoinSet,
13    time::timeout,
14};
15use tokio_util::codec::{Encoder, FramedWrite};
16use tracing::*;
17
18#[cfg(doc)]
19use crate::{Config, Node, protocols::Handshake};
20use crate::{
21    Connection, ConnectionSide, Pea2Pea,
22    connections::create_connection_span,
23    node::NodeTask,
24    protocols::{DisconnectOnDrop, Protocol, ProtocolHandler, ReturnableConnection},
25};
26
27type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
28
29/// Can be used to specify and enable writing, i.e. sending outbound messages. If the [`Handshake`]
30/// protocol is enabled too, it goes into force only after the handshake has been concluded.
31pub trait Writing: Pea2Pea
32where
33    Self: Clone + Send + Sync + 'static,
34{
35    /// The depth of per-connection queues used to send outbound messages; the greater it is, the more outbound
36    /// messages the node can enqueue. Setting it to a large value is not recommended, as doing it might
37    /// obscure potential issues with your implementation (like slow serialization) or network.
38    const MESSAGE_QUEUE_DEPTH: usize = 64;
39
40    /// The initial size of a per-connection buffer for writing outbound messages. Can be set to the maximum expected size
41    /// of the outbound message in order to only allocate it once.
42    const INITIAL_BUFFER_SIZE: usize = 64 * 1024;
43
44    /// The maximum time (in milliseconds) allowed for a single message write to flush
45    /// to the underlying stream before the connection is considered dead.
46    const TIMEOUT_MS: u64 = 10_000;
47
48    /// The type of the outbound messages; unless their serialization is expensive and the message
49    /// is broadcasted (in which case it would get serialized multiple times), serialization should
50    /// be done in the implementation of [`Writing::Codec`].
51    type Message: Send;
52
53    /// The user-supplied [`Encoder`] used to write outbound messages to the target stream.
54    type Codec: Encoder<Self::Message, Error = io::Error> + Send;
55
56    /// Prepares the node to send messages.
57    fn enable_writing(&self) -> impl Future<Output = ()> {
58        async {
59            // create a JoinSet to track all in-flight setup tasks
60            let mut setup_tasks = JoinSet::new();
61
62            let (conn_sender, mut conn_receiver) =
63                mpsc::channel(self.node().config().max_connecting as usize);
64
65            // the conn_senders are used to send messages from the Node to individual connections
66            let conn_senders: WritingSenders = Default::default();
67            // procure a clone to create the WritingHandler with
68            let senders = conn_senders.clone();
69
70            // use a channel to know when the writing task is ready
71            let (tx_writing, rx_writing) = oneshot::channel();
72
73            // the task spawning tasks sending messages to all the streams
74            let self_clone = self.clone();
75            let writing_task = tokio::spawn(async move {
76                trace!(parent: self_clone.node().span(), "spawned the Writing handler task");
77                if tx_writing.send(()).is_err() {
78                    error!(parent: self_clone.node().span(), "writing handler creation interrupted! shutting down the node");
79                    self_clone.node().shut_down().await;
80                    return;
81                }
82
83                loop {
84                    tokio::select! {
85                        // handle new connections from `Node::adapt_stream`
86                        maybe_conn = conn_receiver.recv() => {
87                            match maybe_conn {
88                                Some(returnable_conn) => {
89                                    let self_clone2 = self_clone.clone();
90                                    let senders = conn_senders.clone();
91                                    setup_tasks.spawn(async move {
92                                        self_clone2.handle_new_connection(returnable_conn, &senders).await;
93                                    });
94                                }
95                                None => break, // channel closed
96                            }
97                        }
98                        // task set cleanups
99                        _ = setup_tasks.join_next(), if !setup_tasks.is_empty() => {}
100                    }
101                }
102            });
103            let _ = rx_writing.await;
104            self.node()
105                .tasks
106                .lock()
107                .insert(NodeTask::Writing, writing_task);
108
109            // register the WritingHandler with the Node
110            let hdl = WritingHandler {
111                handler: ProtocolHandler(conn_sender),
112                senders,
113            };
114            assert!(
115                self.node().protocols.writing.set(hdl).is_ok(),
116                "the Writing protocol was enabled more than once!"
117            );
118        }
119    }
120
121    /// Creates an [`Encoder`] used to write the outbound messages to the target stream.
122    /// The `side` param indicates the connection side **from the node's perspective**.
123    fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
124
125    /// Sends the provided message to the specified [`SocketAddr`]. Returns as soon as the message is queued to
126    /// be sent, without waiting for the actual delivery; instead, the caller is provided with a [`oneshot::Receiver`]
127    /// which can be used to determine when and whether the message has been delivered.
128    ///
129    /// # Errors
130    ///
131    /// The following errors can be returned:
132    /// - [`io::ErrorKind::BrokenPipe`] if the outbound message channel is down
133    /// - [`io::ErrorKind::NotConnected`] if the node is not connected to the provided address
134    /// - [`io::ErrorKind::QuotaExceeded`] if the outbound message queue for this address is full
135    /// - [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet
136    fn unicast(
137        &self,
138        addr: SocketAddr,
139        message: Self::Message,
140    ) -> io::Result<oneshot::Receiver<io::Result<()>>> {
141        // access the protocol handler
142        if let Some(handler) = self.node().protocols.writing.get() {
143            // find the message sender for the given address
144            if let Some(sender) = handler.senders.read().get(&addr).cloned() {
145                let (msg, delivery) = WrappedMessage::new(Box::new(message), true);
146                let conn_span = create_connection_span(addr, self.node().span());
147                sender
148                    .try_send(msg)
149                    .map_err(|e| {
150                        error!(parent: conn_span, "can't send a message: {e}");
151                        match e {
152                            mpsc::error::TrySendError::Full(_) => {
153                                io::ErrorKind::QuotaExceeded.into()
154                            }
155                            mpsc::error::TrySendError::Closed(_) => {
156                                io::ErrorKind::BrokenPipe.into()
157                            }
158                        }
159                    })
160                    .map(|_| delivery.unwrap()) // infallible
161            } else {
162                Err(io::ErrorKind::NotConnected.into())
163            }
164        } else {
165            Err(io::ErrorKind::Unsupported.into())
166        }
167    }
168
169    /// Sends the provided message to the specified [`SocketAddr`], and returns as soon as the
170    /// message is queued to be sent, without waiting for the actual delivery (as opposed to
171    /// [`Writing::unicast`], which does provide delivery feedback).
172    ///
173    /// # Errors
174    ///
175    /// See the error section for [`Writing::unicast`].
176    fn unicast_fast(&self, addr: SocketAddr, message: Self::Message) -> io::Result<()> {
177        // access the protocol handler
178        if let Some(handler) = self.node().protocols.writing.get() {
179            // find the message sender for the given address
180            if let Some(sender) = handler.senders.read().get(&addr).cloned() {
181                let (msg, _) = WrappedMessage::new(Box::new(message), false);
182                let conn_span = create_connection_span(addr, self.node().span());
183                sender.try_send(msg).map_err(|e| {
184                    error!(parent: conn_span, "can't send a message: {e}");
185                    match e {
186                        mpsc::error::TrySendError::Full(_) => io::ErrorKind::QuotaExceeded.into(),
187                        mpsc::error::TrySendError::Closed(_) => io::ErrorKind::BrokenPipe.into(),
188                    }
189                })
190            } else {
191                Err(io::ErrorKind::NotConnected.into())
192            }
193        } else {
194            Err(io::ErrorKind::Unsupported.into())
195        }
196    }
197
198    /// Broadcasts the provided message to all connected peers. Returns as soon as the message is queued to
199    /// be sent to all the peers, without waiting for the actual delivery. This method doesn't provide the
200    /// means to check when and if the messages actually get delivered; you can achieve that by calling
201    /// [`Writing::unicast`] for each address returned by [`Node::connected_addrs`].
202    ///
203    /// note: This method clones the message for every connected peer, and serialization via
204    /// [`Writing::Codec`] happens individually for each connection. If your serialization is
205    /// expensive (e.g., large JSON/Bincode structs), manually serialize your message into
206    /// [`Bytes`] *before* calling broadcast. This ensures serialization happens only once,
207    /// rather than N times.
208    ///
209    /// # Errors
210    ///
211    /// Returns [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet.
212    fn broadcast(&self, message: Self::Message) -> io::Result<()>
213    where
214        Self::Message: Clone,
215    {
216        // access the protocol handler
217        if let Some(handler) = self.node().protocols.writing.get() {
218            let senders = handler.senders.read().clone();
219            for (addr, message_sender) in senders {
220                let (msg, _) = WrappedMessage::new(Box::new(message.clone()), false);
221                let conn_span = create_connection_span(addr, self.node().span());
222                let _ = message_sender.try_send(msg).map_err(|e| {
223                    error!(parent: conn_span, "can't send a message: {e}");
224                });
225            }
226
227            Ok(())
228        } else {
229            Err(io::ErrorKind::Unsupported.into())
230        }
231    }
232}
233
234/// This trait is used to restrict access to methods that would otherwise be public in [`Writing`].
235trait WritingInternal: Writing {
236    /// Writes the given message to the network stream and returns the number of written bytes.
237    async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
238        &self,
239        message: Self::Message,
240        writer: &mut FramedWrite<W, Self::Codec>,
241    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
242
243    /// Applies the [`Writing`] protocol to a single connection.
244    async fn handle_new_connection(
245        &self,
246        conn_with_returner: ReturnableConnection,
247        conn_senders: &WritingSenders,
248    );
249}
250
251impl<W: Writing> WritingInternal for W {
252    async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
253        &self,
254        message: Self::Message,
255        writer: &mut FramedWrite<A, Self::Codec>,
256    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
257        writer.feed(message).await?;
258        let len = writer.write_buffer().len();
259        // guard against write starvation
260        match timeout(Duration::from_millis(W::TIMEOUT_MS), writer.flush()).await {
261            Ok(Ok(())) => Ok(len),
262            Ok(Err(e)) => Err(e),
263            Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "write timed out")),
264        }
265    }
266
267    async fn handle_new_connection(
268        &self,
269        (mut conn, conn_returner): ReturnableConnection,
270        conn_senders: &WritingSenders,
271    ) {
272        let addr = conn.addr();
273        let codec = self.codec(addr, !conn.side());
274        let Some(writer) = conn.writer.take() else {
275            error!(parent: conn.span(), "the stream was not returned during the handshake!");
276            return;
277        };
278        let mut framed = FramedWrite::new(writer, codec);
279
280        if Self::INITIAL_BUFFER_SIZE != 0 {
281            framed.write_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
282        }
283
284        let (outbound_message_sender, mut outbound_message_receiver) =
285            mpsc::channel(Self::MESSAGE_QUEUE_DEPTH);
286
287        // register the connection's message sender with the Writing protocol handler
288        conn_senders.write().insert(addr, outbound_message_sender);
289
290        // this will automatically drop the sender upon a disconnect
291        let sender_cleanup = SenderCleanup {
292            addr,
293            senders: Arc::clone(conn_senders),
294        };
295
296        // use a channel to know when the writer task is ready
297        let (tx_writer, rx_writer) = oneshot::channel();
298
299        // the task for writing outbound messages
300        let self_clone = self.clone();
301        let conn_stats = conn.stats().clone();
302        let conn_span = conn.span().clone();
303        let writer_task = tokio::spawn(Box::pin(async move {
304            let node = self_clone.node();
305            trace!(parent: &conn_span, "spawned a task for writing messages");
306            if tx_writer.send(()).is_err() {
307                error!(parent: &conn_span, "Writing was interrupted; shutting down its task");
308                return;
309            }
310
311            // move the sender cleanup into ths task
312            let _sender_cleanup = sender_cleanup;
313
314            // disconnect automatically regardless of how this task concludes
315            let _conn_cleanup = DisconnectOnDrop::new(node.clone(), addr);
316
317            while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
318                let msg = wrapped_msg.msg.downcast().unwrap();
319
320                match self_clone.write_to_stream(*msg, &mut framed).await {
321                    Ok(len) => {
322                        if let Some(tx) = wrapped_msg.delivery_notification {
323                            let _ = tx.send(Ok(()));
324                        }
325                        conn_stats.register_sent_message(len);
326                        node.stats().register_sent_message(len);
327                        trace!(parent: &conn_span, "wrote {len}B");
328                    }
329                    Err(e) => {
330                        error!(parent: &conn_span, "couldn't write: {e}");
331                        if let Some(tx) = wrapped_msg.delivery_notification {
332                            let _ = tx.send(Err(e));
333                        }
334                        break;
335                    }
336                }
337            }
338        }));
339        let _ = rx_writer.await;
340        conn.tasks.push(writer_task);
341
342        // return the Connection to the Node, resuming Node::adapt_stream
343        let conn_span = conn.span().clone();
344        if conn_returner.send(Ok(conn)).is_err() {
345            error!(parent: &conn_span, "couldn't return a Connection from the Writing handler");
346        }
347    }
348}
349
350/// Used to queue messages for delivery and return its confirmation.
351pub(crate) struct WrappedMessage {
352    msg: Box<dyn Any + Send>,
353    delivery_notification: Option<oneshot::Sender<io::Result<()>>>,
354}
355
356impl WrappedMessage {
357    fn new(
358        msg: Box<dyn Any + Send>,
359        confirmation: bool,
360    ) -> (Self, Option<oneshot::Receiver<io::Result<()>>>) {
361        let (tx, rx) = if confirmation {
362            let (tx, rx) = oneshot::channel();
363            (Some(tx), Some(rx))
364        } else {
365            (None, None)
366        };
367
368        let wrapped_msg = Self {
369            msg,
370            delivery_notification: tx,
371        };
372
373        (wrapped_msg, rx)
374    }
375}
376
377/// The handler object dedicated to the [`Writing`] protocol.
378pub(crate) struct WritingHandler {
379    handler: ProtocolHandler<Connection, io::Result<Connection>>,
380    pub(crate) senders: WritingSenders,
381}
382
383impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
384    async fn trigger(&self, item: ReturnableConnection) {
385        self.handler.trigger(item).await;
386    }
387}
388
389struct SenderCleanup {
390    addr: SocketAddr,
391    senders: WritingSenders,
392}
393
394impl Drop for SenderCleanup {
395    fn drop(&mut self) {
396        self.senders.write().remove(&self.addr);
397    }
398}