rtc_dtls/
endpoint.rs

1use crate::conn::DTLSConn;
2use shared::error::{Error, Result};
3use shared::{EcnCodepoint, TransportContext};
4use shared::{TransportMessage, TransportProtocol};
5
6use crate::config::HandshakeConfig;
7use crate::state::State;
8use bytes::BytesMut;
9use std::collections::hash_map::Keys;
10use std::collections::{HashMap, VecDeque, hash_map::Entry::Vacant};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::Instant;
14
15pub enum EndpointEvent {
16    HandshakeComplete,
17    ApplicationData(BytesMut),
18}
19
20/// The main entry point to the library
21///
22/// This object performs no I/O whatsoever. Instead, it generates a stream of packets to send via
23/// `poll_transmit`, and consumes incoming packets and connections-generated events via `handle` and
24/// `handle_event`.
25pub struct Endpoint {
26    local_addr: SocketAddr,
27    transport_protocol: TransportProtocol,
28    transmits: VecDeque<TransportMessage<BytesMut>>,
29    connections: HashMap<SocketAddr, DTLSConn>,
30    server_config: Option<Arc<HandshakeConfig>>,
31}
32
33impl Endpoint {
34    /// Create a new endpoint
35    ///
36    /// Returns `Err` if the configuration is invalid.
37    pub fn new(
38        local_addr: SocketAddr,
39        protocol: TransportProtocol,
40        server_config: Option<Arc<HandshakeConfig>>,
41    ) -> Self {
42        Self {
43            local_addr,
44            transport_protocol: protocol,
45            transmits: VecDeque::new(),
46            connections: HashMap::new(),
47            server_config,
48        }
49    }
50
51    /// Replace the server configuration, affecting new incoming associations only
52    pub fn set_server_config(&mut self, server_config: Option<Arc<HandshakeConfig>>) {
53        self.server_config = server_config;
54    }
55
56    /// Get the next packet to transmit
57    #[must_use]
58    pub fn poll_transmit(&mut self) -> Option<TransportMessage<BytesMut>> {
59        self.transmits.pop_front()
60    }
61
62    /// Get keys of Connections
63    pub fn get_connections_keys(&self) -> Keys<'_, SocketAddr, DTLSConn> {
64        self.connections.keys()
65    }
66
67    /// Get Connection State
68    pub fn get_connection_state(&self, remote: SocketAddr) -> Option<&State> {
69        if let Some(conn) = self.connections.get(&remote) {
70            Some(conn.connection_state())
71        } else {
72            None
73        }
74    }
75
76    /// Initiate an Association
77    pub fn connect(
78        &mut self,
79        remote: SocketAddr,
80        client_config: Arc<HandshakeConfig>,
81        initial_state: Option<State>,
82    ) -> Result<()> {
83        if remote.port() == 0 {
84            return Err(Error::InvalidRemoteAddress(remote));
85        }
86
87        if let Vacant(e) = self.connections.entry(remote) {
88            let mut conn = DTLSConn::new(client_config, true, initial_state);
89            conn.handshake()?;
90
91            while let Some(payload) = conn.outgoing_raw_packet() {
92                self.transmits.push_back(TransportMessage {
93                    now: Instant::now(),
94                    transport: TransportContext {
95                        local_addr: self.local_addr,
96                        peer_addr: remote,
97                        ecn: None,
98                        transport_protocol: self.transport_protocol,
99                    },
100                    message: payload,
101                });
102            }
103
104            e.insert(conn);
105        }
106
107        Ok(())
108    }
109
110    /// Process stop remote
111    pub fn stop(&mut self, remote: SocketAddr) -> Option<DTLSConn> {
112        if let Some(conn) = self.connections.get_mut(&remote) {
113            conn.close();
114            while let Some(payload) = conn.outgoing_raw_packet() {
115                self.transmits.push_back(TransportMessage {
116                    now: Instant::now(),
117                    transport: TransportContext {
118                        local_addr: self.local_addr,
119                        peer_addr: remote,
120                        ecn: None,
121                        transport_protocol: self.transport_protocol,
122                    },
123                    message: payload,
124                });
125            }
126        }
127        self.connections.remove(&remote)
128    }
129
130    /// Process close
131    pub fn close(&mut self) -> Result<()> {
132        for (remote_addr, conn) in self.connections.iter_mut() {
133            conn.close();
134            while let Some(payload) = conn.outgoing_raw_packet() {
135                self.transmits.push_back(TransportMessage {
136                    now: Instant::now(),
137                    transport: TransportContext {
138                        local_addr: self.local_addr,
139                        peer_addr: *remote_addr,
140                        ecn: None,
141                        transport_protocol: self.transport_protocol,
142                    },
143                    message: payload,
144                });
145            }
146        }
147        self.connections.clear();
148
149        Ok(())
150    }
151
152    /// Process an incoming UDP datagram
153    pub fn read(
154        &mut self,
155        now: Instant,
156        remote: SocketAddr,
157        ecn: Option<EcnCodepoint>,
158        data: BytesMut,
159    ) -> Result<Vec<EndpointEvent>> {
160        if let Vacant(e) = self.connections.entry(remote) {
161            if let Some(server_config) = &self.server_config {
162                let handshake_config = server_config.clone();
163                let conn = DTLSConn::new(handshake_config, false, None);
164                e.insert(conn);
165            } else {
166                return Err(Error::NoServerConfig);
167            }
168        }
169
170        // Handle packet on existing association, if any
171        let mut messages = vec![];
172        if let Some(conn) = self.connections.get_mut(&remote) {
173            let is_handshake_completed_before = conn.is_handshake_completed();
174            conn.read(&data)?;
175            if !conn.is_handshake_completed() {
176                conn.handshake()?;
177                conn.handle_incoming_queued_packets()?;
178            }
179            if !is_handshake_completed_before && conn.is_handshake_completed() {
180                messages.push(EndpointEvent::HandshakeComplete)
181            }
182            while let Some(message) = conn.incoming_application_data() {
183                messages.push(EndpointEvent::ApplicationData(message));
184            }
185            while let Some(payload) = conn.outgoing_raw_packet() {
186                self.transmits.push_back(TransportMessage {
187                    now,
188                    transport: TransportContext {
189                        local_addr: self.local_addr,
190                        peer_addr: remote,
191                        ecn,
192                        transport_protocol: self.transport_protocol,
193                    },
194                    message: payload,
195                });
196            }
197        }
198
199        Ok(messages)
200    }
201
202    pub fn write(&mut self, remote: SocketAddr, data: &[u8]) -> Result<()> {
203        if let Some(conn) = self.connections.get_mut(&remote) {
204            conn.write(data)?;
205            while let Some(payload) = conn.outgoing_raw_packet() {
206                self.transmits.push_back(TransportMessage {
207                    now: Instant::now(),
208                    transport: TransportContext {
209                        local_addr: self.local_addr,
210                        peer_addr: remote,
211                        ecn: None,
212                        transport_protocol: self.transport_protocol,
213                    },
214                    message: payload,
215                });
216            }
217            Ok(())
218        } else {
219            Err(Error::InvalidRemoteAddress(remote))
220        }
221    }
222
223    pub fn handle_timeout(&mut self, remote: SocketAddr, now: Instant) -> Result<()> {
224        if let Some(conn) = self.connections.get_mut(&remote) {
225            if let Some(current_retransmit_timer) = &conn.current_retransmit_timer
226                && now >= *current_retransmit_timer
227            {
228                if conn.current_retransmit_timer.take().is_some() && !conn.is_handshake_completed()
229                {
230                    conn.handshake_timeout(now)?;
231                }
232                while let Some(payload) = conn.outgoing_raw_packet() {
233                    self.transmits.push_back(TransportMessage {
234                        now,
235                        transport: TransportContext {
236                            local_addr: self.local_addr,
237                            peer_addr: remote,
238                            ecn: None,
239                            transport_protocol: self.transport_protocol,
240                        },
241                        message: payload,
242                    });
243                }
244            }
245            Ok(())
246        } else {
247            Err(Error::InvalidRemoteAddress(remote))
248        }
249    }
250
251    pub fn poll_timeout(&self, remote: SocketAddr, eto: &mut Instant) -> Result<()> {
252        if let Some(conn) = self.connections.get(&remote) {
253            if let Some(current_retransmit_timer) = &conn.current_retransmit_timer
254                && *current_retransmit_timer < *eto
255            {
256                *eto = *current_retransmit_timer;
257            }
258            Ok(())
259        } else {
260            Err(Error::InvalidRemoteAddress(remote))
261        }
262    }
263}