async_coap/datagram/
loopback_socket.rs

1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15
16use super::*;
17use futures::channel::mpsc::{channel, Receiver, Sender};
18use futures::lock::Mutex;
19use futures::prelude::*;
20use futures::task::Context;
21use futures::Poll;
22use std::fmt::{Debug, Display, Formatter};
23use std::pin::Pin;
24
25/// Simplified "SocketAddr" for [`LoopbackSocket`]. Allows for two different types of addresses:
26/// Unicast addresses and Multicast addresses.
27#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
28pub enum LoopbackSocketAddr {
29    /// "Unicast" Loopback Socket Address.
30    Unicast,
31
32    /// "Multicast" Loopback Socket Address.
33    Multicast,
34}
35
36impl Display for LoopbackSocketAddr {
37    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
38        <Self as Debug>::fmt(self, f)
39    }
40}
41
42impl SocketAddrExt for LoopbackSocketAddr {
43    fn is_multicast(&self) -> bool {
44        match self {
45            LoopbackSocketAddr::Unicast => false,
46            LoopbackSocketAddr::Multicast => true,
47        }
48    }
49
50    fn port(&self) -> u16 {
51        0
52    }
53
54    fn conforming_to(&self, _local: Self) -> Option<Self> {
55        Some(*self)
56    }
57
58    fn addr_to_string(&self) -> String {
59        match self {
60            LoopbackSocketAddr::Unicast => "localhost",
61            LoopbackSocketAddr::Multicast => "broadcasthost",
62        }
63        .to_string()
64    }
65}
66
67impl ToSocketAddrs for LoopbackSocketAddr {
68    type Iter = std::option::IntoIter<Self::SocketAddr>;
69    type SocketAddr = Self;
70    type Error = super::Error;
71
72    fn to_socket_addrs(&self) -> Result<Self::Iter, Self::Error> {
73        Ok(Some(*self).into_iter())
74    }
75}
76
77/// An instance of [`AsyncDatagramSocket`] that implements a simple loopback interface, where
78/// all packets that are sent are looped back to the input.
79#[derive(Debug)]
80pub struct LoopbackSocket {
81    // Message is (packet_bytes, dest_addr)
82    sender: Sender<(Vec<u8>, LoopbackSocketAddr)>,
83    receiver: futures::lock::Mutex<Receiver<(Vec<u8>, LoopbackSocketAddr)>>,
84}
85
86impl LoopbackSocket {
87    /// Creates a new instance of [`LoopbackSocket`].
88    pub fn new() -> LoopbackSocket {
89        let (sender, receiver) = channel(3);
90        LoopbackSocket {
91            sender,
92            receiver: Mutex::new(receiver),
93        }
94    }
95}
96
97impl Unpin for LoopbackSocket {}
98
99impl AsyncDatagramSocket for LoopbackSocket {}
100
101impl DatagramSocketTypes for LoopbackSocket {
102    type SocketAddr = LoopbackSocketAddr;
103    type Error = super::Error;
104
105    fn local_addr(&self) -> Result<Self::SocketAddr, Self::Error> {
106        Ok(LoopbackSocketAddr::Unicast)
107    }
108
109    fn lookup_host(
110        host: &str,
111        _port: u16,
112    ) -> Result<std::vec::IntoIter<Self::SocketAddr>, Self::Error>
113    where
114        Self: Sized,
115    {
116        if host == ALL_COAP_DEVICES_HOSTNAME {
117            Ok(vec![LoopbackSocketAddr::Multicast].into_iter())
118        } else {
119            Ok(vec![LoopbackSocketAddr::Unicast].into_iter())
120        }
121    }
122}
123
124impl AsyncSendTo for LoopbackSocket {
125    fn poll_send_to<B>(
126        self: Pin<&Self>,
127        cx: &mut Context<'_>,
128        buf: &[u8],
129        addr: B,
130    ) -> Poll<Result<usize, Self::Error>>
131    where
132        B: super::ToSocketAddrs<SocketAddr = Self::SocketAddr, Error = Self::Error>,
133    {
134        if let Some(addr) = addr.to_socket_addrs()?.next() {
135            let mut sender = self.get_ref().sender.clone();
136            match sender.poll_ready(cx) {
137                Poll::Ready(Ok(())) => match sender.start_send((buf.to_vec(), addr)) {
138                    Ok(()) => Poll::Ready(Ok(buf.len())),
139                    Err(e) => {
140                        if e.is_full() {
141                            Poll::Pending
142                        } else {
143                            Poll::Ready(Err(Error::IOError))
144                        }
145                    }
146                },
147                Poll::Ready(Err(_)) => Poll::Ready(Err(Error::IOError)),
148                Poll::Pending => Poll::Pending,
149            }
150        } else {
151            Poll::Ready(Err(Error::HostNotFound))
152        }
153    }
154}
155
156impl AsyncRecvFrom for LoopbackSocket {
157    fn poll_recv_from(
158        self: Pin<&Self>,
159        cx: &mut Context<'_>,
160        buf: &mut [u8],
161    ) -> Poll<Result<(usize, Self::SocketAddr, Option<Self::SocketAddr>), Self::Error>> {
162        let mut receiver_lock_future = self.get_ref().receiver.lock();
163        let receiver_lock_future = Pin::new(&mut receiver_lock_future);
164
165        if let Poll::Ready(mut receiver_guard) = receiver_lock_future.poll(cx) {
166            let receiver: &mut Receiver<(Vec<u8>, LoopbackSocketAddr)> = &mut receiver_guard;
167            match receiver.poll_next_unpin(cx) {
168                Poll::Ready(Some((packet, addr))) => {
169                    let len = packet.len();
170                    if buf.len() >= len {
171                        buf[..len].copy_from_slice(&packet);
172                        // TODO: Handle multicast destination determination
173                        Poll::Ready(Ok((len, self.local_addr().unwrap(), Some(addr))))
174                    } else {
175                        Poll::Ready(Err(Error::IOError))
176                    }
177                }
178                Poll::Ready(None) => Poll::Ready(Err(Error::IOError)),
179                Poll::Pending => Poll::Pending,
180            }
181        } else {
182            Poll::Pending
183        }
184    }
185}
186
187impl MulticastSocket for LoopbackSocket {
188    type IpAddr = String;
189
190    fn join_multicast<A>(&self, _addr: A) -> Result<(), Self::Error>
191    where
192        A: std::convert::Into<Self::IpAddr>,
193    {
194        Ok(())
195    }
196
197    fn leave_multicast<A>(&self, _addr: A) -> Result<(), Self::Error>
198    where
199        A: std::convert::Into<Self::IpAddr>,
200    {
201        Ok(())
202    }
203}