crab_net/
lib.rs

1use std::{fs, io::Error, iter::repeat_with, net::SocketAddr, path::Path, time::Duration};
2
3use derive_new::new;
4use log::error;
5use openssl::ssl::{SslContext, SslMethod};
6use sender::{sender_task_dtls, sender_task_tcp, sender_task_udp};
7use statistics::stats_task;
8use tokio::{
9    io::AsyncWrite,
10    net::{TcpSocket, TcpStream, UdpSocket},
11    task::JoinSet,
12    time::sleep,
13};
14use tokio_dtls_stream_sink::{Client, Session};
15use tokio_native_tls::native_tls::{Certificate, TlsConnector};
16
17mod sender;
18mod statistics;
19
20pub async fn manager(params: Parameters) {
21    let (udp, (use_tls, ca_file)) = params.connection_type;
22    if use_tls && ca_file.is_none() {
23        error!("DTLS requires CA file to verify server credentials");
24        return;
25    }
26
27    let stats_tx = stats_task(params.connections);
28
29    let mut tasks = JoinSet::new();
30    let mut start_port = params.start_port;
31
32    for id in 0..params.connections {
33        start_port += id;
34        let payload = generate_payloads(params.len);
35        let stats_tx_cloned = stats_tx.clone();
36        let ca_file = ca_file.clone();
37        if use_tls {
38            if udp {
39                let session =
40                    setup_dtls_session(start_port, params.server_addr, ca_file.unwrap()).await;
41                tasks.spawn(async move {
42                    sender_task_dtls(id, session, payload, params.rate, stats_tx_cloned).await
43                });
44            } else {
45                let stream =
46                    setup_tls_stream(start_port, params.server_addr, ca_file.unwrap()).await;
47                tasks.spawn(async move {
48                    sender_task_tcp(id, stream, payload, params.rate, stats_tx_cloned).await;
49                });
50            }
51        } else if udp {
52            let socket = setup_udp_socket(params.server_addr, start_port).await;
53            tasks.spawn(async move {
54                sender_task_udp(id, socket, payload, params.rate, stats_tx_cloned).await
55            });
56        } else {
57            let stream = setup_tcp_stream(params.server_addr, start_port).await;
58            tasks.spawn(async move {
59                sender_task_tcp(id, stream, payload, params.rate, stats_tx_cloned).await;
60            });
61        }
62        sleep(Duration::from_millis(params.sleep)).await;
63    }
64    while (tasks.join_next().await).is_some() {}
65}
66
67async fn setup_udp_socket(addr: SocketAddr, port: usize) -> UdpSocket {
68    let socket = UdpSocket::bind("0.0.0.0:".to_owned() + &port.to_string())
69        .await
70        .unwrap();
71    socket.connect(addr).await.unwrap();
72    socket
73}
74
75async fn setup_tcp_stream(addr: SocketAddr, port: usize) -> Box<TcpStream> {
76    let local_addr = ("0.0.0.0:".to_owned() + &port.to_string()).parse().unwrap();
77    let socket = TcpSocket::new_v4().unwrap();
78    socket.bind(local_addr).unwrap();
79    Box::new(socket.connect(addr).await.unwrap())
80}
81
82async fn setup_dtls_session(port: usize, addr: SocketAddr, ca_file: String) -> DtlsSession {
83    let mut ctx = SslContext::builder(SslMethod::dtls()).unwrap();
84    ctx.set_ca_file(ca_file).unwrap();
85    let socket = UdpSocket::bind("0.0.0.0:".to_owned() + &port.to_string())
86        .await
87        .unwrap();
88    let client = Client::new(socket);
89    let session = client.connect(addr, Some(ctx.build())).await.unwrap();
90    DtlsSession::new(client, session)
91}
92
93async fn setup_tls_stream(
94    port: usize,
95    addr: SocketAddr,
96    ca_file: String,
97) -> Box<dyn AsyncWrite + Unpin + Send> {
98    let pem = fs::read(Path::new(&ca_file)).unwrap();
99    let cert = Certificate::from_pem(&pem).unwrap();
100    let connector = TlsConnector::builder()
101        .add_root_certificate(cert)
102        .danger_accept_invalid_hostnames(true)
103        .build()
104        .unwrap();
105    let connector = tokio_native_tls::TlsConnector::from(connector);
106    let tcp_stream = setup_tcp_stream(addr, port).await;
107    Box::new(
108        connector
109            .connect(addr.ip().to_string().as_str(), tcp_stream)
110            .await
111            .unwrap(),
112    )
113}
114
115fn generate_payloads(len: usize) -> Vec<u8> {
116    repeat_with(|| fastrand::u8(..)).take(len).collect()
117}
118
119#[derive(new)]
120pub struct Parameters {
121    server_addr: SocketAddr,
122    rate: usize,
123    connections: usize,
124    len: usize,
125    start_port: usize,
126    sleep: u64,
127    connection_type: (bool, (bool, Option<String>)),
128}
129
130#[derive(new)]
131pub struct DtlsSession {
132    _client: Client,
133    session: Session,
134}
135
136impl DtlsSession {
137    pub async fn write(&mut self, buf: &[u8]) -> Result<(), Error> {
138        self.session.write(buf).await
139    }
140}