1use std::{
2 any::Any, collections::HashMap, future::Future, io, net::SocketAddr, sync::Arc, time::Duration,
3};
4
5#[cfg(doc)]
6use bytes::Bytes;
7use futures_util::sink::SinkExt;
8use parking_lot::RwLock;
9use tokio::{
10 io::AsyncWrite,
11 sync::{mpsc, oneshot},
12 task::JoinSet,
13 time::timeout,
14};
15use tokio_util::codec::{Encoder, FramedWrite};
16use tracing::*;
17
18#[cfg(doc)]
19use crate::{Config, Node, protocols::Handshake};
20use crate::{
21 Connection, ConnectionSide, Pea2Pea,
22 connections::create_connection_span,
23 node::NodeTask,
24 protocols::{DisconnectOnDrop, Protocol, ProtocolHandler, ReturnableConnection},
25};
26
27type WritingSenders = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<WrappedMessage>>>>;
28
29pub trait Writing: Pea2Pea
32where
33 Self: Clone + Send + Sync + 'static,
34{
35 const MESSAGE_QUEUE_DEPTH: usize = 64;
39
40 const INITIAL_BUFFER_SIZE: usize = 64 * 1024;
43
44 const TIMEOUT_MS: u64 = 10_000;
47
48 type Message: Send;
52
53 type Codec: Encoder<Self::Message, Error = io::Error> + Send;
55
56 fn enable_writing(&self) -> impl Future<Output = ()> {
58 async {
59 let mut setup_tasks = JoinSet::new();
61
62 let (conn_sender, mut conn_receiver) =
63 mpsc::channel(self.node().config().max_connecting as usize);
64
65 let conn_senders: WritingSenders = Default::default();
67 let senders = conn_senders.clone();
69
70 let (tx_writing, rx_writing) = oneshot::channel();
72
73 let self_clone = self.clone();
75 let writing_task = tokio::spawn(async move {
76 trace!(parent: self_clone.node().span(), "spawned the Writing handler task");
77 if tx_writing.send(()).is_err() {
78 error!(parent: self_clone.node().span(), "writing handler creation interrupted! shutting down the node");
79 self_clone.node().shut_down().await;
80 return;
81 }
82
83 loop {
84 tokio::select! {
85 maybe_conn = conn_receiver.recv() => {
87 match maybe_conn {
88 Some(returnable_conn) => {
89 let self_clone2 = self_clone.clone();
90 let senders = conn_senders.clone();
91 setup_tasks.spawn(async move {
92 self_clone2.handle_new_connection(returnable_conn, &senders).await;
93 });
94 }
95 None => break, }
97 }
98 _ = setup_tasks.join_next(), if !setup_tasks.is_empty() => {}
100 }
101 }
102 });
103 let _ = rx_writing.await;
104 self.node()
105 .tasks
106 .lock()
107 .insert(NodeTask::Writing, writing_task);
108
109 let hdl = WritingHandler {
111 handler: ProtocolHandler(conn_sender),
112 senders,
113 };
114 assert!(
115 self.node().protocols.writing.set(hdl).is_ok(),
116 "the Writing protocol was enabled more than once!"
117 );
118 }
119 }
120
121 fn codec(&self, addr: SocketAddr, side: ConnectionSide) -> Self::Codec;
124
125 fn unicast(
137 &self,
138 addr: SocketAddr,
139 message: Self::Message,
140 ) -> io::Result<oneshot::Receiver<io::Result<()>>> {
141 if let Some(handler) = self.node().protocols.writing.get() {
143 if let Some(sender) = handler.senders.read().get(&addr).cloned() {
145 let (msg, delivery) = WrappedMessage::new(Box::new(message), true);
146 let conn_span = create_connection_span(addr, self.node().span());
147 sender
148 .try_send(msg)
149 .map_err(|e| {
150 error!(parent: conn_span, "can't send a message: {e}");
151 match e {
152 mpsc::error::TrySendError::Full(_) => {
153 io::ErrorKind::QuotaExceeded.into()
154 }
155 mpsc::error::TrySendError::Closed(_) => {
156 io::ErrorKind::BrokenPipe.into()
157 }
158 }
159 })
160 .map(|_| delivery.unwrap()) } else {
162 Err(io::ErrorKind::NotConnected.into())
163 }
164 } else {
165 Err(io::ErrorKind::Unsupported.into())
166 }
167 }
168
169 fn unicast_fast(&self, addr: SocketAddr, message: Self::Message) -> io::Result<()> {
177 if let Some(handler) = self.node().protocols.writing.get() {
179 if let Some(sender) = handler.senders.read().get(&addr).cloned() {
181 let (msg, _) = WrappedMessage::new(Box::new(message), false);
182 let conn_span = create_connection_span(addr, self.node().span());
183 sender.try_send(msg).map_err(|e| {
184 error!(parent: conn_span, "can't send a message: {e}");
185 match e {
186 mpsc::error::TrySendError::Full(_) => io::ErrorKind::QuotaExceeded.into(),
187 mpsc::error::TrySendError::Closed(_) => io::ErrorKind::BrokenPipe.into(),
188 }
189 })
190 } else {
191 Err(io::ErrorKind::NotConnected.into())
192 }
193 } else {
194 Err(io::ErrorKind::Unsupported.into())
195 }
196 }
197
198 fn broadcast(&self, message: Self::Message) -> io::Result<()>
213 where
214 Self::Message: Clone,
215 {
216 if let Some(handler) = self.node().protocols.writing.get() {
218 let senders = handler.senders.read().clone();
219 for (addr, message_sender) in senders {
220 let (msg, _) = WrappedMessage::new(Box::new(message.clone()), false);
221 let conn_span = create_connection_span(addr, self.node().span());
222 let _ = message_sender.try_send(msg).map_err(|e| {
223 error!(parent: conn_span, "can't send a message: {e}");
224 });
225 }
226
227 Ok(())
228 } else {
229 Err(io::ErrorKind::Unsupported.into())
230 }
231 }
232}
233
234trait WritingInternal: Writing {
236 async fn write_to_stream<W: AsyncWrite + Unpin + Send>(
238 &self,
239 message: Self::Message,
240 writer: &mut FramedWrite<W, Self::Codec>,
241 ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error>;
242
243 async fn handle_new_connection(
245 &self,
246 conn_with_returner: ReturnableConnection,
247 conn_senders: &WritingSenders,
248 );
249}
250
251impl<W: Writing> WritingInternal for W {
252 async fn write_to_stream<A: AsyncWrite + Unpin + Send>(
253 &self,
254 message: Self::Message,
255 writer: &mut FramedWrite<A, Self::Codec>,
256 ) -> Result<usize, <Self::Codec as Encoder<Self::Message>>::Error> {
257 writer.feed(message).await?;
258 let len = writer.write_buffer().len();
259 match timeout(Duration::from_millis(W::TIMEOUT_MS), writer.flush()).await {
261 Ok(Ok(())) => Ok(len),
262 Ok(Err(e)) => Err(e),
263 Err(_) => Err(io::Error::new(io::ErrorKind::TimedOut, "write timed out")),
264 }
265 }
266
267 async fn handle_new_connection(
268 &self,
269 (mut conn, conn_returner): ReturnableConnection,
270 conn_senders: &WritingSenders,
271 ) {
272 let addr = conn.addr();
273 let codec = self.codec(addr, !conn.side());
274 let Some(writer) = conn.writer.take() else {
275 error!(parent: conn.span(), "the stream was not returned during the handshake!");
276 return;
277 };
278 let mut framed = FramedWrite::new(writer, codec);
279
280 if Self::INITIAL_BUFFER_SIZE != 0 {
281 framed.write_buffer_mut().reserve(Self::INITIAL_BUFFER_SIZE);
282 }
283
284 let (outbound_message_sender, mut outbound_message_receiver) =
285 mpsc::channel(Self::MESSAGE_QUEUE_DEPTH);
286
287 conn_senders.write().insert(addr, outbound_message_sender);
289
290 let sender_cleanup = SenderCleanup {
292 addr,
293 senders: Arc::clone(conn_senders),
294 };
295
296 let (tx_writer, rx_writer) = oneshot::channel();
298
299 let self_clone = self.clone();
301 let conn_stats = conn.stats().clone();
302 let conn_span = conn.span().clone();
303 let writer_task = tokio::spawn(Box::pin(async move {
304 let node = self_clone.node();
305 trace!(parent: &conn_span, "spawned a task for writing messages");
306 if tx_writer.send(()).is_err() {
307 error!(parent: &conn_span, "Writing was interrupted; shutting down its task");
308 return;
309 }
310
311 let _sender_cleanup = sender_cleanup;
313
314 let _conn_cleanup = DisconnectOnDrop::new(node.clone(), addr);
316
317 while let Some(wrapped_msg) = outbound_message_receiver.recv().await {
318 let msg = wrapped_msg.msg.downcast().unwrap();
319
320 match self_clone.write_to_stream(*msg, &mut framed).await {
321 Ok(len) => {
322 if let Some(tx) = wrapped_msg.delivery_notification {
323 let _ = tx.send(Ok(()));
324 }
325 conn_stats.register_sent_message(len);
326 node.stats().register_sent_message(len);
327 trace!(parent: &conn_span, "wrote {len}B");
328 }
329 Err(e) => {
330 error!(parent: &conn_span, "couldn't write: {e}");
331 if let Some(tx) = wrapped_msg.delivery_notification {
332 let _ = tx.send(Err(e));
333 }
334 break;
335 }
336 }
337 }
338 }));
339 let _ = rx_writer.await;
340 conn.tasks.push(writer_task);
341
342 let conn_span = conn.span().clone();
344 if conn_returner.send(Ok(conn)).is_err() {
345 error!(parent: &conn_span, "couldn't return a Connection from the Writing handler");
346 }
347 }
348}
349
350pub(crate) struct WrappedMessage {
352 msg: Box<dyn Any + Send>,
353 delivery_notification: Option<oneshot::Sender<io::Result<()>>>,
354}
355
356impl WrappedMessage {
357 fn new(
358 msg: Box<dyn Any + Send>,
359 confirmation: bool,
360 ) -> (Self, Option<oneshot::Receiver<io::Result<()>>>) {
361 let (tx, rx) = if confirmation {
362 let (tx, rx) = oneshot::channel();
363 (Some(tx), Some(rx))
364 } else {
365 (None, None)
366 };
367
368 let wrapped_msg = Self {
369 msg,
370 delivery_notification: tx,
371 };
372
373 (wrapped_msg, rx)
374 }
375}
376
377pub(crate) struct WritingHandler {
379 handler: ProtocolHandler<Connection, io::Result<Connection>>,
380 pub(crate) senders: WritingSenders,
381}
382
383impl Protocol<Connection, io::Result<Connection>> for WritingHandler {
384 async fn trigger(&self, item: ReturnableConnection) {
385 self.handler.trigger(item).await;
386 }
387}
388
389struct SenderCleanup {
390 addr: SocketAddr,
391 senders: WritingSenders,
392}
393
394impl Drop for SenderCleanup {
395 fn drop(&mut self) {
396 self.senders.write().remove(&self.addr);
397 }
398}