lading/generator/
tcp.rs

1//! The TCP protocol speaking generator.
2//!
3//! ## Metrics
4//!
5//! `bytes_written`: Bytes sent successfully
6//! `packets_sent`: Packets sent successfully
7//! `request_failure`: Number of failed writes; each occurrence causes a reconnect
8//! `connection_failure`: Number of connection failures
9//! `bytes_per_second`: Configured rate to send data
10//!
11//! Additional metrics may be emitted by this generator's [throttle].
12//!
13
14use std::{
15    net::{SocketAddr, ToSocketAddrs},
16    num::{NonZeroU32, NonZeroUsize},
17};
18
19use byte_unit::{Byte, ByteUnit};
20use metrics::{counter, gauge, register_counter};
21use rand::{rngs::StdRng, SeedableRng};
22use serde::Deserialize;
23use tokio::{io::AsyncWriteExt, net::TcpStream};
24use tracing::{info, trace};
25
26use crate::{
27    block::{self, chunk_bytes, construct_block_cache, Block},
28    payload,
29    signals::Shutdown,
30    throttle::{self, Throttle},
31};
32
33use super::General;
34
35#[derive(Debug, Deserialize, PartialEq)]
36/// Configuration of this generator.
37pub struct Config {
38    /// The seed for random operations against this target
39    pub seed: [u8; 32],
40    /// The address for the target, must be a valid SocketAddr
41    pub addr: String,
42    /// The payload variant
43    pub variant: payload::Config,
44    /// The bytes per second to send or receive from the target
45    pub bytes_per_second: byte_unit::Byte,
46    /// The block sizes for messages to this target
47    pub block_sizes: Option<Vec<byte_unit::Byte>>,
48    /// The maximum size in bytes of the cache of prebuilt messages
49    pub maximum_prebuild_cache_size_bytes: byte_unit::Byte,
50    /// The load throttle configuration
51    #[serde(default)]
52    pub throttle: throttle::Config,
53}
54
55#[derive(thiserror::Error, Debug, Copy, Clone)]
56/// Errors produced by [`Tcp`].
57pub enum Error {
58    /// Creation of payload blocks failed.
59    #[error("Block creation error: {0}")]
60    Block(block::Error),
61}
62
63impl From<block::Error> for Error {
64    fn from(error: block::Error) -> Self {
65        Error::Block(error)
66    }
67}
68
69#[derive(Debug)]
70/// The TCP generator.
71///
72/// This generator is responsible for connecting to the target via TCP
73pub struct Tcp {
74    addr: SocketAddr,
75    throttle: Throttle,
76    block_cache: Vec<Block>,
77    metric_labels: Vec<(String, String)>,
78    shutdown: Shutdown,
79}
80
81impl Tcp {
82    /// Create a new [`Tcp`] instance
83    ///
84    /// # Errors
85    ///
86    /// Creation will fail if the underlying governor capacity exceeds u32.
87    ///
88    /// # Panics
89    ///
90    /// Function will panic if user has passed zero values for any byte
91    /// values. Sharp corners.
92    #[allow(clippy::cast_possible_truncation)]
93    pub fn new(general: General, config: &Config, shutdown: Shutdown) -> Result<Self, Error> {
94        let mut rng = StdRng::from_seed(config.seed);
95        let block_sizes: Vec<NonZeroUsize> = config
96            .block_sizes
97            .clone()
98            .unwrap_or_else(|| {
99                vec![
100                    Byte::from_unit(1.0 / 32.0, ByteUnit::MB).unwrap(),
101                    Byte::from_unit(1.0 / 16.0, ByteUnit::MB).unwrap(),
102                    Byte::from_unit(1.0 / 8.0, ByteUnit::MB).unwrap(),
103                    Byte::from_unit(1.0 / 4.0, ByteUnit::MB).unwrap(),
104                    Byte::from_unit(1.0 / 2.0, ByteUnit::MB).unwrap(),
105                    Byte::from_unit(1_f64, ByteUnit::MB).unwrap(),
106                    Byte::from_unit(2_f64, ByteUnit::MB).unwrap(),
107                    Byte::from_unit(4_f64, ByteUnit::MB).unwrap(),
108                ]
109            })
110            .iter()
111            .map(|sz| NonZeroUsize::new(sz.get_bytes() as usize).expect("bytes must be non-zero"))
112            .collect();
113        let mut labels = vec![
114            ("component".to_string(), "generator".to_string()),
115            ("component_name".to_string(), "tcp".to_string()),
116        ];
117        if let Some(id) = general.id {
118            labels.push(("id".to_string(), id));
119        }
120
121        let bytes_per_second = NonZeroU32::new(config.bytes_per_second.get_bytes() as u32).unwrap();
122        gauge!(
123            "bytes_per_second",
124            f64::from(bytes_per_second.get()),
125            &labels
126        );
127
128        let block_chunks = chunk_bytes(
129            &mut rng,
130            NonZeroUsize::new(config.maximum_prebuild_cache_size_bytes.get_bytes() as usize)
131                .expect("bytes must be non-zero"),
132            &block_sizes,
133        )?;
134        let block_cache = construct_block_cache(&mut rng, &config.variant, &block_chunks, &labels);
135
136        let addr = config
137            .addr
138            .to_socket_addrs()
139            .expect("could not convert to socket")
140            .next()
141            .unwrap();
142        Ok(Self {
143            addr,
144            block_cache,
145            throttle: Throttle::new_with_config(config.throttle, bytes_per_second, labels.clone()),
146            metric_labels: labels,
147            shutdown,
148        })
149    }
150
151    /// Run [`Tcp`] to completion or until a shutdown signal is received.
152    ///
153    /// # Errors
154    ///
155    /// Function will return an error when the TCP socket cannot be written to.
156    ///
157    /// # Panics
158    ///
159    /// Function will panic if underlying byte capacity is not available.
160    pub async fn spin(mut self) -> Result<(), Error> {
161        let mut connection = None;
162        let mut blocks = self.block_cache.iter().cycle().peekable();
163
164        let bytes_written = register_counter!("bytes_written", &self.metric_labels);
165        let packets_sent = register_counter!("packets_sent", &self.metric_labels);
166
167        loop {
168            let blk = blocks.peek().unwrap();
169            let total_bytes = blk.total_bytes;
170
171            tokio::select! {
172                conn = TcpStream::connect(self.addr), if connection.is_none() => {
173                    match conn {
174                        Ok(client) => {
175                            connection = Some(client);
176                        }
177                        Err(err) => {
178                            trace!("connection to {} failed: {}", self.addr, err);
179
180                            let mut error_labels = self.metric_labels.clone();
181                            error_labels.push(("error".to_string(), err.to_string()));
182                            counter!("connection_failure", 1, &error_labels);
183                            tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
184                        }
185                    }
186                }
187                _ = self.throttle.wait_for(total_bytes), if connection.is_some() => {
188                    let mut client = connection.unwrap();
189                    let blk = blocks.next().unwrap(); // actually advance through the blocks
190                    match client.write_all(&blk.bytes).await {
191                        Ok(()) => {
192                            bytes_written.increment(u64::from(blk.total_bytes.get()));
193                            packets_sent.increment(1);
194                            connection = Some(client);
195                        }
196                        Err(err) => {
197                            trace!("write failed: {}", err);
198
199                            let mut error_labels = self.metric_labels.clone();
200                            error_labels.push(("error".to_string(), err.to_string()));
201                            counter!("request_failure", 1, &error_labels);
202                            connection = None;
203                        }
204                    }
205                }
206                _ = self.shutdown.recv() => {
207                    info!("shutdown signal received");
208                    return Ok(());
209                },
210            }
211        }
212    }
213}