1use core::{borrow::BorrowMut, net::SocketAddr, ops::Range};
2
3use log::trace;
4use rand_core::CryptoRngCore;
5use sha2::{
6 digest::{generic_array::GenericArray, OutputSizeUser},
7 Sha256,
8};
9
10use crate::{
11 close_connection,
12 handshake::{ClientState, ServerState},
13 open_connection,
14 parsing_utility::ParseBuffer,
15 record_parsing::{EncodeCiphertextRecord, RecordContentType},
16 stage_alert, try_open_new_handshake, try_pass_packet_to_connection,
17 try_pass_packet_to_handshake, ConnectionId, DeferredAction, DtlsConnection, DtlsError,
18 DtlsPoll, EpochState, HandshakeSlot, HandshakeSlotState, HandshakeState, TimeStampMs,
19};
20
21use super::handshake::{process_client_sync, process_server_sync};
22
23pub struct DtlsStack<'a, const CONNECTIONS: usize> {
24 connections: [Option<DtlsConnection<'a>>; CONNECTIONS],
25
26 rng: &'a mut dyn rand_core::CryptoRngCore,
27 staging_buffer: &'a mut [u8],
28
29 send_to_peer: &'a mut dyn FnMut(&SocketAddr, &[u8]),
30
31 require_cookie: bool,
32 cookie_key: GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>,
35}
36
37impl<'a, const CONNECTIONS: usize> DtlsStack<'a, CONNECTIONS> {
38 pub fn new(
39 rng: &'a mut dyn CryptoRngCore,
40 staging_buffer: &'a mut [u8],
41 send_to_peer: &'a mut dyn FnMut(&SocketAddr, &[u8]),
42 ) -> Result<Self, DtlsError> {
43 let mut me = Self {
44 connections: [const { None }; CONNECTIONS],
45 rng,
46 staging_buffer,
47 send_to_peer,
48 cookie_key: GenericArray::default(),
49 require_cookie: true,
50 };
51 me.rng
52 .try_fill_bytes(&mut me.cookie_key)
53 .map_err(|_| DtlsError::RngError)?;
54 Ok(me)
55 }
56
57 pub fn poll(
58 &mut self,
59 handshakes: &mut [HandshakeSlot],
60 now_ms: TimeStampMs,
61 ) -> Result<DtlsPoll, DtlsError> {
62 let mut return_poll = DtlsPoll::Wait;
63 for handshake in handshakes {
64 let poll = match &mut handshake.state {
65 HandshakeSlotState::Running {
66 state,
67 handshake: ctx,
68 } => {
69 let conn = ctx.connection(&mut self.connections);
70 let mut new_state = *state;
71 let addr = conn.addr;
72 let poll = match &mut new_state {
73 HandshakeState::Client(c) => process_client_sync(
74 c,
75 &now_ms,
76 ctx,
77 &mut handshake.rt_queue,
78 conn,
79 self.rng,
80 self.staging_buffer,
81 &mut |bytes| (self.send_to_peer)(&addr, bytes),
82 ),
83 HandshakeState::Server(s) => process_server_sync(
84 s,
85 &now_ms,
86 ctx,
87 &mut handshake.rt_queue,
88 conn,
89 self.rng,
90 self.staging_buffer,
91 &mut |bytes| (self.send_to_peer)(&addr, bytes),
92 ),
93 };
94 try_send_alert_sync(
95 &poll,
96 self.staging_buffer,
97 &mut |b| (self.send_to_peer)(&addr, b),
98 &mut conn.epochs,
99 &conn.current_epoch,
100 );
101 let poll = poll?;
102 *state = new_state;
103 if matches!(
104 state,
105 HandshakeState::Client(ClientState::FinishedHandshake)
106 | HandshakeState::Server(ServerState::FinishedHandshake)
107 ) {
108 handshake.finish_handshake(conn);
109 }
110 poll
111 }
112 HandshakeSlotState::Empty => DtlsPoll::Wait,
113 HandshakeSlotState::Finished(_) => DtlsPoll::FinishedHandshake,
114 };
115 return_poll = return_poll.merge(poll);
116 }
117 Ok(return_poll)
118 }
119
120 pub fn open_connection(&mut self, slot: &mut HandshakeSlot, addr: &SocketAddr) -> bool {
121 open_connection(&mut self.connections, slot, addr)
122 }
123
124 pub fn close_connection(&mut self, connection_id: ConnectionId) -> bool {
126 let addr = self
127 .connections
128 .get(connection_id.0)
129 .and_then(|c| c.as_ref().map(|c| c.addr));
130 match (
131 addr,
132 close_connection(connection_id, self.staging_buffer, &mut self.connections),
133 ) {
134 (Some(addr), DeferredAction::Send(buf)) => {
135 (self.send_to_peer)(&addr, buf);
136 true
137 }
138 (_, DeferredAction::None) => false,
139 _ => unreachable!(),
140 }
141 }
142
143 pub fn send_dtls_packet(
144 &mut self,
145 connection_id: ConnectionId,
146 packet: &[u8],
147 ) -> Result<(), DtlsError> {
148 if let Some(connection) = &mut self.connections[connection_id.0] {
149 let mut buffer = ParseBuffer::init(self.staging_buffer.borrow_mut());
150 let epoch_index = connection.current_epoch as usize & 3;
151 let mut record = EncodeCiphertextRecord::new(
152 &mut buffer,
153 &connection.epochs[epoch_index],
154 &connection.current_epoch,
155 )?;
156 record.payload_buffer().write_slice_checked(packet)?;
157 record.finish(
158 &mut connection.epochs[epoch_index],
159 RecordContentType::ApplicationData,
160 )?;
161 (self.send_to_peer)(&connection.addr, buffer.as_ref());
162 Ok(())
163 } else {
164 Err(DtlsError::UnknownConnection)
165 }
166 }
167
168 pub fn staging_buffer(&mut self) -> &mut [u8] {
169 self.staging_buffer
170 }
171
172 pub fn handle_dtls_packet(
173 &mut self,
174 handshakes: &mut [HandshakeSlot],
175 addr: &SocketAddr,
176 packet_len: usize,
177 handle_app_data: &mut dyn FnMut(ConnectionId, Range<usize>, &mut Self),
178 ) -> Result<(), DtlsError> {
179 let mut handled = true;
180 match try_pass_packet_to_connection(
181 self.staging_buffer,
182 &mut self.connections,
183 addr,
184 packet_len,
185 )? {
186 DeferredAction::Send(buf) => (self.send_to_peer)(addr, buf),
187 DeferredAction::AppData(id, range) => handle_app_data(id, range, self),
188 DeferredAction::None => {}
189 DeferredAction::Unhandled => handled = false,
190 }
191 if handled {
192 return Ok(());
193 }
194 handled = true;
195 trace!("Could not match packet to connection");
196 match try_pass_packet_to_handshake(
197 self.staging_buffer,
198 &mut self.connections,
199 handshakes,
200 addr,
201 packet_len,
202 )? {
203 DeferredAction::Send(buf) => (self.send_to_peer)(addr, buf),
204 DeferredAction::AppData(id, range) => handle_app_data(id, range, self),
205 DeferredAction::None => {}
206 DeferredAction::Unhandled => handled = false,
207 }
208 if handled {
209 return Ok(());
210 }
211 trace!("Could not match packet to handshake");
212 if let Some(buf) = try_open_new_handshake(
213 self.staging_buffer,
214 self.require_cookie,
215 &self.cookie_key,
216 handshakes,
217 &mut self.connections,
218 addr,
219 packet_len,
220 )? {
221 (self.send_to_peer)(addr, buf);
222 };
223 Ok(())
224 }
225
226 pub fn require_cookie(&mut self, require_cookie: bool) {
227 self.require_cookie = require_cookie;
228 }
229}
230
231fn try_send_alert_sync<T>(
232 error: &Result<T, DtlsError>,
233 staging_buffer: &mut [u8],
234 send_bytes: &mut dyn FnMut(&[u8]),
235 epoch_states: &mut [EpochState],
236 epoch: &u64,
237) {
238 if let Err(DtlsError::Alert(alert)) = error {
239 if let Ok(buf) = stage_alert(staging_buffer, epoch_states, epoch, *alert) {
240 send_bytes(buf);
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use crate::{DtlsError, DtlsStack, HandshakeSlot};
248 use core::net::{IpAddr, Ipv4Addr, SocketAddr};
249
250 #[test]
251 pub fn fail_open_more_handshakes_than_connections() {
252 let mut rng = rand::thread_rng();
253 let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
254 let mut stack = DtlsStack::<1>::new(&mut rng, &mut [], &mut send_to_peer).unwrap();
255 let mut hs = [
256 HandshakeSlot::new(&[], &mut []),
257 HandshakeSlot::new(&[], &mut []),
258 ];
259 let res = stack.open_connection(
260 &mut hs[0],
261 &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
262 );
263 assert!(res);
264 let res = stack.open_connection(
265 &mut hs[1],
266 &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
267 );
268 assert!(!res);
269 }
270
271 #[test]
272 pub fn closing_connections_works() {
273 let mut rng = rand::thread_rng();
274 let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
275 let mut b = [0; 250];
276 let mut stack = DtlsStack::<1>::new(&mut rng, &mut b, &mut send_to_peer).unwrap();
277 let mut hs = [HandshakeSlot::new(&[], &mut [])];
278 let res = stack.open_connection(
279 &mut hs[0],
280 &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
281 );
282 assert!(res);
283
284 hs[0].finish_handshake(stack.connections[0].as_mut().unwrap());
285 assert!(hs[0].try_take_connection_id().is_some());
286
287 let res = stack.close_connection(crate::ConnectionId(0));
288 assert!(res);
289 let res = stack.open_connection(
290 &mut hs[0],
291 &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
292 );
293 assert!(res);
294 }
295
296 #[test]
297 pub fn try_close_non_open_connection() {
298 let mut rng = rand::thread_rng();
299 let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
300 let mut stack = DtlsStack::<1>::new(&mut rng, &mut [], &mut send_to_peer).unwrap();
301 let mut hs = [HandshakeSlot::new(&[], &mut [])];
302 let res = stack.open_connection(
303 &mut hs[0],
304 &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
305 );
306 assert!(res);
307 let res = stack.close_connection(crate::ConnectionId(0));
308 assert!(!res);
309 }
310
311 #[test]
312 pub fn overflow_stage_buffer() {
313 let mut rng = rand::thread_rng();
314 let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
315 let mut stage_buffer = [0u8; 250];
316 let mut stack =
317 DtlsStack::<1>::new(&mut rng, &mut stage_buffer, &mut send_to_peer).unwrap();
318 let mut hs = [HandshakeSlot::new(&[], &mut [])];
319 let res = stack.open_connection(
320 &mut hs[0],
321 &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
322 );
323 assert!(res);
324 hs[0].finish_handshake(stack.connections[0].as_mut().unwrap());
325 let cid = hs[0].try_take_connection_id().unwrap();
326 let data = [1u8; 249];
327
328 let e = stack.send_dtls_packet(cid, &data);
329 assert!(matches!(e, Err(DtlsError::OutOfMemory)));
330 }
331}