pea2pea/protocols/
writing.rs

1use std::{any::Any, collections::HashMap, future::Future, io, net::SocketAddr, sync::Arc};
2
3use futures_util::sink::SinkExt;
4use parking_lot::RwLock;
5use tokio::{
6    io::AsyncWrite,
7    sync::{mpsc, oneshot},
8};
9use tokio_util::codec::{Encoder, FramedWrite};
10use tracing::*;
11
12use crate::{
13    node::NodeTask,
14    protocols::{Protocol, ProtocolHandler, ReturnableConnection},
15    Connection, ConnectionSide, Pea2Pea,
16};
17#[cfg(doc)]
18use crate::{protocols::Handshake, Config, Node};
19
20type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
21
22/// Can be used to specify and enable writing, i.e. sending outbound messages. If the [`Handshake`]
23/// protocol is enabled too, it goes into force only after the handshake has been concluded.
24pub trait Writing: Pea2Pea
25where
26    Self: Clone + Send + Sync + 'static,
27{
28    /// The depth of per-connection queues used to send outbound messages; the greater it is, the more outbound
29    /// messages the node can enqueue. Setting it to a large value is not recommended, as doing it might
30    /// obscure potential issues with your implementation (like slow serialization) or network.
31    ///
32    /// The default value is 64.
33    const MESSAGE_QUEUE_DEPTH: usize = 64;
34
35    /// The initial size of a per-connection buffer for writing outbound messages. Can be set to the maximum expected size
36    /// of the outbound message in order to only allocate it once.
37    ///
38    /// The default value is 64KiB.
39    const INITIAL_BUFFER_SIZE: usize = 64 * 1024;
40
41    /// The type of the outbound messages; unless their serialization is expensive and the message
42    /// is broadcasted (in which case it would get serialized multiple times), serialization should
43    /// be done in the implementation of [`Self::Codec`].
44    type Message: Send;
45
46    /// The user-supplied [`Encoder`] used to write outbound messages to the target stream.
47    type Codec: Encoder<Self::Message, Error = io::Error> + Send;
48
49    /// Prepares the node to send messages.
50    fn enable_writing(&self) -> impl Future<Output = ()> {
51        async {
52            let (conn_sender, mut conn_receiver) = mpsc::unbounded_channel();
53
54            // the conn_senders are used to send messages from the Node to individual connections
55            let conn_senders: WritingSenders = Default::default();
56            // procure a clone to create the WritingHandler with
57            let senders = conn_senders.clone();
58
59            // use a channel to know when the writing task is ready
60            let (tx_writing, rx_writing) = oneshot::channel();
61
62            // the task spawning tasks sending messages to all the streams
63            let self_clone = self.clone();
64            let writing_task = tokio::spawn(async move {
65                trace!(parent: self_clone.node().span(), "spawned the Writing handler task");
66                if tx_writing.send(()).is_err() {
67                    error!(parent: self_clone.node().span(), "Writing handler creation interrupted! shutting down the node");
68                    self_clone.node().shut_down().await;
69                    return;
70                }
71
72                // these objects are sent from `Node::adapt_stream`
73                while let Some(returnable_conn) = conn_receiver.recv().await {
74                    self_clone
75                        .handle_new_connection(returnable_conn, &conn_senders)
76                        .await;
77                }
78            });
79            let _ = rx_writing.await;
80            self.node()
81                .tasks
82                .lock()
83                .insert(NodeTask::Writing, writing_task);
84
85            // register the WritingHandler with the Node
86            let hdl = WritingHandler {
87                handler: ProtocolHandler(conn_sender),
88                senders,
89            };
90            assert!(
91                self.node().protocols.writing.set(hdl).is_ok(),
92                "the Writing protocol was enabled more than once!"
93            );
94        }
95    }
96
97    /// Creates an [`Encoder`] used to write the outbound messages to the target stream.
98    /// The `side` param indicates the connection side **from the node's perspective**.
99    fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
100
101    /// Sends the provided message to the specified [`SocketAddr`]. Returns as soon as the message is queued to
102    /// be sent, without waiting for the actual delivery; instead, the caller is provided with a [`oneshot::Receiver`]
103    /// which can be used to determine when and whether the message has been delivered.
104    ///
105    /// # Errors
106    ///
107    /// The following errors can be returned:
108    /// - [`io::ErrorKind::NotConnected`] if the node is not connected to the provided address
109    /// - [`io::ErrorKind::Other`] if the outbound message queue for this address is full
110    /// - [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet
111    fn unicast(
112        &self,
113        addr: SocketAddr,
114        message: Self::Message,
115    ) -> io::Result<oneshot::Receiver<io::Result<()>>> {
116        // access the protocol handler
117        if let Some(handler) = self.node().protocols.writing.get() {
118            // find the message sender for the given address
119            if let Some(sender) = handler.senders.read().get(&addr).cloned() {
120                let (msg, delivery) = WrappedMessage::new(Box::new(message));
121                sender.try_send(msg).map_err(|e| {
122                    error!(parent: self.node().span(), "can't send a message to {}: {}", addr, e);
123                    io::ErrorKind::Other.into()
124                }).map(|_| delivery)
125            } else {
126                Err(io::ErrorKind::NotConnected.into())
127            }
128        } else {
129            Err(io::ErrorKind::Unsupported.into())
130        }
131    }
132
133    /// Broadcasts the provided message to all connected peers. Returns as soon as the message is queued to
134    /// be sent to all the peers, without waiting for the actual delivery. This method doesn't provide the
135    /// means to check when and if the messages actually get delivered; you can achieve that by calling
136    /// [`Writing::unicast`] for each address returned by [`Node::connected_addrs`].
137    ///
138    /// # Errors
139    ///
140    /// Returns [`io::ErrorKind::Unsupported`] if [`Writing::enable_writing`] hadn't been called yet.
141    fn broadcast(&self, message: Self::Message) -> io::Result<()>
142    where
143        Self::Message: Clone,
144    {
145        // access the protocol handler
146        if let Some(handler) = self.node().protocols.writing.get() {
147            let senders = handler.senders.read().clone();
148            for (addr, message_sender) in senders {
149                let (msg, _delivery) = WrappedMessage::new(Box::new(message.clone()));
150                let _ = message_sender.try_send(msg).map_err(|e| {
151                    error!(parent: self.node().span(), "can't send a message to {}: {}", addr, e);
152                });
153            }
154
155            Ok(())
156        } else {
157            Err(io::ErrorKind::Unsupported.into())
158        }
159    }
160}
161
162/// This trait is used to restrict access to methods that would otherwise be public in [`Writing`].
163trait WritingInternal: Writing {
164    /// Writes the given message to the network stream and returns the number of written bytes.
165    async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
166        &self,
167        message: Self::Message,
168        writer: &mut FramedWrite<W, Self::Codec>,
169    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
170
171    /// Applies the [`Writing`] protocol to a single connection.
172    async fn handle_new_connection(
173        &self,
174        conn_with_returner: ReturnableConnection,
175        conn_senders: &WritingSenders,
176    );
177}
178
179impl<W: Writing> WritingInternal for W {
180    async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
181        &self,
182        message: Self::Message,
183        writer: &mut FramedWrite<A, Self::Codec>,
184    ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
185        writer.feed(message).await?;
186        let len = writer.write_buffer().len();
187        writer.flush().await?;
188
189        Ok(len)
190    }
191
192    async fn handle_new_connection(
193        &self,
194        (mut conn, conn_returner): ReturnableConnection,
195        conn_senders: &WritingSenders,
196    ) {
197        let addr = conn.addr();
198        let codec = self.codec(addr, !conn.side());
199        let writer = conn.writer.take().expect("missing connection writer!");
200        let mut framed = FramedWrite::new(writer, codec);
201
202        if Self::INITIAL_BUFFER_SIZE != 0 {
203            framed.write_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
204        }
205
206        let (outbound_message_sender, mut outbound_message_receiver) =
207            mpsc::channel(Self::MESSAGE_QUEUE_DEPTH);
208
209        // register the connection's message sender with the Writing protocol handler
210        conn_senders.write().insert(addr, outbound_message_sender);
211
212        // this will automatically drop the sender upon a disconnect
213        let auto_cleanup = SenderCleanup {
214            addr,
215            senders: Arc::clone(conn_senders),
216        };
217
218        // use a channel to know when the writer task is ready
219        let (tx_writer, rx_writer) = oneshot::channel();
220
221        // the task for writing outbound messages
222        let self_clone = self.clone();
223        let conn_stats = conn.stats().clone();
224        let writer_task = tokio::spawn(async move {
225            let node = self_clone.node();
226            trace!(parent: node.span(), "spawned a task for writing messages to {}", addr);
227            if tx_writer.send(()).is_err() {
228                error!(parent: node.span(), "Writing for {} was interrupted; shutting down its task", addr);
229                return;
230            }
231
232            // move the cleanup into the task that gets aborted on disconnect
233            let _auto_cleanup = auto_cleanup;
234
235            while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
236                let msg = wrapped_msg.msg.downcast().unwrap();
237
238                match self_clone.write_to_stream(*msg, &mut framed).await {
239                    Ok(len) => {
240                        let _ = wrapped_msg.delivery_notification.send(Ok(()));
241                        conn_stats.register_sent_message(len);
242                        node.stats().register_sent_message(len);
243                        trace!(parent: node.span(), "sent {}B to {}", len, addr);
244                    }
245                    Err(e) => {
246                        error!(parent: node.span(), "couldn't send a message to {}: {}", addr, e);
247                        let is_fatal = node.config().fatal_io_errors.contains(&e.kind());
248                        let _ = wrapped_msg.delivery_notification.send(Err(e));
249                        if is_fatal {
250                            break;
251                        }
252                    }
253                }
254            }
255
256            let _ = node.disconnect(addr).await;
257        });
258        let _ = rx_writer.await;
259        conn.tasks.push(writer_task);
260
261        // return the Connection to the Node, resuming Node::adapt_stream
262        if conn_returner.send(Ok(conn)).is_err() {
263            error!(parent: self.node().span(), "couldn't return a Connection with {} from the Writing handler", addr);
264        }
265    }
266}
267
268/// Used to queue messages for delivery.
269struct WrappedMessage {
270    msg: Box<dyn Any + Send>,
271    delivery_notification: oneshot::Sender<io::Result<()>>,
272}
273
274impl WrappedMessage {
275    fn new(msg: Box<dyn Any + Send>) -> (Self, oneshot::Receiver<io::Result<()>>) {
276        let (tx, rx) = oneshot::channel();
277        let wrapped_msg = Self {
278            msg,
279            delivery_notification: tx,
280        };
281
282        (wrapped_msg, rx)
283    }
284}
285
286/// The handler object dedicated to the [`Writing`] protocol.
287pub(crate) struct WritingHandler {
288    handler: ProtocolHandler<Connection, io::Result<Connection>>,
289    senders: WritingSenders,
290}
291
292impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
293    fn trigger(&self, item: ReturnableConnection) {
294        self.handler.trigger(item);
295    }
296}
297
298struct SenderCleanup {
299    addr: SocketAddr,
300    senders: WritingSenders,
301}
302
303impl Drop for SenderCleanup {
304    fn drop(&mut self) {
305        self.senders.write().remove(&self.addr);
306    }
307}