use std::cell::Cell;
use std::fmt;
use std::marker::PhantomData;
use std::sync::Arc;
#[cfg(debug_assertions)]
use std::sync::atomic::{AtomicU64, Ordering};
use indexmap::IndexSet;
use crate::passive::Passive;
use crate::registry::VarRegistry;
#[cfg(debug_assertions)]
static NEXT_GEN: AtomicU64 = AtomicU64::new(1);
thread_local! {
#[cfg(debug_assertions)]
static ACTIVE_GEN: Cell<u64> = const { Cell::new(0) };
static ACTIVE_REGISTRY: Cell<*const VarRegistry> = const { Cell::new(std::ptr::null()) };
}
pub(crate) fn with_active_registry<R>(f: impl FnOnce(Option<&VarRegistry>) -> R) -> R {
ACTIVE_REGISTRY.with(|c| {
let ptr = c.get();
let reg_ref: Option<&VarRegistry> = if ptr.is_null() {
None
} else {
Some(unsafe { &*ptr })
};
f(reg_ref)
})
}
#[cfg(debug_assertions)]
#[inline(always)]
pub(crate) fn current_gen() -> u64 {
ACTIVE_GEN.with(|c| c.get())
}
#[cfg(debug_assertions)]
#[inline(always)]
pub(crate) fn check_gen(lhs: u64, rhs: u64) {
assert_eq!(
lhs, rhs,
"xad_rs::named: cross-registry forward-mode op detected (lhs tape generation = {lhs}, rhs tape generation = {rhs}). \
Both operands must come from the same NamedForwardTape scope."
);
}
#[cfg(not(debug_assertions))]
#[inline(always)]
#[allow(dead_code)]
pub(crate) fn check_gen(_lhs: (), _rhs: ()) {}
pub struct Jet1Handle<T: Passive = f64> {
idx: usize,
_t: PhantomData<fn() -> T>,
}
impl<T: Passive> Clone for Jet1Handle<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Passive> Copy for Jet1Handle<T> {}
impl<T: Passive> fmt::Debug for Jet1Handle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Jet1Handle").field("idx", &self.idx).finish()
}
}
impl<T: Passive> PartialEq for Jet1Handle<T> {
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}
impl<T: Passive> Eq for Jet1Handle<T> {}
pub struct Jet1VecHandle<T: Passive = f64> {
idx: usize,
_t: PhantomData<fn() -> T>,
}
impl<T: Passive> Clone for Jet1VecHandle<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Passive> Copy for Jet1VecHandle<T> {}
impl<T: Passive> fmt::Debug for Jet1VecHandle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Jet1VecHandle")
.field("idx", &self.idx)
.finish()
}
}
impl<T: Passive> PartialEq for Jet1VecHandle<T> {
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}
impl<T: Passive> Eq for Jet1VecHandle<T> {}
pub struct Jet2Handle<T: Passive = f64> {
idx: usize,
_t: PhantomData<fn() -> T>,
}
impl<T: Passive> Clone for Jet2Handle<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Passive> Copy for Jet2Handle<T> {}
impl<T: Passive> fmt::Debug for Jet2Handle<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Jet2Handle").field("idx", &self.idx).finish()
}
}
impl<T: Passive> PartialEq for Jet2Handle<T> {
fn eq(&self, other: &Self) -> bool {
self.idx == other.idx
}
}
impl<T: Passive> Eq for Jet2Handle<T> {}
pub struct NamedForwardTape<T: Passive = f64> {
builder: IndexSet<String>,
pending_jet1: Vec<(String, T)>,
pending_jet1vec: Vec<(String, T)>,
pending_jet2: Vec<(String, T)>,
#[cfg(debug_assertions)]
#[allow(dead_code)]
gen_id: u64,
#[cfg(debug_assertions)]
prev_gen: u64,
consumed: bool,
_not_send: PhantomData<*const ()>,
}
impl<T: Passive> NamedForwardTape<T> {
pub fn new() -> Self {
#[cfg(debug_assertions)]
let new_gen = NEXT_GEN.fetch_add(1, Ordering::Relaxed);
#[cfg(debug_assertions)]
let prev_gen = ACTIVE_GEN.with(|c| {
let p = c.get();
c.set(new_gen);
p
});
Self {
builder: IndexSet::new(),
pending_jet1: Vec::new(),
pending_jet1vec: Vec::new(),
pending_jet2: Vec::new(),
#[cfg(debug_assertions)]
gen_id: new_gen,
#[cfg(debug_assertions)]
prev_gen,
consumed: false,
_not_send: PhantomData,
}
}
pub fn declare_jet1(&mut self, name: &str, value: T) -> Jet1Handle<T> {
assert!(
!self.consumed,
"NamedForwardTape::declare_jet1({:?}) called after into_scope",
name
);
if !self.builder.contains(name) {
self.builder.insert(name.to_string());
}
let idx = self.pending_jet1.len();
self.pending_jet1.push((name.to_string(), value));
Jet1Handle {
idx,
_t: PhantomData,
}
}
pub fn declare_jet1vec(&mut self, name: &str, value: T) -> Jet1VecHandle<T> {
assert!(
!self.consumed,
"NamedForwardTape::declare_jet1vec({:?}) called after into_scope",
name
);
if !self.builder.contains(name) {
self.builder.insert(name.to_string());
}
let idx = self.pending_jet1vec.len();
self.pending_jet1vec.push((name.to_string(), value));
Jet1VecHandle {
idx,
_t: PhantomData,
}
}
pub fn declare_jet2(&mut self, name: &str, value: T) -> Jet2Handle<T> {
assert!(
!self.consumed,
"NamedForwardTape::declare_jet2({:?}) called after into_scope",
name
);
if !self.builder.contains(name) {
self.builder.insert(name.to_string());
}
let idx = self.pending_jet2.len();
self.pending_jet2.push((name.to_string(), value));
Jet2Handle {
idx,
_t: PhantomData,
}
}
pub fn into_scope(mut self) -> NamedForwardScope<T> {
assert!(
!self.consumed,
"NamedForwardTape::into_scope called twice"
);
let reg = Arc::new(VarRegistry::from_names(self.builder.iter().cloned()));
let new_ptr: *const VarRegistry = Arc::as_ptr(®);
let prev_registry = ACTIVE_REGISTRY.with(|c| {
let prev = c.get();
c.set(new_ptr);
prev
});
self.consumed = true;
#[cfg(debug_assertions)]
let prev_gen = self.prev_gen;
let pending_jet1 = std::mem::take(&mut self.pending_jet1);
let n_jet1 = pending_jet1.len();
let mut jet1s: Vec<crate::forward::jet1::NamedJet1<T>> = Vec::with_capacity(n_jet1);
for (_name, value) in pending_jet1.into_iter() {
let inner = crate::forward::jet1::Jet1::<T>::new(value, T::one());
jet1s.push(crate::forward::jet1::NamedJet1::<T>::__from_inner(inner));
}
let pending_jet1vec = std::mem::take(&mut self.pending_jet1vec);
let n_jet1vec = pending_jet1vec.len();
let mut jet1vecs: Vec<crate::forward::jet1vec::NamedJet1Vec<T>> =
Vec::with_capacity(n_jet1vec);
for (name, value) in pending_jet1vec.into_iter() {
let reg_idx = reg
.index_of(&name)
.expect("declared name missing from frozen registry");
let inner = crate::forward::jet1vec::Jet1Vec::<T>::variable(value, reg_idx, n_jet1vec);
jet1vecs.push(crate::forward::jet1vec::NamedJet1Vec::<T>::__from_inner(
inner,
));
}
let pending_jet2 = std::mem::take(&mut self.pending_jet2);
let mut jet2s: Vec<crate::forward::jet2::NamedJet2<T>> =
Vec::with_capacity(pending_jet2.len());
for (name, value) in pending_jet2.into_iter() {
let reg_idx = reg
.index_of(&name)
.expect("declared name missing from frozen registry");
let inner = crate::forward::jet2::Jet2::<T>::variable(value);
jet2s.push(crate::forward::jet2::NamedJet2::<T>::__from_parts(
inner,
Some(reg_idx),
));
}
NamedForwardScope {
registry: reg,
jet1s,
jet1vecs,
jet2s,
prev_registry,
#[cfg(debug_assertions)]
prev_gen,
_not_send: PhantomData,
}
}
pub fn deactivate_all() {
ACTIVE_REGISTRY.with(|c| c.set(std::ptr::null()));
#[cfg(debug_assertions)]
ACTIVE_GEN.with(|c| c.set(0));
}
}
impl<T: Passive> Default for NamedForwardTape<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Passive> fmt::Debug for NamedForwardTape<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NamedForwardTape")
.field("inputs", &self.builder.len())
.field("pending_jet1", &self.pending_jet1.len())
.field("pending_jet1vec", &self.pending_jet1vec.len())
.field("pending_jet2", &self.pending_jet2.len())
.field("consumed", &self.consumed)
.finish()
}
}
impl<T: Passive> Drop for NamedForwardTape<T> {
fn drop(&mut self) {
if !self.consumed {
#[cfg(debug_assertions)]
ACTIVE_GEN.with(|c| c.set(self.prev_gen));
}
}
}
pub struct NamedForwardScope<T: Passive = f64> {
registry: Arc<VarRegistry>,
jet1s: Vec<crate::forward::jet1::NamedJet1<T>>,
jet1vecs: Vec<crate::forward::jet1vec::NamedJet1Vec<T>>,
jet2s: Vec<crate::forward::jet2::NamedJet2<T>>,
prev_registry: *const VarRegistry,
#[cfg(debug_assertions)]
prev_gen: u64,
_not_send: PhantomData<*const ()>,
}
impl<T: Passive> NamedForwardScope<T> {
#[inline]
pub fn jet1(&self, handle: Jet1Handle<T>) -> &crate::forward::jet1::NamedJet1<T> {
&self.jet1s[handle.idx]
}
#[inline]
pub fn jet1vec(&self, handle: Jet1VecHandle<T>) -> &crate::forward::jet1vec::NamedJet1Vec<T> {
&self.jet1vecs[handle.idx]
}
#[inline]
pub fn jet2(&self, handle: Jet2Handle<T>) -> &crate::forward::jet2::NamedJet2<T> {
&self.jet2s[handle.idx]
}
#[inline]
pub fn constant_jet1(&self, value: T) -> crate::forward::jet1::NamedJet1<T> {
let inner = crate::forward::jet1::Jet1::<T>::constant(value);
crate::forward::jet1::NamedJet1::<T>::__from_inner(inner)
}
#[inline]
pub fn constant_jet1vec(&self, value: T) -> crate::forward::jet1vec::NamedJet1Vec<T> {
let inner = crate::forward::jet1vec::Jet1Vec::<T>::constant(value, self.registry.len());
crate::forward::jet1vec::NamedJet1Vec::<T>::__from_inner(inner)
}
#[inline]
pub fn constant_jet2(&self, value: T) -> crate::forward::jet2::NamedJet2<T> {
let inner = crate::forward::jet2::Jet2::<T>::constant(value);
crate::forward::jet2::NamedJet2::<T>::__from_parts(inner, None)
}
#[inline]
pub fn registry(&self) -> &Arc<VarRegistry> {
&self.registry
}
}
impl<T: Passive> fmt::Debug for NamedForwardScope<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NamedForwardScope")
.field("registry_len", &self.registry.len())
.field("jet1s", &self.jet1s.len())
.field("jet1vecs", &self.jet1vecs.len())
.field("jet2s", &self.jet2s.len())
.finish()
}
}
impl<T: Passive> Drop for NamedForwardScope<T> {
fn drop(&mut self) {
ACTIVE_REGISTRY.with(|c| c.set(self.prev_registry));
#[cfg(debug_assertions)]
ACTIVE_GEN.with(|c| c.set(self.prev_gen));
}
}