1use std::{
15 net::{SocketAddr, ToSocketAddrs},
16 num::{NonZeroU32, NonZeroUsize},
17 time::Duration,
18};
19
20use byte_unit::{Byte, ByteUnit};
21use metrics::{counter, gauge, register_counter};
22use rand::{rngs::StdRng, SeedableRng};
23use serde::Deserialize;
24use tokio::net::UdpSocket;
25use tracing::{debug, info, trace};
26
27use crate::{
28 block::{self, chunk_bytes, construct_block_cache, Block},
29 payload,
30 signals::Shutdown,
31 throttle::{self, Throttle},
32};
33
34use super::General;
35
36#[derive(Debug, Deserialize, PartialEq)]
37pub struct Config {
39 pub seed: [u8; 32],
41 pub addr: String,
43 pub variant: payload::Config,
45 pub bytes_per_second: byte_unit::Byte,
47 pub block_sizes: Option<Vec<byte_unit::Byte>>,
49 pub maximum_prebuild_cache_size_bytes: byte_unit::Byte,
51 #[serde(default)]
53 pub throttle: throttle::Config,
54}
55
56#[derive(thiserror::Error, Debug)]
58pub enum Error {
59 #[error("Creation of payload blocks failed: {0}")]
61 Block(#[from] block::Error),
62 #[error("IO error: {0}")]
64 Io(#[from] std::io::Error),
65}
66
67#[derive(Debug)]
68pub struct Udp {
72 addr: SocketAddr,
73 throttle: Throttle,
74 block_cache: Vec<Block>,
75 metric_labels: Vec<(String, String)>,
76 shutdown: Shutdown,
77}
78
79impl Udp {
80 #[allow(clippy::cast_possible_truncation)]
91 pub fn new(general: General, config: &Config, shutdown: Shutdown) -> Result<Self, Error> {
92 let mut rng = StdRng::from_seed(config.seed);
93 let block_sizes: Vec<NonZeroUsize> = config
94 .block_sizes
95 .clone()
96 .unwrap_or_else(|| {
97 vec![
98 Byte::from_unit(1.0 / 32.0, ByteUnit::MB).unwrap(),
99 Byte::from_unit(1.0 / 16.0, ByteUnit::MB).unwrap(),
100 Byte::from_unit(1.0 / 8.0, ByteUnit::MB).unwrap(),
101 Byte::from_unit(1.0 / 4.0, ByteUnit::MB).unwrap(),
102 Byte::from_unit(1.0 / 2.0, ByteUnit::MB).unwrap(),
103 Byte::from_unit(1_f64, ByteUnit::MB).unwrap(),
104 Byte::from_unit(2_f64, ByteUnit::MB).unwrap(),
105 Byte::from_unit(4_f64, ByteUnit::MB).unwrap(),
106 ]
107 })
108 .iter()
109 .map(|sz| NonZeroUsize::new(sz.get_bytes() as usize).expect("bytes must be non-zero"))
110 .collect();
111 let mut labels = vec![
112 ("component".to_string(), "generator".to_string()),
113 ("component_name".to_string(), "udp".to_string()),
114 ];
115 if let Some(id) = general.id {
116 labels.push(("id".to_string(), id));
117 }
118
119 let bytes_per_second = NonZeroU32::new(config.bytes_per_second.get_bytes() as u32).unwrap();
120 gauge!(
121 "bytes_per_second",
122 f64::from(bytes_per_second.get()),
123 &labels
124 );
125
126 let block_chunks = chunk_bytes(
127 &mut rng,
128 NonZeroUsize::new(config.maximum_prebuild_cache_size_bytes.get_bytes() as usize)
129 .expect("bytes must be non-zero"),
130 &block_sizes,
131 )?;
132 let block_cache = construct_block_cache(&mut rng, &config.variant, &block_chunks, &labels);
133
134 let addr = config
135 .addr
136 .to_socket_addrs()
137 .expect("could not convert to socket")
138 .next()
139 .unwrap();
140
141 Ok(Self {
142 addr,
143 block_cache,
144 throttle: Throttle::new_with_config(config.throttle, bytes_per_second, labels.clone()),
145 metric_labels: labels,
146 shutdown,
147 })
148 }
149
150 pub async fn spin(mut self) -> Result<(), Error> {
160 debug!("UDP generator running");
161 let mut connection = Option::<UdpSocket>::None;
162 let mut blocks = self.block_cache.iter().cycle().peekable();
163
164 let bytes_written = register_counter!("bytes_written", &self.metric_labels);
165 let packets_sent = register_counter!("packets_sent", &self.metric_labels);
166
167 loop {
168 let blk = blocks.peek().unwrap();
169 let total_bytes = blk.total_bytes;
170 assert!(
171 total_bytes.get() <= 65507,
172 "UDP packet too large (over 65507 B)"
173 );
174
175 tokio::select! {
176 conn = UdpSocket::bind("127.0.0.1:0"), if connection.is_none() => {
177 match conn {
178 Ok(sock) => {
179 debug!("UDP port bound");
180 connection = Some(sock);
181 }
182 Err(err) => {
183 trace!("binding UDP port failed: {}", err);
184
185 let mut error_labels = self.metric_labels.clone();
186 error_labels.push(("error".to_string(), err.to_string()));
187 counter!("connection_failure", 1, &error_labels);
188 tokio::time::sleep(Duration::from_secs(1)).await;
189 }
190 }
191 }
192 _ = self.throttle.wait_for(total_bytes), if connection.is_some() => {
193 let sock = connection.unwrap();
194 let blk = blocks.next().unwrap(); match sock.send_to(&blk.bytes, self.addr).await {
196 Ok(bytes) => {
197 bytes_written.increment(bytes as u64);
198 packets_sent.increment(1);
199 connection = Some(sock);
200 }
201 Err(err) => {
202 debug!("write failed: {}", err);
203
204 let mut error_labels = self.metric_labels.clone();
205 error_labels.push(("error".to_string(), err.to_string()));
206 counter!("request_failure", 1, &error_labels);
207 connection = None;
208 }
209 }
210 }
211 _ = self.shutdown.recv() => {
212 info!("shutdown signal received");
213 return Ok(());
214 },
215 }
216 }
217 }
218}