1use std::{any::Any, collections::HashMap, future::Future, io, net::SocketAddr, sync::Arc};
2
3use futures_util::sink::SinkExt;
4use parking_lot::RwLock;
5use tokio::{
6 io::AsyncWrite,
7 sync::{mpsc, oneshot},
8};
9use tokio_util::codec::{Encoder, FramedWrite};
10use tracing::*;
11
12use crate::{
13 node::NodeTask,
14 protocols::{Protocol, ProtocolHandler, ReturnableConnection},
15 Connection, ConnectionSide, Pea2Pea,
16};
17#[cfg(doc)]
18use crate::{protocols::Handshake, Config, Node};
19
20type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
21
22pub trait Writing: Pea2Pea
25where
26 Self: Clone + Send + Sync + 'static,
27{
28 const MESSAGE_QUEUE_DEPTH: usize = 64;
34
35 const INITIAL_BUFFER_SIZE: usize = 64 * 1024;
40
41 type Message: Send;
45
46 type Codec: Encoder<Self::Message, Error = io::Error> + Send;
48
49 fn enable_writing(&self) -> impl Future<Output = ()> {
51 async {
52 let (conn_sender, mut conn_receiver) = mpsc::unbounded_channel();
53
54 let conn_senders: WritingSenders = Default::default();
56 let senders = conn_senders.clone();
58
59 let (tx_writing, rx_writing) = oneshot::channel();
61
62 let self_clone = self.clone();
64 let writing_task = tokio::spawn(async move {
65 trace!(parent: self_clone.node().span(), "spawned the Writing handler task");
66 if tx_writing.send(()).is_err() {
67 error!(parent: self_clone.node().span(), "Writing handler creation interrupted! shutting down the node");
68 self_clone.node().shut_down().await;
69 return;
70 }
71
72 while let Some(returnable_conn) = conn_receiver.recv().await {
74 self_clone
75 .handle_new_connection(returnable_conn, &conn_senders)
76 .await;
77 }
78 });
79 let _ = rx_writing.await;
80 self.node()
81 .tasks
82 .lock()
83 .insert(NodeTask::Writing, writing_task);
84
85 let hdl = WritingHandler {
87 handler: ProtocolHandler(conn_sender),
88 senders,
89 };
90 assert!(
91 self.node().protocols.writing.set(hdl).is_ok(),
92 "the Writing protocol was enabled more than once!"
93 );
94 }
95 }
96
97 fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
100
101 fn unicast(
112 &self,
113 addr: SocketAddr,
114 message: Self::Message,
115 ) -> io::Result<oneshot::Receiver<io::Result<()>>> {
116 if let Some(handler) = self.node().protocols.writing.get() {
118 if let Some(sender) = handler.senders.read().get(&addr).cloned() {
120 let (msg, delivery) = WrappedMessage::new(Box::new(message));
121 sender.try_send(msg).map_err(|e| {
122 error!(parent: self.node().span(), "can't send a message to {}: {}", addr, e);
123 io::ErrorKind::Other.into()
124 }).map(|_| delivery)
125 } else {
126 Err(io::ErrorKind::NotConnected.into())
127 }
128 } else {
129 Err(io::ErrorKind::Unsupported.into())
130 }
131 }
132
133 fn broadcast(&self, message: Self::Message) -> io::Result<()>
142 where
143 Self::Message: Clone,
144 {
145 if let Some(handler) = self.node().protocols.writing.get() {
147 let senders = handler.senders.read().clone();
148 for (addr, message_sender) in senders {
149 let (msg, _delivery) = WrappedMessage::new(Box::new(message.clone()));
150 let _ = message_sender.try_send(msg).map_err(|e| {
151 error!(parent: self.node().span(), "can't send a message to {}: {}", addr, e);
152 });
153 }
154
155 Ok(())
156 } else {
157 Err(io::ErrorKind::Unsupported.into())
158 }
159 }
160}
161
162trait WritingInternal: Writing {
164 async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
166 &self,
167 message: Self::Message,
168 writer: &mut FramedWrite<W, Self::Codec>,
169 ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
170
171 async fn handle_new_connection(
173 &self,
174 conn_with_returner: ReturnableConnection,
175 conn_senders: &WritingSenders,
176 );
177}
178
179impl<W: Writing> WritingInternal for W {
180 async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
181 &self,
182 message: Self::Message,
183 writer: &mut FramedWrite<A, Self::Codec>,
184 ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
185 writer.feed(message).await?;
186 let len = writer.write_buffer().len();
187 writer.flush().await?;
188
189 Ok(len)
190 }
191
192 async fn handle_new_connection(
193 &self,
194 (mut conn, conn_returner): ReturnableConnection,
195 conn_senders: &WritingSenders,
196 ) {
197 let addr = conn.addr();
198 let codec = self.codec(addr, !conn.side());
199 let writer = conn.writer.take().expect("missing connection writer!");
200 let mut framed = FramedWrite::new(writer, codec);
201
202 if Self::INITIAL_BUFFER_SIZE != 0 {
203 framed.write_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
204 }
205
206 let (outbound_message_sender, mut outbound_message_receiver) =
207 mpsc::channel(Self::MESSAGE_QUEUE_DEPTH);
208
209 conn_senders.write().insert(addr, outbound_message_sender);
211
212 let auto_cleanup = SenderCleanup {
214 addr,
215 senders: Arc::clone(conn_senders),
216 };
217
218 let (tx_writer, rx_writer) = oneshot::channel();
220
221 let self_clone = self.clone();
223 let conn_stats = conn.stats().clone();
224 let writer_task = tokio::spawn(async move {
225 let node = self_clone.node();
226 trace!(parent: node.span(), "spawned a task for writing messages to {}", addr);
227 if tx_writer.send(()).is_err() {
228 error!(parent: node.span(), "Writing for {} was interrupted; shutting down its task", addr);
229 return;
230 }
231
232 let _auto_cleanup = auto_cleanup;
234
235 while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
236 let msg = wrapped_msg.msg.downcast().unwrap();
237
238 match self_clone.write_to_stream(*msg, &mut framed).await {
239 Ok(len) => {
240 let _ = wrapped_msg.delivery_notification.send(Ok(()));
241 conn_stats.register_sent_message(len);
242 node.stats().register_sent_message(len);
243 trace!(parent: node.span(), "sent {}B to {}", len, addr);
244 }
245 Err(e) => {
246 error!(parent: node.span(), "couldn't send a message to {}: {}", addr, e);
247 let is_fatal = node.config().fatal_io_errors.contains(&e.kind());
248 let _ = wrapped_msg.delivery_notification.send(Err(e));
249 if is_fatal {
250 break;
251 }
252 }
253 }
254 }
255
256 let _ = node.disconnect(addr).await;
257 });
258 let _ = rx_writer.await;
259 conn.tasks.push(writer_task);
260
261 if conn_returner.send(Ok(conn)).is_err() {
263 error!(parent: self.node().span(), "couldn't return a Connection with {} from the Writing handler", addr);
264 }
265 }
266}
267
268struct WrappedMessage {
270 msg: Box<dyn Any + Send>,
271 delivery_notification: oneshot::Sender<io::Result<()>>,
272}
273
274impl WrappedMessage {
275 fn new(msg: Box<dyn Any + Send>) -> (Self, oneshot::Receiver<io::Result<()>>) {
276 let (tx, rx) = oneshot::channel();
277 let wrapped_msg = Self {
278 msg,
279 delivery_notification: tx,
280 };
281
282 (wrapped_msg, rx)
283 }
284}
285
286pub(crate) struct WritingHandler {
288 handler: ProtocolHandler<Connection, io::Result<Connection>>,
289 senders: WritingSenders,
290}
291
292impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
293 fn trigger(&self, item: ReturnableConnection) {
294 self.handler.trigger(item);
295 }
296}
297
298struct SenderCleanup {
299 addr: SocketAddr,
300 senders: WritingSenders,
301}
302
303impl Drop for SenderCleanup {
304 fn drop(&mut self) {
305 self.senders.write().remove(&self.addr);
306 }
307}