mod acknowledgements;
use std::{
num::{NonZeroU32, NonZeroUsize},
time::Duration,
};
use acknowledgements::Channels;
use byte_unit::{Byte, ByteUnit};
use http::{
header::{AUTHORIZATION, CONTENT_LENGTH},
Method, Request, Uri,
};
use hyper::{client::HttpConnector, Body, Client};
use metrics::{counter, gauge};
use once_cell::sync::OnceCell;
use rand::{prelude::StdRng, SeedableRng};
use serde::Deserialize;
use tokio::{
sync::{Semaphore, SemaphorePermit},
time::timeout,
};
use tracing::info;
use crate::{
block::{self, chunk_bytes, construct_block_cache, Block},
generator::splunk_hec::acknowledgements::Channel,
payload,
payload::SplunkHecEncoding,
signals::Shutdown,
throttle::{self, Throttle},
};
use super::General;
static CONNECTION_SEMAPHORE: OnceCell<Semaphore> = OnceCell::new();
const SPLUNK_HEC_ACKNOWLEDGEMENTS_PATH: &str = "/services/collector/ack";
const SPLUNK_HEC_JSON_PATH: &str = "/services/collector/event";
const SPLUNK_HEC_TEXT_PATH: &str = "/services/collector/raw";
const SPLUNK_HEC_CHANNEL_HEADER: &str = "x-splunk-request-channel";
#[derive(Deserialize, Debug, Clone, Copy, PartialEq)]
pub struct AckSettings {
pub ack_query_interval_seconds: u64,
pub ack_timeout_seconds: u64,
}
#[derive(Deserialize, Debug, PartialEq)]
pub struct Config {
pub seed: [u8; 32],
#[serde(with = "http_serde::uri")]
pub target_uri: Uri,
pub format: SplunkHecEncoding,
pub token: String,
pub acknowledgements: Option<AckSettings>,
pub maximum_prebuild_cache_size_bytes: byte_unit::Byte,
pub bytes_per_second: byte_unit::Byte,
pub block_sizes: Option<Vec<byte_unit::Byte>>,
pub parallel_connections: u16,
#[serde(default)]
pub throttle: throttle::Config,
}
#[derive(thiserror::Error, Debug, Clone, Copy)]
pub enum Error {
#[error("User supplied HEC path is not valid")]
InvalidHECPath,
#[error("Interior error: {0}")]
Acknowledgements(acknowledgements::Error),
#[error("Block creation error: {0}")]
Block(#[from] block::Error),
}
#[derive(Debug)]
pub struct SplunkHec {
uri: Uri,
token: String,
parallel_connections: u16,
throttle: Throttle,
block_cache: Vec<Block>,
metric_labels: Vec<(String, String)>,
channels: Channels,
shutdown: Shutdown,
}
fn get_uri_by_format(base_uri: &Uri, format: payload::SplunkHecEncoding) -> Uri {
let path = match format {
payload::SplunkHecEncoding::Text => SPLUNK_HEC_TEXT_PATH,
payload::SplunkHecEncoding::Json => SPLUNK_HEC_JSON_PATH,
};
Uri::builder()
.authority(base_uri.authority().unwrap().to_string())
.scheme("http")
.path_and_query(path)
.build()
.unwrap()
}
impl SplunkHec {
#[allow(clippy::cast_possible_truncation)]
pub fn new(general: General, config: Config, shutdown: Shutdown) -> Result<Self, Error> {
let mut rng = StdRng::from_seed(config.seed);
let block_sizes: Vec<NonZeroUsize> = config
.block_sizes
.unwrap_or_else(|| {
vec![
Byte::from_unit(1.0 / 8.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1.0 / 4.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1.0 / 2.0, ByteUnit::MB).unwrap(),
Byte::from_unit(1_f64, ByteUnit::MB).unwrap(),
Byte::from_unit(2_f64, ByteUnit::MB).unwrap(),
Byte::from_unit(4_f64, ByteUnit::MB).unwrap(),
]
})
.iter()
.map(|sz| NonZeroUsize::new(sz.get_bytes() as usize).expect("bytes must be non-zero"))
.collect();
let mut labels = vec![
("component".to_string(), "generator".to_string()),
("component_name".to_string(), "splunk_hec".to_string()),
];
if let Some(id) = general.id {
labels.push(("id".to_string(), id));
}
let bytes_per_second = NonZeroU32::new(config.bytes_per_second.get_bytes() as u32).unwrap();
gauge!(
"bytes_per_second",
f64::from(bytes_per_second.get()),
&labels
);
let uri = get_uri_by_format(&config.target_uri, config.format);
let block_chunks = chunk_bytes(
&mut rng,
NonZeroUsize::new(config.maximum_prebuild_cache_size_bytes.get_bytes() as usize)
.expect("bytes must be non-zero"),
&block_sizes,
)?;
let payload_config = payload::Config::SplunkHec {
encoding: config.format,
};
let block_cache = construct_block_cache(&mut rng, &payload_config, &block_chunks, &labels);
let mut channels = Channels::new(config.parallel_connections);
if let Some(ack_settings) = config.acknowledgements {
let ack_uri = Uri::builder()
.authority(uri.authority().unwrap().to_string())
.scheme("http")
.path_and_query(SPLUNK_HEC_ACKNOWLEDGEMENTS_PATH)
.build()
.unwrap();
channels.enable_acknowledgements(ack_uri, config.token.clone(), ack_settings);
}
CONNECTION_SEMAPHORE
.set(Semaphore::new(config.parallel_connections as usize))
.unwrap();
Ok(Self {
channels,
parallel_connections: config.parallel_connections,
uri,
token: config.token,
block_cache,
throttle: Throttle::new_with_config(config.throttle, bytes_per_second, labels.clone()),
metric_labels: labels,
shutdown,
})
}
pub async fn spin(mut self) -> Result<(), Error> {
let client: Client<HttpConnector, Body> = Client::builder()
.pool_max_idle_per_host(self.parallel_connections as usize)
.retry_canceled_requests(false)
.set_host(false)
.build_http();
let uri = self.uri;
let labels = self.metric_labels;
gauge!(
"maximum_requests",
f64::from(self.parallel_connections),
&labels
);
let mut blocks = self.block_cache.iter().cycle().peekable();
let mut channels = self.channels.iter().cycle();
loop {
let channel: Channel = channels.next().unwrap().clone();
let blk = blocks.peek().unwrap();
let total_bytes = blk.total_bytes;
tokio::select! {
_ = self.throttle.wait_for(total_bytes) => {
let client = client.clone();
let labels = labels.clone();
let uri = uri.clone();
let blk = blocks.next().unwrap(); let body = Body::from(blk.bytes.clone());
let block_length = blk.bytes.len();
let request: Request<Body> = Request::builder()
.method(Method::POST)
.uri(uri)
.header(AUTHORIZATION, format!("Splunk {}", self.token))
.header(CONTENT_LENGTH, block_length)
.header(SPLUNK_HEC_CHANNEL_HEADER, channel.id())
.body(body)
.unwrap();
let permit = CONNECTION_SEMAPHORE.get().unwrap().acquire().await.unwrap();
tokio::spawn(send_hec_request(permit, block_length, labels, channel, client, request, self.shutdown.clone()));
}
_ = self.shutdown.recv() => {
info!("shutdown signal received");
return Ok(());
},
}
}
}
}
async fn send_hec_request(
permit: SemaphorePermit<'_>,
block_length: usize,
labels: Vec<(String, String)>,
channel: Channel,
client: Client<HttpConnector>,
request: Request<Body>,
mut shutdown: Shutdown,
) {
counter!("requests_sent", 1, &labels);
let work = client.request(request);
tokio::select! {
tm = timeout(Duration::from_secs(1), work) => {
match tm {
Ok(tm) => match tm {
Ok(response) => {
counter!("bytes_written", block_length as u64, &labels);
let (parts, body) = response.into_parts();
let status = parts.status;
let mut status_labels = labels.clone();
status_labels.push(("status_code".to_string(), status.as_u16().to_string()));
counter!("request_ok", 1, &status_labels);
channel
.send(async {
let body_bytes = hyper::body::to_bytes(body).await.unwrap();
let hec_ack_response =
serde_json::from_slice::<HecAckResponse>(&body_bytes).unwrap();
hec_ack_response.ack_id
})
.await;
}
Err(err) => {
let mut error_labels = labels.clone();
error_labels.push(("error".to_string(), err.to_string()));
counter!("request_failure", 1, &error_labels);
}
}
Err(err) => {
let mut error_labels = labels.clone();
error_labels.push(("error".to_string(), err.to_string()));
counter!("request_timeout", 1, &error_labels);
}
}
}
_ = shutdown.recv() => {},
}
drop(permit);
}
#[derive(Deserialize, Debug)]
struct HecAckResponse {
#[allow(dead_code)]
text: String,
#[allow(dead_code)]
code: u8,
#[serde(rename = "ackId")]
ack_id: u64,
}