use crate::Level;
use std::cell::Cell;
use std::collections::HashMap;
use std::fmt::Display;
use std::future::Future;
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::task::Poll;
static TASK_ID: AtomicU64 = AtomicU64::new(0);
static CONTEXT_ID: AtomicU64 = AtomicU64::new(0);
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct TaskID(u64);
impl Display for TaskID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct ContextID(u64);
impl Task {
#[inline]
fn add_task_interval(&self, key: &'static str, duration: crate::sys::Duration) {
let mut borrow = self.mutable.lock().unwrap();
borrow
.interval_statistics
.get_mut(key)
.map(|v| *v += duration)
.unwrap_or_else(|| {
borrow.interval_statistics.insert(key, duration);
});
}
}
impl Drop for Task {
fn drop(&mut self) {
if !self.mutable.lock().unwrap().interval_statistics.is_empty() {
let mut record = crate::log_record::LogRecord::new(Level::PerfWarn);
record.log_owned(format!("{} ", self.task_id.0));
record.log("PERFWARN: statistics[");
for (key, duration) in &self.mutable.lock().unwrap().interval_statistics {
record.log(key);
record.log_owned(format!(": {:?},", duration));
}
record.log("]");
let global_loggers = crate::global_logger::global_loggers();
for logger in global_loggers {
logger.finish_log_record(record.clone());
}
}
if self.label != "Default task" {
let mut record = crate::log_record::LogRecord::new(Level::Info);
record.log_owned(format!("{} ", self.task_id.0));
record.log("Finished task `");
record.log(&self.label);
record.log("`");
let global_loggers = crate::global_logger::global_loggers();
for logger in global_loggers {
logger.finish_log_record(record.clone());
}
}
}
}
#[derive(Clone, Debug)]
struct TaskMutable {
interval_statistics: HashMap<&'static str, crate::sys::Duration>,
}
#[derive(Debug)]
pub struct Task {
task_id: TaskID,
mutable: Mutex<TaskMutable>,
label: String,
}
impl Task {
fn new(label: String) -> Task {
Task {
task_id: TaskID(TASK_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed)),
mutable: Mutex::new(TaskMutable {
interval_statistics: HashMap::new(),
}),
label,
}
}
}
#[derive(Debug)]
struct ContextInner {
parent: Option<Context>,
context_id: u64,
define_task: Option<Task>,
is_tracing: AtomicBool,
}
#[derive(Debug, Clone)]
pub struct Context {
inner: Arc<ContextInner>,
}
impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl Eq for Context {}
impl Hash for Context {
fn hash<H: Hasher>(&self, state: &mut H) {
Arc::as_ptr(&self.inner).hash(state);
}
}
impl Display for Context {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let nesting = self.nesting_level();
write!(
f,
"{}{} ({})",
" ".repeat(nesting),
self.task_id(),
self.task().label
)
}
}
impl AsRef<Task> for Context {
fn as_ref(&self) -> &Task {
self.task()
}
}
thread_local! {
static CONTEXT: Cell<Context> = Cell::new(Context::new_task_internal(None,"Default task".to_string(),0));
}
impl Context {
#[inline]
pub fn current() -> Context {
CONTEXT.with(|c|
unsafe{&*c.as_ptr()}.clone())
}
pub fn task(&self) -> &Task {
if let Some(task) = &self.inner.define_task {
task
} else {
self.inner
.parent
.as_ref()
.expect("No parent context")
.task()
}
}
#[inline]
pub fn new_task(parent: Option<Context>, label: String) -> Context {
let context_id = CONTEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Self::new_task_internal(parent, label, context_id)
}
#[inline]
fn new_task_internal(parent: Option<Context>, label: String, context_id: u64) -> Context {
Context {
inner: Arc::new(ContextInner {
parent,
context_id,
define_task: Some(Task::new(label)),
is_tracing: AtomicBool::new(false),
}),
}
}
#[inline]
pub fn reset(label: String) {
let new_context = Context::new_task(None, label);
new_context.set_current();
}
pub fn from_parent(context: Context) -> Context {
let is_tracing = context.inner.is_tracing.load(Ordering::Relaxed);
Context {
inner: Arc::new(ContextInner {
parent: Some(context),
context_id: CONTEXT_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed),
define_task: None,
is_tracing: AtomicBool::new(is_tracing),
}),
}
}
#[inline]
pub fn task_id(&self) -> TaskID {
self.task().task_id
}
#[inline]
pub fn is_tracing(&self) -> bool {
self.inner.is_tracing.load(Ordering::Relaxed)
}
#[inline]
pub fn currently_tracing() -> bool {
CONTEXT.with(|c| {
unsafe { &*c.as_ptr() }
.inner
.is_tracing
.load(Ordering::Relaxed)
})
}
pub fn begin_trace() {
Context::current()
.inner
.is_tracing
.store(true, Ordering::Relaxed);
logwise::trace_sync!("Begin trace");
}
pub fn set_current(self) {
CONTEXT.replace(self);
}
pub fn nesting_level(&self) -> usize {
let mut level = 0;
let mut current = self;
while let Some(parent) = ¤t.inner.parent {
level += 1;
current = parent;
}
level
}
#[inline]
pub fn context_id(&self) -> ContextID {
ContextID(self.inner.context_id)
}
pub fn pop(id: ContextID) {
let mut current = Context::current();
loop {
if current.context_id() == id {
let parent = current.inner.parent.clone().expect("No parent context");
CONTEXT.replace(parent);
return;
}
match current.inner.parent.as_ref() {
None => {
logwise::warn_sync!(
"Tried to pop context with ID {id}, but it was not found in the current context chain.",
id = id.0
);
return;
}
Some(ctx) => current = ctx.clone(),
}
}
}
#[doc(hidden)]
#[inline]
pub fn _log_prelude(&self, record: &mut crate::log_record::LogRecord) {
let prefix = if self.is_tracing() { "T" } else { " " };
record.log(prefix);
for _ in 0..self.nesting_level() {
record.log(" ");
}
record.log_owned(format!("{} ", self.task_id()));
}
#[doc(hidden)]
#[inline]
pub fn _add_task_interval(&self, key: &'static str, duration: crate::sys::Duration) {
self.task().add_task_interval(key, duration);
}
}
pub struct ApplyContext<F>(Context, F);
impl<F> ApplyContext<F> {
pub fn new(context: Context, f: F) -> Self {
Self(context, f)
}
}
impl<F> Future for ApplyContext<F>
where
F: Future,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let (context, fut) = unsafe {
let d = self.get_unchecked_mut();
(d.0.clone(), Pin::new_unchecked(&mut d.1))
};
let prior_context = Context::current();
context.set_current();
let r = fut.poll(cx);
prior_context.set_current();
r
}
}
#[cfg(test)]
mod tests {
use super::{Context, Task, TaskID};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;
#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
#[cfg_attr(not(target_arch = "wasm32"), test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn test_new_context() {
Context::reset("test_new_context".to_string());
let port_context = Context::current();
let next_context = Context::from_parent(port_context);
let next_context_id = next_context.context_id();
next_context.set_current();
Context::pop(next_context_id);
}
#[cfg_attr(not(target_arch = "wasm32"), test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn test_context_equality() {
Context::reset("test_context_equality".to_string());
let context1 = Context::current();
let context2 = context1.clone();
let context3 = Context::new_task(None, "different_task".to_string());
assert_eq!(context1, context2);
assert_ne!(context1, context3);
assert_ne!(context2, context3);
}
#[cfg_attr(not(target_arch = "wasm32"), test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn test_context_hash() {
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
Context::reset("test_context_hash".to_string());
let context1 = Context::current();
let context2 = context1.clone();
let context3 = Context::new_task(None, "different_task".to_string());
let mut hasher1 = DefaultHasher::new();
let mut hasher2 = DefaultHasher::new();
context1.hash(&mut hasher1);
context2.hash(&mut hasher2);
assert_eq!(hasher1.finish(), hasher2.finish());
let mut hasher3 = DefaultHasher::new();
context3.hash(&mut hasher3);
assert_ne!(hasher1.finish(), hasher3.finish());
let mut map = HashMap::new();
map.insert(context1.clone(), "value1");
map.insert(context3.clone(), "value3");
assert_eq!(map.get(&context1), Some(&"value1"));
assert_eq!(map.get(&context2), Some(&"value1")); assert_eq!(map.get(&context3), Some(&"value3"));
assert_eq!(map.len(), 2); }
#[cfg_attr(not(target_arch = "wasm32"), test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn test_context_display() {
Context::reset("root_task".to_string());
let root_context = Context::current();
let root_display = format!("{}", root_context);
assert!(root_display.starts_with(&format!("{} (root_task)", root_context.task_id())));
assert!(!root_display.starts_with(" "));
let child_context = Context::from_parent(root_context.clone());
child_context.clone().set_current();
let child_display = format!("{}", child_context);
assert!(child_display.starts_with(" ")); assert!(child_display.contains(&format!("{} (root_task)", root_context.task_id())));
let task_context = Context::new_task(Some(child_context.clone()), "child_task".to_string());
task_context.clone().set_current();
let task_display = format!("{}", task_context);
assert!(task_display.starts_with(" ")); assert!(task_display.contains(&format!("{} (child_task)", task_context.task_id())));
let grandchild_context = Context::from_parent(task_context.clone());
grandchild_context.clone().set_current();
let grandchild_display = format!("{}", grandchild_context);
assert!(grandchild_display.starts_with(" ")); assert!(grandchild_display.contains(&format!("{} (child_task)", task_context.task_id())));
}
#[cfg_attr(not(target_arch = "wasm32"), test)]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn test_context_as_ref_task() {
Context::reset("test_as_ref".to_string());
let context = Context::current();
let task_ref: &Task = context.as_ref();
assert_eq!(task_ref.task_id, context.task_id());
assert_eq!(task_ref.label, "test_as_ref");
fn takes_task_ref(task: &Task) -> TaskID {
task.task_id
}
let id1 = takes_task_ref(context.as_ref());
assert_eq!(id1, context.task_id());
fn takes_as_ref_task<T: AsRef<Task>>(item: T) -> TaskID {
item.as_ref().task_id
}
let id2 = takes_as_ref_task(&context);
let id3 = takes_as_ref_task(context.clone());
assert_eq!(id1, id2);
assert_eq!(id2, id3);
let child_context = Context::from_parent(context.clone());
let child_task_ref: &Task = child_context.as_ref();
assert_eq!(child_task_ref.task_id, context.task_id());
assert_eq!(child_task_ref.label, "test_as_ref");
let new_task_context = Context::new_task(Some(context.clone()), "new_task".to_string());
let new_task_ref: &Task = new_task_context.as_ref();
assert_ne!(new_task_ref.task_id, context.task_id());
assert_eq!(new_task_ref.label, "new_task");
}
}