1use std::{convert::TryInto, net::SocketAddr, time::Instant};
2
3use crate::{packet::*, protocol::handshake::Handshake, settings::*};
4
5use super::{
6 cookie::gen_cookie, hsv5::gen_access_control_response, hsv5::GenHsv5Result,
7 AccessControlRequest, AccessControlResponse, ConnectError, Connection, ConnectionReject,
8 ConnectionResult,
9};
10
11use ConnectionResult::*;
12use ListenState::*;
13
14#[derive(Debug)]
15pub struct Listen {
16 init_settings: ConnInitSettings,
17 state: ListenState,
18 enable_access_control: bool,
19}
20
21#[derive(Clone, Debug)]
22pub struct ConclusionWaitState {
23 from: SocketAddr,
24 cookie: i32,
25 induction_response: Packet,
26 induction_time: Instant,
27}
28
29#[derive(Clone, Debug)]
30#[allow(clippy::large_enum_variant)]
31enum ListenState {
32 InductionWait,
33 ConclusionWait(ConclusionWaitState),
34 AccessControlRequested(
35 ConclusionWaitState,
36 TimeStamp,
37 HandshakeControlInfo,
38 HsV5Info,
39 ),
40}
41
42impl Listen {
43 pub fn new(init_settings: ConnInitSettings, enable_access_control: bool) -> Listen {
44 Listen {
45 state: InductionWait,
46 init_settings,
47 enable_access_control,
48 }
49 }
50
51 pub fn settings(&self) -> &ConnInitSettings {
52 &self.init_settings
53 }
54
55 pub fn handle_packet(&mut self, now: Instant, packet: ReceivePacketResult) -> ConnectionResult {
56 use ReceivePacketError::*;
57 match packet {
58 Ok((packet, from)) => match packet {
59 Packet::Control(control) => self.handle_control_packets(now, from, control),
60 Packet::Data(data) => NotHandled(ConnectError::ControlExpected(data)),
61 },
62 Err(Io(error)) => Failure(error),
63 Err(Parse(e)) => NotHandled(ConnectError::ParseFailed(e)),
64 }
65 }
66
67 pub fn handle_access_control_response(
68 &mut self,
69 now: Instant,
70 response: AccessControlResponse,
71 ) -> ConnectionResult {
72 match self.state.clone() {
73 InductionWait | ConclusionWait(_) => NotHandled(ConnectError::ExpectedHsReq),
75 AccessControlRequested(state, timestamp, shake, info) => {
76 use AccessControlResponse::*;
77 match response {
78 Accepted(key_settings) => {
79 self.accept_connection(now, &state, timestamp, shake, info, key_settings)
80 }
81 Rejected(rr) => self.make_rejection(
82 &shake,
83 state.from,
84 timestamp,
85 ConnectionReject::Rejecting(rr),
86 ),
87 Dropped => self.make_rejection(
88 &shake,
89 state.from,
90 timestamp,
91 ConnectionReject::Rejecting(RejectReason::Core(CoreRejectReason::Peer)),
92 ),
93 }
94 }
95 }
96 }
97
98 pub fn handle_timer(&self, _now: Instant) -> ConnectionResult {
99 NoAction
100 }
101
102 fn handle_control_packets(
103 &mut self,
104 now: Instant,
105 from: SocketAddr,
106 control: ControlPacket,
107 ) -> ConnectionResult {
108 match (self.state.clone(), control.control_type) {
109 (InductionWait, ControlTypes::Handshake(shake)) => {
110 self.wait_for_induction(from, control.timestamp, shake, now)
111 }
112 (ConclusionWait(state), ControlTypes::Handshake(shake)) => self.wait_for_conclusion(
113 now,
114 from,
115 control.dest_sockid,
116 control.timestamp,
117 state,
118 shake,
119 ),
120 (AccessControlRequested(_, _, _, _), _) => {
121 NotHandled(ConnectError::ExpectedAccessControlResponse)
122 }
123 (InductionWait, control_type) | (ConclusionWait(_), control_type) => {
124 NotHandled(ConnectError::HandshakeExpected(control_type))
125 }
126 }
127 }
128
129 fn wait_for_induction(
130 &mut self,
131 from: SocketAddr,
132 timestamp: TimeStamp,
133 shake: HandshakeControlInfo,
134 now: Instant,
135 ) -> ConnectionResult {
136 match shake.shake_type {
137 ShakeType::Induction => {
138 let cookie = gen_cookie(&from);
146
147 let induction_response = Packet::Control(ControlPacket {
150 timestamp,
151 dest_sockid: shake.socket_id,
152 control_type: ControlTypes::Handshake(HandshakeControlInfo {
153 syn_cookie: cookie,
154 socket_id: self.init_settings.local_sockid,
155 info: HandshakeVsInfo::V5(HsV5Info::default()),
156 ..shake
157 }),
158 });
159
160 let save_induction_response = induction_response.clone();
162 self.state = ConclusionWait(ConclusionWaitState {
163 from,
164 cookie,
165 induction_response: save_induction_response,
166 induction_time: now,
167 });
168 SendPacket((induction_response, from))
169 }
170 _ => NotHandled(ConnectError::InductionExpected(shake)),
171 }
172 }
173
174 fn wait_for_conclusion(
175 &mut self,
176 now: Instant,
177 from: SocketAddr,
178 local_socket_id: SocketId,
179 timestamp: TimeStamp,
180 state: ConclusionWaitState,
181 shake: HandshakeControlInfo,
182 ) -> ConnectionResult {
183 const VERSION_5: u32 = 5;
194
195 match (shake.shake_type, shake.info.version(), shake.syn_cookie) {
196 (ShakeType::Induction, _, _) => SendPacket((state.induction_response, from)),
197 (ShakeType::Conclusion, VERSION_5, syn_cookie) if syn_cookie == state.cookie => {
199 let incoming = match &shake.info {
200 HandshakeVsInfo::V5(hs) => hs,
201 _ => {
202 let r = ConnectionReject::Rejecting(
203 ServerRejectReason::Version.into(),
205 );
206 return self.make_rejection(&shake, from, timestamp, r);
207 }
208 }
209 .clone();
210
211 if self.enable_access_control {
212 self.request_access(from, local_socket_id, timestamp, state, shake, incoming)
213 } else {
214 let key_settings = self.settings().key_settings.clone();
215 self.accept_connection(now, &state, timestamp, shake, incoming, key_settings)
216 }
217 }
218 (ShakeType::Conclusion, VERSION_5, syn_cookie) => NotHandled(
219 ConnectError::InvalidHandshakeCookie(state.cookie, syn_cookie),
220 ),
221 (ShakeType::Conclusion, version, _) => {
222 NotHandled(ConnectError::UnsupportedProtocolVersion(version))
223 }
224 (_, _, _) => NotHandled(ConnectError::ConclusionExpected(shake)),
225 }
226 }
227
228 fn request_access(
229 &mut self,
230 remote: SocketAddr,
231 local_socket_id: SocketId,
232 timestamp: TimeStamp,
233 state: ConclusionWaitState,
234 shake: HandshakeControlInfo,
235 incoming: HsV5Info,
236 ) -> ConnectionResult {
237 let stream_id = incoming.sid.clone().and_then(|s| s.try_into().ok());
239 let remote_socket_id = shake.socket_id;
240 let key_size = incoming.key_size;
241
242 self.state = AccessControlRequested(state, timestamp, shake, incoming);
243
244 RequestAccess(AccessControlRequest {
245 local_socket_id,
246 remote,
247 remote_socket_id,
248 stream_id,
249 key_size,
250 })
251 }
252
253 fn accept_connection(
254 &mut self,
255 now: Instant,
256 state: &ConclusionWaitState,
257 timestamp: TimeStamp,
258 shake: HandshakeControlInfo,
259 info: HsV5Info,
260 key_settings: Option<KeySettings>,
261 ) -> ConnectionResult {
262 let response = gen_access_control_response(
263 now,
264 &mut self.init_settings,
265 state.from,
266 state.induction_time,
267 shake.clone(),
268 info,
269 key_settings,
270 );
271 let (hsv5, settings) = match response {
272 GenHsv5Result::Accept(h, c) => (h, c),
273 GenHsv5Result::NotHandled(e) => return NotHandled(e),
274 GenHsv5Result::Reject(r) => {
275 return self.make_rejection(&shake, state.from, timestamp, r);
276 }
277 };
278
279 let resp_handshake = ControlPacket {
280 timestamp,
281 dest_sockid: shake.socket_id,
282 control_type: ControlTypes::Handshake(HandshakeControlInfo {
283 syn_cookie: state.cookie,
284 socket_id: self.init_settings.local_sockid,
285 info: hsv5,
286 shake_type: ShakeType::Conclusion,
287 ..shake }),
289 };
290
291 Connected(
293 Some((resp_handshake.clone().into(), state.from)),
294 Connection {
295 settings,
296 handshake: Handshake::Listener(resp_handshake.control_type),
297 },
298 )
299 }
300
301 fn make_rejection(
302 &mut self,
303 response_to: &HandshakeControlInfo,
304 from: SocketAddr,
305 timestamp: TimeStamp,
306 r: ConnectionReject,
307 ) -> ConnectionResult {
308 self.state = InductionWait;
309 Reject(
310 Some((
311 ControlPacket {
312 timestamp,
313 dest_sockid: response_to.socket_id,
314 control_type: ControlTypes::Handshake(HandshakeControlInfo {
315 shake_type: ShakeType::Rejection(r.reason()),
316 socket_id: self.init_settings.local_sockid,
317 ..response_to.clone()
318 }),
319 }
320 .into(),
321 from,
322 )),
323 r,
324 )
325 }
326}
327
328#[cfg(test)]
329mod test {
330 use std::{
331 net::{IpAddr, Ipv4Addr},
332 time::Duration,
333 };
334
335 use assert_matches::assert_matches;
336 use bytes::Bytes;
337 use rand::random;
338
339 use crate::options::*;
340
341 use super::*;
342
343 fn conn_addr() -> SocketAddr {
344 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8765)
345 }
346
347 fn test_listen() -> Listen {
348 Listen::new(ConnInitSettings::default(), false)
349 }
350
351 fn test_induction() -> HandshakeControlInfo {
352 HandshakeControlInfo {
353 init_seq_num: random(),
354 max_packet_size: PacketSize(1316),
355 max_flow_size: PacketCount(256_000),
356 shake_type: ShakeType::Induction,
357 socket_id: random(),
358 syn_cookie: 0,
359 peer_addr: IpAddr::from([127, 0, 0, 1]),
360 info: HandshakeVsInfo::V5(HsV5Info::default()),
361 }
362 }
363
364 fn test_conclusion() -> HandshakeControlInfo {
365 HandshakeControlInfo {
366 init_seq_num: random(),
367 max_packet_size: PacketSize(1316),
368 max_flow_size: PacketCount(256_000),
369 shake_type: ShakeType::Conclusion,
370 socket_id: random(),
371 syn_cookie: gen_cookie(&conn_addr()),
372 peer_addr: IpAddr::from([127, 0, 0, 1]),
373 info: HandshakeVsInfo::V5(HsV5Info {
374 key_size: KeySize::Unspecified,
375 ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
376 version: SrtVersion::CURRENT,
377 flags: SrtShakeFlags::SUPPORTED,
378 send_latency: Duration::from_secs(1),
379 recv_latency: Duration::from_secs(2),
380 })),
381 ext_km: None,
382 ext_group: None,
383 sid: None,
384 }),
385 }
386 }
387
388 fn build_hs_pack(i: HandshakeControlInfo) -> Packet {
389 Packet::Control(ControlPacket {
390 timestamp: TimeStamp::from_micros(0),
391 dest_sockid: random(),
392 control_type: ControlTypes::Handshake(i),
393 })
394 }
395
396 #[test]
397 fn correct() {
398 let mut l = test_listen();
399
400 let resp = l.handle_packet(
401 Instant::now(),
402 Ok((build_hs_pack(test_induction()), conn_addr())),
403 );
404 assert_matches!(resp, SendPacket(_));
405
406 let resp = l.handle_packet(
407 Instant::now(),
408 Ok((build_hs_pack(test_conclusion()), conn_addr())),
409 );
410 assert_matches!(
412 resp,
413 Connected(
414 Some(_),
415 Connection {
416 handshake: Handshake::Listener(ControlTypes::Handshake(HandshakeControlInfo {
417 info: HandshakeVsInfo::V5(HsV5Info {
418 ext_hs: Some(_),
419 ..
420 }),
421 ..
422 })),
423 ..
424 },
425 )
426 );
427 }
428
429 #[test]
430 fn send_data_packet() {
431 let mut l = test_listen();
432
433 let dp = DataPacket {
434 seq_number: random(),
435 message_loc: PacketLocation::ONLY,
436 in_order_delivery: false,
437 encryption: DataEncryption::None,
438 retransmitted: false,
439 message_number: random(),
440 timestamp: TimeStamp::from_micros(0),
441 dest_sockid: random(),
442 payload: Bytes::from(&b"asdf"[..]),
443 };
444 assert_matches!(
445 l.handle_packet(Instant::now(), Ok(( Packet::Data(dp.clone()), conn_addr()))),
446 NotHandled(ConnectError::ControlExpected(d)) if d == dp
447 );
448 }
449
450 #[test]
451 fn send_ack2() {
452 let mut l = test_listen();
453
454 let a2 = ControlTypes::Ack2(FullAckSeqNumber::new(random::<u32>() + 1).unwrap());
455 assert_matches!(
456 l.handle_packet(Instant::now(),
457 Ok((
458 Packet::Control(ControlPacket {
459 timestamp: TimeStamp::from_micros(0),
460 dest_sockid: random(),
461 control_type: a2.clone()
462 }),
463 conn_addr()
464 )),
465 ),
466 NotHandled(ConnectError::HandshakeExpected(pack)) if pack == a2
467 );
468 }
469
470 #[test]
471 fn send_wrong_handshake() {
472 let mut l = test_listen();
473
474 let shake = test_conclusion();
477 assert_matches!(
478 l.handle_packet(Instant::now(), Ok((
479 build_hs_pack(shake.clone()),
480 conn_addr()
481 ))),
482 NotHandled(ConnectError::InductionExpected(s)) if s == shake
483 );
484 }
485
486 #[test]
487 fn send_induction_twice() {
488 let mut l = test_listen();
489
490 let resp = l.handle_packet(
492 Instant::now(),
493 Ok((build_hs_pack(test_induction()), conn_addr())),
494 );
495 assert_matches!(resp, SendPacket(_));
496
497 let mut shake = test_induction();
498 shake.shake_type = ShakeType::Waveahand;
499 assert_matches!(
500 l.handle_packet(Instant::now(), Ok((
501 build_hs_pack(shake.clone()),
502 conn_addr()
503 ))),
504 NotHandled(ConnectError::ConclusionExpected(nc)) if nc == shake
505 )
506 }
507
508 #[test]
509 fn send_v4_conclusion() {
510 let mut l = test_listen();
511
512 let resp = l.handle_packet(
513 Instant::now(),
514 Ok((build_hs_pack(test_induction()), conn_addr())),
515 );
516 assert_matches!(resp, SendPacket(_));
517
518 let mut c = test_conclusion();
519 c.info = HandshakeVsInfo::V4(SocketType::Datagram);
520
521 let resp = l.handle_packet(Instant::now(), Ok((build_hs_pack(c), conn_addr())));
522
523 assert_matches!(
524 resp,
525 NotHandled(ConnectError::UnsupportedProtocolVersion(4))
526 );
527 }
528
529 #[test]
530 fn send_no_ext_hs_conclusion() {
531 let mut l = test_listen();
532
533 let resp = l.handle_packet(
534 Instant::now(),
535 Ok((build_hs_pack(test_induction()), conn_addr())),
536 );
537 assert_matches!(resp, SendPacket(_));
538
539 let mut c = test_conclusion();
540 c.info = HandshakeVsInfo::V5(HsV5Info::default());
541
542 let resp = l.handle_packet(Instant::now(), Ok((build_hs_pack(c), conn_addr())));
543
544 assert_matches!(resp, NotHandled(ConnectError::ExpectedExtFlags));
545 }
546
547 #[test]
548 fn reject() {
549 let mut l = Listen::new(ConnInitSettings::default(), true);
550
551 let resp = l.handle_packet(
552 Instant::now(),
553 Ok((build_hs_pack(test_induction()), conn_addr())),
554 );
555 assert_matches!(resp, SendPacket(_));
556
557 let resp = l.handle_packet(
558 Instant::now(),
559 Ok((build_hs_pack(test_conclusion()), conn_addr())),
560 );
561 assert_matches!(resp, RequestAccess(_));
562
563 let resp = l.handle_access_control_response(
564 Instant::now(),
565 AccessControlResponse::Rejected(RejectReason::Server(ServerRejectReason::Overload)),
566 );
567 assert_matches!(
568 resp,
569 Reject(
570 _,
571 ConnectionReject::Rejecting(RejectReason::Server(ServerRejectReason::Overload)),
572 )
573 );
574 }
575
576 #[test]
577 fn advertise_key_size() {
578 let mut l = Listen::new(ConnInitSettings::default(), true);
579
580 l.handle_packet(
581 Instant::now(),
582 Ok((build_hs_pack(test_induction()), conn_addr())),
583 );
584
585 let hs_key_size = KeySize::AES256;
586
587 let shake = HandshakeControlInfo {
588 info: HandshakeVsInfo::V5(HsV5Info {
589 key_size: hs_key_size,
590 ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
591 version: SrtVersion::CURRENT,
592 flags: SrtShakeFlags::SUPPORTED,
593 send_latency: Duration::from_secs(1),
594 recv_latency: Duration::from_secs(2),
595 })),
596 ext_km: None,
597 ext_group: None,
598 sid: None,
599 }),
600 ..test_conclusion()
601 };
602
603 let hs_packet = Packet::Control(ControlPacket {
604 timestamp: TimeStamp::from_micros(0),
605 dest_sockid: random(),
606 control_type: ControlTypes::Handshake(shake),
607 });
608
609 let RequestAccess(request_access) =
610 l.handle_packet(Instant::now(), Ok((hs_packet, conn_addr())))
611 else {
612 panic!("expected a ConnectionResult::RequestAccess");
613 };
614
615 assert_eq!(request_access.key_size, hs_key_size);
616 }
617}