1extern crate futures;
2#[macro_use]
3extern crate log;
4#[macro_use]
5extern crate tokio_core;
6
7use futures::{Async, AsyncSink, Poll, Sink, StartSend, Stream};
8use std::io::{Error as IoError, ErrorKind as IoErrorKind};
9use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
10use std::ops::Deref;
11use tokio_core::net::{UdpCodec, UdpSocket};
12
13#[must_use = "sinks do nothing unless polled"]
14#[derive(Debug)]
15pub struct SharedUdpFramed<R, C> {
16 socket: R,
17 codec: C,
18 out_addr: SocketAddr,
19 rd: Vec<u8>,
20 wr: Vec<u8>,
21 flushed: bool,
22}
23
24impl<R: Clone, C: Clone> Clone for SharedUdpFramed<R, C> {
25 fn clone(&self) -> Self {
26 Self::new(self.socket.clone(), self.codec.clone())
27 }
28}
29
30impl<R, C> SharedUdpFramed<R, C> {
31 pub(crate) fn new(socket: R, codec: C) -> Self {
32 Self {
33 socket: socket,
34 codec: codec,
35 out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)),
36 rd: vec![0; 64 * 1024],
37 wr: Vec::with_capacity(8 * 1024),
38 flushed: true,
39 }
40 }
41
42 pub fn socket(&self) -> &R {
43 &self.socket
44 }
45
46 pub fn socket_mut(&mut self) -> &mut R {
47 &mut self.socket
48 }
49
50 pub fn into_socket(self) -> R {
51 self.socket
52 }
53
54 pub fn codec(&self) -> &C {
55 &self.codec
56 }
57
58 pub fn codec_mut(&mut self) -> &mut C {
59 &mut self.codec
60 }
61
62 pub fn into_codec(self) -> C {
63 self.codec
64 }
65}
66
67impl<R: Deref<Target = UdpSocket>, C: UdpCodec> Stream for SharedUdpFramed<R, C> {
68 type Item = C::In;
69 type Error = IoError;
70
71 fn poll(&mut self) -> Poll<Option<C::In>, IoError> {
72 let (n, addr) = try_nb!(self.socket.recv_from(&mut self.rd));
73 trace!("received {} bytes, decoding", n);
74 let frame = try!(self.codec.decode(&addr, &self.rd[..n]));
75 trace!("frame decoded from buffer");
76 Ok(Async::Ready(Some(frame)))
77 }
78}
79
80impl<R: Deref<Target = UdpSocket>, C: UdpCodec> Sink for SharedUdpFramed<R, C> {
81 type SinkItem = C::Out;
82 type SinkError = IoError;
83
84 fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
85 trace!("sending frame");
86
87 if !self.flushed {
88 match self.poll_complete()? {
89 Async::Ready(()) => {}
90 Async::NotReady => return Ok(AsyncSink::NotReady(item)),
91 }
92 }
93
94 self.out_addr = self.codec.encode(item, &mut self.wr);
95 self.flushed = false;
96 trace!("frame encoded; length={}", self.wr.len());
97
98 Ok(AsyncSink::Ready)
99 }
100
101 fn poll_complete(&mut self) -> Poll<(), IoError> {
102 if self.flushed {
103 return Ok(Async::Ready(()));
104 }
105
106 trace!("flushing frame; length={}", self.wr.len());
107 let n = try_nb!(self.socket.send_to(&self.wr, &self.out_addr));
108 trace!("written {}", n);
109
110 let wrote_all = n == self.wr.len();
111 self.wr.clear();
112 self.flushed = true;
113
114 if wrote_all {
115 Ok(Async::Ready(()))
116 } else {
117 Err(IoError::new(
118 IoErrorKind::Other,
119 "failed to write entire datagram to socket",
120 ))
121 }
122 }
123}
124
125pub trait SharedUdpSocket {
126 fn framed<C: UdpCodec>(&self, codec: C) -> SharedUdpFramed<Self, C>
127 where
128 Self: Sized;
129}
130
131impl<R> SharedUdpSocket for R
132where
133 R: Clone + Deref<Target = UdpSocket>,
134{
135 fn framed<C: UdpCodec>(&self, codec: C) -> SharedUdpFramed<Self, C>
136 where
137 Self: Sized,
138 {
139 SharedUdpFramed::new(self.clone(), codec)
140 }
141}
142
143#[cfg(test)]
144mod tests {
145 use super::SharedUdpSocket;
146 use futures::{Future, Sink, Stream};
147 use std::io::Result as IoResult;
148 use std::net::SocketAddr;
149 use std::ops::Deref;
150 use std::rc::Rc;
151 use std::sync::Arc;
152 use tokio_core::net::{UdpCodec, UdpSocket};
153 use tokio_core::reactor::{Core, Handle};
154
155 fn bind_sockets(handle: &Handle) -> (UdpSocket, UdpSocket) {
156 let any_address = "0.0.0.0:0".parse().unwrap();
157
158 let first = UdpSocket::bind(&any_address, handle).unwrap();
159 let second = UdpSocket::bind(&any_address, handle).unwrap();
160
161 (first, second)
162 }
163
164 struct Utf8Codec;
165
166 impl UdpCodec for Utf8Codec {
167 type In = String;
168 type Out = (SocketAddr, String);
169
170 fn decode(&mut self, _: &SocketAddr, buf: &[u8]) -> IoResult<Self::In> {
171 Ok(String::from_utf8_lossy(buf).into_owned())
172 }
173
174 fn encode(&mut self, msg: Self::Out, buf: &mut Vec<u8>) -> SocketAddr {
175 buf.extend_from_slice(msg.1.as_bytes());
176
177 msg.0
178 }
179 }
180
181 #[test]
182 fn works_for_ref_udp_socket() {
183 let core = Core::new().unwrap();
184
185 let (first_socket, second_socket) = bind_sockets(&core.handle());
186
187 test_framed_impl(core, &first_socket, &second_socket);
188 }
189
190 #[test]
191 fn works_for_rc_udp_socket() {
192 let core = Core::new().unwrap();
193
194 let (first_socket, second_socket) = bind_sockets(&core.handle());
195
196 test_framed_impl(core, Rc::new(first_socket), Rc::new(second_socket));
197 }
198
199 #[test]
200 fn works_for_arc_udp_socket() {
201 let core = Core::new().unwrap();
202
203 let (first_socket, second_socket) = bind_sockets(&core.handle());
204
205 test_framed_impl(core, Arc::new(first_socket), Arc::new(second_socket));
206 }
207
208 fn test_framed_impl<R>(mut core: Core, first_socket: R, second_socket: R)
209 where
210 R: Clone + Deref<Target = UdpSocket>,
211 {
212 let loopback = "127.0.0.1".parse().unwrap();
213
214 let mut second_socket_addr = second_socket.local_addr().unwrap();
215 second_socket_addr.set_ip(loopback);
216
217 let second_socket_stream = second_socket.framed(Utf8Codec);
218
219 let sent_message = "Hello";
220 let future = first_socket
221 .framed(Utf8Codec)
222 .send((second_socket_addr, sent_message.to_owned()))
223 .and_then(move |_| {
224 second_socket_stream
225 .into_future()
226 .map(|(msg, _)| msg.unwrap())
227 .map_err(|(err, _)| err)
228 });
229
230 let received_message = core.run(future).unwrap();
231
232 assert_eq!(received_message, sent_message)
233 }
234}