1use std::io;
2use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use async_trait::async_trait;
7use bytes::BytesMut;
8use coap_lite::Packet;
9use futures::{Sink, Stream};
10use log::debug;
11use pin_project::pin_project;
12use tokio::net::{ToSocketAddrs, UdpSocket};
13use tokio_util::codec::{Decoder, Encoder};
14
15use crate::transport::{BoxedFramedBinding, FramedBinding, Transport, TransportError};
16use crate::udp::udp_framed_fork::UdpFramed;
17
18pub(crate) mod udp_framed_fork;
19
20pub struct UdpTransport<A: ToSocketAddrs> {
23 addresses: A,
24 mtu: Option<u32>,
25 multicast: bool,
26 multicast_joiner: MulticastJoiner,
27}
28
29#[derive(Default, Debug, Clone)]
30struct MulticastJoiner {
31 custom_group_joins: Vec<MulticastGroupJoin>,
32 default_ipv4_interface: Option<Ipv4Addr>,
33 default_ipv6_interface: Option<u32>,
34}
35
36#[derive(Debug, Clone)]
37pub enum MulticastGroupJoin {
38 Ipv4(Ipv4Addr, Option<Ipv4Addr>),
39 Ipv6(Ipv6Addr, Option<u32>),
40}
41
42impl<A: ToSocketAddrs> UdpTransport<A> {
43 pub fn new(addresses: A) -> Self {
44 let (mtu, multicast, multicast_joiner) = Default::default();
45 Self {
46 addresses,
47 mtu,
48 multicast,
49 multicast_joiner,
50 }
51 }
52
53 pub fn set_mtu(mut self, mtu: u32) -> Self {
55 self.mtu = Some(mtu);
56 self
57 }
58
59 pub fn enable_multicast(mut self) -> Self {
60 self.multicast = true;
61 self
62 }
63
64 pub fn set_multicast_default_ipv4_interface(mut self, interface: Ipv4Addr) -> Self {
65 self.multicast_joiner.default_ipv4_interface = Some(interface);
66 self
67 }
68
69 pub fn set_multicast_default_ipv6_interface(mut self, interface: u32) -> Self {
70 self.multicast_joiner.default_ipv6_interface = Some(interface);
71 self
72 }
73
74 pub fn add_multicast_join(mut self, join: MulticastGroupJoin) -> Self {
75 self.multicast_joiner.custom_group_joins.push(join);
76 self
77 }
78}
79
80#[async_trait]
81impl<A: ToSocketAddrs + Sync + Send> Transport for UdpTransport<A> {
82 type Endpoint = SocketAddr;
83
84 async fn bind(self) -> Result<BoxedFramedBinding<Self::Endpoint>, TransportError> {
85 let socket = UdpSocket::bind(self.addresses).await?;
86 if self.multicast {
87 self.multicast_joiner.join(&socket)?;
88 }
89 let local_addr = socket.local_addr()?;
90 let framed_socket = UdpFramed::new(socket, Codec::default());
91 let binding = UdpBinding {
92 framed_socket,
93 local_addr,
94 mtu: self.mtu,
95 };
96 Ok(Box::pin(binding))
97 }
98}
99
100impl MulticastJoiner {
101 fn join(self, socket: &UdpSocket) -> io::Result<()> {
102 let ipv4_interface = match socket.local_addr()? {
103 SocketAddr::V4(ipv4) => Some(*ipv4.ip()),
104 _ => None,
105 };
106 let joins = if self.custom_group_joins.is_empty() {
107 self.determine_default_joins(&socket)?
108 } else {
109 self.custom_group_joins
110 };
111
112 for join in joins {
113 debug!("Joining {join:?}...");
114 match join {
115 MulticastGroupJoin::Ipv4(addr, interface) => {
116 let resolved_interface = interface
117 .or(self.default_ipv4_interface)
118 .or(ipv4_interface)
119 .unwrap();
120 socket.join_multicast_v4(addr, resolved_interface)?;
121 }
122 MulticastGroupJoin::Ipv6(addr, interface) => {
123 let resolved_interface = interface.or(self.default_ipv6_interface).unwrap_or(0);
124 socket.join_multicast_v6(&addr, resolved_interface)?;
125 }
126 }
127 }
128
129 Ok(())
130 }
131
132 fn determine_default_joins(&self, socket: &UdpSocket) -> io::Result<Vec<MulticastGroupJoin>> {
133 let mut joins = Vec::new();
134 match socket.local_addr()? {
135 SocketAddr::V4(_) => {
136 joins.push(MulticastGroupJoin::Ipv4(
137 "224.0.1.187".parse().unwrap(),
138 None,
139 ));
140 }
141 SocketAddr::V6(_) => {
142 joins.push(MulticastGroupJoin::Ipv6("ff02::fd".parse().unwrap(), None));
143 joins.push(MulticastGroupJoin::Ipv6("ff05::fd".parse().unwrap(), None));
144 }
145 }
146 Ok(joins)
147 }
148}
149
150#[pin_project]
151struct UdpBinding {
152 #[pin]
153 framed_socket: UdpFramed<Codec>,
154 local_addr: SocketAddr,
155 mtu: Option<u32>,
156}
157
158#[async_trait]
159impl FramedBinding<SocketAddr> for UdpBinding {
160 fn mtu(&self) -> Option<u32> {
161 self.mtu
162 }
163}
164
165impl Stream for UdpBinding {
166 type Item = Result<(Packet, SocketAddr), (TransportError, Option<SocketAddr>)>;
167
168 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169 self.project().framed_socket.poll_next(cx)
170 }
171}
172
173impl Sink<(Packet, SocketAddr)> for UdpBinding {
174 type Error = TransportError;
175
176 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177 self.project().framed_socket.poll_ready(cx)
178 }
179
180 fn start_send(self: Pin<&mut Self>, item: (Packet, SocketAddr)) -> Result<(), Self::Error> {
181 self.project().framed_socket.start_send(item)
182 }
183
184 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
185 self.project().framed_socket.poll_flush(cx)
186 }
187
188 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 self.project().framed_socket.poll_close(cx)
190 }
191}
192
193#[derive(Default)]
194struct Codec;
195
196impl Decoder for Codec {
197 type Item = Packet;
198 type Error = TransportError;
199
200 fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Packet>, TransportError> {
201 if buf.is_empty() {
202 return Ok(None);
203 }
204 let result = (|| Ok(Some(Packet::from_bytes(buf)?)))();
205 buf.clear();
206 result
207 }
208}
209
210impl Encoder<Packet> for Codec {
211 type Error = TransportError;
212
213 fn encode(&mut self, my_packet: Packet, buf: &mut BytesMut) -> Result<(), TransportError> {
214 buf.extend_from_slice(&my_packet.to_bytes()?[..]);
215 Ok(())
216 }
217}