#![forbid(unsafe_code)]
#![doc(
html_root_url = "https://docs.rs/pyo3-log/0.2.1/pyo3-log/",
test(attr(deny(warnings))),
test(attr(allow(unknown_lints, non_local_definitions)))
)]
#![warn(missing_docs)]
use std::cmp;
use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use log::{Level, LevelFilter, Log, Metadata, Record, SetLoggerError};
use pyo3::prelude::*;
use pyo3::types::PyTuple;
#[derive(Clone, Debug)]
pub struct ResetHandle(Arc<ArcSwap<CacheNode>>);
impl ResetHandle {
pub fn reset(&self) {
self.0.store(Default::default());
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
#[derive(Default)]
pub enum Caching {
Nothing,
Loggers,
#[default]
LoggersAndLevels,
}
#[derive(Debug)]
struct CacheEntry {
filter: LevelFilter,
logger: Py<PyAny>,
}
impl CacheEntry {
fn clone_ref(&self, py: Python<'_>) -> Self {
CacheEntry {
filter: self.filter,
logger: self.logger.clone_ref(py),
}
}
}
#[derive(Debug, Default)]
struct CacheNode {
local: Option<CacheEntry>,
children: HashMap<String, Arc<CacheNode>>,
}
impl CacheNode {
fn store_to_cache_recursive<'a, P>(
&self,
py: Python<'_>,
mut path: P,
entry: CacheEntry,
) -> Arc<Self>
where
P: Iterator<Item = &'a str>,
{
let mut me = CacheNode {
children: self.children.clone(),
local: self.local.as_ref().map(|e| e.clone_ref(py)),
};
match path.next() {
Some(segment) => {
let child = me.children.entry(segment.to_owned()).or_default();
*child = child.store_to_cache_recursive(py, path, entry);
}
None => me.local = Some(entry),
}
Arc::new(me)
}
}
#[derive(Debug)]
pub struct Logger {
top_filter: LevelFilter,
filters: HashMap<String, LevelFilter>,
prefix: Option<String>,
logging: Py<PyModule>,
caching: Caching,
cache: Arc<ArcSwap<CacheNode>>,
}
impl Logger {
pub fn new(py: Python<'_>, caching: Caching) -> PyResult<Self> {
let logging = py.import("logging")?;
Ok(Self {
top_filter: LevelFilter::Debug,
filters: HashMap::new(),
prefix: None,
logging: logging.into(),
caching,
cache: Default::default(),
})
}
pub fn install(self) -> Result<ResetHandle, SetLoggerError> {
let handle = self.reset_handle();
let level = cmp::max(
self.top_filter,
self.filters
.values()
.copied()
.max()
.unwrap_or(LevelFilter::Off),
);
log::set_boxed_logger(Box::new(self))?;
log::set_max_level(level);
Ok(handle)
}
pub fn reset_handle(&self) -> ResetHandle {
ResetHandle(Arc::clone(&self.cache))
}
pub fn filter(mut self, filter: LevelFilter) -> Self {
self.top_filter = filter;
self
}
pub fn filter_target(mut self, target: String, filter: LevelFilter) -> Self {
self.filters.insert(target, filter);
self
}
pub fn set_prefix(mut self, prefix: &str) -> Self {
self.prefix = Some(prefix.replace("::", "."));
self
}
fn lookup(&self, target: &str) -> Option<Arc<CacheNode>> {
if self.caching == Caching::Nothing {
return None;
}
let root = self.cache.load();
let mut node: &Arc<CacheNode> = &root;
for segment in target.split("::") {
match node.children.get(segment) {
Some(sub) => node = sub,
None => return None,
}
}
Some(Arc::clone(node))
}
fn log_inner(
&self,
py: Python<'_>,
record: &Record,
cache: &Option<Arc<CacheNode>>,
) -> PyResult<Option<Py<PyAny>>> {
let msg = format!("{}", record.args());
let log_level = map_level(record.level());
let mut target = record.target().replace("::", ".");
target = match &self.prefix {
Some(prefix) => format!("{}.{}", prefix, target),
None => target,
};
let cached_logger = cache
.as_ref()
.and_then(|node| node.local.as_ref())
.map(|local| &local.logger);
let (logger, cached) = match cached_logger {
Some(cached) => (cached.bind(py).clone(), true),
None => (
self.logging
.bind(py)
.getattr("getLogger")?
.call1((&target,))?,
false,
),
};
if is_enabled_for(&logger, record.level())? {
let none = py.None();
#[allow(unused_mut)]
let mut extra = py.None().into_bound(py);
#[cfg(feature = "kv")]
if record.key_values().count() > 0 {
use log::kv::{Key, Value, VisitSource};
use pyo3::types::{PyDict, PyString};
struct PyDictVisitor<'p> {
dict: Bound<'p, PyDict>,
}
impl<'kvs, 'p> VisitSource<'kvs> for PyDictVisitor<'p> {
fn visit_pair(
&mut self,
key: Key<'kvs>,
value: Value<'kvs>,
) -> Result<(), log::kv::Error> {
let py_key = PyString::new(self.dict.py(), key.as_str());
let py_value = PyString::new(self.dict.py(), &value.to_string());
let _ = self.dict.set_item(py_key, py_value);
Ok(())
}
}
let mut visitor = PyDictVisitor {
dict: PyDict::new(py),
};
let _ = record.key_values().visit(&mut visitor);
extra = visitor.dict.into_any();
}
let record = logger.call_method1(
"makeRecord",
(
target,
log_level,
record.file(),
record.line().unwrap_or_default(),
msg,
PyTuple::empty(py), &none, &none, extra, ),
)?;
logger.call_method1("handle", (record,))?;
}
let cache_logger = if !cached && self.caching != Caching::Nothing {
Some(logger.into())
} else {
None
};
Ok(cache_logger)
}
fn filter_for(&self, target: &str) -> LevelFilter {
let mut start = 0;
let mut filter = self.top_filter;
while let Some(end) = target[start..].find("::") {
if let Some(f) = self.filters.get(&target[..start + end]) {
filter = *f;
}
start += end + 2;
}
if let Some(f) = self.filters.get(target) {
filter = *f;
}
filter
}
fn enabled_inner(&self, metadata: &Metadata, cache: &Option<Arc<CacheNode>>) -> bool {
let cache_filter = cache
.as_ref()
.and_then(|node| node.local.as_ref())
.map(|local| local.filter)
.unwrap_or_else(LevelFilter::max);
metadata.level() <= cache_filter && metadata.level() <= self.filter_for(metadata.target())
}
fn store_to_cache(&self, py: Python<'_>, target: &str, entry: CacheEntry) {
let path = target.split("::");
let orig = self.cache.load();
let new = orig.store_to_cache_recursive(py, path, entry);
self.cache.compare_and_swap(orig, new);
}
}
impl Default for Logger {
fn default() -> Self {
Python::attach(|py| {
Self::new(py, Caching::LoggersAndLevels).expect("Failed to initialize python logging")
})
}
}
impl Log for Logger {
fn enabled(&self, metadata: &Metadata) -> bool {
let cache = self.lookup(metadata.target());
self.enabled_inner(metadata, &cache)
}
fn log(&self, record: &Record) {
let cache = self.lookup(record.target());
if self.enabled_inner(record.metadata(), &cache) {
Python::attach(|py| {
let maybe_existing_exception = PyErr::take(py);
match self.log_inner(py, record, &cache) {
Ok(Some(logger)) => {
let filter = match self.caching {
Caching::Nothing => unreachable!(),
Caching::Loggers => LevelFilter::max(),
Caching::LoggersAndLevels => extract_max_level(logger.bind(py))
.unwrap_or_else(|e| {
e.restore(py);
LevelFilter::max()
}),
};
let entry = CacheEntry { filter, logger };
self.store_to_cache(py, record.target(), entry);
}
Ok(None) => (),
Err(e) => {
e.restore(py);
}
};
if let Some(e) = maybe_existing_exception {
e.restore(py);
}
})
}
}
fn flush(&self) {}
}
fn map_level(level: Level) -> usize {
match level {
Level::Error => 40,
Level::Warn => 30,
Level::Info => 20,
Level::Debug => 10,
Level::Trace => 5,
}
}
fn is_enabled_for(logger: &Bound<'_, PyAny>, level: Level) -> PyResult<bool> {
let level = map_level(level);
logger.call_method1("isEnabledFor", (level,))?.is_truthy()
}
fn extract_max_level(logger: &Bound<'_, PyAny>) -> PyResult<LevelFilter> {
use Level::*;
for l in &[Trace, Debug, Info, Warn, Error] {
if is_enabled_for(logger, *l)? {
return Ok(l.to_level_filter());
}
}
Ok(LevelFilter::Off)
}
pub fn try_init() -> Result<ResetHandle, SetLoggerError> {
Logger::default().install()
}
pub fn init() -> ResetHandle {
try_init().unwrap()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_filter() {
let logger = Logger::default();
assert_eq!(logger.filter_for("hello_world"), LevelFilter::Debug);
assert_eq!(logger.filter_for("hello_world::sub"), LevelFilter::Debug);
}
#[test]
fn set_filter() {
let logger = Logger::default().filter(LevelFilter::Info);
assert_eq!(logger.filter_for("hello_world"), LevelFilter::Info);
assert_eq!(logger.filter_for("hello_world::sub"), LevelFilter::Info);
}
#[test]
fn filter_specific() {
let logger = Logger::default()
.filter(LevelFilter::Warn)
.filter_target("hello_world".to_owned(), LevelFilter::Debug)
.filter_target("hello_world::sub".to_owned(), LevelFilter::Trace);
assert_eq!(logger.filter_for("hello_world"), LevelFilter::Debug);
assert_eq!(logger.filter_for("hello_world::sub"), LevelFilter::Trace);
assert_eq!(
logger.filter_for("hello_world::sub::multi::level"),
LevelFilter::Trace
);
assert_eq!(
logger.filter_for("hello_world::another"),
LevelFilter::Debug
);
assert_eq!(
logger.filter_for("hello_world::another::level"),
LevelFilter::Debug
);
assert_eq!(logger.filter_for("other"), LevelFilter::Warn);
}
}