1use std::{fs, io::Error, iter::repeat_with, net::SocketAddr, path::Path, time::Duration};
2
3use derive_new::new;
4use log::error;
5use openssl::ssl::{SslContext, SslMethod};
6use sender::{sender_task_dtls, sender_task_tcp, sender_task_udp};
7use statistics::stats_task;
8use tokio::{
9 io::AsyncWrite,
10 net::{TcpSocket, TcpStream, UdpSocket},
11 task::JoinSet,
12 time::sleep,
13};
14use tokio_dtls_stream_sink::{Client, Session};
15use tokio_native_tls::native_tls::{Certificate, TlsConnector};
16
17mod sender;
18mod statistics;
19
20pub async fn manager(params: Parameters) {
21 let (udp, (use_tls, ca_file)) = params.connection_type;
22 if use_tls && ca_file.is_none() {
23 error!("DTLS requires CA file to verify server credentials");
24 return;
25 }
26
27 let stats_tx = stats_task(params.connections);
28
29 let mut tasks = JoinSet::new();
30 let mut start_port = params.start_port;
31
32 for id in 0..params.connections {
33 start_port += id;
34 let payload = generate_payloads(params.len);
35 let stats_tx_cloned = stats_tx.clone();
36 let ca_file = ca_file.clone();
37 if use_tls {
38 if udp {
39 let session =
40 setup_dtls_session(start_port, params.server_addr, ca_file.unwrap()).await;
41 tasks.spawn(async move {
42 sender_task_dtls(id, session, payload, params.rate, stats_tx_cloned).await
43 });
44 } else {
45 let stream =
46 setup_tls_stream(start_port, params.server_addr, ca_file.unwrap()).await;
47 tasks.spawn(async move {
48 sender_task_tcp(id, stream, payload, params.rate, stats_tx_cloned).await;
49 });
50 }
51 } else if udp {
52 let socket = setup_udp_socket(params.server_addr, start_port).await;
53 tasks.spawn(async move {
54 sender_task_udp(id, socket, payload, params.rate, stats_tx_cloned).await
55 });
56 } else {
57 let stream = setup_tcp_stream(params.server_addr, start_port).await;
58 tasks.spawn(async move {
59 sender_task_tcp(id, stream, payload, params.rate, stats_tx_cloned).await;
60 });
61 }
62 sleep(Duration::from_millis(params.sleep)).await;
63 }
64 while (tasks.join_next().await).is_some() {}
65}
66
67async fn setup_udp_socket(addr: SocketAddr, port: usize) -> UdpSocket {
68 let socket = UdpSocket::bind("0.0.0.0:".to_owned() + &port.to_string())
69 .await
70 .unwrap();
71 socket.connect(addr).await.unwrap();
72 socket
73}
74
75async fn setup_tcp_stream(addr: SocketAddr, port: usize) -> Box<TcpStream> {
76 let local_addr = ("0.0.0.0:".to_owned() + &port.to_string()).parse().unwrap();
77 let socket = TcpSocket::new_v4().unwrap();
78 socket.bind(local_addr).unwrap();
79 Box::new(socket.connect(addr).await.unwrap())
80}
81
82async fn setup_dtls_session(port: usize, addr: SocketAddr, ca_file: String) -> DtlsSession {
83 let mut ctx = SslContext::builder(SslMethod::dtls()).unwrap();
84 ctx.set_ca_file(ca_file).unwrap();
85 let socket = UdpSocket::bind("0.0.0.0:".to_owned() + &port.to_string())
86 .await
87 .unwrap();
88 let client = Client::new(socket);
89 let session = client.connect(addr, Some(ctx.build())).await.unwrap();
90 DtlsSession::new(client, session)
91}
92
93async fn setup_tls_stream(
94 port: usize,
95 addr: SocketAddr,
96 ca_file: String,
97) -> Box<dyn AsyncWrite + Unpin + Send> {
98 let pem = fs::read(Path::new(&ca_file)).unwrap();
99 let cert = Certificate::from_pem(&pem).unwrap();
100 let connector = TlsConnector::builder()
101 .add_root_certificate(cert)
102 .danger_accept_invalid_hostnames(true)
103 .build()
104 .unwrap();
105 let connector = tokio_native_tls::TlsConnector::from(connector);
106 let tcp_stream = setup_tcp_stream(addr, port).await;
107 Box::new(
108 connector
109 .connect(addr.ip().to_string().as_str(), tcp_stream)
110 .await
111 .unwrap(),
112 )
113}
114
115fn generate_payloads(len: usize) -> Vec<u8> {
116 repeat_with(|| fastrand::u8(..)).take(len).collect()
117}
118
119#[derive(new)]
120pub struct Parameters {
121 server_addr: SocketAddr,
122 rate: usize,
123 connections: usize,
124 len: usize,
125 start_port: usize,
126 sleep: u64,
127 connection_type: (bool, (bool, Option<String>)),
128}
129
130#[derive(new)]
131pub struct DtlsSession {
132 _client: Client,
133 session: Session,
134}
135
136impl DtlsSession {
137 pub async fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
138 self.session.write(buf).await
139 }
140}