rusty_dtls/sync/
connection.rs

1use core::{borrow::BorrowMut, net::SocketAddr, ops::Range};
2
3use log::trace;
4use rand_core::CryptoRngCore;
5use sha2::{
6    digest::{generic_array::GenericArray, OutputSizeUser},
7    Sha256,
8};
9
10use crate::{
11    close_connection,
12    handshake::{ClientState, ServerState},
13    open_connection,
14    parsing_utility::ParseBuffer,
15    record_parsing::{EncodeCiphertextRecord, RecordContentType},
16    stage_alert, try_open_new_handshake, try_pass_packet_to_connection,
17    try_pass_packet_to_handshake, ConnectionId, DeferredAction, DtlsConnection, DtlsError,
18    DtlsPoll, EpochState, HandshakeSlot, HandshakeSlotState, HandshakeState, TimeStampMs,
19};
20
21use super::handshake::{process_client_sync, process_server_sync};
22
23pub struct DtlsStack<'a, const CONNECTIONS: usize> {
24    connections: [Option<DtlsConnection<'a>>; CONNECTIONS],
25
26    rng: &'a mut dyn rand_core::CryptoRngCore,
27    staging_buffer: &'a mut [u8],
28
29    send_to_peer: &'a mut dyn FnMut(&SocketAddr, &[u8]),
30
31    require_cookie: bool,
32    // In any case the minimal recommended length for K is L bytes (as the hash output
33    // length) RFC 2104
34    cookie_key: GenericArray<u8, <Sha256 as OutputSizeUser>::OutputSize>,
35}
36
37impl<'a, const CONNECTIONS: usize> DtlsStack<'a, CONNECTIONS> {
38    pub fn new(
39        rng: &'a mut dyn CryptoRngCore,
40        staging_buffer: &'a mut [u8],
41        send_to_peer: &'a mut dyn FnMut(&SocketAddr, &[u8]),
42    ) -> Result<Self, DtlsError> {
43        let mut me = Self {
44            connections: [const { None }; CONNECTIONS],
45            rng,
46            staging_buffer,
47            send_to_peer,
48            cookie_key: GenericArray::default(),
49            require_cookie: true,
50        };
51        me.rng
52            .try_fill_bytes(&mut me.cookie_key)
53            .map_err(|_| DtlsError::RngError)?;
54        Ok(me)
55    }
56
57    pub fn poll(
58        &mut self,
59        handshakes: &mut [HandshakeSlot],
60        now_ms: TimeStampMs,
61    ) -> Result<DtlsPoll, DtlsError> {
62        let mut return_poll = DtlsPoll::Wait;
63        for handshake in handshakes {
64            let poll = match &mut handshake.state {
65                HandshakeSlotState::Running {
66                    state,
67                    handshake: ctx,
68                } => {
69                    let conn = ctx.connection(&mut self.connections);
70                    let mut new_state = *state;
71                    let addr = conn.addr;
72                    let poll = match &mut new_state {
73                        HandshakeState::Client(c) => process_client_sync(
74                            c,
75                            &now_ms,
76                            ctx,
77                            &mut handshake.rt_queue,
78                            conn,
79                            self.rng,
80                            self.staging_buffer,
81                            &mut |bytes| (self.send_to_peer)(&addr, bytes),
82                        ),
83                        HandshakeState::Server(s) => process_server_sync(
84                            s,
85                            &now_ms,
86                            ctx,
87                            &mut handshake.rt_queue,
88                            conn,
89                            self.rng,
90                            self.staging_buffer,
91                            &mut |bytes| (self.send_to_peer)(&addr, bytes),
92                        ),
93                    };
94                    try_send_alert_sync(
95                        &poll,
96                        self.staging_buffer,
97                        &mut |b| (self.send_to_peer)(&addr, b),
98                        &mut conn.epochs,
99                        &conn.current_epoch,
100                    );
101                    let poll = poll?;
102                    *state = new_state;
103                    if matches!(
104                        state,
105                        HandshakeState::Client(ClientState::FinishedHandshake)
106                            | HandshakeState::Server(ServerState::FinishedHandshake)
107                    ) {
108                        handshake.finish_handshake(conn);
109                    }
110                    poll
111                }
112                HandshakeSlotState::Empty => DtlsPoll::Wait,
113                HandshakeSlotState::Finished(_) => DtlsPoll::FinishedHandshake,
114            };
115            return_poll = return_poll.merge(poll);
116        }
117        Ok(return_poll)
118    }
119
120    pub fn open_connection(&mut self, slot: &mut HandshakeSlot, addr: &SocketAddr) -> bool {
121        open_connection(&mut self.connections, slot, addr)
122    }
123
124    /// Returns whether the connection was closed successfully
125    pub fn close_connection(&mut self, connection_id: ConnectionId) -> bool {
126        let addr = self
127            .connections
128            .get(connection_id.0)
129            .and_then(|c| c.as_ref().map(|c| c.addr));
130        match (
131            addr,
132            close_connection(connection_id, self.staging_buffer, &mut self.connections),
133        ) {
134            (Some(addr), DeferredAction::Send(buf)) => {
135                (self.send_to_peer)(&addr, buf);
136                true
137            }
138            (_, DeferredAction::None) => false,
139            _ => unreachable!(),
140        }
141    }
142
143    pub fn send_dtls_packet(
144        &mut self,
145        connection_id: ConnectionId,
146        packet: &[u8],
147    ) -> Result<(), DtlsError> {
148        if let Some(connection) = &mut self.connections[connection_id.0] {
149            let mut buffer = ParseBuffer::init(self.staging_buffer.borrow_mut());
150            let epoch_index = connection.current_epoch as usize & 3;
151            let mut record = EncodeCiphertextRecord::new(
152                &mut buffer,
153                &connection.epochs[epoch_index],
154                &connection.current_epoch,
155            )?;
156            record.payload_buffer().write_slice_checked(packet)?;
157            record.finish(
158                &mut connection.epochs[epoch_index],
159                RecordContentType::ApplicationData,
160            )?;
161            (self.send_to_peer)(&connection.addr, buffer.as_ref());
162            Ok(())
163        } else {
164            Err(DtlsError::UnknownConnection)
165        }
166    }
167
168    pub fn staging_buffer(&mut self) -> &mut [u8] {
169        self.staging_buffer
170    }
171
172    pub fn handle_dtls_packet(
173        &mut self,
174        handshakes: &mut [HandshakeSlot],
175        addr: &SocketAddr,
176        packet_len: usize,
177        handle_app_data: &mut dyn FnMut(ConnectionId, Range<usize>, &mut Self),
178    ) -> Result<(), DtlsError> {
179        let mut handled = true;
180        match try_pass_packet_to_connection(
181            self.staging_buffer,
182            &mut self.connections,
183            addr,
184            packet_len,
185        )? {
186            DeferredAction::Send(buf) => (self.send_to_peer)(addr, buf),
187            DeferredAction::AppData(id, range) => handle_app_data(id, range, self),
188            DeferredAction::None => {}
189            DeferredAction::Unhandled => handled = false,
190        }
191        if handled {
192            return Ok(());
193        }
194        handled = true;
195        trace!("Could not match packet to connection");
196        match try_pass_packet_to_handshake(
197            self.staging_buffer,
198            &mut self.connections,
199            handshakes,
200            addr,
201            packet_len,
202        )? {
203            DeferredAction::Send(buf) => (self.send_to_peer)(addr, buf),
204            DeferredAction::AppData(id, range) => handle_app_data(id, range, self),
205            DeferredAction::None => {}
206            DeferredAction::Unhandled => handled = false,
207        }
208        if handled {
209            return Ok(());
210        }
211        trace!("Could not match packet to handshake");
212        if let Some(buf) = try_open_new_handshake(
213            self.staging_buffer,
214            self.require_cookie,
215            &self.cookie_key,
216            handshakes,
217            &mut self.connections,
218            addr,
219            packet_len,
220        )? {
221            (self.send_to_peer)(addr, buf);
222        };
223        Ok(())
224    }
225
226    pub fn require_cookie(&mut self, require_cookie: bool) {
227        self.require_cookie = require_cookie;
228    }
229}
230
231fn try_send_alert_sync<T>(
232    error: &Result<T, DtlsError>,
233    staging_buffer: &mut [u8],
234    send_bytes: &mut dyn FnMut(&[u8]),
235    epoch_states: &mut [EpochState],
236    epoch: &u64,
237) {
238    if let Err(DtlsError::Alert(alert)) = error {
239        if let Ok(buf) = stage_alert(staging_buffer, epoch_states, epoch, *alert) {
240            send_bytes(buf);
241        }
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use crate::{DtlsError, DtlsStack, HandshakeSlot};
248    use core::net::{IpAddr, Ipv4Addr, SocketAddr};
249
250    #[test]
251    pub fn fail_open_more_handshakes_than_connections() {
252        let mut rng = rand::thread_rng();
253        let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
254        let mut stack = DtlsStack::<1>::new(&mut rng, &mut [], &mut send_to_peer).unwrap();
255        let mut hs = [
256            HandshakeSlot::new(&[], &mut []),
257            HandshakeSlot::new(&[], &mut []),
258        ];
259        let res = stack.open_connection(
260            &mut hs[0],
261            &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
262        );
263        assert!(res);
264        let res = stack.open_connection(
265            &mut hs[1],
266            &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
267        );
268        assert!(!res);
269    }
270
271    #[test]
272    pub fn closing_connections_works() {
273        let mut rng = rand::thread_rng();
274        let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
275        let mut b = [0; 250];
276        let mut stack = DtlsStack::<1>::new(&mut rng, &mut b, &mut send_to_peer).unwrap();
277        let mut hs = [HandshakeSlot::new(&[], &mut [])];
278        let res = stack.open_connection(
279            &mut hs[0],
280            &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
281        );
282        assert!(res);
283
284        hs[0].finish_handshake(stack.connections[0].as_mut().unwrap());
285        assert!(hs[0].try_take_connection_id().is_some());
286
287        let res = stack.close_connection(crate::ConnectionId(0));
288        assert!(res);
289        let res = stack.open_connection(
290            &mut hs[0],
291            &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
292        );
293        assert!(res);
294    }
295
296    #[test]
297    pub fn try_close_non_open_connection() {
298        let mut rng = rand::thread_rng();
299        let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
300        let mut stack = DtlsStack::<1>::new(&mut rng, &mut [], &mut send_to_peer).unwrap();
301        let mut hs = [HandshakeSlot::new(&[], &mut [])];
302        let res = stack.open_connection(
303            &mut hs[0],
304            &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
305        );
306        assert!(res);
307        let res = stack.close_connection(crate::ConnectionId(0));
308        assert!(!res);
309    }
310
311    #[test]
312    pub fn overflow_stage_buffer() {
313        let mut rng = rand::thread_rng();
314        let mut send_to_peer = |_: &SocketAddr, _: &[u8]| {};
315        let mut stage_buffer = [0u8; 250];
316        let mut stack =
317            DtlsStack::<1>::new(&mut rng, &mut stage_buffer, &mut send_to_peer).unwrap();
318        let mut hs = [HandshakeSlot::new(&[], &mut [])];
319        let res = stack.open_connection(
320            &mut hs[0],
321            &SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
322        );
323        assert!(res);
324        hs[0].finish_handshake(stack.connections[0].as_mut().unwrap());
325        let cid = hs[0].try_take_connection_id().unwrap();
326        let data = [1u8; 249];
327
328        let e = stack.send_dtls_packet(cid, &data);
329        assert!(matches!(e, Err(DtlsError::OutOfMemory)));
330    }
331}