1use crate::{
15 memory::BufferDescriptor,
16 message::{payload::Payload, Message, MessageHeader},
17 prelude::MemoryManager,
18 types::MessageToken,
19};
20
21use core::{mem, slice};
22
23#[derive(Debug, Copy, Clone)]
28pub struct Batch<'a, P: Payload> {
29 messages: &'a [Message<P>],
31}
32
33impl<'a, P: Payload> Batch<'a, P> {
34 #[inline]
36 pub const fn new(messages: &'a [Message<P>]) -> Self {
37 Self { messages }
38 }
39
40 #[inline]
42 pub fn messages(&self) -> &'a [Message<P>] {
43 self.messages
44 }
45
46 #[inline]
48 pub fn len(&self) -> usize {
49 self.messages.len()
50 }
51
52 #[inline]
54 pub fn is_empty(&self) -> bool {
55 self.messages.is_empty()
56 }
57
58 pub fn total_payload_bytes(&self) -> usize {
60 self.messages
61 .iter()
62 .map(|m| m.header.payload_size_bytes)
63 .sum()
64 }
65
66 #[inline]
68 pub fn iter(&self) -> core::slice::Iter<'_, Message<P>> {
69 self.messages.iter()
70 }
71
72 #[inline]
74 pub fn first_flagged(&self) -> bool {
75 self.messages
76 .first()
77 .map(|m| m.header.flags.is_first())
78 .unwrap_or(false)
79 }
80
81 #[inline]
83 pub fn last_flagged(&self) -> bool {
84 self.messages
85 .last()
86 .map(|m| m.header.flags.is_last())
87 .unwrap_or(false)
88 }
89
90 #[inline]
94 pub fn assert_flags_consistent(&self) {
95 if self.is_empty() {
96 return;
97 }
98 debug_assert!(
99 self.first_flagged(),
100 "batch: first item missing FIRST_IN_BATCH"
101 );
102 debug_assert!(
103 self.last_flagged(),
104 "batch: last item missing LAST_IN_BATCH"
105 );
106 for m in &self.messages[1..self.messages.len().saturating_sub(1)] {
108 debug_assert!(
109 !m.header.flags.is_first() && !m.header.flags.is_last(),
110 "batch: internal item has boundary flag"
111 );
112 }
113 }
114}
115
116impl<'a, P: Payload> Payload for Batch<'a, P> {
117 #[inline]
118 fn buffer_descriptor(&self) -> BufferDescriptor {
119 let total_payload_bytes: usize = self
121 .messages
122 .iter()
123 .map(|m| {
124 m.header.payload_size_bytes
127 })
128 .sum();
129
130 let header_bytes = self.messages.len() * mem::size_of::<MessageHeader>();
131 BufferDescriptor::new(total_payload_bytes + header_bytes)
132 }
133}
134
135impl<'a, P: Payload> Payload for &'a Batch<'a, P> {
137 #[inline]
138 fn buffer_descriptor(&self) -> BufferDescriptor {
139 (*self).buffer_descriptor()
140 }
141}
142
143#[derive(Debug)]
152pub enum BatchView<'a, I> {
153 #[cfg(feature = "alloc")]
155 Owned(alloc::vec::Vec<I>),
156
157 Borrowed(&'a mut [I], usize),
159}
160
161impl<'a, I> BatchView<'a, I> {
162 #[cfg(feature = "alloc")]
164 #[inline]
165 pub fn from_owned(v: alloc::vec::Vec<I>) -> Self {
166 BatchView::Owned(v)
167 }
168
169 #[inline]
171 pub fn from_borrowed(buf: &'a mut [I], len: usize) -> Self {
172 debug_assert!(len <= buf.len());
173 BatchView::Borrowed(buf, len)
174 }
175
176 #[inline]
178 pub fn len(&self) -> usize {
179 match self {
180 #[cfg(feature = "alloc")]
181 BatchView::Owned(v) => v.len(),
182 BatchView::Borrowed(_, n) => *n,
183 }
184 }
185
186 #[inline]
188 pub fn is_empty(&self) -> bool {
189 self.len() == 0
190 }
191
192 #[inline]
194 pub fn iter(&self) -> slice::Iter<'_, I> {
195 match self {
196 #[cfg(feature = "alloc")]
197 BatchView::Owned(v) => v.as_slice().iter(),
198 BatchView::Borrowed(buf, n) => buf[..*n].iter(),
199 }
200 }
201
202 #[inline]
204 pub fn iter_mut(&mut self) -> slice::IterMut<'_, I> {
205 match self {
206 #[cfg(feature = "alloc")]
207 BatchView::Owned(v) => v.as_mut_slice().iter_mut(),
208 BatchView::Borrowed(buf, n) => buf[..*n].iter_mut(),
209 }
210 }
211
212 #[inline]
214 pub fn as_slice(&self) -> &[I] {
215 match self {
216 #[cfg(feature = "alloc")]
217 BatchView::Owned(v) => v.as_slice(),
218 BatchView::Borrowed(buf, n) => &buf[..*n],
219 }
220 }
221}
222
223impl<'a, P: Payload> BatchView<'a, Message<P>> {
225 #[inline]
229 pub fn as_batch(&self) -> Batch<'_, P> {
230 let slice: &[Message<P>] = match self {
231 #[cfg(feature = "alloc")]
232 BatchView::Owned(v) => v.as_slice(),
233 BatchView::Borrowed(buf, n) => &buf[..*n],
234 };
235 Batch::new(slice)
236 }
237
238 #[inline]
240 pub fn first_header_mut(&mut self) -> Option<&mut MessageHeader> {
241 if self.is_empty() {
242 return None;
243 }
244 Some(match self {
245 #[cfg(feature = "alloc")]
246 BatchView::Owned(v) => v[0].header_mut(),
247 BatchView::Borrowed(buf, _) => buf[0].header_mut(),
248 })
249 }
250
251 #[inline]
253 pub fn last_header_mut(&mut self) -> Option<&mut MessageHeader> {
254 let n = self.len();
255 if n == 0 {
256 return None;
257 }
258 Some(match self {
259 #[cfg(feature = "alloc")]
260 BatchView::Owned(v) => v[n - 1].header_mut(),
261 BatchView::Borrowed(buf, _) => buf[n - 1].header_mut(),
262 })
263 }
264
265 #[inline]
268 pub fn into_batch_ref(&self) -> Batch<'_, P> {
269 self.as_batch()
270 }
271
272 #[cfg(feature = "alloc")]
282 #[inline]
283 pub fn into_owned<'b>(self) -> BatchView<'b, Message<P>>
284 where
285 Message<P>: Clone,
286 {
287 match self {
288 BatchView::Owned(v) => BatchView::<'b, Message<P>>::Owned(v),
289 BatchView::Borrowed(buf, n) => {
290 let mut v: alloc::vec::Vec<Message<P>> = alloc::vec::Vec::with_capacity(n);
291 for m in &buf[..n] {
292 v.push(m.clone());
293 }
294 BatchView::<'b, Message<P>>::Owned(v)
295 }
296 }
297 }
298
299 #[cfg(feature = "alloc")]
306 #[inline]
307 pub fn into_vec(self) -> alloc::vec::Vec<Message<P>>
308 where
309 P: Clone,
310 {
311 match self {
312 BatchView::Owned(v) => v,
313 BatchView::Borrowed(buf, n) => {
314 let mut v = alloc::vec::Vec::with_capacity(n);
315 for m in &buf[..n] {
316 v.push(m.clone());
317 }
318 v
319 }
320 }
321 }
322}
323
324impl<'a, P: Payload> Payload for BatchView<'a, Message<P>> {
325 #[inline]
326 fn buffer_descriptor(&self) -> BufferDescriptor {
327 match self {
328 #[cfg(feature = "alloc")]
329 BatchView::Owned(v) => {
330 let total_payload_bytes: usize =
331 v.iter().map(|m| m.header().payload_size_bytes).sum();
332 let header_bytes = v.len() * mem::size_of::<MessageHeader>();
333 BufferDescriptor::new(total_payload_bytes + header_bytes)
334 }
335
336 BatchView::Borrowed(buf, n) => {
337 let total_payload_bytes: usize = buf[..*n]
338 .iter()
339 .map(|m| m.header().payload_size_bytes)
340 .sum();
341 let header_bytes = *n * mem::size_of::<MessageHeader>();
342 BufferDescriptor::new(total_payload_bytes + header_bytes)
343 }
344 }
345 }
346}
347
348impl<'a, P: Payload> Payload for &'a BatchView<'a, Message<P>> {
350 #[inline]
351 fn buffer_descriptor(&self) -> BufferDescriptor {
352 (*self).buffer_descriptor()
353 }
354}
355
356pub struct BatchMessageIter<'edge, 'mgr, P: Payload, M: MemoryManager<P>> {
366 tokens: core::slice::Iter<'edge, MessageToken>,
367 manager: &'mgr M,
368 stride: usize,
370 len: usize,
372 _pd: core::marker::PhantomData<P>,
373}
374
375impl<'edge, 'mgr, P: Payload, M: MemoryManager<P>> BatchMessageIter<'edge, 'mgr, P, M> {
376 #[inline]
378 pub fn new(
379 tokens: core::slice::Iter<'edge, MessageToken>,
380 manager: &'mgr M,
381 stride: usize,
382 len: usize,
383 ) -> Self {
384 Self {
385 tokens,
386 manager,
387 stride,
388 len,
389 _pd: core::marker::PhantomData,
390 }
391 }
392
393 #[inline]
395 pub fn stride(&self) -> usize {
396 self.stride
397 }
398
399 #[inline]
401 pub fn len(&self) -> usize {
402 self.len
403 }
404
405 #[inline]
407 pub fn is_empty(&self) -> bool {
408 self.len == 0
409 }
410
411 #[inline]
413 pub fn is_sliding(&self) -> bool {
414 self.stride < self.len
415 }
416}
417
418impl<'edge, 'mgr, P: Payload, M: MemoryManager<P>> Iterator
419 for BatchMessageIter<'edge, 'mgr, P, M>
420{
421 type Item = M::ReadGuard<'mgr>;
422
423 #[inline]
424 fn next(&mut self) -> Option<Self::Item> {
425 let &token = self.tokens.next()?;
426 self.manager.read(token).ok()
427 }
428
429 #[inline]
430 fn size_hint(&self) -> (usize, Option<usize>) {
431 self.tokens.size_hint()
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use crate::prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT};
438
439 use super::*;
440
441 fn make_msg_tensor(v: u32) -> Message<TestTensor> {
444 Message::new(MessageHeader::empty(), create_test_tensor_filled_with(v))
445 }
446
447 #[test]
448 fn batch_basic_props() {
449 let arr: [Message<TestTensor>; 3] = [
451 make_msg_tensor(10),
452 make_msg_tensor(11),
453 make_msg_tensor(12),
454 ];
455
456 let batch = Batch::new(&arr[..2]); assert_eq!(batch.len(), 2);
459 assert!(!batch.is_empty());
460 assert_eq!(batch.messages().len(), 2);
461
462 assert_eq!(batch.total_payload_bytes(), 2 * TEST_TENSOR_BYTE_COUNT);
464
465 assert!(!batch.first_flagged());
467 assert!(!batch.last_flagged());
468 }
469
470 #[test]
471 fn batchview_borrowed_basic_and_mutation() {
472 let mut arr: [Message<TestTensor>; 4] = [
474 make_msg_tensor(100),
475 make_msg_tensor(101),
476 make_msg_tensor(102),
477 make_msg_tensor(103),
478 ];
479
480 let mut bv = BatchView::from_borrowed(&mut arr, 3);
482 assert_eq!(bv.len(), 3);
483 assert!(!bv.is_empty());
484
485 for (i, m) in bv.iter_mut().enumerate() {
487 *m.payload_mut() = create_test_tensor_filled_with(200 + (i as u32));
488 }
489
490 let batch = bv.as_batch();
492 let mut vals = [TestTensor::default(); 3];
493 let mut i = 0;
494 for m in batch.iter() {
495 vals[i] = m.payload().clone();
496 i += 1;
497 }
498 assert_eq!(
499 vals,
500 [
501 create_test_tensor_filled_with(200),
502 create_test_tensor_filled_with(201),
503 create_test_tensor_filled_with(202),
504 ]
505 );
506
507 {
509 let fh = bv.first_header_mut().expect("first header");
510 fh.set_first_in_batch();
511 let lh = bv.last_header_mut().expect("last header");
512 lh.set_last_in_batch();
513 }
514
515 let batch2 = bv.as_batch();
516 assert!(batch2.first_flagged());
517 assert!(batch2.last_flagged());
518 }
519
520 #[cfg(feature = "alloc")]
521 #[test]
522 fn batchview_owned_basic_and_into_owned() {
523 use alloc::vec::Vec;
524
525 let mut vec: Vec<Message<TestTensor>> = Vec::new();
527 vec.push(make_msg_tensor(1));
528 vec.push(make_msg_tensor(2));
529 vec.push(make_msg_tensor(3));
530
531 let mut bv = BatchView::from_owned(vec);
532 assert_eq!(bv.len(), 3);
533 assert!(!bv.is_empty());
534
535 for (i, m) in bv.iter_mut().enumerate() {
537 if i == 2 {
538 *m.payload_mut() = create_test_tensor_filled_with(42);
539 }
540 }
541
542 let batch = bv.as_batch();
544 let mut xs: Vec<TestTensor> = Vec::new();
545 for m in batch.iter() {
546 xs.push(m.payload().clone());
547 }
548 assert_eq!(
550 xs.as_slice(),
551 &[
552 create_test_tensor_filled_with(1),
553 create_test_tensor_filled_with(2),
554 create_test_tensor_filled_with(42),
555 ]
556 );
557
558 {
560 let fh = bv.first_header_mut().expect("first header");
561 fh.set_first_in_batch();
562 let lh = bv.last_header_mut().expect("last header");
563 lh.set_last_in_batch();
564 }
565 let batch2 = bv.as_batch();
566 assert!(batch2.first_flagged());
567 assert!(batch2.last_flagged());
568
569 let ov = bv.into_vec();
571 assert_eq!(ov.len(), 3);
572 assert_eq!(
574 *ov.last().unwrap().payload(),
575 create_test_tensor_filled_with(42)
576 );
577 }
578
579 #[test]
580 fn batch_assert_flags_consistent_no_panic_when_correct() {
581 let mut arr: [Message<TestTensor>; 2] = [make_msg_tensor(7), make_msg_tensor(8)];
582
583 {
585 let mut bv = BatchView::from_borrowed(&mut arr, 2);
586 bv.first_header_mut().unwrap().set_first_in_batch();
587 bv.last_header_mut().unwrap().set_last_in_batch();
588 let batch = bv.as_batch();
589 batch.assert_flags_consistent();
591 }
592 }
593}