use crate::clocks::{Clock, SystemClock};
use crate::{Connection, Db, Result};
use gethostname::gethostname;
use log::{Level, Log, Metadata, Record};
use std::env;
use std::str::FromStr;
use std::sync::mpsc::{self, RecvTimeoutError};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use time::OffsetDateTime;
const CHANNEL_SIZE: usize = 128;
const MAX_BATCH_SIZE: usize = 128;
const MAX_FLUSH_DELAY_SECS: u64 = 5;
const DEFAULT_LOG_LEVEL: Level = Level::Warn;
pub(crate) const LOG_ENTRY_MAX_HOSTNAME_LENGTH: usize = 64;
pub(crate) const LOG_ENTRY_MAX_MODULE_LENGTH: usize = 64;
pub(crate) const LOG_ENTRY_MAX_FILENAME_LENGTH: usize = 256;
pub(crate) const LOG_ENTRY_MAX_MESSAGE_LENGTH: usize = 1024;
#[derive(Debug)]
pub(crate) struct LogEntry<'a, 'b> {
pub(crate) timestamp: OffsetDateTime,
pub(crate) hostname: String,
pub(crate) level: Level,
pub(crate) module: Option<&'a str>,
pub(crate) filename: Option<&'b str>,
pub(crate) line: Option<u32>,
pub(crate) message: String,
}
type StaticLogEntry = LogEntry<'static, 'static>;
#[derive(Debug)]
enum Action {
Stop,
Flush,
Record(StaticLogEntry),
}
async fn write_all(db: Arc<dyn Db + Send + Sync + 'static>, entries: Vec<StaticLogEntry>) {
if let Err(e) = db.put_log_entries(entries).await {
eprintln!("Failed to write log entries: {}", e);
}
}
async fn recorder(
db: Arc<dyn Db + Send + Sync + 'static>,
action_rx: mpsc::Receiver<Action>,
done_tx: mpsc::SyncSender<()>,
) {
let mut buffer = vec![];
let mut writers = vec![];
let timeout = Duration::new(MAX_FLUSH_DELAY_SECS, 0);
loop {
let auto_flush;
let action = match action_rx.recv_timeout(timeout) {
Ok(action) => {
auto_flush = false;
action
}
Err(RecvTimeoutError::Timeout) => {
auto_flush = true;
Action::Flush
}
Err(RecvTimeoutError::Disconnected) => {
eprintln!("Failed to get log entry due to closed channel; terminating logger");
break;
}
};
match action {
Action::Stop => break,
Action::Flush => {
if !buffer.is_empty() {
let batch = buffer.split_off(0);
let db = db.clone();
writers.push(tokio::spawn(async move { write_all(db, batch).await }));
}
assert!(buffer.is_empty());
for writer in writers.split_off(0) {
if let Err(e) = writer.await {
eprintln!("Failed to write batched entries: {}", e);
}
}
assert!(writers.is_empty());
if !auto_flush {
done_tx.send(()).unwrap();
}
}
Action::Record(entry) => {
buffer.push(entry);
if buffer.len() == MAX_BATCH_SIZE {
let batch = buffer.split_off(0);
let db = db.clone();
writers.push(tokio::spawn(async move { write_all(db, batch).await }));
assert!(buffer.is_empty());
}
}
}
}
drop(db);
done_tx.send(()).unwrap();
}
fn is_recorder_log(record: &Record) -> bool {
let module = match record.module_path_static() {
Some(module) => module,
None => return true,
};
(module.starts_with("rustls::") || module.starts_with("sqlx::"))
|| (record.level() >= Level::Trace
&& (module.starts_with("async_io::")
|| module.starts_with("async_std::")
|| module.starts_with("polling")))
}
fn env_rust_log() -> Level {
match env::var("RUST_LOG") {
Ok(level) => match Level::from_str(&level) {
Ok(level) => level,
Err(e) => {
eprintln!("Invalid RUST_LOG value: {}", e);
DEFAULT_LOG_LEVEL
}
},
Err(env::VarError::NotPresent) => DEFAULT_LOG_LEVEL,
Err(e) => {
eprintln!("Invalid RUST_LOG value: {}", e);
DEFAULT_LOG_LEVEL
}
}
}
pub struct Handle {
db: Connection,
action_tx: mpsc::SyncSender<Action>,
done_rx: Arc<Mutex<mpsc::Receiver<()>>>,
}
impl Handle {
pub async fn get_log_entries(&self) -> Result<Vec<String>> {
self.db.0.get_log_entries().await
}
}
impl Drop for Handle {
fn drop(&mut self) {
let done_rx = self.done_rx.lock().unwrap();
self.action_tx.send(Action::Flush).unwrap();
done_rx.recv().unwrap();
self.action_tx.send(Action::Stop).unwrap();
done_rx.recv().unwrap();
}
}
struct DbLogger {
hostname: String,
action_tx: mpsc::SyncSender<Action>,
done_rx: Arc<Mutex<mpsc::Receiver<()>>>,
clock: Arc<dyn Clock + Send + Sync + 'static>,
}
impl DbLogger {
async fn new(
hostname: String,
db: Connection,
clock: Arc<dyn Clock + Send + Sync + 'static>,
) -> Self {
let (action_tx, action_rx) = mpsc::sync_channel(CHANNEL_SIZE);
let (done_tx, done_rx) = mpsc::sync_channel(1);
tokio::spawn(async move {
recorder(db.0, action_rx, done_tx).await;
});
let done_rx = Arc::from(Mutex::from(done_rx));
Self { hostname, action_tx, done_rx, clock }
}
}
impl Log for DbLogger {
fn enabled(&self, _metadata: &Metadata) -> bool {
true
}
fn log(&self, record: &Record) {
if !self.enabled(record.metadata()) {
return;
}
let now = self.clock.now_utc();
if is_recorder_log(record) {
eprintln!(
"Non-persisted log entry: {:?} {} {:?} {:?}:{:?} {}",
now,
record.level(),
record.module_path_static(),
record.file_static(),
record.line(),
record.args(),
);
return;
}
let entry = StaticLogEntry {
timestamp: now,
hostname: self.hostname.clone(),
level: record.level(),
module: record.module_path_static(),
filename: record.file_static(),
line: record.line(),
message: format!("{}", record.args()),
};
self.action_tx.send(Action::Record(entry)).unwrap();
}
fn flush(&self) {
let done_rx = self.done_rx.lock().unwrap();
self.action_tx.send(Action::Flush).unwrap();
done_rx.recv().unwrap();
}
}
pub async fn init(db: Connection) -> Handle {
let max_level = env_rust_log();
let hostname =
gethostname().into_string().unwrap_or_else(|_e| String::from("invalid-hostname"));
let logger = DbLogger::new(hostname, db.clone(), Arc::from(SystemClock::default())).await;
let handle =
Handle { db, action_tx: logger.action_tx.clone(), done_rx: logger.done_rx.clone() };
log::set_boxed_logger(Box::from(logger)).expect("Logger should not have been set up yet");
log::set_max_level(max_level.to_level_filter());
handle
}
#[cfg(test)]
#[cfg(feature = "sqlite")]
mod tests {
use super::*;
use crate::clocks::MonotonicClock;
use crate::sqlite;
use log::RecordBuilder;
async fn setup() -> (DbLogger, Connection) {
let db = sqlite::connect(sqlite::ConnectionOptions { uri: ":memory:".to_owned() })
.await
.unwrap();
db.create_schema().await.unwrap();
let clock = Arc::from(MonotonicClock::new(1000));
(DbLogger::new("fake-hostname".to_owned(), db.clone(), clock).await, db)
}
fn emit_all_log_levels(logger: &dyn Log) {
for (level, message) in &[
(Level::Error, "An error message"),
(Level::Warn, "A warning message"),
(Level::Info, "An info message"),
(Level::Debug, "A debug message"),
(Level::Trace, "A trace message"),
] {
logger.log(
&RecordBuilder::new()
.level(*level)
.module_path_static(Some("the-module"))
.file_static(Some("the-file"))
.line(Some(123))
.args(format_args!("{}", message))
.build(),
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_all_log_levels() {
let (logger, db) = setup().await;
emit_all_log_levels(&logger);
logger.flush();
let entries = db.0.get_log_entries().await.unwrap();
assert_eq!(
vec![
"1000.0 fake-hostname 1 the-module the-file:123 An error message".to_owned(),
"1001.0 fake-hostname 2 the-module the-file:123 A warning message".to_owned(),
"1002.0 fake-hostname 3 the-module the-file:123 An info message".to_owned(),
"1003.0 fake-hostname 4 the-module the-file:123 A debug message".to_owned(),
"1004.0 fake-hostname 5 the-module the-file:123 A trace message".to_owned(),
],
entries
);
}
}