Skip to main content

limen_core/message/
batch.rs

1//! Limen-core batch types and helpers.
2//!
3//! This module provides:
4//! - `Batch<'a, P>`: a thin, immutable view over a slice of `Message<P>` used
5//!   by policies, telemetry and nodes, and
6//! - `BatchView<'a, P>`: an internal container (owned or borrowed) used by the
7//!   runtime / NodeLink / StepContext to assemble batches before handing a
8//!   borrowed `Batch<'a, P>` to nodes.
9//!
10//! `BatchView` intentionally provides both an owned (`Vec`) variant (alloc)
11//! and a borrowed, stack/heapless-friendly variant so the runtime can operate
12//! in both `alloc` and `no-alloc` builds.
13
14use crate::{
15    memory::BufferDescriptor,
16    message::{payload::Payload, Message, MessageHeader},
17    prelude::MemoryManager,
18    types::MessageToken,
19};
20
21use core::{mem, slice};
22
23/// A thin batch view over a slice of messages.
24///
25/// Batch formation is runtime-specific; the core only provides
26/// a convenient immutable view for policies and telemetry.
27#[derive(Debug, Copy, Clone)]
28pub struct Batch<'a, P: Payload> {
29    /// The ordered messages in the batch.
30    messages: &'a [Message<P>],
31}
32
33impl<'a, P: Payload> Batch<'a, P> {
34    /// Construct a new batch view over a slice of messages.
35    #[inline]
36    pub const fn new(messages: &'a [Message<P>]) -> Self {
37        Self { messages }
38    }
39
40    /// Return the underlying messages slice.
41    #[inline]
42    pub fn messages(&self) -> &'a [Message<P>] {
43        self.messages
44    }
45
46    /// Return the number of messages in the batch.
47    #[inline]
48    pub fn len(&self) -> usize {
49        self.messages.len()
50    }
51
52    /// Return `true` if the batch is empty.
53    #[inline]
54    pub fn is_empty(&self) -> bool {
55        self.messages.is_empty()
56    }
57
58    /// Total byte size across message payloads.
59    pub fn total_payload_bytes(&self) -> usize {
60        self.messages
61            .iter()
62            .map(|m| m.header.payload_size_bytes)
63            .sum()
64    }
65
66    /// Iterate over messages.
67    #[inline]
68    pub fn iter(&self) -> core::slice::Iter<'_, Message<P>> {
69        self.messages.iter()
70    }
71
72    /// Convenience: is the first message marked FIRST_IN_BATCH (if present)?
73    #[inline]
74    pub fn first_flagged(&self) -> bool {
75        self.messages
76            .first()
77            .map(|m| m.header.flags.is_first())
78            .unwrap_or(false)
79    }
80
81    /// Convenience: is the last message marked LAST_IN_BATCH (if present)?
82    #[inline]
83    pub fn last_flagged(&self) -> bool {
84        self.messages
85            .last()
86            .map(|m| m.header.flags.is_last())
87            .unwrap_or(false)
88    }
89
90    /// (Optional) Validate flags are consistent with batch boundaries.
91    /// Enable only when you want assertions (e.g., in tests) via a feature flag.
92    // #[cfg(feature = "validate_batches")]
93    #[inline]
94    pub fn assert_flags_consistent(&self) {
95        if self.is_empty() {
96            return;
97        }
98        debug_assert!(
99            self.first_flagged(),
100            "batch: first item missing FIRST_IN_BATCH"
101        );
102        debug_assert!(
103            self.last_flagged(),
104            "batch: last item missing LAST_IN_BATCH"
105        );
106        // Optional: internal items should have neither FIRST nor LAST
107        for m in &self.messages[1..self.messages.len().saturating_sub(1)] {
108            debug_assert!(
109                !m.header.flags.is_first() && !m.header.flags.is_last(),
110                "batch: internal item has boundary flag"
111            );
112        }
113    }
114}
115
116impl<'a, P: Payload> Payload for Batch<'a, P> {
117    #[inline]
118    fn buffer_descriptor(&self) -> BufferDescriptor {
119        // Sum payload bytes across messages and add header size per message.
120        let total_payload_bytes: usize = self
121            .messages
122            .iter()
123            .map(|m| {
124                // Use the header field stored on message as the authoritative payload size.
125                // This avoids re-inspecting m.payload() which might be expensive for some payloads.
126                m.header.payload_size_bytes
127            })
128            .sum();
129
130        let header_bytes = self.messages.len() * mem::size_of::<MessageHeader>();
131        BufferDescriptor::new(total_payload_bytes + header_bytes)
132    }
133}
134
135// Provide also for borrowed Batch reference to match other Payload impls pattern.
136impl<'a, P: Payload> Payload for &'a Batch<'a, P> {
137    #[inline]
138    fn buffer_descriptor(&self) -> BufferDescriptor {
139        (*self).buffer_descriptor()
140    }
141}
142
143/// An internal batch container used by the runtime/nodelink/stepcontext.
144///
145/// Generic over the stored item `I`. Commonly `I = Message<P>`, but by using a
146/// stored-item generic the same type works for any Edge whose `Item` is a
147/// `Payload`-implementing type.
148///
149/// - `Owned(Vec<I>)`: when `alloc` feature is enabled and we can own a Vec.
150/// - `Borrowed(&'a mut [I], len)`: stack/heapless-backed buffer with explicit length.
151#[derive(Debug)]
152pub enum BatchView<'a, I> {
153    /// Owned variant (alloc-enabled). Stores the entire `Vec<I>`.
154    #[cfg(feature = "alloc")]
155    Owned(alloc::vec::Vec<I>),
156
157    /// Borrowed variant: a mutable slice plus a length indicating the valid prefix.
158    Borrowed(&'a mut [I], usize),
159}
160
161impl<'a, I> BatchView<'a, I> {
162    /// Construct from an owned Vec (alloc feature required).
163    #[cfg(feature = "alloc")]
164    #[inline]
165    pub fn from_owned(v: alloc::vec::Vec<I>) -> Self {
166        BatchView::Owned(v)
167    }
168
169    /// Construct from a borrowed slice + length.
170    #[inline]
171    pub fn from_borrowed(buf: &'a mut [I], len: usize) -> Self {
172        debug_assert!(len <= buf.len());
173        BatchView::Borrowed(buf, len)
174    }
175
176    /// Number of items in the batch.
177    #[inline]
178    pub fn len(&self) -> usize {
179        match self {
180            #[cfg(feature = "alloc")]
181            BatchView::Owned(v) => v.len(),
182            BatchView::Borrowed(_, n) => *n,
183        }
184    }
185
186    /// Is the batch empty?
187    #[inline]
188    pub fn is_empty(&self) -> bool {
189        self.len() == 0
190    }
191
192    /// Immutable iterator over items.
193    #[inline]
194    pub fn iter(&self) -> slice::Iter<'_, I> {
195        match self {
196            #[cfg(feature = "alloc")]
197            BatchView::Owned(v) => v.as_slice().iter(),
198            BatchView::Borrowed(buf, n) => buf[..*n].iter(),
199        }
200    }
201
202    /// Mutable iterator over items.
203    #[inline]
204    pub fn iter_mut(&mut self) -> slice::IterMut<'_, I> {
205        match self {
206            #[cfg(feature = "alloc")]
207            BatchView::Owned(v) => v.as_mut_slice().iter_mut(),
208            BatchView::Borrowed(buf, n) => buf[..*n].iter_mut(),
209        }
210    }
211
212    /// Return an immutable slice over the valid items.
213    #[inline]
214    pub fn as_slice(&self) -> &[I] {
215        match self {
216            #[cfg(feature = "alloc")]
217            BatchView::Owned(v) => v.as_slice(),
218            BatchView::Borrowed(buf, n) => &buf[..*n],
219        }
220    }
221}
222
223/// Special-case helpers for the common stored-item type: `Message<P>`.
224impl<'a, P: Payload> BatchView<'a, Message<P>> {
225    /// Convert to the public, borrowed `Batch<'_, P>` view.
226    ///
227    /// This is only available when the stored item is `Message<P>`.
228    #[inline]
229    pub fn as_batch(&self) -> Batch<'_, P> {
230        let slice: &[Message<P>] = match self {
231            #[cfg(feature = "alloc")]
232            BatchView::Owned(v) => v.as_slice(),
233            BatchView::Borrowed(buf, n) => &buf[..*n],
234        };
235        Batch::new(slice)
236    }
237
238    /// Mutable access to the first message header if present.
239    #[inline]
240    pub fn first_header_mut(&mut self) -> Option<&mut MessageHeader> {
241        if self.is_empty() {
242            return None;
243        }
244        Some(match self {
245            #[cfg(feature = "alloc")]
246            BatchView::Owned(v) => v[0].header_mut(),
247            BatchView::Borrowed(buf, _) => buf[0].header_mut(),
248        })
249    }
250
251    /// Mutable access to the last message header if present.
252    #[inline]
253    pub fn last_header_mut(&mut self) -> Option<&mut MessageHeader> {
254        let n = self.len();
255        if n == 0 {
256            return None;
257        }
258        Some(match self {
259            #[cfg(feature = "alloc")]
260            BatchView::Owned(v) => v[n - 1].header_mut(),
261            BatchView::Borrowed(buf, _) => buf[n - 1].header_mut(),
262        })
263    }
264
265    /// Try to convert into a borrowed `Batch<'_, P>` while keeping `self` alive.
266    /// Equivalent to `self.as_batch()` but offered for clarity.
267    #[inline]
268    pub fn into_batch_ref(&self) -> Batch<'_, P> {
269        self.as_batch()
270    }
271
272    /// Convert this batch view into an owned batch view.
273    ///
274    /// This is required by mutex-backed edges: a borrowed batch cannot escape the lock guard.
275    ///
276    /// Semantics:
277    /// - If already `Owned`, the inner `Vec` is forwarded.
278    /// - If `Borrowed`, the valid prefix is cloned into a new `Vec`.
279    ///
280    /// The returned `BatchView` is `Owned` and does not borrow from the original `'a`.
281    #[cfg(feature = "alloc")]
282    #[inline]
283    pub fn into_owned<'b>(self) -> BatchView<'b, Message<P>>
284    where
285        Message<P>: Clone,
286    {
287        match self {
288            BatchView::Owned(v) => BatchView::<'b, Message<P>>::Owned(v),
289            BatchView::Borrowed(buf, n) => {
290                let mut v: alloc::vec::Vec<Message<P>> = alloc::vec::Vec::with_capacity(n);
291                for m in &buf[..n] {
292                    v.push(m.clone());
293                }
294                BatchView::<'b, Message<P>>::Owned(v)
295            }
296        }
297    }
298
299    /// Consume and return an owned `Vec<Message<P>>`.
300    ///
301    /// - If this `BatchView` is `Owned`, returns the inner Vec without copying.
302    /// - If this `BatchView` is `Borrowed`, clones the valid prefix into a new Vec.
303    ///
304    /// Alloc-only.
305    #[cfg(feature = "alloc")]
306    #[inline]
307    pub fn into_vec(self) -> alloc::vec::Vec<Message<P>>
308    where
309        P: Clone,
310    {
311        match self {
312            BatchView::Owned(v) => v,
313            BatchView::Borrowed(buf, n) => {
314                let mut v = alloc::vec::Vec::with_capacity(n);
315                for m in &buf[..n] {
316                    v.push(m.clone());
317                }
318                v
319            }
320        }
321    }
322}
323
324impl<'a, P: Payload> Payload for BatchView<'a, Message<P>> {
325    #[inline]
326    fn buffer_descriptor(&self) -> BufferDescriptor {
327        match self {
328            #[cfg(feature = "alloc")]
329            BatchView::Owned(v) => {
330                let total_payload_bytes: usize =
331                    v.iter().map(|m| m.header().payload_size_bytes).sum();
332                let header_bytes = v.len() * mem::size_of::<MessageHeader>();
333                BufferDescriptor::new(total_payload_bytes + header_bytes)
334            }
335
336            BatchView::Borrowed(buf, n) => {
337                let total_payload_bytes: usize = buf[..*n]
338                    .iter()
339                    .map(|m| m.header().payload_size_bytes)
340                    .sum();
341                let header_bytes = *n * mem::size_of::<MessageHeader>();
342                BufferDescriptor::new(total_payload_bytes + header_bytes)
343            }
344        }
345    }
346}
347
348// Borrowed ref as well
349impl<'a, P: Payload> Payload for &'a BatchView<'a, Message<P>> {
350    #[inline]
351    fn buffer_descriptor(&self) -> BufferDescriptor {
352        (*self).buffer_descriptor()
353    }
354}
355
356/// Lazy guard-yielding iterator over messages in a batch.
357///
358/// Backed by token resolution through a memory manager. Each `next()` call
359/// takes a `ReadGuard` from the manager — no copying, no contiguous buffer
360/// required.
361///
362/// Nodes can:
363/// - iterate and process one at a time (default `step_batch`)
364/// - iterate and copy into a node-owned scratch buffer (`InferenceModel`)
365pub struct BatchMessageIter<'edge, 'mgr, P: Payload, M: MemoryManager<P>> {
366    tokens: core::slice::Iter<'edge, MessageToken>,
367    manager: &'mgr M,
368    /// Number of leading tokens that were popped (rest are peeked).
369    stride: usize,
370    /// Total number of tokens in the batch.
371    len: usize,
372    _pd: core::marker::PhantomData<P>,
373}
374
375impl<'edge, 'mgr, P: Payload, M: MemoryManager<P>> BatchMessageIter<'edge, 'mgr, P, M> {
376    /// Construct a new `BatchMessageIter` from token slice, manager, and stride.
377    #[inline]
378    pub fn new(
379        tokens: core::slice::Iter<'edge, MessageToken>,
380        manager: &'mgr M,
381        stride: usize,
382        len: usize,
383    ) -> Self {
384        Self {
385            tokens,
386            manager,
387            stride,
388            len,
389            _pd: core::marker::PhantomData,
390        }
391    }
392
393    /// How many items were popped (will be freed after the callback).
394    #[inline]
395    pub fn stride(&self) -> usize {
396        self.stride
397    }
398
399    /// Total batch length (popped + peeked).
400    #[inline]
401    pub fn len(&self) -> usize {
402        self.len
403    }
404
405    /// Whether the batch is empty.
406    #[inline]
407    pub fn is_empty(&self) -> bool {
408        self.len == 0
409    }
410
411    /// Whether this is a sliding window batch (stride < len).
412    #[inline]
413    pub fn is_sliding(&self) -> bool {
414        self.stride < self.len
415    }
416}
417
418impl<'edge, 'mgr, P: Payload, M: MemoryManager<P>> Iterator
419    for BatchMessageIter<'edge, 'mgr, P, M>
420{
421    type Item = M::ReadGuard<'mgr>;
422
423    #[inline]
424    fn next(&mut self) -> Option<Self::Item> {
425        let &token = self.tokens.next()?;
426        self.manager.read(token).ok()
427    }
428
429    #[inline]
430    fn size_hint(&self) -> (usize, Option<usize>) {
431        self.tokens.size_hint()
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use crate::prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT};
438
439    use super::*;
440
441    /// Helper: construct a `Message<TestTensor>` with an empty header and a
442    /// uniformly-filled shared test tensor payload.
443    fn make_msg_tensor(v: u32) -> Message<TestTensor> {
444        Message::new(MessageHeader::empty(), create_test_tensor_filled_with(v))
445    }
446
447    #[test]
448    fn batch_basic_props() {
449        // Build a small array of messages and create a Batch view.
450        let arr: [Message<TestTensor>; 3] = [
451            make_msg_tensor(10),
452            make_msg_tensor(11),
453            make_msg_tensor(12),
454        ];
455
456        // Build a Batch directly over a slice.
457        let batch = Batch::new(&arr[..2]); // first two items
458        assert_eq!(batch.len(), 2);
459        assert!(!batch.is_empty());
460        assert_eq!(batch.messages().len(), 2);
461
462        // total_payload_bytes should be sum of test tensor payload sizes
463        assert_eq!(batch.total_payload_bytes(), 2 * TEST_TENSOR_BYTE_COUNT);
464
465        // initially flags are not set
466        assert!(!batch.first_flagged());
467        assert!(!batch.last_flagged());
468    }
469
470    #[test]
471    fn batchview_borrowed_basic_and_mutation() {
472        // Prepare 4 messages but claim only first 3 are valid for the batch.
473        let mut arr: [Message<TestTensor>; 4] = [
474            make_msg_tensor(100),
475            make_msg_tensor(101),
476            make_msg_tensor(102),
477            make_msg_tensor(103),
478        ];
479
480        // Create Borrowed BatchView with length = 3
481        let mut bv = BatchView::from_borrowed(&mut arr, 3);
482        assert_eq!(bv.len(), 3);
483        assert!(!bv.is_empty());
484
485        // Mutate payloads via iter_mut()
486        for (i, m) in bv.iter_mut().enumerate() {
487            *m.payload_mut() = create_test_tensor_filled_with(200 + (i as u32));
488        }
489
490        // Convert to public Batch and inspect values without using vec
491        let batch = bv.as_batch();
492        let mut vals = [TestTensor::default(); 3];
493        let mut i = 0;
494        for m in batch.iter() {
495            vals[i] = m.payload().clone();
496            i += 1;
497        }
498        assert_eq!(
499            vals,
500            [
501                create_test_tensor_filled_with(200),
502                create_test_tensor_filled_with(201),
503                create_test_tensor_filled_with(202),
504            ]
505        );
506
507        // Set first/last flags through BatchView helpers
508        {
509            let fh = bv.first_header_mut().expect("first header");
510            fh.set_first_in_batch();
511            let lh = bv.last_header_mut().expect("last header");
512            lh.set_last_in_batch();
513        }
514
515        let batch2 = bv.as_batch();
516        assert!(batch2.first_flagged());
517        assert!(batch2.last_flagged());
518    }
519
520    #[cfg(feature = "alloc")]
521    #[test]
522    fn batchview_owned_basic_and_into_owned() {
523        use alloc::vec::Vec;
524
525        // Create an owned vector of messages.
526        let mut vec: Vec<Message<TestTensor>> = Vec::new();
527        vec.push(make_msg_tensor(1));
528        vec.push(make_msg_tensor(2));
529        vec.push(make_msg_tensor(3));
530
531        let mut bv = BatchView::from_owned(vec);
532        assert_eq!(bv.len(), 3);
533        assert!(!bv.is_empty());
534
535        // Mutate last payload via iter_mut()
536        for (i, m) in bv.iter_mut().enumerate() {
537            if i == 2 {
538                *m.payload_mut() = create_test_tensor_filled_with(42);
539            }
540        }
541
542        // Inspect via as_batch into an owned Vec
543        let batch = bv.as_batch();
544        let mut xs: Vec<TestTensor> = Vec::new();
545        for m in batch.iter() {
546            xs.push(m.payload().clone());
547        }
548        // compare to a slice instead of using `vec![]` macro
549        assert_eq!(
550            xs.as_slice(),
551            &[
552                create_test_tensor_filled_with(1),
553                create_test_tensor_filled_with(2),
554                create_test_tensor_filled_with(42),
555            ]
556        );
557
558        // Check header helpers and then consume owned vec
559        {
560            let fh = bv.first_header_mut().expect("first header");
561            fh.set_first_in_batch();
562            let lh = bv.last_header_mut().expect("last header");
563            lh.set_last_in_batch();
564        }
565        let batch2 = bv.as_batch();
566        assert!(batch2.first_flagged());
567        assert!(batch2.last_flagged());
568
569        // Consume the owned vec
570        let ov = bv.into_vec();
571        assert_eq!(ov.len(), 3);
572        // Confirm the final payload value (42) survived.
573        assert_eq!(
574            *ov.last().unwrap().payload(),
575            create_test_tensor_filled_with(42)
576        );
577    }
578
579    #[test]
580    fn batch_assert_flags_consistent_no_panic_when_correct() {
581        let mut arr: [Message<TestTensor>; 2] = [make_msg_tensor(7), make_msg_tensor(8)];
582
583        // set flags explicitly on headers and make a Batch
584        {
585            let mut bv = BatchView::from_borrowed(&mut arr, 2);
586            bv.first_header_mut().unwrap().set_first_in_batch();
587            bv.last_header_mut().unwrap().set_last_in_batch();
588            let batch = bv.as_batch();
589            // This should not panic (debug_assert runs in test builds)
590            batch.assert_flags_consistent();
591        }
592    }
593}