1use crate::channel::{Channel, RecvMessage};
8use crate::message::{
9 ErrorResponse, Indication, InvalidMessage, Request, Response, SuccessResponse,
10};
11use crate::transport::{StunTcpTransporter, StunTransport, StunUdpTransporter};
12use crate::{Error, ErrorKind, Result};
13use bytecodec::marker::Never;
14use factory::DefaultFactory;
15use factory::Factory;
16use fibers::sync::mpsc;
17use fibers::{BoxSpawn, Spawn};
18use fibers_transport::{FixedPeerTransporter, TcpTransport, UdpTransport};
19use futures::{Async, Future, Poll, Stream};
20use std::fmt;
21use std::net::SocketAddr;
22use stun_codec::rfc5389;
23use stun_codec::{Attribute, MessageDecoder, MessageEncoder};
24
25pub const DEFAULT_PORT: u16 = 3478;
27
28pub const DEFAULT_TLS_PORT: u16 = 5349;
30
31type UdpTransporter<A> = fibers_transport::UdpTransporter<MessageEncoder<A>, MessageDecoder<A>>;
32
33#[derive(Debug)]
35#[must_use = "future do nothing unless polled"]
36pub struct UdpServer<H: HandleMessage> {
37 driver: HandlerDriver<H, StunUdpTransporter<H::Attribute, UdpTransporter<H::Attribute>>>,
38}
39impl<H: HandleMessage> UdpServer<H> {
40 pub fn start<S>(
42 spawner: S,
43 bind_addr: SocketAddr,
44 handler: H,
45 ) -> impl Future<Item = Self, Error = Error>
46 where
47 S: Spawn + Send + 'static,
48 {
49 UdpTransporter::bind(bind_addr)
50 .map_err(|e| track!(Error::from(e)))
51 .map(move |transporter| {
52 let channel = Channel::new(StunUdpTransporter::new(transporter));
53 let driver = HandlerDriver::new(spawner.boxed(), handler, channel, true);
54 UdpServer { driver }
55 })
56 }
57
58 pub fn local_addr(&self) -> SocketAddr {
60 self.driver
61 .channel
62 .transporter_ref()
63 .inner_ref()
64 .local_addr()
65 }
66}
67impl<H: HandleMessage> Future for UdpServer<H> {
68 type Item = Never;
69 type Error = Error;
70
71 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
72 if let Async::Ready(()) = track!(self.driver.poll())? {
73 track_panic!(ErrorKind::Other, "STUN UDP server unexpectedly terminated");
74 }
75 Ok(Async::NotReady)
76 }
77}
78
79type TcpListener<A> = fibers_transport::TcpListener<
80 DefaultFactory<MessageEncoder<A>>,
81 DefaultFactory<MessageDecoder<A>>,
82>;
83
84#[must_use = "future do nothing unless polled"]
86pub struct TcpServer<S, H>
87where
88 H: Factory,
89 H::Item: HandleMessage,
90{
91 spawner: S,
92 handler_factory: H,
93 listener: TcpListener<<H::Item as HandleMessage>::Attribute>,
94}
95impl<S, H> TcpServer<S, H>
96where
97 S: Spawn + Clone + Send + 'static,
98 H: Factory,
99 H::Item: HandleMessage,
100{
101 pub fn start(
103 spawner: S,
104 bind_addr: SocketAddr,
105 handler_factory: H,
106 ) -> impl Future<Item = Self, Error = Error> {
107 TcpListener::listen(bind_addr)
108 .map_err(|e| track!(Error::from(e)))
109 .map(move |listener| TcpServer {
110 spawner,
111 handler_factory,
112 listener,
113 })
114 }
115
116 pub fn local_addr(&self) -> SocketAddr {
118 self.listener.local_addr()
119 }
120}
121impl<S, H> Future for TcpServer<S, H>
122where
123 S: Spawn + Clone + Send + 'static,
124 H: Factory,
125 H::Item: HandleMessage + Send + 'static,
126 <<H::Item as HandleMessage>::Attribute as Attribute>::Decoder: Send + 'static,
127 <<H::Item as HandleMessage>::Attribute as Attribute>::Encoder: Send + 'static,
128{
129 type Item = Never;
130 type Error = Error;
131
132 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
133 while let Async::Ready(transporter) = track!(self.listener.poll())? {
134 if let Some(transporter) = transporter {
135 let peer_addr = transporter.peer_addr();
136 let transporter =
137 FixedPeerTransporter::new(peer_addr, (), StunTcpTransporter::new(transporter));
138 let channel = Channel::new(transporter);
139 let handler = self.handler_factory.create();
140 let future =
141 HandlerDriver::new(self.spawner.clone().boxed(), handler, channel, false);
142 self.spawner.spawn(future.map_err(|_| ()));
143 } else {
144 track_panic!(ErrorKind::Other, "STUN TCP server unexpectedly terminated");
145 }
146 }
147 Ok(Async::NotReady)
148 }
149}
150impl<S, H> fmt::Debug for TcpServer<S, H>
151where
152 H: Factory,
153 H::Item: HandleMessage,
154{
155 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
156 write!(f, "TcpServer {{ .. }}")
157 }
158}
159
160pub enum Action<T> {
162 Reply(T),
164
165 FutureReply(Box<dyn Future<Item = T, Error = Never> + Send + 'static>),
167
168 NoReply,
170
171 FutureNoReply(Box<dyn Future<Item = (), Error = Never> + Send + 'static>),
173}
174impl<T: fmt::Debug> fmt::Debug for Action<T> {
175 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
176 match self {
177 Action::Reply(t) => write!(f, "Reply({t:?})"),
178 Action::FutureReply(_) => write!(f, "FutureReply(_)"),
179 Action::NoReply => write!(f, "NoReply"),
180 Action::FutureNoReply(_) => write!(f, "FutureNoReply(_)"),
181 }
182 }
183}
184
185#[allow(unused_variables)]
187pub trait HandleMessage {
188 type Attribute: Attribute + Send + 'static;
190
191 fn handle_call(
195 &mut self,
196 peer: SocketAddr,
197 request: Request<Self::Attribute>,
198 ) -> Action<Response<Self::Attribute>> {
199 Action::NoReply
200 }
201
202 fn handle_cast(
206 &mut self,
207 peer: SocketAddr,
208 indication: Indication<Self::Attribute>,
209 ) -> Action<Never> {
210 Action::NoReply
211 }
212
213 fn handle_invalid_message(
220 &mut self,
221 peer: SocketAddr,
222 message: InvalidMessage,
223 ) -> Action<Response<Self::Attribute>> {
224 Action::NoReply
225 }
226
227 fn handle_channel_error(&mut self, error: &Error) {}
231}
232
233#[derive(Debug)]
234struct HandlerDriver<H, T>
235where
236 H: HandleMessage,
237 T: StunTransport<H::Attribute, PeerAddr = SocketAddr>,
238{
239 spawner: BoxSpawn,
240 handler: H,
241 channel: Channel<H::Attribute, T>,
242 response_tx: mpsc::Sender<(SocketAddr, Response<H::Attribute>)>,
243 response_rx: mpsc::Receiver<(SocketAddr, Response<H::Attribute>)>,
244 recoverable_channel: bool,
245}
246impl<H, T> HandlerDriver<H, T>
247where
248 H: HandleMessage,
249 T: StunTransport<H::Attribute, PeerAddr = SocketAddr>,
250{
251 fn new(
252 spawner: BoxSpawn,
253 handler: H,
254 channel: Channel<H::Attribute, T>,
255 recoverable_channel: bool,
256 ) -> Self {
257 let (response_tx, response_rx) = mpsc::channel();
258 HandlerDriver {
259 spawner,
260 handler,
261 channel,
262 response_tx,
263 response_rx,
264 recoverable_channel,
265 }
266 }
267
268 fn handle_message(
269 &mut self,
270 peer: SocketAddr,
271 message: RecvMessage<H::Attribute>,
272 ) -> Result<()> {
273 match message {
274 RecvMessage::Indication(m) => self.handle_indication(peer, m),
275 RecvMessage::Request(m) => track!(self.handle_request(peer, m))?,
276 RecvMessage::Invalid(m) => track!(self.handle_invalid_message(peer, m))?,
277 }
278 Ok(())
279 }
280
281 fn handle_indication(&mut self, peer: SocketAddr, indication: Indication<H::Attribute>) {
282 match self.handler.handle_cast(peer, indication) {
283 Action::NoReply => {}
284 Action::FutureNoReply(future) => self.spawner.spawn(future.map_err(|_| unreachable!())),
285 _ => unreachable!(),
286 }
287 }
288
289 fn handle_request(&mut self, peer: SocketAddr, request: Request<H::Attribute>) -> Result<()> {
290 match self.handler.handle_call(peer, request) {
291 Action::NoReply => {}
292 Action::FutureNoReply(future) => self.spawner.spawn(future.map_err(|_| unreachable!())),
293 Action::Reply(m) => track!(self.channel.reply(peer, m))?,
294 Action::FutureReply(future) => {
295 let tx = self.response_tx.clone();
296 self.spawner.spawn(
297 future
298 .map(move |response| {
299 let _ = tx.send((peer, response));
300 })
301 .map_err(|_| unreachable!()),
302 );
303 }
304 }
305 Ok(())
306 }
307
308 fn handle_invalid_message(&mut self, peer: SocketAddr, message: InvalidMessage) -> Result<()> {
309 match self.handler.handle_invalid_message(peer, message) {
310 Action::NoReply => {}
311 Action::FutureNoReply(future) => self.spawner.spawn(future.map_err(|_| unreachable!())),
312 Action::Reply(m) => track!(self.channel.reply(peer, m))?,
313 Action::FutureReply(future) => {
314 let tx = self.response_tx.clone();
315 self.spawner.spawn(
316 future
317 .map(move |response| {
318 let _ = tx.send((peer, response));
319 })
320 .map_err(|_| unreachable!()),
321 );
322 }
323 }
324 Ok(())
325 }
326}
327impl<H, T> Future for HandlerDriver<H, T>
328where
329 H: HandleMessage,
330 T: StunTransport<H::Attribute, PeerAddr = SocketAddr>,
331{
332 type Item = ();
333 type Error = Error;
334
335 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
336 let mut did_something = true;
337 while did_something {
338 did_something = false;
339
340 match track!(self.channel.poll_recv()) {
341 Err(e) => {
342 self.handler.handle_channel_error(&e);
343 if !self.recoverable_channel {
344 return Err(e);
345 }
346 did_something = true;
347 }
348 Ok(Async::NotReady) => {}
349 Ok(Async::Ready(None)) => return Ok(Async::Ready(())),
350 Ok(Async::Ready(Some((peer, message)))) => {
351 track!(self.handle_message(peer, message))?;
352 did_something = true;
353 }
354 }
355 if let Err(e) = track!(self.channel.poll_send()) {
356 self.handler.handle_channel_error(&e);
357 return Err(e);
358 }
359 if let Async::Ready(item) = self.response_rx.poll().expect("never fails") {
360 let (peer, response) = item.expect("never fails");
361 track!(self.channel.reply(peer, response))?;
362 did_something = true;
363 }
364 }
365 Ok(Async::NotReady)
366 }
367}
368
369#[derive(Debug, Default, Clone)]
373pub struct BindingHandler;
374impl HandleMessage for BindingHandler {
375 type Attribute = rfc5389::Attribute;
376
377 fn handle_call(
378 &mut self,
379 peer: SocketAddr,
380 request: Request<Self::Attribute>,
381 ) -> Action<Response<Self::Attribute>> {
382 if request.method() == rfc5389::methods::BINDING {
383 let mut response = SuccessResponse::new(&request);
384 response.add_attribute(rfc5389::attributes::XorMappedAddress::new(peer).into());
385 Action::Reply(Ok(response))
386 } else {
387 let response = ErrorResponse::new(&request, rfc5389::errors::BadRequest.into());
388 Action::Reply(Err(response))
389 }
390 }
391
392 fn handle_channel_error(&mut self, error: &Error) {
393 eprintln!("[ERROR] {error}");
394 }
395}