Skip to main content

aimdb_core/transform/
join.rs

1use core::any::Any;
2use core::fmt::Debug;
3use core::marker::PhantomData;
4
5use alloc::{
6    boxed::Box,
7    string::{String, ToString},
8    sync::Arc,
9    vec::Vec,
10};
11
12use aimdb_executor::{ExecutorResult, JoinFanInRuntime, JoinQueue, JoinReceiver, JoinSender};
13
14use crate::transform::TransformDescriptor;
15use crate::typed_record::BoxFuture;
16
17// ============================================================================
18// JoinTrigger
19// ============================================================================
20
21/// Identifies which input produced a value in a multi-input join transform.
22///
23/// Passed to the event loop inside the closure registered with [`JoinBuilder::on_triggers`].
24/// Use [`JoinTrigger::index`] to branch on the source input and
25/// [`JoinTrigger::as_input`] to downcast the value to the concrete type.
26pub enum JoinTrigger {
27    Input {
28        index: usize,
29        value: Box<dyn Any + Send>,
30    },
31}
32
33impl JoinTrigger {
34    pub fn as_input<T: 'static>(&self) -> Option<&T> {
35        match self {
36            JoinTrigger::Input { value, .. } => value.downcast_ref::<T>(),
37        }
38    }
39
40    pub fn index(&self) -> usize {
41        match self {
42            JoinTrigger::Input { index, .. } => *index,
43        }
44    }
45}
46
47// ============================================================================
48// JoinEventRx — type-erased trigger receiver
49// ============================================================================
50
51/// Type-erased receiver for join trigger events.
52///
53/// Obtained as the first argument to the [`JoinBuilder::on_triggers`] closure.
54/// Call `.recv().await` in a loop to consume trigger events from all input forwarders.
55/// Returns `Err` when all input forwarders have exited and the channel is closed.
56///
57/// ```rust,ignore
58/// .on_triggers(|mut rx, producer| async move {
59///     let mut last_a: Option<f32> = None;
60///     let mut last_b: Option<f32> = None;
61///     while let Ok(trigger) = rx.recv().await {
62///         match trigger.index() {
63///             0 => last_a = trigger.as_input::<InputA>().copied(),
64///             1 => last_b = trigger.as_input::<InputB>().copied(),
65///             _ => {}
66///         }
67///         if let (Some(a), Some(b)) = (last_a, last_b) {
68///             producer.produce(compute(a, b)).await.ok();
69///         }
70///     }
71/// })
72/// ```
73pub struct JoinEventRx {
74    inner: Box<dyn DynJoinRx + Send>,
75}
76
77impl JoinEventRx {
78    fn new<R: JoinReceiver<JoinTrigger> + Send + 'static>(inner: R) -> Self {
79        Self {
80            inner: Box::new(inner),
81        }
82    }
83
84    /// Receive the next trigger event.
85    ///
86    /// Returns `Ok(JoinTrigger)` when an input fires, or `Err` when all inputs are closed.
87    ///
88    /// # Runtime portability
89    ///
90    /// On Tokio and WASM, the channel closes once every input forwarder has
91    /// dropped its sender, and `recv` returns `Err`, ending any
92    /// `while let Ok(_) = rx.recv().await` loop.
93    ///
94    /// On Embassy the channel **never** closes — this branch is unreachable
95    /// and the loop runs for the device lifetime. Portable handlers should
96    /// not rely on the loop exiting to release resources.
97    pub async fn recv(&mut self) -> ExecutorResult<JoinTrigger> {
98        self.inner.recv_boxed().await
99    }
100}
101
102trait DynJoinRx: Send {
103    fn recv_boxed<'a>(&'a mut self) -> BoxFuture<'a, ExecutorResult<JoinTrigger>>;
104}
105
106impl<R: JoinReceiver<JoinTrigger> + Send> DynJoinRx for R {
107    fn recv_boxed<'a>(&'a mut self) -> BoxFuture<'a, ExecutorResult<JoinTrigger>> {
108        Box::pin(self.recv())
109    }
110}
111
112// ============================================================================
113// JoinBuilder → JoinPipeline
114// ============================================================================
115
116/// Type-erased factory for creating a forwarder task for one join input.
117#[cfg(feature = "alloc")]
118type JoinInputFactory<R> = Box<
119    dyn FnOnce(
120            Arc<crate::AimDb<R>>,
121            usize,
122            <<R as JoinFanInRuntime>::JoinQueue<JoinTrigger> as JoinQueue<JoinTrigger>>::Sender,
123        ) -> BoxFuture<'static, ()>
124        + Send
125        + Sync,
126>;
127
128/// Configures a multi-input join transform.
129///
130/// Available on any runtime that implements [`aimdb_executor::JoinFanInRuntime`].
131/// The fan-in queue (bounded channel between input forwarders and the trigger
132/// loop) is created by the runtime adapter at database startup — capacity is an
133/// internal constant chosen per adapter (Tokio: 64, Embassy: 8, WASM: 64).
134///
135/// Obtain via [`RecordRegistrar::transform_join`].
136#[cfg(feature = "alloc")]
137pub struct JoinBuilder<O, R: JoinFanInRuntime + 'static> {
138    inputs: Vec<(String, JoinInputFactory<R>)>,
139    _phantom: PhantomData<(O, R)>,
140}
141
142#[cfg(feature = "alloc")]
143impl<O, R> JoinBuilder<O, R>
144where
145    O: Send + Sync + Clone + Debug + 'static,
146    R: JoinFanInRuntime + 'static,
147{
148    pub(crate) fn new() -> Self {
149        Self {
150            inputs: Vec::new(),
151            _phantom: PhantomData,
152        }
153    }
154
155    /// Add a typed input to the join.
156    pub fn input<I>(mut self, key: impl crate::RecordKey) -> Self
157    where
158        I: Send + Sync + Clone + Debug + 'static,
159    {
160        let key_str = key.as_str().to_string();
161        let key_for_factory = key_str.clone();
162
163        type Tx<R> =
164            <<R as JoinFanInRuntime>::JoinQueue<JoinTrigger> as JoinQueue<JoinTrigger>>::Sender;
165
166        let factory: JoinInputFactory<R> = Box::new(
167            move |db: Arc<crate::AimDb<R>>, index: usize, tx: Tx<R>| {
168                Box::pin(async move {
169                    let consumer =
170                        crate::typed_api::Consumer::<I, R>::new(db, key_for_factory.clone());
171                    let mut reader = match consumer.subscribe() {
172                        Ok(r) => r,
173                        Err(_e) => {
174                            #[cfg(feature = "tracing")]
175                            tracing::error!(
176                                "🔄 Join input '{}' (index {}) subscription failed: {:?}",
177                                key_for_factory,
178                                index,
179                                _e
180                            );
181                            #[cfg(all(feature = "std", not(feature = "tracing")))]
182                            eprintln!(
183                                "AIMDB TRANSFORM ERROR: Join input '{}' (index {}) subscription failed: {:?}",
184                                key_for_factory, index, _e
185                            );
186                            return;
187                        }
188                    };
189
190                    while let Ok(value) = reader.recv().await {
191                        let trigger = JoinTrigger::Input {
192                            index,
193                            value: Box::new(value),
194                        };
195                        if tx.send(trigger).await.is_err() {
196                            break;
197                        }
198                    }
199                }) as BoxFuture<'static, ()>
200            },
201        );
202
203        self.inputs.push((key_str, factory));
204        self
205    }
206
207    /// Complete the pipeline by providing an async task that owns the event loop and state.
208    ///
209    /// The closure receives a [`JoinEventRx`] to read trigger events and a [`crate::Producer`]
210    /// to emit output values. Both are owned — moved into the `async move` block — so the
211    /// closure can freely hold borrows across `.await` points and maintain any state it needs.
212    ///
213    /// The task runs until all input forwarders close (i.e., all upstream records stop producing).
214    ///
215    /// ```rust,ignore
216    /// .on_triggers(|mut rx, producer| async move {
217    ///     let mut last_a: Option<f32> = None;
218    ///     let mut last_b: Option<f32> = None;
219    ///     while let Ok(trigger) = rx.recv().await {
220    ///         match trigger.index() {
221    ///             0 => last_a = trigger.as_input::<InputA>().copied(),
222    ///             1 => last_b = trigger.as_input::<InputB>().copied(),
223    ///             _ => {}
224    ///         }
225    ///         if let (Some(a), Some(b)) = (last_a, last_b) {
226    ///             producer.produce(compute(a, b)).await.ok();
227    ///         }
228    ///     }
229    /// })
230    /// ```
231    pub fn on_triggers<F, Fut>(self, handler: F) -> JoinPipeline<O, R>
232    where
233        F: FnOnce(JoinEventRx, crate::Producer<O, R>) -> Fut + Send + 'static,
234        Fut: core::future::Future<Output = ()> + Send + 'static,
235    {
236        let inputs = self.inputs;
237        let input_keys_for_descriptor: Vec<String> =
238            inputs.iter().map(|(k, _)| k.clone()).collect();
239
240        JoinPipeline {
241            spawn_factory: Box::new(move |_| TransformDescriptor {
242                input_keys: input_keys_for_descriptor,
243                spawn_fn: Box::new(move |producer, db, ctx| {
244                    Box::pin(run_join_transform(db, inputs, producer, handler, ctx))
245                }),
246            }),
247        }
248    }
249}
250
251/// Completed multi-input join pipeline, ready to be registered on a record.
252///
253/// Produced by [`JoinBuilder::on_triggers`] and consumed by
254/// [`RecordRegistrar::transform_join`]. Not normally constructed directly.
255#[cfg(feature = "alloc")]
256pub struct JoinPipeline<O: Send + Sync + Clone + Debug + 'static, R: JoinFanInRuntime + 'static> {
257    pub(crate) spawn_factory: Box<dyn FnOnce(()) -> TransformDescriptor<O, R> + Send>,
258}
259
260#[cfg(feature = "alloc")]
261impl<O, R> JoinPipeline<O, R>
262where
263    O: Send + Sync + Clone + Debug + 'static,
264    R: JoinFanInRuntime + 'static,
265{
266    pub(crate) fn into_descriptor(self) -> TransformDescriptor<O, R> {
267        (self.spawn_factory)(())
268    }
269}
270
271// ============================================================================
272// Join Transform Task Runner
273// ============================================================================
274
275#[cfg(feature = "alloc")]
276#[allow(unused_variables)]
277async fn run_join_transform<O, R, F, Fut>(
278    db: Arc<crate::AimDb<R>>,
279    inputs: Vec<(String, JoinInputFactory<R>)>,
280    producer: crate::Producer<O, R>,
281    handler: F,
282    runtime_ctx: Arc<dyn Any + Send + Sync>,
283) where
284    O: Send + Sync + Clone + Debug + 'static,
285    R: JoinFanInRuntime + 'static,
286    F: FnOnce(JoinEventRx, crate::Producer<O, R>) -> Fut + Send + 'static,
287    Fut: core::future::Future<Output = ()> + Send + 'static,
288{
289    let output_key = producer.key().to_string();
290    let input_keys: Vec<String> = inputs.iter().map(|(k, _)| k.clone()).collect();
291
292    #[cfg(feature = "tracing")]
293    tracing::info!(
294        "🔄 Join transform started: {:?} → '{}'",
295        input_keys,
296        output_key
297    );
298
299    let runtime: &R = match runtime_ctx.downcast_ref::<R>() {
300        Some(r) => r,
301        None => {
302            #[cfg(feature = "tracing")]
303            tracing::error!(
304                "🔄 Join transform '{}' FATAL: runtime context downcast failed",
305                output_key
306            );
307            return;
308        }
309    };
310
311    let queue = match runtime.create_join_queue::<JoinTrigger>() {
312        Ok(q) => q,
313        Err(_e) => {
314            #[cfg(feature = "tracing")]
315            tracing::error!(
316                "🔄 Join transform '{}' FATAL: failed to create join queue",
317                output_key
318            );
319            return;
320        }
321    };
322    let (tx, rx) = queue.split();
323
324    for (index, (_key, factory)) in inputs.into_iter().enumerate() {
325        let sender = tx.clone();
326        let db = db.clone();
327
328        let forwarder_future = factory(db, index, sender);
329        if let Err(_e) = runtime.spawn(forwarder_future) {
330            #[cfg(feature = "tracing")]
331            tracing::error!(
332                "🔄 Join transform '{}' FATAL: failed to spawn forwarder for input index {}",
333                output_key,
334                index
335            );
336            return;
337        }
338    }
339
340    drop(tx);
341
342    #[cfg(feature = "tracing")]
343    tracing::debug!(
344        "✅ Join transform '{}' all forwarders spawned, handing receiver to user task",
345        output_key
346    );
347
348    handler(JoinEventRx::new(rx), producer).await;
349
350    #[cfg(feature = "tracing")]
351    tracing::warn!("🔄 Join transform '{}' user task exited", output_key);
352}