use std::cell::RefCell;
use std::marker::PhantomData;
use std::rc::Rc;
use proc_macro2::Span;
use quote::quote;
use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
use crate::compile::ir::{HydroNode, SharedNode};
use crate::location::Location;
pub struct SingletonRef<'a, 'slf, T, L, const IS_MUT: bool = false> {
pub(crate) ir_node: &'slf RefCell<HydroNode>,
_phantom: PhantomData<(&'a T, L)>,
}
pub type SingletonMut<'a, 'slf, T, L> = SingletonRef<'a, 'slf, T, L, true>;
impl<'slf, T, L, const IS_MUT: bool> SingletonRef<'_, 'slf, T, L, IS_MUT> {
pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
Self {
ir_node,
_phantom: PhantomData,
}
}
pub fn as_ref(&self) -> SingletonRef<'_, 'slf, T, L, false> {
SingletonRef {
ir_node: self.ir_node,
_phantom: PhantomData,
}
}
pub fn as_mut(&self) -> SingletonRef<'_, 'slf, T, L, true> {
SingletonRef {
ir_node: self.ir_node,
_phantom: PhantomData,
}
}
}
impl<T, L, const IS_MUT: bool> Copy for SingletonRef<'_, '_, T, L, IS_MUT> {}
impl<T, L, const IS_MUT: bool> Clone for SingletonRef<'_, '_, T, L, IS_MUT> {
fn clone(&self) -> Self {
*self
}
}
thread_local! {
static SINGLETON_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
}
pub(crate) fn singleton_ref_ident(index: usize) -> syn::Ident {
syn::Ident::new(
&format!("__hydro_singleton_ref_{}", index),
Span::call_site(),
)
}
pub fn with_singleton_capture(
f: impl FnOnce() -> crate::compile::ir::DebugExpr,
) -> crate::compile::ir::ClosureExpr {
SINGLETON_REFS.with(|cell| {
let prev = cell.borrow_mut().replace(Vec::new());
assert!(
prev.is_none(),
"nested singleton capture scopes are not supported"
);
});
let expr = (f)();
let singleton_refs = SINGLETON_REFS.with(|cell| cell.borrow_mut().take().unwrap());
crate::compile::ir::ClosureExpr::new(expr, singleton_refs)
}
impl<'a, 'slf, T: 'a, L, const IS_MUT: bool> SingletonRef<'a, 'slf, T, L, IS_MUT>
where
L: Location<'a>,
{
fn to_tokens_helper(self, _ctx: &L) -> (QuoteTokens, ()) {
let ident = SINGLETON_REFS.with(|cell| {
let mut guard = cell.borrow_mut();
let refs = guard.as_mut().expect(
"SingletonRef used inside q!() but no singleton capture scope is active. \
This is a bug — singleton capture should be set up by the operator that uses q!().",
);
let index = refs.len();
let ident = singleton_ref_ident(index);
let metadata = self.ir_node.borrow().metadata().clone();
if !matches!(&*self.ir_node.borrow(), HydroNode::Singleton { .. }) {
let orig = self.ir_node.replace(HydroNode::Placeholder);
*self.ir_node.borrow_mut() = HydroNode::Singleton {
inner: SharedNode(Rc::new(RefCell::new(orig))),
metadata: metadata.clone(),
};
}
let borrow: std::cell::Ref<'_, HydroNode> = self.ir_node.borrow();
let HydroNode::Singleton { inner, .. } = &*borrow else {
unreachable!()
};
refs.push((
HydroNode::Singleton {
inner: SharedNode(Rc::clone(&inner.0)),
metadata,
},
IS_MUT,
));
ident
});
(
QuoteTokens {
prelude: None,
expr: Some(quote!(#ident)),
},
(),
)
}
}
impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for SingletonRef<'a, 'slf, T, L>
where
L: Location<'a>,
{
type O = &'a T;
fn to_tokens(self, ctx: &L) -> (QuoteTokens, ()) {
self.to_tokens_helper(ctx)
}
}
impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for SingletonMut<'a, 'slf, T, L>
where
L: Location<'a>,
{
type O = &'a mut T;
fn to_tokens(self, ctx: &L) -> (QuoteTokens, ()) {
self.to_tokens_helper(ctx)
}
}
#[cfg(test)]
#[cfg(feature = "build")]
mod tests {
use stageleft::q;
use crate::compile::builder::FlowBuilder;
use crate::location::Location;
struct P1 {}
#[test]
fn singleton_by_ref_compiles() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_ref = my_count.by_ref();
node.source_iter(q!(1..=3i32))
.map(q!(|x| x + *count_ref))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_non_copy() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_vec = node.source_iter(q!(0..5i32)).fold(
q!(|| Vec::<i32>::new()),
q!(|acc: &mut Vec<i32>, x| acc.push(x)),
);
let vec_ref = my_vec.by_ref();
node.source_iter(q!(1..=3i32))
.map(q!(|x| x + vec_ref.len() as i32))
.for_each(q!(|_| {}));
my_vec.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_filter() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let threshold = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let threshold_ref = threshold.by_ref();
node.source_iter(q!(1..=10i32))
.filter(q!(|x| *x > *threshold_ref))
.for_each(q!(|_| {}));
threshold.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_flat_map() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let count = node
.source_iter(q!(0..3i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
let count_ref = count.by_ref();
node.source_iter(q!(1..=2i32))
.flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
.for_each(q!(|_| {}));
count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_inspect() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
let count_ref = count.by_ref();
node.source_iter(q!(1..=3i32))
.inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
.for_each(q!(|_| {}));
count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_partition() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let threshold = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let threshold_ref = threshold.by_ref();
let (above, below) = node
.source_iter(q!(1..=10i32))
.partition(q!(|x| *x > *threshold_ref));
above.for_each(q!(|_| {}));
below.for_each(q!(|_| {}));
threshold.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_partition_with_downstream_ops() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let threshold = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let threshold_ref = threshold.by_ref();
let (above, below) = node
.source_iter(q!(1..=10i32))
.partition(q!(|x| *x > *threshold_ref));
above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
threshold.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
}