1use crate::conn::DTLSConn;
2use shared::error::{Error, Result};
3use shared::{EcnCodepoint, TransportContext};
4use shared::{TransportMessage, TransportProtocol};
5
6use crate::config::HandshakeConfig;
7use crate::state::State;
8use bytes::BytesMut;
9use std::collections::hash_map::Keys;
10use std::collections::{HashMap, VecDeque, hash_map::Entry::Vacant};
11use std::net::SocketAddr;
12use std::sync::Arc;
13use std::time::Instant;
14
15pub enum EndpointEvent {
16 HandshakeComplete,
17 ApplicationData(BytesMut),
18}
19
20pub struct Endpoint {
26 local_addr: SocketAddr,
27 transport_protocol: TransportProtocol,
28 transmits: VecDeque<TransportMessage<BytesMut>>,
29 connections: HashMap<SocketAddr, DTLSConn>,
30 server_config: Option<Arc<HandshakeConfig>>,
31}
32
33impl Endpoint {
34 pub fn new(
38 local_addr: SocketAddr,
39 protocol: TransportProtocol,
40 server_config: Option<Arc<HandshakeConfig>>,
41 ) -> Self {
42 Self {
43 local_addr,
44 transport_protocol: protocol,
45 transmits: VecDeque::new(),
46 connections: HashMap::new(),
47 server_config,
48 }
49 }
50
51 pub fn set_server_config(&mut self, server_config: Option<Arc<HandshakeConfig>>) {
53 self.server_config = server_config;
54 }
55
56 #[must_use]
58 pub fn poll_transmit(&mut self) -> Option<TransportMessage<BytesMut>> {
59 self.transmits.pop_front()
60 }
61
62 pub fn get_connections_keys(&self) -> Keys<'_, SocketAddr, DTLSConn> {
64 self.connections.keys()
65 }
66
67 pub fn get_connection_state(&self, remote: SocketAddr) -> Option<&State> {
69 if let Some(conn) = self.connections.get(&remote) {
70 Some(conn.connection_state())
71 } else {
72 None
73 }
74 }
75
76 pub fn connect(
78 &mut self,
79 remote: SocketAddr,
80 client_config: Arc<HandshakeConfig>,
81 initial_state: Option<State>,
82 ) -> Result<()> {
83 if remote.port() == 0 {
84 return Err(Error::InvalidRemoteAddress(remote));
85 }
86
87 if let Vacant(e) = self.connections.entry(remote) {
88 let mut conn = DTLSConn::new(client_config, true, initial_state);
89 conn.handshake()?;
90
91 while let Some(payload) = conn.outgoing_raw_packet() {
92 self.transmits.push_back(TransportMessage {
93 now: Instant::now(),
94 transport: TransportContext {
95 local_addr: self.local_addr,
96 peer_addr: remote,
97 ecn: None,
98 transport_protocol: self.transport_protocol,
99 },
100 message: payload,
101 });
102 }
103
104 e.insert(conn);
105 }
106
107 Ok(())
108 }
109
110 pub fn stop(&mut self, remote: SocketAddr) -> Option<DTLSConn> {
112 if let Some(conn) = self.connections.get_mut(&remote) {
113 conn.close();
114 while let Some(payload) = conn.outgoing_raw_packet() {
115 self.transmits.push_back(TransportMessage {
116 now: Instant::now(),
117 transport: TransportContext {
118 local_addr: self.local_addr,
119 peer_addr: remote,
120 ecn: None,
121 transport_protocol: self.transport_protocol,
122 },
123 message: payload,
124 });
125 }
126 }
127 self.connections.remove(&remote)
128 }
129
130 pub fn close(&mut self) -> Result<()> {
132 for (remote_addr, conn) in self.connections.iter_mut() {
133 conn.close();
134 while let Some(payload) = conn.outgoing_raw_packet() {
135 self.transmits.push_back(TransportMessage {
136 now: Instant::now(),
137 transport: TransportContext {
138 local_addr: self.local_addr,
139 peer_addr: *remote_addr,
140 ecn: None,
141 transport_protocol: self.transport_protocol,
142 },
143 message: payload,
144 });
145 }
146 }
147 self.connections.clear();
148
149 Ok(())
150 }
151
152 pub fn read(
154 &mut self,
155 now: Instant,
156 remote: SocketAddr,
157 ecn: Option<EcnCodepoint>,
158 data: BytesMut,
159 ) -> Result<Vec<EndpointEvent>> {
160 if let Vacant(e) = self.connections.entry(remote) {
161 if let Some(server_config) = &self.server_config {
162 let handshake_config = server_config.clone();
163 let conn = DTLSConn::new(handshake_config, false, None);
164 e.insert(conn);
165 } else {
166 return Err(Error::NoServerConfig);
167 }
168 }
169
170 let mut messages = vec![];
172 if let Some(conn) = self.connections.get_mut(&remote) {
173 let is_handshake_completed_before = conn.is_handshake_completed();
174 conn.read(&data)?;
175 if !conn.is_handshake_completed() {
176 conn.handshake()?;
177 conn.handle_incoming_queued_packets()?;
178 }
179 if !is_handshake_completed_before && conn.is_handshake_completed() {
180 messages.push(EndpointEvent::HandshakeComplete)
181 }
182 while let Some(message) = conn.incoming_application_data() {
183 messages.push(EndpointEvent::ApplicationData(message));
184 }
185 while let Some(payload) = conn.outgoing_raw_packet() {
186 self.transmits.push_back(TransportMessage {
187 now,
188 transport: TransportContext {
189 local_addr: self.local_addr,
190 peer_addr: remote,
191 ecn,
192 transport_protocol: self.transport_protocol,
193 },
194 message: payload,
195 });
196 }
197 }
198
199 Ok(messages)
200 }
201
202 pub fn write(&mut self, remote: SocketAddr, data: &[u8]) -> Result<()> {
203 if let Some(conn) = self.connections.get_mut(&remote) {
204 conn.write(data)?;
205 while let Some(payload) = conn.outgoing_raw_packet() {
206 self.transmits.push_back(TransportMessage {
207 now: Instant::now(),
208 transport: TransportContext {
209 local_addr: self.local_addr,
210 peer_addr: remote,
211 ecn: None,
212 transport_protocol: self.transport_protocol,
213 },
214 message: payload,
215 });
216 }
217 Ok(())
218 } else {
219 Err(Error::InvalidRemoteAddress(remote))
220 }
221 }
222
223 pub fn handle_timeout(&mut self, remote: SocketAddr, now: Instant) -> Result<()> {
224 if let Some(conn) = self.connections.get_mut(&remote) {
225 if let Some(current_retransmit_timer) = &conn.current_retransmit_timer
226 && now >= *current_retransmit_timer
227 {
228 if conn.current_retransmit_timer.take().is_some() && !conn.is_handshake_completed()
229 {
230 conn.handshake_timeout(now)?;
231 }
232 while let Some(payload) = conn.outgoing_raw_packet() {
233 self.transmits.push_back(TransportMessage {
234 now,
235 transport: TransportContext {
236 local_addr: self.local_addr,
237 peer_addr: remote,
238 ecn: None,
239 transport_protocol: self.transport_protocol,
240 },
241 message: payload,
242 });
243 }
244 }
245 Ok(())
246 } else {
247 Err(Error::InvalidRemoteAddress(remote))
248 }
249 }
250
251 pub fn poll_timeout(&self, remote: SocketAddr, eto: &mut Instant) -> Result<()> {
252 if let Some(conn) = self.connections.get(&remote) {
253 if let Some(current_retransmit_timer) = &conn.current_retransmit_timer
254 && *current_retransmit_timer < *eto
255 {
256 *eto = *current_retransmit_timer;
257 }
258 Ok(())
259 } else {
260 Err(Error::InvalidRemoteAddress(remote))
261 }
262 }
263}