use crate::{
block::{self, chunk_bytes, construct_block_cache, Block},
payload,
signals::Shutdown,
throttle::{self, Throttle},
};
use byte_unit::{Byte, ByteUnit};
use metrics::{counter, gauge, register_counter};
use rand::{rngs::StdRng, SeedableRng};
use serde::Deserialize;
use std::{
num::{NonZeroU32, NonZeroUsize},
path::PathBuf,
};
use tokio::{net, task::JoinError};
use tracing::{debug, error, info};
use super::General;
#[derive(Debug, Deserialize, PartialEq)]
pub struct Config {
pub seed: [u8; 32],
pub path: PathBuf,
pub variant: payload::Config,
pub bytes_per_second: byte_unit::Byte,
pub block_sizes: Option<Vec<byte_unit::Byte>>,
pub maximum_prebuild_cache_size_bytes: byte_unit::Byte,
#[serde(default)]
pub throttle: throttle::Config,
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Creation of payload blocks failed: {0}")]
Block(#[from] block::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Subtask failure: {0}")]
Subtask(#[from] JoinError),
}
#[derive(Debug)]
pub struct UnixStream {
path: PathBuf,
throttle: Throttle,
block_cache: Vec<Block>,
metric_labels: Vec<(String, String)>,
shutdown: Shutdown,
}
impl UnixStream {
#[allow(clippy::cast_possible_truncation)]
pub fn new(general: General, config: Config, shutdown: Shutdown) -> Result<Self, Error> {
let mut rng = StdRng::from_seed(config.seed);
let block_sizes: Vec<NonZeroUsize> = config
.block_sizes
.clone()
.unwrap_or_else(|| {
vec![
Byte::from_unit(1.0 / 32.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1.0 / 16.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1.0 / 8.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1.0 / 4.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1.0 / 2.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1_f64, ByteUnit::MB).unwrap(),
Byte::from_unit(2_f64, ByteUnit::MB).unwrap(),
Byte::from_unit(4_f64, ByteUnit::MB).unwrap(),
]
})
.iter()
.map(|sz| NonZeroUsize::new(sz.get_bytes() as usize).expect("bytes must be non-zero"))
.collect();
let mut labels = vec![
("component".to_string(), "generator".to_string()),
("component_name".to_string(), "unix_stream".to_string()),
];
if let Some(id) = general.id {
labels.push(("id".to_string(), id));
}
let bytes_per_second = NonZeroU32::new(config.bytes_per_second.get_bytes() as u32).unwrap();
gauge!(
"bytes_per_second",
f64::from(bytes_per_second.get()),
&labels
);
let block_chunks = chunk_bytes(
&mut rng,
NonZeroUsize::new(config.maximum_prebuild_cache_size_bytes.get_bytes() as usize)
.expect("bytes must be non-zero"),
&block_sizes,
)?;
let block_cache = construct_block_cache(&mut rng, &config.variant, &block_chunks, &labels);
Ok(Self {
path: config.path,
block_cache,
throttle: Throttle::new_with_config(config.throttle, bytes_per_second, labels.clone()),
metric_labels: labels,
shutdown,
})
}
pub async fn spin(mut self) -> Result<(), Error> {
debug!("UnixStream generator running");
let mut blocks = self.block_cache.iter().cycle().peekable();
let mut unix_stream = Option::<net::UnixStream>::None;
let bytes_written = register_counter!("bytes_written", &self.metric_labels);
let packets_sent = register_counter!("packets_sent", &self.metric_labels);
loop {
let blk = blocks.peek().unwrap();
let total_bytes = blk.total_bytes;
tokio::select! {
sock = net::UnixStream::connect(&self.path), if unix_stream.is_none() => {
match sock {
Ok(stream) => {
debug!("UDS socket opened for writing.");
unix_stream = Some(stream);
}
Err(err) => {
error!("Opening UDS path failed: {}", err);
let mut error_labels = self.metric_labels.clone();
error_labels.push(("error".to_string(), err.to_string()));
counter!("connection_failure", 1, &error_labels);
}
}
}
_ = self.throttle.wait_for(total_bytes), if unix_stream.is_some() => {
let blk_max: usize = total_bytes.get() as usize;
let mut blk_offset = 0;
let blk = blocks.next().unwrap(); while blk_offset < blk_max {
let stream = unix_stream.unwrap();
unix_stream = None;
let ready = stream
.ready(tokio::io::Interest::WRITABLE)
.await
.map_err(Error::Io)
.unwrap(); if ready.is_writable() {
match stream.try_write(&blk.bytes[blk_offset..]) {
Ok(bytes) => {
bytes_written.increment(bytes as u64);
packets_sent.increment(1);
blk_offset = bytes;
}
Err(ref e) if e.kind() == tokio::io::ErrorKind::WouldBlock => {
tokio::task::yield_now().await;
}
Err(err) => {
debug!("write failed: {}", err);
let mut error_labels = self.metric_labels.clone();
error_labels.push(("error".to_string(), err.to_string()));
counter!("request_failure", 1, &error_labels);
break;
}
}
}
unix_stream = Some(stream);
}
}
_ = self.shutdown.recv() => {
info!("shutdown signal received");
return Ok(());
},
}
}
}
}