use std::thread::{self, JoinHandle};
use flume::{self, Receiver, Sender};
use log::{LevelFilter, Log, Metadata, Record};
use serializable_log_record::{into_log_record, SerializableLogRecord};
enum MsgType {
Data(SerializableLogRecord),
Flush,
Shutdown,
}
#[derive(Debug, Copy, Clone)]
pub enum ParallelMode {
Sequential,
Parallel,
}
#[derive(Debug)]
pub struct ParallelLogger {
tx: Sender<MsgType>,
log_level: LevelFilter,
}
pub struct ShutdownHandle {
tx: Sender<MsgType>,
join_handles: Vec<JoinHandle<()>>,
}
impl ShutdownHandle {
pub fn shutdown(self) -> i32 {
if let Err(e) = self.tx.send(MsgType::Shutdown) {
eprintln!("An internal error occurred in ParallelLogger: {e}");
return 0;
}
let mut successful_joins = 0;
for join_handle in self.join_handles {
if let Err(e) = join_handle.join() {
eprintln!("An internal error occurred while shutting down ParallelLogger: {e:?}");
} else {
successful_joins += 1;
}
}
successful_joins
}
}
impl ParallelLogger {
pub fn init(log_level: LevelFilter, mode: ParallelMode, actual_loggers: Vec<Box<dyn Log>>) -> ShutdownHandle {
assert!(
!actual_loggers.is_empty(),
"Failed to initialize ParallelLogger: No actual loggers provided"
);
let (tx, rx) = flume::unbounded();
let mut join_handles = Vec::with_capacity(actual_loggers.len());
match mode {
ParallelMode::Sequential => {
let join_handle = Self::start_thread(rx, actual_loggers);
join_handles.push(join_handle);
}
ParallelMode::Parallel => {
let mut counter = 0;
for logger in actual_loggers {
let join_handle = Self::start_thread_single(rx.clone(), logger, counter);
join_handles.push(join_handle);
counter += 1;
}
}
};
let tpl = Self {
tx: tx.clone(),
log_level,
};
log::set_boxed_logger(Box::new(tpl)).unwrap();
log::set_max_level(log_level);
ShutdownHandle { tx, join_handles }
}
fn start_thread_single(rx: Receiver<MsgType>, actual_logger: Box<dyn Log>, counter: i32) -> JoinHandle<()> {
thread::Builder::new()
.name(format!("ParallelLogger-Thread-{counter}"))
.spawn(move || {
while let Ok(message) = rx.recv() {
match message {
MsgType::Data(message) => Self::log_record(&message, &*actual_logger),
MsgType::Flush => actual_logger.flush(),
MsgType::Shutdown => break,
};
}
})
.unwrap()
}
fn start_thread(rx: Receiver<MsgType>, actual_loggers: Vec<Box<dyn Log>>) -> JoinHandle<()> {
thread::Builder::new()
.name("ParallelLogger-Thread-0".to_owned())
.spawn(move || {
while let Ok(message) = rx.recv() {
match message {
MsgType::Data(message) => {
for actual_logger in &actual_loggers {
Self::log_record(&message, actual_logger);
}
}
MsgType::Flush => {
for actual_logger in &actual_loggers {
actual_logger.flush();
}
}
MsgType::Shutdown => break,
};
}
})
.unwrap()
}
fn log_record(message: &SerializableLogRecord, actual_logger: &dyn Log) {
let mut builder = Record::builder();
actual_logger.log(&into_log_record!(builder, message));
}
fn send(&self, msg: MsgType) {
if let Err(e) = self.tx.send(msg) {
eprintln!("An internal error occurred in ParallelLogger: {e}");
}
}
}
impl Log for ParallelLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= self.log_level
}
fn log(&self, record: &Record) {
self.send(MsgType::Data(SerializableLogRecord::from(record)));
}
fn flush(&self) {
self.send(MsgType::Flush);
}
}
#[cfg(test)]
mod test {
use super::*;
use log::{LevelFilter, Log, Metadata, Record};
use std::{sync::mpsc::Sender, time::Duration};
struct ChannelLogger {
level: LevelFilter,
sender: Sender<SerializableLogRecord>,
}
impl ChannelLogger {
pub fn new(level: LevelFilter, sender: Sender<SerializableLogRecord>) -> Box<Self> {
Box::new(Self { level, sender })
}
}
impl Log for ChannelLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= self.level
}
fn log(&self, record: &Record) {
if self.enabled(record.metadata()) {
let msg = SerializableLogRecord::from(record);
if self.sender.send(msg).is_err() {
eprintln!("Failed to send message through channel");
}
}
}
fn flush(&self) {}
}
#[test]
fn test_regular_log_message() {
let (tx, rx) = std::sync::mpsc::channel();
let (tx2, rx2) = std::sync::mpsc::channel();
let (tx3, rx3) = std::sync::mpsc::channel();
let logger = ChannelLogger::new(LevelFilter::Info, tx);
let logger2 = ChannelLogger::new(LevelFilter::Info, tx2);
let logger3 = ChannelLogger::new(LevelFilter::Error, tx3);
let shutdown_handle = ParallelLogger::init(LevelFilter::Info, ParallelMode::Sequential, vec![logger, logger2, logger3]);
log::info!("Test message");
let msg = rx.recv_timeout(Duration::from_secs(2));
assert!(msg.is_ok());
let msg = msg.unwrap();
assert_eq!(msg.level, "INFO");
assert_eq!(msg.args, "Test message");
assert_eq!(msg.module_path, Some("parallel_logger::test".into()));
assert_eq!(msg.target, "parallel_logger::test");
assert_eq!(msg.file, Some("src/lib.rs".to_owned()));
assert!(msg.line.is_some());
let msg = rx2.recv_timeout(Duration::from_secs(2));
assert!(msg.is_ok());
let msg = msg.unwrap();
assert_eq!(msg.level, "INFO");
assert_eq!(msg.args, "Test message");
assert_eq!(msg.module_path, Some("parallel_logger::test".into()));
assert_eq!(msg.target, "parallel_logger::test");
assert_eq!(msg.file, Some("src/lib.rs".to_owned()));
assert!(msg.line.is_some());
assert!(rx3.recv_timeout(Duration::from_secs(2)).is_err());
assert_eq!(shutdown_handle.shutdown(), 1);
}
#[test]
#[should_panic]
fn test_parallel_logger_no_actual_loggers() {
ParallelLogger::init(LevelFilter::Info, ParallelMode::Sequential, vec![]);
}
}