use crate::{
change_detection::MaybeLocation,
message::{Message, MessageCursor, MessageId, MessageInstance},
resource::Resource,
};
use alloc::vec::Vec;
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
#[cfg(feature = "bevy_reflect")]
use {
crate::reflect::ReflectResource,
bevy_reflect::{std_traits::ReflectDefault, Reflect},
};
#[derive(Debug, Resource)]
#[cfg_attr(feature = "bevy_reflect", derive(Reflect), reflect(Resource, Default))]
pub struct Messages<E: Message> {
pub(crate) messages_a: MessageSequence<E>,
pub(crate) messages_b: MessageSequence<E>,
pub(crate) message_count: usize,
}
impl<E: Message> Default for Messages<E> {
fn default() -> Self {
Self {
messages_a: Default::default(),
messages_b: Default::default(),
message_count: Default::default(),
}
}
}
impl<M: Message> Messages<M> {
pub fn oldest_message_count(&self) -> usize {
self.messages_a.start_message_count
}
#[track_caller]
pub fn write(&mut self, message: M) -> MessageId<M> {
self.write_with_caller(message, MaybeLocation::caller())
}
pub(crate) fn write_with_caller(&mut self, message: M, caller: MaybeLocation) -> MessageId<M> {
let message_id = MessageId {
id: self.message_count,
caller,
_marker: PhantomData,
};
#[cfg(feature = "detailed_trace")]
tracing::trace!("Messages::write() -> id: {}", message_id);
let message_instance = MessageInstance {
message_id,
message,
};
self.messages_b.push(message_instance);
self.message_count += 1;
message_id
}
#[track_caller]
pub fn write_batch(&mut self, messages: impl IntoIterator<Item = M>) -> WriteBatchIds<M> {
let last_count = self.message_count;
self.extend(messages);
WriteBatchIds {
last_count,
message_count: self.message_count,
_marker: PhantomData,
}
}
#[track_caller]
pub fn write_default(&mut self) -> MessageId<M>
where
M: Default,
{
self.write(Default::default())
}
pub fn get_cursor(&self) -> MessageCursor<M> {
MessageCursor::default()
}
pub fn get_cursor_current(&self) -> MessageCursor<M> {
MessageCursor {
last_message_count: self.message_count,
..Default::default()
}
}
pub fn update(&mut self) {
core::mem::swap(&mut self.messages_a, &mut self.messages_b);
self.messages_b.clear();
self.messages_b.start_message_count = self.message_count;
debug_assert_eq!(
self.messages_a.start_message_count + self.messages_a.len(),
self.messages_b.start_message_count
);
}
#[must_use = "If you do not need the returned messages, call .update() instead."]
pub fn update_drain(&mut self) -> impl Iterator<Item = M> + '_ {
core::mem::swap(&mut self.messages_a, &mut self.messages_b);
let iter = self.messages_b.messages.drain(..);
self.messages_b.start_message_count = self.message_count;
debug_assert_eq!(
self.messages_a.start_message_count + self.messages_a.len(),
self.messages_b.start_message_count
);
iter.map(|e| e.message)
}
#[inline]
fn reset_start_message_count(&mut self) {
self.messages_a.start_message_count = self.message_count;
self.messages_b.start_message_count = self.message_count;
}
#[inline]
pub fn clear(&mut self) {
self.reset_start_message_count();
self.messages_a.clear();
self.messages_b.clear();
}
#[inline]
pub fn len(&self) -> usize {
self.messages_a.len() + self.messages_b.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn drain(&mut self) -> impl Iterator<Item = M> + '_ {
self.reset_start_message_count();
self.messages_a
.drain(..)
.chain(self.messages_b.drain(..))
.map(|i| i.message)
}
pub fn iter_current_update_messages(&self) -> impl ExactSizeIterator<Item = &M> {
self.messages_b.iter().map(|i| &i.message)
}
pub fn get_message(&self, id: usize) -> Option<(&M, MessageId<M>)> {
if id < self.oldest_message_count() {
return None;
}
let sequence = self.sequence(id);
let index = id.saturating_sub(sequence.start_message_count);
sequence
.get(index)
.map(|instance| (&instance.message, instance.message_id))
}
fn sequence(&self, id: usize) -> &MessageSequence<M> {
if id < self.messages_b.start_message_count {
&self.messages_a
} else {
&self.messages_b
}
}
}
impl<E: Message> Extend<E> for Messages<E> {
#[track_caller]
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = E>,
{
let old_count = self.message_count;
let mut message_count = self.message_count;
let messages = iter.into_iter().map(|message| {
let message_id = MessageId {
id: message_count,
caller: MaybeLocation::caller(),
_marker: PhantomData,
};
message_count += 1;
MessageInstance {
message_id,
message,
}
});
self.messages_b.extend(messages);
if old_count != message_count {
#[cfg(feature = "detailed_trace")]
tracing::trace!(
"Messages::extend() -> ids: ({}..{})",
self.message_count,
message_count
);
}
self.message_count = message_count;
}
}
#[derive(Debug)]
#[cfg_attr(feature = "bevy_reflect", derive(Reflect), reflect(Default))]
pub(crate) struct MessageSequence<E: Message> {
pub(crate) messages: Vec<MessageInstance<E>>,
pub(crate) start_message_count: usize,
}
impl<E: Message> Default for MessageSequence<E> {
fn default() -> Self {
Self {
messages: Default::default(),
start_message_count: Default::default(),
}
}
}
impl<E: Message> Deref for MessageSequence<E> {
type Target = Vec<MessageInstance<E>>;
fn deref(&self) -> &Self::Target {
&self.messages
}
}
impl<E: Message> DerefMut for MessageSequence<E> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.messages
}
}
pub struct WriteBatchIds<E> {
last_count: usize,
message_count: usize,
_marker: PhantomData<E>,
}
impl<E: Message> Iterator for WriteBatchIds<E> {
type Item = MessageId<E>;
fn next(&mut self) -> Option<Self::Item> {
if self.last_count >= self.message_count {
return None;
}
let result = Some(MessageId {
id: self.last_count,
caller: MaybeLocation::caller(),
_marker: PhantomData,
});
self.last_count += 1;
result
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = <Self as ExactSizeIterator>::len(self);
(len, Some(len))
}
}
impl<E: Message> ExactSizeIterator for WriteBatchIds<E> {
fn len(&self) -> usize {
self.message_count.saturating_sub(self.last_count)
}
}
#[cfg(test)]
mod tests {
use crate::message::{Message, Messages};
#[test]
fn iter_current_update_messages_iterates_over_current_messages() {
#[derive(Message, Clone)]
struct TestMessage;
let mut test_messages = Messages::<TestMessage>::default();
assert_eq!(test_messages.len(), 0);
assert_eq!(test_messages.iter_current_update_messages().count(), 0);
test_messages.update();
test_messages.write(TestMessage);
assert_eq!(test_messages.len(), 1);
assert_eq!(test_messages.iter_current_update_messages().count(), 1);
test_messages.update();
test_messages.write(TestMessage);
test_messages.write(TestMessage);
assert_eq!(test_messages.len(), 3); assert_eq!(test_messages.iter_current_update_messages().count(), 2);
test_messages.update();
assert_eq!(test_messages.len(), 2); assert_eq!(test_messages.iter_current_update_messages().count(), 0);
}
#[test]
fn write_batch_iter_size_hint() {
#[derive(Message, Clone, Copy)]
struct TestMessage;
let mut test_messages = Messages::<TestMessage>::default();
let write_batch_ids = test_messages.write_batch([TestMessage; 4]);
let expected_len = 4;
assert_eq!(write_batch_ids.len(), expected_len);
assert_eq!(
write_batch_ids.size_hint(),
(expected_len, Some(expected_len))
);
}
}