1use crate::async_carrier::{self, AsyncCommandSender, DemandBatcher};
17use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
18use std::net::SocketAddr;
19use std::sync::{Arc, Mutex, mpsc as std_mpsc};
20use tokio::net::{ToSocketAddrs, UdpSocket};
21use tokio::runtime::Handle;
22use tokio::sync::mpsc as tokio_mpsc;
23use tokio::task::JoinHandle;
24
25pub const DEFAULT_MAX_DATAGRAM_SIZE: usize = 65_536;
31
32pub const DEFAULT_RECEIVE_BUFFER: usize = 64;
34
35#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct Datagram {
38 pub payload: Vec<u8>,
40 pub remote: SocketAddr,
42}
43
44impl Datagram {
45 #[must_use]
47 pub fn new(payload: impl Into<Vec<u8>>, remote: SocketAddr) -> Self {
48 Self {
49 payload: payload.into(),
50 remote,
51 }
52 }
53
54 #[must_use]
56 pub fn payload(&self) -> &[u8] {
57 &self.payload
58 }
59
60 #[must_use]
62 pub fn remote(&self) -> SocketAddr {
63 self.remote
64 }
65
66 #[must_use]
68 pub fn into_parts(self) -> (Vec<u8>, SocketAddr) {
69 (self.payload, self.remote)
70 }
71
72 #[must_use]
74 pub fn into_payload(self) -> Vec<u8> {
75 self.payload
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct UdpBinding {
82 pub local_addr: SocketAddr,
83}
84
85impl UdpBinding {
86 #[must_use]
87 pub fn local_addr(&self) -> SocketAddr {
88 self.local_addr
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub struct UdpConnection {
95 pub local_addr: SocketAddr,
96 pub remote_addr: SocketAddr,
97}
98
99impl UdpConnection {
100 #[must_use]
101 pub fn local_addr(&self) -> SocketAddr {
102 self.local_addr
103 }
104
105 #[must_use]
106 pub fn remote_addr(&self) -> SocketAddr {
107 self.remote_addr
108 }
109}
110
111pub struct TokioUdp;
113
114pub type Udp = TokioUdp;
116
117enum ReceiveResponse<T> {
118 Item(T),
119 Error(StreamError),
120}
121
122enum QueueOutcome {
123 Queued,
124 Dropped,
125 Closed,
126}
127
128struct ReceiveResource<T> {
129 receiver: std_mpsc::Receiver<ReceiveResponse<T>>,
130 carrier: UdpCarrier<T>,
131 demand: DemandBatcher,
132}
133
134impl<T> Drop for ReceiveResource<T> {
135 fn drop(&mut self) {
136 self.carrier.close_read();
137 }
138}
139
140enum UdpCarrierCommand<T> {
141 Demand(usize),
142 SendOne(T),
143 SendBatch(Vec<T>),
144 CloseRead,
145 CloseWrite {
146 ack: std_mpsc::Sender<StreamResult<()>>,
147 },
148}
149
150#[derive(Clone)]
151struct UdpCarrier<T> {
152 inner: Arc<UdpCarrierInner<T>>,
153}
154
155struct UdpCarrierInner<T> {
156 commands: AsyncCommandSender<UdpCarrierCommand<T>>,
157 send_errors: Mutex<std_mpsc::Receiver<StreamError>>,
158 task: Mutex<Option<JoinHandle<()>>>,
159}
160
161impl<T> Drop for UdpCarrierInner<T> {
162 fn drop(&mut self) {
163 if let Some(task) = self.task.lock().expect("UDP carrier task poisoned").take() {
164 task.abort();
165 }
166 }
167}
168
169impl<T> UdpCarrier<T> {
170 fn close_read(&self) {
171 let _ = self.inner.commands.try_send(UdpCarrierCommand::CloseRead);
172 }
173}
174
175impl<T> UdpCarrier<T>
176where
177 T: Send + 'static,
178{
179 fn request_demand(&self, demand: usize) -> StreamResult<()> {
180 self.inner
181 .commands
182 .send_or_blocking(UdpCarrierCommand::Demand(demand))
183 }
184
185 fn send_items(&self, items: Vec<T>) -> StreamResult<()> {
186 self.inner
187 .commands
188 .send_or_blocking(UdpCarrierCommand::SendBatch(items))
189 }
190
191 fn send_one(&self, item: T) -> StreamResult<()> {
192 self.inner
193 .commands
194 .send_or_blocking(UdpCarrierCommand::SendOne(item))
195 }
196
197 fn close_write(&self) -> StreamResult<()> {
198 self.check_send_error()?;
199 let (ack_sender, ack_receiver) = std_mpsc::channel();
200 self.inner
201 .commands
202 .send_or_blocking(UdpCarrierCommand::CloseWrite { ack: ack_sender })?;
203 match ack_receiver.recv() {
204 Ok(result) => result,
205 Err(_) => Err(abrupt_termination()),
206 }?;
207 self.check_send_error()
208 }
209
210 fn check_send_error(&self) -> StreamResult<()> {
211 match self
212 .inner
213 .send_errors
214 .lock()
215 .expect("UDP carrier send error receiver poisoned")
216 .try_recv()
217 {
218 Ok(error) => Err(error),
219 Err(std_mpsc::TryRecvError::Empty) | Err(std_mpsc::TryRecvError::Disconnected) => {
220 Ok(())
221 }
222 }
223 }
224}
225
226struct SendResource<T> {
227 carrier: UdpCarrier<T>,
228 pending: Vec<T>,
229 batch_size: usize,
230}
231
232type UdpCarrierParts<T> = (
233 UdpCarrier<T>,
234 Option<std_mpsc::Receiver<ReceiveResponse<T>>>,
235);
236
237fn io_error(error: std::io::Error) -> StreamError {
238 StreamError::Failed(error.to_string())
239}
240
241fn abrupt_termination() -> StreamError {
242 StreamError::AbruptTermination
243}
244
245impl TokioUdp {
246 #[must_use]
254 pub fn bind<A>(
255 addr: A,
256 max_datagram_size: usize,
257 receive_buffer: usize,
258 ) -> Source<Datagram, StreamCompletion<UdpBinding>>
259 where
260 A: ToSocketAddrs + Clone + Send + Sync + 'static,
261 {
262 assert!(
263 max_datagram_size > 0,
264 "maximum datagram size must be greater than zero"
265 );
266 assert!(
267 receive_buffer > 0,
268 "receive buffer must be greater than zero"
269 );
270 Source::lazy_future_source(move || {
271 let addr = addr.clone();
272 async move {
273 let handle = Handle::current();
274 let socket = UdpSocket::bind(addr).await.map_err(io_error)?;
275 let local_addr = socket.local_addr().map_err(io_error)?;
276 Ok(datagram_source_from_socket(
277 socket,
278 local_addr,
279 handle,
280 max_datagram_size,
281 receive_buffer,
282 ))
283 }
284 })
285 }
286
287 #[must_use]
290 pub fn bind_default<A>(addr: A) -> Source<Datagram, StreamCompletion<UdpBinding>>
291 where
292 A: ToSocketAddrs + Clone + Send + Sync + 'static,
293 {
294 Self::bind(addr, DEFAULT_MAX_DATAGRAM_SIZE, DEFAULT_RECEIVE_BUFFER)
295 }
296
297 #[must_use]
304 pub fn send_sink<A>(local_addr: A) -> Sink<Datagram, StreamCompletion<NotUsed>>
305 where
306 A: ToSocketAddrs + Clone + Send + Sync + 'static,
307 {
308 Flow::<Datagram, NotUsed>::future_flow(move || {
309 let local_addr = local_addr.clone();
310 async move {
311 let handle = Handle::current();
312 let socket = UdpSocket::bind(local_addr).await.map_err(io_error)?;
313 let carrier = start_datagram_carrier(
314 socket,
315 handle,
316 DEFAULT_MAX_DATAGRAM_SIZE,
317 1,
318 false,
319 true,
320 );
321 Ok(datagram_send_flow_from_carrier(carrier.0, 1))
322 }
323 })
324 .to_mat(Sink::ignore(), Keep::right)
325 }
326
327 #[must_use]
334 pub fn bind_flow<A>(
335 addr: A,
336 max_datagram_size: usize,
337 receive_buffer: usize,
338 ) -> Flow<Datagram, Datagram, StreamCompletion<UdpBinding>>
339 where
340 A: ToSocketAddrs + Clone + Send + Sync + 'static,
341 {
342 assert!(
343 max_datagram_size > 0,
344 "maximum datagram size must be greater than zero"
345 );
346 assert!(
347 receive_buffer > 0,
348 "receive buffer must be greater than zero"
349 );
350 Flow::<Datagram, Datagram>::future_flow(move || {
351 let addr = addr.clone();
352 async move {
353 let handle = Handle::current();
354 let socket = UdpSocket::bind(addr).await.map_err(io_error)?;
355 let local_addr = socket.local_addr().map_err(io_error)?;
356 let (carrier, receiver) = start_datagram_carrier(
357 socket,
358 handle,
359 max_datagram_size,
360 receive_buffer,
361 true,
362 true,
363 );
364 let sink = datagram_send_flow_from_carrier(carrier.clone(), 1)
365 .to_mat(Sink::ignore(), Keep::right);
366 let source = datagram_source_from_carrier(
367 carrier,
368 receiver.expect("UDP bind_flow receiver exists"),
369 local_addr,
370 receive_buffer,
371 );
372 Ok(Flow::from_sink_and_source(sink, source)
373 .map_materialized_value(move |_| UdpBinding { local_addr }))
374 }
375 })
376 }
377
378 #[must_use]
381 pub fn bind_flow_default<A>(addr: A) -> Flow<Datagram, Datagram, StreamCompletion<UdpBinding>>
382 where
383 A: ToSocketAddrs + Clone + Send + Sync + 'static,
384 {
385 Self::bind_flow(addr, DEFAULT_MAX_DATAGRAM_SIZE, DEFAULT_RECEIVE_BUFFER)
386 }
387
388 #[must_use]
395 pub fn connect<A, P>(
396 local_addr: A,
397 peer: P,
398 max_datagram_size: usize,
399 receive_buffer: usize,
400 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<UdpConnection>>
401 where
402 A: ToSocketAddrs + Clone + Send + Sync + 'static,
403 P: ToSocketAddrs + Clone + Send + Sync + 'static,
404 {
405 assert!(
406 max_datagram_size > 0,
407 "maximum datagram size must be greater than zero"
408 );
409 assert!(
410 receive_buffer > 0,
411 "receive buffer must be greater than zero"
412 );
413 Flow::<Vec<u8>, Vec<u8>>::future_flow(move || {
414 let local_addr = local_addr.clone();
415 let peer = peer.clone();
416 async move {
417 let handle = Handle::current();
418 let socket = UdpSocket::bind(local_addr).await.map_err(io_error)?;
419 socket.connect(peer).await.map_err(io_error)?;
420 let connection = UdpConnection {
421 local_addr: socket.local_addr().map_err(io_error)?,
422 remote_addr: socket.peer_addr().map_err(io_error)?,
423 };
424 let (carrier, receiver) = start_connected_carrier(
425 socket,
426 handle,
427 max_datagram_size,
428 receive_buffer,
429 true,
430 true,
431 );
432 let sink = connected_send_flow_from_carrier(carrier.clone(), 1)
433 .to_mat(Sink::ignore(), Keep::right);
434 let source = connected_source_from_carrier(
435 carrier,
436 receiver.expect("connected UDP receiver exists"),
437 receive_buffer,
438 );
439 Ok(Flow::from_sink_and_source(sink, source)
440 .map_materialized_value(move |_| connection))
441 }
442 })
443 }
444
445 #[must_use]
448 pub fn connect_default<A, P>(
449 local_addr: A,
450 peer: P,
451 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<UdpConnection>>
452 where
453 A: ToSocketAddrs + Clone + Send + Sync + 'static,
454 P: ToSocketAddrs + Clone + Send + Sync + 'static,
455 {
456 Self::connect(
457 local_addr,
458 peer,
459 DEFAULT_MAX_DATAGRAM_SIZE,
460 DEFAULT_RECEIVE_BUFFER,
461 )
462 }
463}
464
465fn datagram_source_from_socket(
466 socket: UdpSocket,
467 local_addr: SocketAddr,
468 handle: Handle,
469 max_datagram_size: usize,
470 receive_buffer: usize,
471) -> Source<Datagram, UdpBinding> {
472 let (carrier, receiver) = start_datagram_carrier(
473 socket,
474 handle,
475 max_datagram_size,
476 receive_buffer,
477 true,
478 false,
479 );
480 datagram_source_from_carrier(
481 carrier,
482 receiver.expect("UDP bind receiver exists"),
483 local_addr,
484 receive_buffer,
485 )
486}
487
488fn datagram_source_from_carrier(
489 carrier: UdpCarrier<Datagram>,
490 receiver: std_mpsc::Receiver<ReceiveResponse<Datagram>>,
491 local_addr: SocketAddr,
492 receive_buffer: usize,
493) -> Source<Datagram, UdpBinding> {
494 let receiver = Arc::new(Mutex::new(Some(receiver)));
495 Source::unfold_resource(
496 move || {
497 let receiver = receiver
498 .lock()
499 .expect("UDP receive resource receiver poisoned")
500 .take()
501 .ok_or_else(|| StreamError::Failed("UDP receive source already used".to_owned()))?;
502 let demand = DemandBatcher::new(receive_buffer);
503 carrier.request_demand(demand.initial())?;
504 Ok(ReceiveResource {
505 receiver,
506 carrier: carrier.clone(),
507 demand,
508 })
509 },
510 receive_next_item,
511 close_receive_resource,
512 )
513 .map_materialized_value(move |_| UdpBinding { local_addr })
514}
515
516fn connected_source_from_carrier(
517 carrier: UdpCarrier<Vec<u8>>,
518 receiver: std_mpsc::Receiver<ReceiveResponse<Vec<u8>>>,
519 receive_buffer: usize,
520) -> Source<Vec<u8>, NotUsed> {
521 let receiver = Arc::new(Mutex::new(Some(receiver)));
522 Source::unfold_resource(
523 move || {
524 let receiver = receiver
525 .lock()
526 .expect("connected UDP receive resource receiver poisoned")
527 .take()
528 .ok_or_else(|| {
529 StreamError::Failed("connected UDP receive source already used".to_owned())
530 })?;
531 let demand = DemandBatcher::new(receive_buffer);
532 carrier.request_demand(demand.initial())?;
533 Ok(ReceiveResource {
534 receiver,
535 carrier: carrier.clone(),
536 demand,
537 })
538 },
539 receive_next_item,
540 close_receive_resource,
541 )
542}
543
544fn receive_next_item<T>(resource: &mut ReceiveResource<T>) -> StreamResult<Option<T>>
545where
546 T: Send + 'static,
547{
548 match resource.receiver.recv() {
549 Ok(ReceiveResponse::Item(item)) => {
550 if let Some(demand) = resource.demand.record_consumed() {
551 resource.carrier.request_demand(demand)?;
552 }
553 Ok(Some(item))
554 }
555 Ok(ReceiveResponse::Error(error)) => Err(error),
556 Err(_) => Err(abrupt_termination()),
557 }
558}
559
560fn close_receive_resource<T>(resource: ReceiveResource<T>) -> StreamResult<()>
561where
562 T: Send + 'static,
563{
564 resource.carrier.close_read();
565 Ok(())
566}
567
568fn start_datagram_carrier(
569 socket: UdpSocket,
570 handle: Handle,
571 max_datagram_size: usize,
572 receive_buffer: usize,
573 read_open: bool,
574 write_open: bool,
575) -> UdpCarrierParts<Datagram> {
576 let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
577 let (commands, command_receiver) = async_carrier::command_channel(command_capacity, "UDP");
578 let (send_error_sender, send_error_receiver) = std_mpsc::channel();
579 let (receive_sender, receive_receiver) = if read_open {
580 let (sender, receiver) = std_mpsc::sync_channel(receive_buffer.saturating_add(1));
581 (Some(sender), Some(receiver))
582 } else {
583 (None, None)
584 };
585 let task = handle.spawn(run_datagram_carrier_task(
586 socket,
587 max_datagram_size,
588 receive_sender,
589 send_error_sender,
590 command_receiver,
591 read_open,
592 write_open,
593 ));
594 (
595 UdpCarrier {
596 inner: Arc::new(UdpCarrierInner {
597 commands,
598 send_errors: Mutex::new(send_error_receiver),
599 task: Mutex::new(Some(task)),
600 }),
601 },
602 receive_receiver,
603 )
604}
605
606fn start_connected_carrier(
607 socket: UdpSocket,
608 handle: Handle,
609 max_datagram_size: usize,
610 receive_buffer: usize,
611 read_open: bool,
612 write_open: bool,
613) -> UdpCarrierParts<Vec<u8>> {
614 let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
615 let (commands, command_receiver) =
616 async_carrier::command_channel(command_capacity, "connected UDP");
617 let (send_error_sender, send_error_receiver) = std_mpsc::channel();
618 let (receive_sender, receive_receiver) = if read_open {
619 let (sender, receiver) = std_mpsc::sync_channel(receive_buffer.saturating_add(1));
620 (Some(sender), Some(receiver))
621 } else {
622 (None, None)
623 };
624 let task = handle.spawn(run_connected_carrier_task(
625 socket,
626 max_datagram_size,
627 receive_sender,
628 send_error_sender,
629 command_receiver,
630 read_open,
631 write_open,
632 ));
633 (
634 UdpCarrier {
635 inner: Arc::new(UdpCarrierInner {
636 commands,
637 send_errors: Mutex::new(send_error_receiver),
638 task: Mutex::new(Some(task)),
639 }),
640 },
641 receive_receiver,
642 )
643}
644
645async fn run_datagram_carrier_task(
646 socket: UdpSocket,
647 max_datagram_size: usize,
648 receive_sender: Option<std_mpsc::SyncSender<ReceiveResponse<Datagram>>>,
649 send_error_sender: std_mpsc::Sender<StreamError>,
650 mut commands: tokio_mpsc::Receiver<UdpCarrierCommand<Datagram>>,
651 mut read_open: bool,
652 mut write_open: bool,
653) {
654 let mut buffer = vec![0_u8; max_datagram_size];
655 let mut requested = 0_usize;
656 loop {
657 if !read_open && !write_open {
658 return;
659 }
660
661 if read_open && requested > 0 {
662 tokio::select! {
663 biased;
664 command = commands.recv() => {
665 let Some(command) = command else { return; };
666 if !handle_datagram_command(
667 &socket,
668 command,
669 &receive_sender,
670 &send_error_sender,
671 &mut read_open,
672 &mut write_open,
673 &mut requested,
674 ).await {
675 return;
676 }
677 }
678 received = socket.recv_from(&mut buffer) => {
679 match received {
680 Ok((read, remote)) => {
681 let datagram = Datagram::new(buffer[..read].to_vec(), remote);
682 match try_send_received_item(&receive_sender, datagram) {
683 QueueOutcome::Queued => {
684 requested = requested.saturating_sub(1);
685 }
686 QueueOutcome::Dropped => {
687 requested = 0;
688 if let Err(error) = drain_ready_datagrams(&socket, &mut buffer) {
689 report_carrier_error(&receive_sender, &send_error_sender, error);
690 return;
691 }
692 }
693 QueueOutcome::Closed => {
694 read_open = false;
695 }
696 }
697 }
698 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
699 Err(error) => {
700 report_carrier_error(&receive_sender, &send_error_sender, io_error(error));
701 return;
702 }
703 }
704 }
705 }
706 } else {
707 let Some(command) = commands.recv().await else {
708 return;
709 };
710 if !handle_datagram_command(
711 &socket,
712 command,
713 &receive_sender,
714 &send_error_sender,
715 &mut read_open,
716 &mut write_open,
717 &mut requested,
718 )
719 .await
720 {
721 return;
722 }
723 }
724 }
725}
726
727async fn run_connected_carrier_task(
728 socket: UdpSocket,
729 max_datagram_size: usize,
730 receive_sender: Option<std_mpsc::SyncSender<ReceiveResponse<Vec<u8>>>>,
731 send_error_sender: std_mpsc::Sender<StreamError>,
732 mut commands: tokio_mpsc::Receiver<UdpCarrierCommand<Vec<u8>>>,
733 mut read_open: bool,
734 mut write_open: bool,
735) {
736 let mut buffer = vec![0_u8; max_datagram_size];
737 let mut requested = 0_usize;
738 loop {
739 if !read_open && !write_open {
740 return;
741 }
742
743 if read_open && requested > 0 {
744 tokio::select! {
745 biased;
746 command = commands.recv() => {
747 let Some(command) = command else { return; };
748 if !handle_connected_command(
749 &socket,
750 command,
751 &receive_sender,
752 &send_error_sender,
753 &mut read_open,
754 &mut write_open,
755 &mut requested,
756 ).await {
757 return;
758 }
759 }
760 received = socket.recv(&mut buffer) => {
761 match received {
762 Ok(read) => {
763 match try_send_received_item(&receive_sender, buffer[..read].to_vec()) {
764 QueueOutcome::Queued => {
765 requested = requested.saturating_sub(1);
766 }
767 QueueOutcome::Dropped => {
768 requested = 0;
769 if let Err(error) = drain_ready_connected_datagrams(&socket, &mut buffer) {
770 report_carrier_error(&receive_sender, &send_error_sender, error);
771 return;
772 }
773 }
774 QueueOutcome::Closed => {
775 read_open = false;
776 }
777 }
778 }
779 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
780 Err(error) => {
781 report_carrier_error(&receive_sender, &send_error_sender, io_error(error));
782 return;
783 }
784 }
785 }
786 }
787 } else {
788 let Some(command) = commands.recv().await else {
789 return;
790 };
791 if !handle_connected_command(
792 &socket,
793 command,
794 &receive_sender,
795 &send_error_sender,
796 &mut read_open,
797 &mut write_open,
798 &mut requested,
799 )
800 .await
801 {
802 return;
803 }
804 }
805 }
806}
807
808async fn handle_datagram_command(
809 socket: &UdpSocket,
810 command: UdpCarrierCommand<Datagram>,
811 receive_sender: &Option<std_mpsc::SyncSender<ReceiveResponse<Datagram>>>,
812 send_error_sender: &std_mpsc::Sender<StreamError>,
813 read_open: &mut bool,
814 write_open: &mut bool,
815 requested: &mut usize,
816) -> bool {
817 match command {
818 UdpCarrierCommand::Demand(demand) => {
819 *requested = requested.saturating_add(demand);
820 true
821 }
822 UdpCarrierCommand::SendOne(datagram) => {
823 if !*write_open {
824 let error = StreamError::Failed("UDP write side is closed".to_owned());
825 report_carrier_error(receive_sender, send_error_sender, error);
826 return false;
827 }
828 send_one_datagram(socket, receive_sender, send_error_sender, datagram).await
829 }
830 UdpCarrierCommand::SendBatch(datagrams) => {
831 if !*write_open {
832 let error = StreamError::Failed("UDP write side is closed".to_owned());
833 report_carrier_error(receive_sender, send_error_sender, error);
834 return false;
835 }
836 for datagram in datagrams {
837 if !send_one_datagram(socket, receive_sender, send_error_sender, datagram).await {
838 return false;
839 }
840 }
841 true
842 }
843 UdpCarrierCommand::CloseRead => {
844 *read_open = false;
845 true
846 }
847 UdpCarrierCommand::CloseWrite { ack } => {
848 *write_open = false;
849 let _ = ack.send(Ok(()));
850 true
851 }
852 }
853}
854
855async fn handle_connected_command(
856 socket: &UdpSocket,
857 command: UdpCarrierCommand<Vec<u8>>,
858 receive_sender: &Option<std_mpsc::SyncSender<ReceiveResponse<Vec<u8>>>>,
859 send_error_sender: &std_mpsc::Sender<StreamError>,
860 read_open: &mut bool,
861 write_open: &mut bool,
862 requested: &mut usize,
863) -> bool {
864 match command {
865 UdpCarrierCommand::Demand(demand) => {
866 *requested = requested.saturating_add(demand);
867 true
868 }
869 UdpCarrierCommand::SendOne(payload) => {
870 if !*write_open {
871 let error = StreamError::Failed("connected UDP write side is closed".to_owned());
872 report_carrier_error(receive_sender, send_error_sender, error);
873 return false;
874 }
875 send_one_connected_payload(socket, receive_sender, send_error_sender, payload).await
876 }
877 UdpCarrierCommand::SendBatch(payloads) => {
878 if !*write_open {
879 let error = StreamError::Failed("connected UDP write side is closed".to_owned());
880 report_carrier_error(receive_sender, send_error_sender, error);
881 return false;
882 }
883 for payload in payloads {
884 if !send_one_connected_payload(socket, receive_sender, send_error_sender, payload)
885 .await
886 {
887 return false;
888 }
889 }
890 true
891 }
892 UdpCarrierCommand::CloseRead => {
893 *read_open = false;
894 true
895 }
896 UdpCarrierCommand::CloseWrite { ack } => {
897 *write_open = false;
898 let _ = ack.send(Ok(()));
899 true
900 }
901 }
902}
903
904async fn send_one_datagram(
905 socket: &UdpSocket,
906 receive_sender: &Option<std_mpsc::SyncSender<ReceiveResponse<Datagram>>>,
907 send_error_sender: &std_mpsc::Sender<StreamError>,
908 datagram: Datagram,
909) -> bool {
910 let expected = datagram.payload.len();
911 match socket.send_to(&datagram.payload, datagram.remote).await {
912 Ok(sent) if sent == expected => true,
913 Ok(sent) => {
914 report_carrier_error(
915 receive_sender,
916 send_error_sender,
917 short_send_error(sent, expected),
918 );
919 false
920 }
921 Err(error) => {
922 report_carrier_error(receive_sender, send_error_sender, io_error(error));
923 false
924 }
925 }
926}
927
928async fn send_one_connected_payload(
929 socket: &UdpSocket,
930 receive_sender: &Option<std_mpsc::SyncSender<ReceiveResponse<Vec<u8>>>>,
931 send_error_sender: &std_mpsc::Sender<StreamError>,
932 payload: Vec<u8>,
933) -> bool {
934 let expected = payload.len();
935 match socket.send(&payload).await {
936 Ok(sent) if sent == expected => true,
937 Ok(sent) => {
938 report_carrier_error(
939 receive_sender,
940 send_error_sender,
941 short_send_error(sent, expected),
942 );
943 false
944 }
945 Err(error) => {
946 report_carrier_error(receive_sender, send_error_sender, io_error(error));
947 false
948 }
949 }
950}
951
952fn try_send_received_item<T>(
953 sender: &Option<std_mpsc::SyncSender<ReceiveResponse<T>>>,
954 item: T,
955) -> QueueOutcome
956where
957 T: Send + 'static,
958{
959 let Some(sender) = sender else {
960 return QueueOutcome::Closed;
961 };
962 match sender.try_send(ReceiveResponse::Item(item)) {
963 Ok(()) => QueueOutcome::Queued,
964 Err(std_mpsc::TrySendError::Full(_)) => QueueOutcome::Dropped,
965 Err(std_mpsc::TrySendError::Disconnected(_)) => QueueOutcome::Closed,
966 }
967}
968
969fn drain_ready_datagrams(socket: &UdpSocket, buffer: &mut [u8]) -> StreamResult<()> {
970 loop {
971 match socket.try_recv_from(buffer) {
972 Ok((_read, _remote)) => {}
973 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
974 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
975 Err(error) => return Err(io_error(error)),
976 }
977 }
978}
979
980fn drain_ready_connected_datagrams(socket: &UdpSocket, buffer: &mut [u8]) -> StreamResult<()> {
981 loop {
982 match socket.try_recv(buffer) {
983 Ok(_read) => {}
984 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
985 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
986 Err(error) => return Err(io_error(error)),
987 }
988 }
989}
990
991fn report_carrier_error<T>(
992 receive_sender: &Option<std_mpsc::SyncSender<ReceiveResponse<T>>>,
993 send_error_sender: &std_mpsc::Sender<StreamError>,
994 error: StreamError,
995) where
996 T: Send + 'static,
997{
998 let _ = send_error_sender.send(error.clone());
999 if let Some(receive_sender) = receive_sender {
1000 let _ = receive_sender.try_send(ReceiveResponse::Error(error));
1001 }
1002}
1003
1004fn datagram_send_flow_from_carrier(
1005 carrier: UdpCarrier<Datagram>,
1006 batch_size: usize,
1007) -> Flow<Datagram, NotUsed, NotUsed> {
1008 Flow::<Datagram, Datagram>::identity().map_with_resource(
1009 move || {
1010 Ok(SendResource {
1011 carrier: carrier.clone(),
1012 pending: Vec::with_capacity(batch_size),
1013 batch_size,
1014 })
1015 },
1016 |resource, datagram| {
1017 send_datagram(resource, datagram)?;
1018 Ok(NotUsed)
1019 },
1020 close_send_resource,
1021 )
1022}
1023
1024fn connected_send_flow_from_carrier(
1025 carrier: UdpCarrier<Vec<u8>>,
1026 batch_size: usize,
1027) -> Flow<Vec<u8>, NotUsed, NotUsed> {
1028 Flow::<Vec<u8>, Vec<u8>>::identity().map_with_resource(
1029 move || {
1030 Ok(SendResource {
1031 carrier: carrier.clone(),
1032 pending: Vec::with_capacity(batch_size),
1033 batch_size,
1034 })
1035 },
1036 |resource, payload| {
1037 send_connected_payload(resource, payload)?;
1038 Ok(NotUsed)
1039 },
1040 close_send_resource,
1041 )
1042}
1043
1044fn close_send_resource<T>(mut resource: SendResource<T>) -> StreamResult<Option<NotUsed>>
1045where
1046 T: Send + 'static,
1047{
1048 flush_send_resource(&mut resource)?;
1049 resource.carrier.close_write()?;
1050 Ok(None)
1051}
1052
1053fn send_datagram(resource: &mut SendResource<Datagram>, datagram: Datagram) -> StreamResult<()> {
1054 send_item(resource, datagram)
1055}
1056
1057fn send_connected_payload(
1058 resource: &mut SendResource<Vec<u8>>,
1059 payload: Vec<u8>,
1060) -> StreamResult<()> {
1061 send_item(resource, payload)
1062}
1063
1064fn send_item<T>(resource: &mut SendResource<T>, item: T) -> StreamResult<()>
1065where
1066 T: Send + 'static,
1067{
1068 if resource.batch_size <= 1 {
1069 return resource.carrier.send_one(item);
1070 }
1071 resource.pending.push(item);
1072 if resource.pending.len() >= resource.batch_size {
1073 flush_send_resource(resource)?;
1074 }
1075 Ok(())
1076}
1077
1078fn flush_send_resource<T>(resource: &mut SendResource<T>) -> StreamResult<()>
1079where
1080 T: Send + 'static,
1081{
1082 if resource.pending.is_empty() {
1083 return resource.carrier.check_send_error();
1084 }
1085 let pending = std::mem::take(&mut resource.pending);
1086 resource.carrier.send_items(pending)
1087}
1088
1089fn short_send_error(sent: usize, expected: usize) -> StreamError {
1090 StreamError::Failed(format!(
1091 "UDP socket sent {sent} bytes from {expected}-byte datagram"
1092 ))
1093}