use std::any::TypeId;
use std::any::type_name;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::hash::BuildHasherDefault;
use std::marker::PhantomData;
use std::ops::Deref;
use std::ops::DerefMut;
use std::sync::Arc;
use arc_swap::ArcSwap;
use arc_swap::Guard;
use vortex_error::VortexExpect;
use vortex_error::vortex_panic;
use vortex_utils::aliases::hash_map::HashMap;
use crate::IdHasher;
use crate::SessionExt;
use crate::SessionVar;
use crate::UnknownPluginPolicy;
pub trait VortexSessionVar: SessionVar {}
impl<V: SessionVar + Clone> VortexSessionVar for V {}
type SessionVars = HashMap<TypeId, Arc<dyn VortexSessionVar>, BuildHasherDefault<IdHasher>>;
pub struct SessionGuard<'a, V> {
snapshot: Guard<Arc<SessionVars>>,
_session: PhantomData<&'a VortexSession>,
_marker: PhantomData<fn() -> V>,
}
impl<V: VortexSessionVar> Deref for SessionGuard<'_, V> {
type Target = V;
fn deref(&self) -> &V {
self.snapshot
.get(&TypeId::of::<V>())
.vortex_expect("SessionGuard invariant: variable present in snapshot")
.as_any()
.downcast_ref::<V>()
.vortex_expect("Type mismatch - this is a bug")
}
}
impl<V: VortexSessionVar> Debug for SessionGuard<'_, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
pub struct SessionMut<'a, V: VortexSessionVar> {
session: &'a VortexSession,
value: Option<V>,
}
impl<V: VortexSessionVar> Deref for SessionMut<'_, V> {
type Target = V;
fn deref(&self) -> &V {
self.value
.as_ref()
.vortex_expect("SessionMut invariant: value present until drop")
}
}
impl<V: VortexSessionVar> DerefMut for SessionMut<'_, V> {
fn deref_mut(&mut self) -> &mut V {
self.value
.as_mut()
.vortex_expect("SessionMut invariant: value present until drop")
}
}
impl<V: VortexSessionVar> Drop for SessionMut<'_, V> {
fn drop(&mut self) {
if let Some(value) = self.value.take() {
self.session.register(value);
}
}
}
impl<V: VortexSessionVar> Debug for SessionMut<'_, V> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
#[derive(Clone)]
pub struct VortexSession(Arc<ArcSwap<SessionVars>>);
impl Debug for VortexSession {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("VortexSession")
.field(&self.0.load().as_ref())
.finish()
}
}
impl VortexSession {
pub fn empty() -> Self {
Self(Arc::new(ArcSwap::from_pointee(SessionVars::default())))
}
fn insert_default<V: VortexSessionVar + Default>(&self) {
let default: Arc<dyn VortexSessionVar> = Arc::new(V::default());
self.0.rcu(|current| {
let mut next = SessionVars::clone(current);
next.entry(TypeId::of::<V>())
.or_insert_with(|| Arc::clone(&default));
next
});
}
fn register<V: VortexSessionVar>(&self, var: V) {
let var: Arc<dyn VortexSessionVar> = Arc::new(var);
self.0.rcu(|current| {
let mut next = SessionVars::clone(current);
next.insert(TypeId::of::<V>(), Arc::clone(&var));
next
});
}
pub fn with<V: VortexSessionVar + Default>(self) -> Self {
self.with_some(V::default())
}
pub fn with_some<V: VortexSessionVar>(self, var: V) -> Self {
if self.get_opt::<V>().is_some() {
vortex_panic!(
"Session variable of type {} already exists",
type_name::<V>()
);
}
self.register(var);
self
}
pub fn allows_unknown(&self) -> bool {
self.get_opt::<UnknownPluginPolicy>()
.is_some_and(|p| p.allow_unknown)
}
pub fn allow_unknown(self) -> Self {
self.get_mut::<UnknownPluginPolicy>().allow_unknown = true;
self
}
}
impl SessionExt for VortexSession {
fn session(&self) -> VortexSession {
self.clone()
}
fn get<V: VortexSessionVar + Default>(&self) -> SessionGuard<'_, V> {
if self.get_opt::<V>().is_none() {
self.insert_default::<V>();
}
self.get_opt::<V>()
.vortex_expect("variable was just inserted")
}
fn get_opt<V: VortexSessionVar>(&self) -> Option<SessionGuard<'_, V>> {
let snapshot = self.0.load();
snapshot
.contains_key(&TypeId::of::<V>())
.then(|| SessionGuard {
snapshot,
_session: PhantomData,
_marker: PhantomData,
})
}
fn get_mut<V: VortexSessionVar + Default + Clone>(&self) -> SessionMut<'_, V> {
let value = (*self.get::<V>()).clone();
SessionMut {
session: self,
value: Some(value),
}
}
}
#[cfg(test)]
mod tests {
use std::any::Any;
use super::VortexSession;
use crate::SessionExt;
use crate::SessionVar;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
struct Counter {
count: u32,
}
impl SessionVar for Counter {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
#[derive(Clone, Debug, Default)]
struct Other;
impl SessionVar for Other {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
thread_local! {
static REENTRANT_SESSION: std::cell::RefCell<Option<VortexSession>> =
const { std::cell::RefCell::new(None) };
}
#[derive(Clone, Debug)]
struct Reentrant {
inner: u32,
}
impl Default for Reentrant {
fn default() -> Self {
REENTRANT_SESSION.with(|s| {
if let Some(session) = s.borrow().as_ref() {
drop(session.get::<Counter>());
}
});
Reentrant { inner: 7 }
}
}
impl SessionVar for Reentrant {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
#[test]
fn with_some_round_trip() {
let session = VortexSession::empty().with_some(Counter { count: 1 });
assert_eq!(*session.get::<Counter>(), Counter { count: 1 });
assert!(session.get_opt::<Other>().is_none());
}
#[test]
fn get_inserts_default() {
let session = VortexSession::empty();
assert!(session.get_opt::<Counter>().is_none());
assert_eq!(session.get::<Counter>().count, 0);
assert!(session.get_opt::<Counter>().is_some());
}
#[test]
fn register_is_visible_through_clones() {
let session = VortexSession::empty();
let clone = session.clone();
session.register(Counter { count: 7 });
assert_eq!(clone.get::<Counter>().count, 7);
}
#[test]
fn with_some_mutates_shared_store() {
let session = VortexSession::empty();
let clone = session.clone();
let configured = session.with_some(Counter { count: 5 });
assert_eq!(configured.get::<Counter>().count, 5);
assert_eq!(clone.get::<Counter>().count, 5);
}
#[test]
fn allow_unknown_mutates_shared_store() {
let session = VortexSession::empty();
let clone = session.clone();
assert!(!clone.allows_unknown());
session.allow_unknown();
assert!(clone.allows_unknown());
}
#[test]
fn empty_sessions_are_independent() {
let session = VortexSession::empty().with_some(Counter { count: 1 });
let other = VortexSession::empty().with_some(Counter { count: 2 });
session.register(Counter { count: 9 });
assert_eq!(session.get::<Counter>().count, 9);
assert_eq!(other.get::<Counter>().count, 2);
}
#[test]
#[should_panic(expected = "already exists")]
fn with_some_duplicate_panics() {
VortexSession::empty()
.with::<Counter>()
.with_some(Counter { count: 1 });
}
#[test]
fn allow_unknown_flag_is_opt_in() {
let session = VortexSession::empty();
assert!(!session.allows_unknown());
let session = session.allow_unknown();
assert!(session.allows_unknown());
}
#[test]
fn get_opt_does_not_insert_a_default() {
let session = VortexSession::empty();
assert!(session.get_opt::<Counter>().is_none());
assert!(session.get_opt::<Counter>().is_none());
}
#[test]
fn inserting_a_default_while_holding_a_guard_succeeds() {
let session = VortexSession::empty().with_some(Counter { count: 5 });
let counter = session.get::<Counter>();
let other = session.get::<Other>();
assert_eq!(counter.count, 5);
let _: &Other = &other;
assert!(session.get_opt::<Other>().is_some());
}
#[test]
fn a_held_guard_keeps_observing_its_own_snapshot_after_a_write() {
let session = VortexSession::empty().with_some(Counter { count: 1 });
let held = session.get::<Counter>();
session.register(Counter { count: 2 });
assert_eq!(held.count, 1);
assert_eq!(session.get::<Counter>().count, 2);
}
#[test]
fn default_insertion_may_reenter_the_session_without_deadlocking() {
let session = VortexSession::empty();
REENTRANT_SESSION.with(|s| *s.borrow_mut() = Some(session.clone()));
assert_eq!(session.get::<Reentrant>().inner, 7);
assert!(session.get_opt::<Reentrant>().is_some());
assert!(session.get_opt::<Counter>().is_some());
REENTRANT_SESSION.with(|s| *s.borrow_mut() = None);
}
#[test]
fn get_mut_publishes_on_drop() {
let session = VortexSession::empty();
session.register(Counter { count: 1 });
session.get_mut::<Counter>().count = 42;
assert_eq!(session.get::<Counter>().count, 42);
}
#[test]
fn get_mut_inserts_default_then_mutates() {
let session = VortexSession::empty();
assert!(session.get_opt::<Counter>().is_none());
session.get_mut::<Counter>().count += 5;
assert_eq!(session.get::<Counter>().count, 5);
}
#[test]
fn get_mut_mutation_is_visible_through_clones() {
let session = VortexSession::empty().with_some(Counter { count: 1 });
let clone = session.clone();
session.get_mut::<Counter>().count = 9;
assert_eq!(clone.get::<Counter>().count, 9);
}
}