use super::*;
use crate::stats::err_printer::ErrPrinter;
use std::io::Write;
pub fn init_controller<C: Config + 'static>(
config: &'static C,
) -> (
JoinHandle<()>,
flume::Sender<StatType>,
Arc<AtomicBool>,
Arc<AtomicBool>,
) {
log::trace!("Initializing stats controller");
let mut stats = Controller::new(config);
let stats_send_chan = stats.send_channel();
let thread_stop_flag = stats.end_processing_flag();
let any_errors_flag = stats.any_errors_flag();
let stats_thread = Builder::new()
.name("stats_thread".to_string())
.spawn(move || {
stats.run();
})
.expect("Failed to spawn stats thread");
(
stats_thread,
stats_send_chan,
thread_stop_flag,
any_errors_flag,
)
}
pub struct Controller<C: Config + 'static> {
stats_collector: StatsCollector,
pub processing_time: Instant,
config: &'static C,
max_tolerate_errors: u32,
stats_recv_chan: flume::Receiver<StatType>,
stats_send_chan: Option<flume::Sender<StatType>>,
end_processing_flag: Arc<AtomicBool>,
any_errors_flag: Arc<AtomicBool>,
spinner: Option<ProgressBar>,
spinner_message: String,
}
impl<C: Config + 'static> Controller<C> {
pub fn new(global_config: &'static C) -> Self {
let (stats_send_chan, stats_recv_chan): (
flume::Sender<StatType>,
flume::Receiver<StatType>,
) = flume::unbounded();
Controller {
stats_collector: if global_config.alpide_checks_enabled() {
StatsCollector::with_alpide_stats()
} else {
StatsCollector::default()
},
config: global_config,
processing_time: Instant::now(),
max_tolerate_errors: global_config.max_tolerate_errors(),
stats_recv_chan,
stats_send_chan: Some(stats_send_chan),
end_processing_flag: Arc::new(AtomicBool::new(false)),
any_errors_flag: Arc::new(AtomicBool::new(false)),
spinner: if global_config.view().is_some() {
None
} else {
Some(new_styled_spinner())
},
spinner_message: String::new(),
}
}
pub fn send_channel(&self) -> flume::Sender<StatType> {
if self.stats_send_chan.is_none() {
log::error!("Controller send channel is none, most likely it is already running and does not accept new producers");
panic!("Controller send channel is none, most likely it is already running and does not accept new producers");
}
self.stats_send_chan.as_ref().unwrap().clone()
}
pub fn end_processing_flag(&self) -> Arc<AtomicBool> {
self.end_processing_flag.clone()
}
pub fn any_errors_flag(&self) -> Arc<AtomicBool> {
self.any_errors_flag.clone()
}
pub fn run(&mut self) {
self.stats_send_chan = None;
while let Ok(stats_update) = self.stats_recv_chan.recv() {
self.update(stats_update);
}
if self.config.custom_checks_enabled() {
self.stats_collector.validate_custom_stats(self.config);
}
if self.config.view().is_some() || self.config.output_mode() == DataOutputMode::Stdout {
log::info!("View active or output is being piped, skipping report summary printout.")
} else {
self.process_stats();
if self.stats_collector.any_rdhs_seen() {
self.new_spinner_with_prefix("Generating report".to_string());
self.print();
}
}
if self.stats_collector.any_errors() {
self.any_errors_flag.store(true, Ordering::SeqCst);
}
if self.config.stats_output_mode() != DataOutputMode::None {
self.stats_collector.write_stats(
&self.config.stats_output_mode(),
self.config.stats_output_format().unwrap(),
);
}
if let Some(input_stats) = self.config.input_stats_file() {
log::info!("Validating input stats file against collected stats");
let input_stats_str =
fs::read_to_string(input_stats).expect("Failed to read input stats file");
let input_stats_collector: StatsCollector = if input_stats.extension().unwrap()
== "json"
{
serde_json::from_str(&input_stats_str)
.expect("Failed to deserialize input stats file")
} else if input_stats.extension().unwrap() == "toml" {
toml::from_str(&input_stats_str).expect("Failed to deserialize input stats file")
} else {
panic!("Invalid input stats file extension, must be .json or .toml")
};
if self
.stats_collector
.validate_other_stats(&input_stats_collector, self.config.mute_errors())
.is_err()
{
self.any_errors_flag.store(true, Ordering::SeqCst);
log::warn!("Input stats did not match collected stats");
} else {
log::info!("Input stats matched collected stats");
}
}
}
fn update(&mut self, stat: StatType) {
match stat {
StatType::RDHSeen(_)
| StatType::RDHFiltered(_)
| StatType::PayloadSize(_)
| StatType::LinksObserved(_)
| StatType::RdhVersion(_)
| StatType::DataFormat(_)
| StatType::LayerStaveSeen { .. }
| StatType::SystemId(_)
| StatType::FeeId(_)
| StatType::TriggerType(_)
| StatType::AlpideStats(_) => {
self.stats_collector.collect(stat);
}
StatType::HBFsSeen(_) => {
self.stats_collector.collect(stat);
if self.spinner.is_some() {
self.spinner.as_mut().unwrap().set_prefix(format!(
"Analyzing {hbfs} HBFs",
hbfs = self.stats_collector.hbfs_seen()
))
};
}
StatType::RunTriggerType((raw_tt, tt_str)) => {
log::debug!("Run trigger type determined to be {raw_tt:#0x}: {tt_str}");
self.stats_collector
.collect(StatType::RunTriggerType((raw_tt, tt_str)));
}
StatType::Error(msg) => {
if self.stats_collector.any_fatal_err() {
log::trace!("Fatal error already seen, ignoring error: {msg}");
return;
}
self.stats_collector.collect(StatType::Error(msg));
self.set_spinner_msg(
format!(
"{err_cnt} Errors in data!",
err_cnt = self.stats_collector.err_count()
)
.red()
.to_string(),
);
if self.max_tolerate_errors > 0 {
log::trace!("Error count: {}", self.stats_collector.err_count());
if self.stats_collector.err_count() == self.max_tolerate_errors as u64 {
log::trace!("Errors reached maximum tolerated errors, exiting...");
self.end_processing_flag.store(true, Ordering::SeqCst);
}
}
}
StatType::Fatal(err) => {
if self.stats_collector.any_fatal_err() {
log::trace!("Fatal error already seen, ignoring error: {err}");
return;
}
self.end_processing_flag.store(true, Ordering::SeqCst);
log::error!("FATAL: {err}\nShutting down...");
self.stats_collector.collect(StatType::Fatal(err));
}
}
}
fn process_stats(&mut self) {
if self.stats_collector.err_count() > 0 {
self.new_spinner_with_prefix(
format!(
"Processing {err_count} error messages",
err_count = self.stats_collector.err_count()
)
.yellow()
.to_string(),
);
self.stats_collector.finalize(self.config.mute_errors());
self.spinner.as_mut().unwrap().abandon();
} else {
self.stats_collector.finalize(self.config.mute_errors());
}
if self.stats_collector.any_errors() && !self.config.mute_errors() {
ErrPrinter::new(
if self.config.max_tolerate_errors() > 0 {
Some(self.config.max_tolerate_errors())
} else {
None
},
self.config.error_code_filter(),
)
.print(
self.stats_collector.error_stats().errors_as_slice_iter(),
self.stats_collector.unique_error_codes_as_slice(),
);
}
}
fn print(&mut self) {
let mut report = stats::stats_report::make_report(
self.processing_time.elapsed(),
&mut self.stats_collector,
self.config.filter_target(),
);
self.append_spinner_msg("... completed");
if self.spinner.is_some() {
self.spinner.as_mut().unwrap().abandon();
}
let mut lock = io::stdout().lock();
if let Err(e) = writeln!(lock, "{}", report.format()) {
if e.kind() == io::ErrorKind::BrokenPipe {
log::warn!("Broken pipe, stdout was closed before report could be written");
} else {
log::error!("Failed to write report to stdout: {e}");
}
}
}
fn new_spinner_with_prefix(&mut self, prefix: String) {
if self.spinner.is_some() {
self.append_spinner_msg("... completed");
self.spinner.as_mut().unwrap().abandon();
self.spinner = Some(new_styled_spinner());
self.spinner_message = "".to_string();
self.spinner.as_mut().unwrap().set_prefix(prefix);
} else {
self.spinner = Some(new_styled_spinner());
self.spinner_message = "".to_string();
self.spinner.as_mut().unwrap().set_prefix(prefix);
}
}
fn set_spinner_msg(&mut self, new_msg: String) {
if self.spinner.is_some() {
self.spinner_message = new_msg;
self.spinner
.as_mut()
.unwrap()
.set_message(self.spinner_message.clone());
}
}
fn append_spinner_msg(&mut self, to_append: &str) {
if self.spinner.is_some() {
self.spinner_message = self.spinner_message.clone() + to_append + " ";
self.spinner
.as_mut()
.unwrap()
.set_message(self.spinner_message.clone());
}
}
}
fn new_styled_spinner() -> ProgressBar {
let spinner_style =
ProgressStyle::with_template("{spinner} [ {prefix:.bold.blue} ] {wide_msg}")
.unwrap()
.tick_strings(&[
"▹▹▹▹▹",
"▸▹▹▹▹",
"▹▸▹▹▹",
"▹▹▸▹▹",
"▹▹▹▸▹",
"▹▹▹▹▸",
"▪▪▪▪▪",
]);
let pb = ProgressBar::new_spinner();
pb.set_style(spinner_style);
pb.enable_steady_tick(Duration::from_millis(120));
pb
}
#[cfg(test)]
mod tests {
use super::*;
static CONFIG_TEST_INIT_CONTROLLER: OnceLock<MockConfig> = OnceLock::new();
#[test]
fn test_init_controller() {
let mock_config = MockConfig::default();
CONFIG_TEST_INIT_CONTROLLER.set(mock_config).unwrap();
let (handle, send_ch, stop_flag, _errors_flag) =
init_controller(CONFIG_TEST_INIT_CONTROLLER.get().unwrap());
assert!(!stop_flag.load(Ordering::SeqCst));
send_ch.send(StatType::RdhVersion(7)).unwrap();
send_ch.send(StatType::DataFormat(99)).unwrap();
send_ch
.send(StatType::RunTriggerType((0xBEEF, "BEEF".to_owned().into())))
.unwrap();
send_ch.send(StatType::RDHSeen(1)).unwrap();
send_ch
.send(StatType::Fatal("Test fatal error".to_string().into()))
.unwrap();
drop(send_ch);
handle.join().unwrap();
assert!(stop_flag.load(Ordering::SeqCst));
}
}