coap_server/udp/
mod.rs

1use std::io;
2use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use async_trait::async_trait;
7use bytes::BytesMut;
8use coap_lite::Packet;
9use futures::{Sink, Stream};
10use log::debug;
11use pin_project::pin_project;
12use tokio::net::{ToSocketAddrs, UdpSocket};
13use tokio_util::codec::{Decoder, Encoder};
14
15use crate::transport::{BoxedFramedBinding, FramedBinding, Transport, TransportError};
16use crate::udp::udp_framed_fork::UdpFramed;
17
18pub(crate) mod udp_framed_fork;
19
20/// Default CoAP transport as originally defined in RFC 7252.  Likely this is what you want if
21/// you're new to CoAP.
22pub struct UdpTransport<A: ToSocketAddrs> {
23    addresses: A,
24    mtu: Option<u32>,
25    multicast: bool,
26    multicast_joiner: MulticastJoiner,
27}
28
29#[derive(Default, Debug, Clone)]
30struct MulticastJoiner {
31    custom_group_joins: Vec<MulticastGroupJoin>,
32    default_ipv4_interface: Option<Ipv4Addr>,
33    default_ipv6_interface: Option<u32>,
34}
35
36#[derive(Debug, Clone)]
37pub enum MulticastGroupJoin {
38    Ipv4(Ipv4Addr, Option<Ipv4Addr>),
39    Ipv6(Ipv6Addr, Option<u32>),
40}
41
42impl<A: ToSocketAddrs> UdpTransport<A> {
43    pub fn new(addresses: A) -> Self {
44        let (mtu, multicast, multicast_joiner) = Default::default();
45        Self {
46            addresses,
47            mtu,
48            multicast,
49            multicast_joiner,
50        }
51    }
52
53    /// Manually set the MTU that will be used for block-wise transfer handling purposes.
54    pub fn set_mtu(mut self, mtu: u32) -> Self {
55        self.mtu = Some(mtu);
56        self
57    }
58
59    pub fn enable_multicast(mut self) -> Self {
60        self.multicast = true;
61        self
62    }
63
64    pub fn set_multicast_default_ipv4_interface(mut self, interface: Ipv4Addr) -> Self {
65        self.multicast_joiner.default_ipv4_interface = Some(interface);
66        self
67    }
68
69    pub fn set_multicast_default_ipv6_interface(mut self, interface: u32) -> Self {
70        self.multicast_joiner.default_ipv6_interface = Some(interface);
71        self
72    }
73
74    pub fn add_multicast_join(mut self, join: MulticastGroupJoin) -> Self {
75        self.multicast_joiner.custom_group_joins.push(join);
76        self
77    }
78}
79
80#[async_trait]
81impl<A: ToSocketAddrs + Sync + Send> Transport for UdpTransport<A> {
82    type Endpoint = SocketAddr;
83
84    async fn bind(self) -> Result<BoxedFramedBinding<Self::Endpoint>, TransportError> {
85        let socket = UdpSocket::bind(self.addresses).await?;
86        if self.multicast {
87            self.multicast_joiner.join(&socket)?;
88        }
89        let local_addr = socket.local_addr()?;
90        let framed_socket = UdpFramed::new(socket, Codec::default());
91        let binding = UdpBinding {
92            framed_socket,
93            local_addr,
94            mtu: self.mtu,
95        };
96        Ok(Box::pin(binding))
97    }
98}
99
100impl MulticastJoiner {
101    fn join(self, socket: &UdpSocket) -> io::Result<()> {
102        let ipv4_interface = match socket.local_addr()? {
103            SocketAddr::V4(ipv4) => Some(*ipv4.ip()),
104            _ => None,
105        };
106        let joins = if self.custom_group_joins.is_empty() {
107            self.determine_default_joins(&socket)?
108        } else {
109            self.custom_group_joins
110        };
111
112        for join in joins {
113            debug!("Joining {join:?}...");
114            match join {
115                MulticastGroupJoin::Ipv4(addr, interface) => {
116                    let resolved_interface = interface
117                        .or(self.default_ipv4_interface)
118                        .or(ipv4_interface)
119                        .unwrap();
120                    socket.join_multicast_v4(addr, resolved_interface)?;
121                }
122                MulticastGroupJoin::Ipv6(addr, interface) => {
123                    let resolved_interface = interface.or(self.default_ipv6_interface).unwrap_or(0);
124                    socket.join_multicast_v6(&addr, resolved_interface)?;
125                }
126            }
127        }
128
129        Ok(())
130    }
131
132    fn determine_default_joins(&self, socket: &UdpSocket) -> io::Result<Vec<MulticastGroupJoin>> {
133        let mut joins = Vec::new();
134        match socket.local_addr()? {
135            SocketAddr::V4(_) => {
136                joins.push(MulticastGroupJoin::Ipv4(
137                    "224.0.1.187".parse().unwrap(),
138                    None,
139                ));
140            }
141            SocketAddr::V6(_) => {
142                joins.push(MulticastGroupJoin::Ipv6("ff02::fd".parse().unwrap(), None));
143                joins.push(MulticastGroupJoin::Ipv6("ff05::fd".parse().unwrap(), None));
144            }
145        }
146        Ok(joins)
147    }
148}
149
150#[pin_project]
151struct UdpBinding {
152    #[pin]
153    framed_socket: UdpFramed<Codec>,
154    local_addr: SocketAddr,
155    mtu: Option<u32>,
156}
157
158#[async_trait]
159impl FramedBinding<SocketAddr> for UdpBinding {
160    fn mtu(&self) -> Option<u32> {
161        self.mtu
162    }
163}
164
165impl Stream for UdpBinding {
166    type Item = Result<(Packet, SocketAddr), (TransportError, Option<SocketAddr>)>;
167
168    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169        self.project().framed_socket.poll_next(cx)
170    }
171}
172
173impl Sink<(Packet, SocketAddr)> for UdpBinding {
174    type Error = TransportError;
175
176    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
177        self.project().framed_socket.poll_ready(cx)
178    }
179
180    fn start_send(self: Pin<&mut Self>, item: (Packet, SocketAddr)) -> Result<(), Self::Error> {
181        self.project().framed_socket.start_send(item)
182    }
183
184    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
185        self.project().framed_socket.poll_flush(cx)
186    }
187
188    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189        self.project().framed_socket.poll_close(cx)
190    }
191}
192
193#[derive(Default)]
194struct Codec;
195
196impl Decoder for Codec {
197    type Item = Packet;
198    type Error = TransportError;
199
200    fn decode(&mut self, buf: &mut BytesMut) -> Result<Option<Packet>, TransportError> {
201        if buf.is_empty() {
202            return Ok(None);
203        }
204        let result = (|| Ok(Some(Packet::from_bytes(buf)?)))();
205        buf.clear();
206        result
207    }
208}
209
210impl Encoder<Packet> for Codec {
211    type Error = TransportError;
212
213    fn encode(&mut self, my_packet: Packet, buf: &mut BytesMut) -> Result<(), TransportError> {
214        buf.extend_from_slice(&my_packet.to_bytes()?[..]);
215        Ok(())
216    }
217}