use std::cell::RefCell;
use std::marker::PhantomData;
use std::rc::Rc;
use proc_macro2::Span;
use quote::quote;
use stageleft::runtime_support::{FreeVariableWithContextWithProps, QuoteTokens};
use crate::compile::ir::{AccessCounter, HydroNode, SharedNode};
use crate::location::Location;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum HandoffRefKind {
Singleton,
Optional,
Vec,
}
thread_local! {
static CAPTURED_REFS: RefCell<Option<Vec<(HydroNode, bool)>>> = const { RefCell::new(None) };
}
pub(crate) fn handoff_ref_ident(index: usize) -> syn::Ident {
syn::Ident::new(
&format!("__hydro_singleton_ref_{}", index),
Span::call_site(),
)
}
pub fn with_ref_capture(
f: impl FnOnce() -> crate::compile::ir::DebugExpr,
) -> crate::compile::ir::ClosureExpr {
CAPTURED_REFS.with(|cell| {
let prev = cell.borrow_mut().replace(Vec::new());
assert!(
prev.is_none(),
"nested handoff reference capture scopes are not supported"
);
});
let expr = (f)();
let captured_refs = CAPTURED_REFS.with(|cell| cell.borrow_mut().take().unwrap());
crate::compile::ir::ClosureExpr::new(expr, captured_refs)
}
fn register_handoff_ref(
ir_node: &RefCell<HydroNode>,
is_mut: bool,
kind: HandoffRefKind,
) -> syn::Ident {
CAPTURED_REFS.with(|cell| {
let mut guard = cell.borrow_mut();
let refs = guard.as_mut().expect(
"HandoffRef used inside q!() but no reference capture scope is active. \
This is a bug — reference capture should be set up by the operator that uses q!().",
);
let index = refs.len();
let ident = handoff_ref_ident(index);
let metadata = ir_node.borrow().metadata().clone();
if !matches!(&*ir_node.borrow(), HydroNode::Reference { .. }) {
let orig = ir_node.replace(HydroNode::Placeholder);
*ir_node.borrow_mut() = HydroNode::Reference {
inner: SharedNode(Rc::new(RefCell::new(orig))),
kind,
access_counter: AccessCounter::new(),
metadata: metadata.clone(),
};
}
let borrow: std::cell::Ref<'_, HydroNode> = ir_node.borrow();
let HydroNode::Reference {
inner,
access_counter,
..
} = &*borrow
else {
unreachable!()
};
let group = access_counter.next_group(is_mut);
refs.push((
HydroNode::Reference {
inner: SharedNode(Rc::clone(&inner.0)),
kind,
access_counter: group,
metadata,
},
is_mut,
));
ident
})
}
macro_rules! define_handoff_ref {
(
$(
$(#[$meta:meta])*
$name:ident, $is_mut:expr, $kind:expr, $output:ty
)+
) => {
$(
$(#[$meta])*
pub struct $name<'a, 'slf, T, L> {
pub(crate) ir_node: &'slf RefCell<HydroNode>,
_phantom: PhantomData<(&'a T, L)>,
}
impl<'slf, T, L> $name<'_, 'slf, T, L> {
pub(crate) fn new(ir_node: &'slf RefCell<HydroNode>) -> Self {
Self {
ir_node,
_phantom: PhantomData,
}
}
}
impl<T, L> Copy for $name<'_, '_, T, L> {}
impl<T, L> Clone for $name<'_, '_, T, L> {
fn clone(&self) -> Self {
*self
}
}
impl<'a, 'slf, T: 'a, L> FreeVariableWithContextWithProps<L, ()> for $name<'a, 'slf, T, L>
where
L: Location<'a>,
{
type O = $output;
fn to_tokens(self, _ctx: &L) -> (QuoteTokens, ()) {
let ident = register_handoff_ref(
self.ir_node,
$is_mut,
$kind,
);
(
QuoteTokens {
prelude: None,
expr: Some(quote!(#ident)),
},
(),
)
}
}
)+
};
}
#[stageleft::export(
SingletonRef,
SingletonMut,
OptionalRef,
OptionalMut,
StreamRef,
StreamMut
)]
define_handoff_ref!(
SingletonRef, false, HandoffRefKind::Singleton, &'a T
SingletonMut, true, HandoffRefKind::Singleton, &'a mut T
OptionalRef, false, HandoffRefKind::Optional, &'a Option<T>
OptionalMut, true, HandoffRefKind::Optional, &'a mut Option<T>
StreamRef, false, HandoffRefKind::Vec, &'a Vec<T>
StreamMut, true, HandoffRefKind::Vec, &'a mut Vec<T>
);
#[cfg(test)]
#[cfg(feature = "build")]
mod tests {
use stageleft::q;
use crate::compile::builder::FlowBuilder;
use crate::location::Location;
struct P1 {}
#[test]
fn singleton_by_ref_compiles() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_ref = my_count.by_ref();
node.source_iter(q!(1..=3i32))
.map(q!(|x| x + *count_ref))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_non_copy() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_vec = node.source_iter(q!(0..5i32)).fold(
q!(|| Vec::<i32>::new()),
q!(|acc: &mut Vec<i32>, x| acc.push(x)),
);
let vec_ref = my_vec.by_ref();
node.source_iter(q!(1..=3i32))
.map(q!(|x| x + vec_ref.len() as i32))
.for_each(q!(|_| {}));
my_vec.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_filter() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let threshold = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let threshold_ref = threshold.by_ref();
node.source_iter(q!(1..=10i32))
.filter(q!(|x| *x > *threshold_ref))
.for_each(q!(|_| {}));
threshold.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_flat_map() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let count = node
.source_iter(q!(0..3i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
let count_ref = count.by_ref();
node.source_iter(q!(1..=2i32))
.flat_map_ordered(q!(|x| (0..*count_ref).map(move |i| x + i)))
.for_each(q!(|_| {}));
count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_inspect() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, _| *acc += 1));
let count_ref = count.by_ref();
node.source_iter(q!(1..=3i32))
.inspect(q!(|x| println!("count={}, x={}", *count_ref, x)))
.for_each(q!(|_| {}));
count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_partition() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let threshold = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let threshold_ref = threshold.by_ref();
let (above, below) = node
.source_iter(q!(1..=10i32))
.partition(q!(|x| *x > *threshold_ref));
above.for_each(q!(|_| {}));
below.for_each(q!(|_| {}));
threshold.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_partition_with_downstream_ops() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let threshold = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let threshold_ref = threshold.by_ref();
let (above, below) = node
.source_iter(q!(1..=10i32))
.partition(q!(|x| *x > *threshold_ref));
above.map(q!(|x| x * 2)).for_each(q!(|_| {}));
below.map(q!(|x| x + 100)).for_each(q!(|_| {}));
threshold.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_mut_compiles() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_mut = my_count.by_mut();
node.source_iter(q!(1..=3i32))
.map(q!(|x| {
*count_mut += x;
x
}))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn optional_by_ref_compiles() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_opt = node.source_iter(q!(0..5i32)).reduce(q!(|a, b| *a += b));
let opt_ref = my_opt.by_ref();
node.source_iter(q!(1..=3i32))
.map(q!(|x| x + opt_ref.unwrap_or(0)))
.for_each(q!(|_| {}));
my_opt.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn stream_by_ref_compiles() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_stream = node.source_iter(q!(0..5i32));
let stream_ref = my_stream.by_ref();
node.source_iter(q!(1..=3i32))
.map(q!(|x| x + stream_ref.len() as i32))
.for_each(q!(|_| {}));
my_stream.for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_mut_filter() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_mut = my_count.by_mut();
node.source_iter(q!(1..=3i32))
.filter(q!(|x| {
*count_mut += *x;
*count_mut > 0
}))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_mut_flat_map() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_mut = my_count.by_mut();
node.source_iter(q!(1..=3i32))
.flat_map_ordered(q!(|x| {
*count_mut += x;
vec![*count_mut]
}))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_mut_filter_map() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_mut = my_count.by_mut();
node.source_iter(q!(1..=3i32))
.filter_map(q!(|x| {
*count_mut += x;
Some(*count_mut)
}))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_mut_inspect() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_mut = my_count.by_mut();
node.source_iter(q!(1..=3i32))
.inspect(q!(|x| {
*count_mut += *x;
}))
.for_each(q!(|_| {}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_ref_for_each() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_ref = my_count.by_ref();
node.source_iter(q!(1..=3i32))
.for_each(q!(|x| println!("{}", x + *count_ref)));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
#[test]
fn singleton_by_mut_for_each() {
let mut flow = FlowBuilder::new();
let node = flow.process::<P1>();
let my_count = node
.source_iter(q!(0..5i32))
.fold(q!(|| 0i32), q!(|acc: &mut i32, x| *acc += x));
let count_mut = my_count.by_mut();
node.source_iter(q!(1..=3i32)).for_each(q!(|x| {
*count_mut += x;
}));
my_count.into_stream().for_each(q!(|_| {}));
let _built = flow.finalize();
}
}