use std::{
sync::mpsc::{Receiver, Sender},
thread::{self, JoinHandle},
};
use log::{Level, LevelFilter, Log, Metadata, Record};
struct RecordMsg {
level: Level,
args: String,
module_path: Option<String>,
target: String,
file: Option<String>,
line: Option<u32>,
}
impl RecordMsg {
const fn new(
level: Level,
args: String,
module_path: Option<String>,
target: String,
file: Option<String>,
line: Option<u32>,
) -> Self {
Self {
level,
args,
module_path,
target,
file,
line,
}
}
}
enum MsgType {
Data(RecordMsg),
Flush,
Shutdown,
}
#[derive(Debug)]
pub struct ParallelLogger {
tx: Sender<MsgType>,
log_level: LevelFilter,
join_handle: Option<JoinHandle<()>>,
}
impl ParallelLogger {
pub fn init(log_level: LevelFilter, actual_loggers: Vec<Box<dyn Log>>) {
assert!(!actual_loggers.is_empty(), "Failed to initialize ParallelLogger: No actual loggers provided");
let (tx, rx) = std::sync::mpsc::channel();
let join_handle = Self::start_thread(rx, actual_loggers);
let tpl = Self {
tx,
log_level,
join_handle: Some(join_handle),
};
log::set_boxed_logger(Box::new(tpl)).unwrap();
log::set_max_level(log_level);
}
fn start_thread(rx: Receiver<MsgType>, actual_loggers: Vec<Box<dyn Log>>) -> JoinHandle<()> {
thread::spawn(move || {
while let Ok(message) = rx.recv() {
match message {
MsgType::Data(message) => {
for actual_logger in &actual_loggers {
Self::log_record(&message, actual_logger);
}
}
MsgType::Flush => {
for actual_logger in &actual_loggers {
actual_logger.flush();
}
}
MsgType::Shutdown => break,
};
}
})
}
fn log_record(message: &RecordMsg, actual_logger: &dyn Log) {
let mut builder = Record::builder();
actual_logger.log(
&builder
.level(message.level)
.args(format_args!("{}", message.args))
.module_path(message.module_path.as_deref())
.target(message.target.as_str())
.file(message.file.as_deref())
.line(message.line)
.build(),
);
}
fn send(&self, msg: MsgType) {
if let Err(e) = self.tx.send(msg) {
eprintln!("An internal error occurred in ParallelLogger: {e}");
}
}
fn convert_msg(record: &Record) -> RecordMsg {
RecordMsg::new(
record.level(),
record.args().to_string(),
record.module_path().map(str::to_owned),
record.target().to_owned(),
record.file().map(str::to_owned),
record.line(),
)
}
}
impl Log for ParallelLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= self.log_level
}
fn log(&self, record: &Record) {
self.send(MsgType::Data(Self::convert_msg(record)));
}
fn flush(&self) {
self.send(MsgType::Flush);
}
}
impl Drop for ParallelLogger {
fn drop(&mut self) {
self.send(MsgType::Shutdown);
if let Some(join_handle) = self.join_handle.take() {
if let Err(e) = join_handle.join() {
eprintln!("An internal error occurred while shutting down ParallelLogger: {e:?}");
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use log::{LevelFilter, Log, Metadata, Record};
use std::{sync::mpsc::Sender, time::Duration};
use crate::RecordMsg;
struct ChannelLogger {
level: LevelFilter,
sender: Sender<RecordMsg>,
}
impl ChannelLogger {
pub fn new(level: LevelFilter, sender: Sender<RecordMsg>) -> Box<Self> {
Box::new(Self { level, sender })
}
}
impl Log for ChannelLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= self.level
}
fn log(&self, record: &Record) {
if self.enabled(record.metadata()) {
let msg = ParallelLogger::convert_msg(record);
if self.sender.send(msg).is_err() {
eprintln!("Failed to send message through channel");
}
}
}
fn flush(&self) {}
}
#[test]
fn test_regular_log_message() {
let (tx, rx) = std::sync::mpsc::channel();
let (tx2, rx2) = std::sync::mpsc::channel();
let logger = ChannelLogger::new(LevelFilter::Info, tx);
let logger2 = ChannelLogger::new(LevelFilter::Error, tx2);
ParallelLogger::init(LevelFilter::Info, vec![logger, logger2]);
log::info!("Test message");
let msg = rx.recv_timeout(Duration::from_secs(2));
assert!(msg.is_ok());
let msg = msg.unwrap();
assert_eq!(msg.level, Level::Info);
assert_eq!(msg.args, "Test message");
assert_eq!(msg.module_path, Some("parallel_logger::test".into()));
assert_eq!(msg.target, "parallel_logger::test");
assert_eq!(msg.file, Some("src/lib.rs".to_owned()));
assert!(msg.line.is_some());
assert!(rx2.recv_timeout(Duration::from_secs(2)).is_err());
}
#[test]
#[should_panic]
fn test_parallel_logger_no_actual_loggers() {
ParallelLogger::init(LevelFilter::Info, vec![]);
}
}