#![allow(clippy::arc_with_non_send_sync)]
#![allow(clippy::collapsible_if, clippy::collapsible_match)]
#![allow(clippy::too_many_arguments, clippy::too_many_lines)]
use std::cell::Cell;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex, Weak};
use ahash::AHashMap;
use graphrefly_core::{Core, FnId, HandleId, Message, NodeId, Sink};
use smallvec::SmallVec;
use super::producer::{ProducerBinding, ProducerCtx, ProducerEmitter, SubGuard};
pub type ProjectFn = Box<dyn Fn(HandleId) -> NodeId + Send + Sync>;
pub trait HigherOrderBinding: ProducerBinding {
fn register_project(&self, project: ProjectFn) -> FnId;
fn invoke_project(&self, fn_id: FnId, value: HandleId) -> NodeId;
}
fn build_inner_sink(
em: ProducerEmitter,
producer_binding: Arc<dyn ProducerBinding>,
producer_id: NodeId,
on_inner_complete: Arc<dyn Fn()>,
on_inner_error: Arc<dyn Fn(HandleId)>,
) -> Sink {
Arc::new(move |msgs: &[Message]| {
enum Action {
Emit(HandleId),
Complete,
Error(HandleId),
Invalidate,
Teardown,
}
let mut actions: SmallVec<[Action; 4]> = SmallVec::new();
for m in msgs {
match m.tier() {
3 => {
if let Some(h) = m.payload_handle() {
producer_binding.retain_handle(h);
actions.push(Action::Emit(h));
}
}
4 => {
actions.push(Action::Invalidate);
}
5 => {
if let Some(h) = m.payload_handle() {
producer_binding.retain_handle(h);
actions.push(Action::Error(h));
} else {
actions.push(Action::Complete);
}
}
6 => {
actions.push(Action::Teardown);
}
_ => {}
}
}
for action in actions {
match action {
Action::Emit(h) => em.emit_or_defer(producer_id, h),
Action::Complete => on_inner_complete(),
Action::Error(h) => on_inner_error(h),
Action::Invalidate => {
let _ = em.defer(move |c| c.invalidate(producer_id));
}
Action::Teardown => {
let _ = em.defer(move |c| c.teardown(producer_id));
}
}
}
})
}
struct SwitchState {
inner_sub: Option<SubGuard>,
epoch: u64,
source_done: bool,
terminated: bool,
}
impl SwitchState {
fn new() -> Self {
Self {
inner_sub: None,
epoch: 0,
source_done: false,
terminated: false,
}
}
}
#[must_use]
pub fn switch_map(
core: &Core,
binding: &Arc<dyn HigherOrderBinding>,
source: NodeId,
project: ProjectFn,
) -> NodeId {
let project_fn_id = binding.register_project(project);
let binding_weak: Weak<dyn HigherOrderBinding> = Arc::downgrade(binding);
let producer_binding_weak: Weak<dyn ProducerBinding> =
Arc::downgrade(&(binding.clone() as Arc<dyn ProducerBinding>));
let build = Box::new(move |ctx: ProducerCtx<'_>| {
let producer_id = ctx.node_id();
let (Some(binding_clone), Some(producer_binding)) =
(binding_weak.upgrade(), producer_binding_weak.upgrade())
else {
return;
};
let em = ctx.emitter();
let state: Arc<Mutex<SwitchState>> = Arc::new(Mutex::new(SwitchState::new()));
let state_for_outer = state.clone();
let em_for_outer = em.clone();
let binding_for_outer = binding_clone.clone();
let producer_binding_for_outer = producer_binding.clone();
let outer_sink: Sink = Arc::new(move |msgs| {
#[derive(Default)]
struct Plan {
latest_outer_h: Option<HandleId>,
latest_retained: bool,
self_complete: bool,
self_error: Option<HandleId>,
}
let mut plan = Plan::default();
{
let mut s = state_for_outer.lock().unwrap();
if s.terminated {
return;
}
for m in msgs {
match m.tier() {
3 => {
if let Some(h) = m.payload_handle() {
plan.latest_outer_h = Some(h);
}
}
5 => {
if let Some(h) = m.payload_handle() {
if !s.terminated {
s.terminated = true;
binding_for_outer.retain_handle(h);
plan.self_error = Some(h);
}
} else {
s.source_done = true;
if s.inner_sub.is_none()
&& plan.latest_outer_h.is_none()
&& !s.terminated
{
s.terminated = true;
plan.self_complete = true;
}
}
}
_ => {} }
}
if let Some(h) = plan.latest_outer_h {
if !s.terminated {
binding_for_outer.retain_handle(h);
plan.latest_retained = true;
}
}
}
if plan.latest_retained {
let outer_h = plan
.latest_outer_h
.expect("latest_retained implies latest_outer_h is Some");
let my_epoch = {
let mut s = state_for_outer.lock().unwrap();
let prev = s.inner_sub.take();
s.epoch += 1;
let e = s.epoch;
drop(s);
drop(prev); e
};
let inner_node = binding_for_outer.invoke_project(project_fn_id, outer_h);
binding_for_outer.release_handle(outer_h);
let on_complete = make_switch_on_complete(
state_for_outer.clone(),
em_for_outer.clone(),
producer_id,
);
let on_error = make_switch_on_error(
state_for_outer.clone(),
em_for_outer.clone(),
producer_id,
);
let on_complete_for_dead = on_complete.clone();
let inner_sink = build_inner_sink(
em_for_outer.clone(),
producer_binding_for_outer.clone(),
producer_id,
on_complete,
on_error,
);
let state_sub = state_for_outer.clone();
let em_guard = em_for_outer.clone();
let _ = em_for_outer.defer(move |c| {
match c.try_subscribe(inner_node, inner_sink) {
Ok(sub) => {
let guard = SubGuard::new(inner_node, sub, em_guard);
let to_drop = {
let mut s = state_sub.lock().unwrap();
if s.terminated || s.epoch != my_epoch {
Some(guard)
} else {
s.inner_sub.replace(guard)
}
};
drop(to_drop);
}
Err(graphrefly_core::SubscribeError::TornDown { .. }) => {
on_complete_for_dead();
}
Err(graphrefly_core::SubscribeError::PartitionOrderViolation(_)) => {
panic!(
"switch_map inner subscribe: partition-order \
violation inside em.defer — substrate invariant broken"
);
}
}
});
}
if plan.self_complete {
em_for_outer.complete_or_defer(producer_id);
} else if let Some(h) = plan.self_error {
em_for_outer.error_or_defer(producer_id, h);
}
});
ctx.subscribe_to(source, outer_sink);
});
let fn_id = binding.register_producer_build(build);
core.register_producer(fn_id)
.expect("invariant: register_producer has no deps; no error variants reachable")
}
fn make_switch_on_complete(
state: Arc<Mutex<SwitchState>>,
em: ProducerEmitter,
producer_id: NodeId,
) -> Arc<dyn Fn()> {
Arc::new(move || {
let prev_inner;
let mut should_complete = false;
{
let mut s = state.lock().unwrap();
if s.terminated {
return;
}
prev_inner = s.inner_sub.take();
if s.source_done && !s.terminated {
s.terminated = true;
should_complete = true;
}
}
drop(prev_inner); if should_complete {
em.complete_or_defer(producer_id);
}
})
}
fn make_switch_on_error(
state: Arc<Mutex<SwitchState>>,
em: ProducerEmitter,
producer_id: NodeId,
) -> Arc<dyn Fn(HandleId)> {
Arc::new(move |h| {
let prev_inner;
{
let mut s = state.lock().unwrap();
if s.terminated {
return;
}
s.terminated = true;
prev_inner = s.inner_sub.take();
}
drop(prev_inner);
em.error_or_defer(producer_id, h);
})
}
struct ExhaustState {
inner_sub: Option<SubGuard>,
pending: bool,
source_done: bool,
terminated: bool,
}
impl ExhaustState {
fn new() -> Self {
Self {
inner_sub: None,
pending: false,
source_done: false,
terminated: false,
}
}
}
#[must_use]
pub fn exhaust_map(
core: &Core,
binding: &Arc<dyn HigherOrderBinding>,
source: NodeId,
project: ProjectFn,
) -> NodeId {
let project_fn_id = binding.register_project(project);
let binding_weak: Weak<dyn HigherOrderBinding> = Arc::downgrade(binding);
let producer_binding_weak: Weak<dyn ProducerBinding> =
Arc::downgrade(&(binding.clone() as Arc<dyn ProducerBinding>));
let build = Box::new(move |ctx: ProducerCtx<'_>| {
let producer_id = ctx.node_id();
let (Some(binding_clone), Some(producer_binding)) =
(binding_weak.upgrade(), producer_binding_weak.upgrade())
else {
return;
};
let em = ctx.emitter();
let state: Arc<Mutex<ExhaustState>> = Arc::new(Mutex::new(ExhaustState::new()));
let state_for_outer = state.clone();
let em_for_outer = em.clone();
let binding_for_outer = binding_clone.clone();
let producer_binding_for_outer = producer_binding.clone();
let outer_sink: Sink = Arc::new(move |msgs| {
#[derive(Default)]
struct Plan {
first_outer_h: Option<HandleId>,
first_retained: bool,
self_complete: bool,
self_error: Option<HandleId>,
}
let mut plan = Plan::default();
{
let mut s = state_for_outer.lock().unwrap();
if s.terminated {
return;
}
for m in msgs {
match m.tier() {
3 => {
if let Some(h) = m.payload_handle() {
if s.inner_sub.is_none()
&& !s.pending
&& plan.first_outer_h.is_none()
{
binding_for_outer.retain_handle(h);
plan.first_outer_h = Some(h);
plan.first_retained = true;
s.pending = true;
}
}
}
5 => {
if let Some(h) = m.payload_handle() {
if !s.terminated {
s.terminated = true;
if plan.first_retained {
if let Some(h0) = plan.first_outer_h.take() {
binding_for_outer.release_handle(h0);
plan.first_retained = false;
}
}
binding_for_outer.retain_handle(h);
plan.self_error = Some(h);
}
} else {
s.source_done = true;
if s.inner_sub.is_none()
&& plan.first_outer_h.is_none()
&& !s.terminated
{
s.terminated = true;
plan.self_complete = true;
}
}
}
_ => {} }
}
}
if plan.first_retained {
let outer_h = plan
.first_outer_h
.expect("first_retained implies first_outer_h is Some");
let inner_node = binding_for_outer.invoke_project(project_fn_id, outer_h);
binding_for_outer.release_handle(outer_h);
let on_complete = make_exhaust_on_complete(
state_for_outer.clone(),
em_for_outer.clone(),
producer_id,
);
let on_error = make_exhaust_on_error(
state_for_outer.clone(),
em_for_outer.clone(),
producer_id,
);
let on_complete_for_dead = on_complete.clone();
let inner_sink = build_inner_sink(
em_for_outer.clone(),
producer_binding_for_outer.clone(),
producer_id,
on_complete,
on_error,
);
let state_sub = state_for_outer.clone();
let em_guard = em_for_outer.clone();
let _ =
em_for_outer.defer(move |c| match c.try_subscribe(inner_node, inner_sink) {
Ok(sub) => {
let guard = SubGuard::new(inner_node, sub, em_guard);
let to_drop = {
let mut s = state_sub.lock().unwrap();
if s.terminated {
Some(guard)
} else {
s.inner_sub.replace(guard)
}
};
drop(to_drop);
}
Err(graphrefly_core::SubscribeError::TornDown { .. }) => {
on_complete_for_dead();
}
Err(graphrefly_core::SubscribeError::PartitionOrderViolation(_)) => {
panic!(
"exhaust_map inner subscribe: partition-order \
violation inside em.defer — substrate invariant broken"
);
}
});
}
if plan.self_complete {
em_for_outer.complete_or_defer(producer_id);
} else if let Some(h) = plan.self_error {
em_for_outer.error_or_defer(producer_id, h);
}
});
ctx.subscribe_to(source, outer_sink);
});
let fn_id = binding.register_producer_build(build);
core.register_producer(fn_id)
.expect("invariant: register_producer has no deps; no error variants reachable")
}
fn make_exhaust_on_complete(
state: Arc<Mutex<ExhaustState>>,
em: ProducerEmitter,
producer_id: NodeId,
) -> Arc<dyn Fn()> {
Arc::new(move || {
let prev_inner;
let mut should_complete = false;
{
let mut s = state.lock().unwrap();
if s.terminated {
return;
}
prev_inner = s.inner_sub.take();
s.pending = false;
if s.source_done && !s.terminated {
s.terminated = true;
should_complete = true;
}
}
drop(prev_inner);
if should_complete {
em.complete_or_defer(producer_id);
}
})
}
fn make_exhaust_on_error(
state: Arc<Mutex<ExhaustState>>,
em: ProducerEmitter,
producer_id: NodeId,
) -> Arc<dyn Fn(HandleId)> {
Arc::new(move |h| {
let prev_inner;
{
let mut s = state.lock().unwrap();
if s.terminated {
return;
}
s.terminated = true;
prev_inner = s.inner_sub.take();
}
drop(prev_inner);
em.error_or_defer(producer_id, h);
})
}
thread_local! {
static MERGE_DRAIN_ACTIVE: Cell<bool> = const { Cell::new(false) };
}
struct MergeMapState {
active: u32,
buffer: VecDeque<HandleId>,
inner_subs: AHashMap<u64, SubGuard>,
pending_inner_ids: ahash::AHashSet<u64>,
next_inner_id: u64,
source_done: bool,
terminated: bool,
}
impl MergeMapState {
fn new() -> Self {
Self {
active: 0,
buffer: VecDeque::new(),
inner_subs: AHashMap::new(),
pending_inner_ids: ahash::AHashSet::new(),
next_inner_id: 0,
source_done: false,
terminated: false,
}
}
}
#[must_use]
pub fn merge_map(
core: &Core,
binding: &Arc<dyn HigherOrderBinding>,
source: NodeId,
project: ProjectFn,
) -> NodeId {
merge_map_with_concurrency(core, binding, source, project, None)
}
#[must_use]
pub fn concat_map(
core: &Core,
binding: &Arc<dyn HigherOrderBinding>,
source: NodeId,
project: ProjectFn,
) -> NodeId {
merge_map_with_concurrency(core, binding, source, project, Some(1))
}
#[must_use]
pub fn merge_map_with_concurrency(
core: &Core,
binding: &Arc<dyn HigherOrderBinding>,
source: NodeId,
project: ProjectFn,
concurrency: Option<u32>,
) -> NodeId {
let project_fn_id = binding.register_project(project);
let binding_weak: Weak<dyn HigherOrderBinding> = Arc::downgrade(binding);
let producer_binding_weak: Weak<dyn ProducerBinding> =
Arc::downgrade(&(binding.clone() as Arc<dyn ProducerBinding>));
let build = Box::new(move |ctx: ProducerCtx<'_>| {
let producer_id = ctx.node_id();
let (Some(binding_clone), Some(producer_binding)) =
(binding_weak.upgrade(), producer_binding_weak.upgrade())
else {
return;
};
let em = ctx.emitter();
let state: Arc<Mutex<MergeMapState>> = Arc::new(Mutex::new(MergeMapState::new()));
let state_for_outer = state.clone();
let em_for_outer = em.clone();
let binding_for_outer = binding_clone.clone();
let producer_binding_for_outer = producer_binding.clone();
let outer_sink: Sink = Arc::new(move |msgs| {
let mut error_action: Option<HandleId> = None;
let mut self_complete_now = false;
{
let mut s = state_for_outer.lock().unwrap();
if s.terminated {
return;
}
for m in msgs {
match m.tier() {
3 => {
if let Some(h) = m.payload_handle() {
binding_for_outer.retain_handle(h);
s.buffer.push_back(h);
}
}
5 => {
if let Some(h) = m.payload_handle() {
if !s.terminated {
s.terminated = true;
binding_for_outer.retain_handle(h);
while let Some(q) = s.buffer.pop_front() {
binding_for_outer.release_handle(q);
}
error_action = Some(h);
}
} else {
s.source_done = true;
if s.active == 0 && s.buffer.is_empty() && !s.terminated {
s.terminated = true;
self_complete_now = true;
}
}
}
_ => {} }
}
}
if let Some(h) = error_action {
em_for_outer.error_or_defer(producer_id, h);
return;
}
if self_complete_now {
em_for_outer.complete_or_defer(producer_id);
return;
}
drain_merge_buffer(
&state_for_outer,
&em_for_outer,
&binding_for_outer,
&producer_binding_for_outer,
producer_id,
project_fn_id,
concurrency,
);
});
ctx.subscribe_to(source, outer_sink);
});
let fn_id = binding.register_producer_build(build);
core.register_producer(fn_id)
.expect("invariant: register_producer has no deps; no error variants reachable")
}
fn drain_merge_buffer(
state: &Arc<Mutex<MergeMapState>>,
em: &ProducerEmitter,
binding: &Arc<dyn HigherOrderBinding>,
producer_binding: &Arc<dyn ProducerBinding>,
producer_id: NodeId,
project_fn_id: FnId,
concurrency: Option<u32>,
) {
if MERGE_DRAIN_ACTIVE.with(|f| f.replace(true)) {
return;
}
loop {
let h_and_id;
let mut should_self_complete = false;
{
let mut s = state.lock().unwrap();
if s.terminated {
MERGE_DRAIN_ACTIVE.with(|f| f.set(false));
return;
}
let allowed = match concurrency {
None => true,
Some(n) => s.active < n,
};
if !allowed {
MERGE_DRAIN_ACTIVE.with(|f| f.set(false));
return;
}
if let Some(h) = s.buffer.pop_front() {
s.active += 1;
let id = s.next_inner_id;
s.next_inner_id += 1;
s.pending_inner_ids.insert(id);
h_and_id = Some((h, id));
} else if s.source_done && s.active == 0 && !s.terminated {
s.terminated = true;
should_self_complete = true;
h_and_id = None;
} else {
h_and_id = None;
}
}
if should_self_complete {
MERGE_DRAIN_ACTIVE.with(|f| f.set(false));
em.complete_or_defer(producer_id);
return;
}
let Some((outer_h, inner_id)) = h_and_id else {
MERGE_DRAIN_ACTIVE.with(|f| f.set(false));
return;
};
let inner_node = binding.invoke_project(project_fn_id, outer_h);
binding.release_handle(outer_h);
let on_complete = make_merge_on_complete(
state.clone(),
em.clone(),
binding.clone(),
producer_binding.clone(),
producer_id,
project_fn_id,
inner_id,
concurrency,
);
let on_error = make_merge_on_error(state.clone(), em.clone(), binding.clone(), producer_id);
let on_complete_for_dead = on_complete.clone();
let inner_sink = build_inner_sink(
em.clone(),
producer_binding.clone(),
producer_id,
on_complete,
on_error,
);
let state_sub = state.clone();
let em_guard = em.clone();
let _ = em.defer(move |c| {
match c.try_subscribe(inner_node, inner_sink) {
Ok(sub) => {
let guard = SubGuard::new(inner_node, sub, em_guard);
let to_drop = {
let mut s = state_sub.lock().unwrap();
if s.terminated || !s.pending_inner_ids.remove(&inner_id) {
Some(guard)
} else {
s.inner_subs.insert(inner_id, guard);
None
}
};
drop(to_drop);
}
Err(graphrefly_core::SubscribeError::TornDown { .. }) => {
on_complete_for_dead();
}
Err(graphrefly_core::SubscribeError::PartitionOrderViolation(_)) => {
panic!(
"merge_map inner subscribe: partition-order \
violation inside em.defer — substrate invariant broken"
);
}
}
});
}
}
fn make_merge_on_complete(
state: Arc<Mutex<MergeMapState>>,
em: ProducerEmitter,
binding: Arc<dyn HigherOrderBinding>,
producer_binding: Arc<dyn ProducerBinding>,
producer_id: NodeId,
project_fn_id: FnId,
this_inner_id: u64,
concurrency: Option<u32>,
) -> Arc<dyn Fn()> {
Arc::new(move || {
let removed_sub;
{
let mut s = state.lock().unwrap();
if s.terminated {
return;
}
s.active -= 1;
s.pending_inner_ids.remove(&this_inner_id);
removed_sub = s.inner_subs.remove(&this_inner_id);
}
drop(removed_sub);
drain_merge_buffer(
&state,
&em,
&binding,
&producer_binding,
producer_id,
project_fn_id,
concurrency,
);
})
}
fn make_merge_on_error(
state: Arc<Mutex<MergeMapState>>,
em: ProducerEmitter,
binding: Arc<dyn HigherOrderBinding>,
producer_id: NodeId,
) -> Arc<dyn Fn(HandleId)> {
Arc::new(move |h| {
let removed_subs;
let buffered_to_release;
{
let mut s = state.lock().unwrap();
if s.terminated {
return;
}
s.terminated = true;
removed_subs = s.inner_subs.drain().map(|(_, sub)| sub).collect::<Vec<_>>();
s.pending_inner_ids.clear();
buffered_to_release = s.buffer.drain(..).collect::<Vec<_>>();
}
drop(removed_subs); for h_b in buffered_to_release {
binding.release_handle(h_b);
}
em.error_or_defer(producer_id, h);
})
}
const _: fn() = || {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ProjectFn>();
};