msf_ice/
lib.rs

1#[macro_use]
2mod log;
3
4mod candidate;
5mod channel;
6mod check;
7mod checklist;
8mod session;
9mod socket;
10mod utils;
11
12use std::{
13    future::Future,
14    net::{IpAddr, SocketAddr},
15    ops::Deref,
16    pin::Pin,
17    task::{Context, Poll},
18    time::{Duration, Instant},
19};
20
21use futures::{channel::mpsc, ready, FutureExt, StreamExt};
22use tokio::time::Sleep;
23
24#[cfg(feature = "slog")]
25use slog::{o, Discard, Logger};
26
27#[cfg(not(feature = "slog"))]
28use self::log::Logger;
29
30use self::{channel::Channel, session::Session};
31
32pub use self::{
33    candidate::{CandidateKind, LocalCandidate, RemoteCandidate},
34    channel::{ChannelBuilder, Component},
35    session::Credentials,
36    socket::Packet,
37};
38
39/// ICE agent role.
40#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
41pub enum AgentRole {
42    /// Agents who initiated the connection.
43    Controlling,
44    /// Agents who did not initiate the connection.
45    Controlled,
46}
47
48impl AgentRole {
49    /// Reverse the agent role.
50    fn reverse(self) -> Self {
51        match self {
52            Self::Controlled => Self::Controlling,
53            Self::Controlling => Self::Controlled,
54        }
55    }
56}
57
58/// ICE agent builder.
59pub struct AgentBuilder {
60    #[cfg(feature = "slog")]
61    logger: Logger,
62    agent_role: AgentRole,
63    local_addresses: Vec<IpAddr>,
64    stun_servers: Vec<SocketAddr>,
65    channels: Vec<ChannelBuilder>,
66    check_interval: Duration,
67}
68
69impl AgentBuilder {
70    /// Create a new builder.
71    fn new(agent_role: AgentRole) -> Self {
72        Self {
73            #[cfg(feature = "slog")]
74            logger: Logger::root(Discard, o!()),
75            agent_role,
76            local_addresses: Vec::new(),
77            stun_servers: Vec::new(),
78            channels: Vec::new(),
79            check_interval: Duration::from_millis(50),
80        }
81    }
82
83    /// Use a given logger.
84    #[cfg(feature = "slog")]
85    #[inline]
86    pub fn logger(&mut self, logger: Logger) -> &mut Self {
87        self.logger = logger;
88        self
89    }
90
91    /// Add given local address.
92    #[inline]
93    pub fn local_address(&mut self, addr: IpAddr) -> &mut Self {
94        self.local_addresses.push(addr);
95        self
96    }
97
98    /// Use a given STUN server.
99    #[inline]
100    pub fn stun_server(&mut self, addr: SocketAddr) -> &mut Self {
101        self.stun_servers.push(addr);
102        self
103    }
104
105    /// Add a new channel.
106    ///
107    /// The method returns a channel builder where components can be created.
108    #[inline]
109    pub fn channel(&mut self) -> &mut ChannelBuilder {
110        let create = self
111            .channels
112            .last()
113            .map(|last| !last.is_empty())
114            .unwrap_or(true);
115
116        if create {
117            self.channels.push(Channel::builder(self.channels.len()));
118        }
119
120        self.channels.last_mut().unwrap()
121    }
122
123    /// Build the agent.
124    pub fn build(mut self) -> Agent {
125        self.local_addresses.sort_unstable();
126        self.local_addresses.dedup();
127
128        self.stun_servers.sort_unstable();
129        self.stun_servers.dedup();
130
131        let session = Session::new(self.agent_role, self.channels.len());
132
133        #[cfg(feature = "slog")]
134        let logger = self.logger;
135
136        #[cfg(not(feature = "slog"))]
137        let logger = Logger;
138
139        let channels = self
140            .channels
141            .into_iter()
142            .filter(|channel| !channel.is_empty())
143            .map(|channel| {
144                channel.build(
145                    logger.clone(),
146                    session.clone(),
147                    &self.local_addresses,
148                    &self.stun_servers,
149                )
150            })
151            .collect();
152
153        let (local_candidate_tx, local_candidate_rx) = mpsc::unbounded();
154        let (remote_candidate_tx, remote_candidate_rx) = mpsc::unbounded();
155
156        let task = AgentTask {
157            session: session.clone(),
158            channels,
159            remote_candidate_rx,
160            local_candidate_tx: Some(local_candidate_tx),
161            last_check: Instant::now(),
162            next_check: Box::pin(tokio::time::sleep(self.check_interval)),
163            check_interval: self.check_interval,
164            check_tokens: 1,
165        };
166
167        let channel_count = task.channels.len();
168
169        tokio::spawn(task);
170
171        Agent {
172            session,
173            channels: channel_count,
174            local_candidate_rx,
175            remote_candidate_tx,
176        }
177    }
178}
179
180/// ICE agent.
181///
182/// # Usage
183/// 0. Get all components and prepare them for data/media transmission.
184/// 1. Get the local credentials for all channels and send them over to a
185///    remote agent.
186/// 2. Get all local candidates and send them over to the remote agent.
187/// 3. Set remote credentials for all channels (required to be done before
188///    adding remote candidates).
189/// 4. Add remote candidates.
190/// 5. If there are no more remote candidates, conclude connectivity checks.
191pub struct Agent {
192    session: Session,
193    channels: usize,
194    local_candidate_rx: mpsc::UnboundedReceiver<LocalCandidate>,
195    remote_candidate_tx: mpsc::UnboundedSender<NewRemoteCandidate>,
196}
197
198impl Agent {
199    /// Get an ICE agent builder.
200    #[inline]
201    pub fn builder(agent_role: AgentRole) -> AgentBuilder {
202        AgentBuilder::new(agent_role)
203    }
204
205    /// Get the next local candidate.
206    #[inline]
207    pub fn poll_next_local_candidate(
208        &mut self,
209        cx: &mut Context<'_>,
210    ) -> Poll<Option<LocalCandidate>> {
211        if let Some(candidate) = ready!(self.local_candidate_rx.poll_next_unpin(cx)) {
212            Poll::Ready(Some(candidate))
213        } else {
214            Poll::Ready(None)
215        }
216    }
217
218    /// Get the next local candidate.
219    #[inline]
220    pub async fn next_local_candidate(&mut self) -> Option<LocalCandidate> {
221        futures::future::poll_fn(|cx| self.poll_next_local_candidate(cx)).await
222    }
223
224    /// Get the number of channels.
225    #[inline]
226    pub fn channels(&self) -> usize {
227        self.channels
228    }
229
230    /// Get local credentials of a given channel.
231    #[inline]
232    pub fn get_local_credentials(&self, channel: usize) -> Credentials {
233        self.session.get_local_credentials(channel)
234    }
235
236    /// Get remote credentials of a given channel (if known).
237    #[inline]
238    pub fn get_remote_credentials(&self, channel: usize) -> Option<Credentials> {
239        self.session.get_remote_credentials(channel)
240    }
241
242    /// Set remote credentials for a given channel.
243    #[inline]
244    pub fn set_remote_credentials(&mut self, channel: usize, credentials: Credentials) {
245        self.session.set_remote_credentials(channel, credentials);
246    }
247
248    /// Add a given remote candidate.
249    ///
250    /// # Panics
251    /// The method will panic if the remote credentials for the corresponding
252    /// channel have not been set.
253    pub fn add_remote_candidate(
254        &mut self,
255        candidate: RemoteCandidate,
256        username_fragment: Option<&str>,
257    ) {
258        let channel = candidate.channel();
259
260        if channel >= self.channels {
261            return;
262        }
263
264        self.session
265            .lock()
266            .get_remote_credentials(channel)
267            .expect("missing remote credentials");
268
269        self.remote_candidate_tx
270            .unbounded_send(NewRemoteCandidate::new(candidate, username_fragment))
271            .unwrap()
272    }
273}
274
275/// Background task of the corresponding ICE agent.
276struct AgentTask {
277    session: Session,
278    channels: Vec<Channel>,
279    remote_candidate_rx: mpsc::UnboundedReceiver<NewRemoteCandidate>,
280    local_candidate_tx: Option<mpsc::UnboundedSender<LocalCandidate>>,
281    last_check: Instant,
282    next_check: Pin<Box<Sleep>>,
283    check_interval: Duration,
284    check_tokens: u32,
285}
286
287impl AgentTask {
288    /// Process a given remote candidate.
289    fn process_remote_candidate(&mut self, candidate: NewRemoteCandidate) {
290        // drop the candidate if the channel index is out of bounds or if the
291        // username fragment does not match
292        if let Some(channel) = self.channels.get_mut(candidate.channel()) {
293            let is_from_current_session = self
294                .session
295                .lock()
296                .get_remote_credentials(candidate.channel())
297                .map(|credentials| {
298                    candidate
299                        .username_fragment()
300                        .map(|username| username == credentials.username())
301                        .unwrap_or(true)
302                })
303                .unwrap_or(false);
304
305            if is_from_current_session {
306                channel.process_remote_candidate(candidate.into());
307            }
308        }
309    }
310
311    /// Process new local candidates.
312    fn process_local_candidates(&mut self, cx: &mut Context<'_>) {
313        if let Some(candidate_tx) = self.local_candidate_tx.as_mut() {
314            let mut resolved = 0;
315
316            for channel in &mut self.channels {
317                while let Poll::Ready(r) = channel.poll_next_local_candidate(cx) {
318                    if let Some(candidate) = r {
319                        candidate_tx.unbounded_send(candidate).unwrap_or_default();
320                    } else {
321                        // mark the channel as resolved
322                        resolved += 1;
323
324                        // ... and stop polling it
325                        break;
326                    }
327                }
328            }
329
330            if resolved == self.channels.len() {
331                self.local_candidate_tx = None;
332            }
333        }
334    }
335
336    /// Drive channels.
337    fn drive_channels(&mut self, cx: &mut Context<'_>) {
338        for channel in &mut self.channels {
339            channel.drive_channel(cx);
340        }
341    }
342
343    /// Schedule connectivity checks.
344    fn schedule_checks(&mut self, cx: &mut Context<'_>) {
345        // get the number of available tokens
346        let elapsed = self.last_check.elapsed();
347
348        let n = (elapsed.as_millis() / self.check_interval.as_millis()) as u32;
349
350        self.check_tokens = self.check_tokens.saturating_add(n);
351
352        self.last_check += n * self.check_interval;
353
354        // schedule the next time event
355        loop {
356            let poll = self.next_check.poll_unpin(cx);
357
358            if poll.is_pending() {
359                break;
360            }
361
362            let mut next = self.last_check;
363
364            while next < Instant::now() {
365                next += self.check_interval;
366            }
367
368            let pinned = self.next_check.as_mut();
369
370            pinned.reset(next.into());
371        }
372
373        // and schedule as many checks as possible
374        for channel in &mut self.channels {
375            while self.check_tokens > 0 {
376                if channel.schedule_check() {
377                    self.check_tokens -= 1;
378                } else {
379                    break;
380                }
381            }
382        }
383    }
384}
385
386impl Future for AgentTask {
387    type Output = ();
388
389    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390        while let Poll::Ready(next) = self.remote_candidate_rx.poll_next_unpin(cx) {
391            if let Some(candidate) = next {
392                self.process_remote_candidate(candidate);
393            } else {
394                return Poll::Ready(());
395            }
396        }
397
398        self.schedule_checks(cx);
399        self.process_local_candidates(cx);
400        self.drive_channels(cx);
401
402        Poll::Pending
403    }
404}
405
406/// New remote candidate.
407struct NewRemoteCandidate {
408    candidate: RemoteCandidate,
409    username_fragment: Option<String>,
410}
411
412impl NewRemoteCandidate {
413    /// Create a new remote candidate.
414    fn new(candidate: RemoteCandidate, username_fragment: Option<&str>) -> Self {
415        Self {
416            username_fragment: username_fragment.map(|v| v.to_string()),
417            candidate,
418        }
419    }
420
421    /// Get the username fragment.
422    fn username_fragment(&self) -> Option<&str> {
423        self.username_fragment.as_deref()
424    }
425}
426
427impl Deref for NewRemoteCandidate {
428    type Target = RemoteCandidate;
429
430    fn deref(&self) -> &Self::Target {
431        &self.candidate
432    }
433}
434
435impl From<NewRemoteCandidate> for RemoteCandidate {
436    fn from(c: NewRemoteCandidate) -> Self {
437        c.candidate
438    }
439}