use conjure_object::Any;
use pin_project::{pin_project, pinned_drop};
use serde::Serialize;
use std::cell::RefCell;
use std::collections::{hash_map, HashMap};
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::sync::{Arc, OnceLock};
use std::task::{Context, Poll};
static EMPTY: OnceLock<Map> = OnceLock::new();
thread_local! {
static MDC: RefCell<Snapshot> = RefCell::new(Snapshot::new());
}
pub fn insert_safe<T>(key: &'static str, value: T) -> Option<Any>
where
T: Serialize,
{
MDC.with(|v| v.borrow_mut().safe_mut().insert(key, value))
}
pub fn insert_unsafe<T>(key: &'static str, value: T) -> Option<Any>
where
T: Serialize,
{
MDC.with(|v| v.borrow_mut().unsafe_mut().insert(key, value))
}
pub fn remove_safe(key: &str) -> Option<Any> {
MDC.with(|v| v.borrow_mut().safe_mut().remove(key))
}
pub fn remove_unsafe(key: &str) -> Option<Any> {
MDC.with(|v| v.borrow_mut().unsafe_mut().remove(key))
}
pub fn snapshot() -> Snapshot {
MDC.with(|v| v.borrow().clone())
}
pub fn clear() {
MDC.with(|v| {
let mut mdc = v.borrow_mut();
mdc.safe_mut().clear();
mdc.unsafe_mut().clear();
});
}
pub fn set(snapshot: Snapshot) -> Snapshot {
MDC.with(|v| mem::replace(&mut *v.borrow_mut(), snapshot))
}
pub fn swap(snapshot: &mut Snapshot) {
MDC.with(|v| mem::swap(&mut *v.borrow_mut(), snapshot));
}
pub fn bind<F>(future: F) -> Bind<F> {
Bind {
future: Some(future),
snapshot: snapshot(),
}
}
pub fn scope() -> Scope {
Scope { old: snapshot() }
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Map {
map: Arc<HashMap<&'static str, Any>>,
}
impl Default for Map {
#[inline]
fn default() -> Self {
EMPTY
.get_or_init(|| Map {
map: Arc::new(HashMap::new()),
})
.clone()
}
}
impl Map {
#[inline]
pub fn new() -> Self {
Map::default()
}
#[inline]
pub fn clear(&mut self) {
match Arc::get_mut(&mut self.map) {
Some(map) => map.clear(),
None => *self = Map::new(),
}
}
#[inline]
pub fn len(&self) -> usize {
self.map.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
#[inline]
pub fn get(&self, key: &str) -> Option<&Any> {
self.map.get(key)
}
#[inline]
pub fn contains_key(&self, key: &str) -> bool {
self.map.contains_key(key)
}
#[inline]
pub fn insert<V>(&mut self, key: &'static str, value: V) -> Option<Any>
where
V: Serialize,
{
let value = Any::new(value).expect("value failed to serialize");
Arc::make_mut(&mut self.map).insert(key, value)
}
#[inline]
pub fn remove(&mut self, key: &str) -> Option<Any> {
Arc::make_mut(&mut self.map).remove(key)
}
#[inline]
pub fn iter(&self) -> Iter<'_> {
Iter {
it: self.map.iter(),
}
}
}
impl<'a> IntoIterator for &'a Map {
type Item = (&'static str, &'a Any);
type IntoIter = Iter<'a>;
#[inline]
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct Iter<'a> {
it: hash_map::Iter<'a, &'static str, Any>,
}
impl<'a> Iterator for Iter<'a> {
type Item = (&'static str, &'a Any);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
self.it.next().map(|(k, v)| (*k, v))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.it.size_hint()
}
}
impl ExactSizeIterator for Iter<'_> {
#[inline]
fn len(&self) -> usize {
self.it.len()
}
}
#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub struct Snapshot {
safe: Map,
unsafe_: Map,
}
impl Snapshot {
#[inline]
pub fn new() -> Self {
Snapshot::default()
}
#[inline]
pub fn safe(&self) -> &Map {
&self.safe
}
#[inline]
pub fn safe_mut(&mut self) -> &mut Map {
&mut self.safe
}
#[inline]
pub fn unsafe_(&self) -> &Map {
&self.unsafe_
}
#[inline]
pub fn unsafe_mut(&mut self) -> &mut Map {
&mut self.unsafe_
}
}
pub struct Scope {
old: Snapshot,
}
impl Drop for Scope {
fn drop(&mut self) {
swap(&mut self.old);
}
}
#[pin_project(PinnedDrop)]
pub struct Bind<F> {
#[pin]
future: Option<F>,
snapshot: Snapshot,
}
#[pinned_drop]
impl<F> PinnedDrop for Bind<F> {
fn drop(self: Pin<&mut Self>) {
let mut this = self.project();
let _guard = scope_with(this.snapshot);
this.future.set(None);
}
}
impl<F> Future for Bind<F>
where
F: Future,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let _guard = scope_with(this.snapshot);
this.future.as_pin_mut().unwrap().poll(cx)
}
}
fn scope_with(snapshot: &mut Snapshot) -> ScopeWith<'_> {
swap(snapshot);
ScopeWith { snapshot }
}
struct ScopeWith<'a> {
snapshot: &'a mut Snapshot,
}
impl Drop for ScopeWith<'_> {
fn drop(&mut self) {
swap(self.snapshot);
}
}
#[cfg(test)]
mod test {
use conjure_object::Any;
use crate::mdc;
#[test]
fn scope() {
mdc::clear();
mdc::insert_safe("foo", "bar");
let guard = mdc::scope();
mdc::insert_safe("foo", "baz");
assert_eq!(
mdc::snapshot().safe().get("foo").unwrap(),
&Any::new("baz").unwrap(),
);
drop(guard);
assert_eq!(
mdc::snapshot().safe().get("foo").unwrap(),
&Any::new("bar").unwrap(),
);
}
#[test]
fn bind() {
mdc::clear();
mdc::insert_safe("foo", "bar");
futures_executor::block_on(mdc::bind(async {
mdc::insert_safe("foo", "baz");
assert_eq!(
mdc::snapshot().safe().get("foo").unwrap(),
&Any::new("baz").unwrap(),
);
}));
assert_eq!(
mdc::snapshot().safe().get("foo").unwrap(),
&Any::new("bar").unwrap(),
);
}
}