1use std::{
15 net::{SocketAddr, ToSocketAddrs},
16 num::{NonZeroU32, NonZeroUsize},
17};
18
19use byte_unit::{Byte, ByteUnit};
20use metrics::{counter, gauge, register_counter};
21use rand::{rngs::StdRng, SeedableRng};
22use serde::Deserialize;
23use tokio::{io::AsyncWriteExt, net::TcpStream};
24use tracing::{info, trace};
25
26use crate::{
27 block::{self, chunk_bytes, construct_block_cache, Block},
28 payload,
29 signals::Shutdown,
30 throttle::{self, Throttle},
31};
32
33use super::General;
34
35#[derive(Debug, Deserialize, PartialEq)]
36pub struct Config {
38 pub seed: [u8; 32],
40 pub addr: String,
42 pub variant: payload::Config,
44 pub bytes_per_second: byte_unit::Byte,
46 pub block_sizes: Option<Vec<byte_unit::Byte>>,
48 pub maximum_prebuild_cache_size_bytes: byte_unit::Byte,
50 #[serde(default)]
52 pub throttle: throttle::Config,
53}
54
55#[derive(thiserror::Error, Debug, Copy, Clone)]
56pub enum Error {
58 #[error("Block creation error: {0}")]
60 Block(block::Error),
61}
62
63impl From<block::Error> for Error {
64 fn from(error: block::Error) -> Self {
65 Error::Block(error)
66 }
67}
68
69#[derive(Debug)]
70pub struct Tcp {
74 addr: SocketAddr,
75 throttle: Throttle,
76 block_cache: Vec<Block>,
77 metric_labels: Vec<(String, String)>,
78 shutdown: Shutdown,
79}
80
81impl Tcp {
82 #[allow(clippy::cast_possible_truncation)]
93 pub fn new(general: General, config: &Config, shutdown: Shutdown) -> Result<Self, Error> {
94 let mut rng = StdRng::from_seed(config.seed);
95 let block_sizes: Vec<NonZeroUsize> = config
96 .block_sizes
97 .clone()
98 .unwrap_or_else(|| {
99 vec![
100 Byte::from_unit(1.0 / 32.0, ByteUnit::MB).unwrap(),
101 Byte::from_unit(1.0 / 16.0, ByteUnit::MB).unwrap(),
102 Byte::from_unit(1.0 / 8.0, ByteUnit::MB).unwrap(),
103 Byte::from_unit(1.0 / 4.0, ByteUnit::MB).unwrap(),
104 Byte::from_unit(1.0 / 2.0, ByteUnit::MB).unwrap(),
105 Byte::from_unit(1_f64, ByteUnit::MB).unwrap(),
106 Byte::from_unit(2_f64, ByteUnit::MB).unwrap(),
107 Byte::from_unit(4_f64, ByteUnit::MB).unwrap(),
108 ]
109 })
110 .iter()
111 .map(|sz| NonZeroUsize::new(sz.get_bytes() as usize).expect("bytes must be non-zero"))
112 .collect();
113 let mut labels = vec![
114 ("component".to_string(), "generator".to_string()),
115 ("component_name".to_string(), "tcp".to_string()),
116 ];
117 if let Some(id) = general.id {
118 labels.push(("id".to_string(), id));
119 }
120
121 let bytes_per_second = NonZeroU32::new(config.bytes_per_second.get_bytes() as u32).unwrap();
122 gauge!(
123 "bytes_per_second",
124 f64::from(bytes_per_second.get()),
125 &labels
126 );
127
128 let block_chunks = chunk_bytes(
129 &mut rng,
130 NonZeroUsize::new(config.maximum_prebuild_cache_size_bytes.get_bytes() as usize)
131 .expect("bytes must be non-zero"),
132 &block_sizes,
133 )?;
134 let block_cache = construct_block_cache(&mut rng, &config.variant, &block_chunks, &labels);
135
136 let addr = config
137 .addr
138 .to_socket_addrs()
139 .expect("could not convert to socket")
140 .next()
141 .unwrap();
142 Ok(Self {
143 addr,
144 block_cache,
145 throttle: Throttle::new_with_config(config.throttle, bytes_per_second, labels.clone()),
146 metric_labels: labels,
147 shutdown,
148 })
149 }
150
151 pub async fn spin(mut self) -> Result<(), Error> {
161 let mut connection = None;
162 let mut blocks = self.block_cache.iter().cycle().peekable();
163
164 let bytes_written = register_counter!("bytes_written", &self.metric_labels);
165 let packets_sent = register_counter!("packets_sent", &self.metric_labels);
166
167 loop {
168 let blk = blocks.peek().unwrap();
169 let total_bytes = blk.total_bytes;
170
171 tokio::select! {
172 conn = TcpStream::connect(self.addr), if connection.is_none() => {
173 match conn {
174 Ok(client) => {
175 connection = Some(client);
176 }
177 Err(err) => {
178 trace!("connection to {} failed: {}", self.addr, err);
179
180 let mut error_labels = self.metric_labels.clone();
181 error_labels.push(("error".to_string(), err.to_string()));
182 counter!("connection_failure", 1, &error_labels);
183 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
184 }
185 }
186 }
187 _ = self.throttle.wait_for(total_bytes), if connection.is_some() => {
188 let mut client = connection.unwrap();
189 let blk = blocks.next().unwrap(); match client.write_all(&blk.bytes).await {
191 Ok(()) => {
192 bytes_written.increment(u64::from(blk.total_bytes.get()));
193 packets_sent.increment(1);
194 connection = Some(client);
195 }
196 Err(err) => {
197 trace!("write failed: {}", err);
198
199 let mut error_labels = self.metric_labels.clone();
200 error_labels.push(("error".to_string(), err.to_string()));
201 counter!("request_failure", 1, &error_labels);
202 connection = None;
203 }
204 }
205 }
206 _ = self.shutdown.recv() => {
207 info!("shutdown signal received");
208 return Ok(());
209 },
210 }
211 }
212 }
213}