use crate::{
memory::BufferDescriptor,
message::{payload::Payload, Message, MessageHeader},
prelude::MemoryManager,
types::MessageToken,
};
use core::{mem, slice};
#[derive(Debug, Copy, Clone)]
pub struct Batch<'a, P: Payload> {
messages: &'a [Message<P>],
}
impl<'a, P: Payload> Batch<'a, P> {
#[inline]
pub const fn new(messages: &'a [Message<P>]) -> Self {
Self { messages }
}
#[inline]
pub fn messages(&self) -> &'a [Message<P>] {
self.messages
}
#[inline]
pub fn len(&self) -> usize {
self.messages.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.messages.is_empty()
}
pub fn total_payload_bytes(&self) -> usize {
self.messages
.iter()
.map(|m| m.header.payload_size_bytes)
.sum()
}
#[inline]
pub fn iter(&self) -> core::slice::Iter<'_, Message<P>> {
self.messages.iter()
}
#[inline]
pub fn first_flagged(&self) -> bool {
self.messages
.first()
.map(|m| m.header.flags.is_first())
.unwrap_or(false)
}
#[inline]
pub fn last_flagged(&self) -> bool {
self.messages
.last()
.map(|m| m.header.flags.is_last())
.unwrap_or(false)
}
#[inline]
pub fn assert_flags_consistent(&self) {
if self.is_empty() {
return;
}
debug_assert!(
self.first_flagged(),
"batch: first item missing FIRST_IN_BATCH"
);
debug_assert!(
self.last_flagged(),
"batch: last item missing LAST_IN_BATCH"
);
for m in &self.messages[1..self.messages.len().saturating_sub(1)] {
debug_assert!(
!m.header.flags.is_first() && !m.header.flags.is_last(),
"batch: internal item has boundary flag"
);
}
}
}
impl<'a, P: Payload> Payload for Batch<'a, P> {
#[inline]
fn buffer_descriptor(&self) -> BufferDescriptor {
let total_payload_bytes: usize = self
.messages
.iter()
.map(|m| {
m.header.payload_size_bytes
})
.sum();
let header_bytes = self.messages.len() * mem::size_of::<MessageHeader>();
BufferDescriptor::new(total_payload_bytes + header_bytes)
}
}
impl<'a, P: Payload> Payload for &'a Batch<'a, P> {
#[inline]
fn buffer_descriptor(&self) -> BufferDescriptor {
(*self).buffer_descriptor()
}
}
#[derive(Debug)]
pub enum BatchView<'a, I> {
#[cfg(feature = "alloc")]
Owned(alloc::vec::Vec<I>),
Borrowed(&'a mut [I], usize),
}
impl<'a, I> BatchView<'a, I> {
#[cfg(feature = "alloc")]
#[inline]
pub fn from_owned(v: alloc::vec::Vec<I>) -> Self {
BatchView::Owned(v)
}
#[inline]
pub fn from_borrowed(buf: &'a mut [I], len: usize) -> Self {
debug_assert!(len <= buf.len());
BatchView::Borrowed(buf, len)
}
#[inline]
pub fn len(&self) -> usize {
match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v.len(),
BatchView::Borrowed(_, n) => *n,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn iter(&self) -> slice::Iter<'_, I> {
match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v.as_slice().iter(),
BatchView::Borrowed(buf, n) => buf[..*n].iter(),
}
}
#[inline]
pub fn iter_mut(&mut self) -> slice::IterMut<'_, I> {
match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v.as_mut_slice().iter_mut(),
BatchView::Borrowed(buf, n) => buf[..*n].iter_mut(),
}
}
#[inline]
pub fn as_slice(&self) -> &[I] {
match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v.as_slice(),
BatchView::Borrowed(buf, n) => &buf[..*n],
}
}
}
impl<'a, P: Payload> BatchView<'a, Message<P>> {
#[inline]
pub fn as_batch(&self) -> Batch<'_, P> {
let slice: &[Message<P>] = match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v.as_slice(),
BatchView::Borrowed(buf, n) => &buf[..*n],
};
Batch::new(slice)
}
#[inline]
pub fn first_header_mut(&mut self) -> Option<&mut MessageHeader> {
if self.is_empty() {
return None;
}
Some(match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v[0].header_mut(),
BatchView::Borrowed(buf, _) => buf[0].header_mut(),
})
}
#[inline]
pub fn last_header_mut(&mut self) -> Option<&mut MessageHeader> {
let n = self.len();
if n == 0 {
return None;
}
Some(match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => v[n - 1].header_mut(),
BatchView::Borrowed(buf, _) => buf[n - 1].header_mut(),
})
}
#[inline]
pub fn into_batch_ref(&self) -> Batch<'_, P> {
self.as_batch()
}
#[cfg(feature = "alloc")]
#[inline]
pub fn into_owned<'b>(self) -> BatchView<'b, Message<P>>
where
Message<P>: Clone,
{
match self {
BatchView::Owned(v) => BatchView::<'b, Message<P>>::Owned(v),
BatchView::Borrowed(buf, n) => {
let mut v: alloc::vec::Vec<Message<P>> = alloc::vec::Vec::with_capacity(n);
for m in &buf[..n] {
v.push(m.clone());
}
BatchView::<'b, Message<P>>::Owned(v)
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn into_vec(self) -> alloc::vec::Vec<Message<P>>
where
P: Clone,
{
match self {
BatchView::Owned(v) => v,
BatchView::Borrowed(buf, n) => {
let mut v = alloc::vec::Vec::with_capacity(n);
for m in &buf[..n] {
v.push(m.clone());
}
v
}
}
}
}
impl<'a, P: Payload> Payload for BatchView<'a, Message<P>> {
#[inline]
fn buffer_descriptor(&self) -> BufferDescriptor {
match self {
#[cfg(feature = "alloc")]
BatchView::Owned(v) => {
let total_payload_bytes: usize =
v.iter().map(|m| m.header().payload_size_bytes).sum();
let header_bytes = v.len() * mem::size_of::<MessageHeader>();
BufferDescriptor::new(total_payload_bytes + header_bytes)
}
BatchView::Borrowed(buf, n) => {
let total_payload_bytes: usize = buf[..*n]
.iter()
.map(|m| m.header().payload_size_bytes)
.sum();
let header_bytes = *n * mem::size_of::<MessageHeader>();
BufferDescriptor::new(total_payload_bytes + header_bytes)
}
}
}
}
impl<'a, P: Payload> Payload for &'a BatchView<'a, Message<P>> {
#[inline]
fn buffer_descriptor(&self) -> BufferDescriptor {
(*self).buffer_descriptor()
}
}
pub struct BatchMessageIter<'edge, 'mgr, P: Payload, M: MemoryManager<P>> {
tokens: core::slice::Iter<'edge, MessageToken>,
manager: &'mgr M,
stride: usize,
len: usize,
_pd: core::marker::PhantomData<P>,
}
impl<'edge, 'mgr, P: Payload, M: MemoryManager<P>> BatchMessageIter<'edge, 'mgr, P, M> {
#[inline]
pub fn new(
tokens: core::slice::Iter<'edge, MessageToken>,
manager: &'mgr M,
stride: usize,
len: usize,
) -> Self {
Self {
tokens,
manager,
stride,
len,
_pd: core::marker::PhantomData,
}
}
#[inline]
pub fn stride(&self) -> usize {
self.stride
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn is_sliding(&self) -> bool {
self.stride < self.len
}
}
impl<'edge, 'mgr, P: Payload, M: MemoryManager<P>> Iterator
for BatchMessageIter<'edge, 'mgr, P, M>
{
type Item = M::ReadGuard<'mgr>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let &token = self.tokens.next()?;
self.manager.read(token).ok()
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.tokens.size_hint()
}
}
#[cfg(test)]
mod tests {
use crate::prelude::{create_test_tensor_filled_with, TestTensor, TEST_TENSOR_BYTE_COUNT};
use super::*;
fn make_msg_tensor(v: u32) -> Message<TestTensor> {
Message::new(MessageHeader::empty(), create_test_tensor_filled_with(v))
}
#[test]
fn batch_basic_props() {
let arr: [Message<TestTensor>; 3] = [
make_msg_tensor(10),
make_msg_tensor(11),
make_msg_tensor(12),
];
let batch = Batch::new(&arr[..2]); assert_eq!(batch.len(), 2);
assert!(!batch.is_empty());
assert_eq!(batch.messages().len(), 2);
assert_eq!(batch.total_payload_bytes(), 2 * TEST_TENSOR_BYTE_COUNT);
assert!(!batch.first_flagged());
assert!(!batch.last_flagged());
}
#[test]
fn batchview_borrowed_basic_and_mutation() {
let mut arr: [Message<TestTensor>; 4] = [
make_msg_tensor(100),
make_msg_tensor(101),
make_msg_tensor(102),
make_msg_tensor(103),
];
let mut bv = BatchView::from_borrowed(&mut arr, 3);
assert_eq!(bv.len(), 3);
assert!(!bv.is_empty());
for (i, m) in bv.iter_mut().enumerate() {
*m.payload_mut() = create_test_tensor_filled_with(200 + (i as u32));
}
let batch = bv.as_batch();
let mut vals = [TestTensor::default(); 3];
let mut i = 0;
for m in batch.iter() {
vals[i] = m.payload().clone();
i += 1;
}
assert_eq!(
vals,
[
create_test_tensor_filled_with(200),
create_test_tensor_filled_with(201),
create_test_tensor_filled_with(202),
]
);
{
let fh = bv.first_header_mut().expect("first header");
fh.set_first_in_batch();
let lh = bv.last_header_mut().expect("last header");
lh.set_last_in_batch();
}
let batch2 = bv.as_batch();
assert!(batch2.first_flagged());
assert!(batch2.last_flagged());
}
#[cfg(feature = "alloc")]
#[test]
fn batchview_owned_basic_and_into_owned() {
use alloc::vec::Vec;
let mut vec: Vec<Message<TestTensor>> = Vec::new();
vec.push(make_msg_tensor(1));
vec.push(make_msg_tensor(2));
vec.push(make_msg_tensor(3));
let mut bv = BatchView::from_owned(vec);
assert_eq!(bv.len(), 3);
assert!(!bv.is_empty());
for (i, m) in bv.iter_mut().enumerate() {
if i == 2 {
*m.payload_mut() = create_test_tensor_filled_with(42);
}
}
let batch = bv.as_batch();
let mut xs: Vec<TestTensor> = Vec::new();
for m in batch.iter() {
xs.push(m.payload().clone());
}
assert_eq!(
xs.as_slice(),
&[
create_test_tensor_filled_with(1),
create_test_tensor_filled_with(2),
create_test_tensor_filled_with(42),
]
);
{
let fh = bv.first_header_mut().expect("first header");
fh.set_first_in_batch();
let lh = bv.last_header_mut().expect("last header");
lh.set_last_in_batch();
}
let batch2 = bv.as_batch();
assert!(batch2.first_flagged());
assert!(batch2.last_flagged());
let ov = bv.into_vec();
assert_eq!(ov.len(), 3);
assert_eq!(
*ov.last().unwrap().payload(),
create_test_tensor_filled_with(42)
);
}
#[test]
fn batch_assert_flags_consistent_no_panic_when_correct() {
let mut arr: [Message<TestTensor>; 2] = [make_msg_tensor(7), make_msg_tensor(8)];
{
let mut bv = BatchView::from_borrowed(&mut arr, 2);
bv.first_header_mut().unwrap().set_first_in_batch();
bv.last_header_mut().unwrap().set_last_in_batch();
let batch = bv.as_batch();
batch.assert_flags_consistent();
}
}
}