msf_ice/
channel.rs

1use std::{
2    collections::VecDeque,
3    io,
4    net::{IpAddr, SocketAddr},
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use bytes::Bytes;
10use futures::{channel::mpsc, ready, Sink, Stream, StreamExt};
11use msf_stun as stun;
12
13use crate::{
14    candidate::{LocalCandidate, RemoteCandidate},
15    checklist::Checklist,
16    log::Logger,
17    session::Session,
18    socket::{Binding as SocketBinding, ICESockets, LocalBinding, Packet, ReflexiveBinding},
19};
20
21/// Channel builder.
22pub struct ChannelBuilder {
23    channel: usize,
24    components: Vec<ComponentHandle>,
25}
26
27impl ChannelBuilder {
28    /// Create a new channel builder.
29    fn new(channel: usize) -> Self {
30        Self {
31            channel,
32            components: Vec::new(),
33        }
34    }
35
36    /// Check if there are any components.
37    pub(crate) fn is_empty(&self) -> bool {
38        self.components.is_empty()
39    }
40
41    /// Add a new component.
42    #[inline]
43    pub fn component(&mut self) -> Component {
44        assert!(self.components.len() < 256);
45
46        let component_id = self.components.len() as u8;
47
48        let (component, handle) = Component::new(self.channel, component_id);
49
50        self.components.push(handle);
51
52        component
53    }
54
55    /// Build the channel.
56    pub(crate) fn build(
57        self,
58        logger: Logger,
59        session: Session,
60        local_addresses: &[IpAddr],
61        stun_servers: &[SocketAddr],
62    ) -> Channel {
63        let components = self.components.len();
64
65        debug_assert!(components > 0);
66
67        let checklist = Checklist::new(session.clone(), self.channel, components);
68
69        let mut component_transports = Vec::with_capacity(components);
70
71        component_transports.resize_with(components, || {
72            ComponentTransport::new(logger.clone(), local_addresses, stun_servers)
73        });
74
75        Channel {
76            session,
77            channel_index: self.channel,
78            checklist,
79            component_transports,
80            component_handles: self.components,
81            available_candidates: VecDeque::new(),
82        }
83    }
84}
85
86/// Single data/media channel.
87pub struct Channel {
88    session: Session,
89    channel_index: usize,
90    checklist: Checklist,
91    component_transports: Vec<ComponentTransport>,
92    component_handles: Vec<ComponentHandle>,
93    available_candidates: VecDeque<LocalCandidate>,
94}
95
96impl Channel {
97    /// Get a channel builder.
98    pub fn builder(channel: usize) -> ChannelBuilder {
99        ChannelBuilder::new(channel)
100    }
101
102    /// Get the next local candidate.
103    pub fn poll_next_local_candidate(
104        &mut self,
105        cx: &mut Context<'_>,
106    ) -> Poll<Option<LocalCandidate>> {
107        loop {
108            if let Some(candidate) = self.available_candidates.pop_front() {
109                return Poll::Ready(Some(candidate));
110            } else if let Some((component, binding)) = ready!(self.poll_next_socket_binding(cx)) {
111                self.process_socket_binding(component, binding);
112            } else {
113                return Poll::Ready(None);
114            }
115        }
116    }
117
118    /// Get the next socket binding.
119    fn poll_next_socket_binding(
120        &mut self,
121        cx: &mut Context<'_>,
122    ) -> Poll<Option<(u8, SocketBinding)>> {
123        let mut pending = 0;
124
125        let transports = self.component_transports.iter_mut();
126
127        for (index, transport) in transports.enumerate() {
128            match transport.poll_next_binding(cx) {
129                Poll::Ready(Some(binding)) => return Poll::Ready(Some((index as _, binding))),
130                Poll::Ready(None) => (),
131                Poll::Pending => pending += 1,
132            }
133        }
134
135        if pending > 0 {
136            Poll::Pending
137        } else {
138            Poll::Ready(None)
139        }
140    }
141
142    /// Process a given socket binding.
143    fn process_socket_binding(&mut self, component: u8, binding: SocketBinding) {
144        match binding {
145            SocketBinding::Local(binding) => self.process_local_binding(component, binding),
146            SocketBinding::Reflexive(binding) => self.process_reflexive_binding(component, binding),
147        }
148    }
149
150    /// Process a given socket binding.
151    fn process_local_binding(&mut self, component: u8, binding: LocalBinding) {
152        let addr = binding.addr();
153
154        let candidate = LocalCandidate::host(self.channel_index, component, addr);
155
156        let foundation = self.session.assign_foundation(&candidate, None);
157
158        let candidate = candidate.with_foundation(foundation);
159
160        self.checklist.add_local_candidate(candidate);
161
162        let ip = addr.ip();
163
164        if !ip.is_unspecified() {
165            self.available_candidates.push_back(candidate);
166        }
167    }
168
169    /// Process a given socket binding.
170    fn process_reflexive_binding(&mut self, component: u8, binding: ReflexiveBinding) {
171        let candidate = LocalCandidate::server_reflexive(
172            self.channel_index,
173            component,
174            binding.base(),
175            binding.addr(),
176        );
177
178        let source = binding.source();
179
180        let foundation = self
181            .session
182            .assign_foundation(&candidate, Some(source.ip()));
183
184        let candidate = candidate.with_foundation(foundation);
185
186        self.checklist.add_local_candidate(candidate);
187        self.available_candidates.push_back(candidate);
188    }
189
190    /// Add a given remote candidate.
191    pub fn process_remote_candidate(&mut self, candidate: RemoteCandidate) {
192        // we silently drop all remote candidates with unknown component ID
193        if (candidate.component() as usize) < self.component_transports.len() {
194            self.checklist.add_remote_candidate(candidate);
195        }
196    }
197
198    /// Schedule a connectivity check.
199    ///
200    /// The method returns `true` if a check was scheduled.
201    pub fn schedule_check(&mut self) -> bool {
202        self.checklist.schedule_check()
203    }
204
205    /// Drive the channel.
206    pub fn drive_channel(&mut self, cx: &mut Context<'_>) {
207        self.drive_connectivity_checks(cx);
208        self.drive_input(cx);
209        self.drive_output(cx);
210    }
211
212    /// Drive connectivity checks.
213    fn drive_connectivity_checks(&mut self, cx: &mut Context<'_>) {
214        while let Poll::Ready(msg) = self.checklist.poll(cx) {
215            let component = msg.component();
216
217            let transport = &mut self.component_transports[component as usize];
218
219            let local_addr = msg.local_addr();
220            let remote_addr = msg.remote_addr();
221
222            transport.send_using(local_addr, remote_addr, msg.take_data());
223        }
224    }
225
226    /// Drive the input.
227    fn drive_input(&mut self, cx: &mut Context<'_>) {
228        for index in 0..self.component_transports.len() {
229            loop {
230                // we can't iterate directly over the transports because we
231                // need to also borrow self in each iteration
232                let transport = &mut self.component_transports[index];
233
234                if let Poll::Ready(packet) = transport.poll_recv(cx) {
235                    self.process_incoming_packet(index as _, packet);
236                } else {
237                    break;
238                }
239            }
240        }
241    }
242
243    /// Drive the output.
244    fn drive_output(&mut self, cx: &mut Context<'_>) {
245        for (index, transport) in self.component_transports.iter_mut().enumerate() {
246            if transport.is_bound() {
247                if let Some(handle) = self.component_handles.get_mut(index) {
248                    while let Poll::Ready(Some(data)) = handle.poll_next_output_packet(cx) {
249                        transport.send(data);
250                    }
251                }
252            }
253        }
254    }
255
256    /// Process a given incoming packet.
257    fn process_incoming_packet(&mut self, component: u8, packet: Packet) {
258        let local_addr = packet.local_addr();
259        let remote_addr = packet.remote_addr();
260        let data = packet.data();
261
262        if let Some(msg) = self.parse_stun_message(data) {
263            self.process_stun_message(component, local_addr, remote_addr, msg);
264        } else if let Some(handle) = self.component_handles.get_mut(component as usize) {
265            handle.deliver_input_packet(packet);
266        }
267    }
268
269    /// Try to parse a STUN message.
270    fn parse_stun_message(&self, data: &Bytes) -> Option<stun::Message> {
271        if let Ok(msg) = stun::Message::from_frame(data.clone()) {
272            if msg.is_rfc5389_message() && msg.check_fingerprint() {
273                return Some(msg);
274            }
275        }
276
277        None
278    }
279
280    /// Process a given STUN message.
281    fn process_stun_message(
282        &mut self,
283        component: u8,
284        local_addr: SocketAddr,
285        remote_addr: SocketAddr,
286        msg: stun::Message,
287    ) {
288        if msg.method() == stun::Method::Binding {
289            if msg.is_request() {
290                self.process_stun_request(component, local_addr, remote_addr, msg)
291            } else if msg.is_response() {
292                self.process_stun_response(component, local_addr, remote_addr, msg)
293            }
294        }
295    }
296
297    /// Process a given STUN request.
298    fn process_stun_request(
299        &mut self,
300        component: u8,
301        local_addr: SocketAddr,
302        remote_addr: SocketAddr,
303        msg: stun::Message,
304    ) {
305        let response =
306            self.checklist
307                .process_stun_request(component, local_addr, remote_addr, &msg);
308
309        let transport = &mut self.component_transports[component as usize];
310
311        transport.send_using(local_addr, remote_addr, response);
312    }
313
314    /// Process a given STUN response.
315    fn process_stun_response(
316        &mut self,
317        component: u8,
318        local_addr: SocketAddr,
319        remote_addr: SocketAddr,
320        msg: stun::Message,
321    ) {
322        if let Some(nominated) =
323            self.checklist
324                .process_stun_response(component, local_addr, remote_addr, &msg)
325        {
326            let local = nominated.local();
327            let remote = nominated.remote();
328
329            let transport = &mut self.component_transports[component as usize];
330
331            transport.bind(local.base(), remote.addr());
332        }
333    }
334}
335
336/// Component stream/sink.
337pub struct Component {
338    channel: usize,
339    component_id: u8,
340    input_packet_rx: mpsc::UnboundedReceiver<Packet>,
341    output_packet_tx: mpsc::Sender<Bytes>,
342}
343
344impl Component {
345    /// Create a new component stream/sink.
346    fn new(channel: usize, component_id: u8) -> (Self, ComponentHandle) {
347        let (input_packet_tx, input_packet_rx) = mpsc::unbounded();
348        let (output_packet_tx, output_packet_rx) = mpsc::channel(8);
349
350        let transport = Self {
351            channel,
352            component_id,
353            input_packet_rx,
354            output_packet_tx,
355        };
356
357        let handle = ComponentHandle {
358            input_packet_tx,
359            output_packet_rx,
360        };
361
362        (transport, handle)
363    }
364
365    /// Get index of the channel this component belongs to.
366    #[inline]
367    pub fn channel(&self) -> usize {
368        self.channel
369    }
370
371    /// Get the component ID (zero-based).
372    #[inline]
373    pub fn component_id(&self) -> u8 {
374        self.component_id
375    }
376}
377
378impl Stream for Component {
379    type Item = io::Result<Packet>;
380
381    #[inline]
382    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
383        match self.input_packet_rx.poll_next_unpin(cx) {
384            Poll::Ready(Some(packet)) => Poll::Ready(Some(Ok(packet))),
385            Poll::Ready(None) => Poll::Ready(None),
386            Poll::Pending => Poll::Pending,
387        }
388    }
389}
390
391impl Sink<Bytes> for Component {
392    type Error = io::Error;
393
394    #[inline]
395    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
396        Pin::new(&mut self.output_packet_tx)
397            .poll_ready(cx)
398            .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
399    }
400
401    #[inline]
402    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
403        Pin::new(&mut self.output_packet_tx)
404            .start_send(item)
405            .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
406    }
407
408    #[inline]
409    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
410        Pin::new(&mut self.output_packet_tx)
411            .poll_flush(cx)
412            .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
413    }
414
415    #[inline]
416    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
417        Pin::new(&mut self.output_packet_tx)
418            .poll_close(cx)
419            .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe))
420    }
421}
422
423/// Component handle.
424struct ComponentHandle {
425    input_packet_tx: mpsc::UnboundedSender<Packet>,
426    output_packet_rx: mpsc::Receiver<Bytes>,
427}
428
429impl ComponentHandle {
430    /// Get next output packet.
431    fn poll_next_output_packet(&mut self, cx: &mut Context<'_>) -> Poll<Option<Bytes>> {
432        self.output_packet_rx.poll_next_unpin(cx)
433    }
434
435    /// Deliver a given input packet.
436    fn deliver_input_packet(&mut self, packet: Packet) {
437        self.input_packet_tx
438            .unbounded_send(packet)
439            .unwrap_or_default();
440    }
441}
442
443/// Component transport.
444struct ComponentTransport {
445    sockets: ICESockets,
446    binding: Option<ComponentBinding>,
447}
448
449impl ComponentTransport {
450    /// Create a new component transport.
451    fn new(logger: Logger, local_addresses: &[IpAddr], stun_servers: &[SocketAddr]) -> Self {
452        Self {
453            sockets: ICESockets::new(logger, local_addresses, stun_servers),
454            binding: None,
455        }
456    }
457
458    /// Check if the transport has been bound to local/remote address pair.
459    fn is_bound(&self) -> bool {
460        self.binding.is_some()
461    }
462
463    /// Bind the transport to a given local/remote address pair.
464    fn bind(&mut self, local: SocketAddr, remote: SocketAddr) {
465        self.binding = Some(ComponentBinding::new(local, remote));
466    }
467
468    /// Get the next local binding.
469    fn poll_next_binding(&mut self, cx: &mut Context<'_>) -> Poll<Option<SocketBinding>> {
470        self.sockets.poll_next_binding(cx)
471    }
472
473    /// Read the next packet.
474    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Packet> {
475        self.sockets.poll_recv(cx)
476    }
477
478    /// Send given data from a given local binding to a given remote host.
479    fn send_using(&mut self, local_addr: SocketAddr, remote_addr: SocketAddr, data: Bytes) {
480        self.sockets.send(local_addr, remote_addr, data);
481    }
482
483    /// Send given data from the local address this transport is bound to to
484    /// the remote host that this transport is connected to.
485    fn send(&mut self, data: Bytes) {
486        if let Some(binding) = self.binding {
487            self.send_using(binding.local, binding.remote, data);
488        } else if cfg!(debug_assertions) {
489            panic!("unable to send given data packet, no binding");
490        }
491    }
492}
493
494/// Component binding.
495#[derive(Copy, Clone)]
496struct ComponentBinding {
497    local: SocketAddr,
498    remote: SocketAddr,
499}
500
501impl ComponentBinding {
502    /// Create a new component binding.
503    fn new(local: SocketAddr, remote: SocketAddr) -> Self {
504        Self { local, remote }
505    }
506}