1use rearch::{SideEffect, SideEffectRegistrar};
2use std::{
3 any::Any,
4 cell::{Cell, OnceCell},
5 sync::Arc,
6};
7
8type SideEffectTxn<'f> = Box<dyn 'f + FnOnce()>;
9type SideEffectTxnRunner = Arc<dyn Send + Sync + Fn(SideEffectTxn)>;
10type SideEffectStateMutation<'f> = Box<dyn 'f + FnOnce(&mut dyn Any)>;
11
12type MultiSideEffectStateMutation<'f> = Box<dyn 'f + FnOnce(&mut [OnceCell<Box<dyn Any + Send>>])>;
13type MultiSideEffectStateMutationRunner = Arc<dyn Send + Sync + Fn(MultiSideEffectStateMutation)>;
14
15pub fn multi<const LENGTH: usize>(
25) -> impl for<'a> SideEffect<Api<'a> = MultiSideEffectRegistrar<'a>> {
26 MultiEffectLifetimeFixer(multi_impl::<LENGTH>)
27}
28
29fn multi_impl<const LENGTH: usize>(register: SideEffectRegistrar) -> MultiSideEffectRegistrar {
30 let default_array: [OnceCell<Box<dyn Any + Send>>; LENGTH] =
31 std::array::from_fn(|_| OnceCell::new());
32 let (curr_slice, mutation_runner, run_txn) = register.raw(default_array);
33 let multi_mutation_runner = Arc::new(move |mutation: MultiSideEffectStateMutation| {
34 mutation_runner(Box::new(move |data| mutation(data)));
35 });
36 MultiSideEffectRegistrar {
37 curr_index: Cell::new(0),
38 curr_slice: Cell::new(curr_slice),
39 multi_mutation_runner,
40 run_txn,
41 }
42}
43
44#[allow(clippy::module_name_repetitions)] pub struct MultiSideEffectRegistrar<'a> {
49 curr_index: Cell<usize>,
51 curr_slice: Cell<&'a mut [OnceCell<Box<dyn Any + Send>>]>,
52 multi_mutation_runner: MultiSideEffectStateMutationRunner,
53 run_txn: SideEffectTxnRunner,
54}
55
56impl<'a> MultiSideEffectRegistrar<'a> {
57 pub fn register<S: SideEffect>(&'a self, effect: S) -> S::Api<'a> {
63 let (curr_data, rest_slice) =
64 self.curr_slice.take().split_first_mut().unwrap_or_else(|| {
65 panic!(
66 "multi was not given a long enough length; it should be at least {}",
67 self.curr_index.get() + 1
68 );
69 });
70
71 let mutation_runner = {
72 let curr_index = self.curr_index.get();
73 let multi_mutation_runner = Arc::clone(&self.multi_mutation_runner);
74 Arc::new(move |mutation: SideEffectStateMutation| {
75 multi_mutation_runner(Box::new(|multi_data_slice| {
76 let data = &mut **multi_data_slice[curr_index]
77 .get_mut()
78 .expect("To trigger rebuild, side effect must've been registered");
79 mutation(data);
80 }));
81 })
82 };
83
84 self.curr_index.set(self.curr_index.get() + 1);
85 self.curr_slice.replace(rest_slice);
86
87 SideEffectRegistrar::new(curr_data, mutation_runner, Arc::clone(&self.run_txn))
88 .register(effect)
89 }
90}
91
92struct MultiEffectLifetimeFixer<F>(F);
94impl<F> SideEffect for MultiEffectLifetimeFixer<F>
95where
96 F: FnOnce(SideEffectRegistrar) -> MultiSideEffectRegistrar,
97{
98 type Api<'a> = MultiSideEffectRegistrar<'a>;
99 fn build(self, registrar: SideEffectRegistrar) -> Self::Api<'_> {
100 self.0(registrar)
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use crate::*;
107 use rearch::{CapsuleHandle, Container};
108
109 #[test]
110 #[should_panic(expected = "multi was not given a long enough length; it should be at least 1")]
111 fn multi_register_undersized() {
112 fn capsule(CapsuleHandle { register, .. }: CapsuleHandle) -> bool {
113 let register = register.register(multi::<0>());
114 register.register(is_first_build())
115 }
116
117 Container::new().read(capsule);
118 }
119
120 #[test]
121 fn multi_register_right_size() {
122 fn capsule(CapsuleHandle { register, .. }: CapsuleHandle) -> bool {
123 let register = register.register(multi::<1>());
124 register.register(is_first_build())
125 }
126
127 assert!(Container::new().read(capsule));
128 }
129
130 #[test]
131 fn multi_register_oversized() {
132 fn capsule(
133 CapsuleHandle { register, .. }: CapsuleHandle,
134 ) -> (u32, u32, impl CData + Fn(u32)) {
135 let register = register.register(multi::<16>());
136 let (x, set_x) = register.register(state::<Cloned<_>>(0));
137 let num_builds = register.register(value::<MutRef<_>>(0));
138 *num_builds += 1;
139 (*num_builds, x, set_x)
140 }
141
142 let container = Container::new();
143 let (builds, x, set_x) = container.read(capsule);
144 assert_eq!(builds, 1);
145 assert_eq!(x, 0);
146 set_x(123);
147 let (builds, x, _) = container.read(capsule);
148 assert_eq!(builds, 2);
149 assert_eq!(x, 123);
150 }
151}