1use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
17use std::net::SocketAddr;
18use std::sync::Arc;
19use tokio::net::{ToSocketAddrs, UdpSocket};
20use tokio::runtime::Handle;
21use tokio::sync::{mpsc, watch};
22use tokio::task::JoinHandle;
23
24pub const DEFAULT_MAX_DATAGRAM_SIZE: usize = 65_536;
30
31pub const DEFAULT_RECEIVE_BUFFER: usize = 64;
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36pub struct Datagram {
37 pub payload: Vec<u8>,
39 pub remote: SocketAddr,
41}
42
43impl Datagram {
44 #[must_use]
46 pub fn new(payload: impl Into<Vec<u8>>, remote: SocketAddr) -> Self {
47 Self {
48 payload: payload.into(),
49 remote,
50 }
51 }
52
53 #[must_use]
55 pub fn payload(&self) -> &[u8] {
56 &self.payload
57 }
58
59 #[must_use]
61 pub fn remote(&self) -> SocketAddr {
62 self.remote
63 }
64
65 #[must_use]
67 pub fn into_parts(self) -> (Vec<u8>, SocketAddr) {
68 (self.payload, self.remote)
69 }
70
71 #[must_use]
73 pub fn into_payload(self) -> Vec<u8> {
74 self.payload
75 }
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub struct UdpBinding {
81 pub local_addr: SocketAddr,
82}
83
84impl UdpBinding {
85 #[must_use]
86 pub fn local_addr(&self) -> SocketAddr {
87 self.local_addr
88 }
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub struct UdpConnection {
94 pub local_addr: SocketAddr,
95 pub remote_addr: SocketAddr,
96}
97
98impl UdpConnection {
99 #[must_use]
100 pub fn local_addr(&self) -> SocketAddr {
101 self.local_addr
102 }
103
104 #[must_use]
105 pub fn remote_addr(&self) -> SocketAddr {
106 self.remote_addr
107 }
108}
109
110pub struct TokioUdp;
112
113pub type Udp = TokioUdp;
115
116enum ReceiveResponse<T> {
117 Item(T),
118 Error(StreamError),
119}
120
121enum QueueOutcome {
122 Queued,
123 Dropped,
124 Closed,
125}
126
127struct ReceiveResource<T> {
128 receiver: mpsc::Receiver<ReceiveResponse<T>>,
129 cancel: watch::Sender<bool>,
130 task: JoinHandle<()>,
131}
132
133impl<T> Drop for ReceiveResource<T> {
134 fn drop(&mut self) {
135 let _ = self.cancel.send(true);
136 self.task.abort();
137 }
138}
139
140struct SendResource {
141 socket: Arc<UdpSocket>,
142 handle: Handle,
143}
144
145fn io_error(error: std::io::Error) -> StreamError {
146 StreamError::Failed(error.to_string())
147}
148
149fn abrupt_termination() -> StreamError {
150 StreamError::AbruptTermination
151}
152
153impl TokioUdp {
154 #[must_use]
162 pub fn bind<A>(
163 addr: A,
164 max_datagram_size: usize,
165 receive_buffer: usize,
166 ) -> Source<Datagram, StreamCompletion<UdpBinding>>
167 where
168 A: ToSocketAddrs + Clone + Send + Sync + 'static,
169 {
170 assert!(
171 max_datagram_size > 0,
172 "maximum datagram size must be greater than zero"
173 );
174 assert!(
175 receive_buffer > 0,
176 "receive buffer must be greater than zero"
177 );
178 Source::lazy_future_source(move || {
179 let addr = addr.clone();
180 async move {
181 let handle = Handle::current();
182 let socket = UdpSocket::bind(addr).await.map_err(io_error)?;
183 let local_addr = socket.local_addr().map_err(io_error)?;
184 Ok(datagram_source_from_socket(
185 Arc::new(socket),
186 local_addr,
187 handle,
188 max_datagram_size,
189 receive_buffer,
190 ))
191 }
192 })
193 }
194
195 #[must_use]
198 pub fn bind_default<A>(addr: A) -> Source<Datagram, StreamCompletion<UdpBinding>>
199 where
200 A: ToSocketAddrs + Clone + Send + Sync + 'static,
201 {
202 Self::bind(addr, DEFAULT_MAX_DATAGRAM_SIZE, DEFAULT_RECEIVE_BUFFER)
203 }
204
205 #[must_use]
212 pub fn send_sink<A>(local_addr: A) -> Sink<Datagram, StreamCompletion<NotUsed>>
213 where
214 A: ToSocketAddrs + Clone + Send + Sync + 'static,
215 {
216 Flow::<Datagram, NotUsed>::future_flow(move || {
217 let local_addr = local_addr.clone();
218 async move {
219 let handle = Handle::current();
220 let socket = UdpSocket::bind(local_addr).await.map_err(io_error)?;
221 Ok(datagram_send_flow_from_socket(Arc::new(socket), handle))
222 }
223 })
224 .to_mat(Sink::ignore(), Keep::right)
225 }
226
227 #[must_use]
234 pub fn bind_flow<A>(
235 addr: A,
236 max_datagram_size: usize,
237 receive_buffer: usize,
238 ) -> Flow<Datagram, Datagram, StreamCompletion<UdpBinding>>
239 where
240 A: ToSocketAddrs + Clone + Send + Sync + 'static,
241 {
242 assert!(
243 max_datagram_size > 0,
244 "maximum datagram size must be greater than zero"
245 );
246 assert!(
247 receive_buffer > 0,
248 "receive buffer must be greater than zero"
249 );
250 Flow::<Datagram, Datagram>::future_flow(move || {
251 let addr = addr.clone();
252 async move {
253 let handle = Handle::current();
254 let socket = Arc::new(UdpSocket::bind(addr).await.map_err(io_error)?);
255 let local_addr = socket.local_addr().map_err(io_error)?;
256 let sink = datagram_send_flow_from_socket(Arc::clone(&socket), handle.clone())
257 .to_mat(Sink::ignore(), Keep::right);
258 let source = datagram_source_from_socket(
259 Arc::clone(&socket),
260 local_addr,
261 handle,
262 max_datagram_size,
263 receive_buffer,
264 );
265 Ok(Flow::from_sink_and_source(sink, source)
266 .map_materialized_value(move |_| UdpBinding { local_addr }))
267 }
268 })
269 }
270
271 #[must_use]
274 pub fn bind_flow_default<A>(addr: A) -> Flow<Datagram, Datagram, StreamCompletion<UdpBinding>>
275 where
276 A: ToSocketAddrs + Clone + Send + Sync + 'static,
277 {
278 Self::bind_flow(addr, DEFAULT_MAX_DATAGRAM_SIZE, DEFAULT_RECEIVE_BUFFER)
279 }
280
281 #[must_use]
288 pub fn connect<A, P>(
289 local_addr: A,
290 peer: P,
291 max_datagram_size: usize,
292 receive_buffer: usize,
293 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<UdpConnection>>
294 where
295 A: ToSocketAddrs + Clone + Send + Sync + 'static,
296 P: ToSocketAddrs + Clone + Send + Sync + 'static,
297 {
298 assert!(
299 max_datagram_size > 0,
300 "maximum datagram size must be greater than zero"
301 );
302 assert!(
303 receive_buffer > 0,
304 "receive buffer must be greater than zero"
305 );
306 Flow::<Vec<u8>, Vec<u8>>::future_flow(move || {
307 let local_addr = local_addr.clone();
308 let peer = peer.clone();
309 async move {
310 let handle = Handle::current();
311 let socket = UdpSocket::bind(local_addr).await.map_err(io_error)?;
312 socket.connect(peer).await.map_err(io_error)?;
313 let connection = UdpConnection {
314 local_addr: socket.local_addr().map_err(io_error)?,
315 remote_addr: socket.peer_addr().map_err(io_error)?,
316 };
317 let socket = Arc::new(socket);
318 let sink = connected_send_flow_from_socket(Arc::clone(&socket), handle.clone())
319 .to_mat(Sink::ignore(), Keep::right);
320 let source = connected_source_from_socket(
321 Arc::clone(&socket),
322 handle,
323 max_datagram_size,
324 receive_buffer,
325 );
326 Ok(Flow::from_sink_and_source(sink, source)
327 .map_materialized_value(move |_| connection))
328 }
329 })
330 }
331
332 #[must_use]
335 pub fn connect_default<A, P>(
336 local_addr: A,
337 peer: P,
338 ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<UdpConnection>>
339 where
340 A: ToSocketAddrs + Clone + Send + Sync + 'static,
341 P: ToSocketAddrs + Clone + Send + Sync + 'static,
342 {
343 Self::connect(
344 local_addr,
345 peer,
346 DEFAULT_MAX_DATAGRAM_SIZE,
347 DEFAULT_RECEIVE_BUFFER,
348 )
349 }
350}
351
352fn datagram_source_from_socket(
353 socket: Arc<UdpSocket>,
354 local_addr: SocketAddr,
355 handle: Handle,
356 max_datagram_size: usize,
357 receive_buffer: usize,
358) -> Source<Datagram, UdpBinding> {
359 Source::unfold_resource(
360 move || {
361 let (sender, receiver) = mpsc::channel(receive_buffer);
362 let (cancel_sender, cancel_receiver) = watch::channel(false);
363 let task = handle.spawn(run_datagram_receive_task(
364 Arc::clone(&socket),
365 max_datagram_size,
366 sender,
367 cancel_receiver,
368 ));
369 Ok(ReceiveResource {
370 receiver,
371 cancel: cancel_sender,
372 task,
373 })
374 },
375 receive_next_item,
376 close_receive_resource,
377 )
378 .map_materialized_value(move |_| UdpBinding { local_addr })
379}
380
381fn connected_source_from_socket(
382 socket: Arc<UdpSocket>,
383 handle: Handle,
384 max_datagram_size: usize,
385 receive_buffer: usize,
386) -> Source<Vec<u8>, NotUsed> {
387 Source::unfold_resource(
388 move || {
389 let (sender, receiver) = mpsc::channel(receive_buffer);
390 let (cancel_sender, cancel_receiver) = watch::channel(false);
391 let task = handle.spawn(run_connected_receive_task(
392 Arc::clone(&socket),
393 max_datagram_size,
394 sender,
395 cancel_receiver,
396 ));
397 Ok(ReceiveResource {
398 receiver,
399 cancel: cancel_sender,
400 task,
401 })
402 },
403 receive_next_item,
404 close_receive_resource,
405 )
406}
407
408fn receive_next_item<T>(resource: &mut ReceiveResource<T>) -> StreamResult<Option<T>>
409where
410 T: Send + 'static,
411{
412 match resource.receiver.blocking_recv() {
413 Some(ReceiveResponse::Item(item)) => Ok(Some(item)),
414 Some(ReceiveResponse::Error(error)) => Err(error),
415 None => Err(abrupt_termination()),
416 }
417}
418
419fn close_receive_resource<T>(resource: ReceiveResource<T>) -> StreamResult<()>
420where
421 T: Send + 'static,
422{
423 let _ = resource.cancel.send(true);
424 resource.task.abort();
425 Ok(())
426}
427
428async fn run_datagram_receive_task(
429 socket: Arc<UdpSocket>,
430 max_datagram_size: usize,
431 sender: mpsc::Sender<ReceiveResponse<Datagram>>,
432 mut cancel: watch::Receiver<bool>,
433) {
434 let mut buffer = vec![0_u8; max_datagram_size];
435 loop {
436 let received = tokio::select! {
437 received = socket.recv_from(&mut buffer) => received,
438 changed = cancel.changed() => {
439 let _ = changed;
440 return;
441 }
442 };
443
444 match received {
445 Ok((read, remote)) => {
446 let datagram = Datagram::new(buffer[..read].to_vec(), remote);
447 match try_send_received_item(&sender, datagram) {
448 QueueOutcome::Queued => {}
449 QueueOutcome::Dropped => {
450 if let Err(error) = drain_ready_datagrams(&socket, &mut buffer) {
451 let _ = send_receive_error(&sender, error, &mut cancel).await;
452 return;
453 }
454 }
455 QueueOutcome::Closed => return,
456 }
457 }
458 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
459 Err(error) => {
460 let _ = send_receive_error(&sender, io_error(error), &mut cancel).await;
461 return;
462 }
463 }
464 }
465}
466
467async fn run_connected_receive_task(
468 socket: Arc<UdpSocket>,
469 max_datagram_size: usize,
470 sender: mpsc::Sender<ReceiveResponse<Vec<u8>>>,
471 mut cancel: watch::Receiver<bool>,
472) {
473 let mut buffer = vec![0_u8; max_datagram_size];
474 loop {
475 let received = tokio::select! {
476 received = socket.recv(&mut buffer) => received,
477 changed = cancel.changed() => {
478 let _ = changed;
479 return;
480 }
481 };
482
483 match received {
484 Ok(read) => match try_send_received_item(&sender, buffer[..read].to_vec()) {
485 QueueOutcome::Queued => {}
486 QueueOutcome::Dropped => {
487 if let Err(error) = drain_ready_connected_datagrams(&socket, &mut buffer) {
488 let _ = send_receive_error(&sender, error, &mut cancel).await;
489 return;
490 }
491 }
492 QueueOutcome::Closed => return,
493 },
494 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
495 Err(error) => {
496 let _ = send_receive_error(&sender, io_error(error), &mut cancel).await;
497 return;
498 }
499 }
500 }
501}
502
503fn try_send_received_item<T>(sender: &mpsc::Sender<ReceiveResponse<T>>, item: T) -> QueueOutcome
504where
505 T: Send + 'static,
506{
507 match sender.try_send(ReceiveResponse::Item(item)) {
508 Ok(()) => QueueOutcome::Queued,
509 Err(mpsc::error::TrySendError::Full(_)) => QueueOutcome::Dropped,
510 Err(mpsc::error::TrySendError::Closed(_)) => QueueOutcome::Closed,
511 }
512}
513
514fn drain_ready_datagrams(socket: &UdpSocket, buffer: &mut [u8]) -> StreamResult<()> {
515 loop {
516 match socket.try_recv_from(buffer) {
517 Ok((_read, _remote)) => {}
518 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
519 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
520 Err(error) => return Err(io_error(error)),
521 }
522 }
523}
524
525fn drain_ready_connected_datagrams(socket: &UdpSocket, buffer: &mut [u8]) -> StreamResult<()> {
526 loop {
527 match socket.try_recv(buffer) {
528 Ok(_read) => {}
529 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
530 Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
531 Err(error) => return Err(io_error(error)),
532 }
533 }
534}
535
536async fn send_receive_error<T>(
537 sender: &mpsc::Sender<ReceiveResponse<T>>,
538 error: StreamError,
539 cancel: &mut watch::Receiver<bool>,
540) -> bool
541where
542 T: Send + 'static,
543{
544 tokio::select! {
545 result = sender.send(ReceiveResponse::Error(error)) => result.is_ok(),
546 changed = cancel.changed() => {
547 let _ = changed;
548 false
549 }
550 }
551}
552
553fn datagram_send_flow_from_socket(
554 socket: Arc<UdpSocket>,
555 handle: Handle,
556) -> Flow<Datagram, NotUsed, NotUsed> {
557 Flow::<Datagram, Datagram>::identity().map_with_resource(
558 move || {
559 Ok(SendResource {
560 socket: Arc::clone(&socket),
561 handle: handle.clone(),
562 })
563 },
564 |resource, datagram| {
565 send_datagram(resource, datagram)?;
566 Ok(NotUsed)
567 },
568 |_resource| Ok(None),
569 )
570}
571
572fn connected_send_flow_from_socket(
573 socket: Arc<UdpSocket>,
574 handle: Handle,
575) -> Flow<Vec<u8>, NotUsed, NotUsed> {
576 Flow::<Vec<u8>, Vec<u8>>::identity().map_with_resource(
577 move || {
578 Ok(SendResource {
579 socket: Arc::clone(&socket),
580 handle: handle.clone(),
581 })
582 },
583 |resource, payload| {
584 send_connected_payload(resource, payload)?;
585 Ok(NotUsed)
586 },
587 |_resource| Ok(None),
588 )
589}
590
591fn send_datagram(resource: &SendResource, datagram: Datagram) -> StreamResult<()> {
592 let expected = datagram.payload.len();
593 let sent = resource.handle.block_on(async {
594 resource
595 .socket
596 .send_to(&datagram.payload, datagram.remote)
597 .await
598 .map_err(io_error)
599 })?;
600 if sent == expected {
601 Ok(())
602 } else {
603 Err(short_send_error(sent, expected))
604 }
605}
606
607fn send_connected_payload(resource: &SendResource, payload: Vec<u8>) -> StreamResult<()> {
608 let expected = payload.len();
609 let sent = resource
610 .handle
611 .block_on(async { resource.socket.send(&payload).await.map_err(io_error) })?;
612 if sent == expected {
613 Ok(())
614 } else {
615 Err(short_send_error(sent, expected))
616 }
617}
618
619fn short_send_error(sent: usize, expected: usize) -> StreamError {
620 StreamError::Failed(format!(
621 "UDP socket sent {sent} bytes from {expected}-byte datagram"
622 ))
623}