use crate::worker::Worker;
use crate::Msg;
use crossbeam_channel::{bounded, SendTimeoutError, Sender};
use std::io;
use std::io::Write;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::Duration;
use tracing_subscriber::fmt::MakeWriter;
pub const DEFAULT_BUFFERED_LINES_LIMIT: usize = 128_000;
#[must_use]
#[derive(Debug)]
pub struct WorkerGuard {
_guard: Option<JoinHandle<()>>,
sender: Sender<Msg>,
shutdown: Sender<()>,
}
#[derive(Clone, Debug)]
pub struct NonBlocking {
error_counter: ErrorCounter,
channel: Sender<Msg>,
is_lossy: bool,
}
#[derive(Clone, Debug)]
pub struct ErrorCounter(Arc<AtomicUsize>);
impl NonBlocking {
pub fn new<T: Write + Send + 'static>(writer: T) -> (NonBlocking, WorkerGuard) {
NonBlockingBuilder::default().finish(writer)
}
fn create<T: Write + Send + 'static>(
writer: T,
buffered_lines_limit: usize,
is_lossy: bool,
thread_name: String,
) -> (NonBlocking, WorkerGuard) {
let (sender, receiver) = bounded(buffered_lines_limit);
let (shutdown_sender, shutdown_receiver) = bounded(0);
let worker = Worker::new(receiver, writer, shutdown_receiver);
let worker_guard = WorkerGuard::new(
worker.worker_thread(thread_name),
sender.clone(),
shutdown_sender,
);
(
Self {
channel: sender,
error_counter: ErrorCounter(Arc::new(AtomicUsize::new(0))),
is_lossy,
},
worker_guard,
)
}
pub fn error_counter(&self) -> ErrorCounter {
self.error_counter.clone()
}
}
#[derive(Debug)]
pub struct NonBlockingBuilder {
buffered_lines_limit: usize,
is_lossy: bool,
thread_name: String,
}
impl NonBlockingBuilder {
pub fn buffered_lines_limit(mut self, buffered_lines_limit: usize) -> NonBlockingBuilder {
self.buffered_lines_limit = buffered_lines_limit;
self
}
pub fn lossy(mut self, is_lossy: bool) -> NonBlockingBuilder {
self.is_lossy = is_lossy;
self
}
pub fn thread_name(mut self, name: &str) -> NonBlockingBuilder {
self.thread_name = name.to_string();
self
}
pub fn finish<T: Write + Send + 'static>(self, writer: T) -> (NonBlocking, WorkerGuard) {
NonBlocking::create(
writer,
self.buffered_lines_limit,
self.is_lossy,
self.thread_name,
)
}
}
impl Default for NonBlockingBuilder {
fn default() -> Self {
NonBlockingBuilder {
buffered_lines_limit: DEFAULT_BUFFERED_LINES_LIMIT,
is_lossy: true,
thread_name: "tracing-appender".to_string(),
}
}
}
impl std::io::Write for NonBlocking {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let buf_size = buf.len();
if self.is_lossy {
if self.channel.try_send(Msg::Line(buf.to_vec())).is_err() {
self.error_counter.incr_saturating();
}
} else {
return match self.channel.send(Msg::Line(buf.to_vec())) {
Ok(_) => Ok(buf_size),
Err(_) => Err(io::Error::from(io::ErrorKind::Other)),
};
}
Ok(buf_size)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
#[inline]
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.write(buf).map(|_| ())
}
}
impl<'a> MakeWriter<'a> for NonBlocking {
type Writer = NonBlocking;
fn make_writer(&'a self) -> Self::Writer {
self.clone()
}
}
impl WorkerGuard {
fn new(handle: JoinHandle<()>, sender: Sender<Msg>, shutdown: Sender<()>) -> Self {
WorkerGuard {
_guard: Some(handle),
sender,
shutdown,
}
}
}
impl Drop for WorkerGuard {
fn drop(&mut self) {
match self
.sender
.send_timeout(Msg::Shutdown, Duration::from_millis(100))
{
Ok(_) => {
let _ = self.shutdown.send_timeout((), Duration::from_millis(1000));
}
Err(SendTimeoutError::Disconnected(_)) => (),
Err(SendTimeoutError::Timeout(e)) => println!(
"Failed to send shutdown signal to logging worker. Error: {:?}",
e
),
}
}
}
impl ErrorCounter {
pub fn dropped_lines(&self) -> usize {
self.0.load(Ordering::Acquire)
}
fn incr_saturating(&self) {
let mut curr = self.0.load(Ordering::Acquire);
if curr == usize::MAX {
return;
}
loop {
let val = curr.saturating_add(1);
match self
.0
.compare_exchange(curr, val, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => return,
Err(actual) => curr = actual,
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
struct MockWriter {
tx: mpsc::SyncSender<String>,
}
impl MockWriter {
fn new(capacity: usize) -> (Self, mpsc::Receiver<String>) {
let (tx, rx) = mpsc::sync_channel(capacity);
(Self { tx }, rx)
}
}
impl std::io::Write for MockWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let buf_len = buf.len();
let _ = self.tx.send(String::from_utf8_lossy(buf).to_string());
Ok(buf_len)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[test]
fn backpressure_exerted() {
let (mock_writer, rx) = MockWriter::new(1);
let (mut non_blocking, _guard) = self::NonBlockingBuilder::default()
.lossy(false)
.buffered_lines_limit(1)
.finish(mock_writer);
let error_count = non_blocking.error_counter();
non_blocking.write_all(b"Hello").expect("Failed to write");
assert_eq!(0, error_count.dropped_lines());
let handle = thread::spawn(move || {
non_blocking.write_all(b", World").expect("Failed to write");
});
thread::sleep(Duration::from_millis(100));
assert_eq!(0, error_count.dropped_lines());
let mut line = rx.recv().unwrap();
assert_eq!(line, "Hello");
handle.join().expect("thread should not panic");
line = rx.recv().unwrap();
assert_eq!(line, ", World");
}
fn write_non_blocking(non_blocking: &mut NonBlocking, msg: &[u8]) {
non_blocking.write_all(msg).expect("Failed to write");
thread::sleep(Duration::from_millis(200));
}
#[test]
#[ignore] fn logs_dropped_if_lossy() {
let (mock_writer, rx) = MockWriter::new(1);
let (mut non_blocking, _guard) = self::NonBlockingBuilder::default()
.lossy(true)
.buffered_lines_limit(1)
.finish(mock_writer);
let error_count = non_blocking.error_counter();
write_non_blocking(&mut non_blocking, b"Hello");
assert_eq!(0, error_count.dropped_lines());
write_non_blocking(&mut non_blocking, b", World");
assert_eq!(0, error_count.dropped_lines());
write_non_blocking(&mut non_blocking, b"Test");
assert_eq!(0, error_count.dropped_lines());
let line = rx.recv().unwrap();
assert_eq!(line, "Hello");
write_non_blocking(&mut non_blocking, b"Universe");
assert_eq!(1, error_count.dropped_lines());
let line = rx.recv().unwrap();
assert_eq!(line, ", World");
assert_eq!(1, error_count.dropped_lines());
}
#[test]
fn multi_threaded_writes() {
let (mock_writer, rx) = MockWriter::new(DEFAULT_BUFFERED_LINES_LIMIT);
let (non_blocking, _guard) = self::NonBlockingBuilder::default()
.lossy(true)
.finish(mock_writer);
let error_count = non_blocking.error_counter();
let mut join_handles: Vec<JoinHandle<()>> = Vec::with_capacity(10);
for _ in 0..10 {
let cloned_non_blocking = non_blocking.clone();
join_handles.push(thread::spawn(move || {
let subscriber = tracing_subscriber::fmt().with_writer(cloned_non_blocking);
tracing::subscriber::with_default(subscriber.finish(), || {
tracing::event!(tracing::Level::INFO, "Hello");
});
}));
}
for handle in join_handles {
handle.join().expect("Failed to join thread");
}
let mut hello_count: u8 = 0;
while let Ok(event_str) = rx.recv_timeout(Duration::from_secs(5)) {
assert!(event_str.contains("Hello"));
hello_count += 1;
}
assert_eq!(10, hello_count);
assert_eq!(0, error_count.dropped_lines());
}
}