lz_shared_udp/
lib.rs

1extern crate futures;
2#[macro_use]
3extern crate log;
4#[macro_use]
5extern crate tokio_core;
6
7use futures::{Async, AsyncSink, Poll, Sink, StartSend, Stream};
8use std::io::{Error as IoError, ErrorKind as IoErrorKind};
9use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
10use std::ops::Deref;
11use tokio_core::net::{UdpCodec, UdpSocket};
12
13#[must_use = "sinks do nothing unless polled"]
14#[derive(Debug)]
15pub struct SharedUdpFramed<R, C> {
16    socket: R,
17    codec: C,
18    out_addr: SocketAddr,
19    rd: Vec<u8>,
20    wr: Vec<u8>,
21    flushed: bool,
22}
23
24impl<R: Clone, C: Clone> Clone for SharedUdpFramed<R, C> {
25    fn clone(&self) -> Self {
26        Self::new(self.socket.clone(), self.codec.clone())
27    }
28}
29
30impl<R, C> SharedUdpFramed<R, C> {
31    pub(crate) fn new(socket: R, codec: C) -> Self {
32        Self {
33            socket: socket,
34            codec: codec,
35            out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
36            rd: vec![0; 64 * 1024],
37            wr: Vec::with_capacity(8 * 1024),
38            flushed: true,
39        }
40    }
41
42    pub fn socket(&self) -> &R {
43        &self.socket
44    }
45
46    pub fn socket_mut(&mut self) -> &mut R {
47        &mut self.socket
48    }
49
50    pub fn into_socket(self) -> R {
51        self.socket
52    }
53
54    pub fn codec(&self) -> &C {
55        &self.codec
56    }
57
58    pub fn codec_mut(&mut self) -> &mut C {
59        &mut self.codec
60    }
61
62    pub fn into_codec(self) -> C {
63        self.codec
64    }
65}
66
67impl<R: Deref<Target = UdpSocket>, C: UdpCodec> Stream for SharedUdpFramed<R, C> {
68    type Item = C::In;
69    type Error = IoError;
70
71    fn poll(&mut self) -> Poll<Option<C::In>, IoError> {
72        let (n, addr) = try_nb!(self.socket.recv_from(&mut self.rd));
73        trace!("received {} bytes, decoding", n);
74        let frame = try!(self.codec.decode(&addr, &self.rd[..n]));
75        trace!("frame decoded from buffer");
76        Ok(Async::Ready(Some(frame)))
77    }
78}
79
80impl<R: Deref<Target = UdpSocket>, C: UdpCodec> Sink for SharedUdpFramed<R, C> {
81    type SinkItem = C::Out;
82    type SinkError = IoError;
83
84    fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
85        trace!("sending frame");
86
87        if !self.flushed {
88            match self.poll_complete()? {
89                Async::Ready(()) => {}
90                Async::NotReady => return Ok(AsyncSink::NotReady(item)),
91            }
92        }
93
94        self.out_addr = self.codec.encode(item, &mut self.wr);
95        self.flushed = false;
96        trace!("frame encoded; length={}", self.wr.len());
97
98        Ok(AsyncSink::Ready)
99    }
100
101    fn poll_complete(&mut self) -> Poll<(), IoError> {
102        if self.flushed {
103            return Ok(Async::Ready(()));
104        }
105
106        trace!("flushing frame; length={}", self.wr.len());
107        let n = try_nb!(self.socket.send_to(&self.wr, &self.out_addr));
108        trace!("written {}", n);
109
110        let wrote_all = n == self.wr.len();
111        self.wr.clear();
112        self.flushed = true;
113
114        if wrote_all {
115            Ok(Async::Ready(()))
116        } else {
117            Err(IoError::new(
118                IoErrorKind::Other,
119                "failed to write entire datagram to socket",
120            ))
121        }
122    }
123}
124
125pub trait SharedUdpSocket {
126    fn framed<C: UdpCodec>(&self, codec: C) -> SharedUdpFramed<Self, C>
127    where
128        Self: Sized;
129}
130
131impl<R> SharedUdpSocket for R
132where
133    R: Clone + Deref<Target = UdpSocket>,
134{
135    fn framed<C: UdpCodec>(&self, codec: C) -> SharedUdpFramed<Self, C>
136    where
137        Self: Sized,
138    {
139        SharedUdpFramed::new(self.clone(), codec)
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use super::SharedUdpSocket;
146    use futures::{Future, Sink, Stream};
147    use std::io::Result as IoResult;
148    use std::net::SocketAddr;
149    use std::ops::Deref;
150    use std::rc::Rc;
151    use std::sync::Arc;
152    use tokio_core::net::{UdpCodec, UdpSocket};
153    use tokio_core::reactor::{Core, Handle};
154
155    fn bind_sockets(handle: &Handle) -> (UdpSocket, UdpSocket) {
156        let any_address = "0.0.0.0:0".parse().unwrap();
157
158        let first = UdpSocket::bind(&any_address, handle).unwrap();
159        let second = UdpSocket::bind(&any_address, handle).unwrap();
160
161        (first, second)
162    }
163
164    struct Utf8Codec;
165
166    impl UdpCodec for Utf8Codec {
167        type In = String;
168        type Out = (SocketAddr, String);
169
170        fn decode(&mut self, _: &SocketAddr, buf: &[u8]) -> IoResult<Self::In> {
171            Ok(String::from_utf8_lossy(buf).into_owned())
172        }
173
174        fn encode(&mut self, msg: Self::Out, buf: &mut Vec<u8>) -> SocketAddr {
175            buf.extend_from_slice(msg.1.as_bytes());
176
177            msg.0
178        }
179    }
180
181    #[test]
182    fn works_for_ref_udp_socket() {
183        let core = Core::new().unwrap();
184
185        let (first_socket, second_socket) = bind_sockets(&core.handle());
186
187        test_framed_impl(core, &first_socket, &second_socket);
188    }
189
190    #[test]
191    fn works_for_rc_udp_socket() {
192        let core = Core::new().unwrap();
193
194        let (first_socket, second_socket) = bind_sockets(&core.handle());
195
196        test_framed_impl(core, Rc::new(first_socket), Rc::new(second_socket));
197    }
198
199    #[test]
200    fn works_for_arc_udp_socket() {
201        let core = Core::new().unwrap();
202
203        let (first_socket, second_socket) = bind_sockets(&core.handle());
204
205        test_framed_impl(core, Arc::new(first_socket), Arc::new(second_socket));
206    }
207
208    fn test_framed_impl<R>(mut core: Core, first_socket: R, second_socket: R)
209    where
210        R: Clone + Deref<Target = UdpSocket>,
211    {
212        let loopback = "127.0.0.1".parse().unwrap();
213
214        let mut second_socket_addr = second_socket.local_addr().unwrap();
215        second_socket_addr.set_ip(loopback);
216
217        let second_socket_stream = second_socket.framed(Utf8Codec);
218
219        let sent_message = "Hello";
220        let future = first_socket
221            .framed(Utf8Codec)
222            .send((second_socket_addr, sent_message.to_owned()))
223            .and_then(move |_| {
224                second_socket_stream
225                    .into_future()
226                    .map(|(msg, _)| msg.unwrap())
227                    .map_err(|(err, _)| err)
228            });
229
230        let received_message = core.run(future).unwrap();
231
232        assert_eq!(received_message, sent_message)
233    }
234}