#![allow(clippy::return_self_not_must_use, clippy::must_use_candidate)]
use std::marker::PhantomData;
use bytes::Bytes;
use crate::RawMessage;
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
use crate::codec::{Codec, DefaultCodec};
use crate::runtime::HandlerResult;
use super::coordinator::{Coordinator, Outcome, Record};
#[derive(Debug)]
pub struct SubscriberAssertions<'a> {
coordinator: &'a Coordinator,
scope_id: usize,
name: String,
}
impl<'a> SubscriberAssertions<'a> {
pub(crate) fn new(coordinator: &'a Coordinator, scope_id: usize, name: String) -> Self {
Self {
coordinator,
scope_id,
name,
}
}
fn with_records<R>(&self, f: impl FnOnce(&[&Record]) -> R) -> R {
self.coordinator.with_records(self.scope_id, &self.name, f)
}
fn with_last<R>(&self, what: &str, f: impl FnOnce(&Record) -> R) -> R {
self.with_records(|records| {
let last = records.last().unwrap_or_else(|| {
panic!(
"subscriber {:?} was not called, cannot assert {what}",
self.name
)
});
f(last)
})
}
pub fn assert_called_once(self) -> Self {
let count = self.with_records(|records| records.len());
assert_eq!(
count, 1,
"subscriber {:?} was called {count} times, expected exactly once",
self.name,
);
self
}
pub fn assert_called(self, times: usize) -> Self {
let count = self.with_records(|records| records.len());
assert_eq!(
count, times,
"subscriber {:?} was called {count} times, expected {times}",
self.name,
);
self
}
#[must_use]
pub fn received_raw(&self) -> Vec<Bytes> {
self.with_records(|records| records.iter().map(|record| record.raw.clone()).collect())
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
#[must_use]
pub fn received<T: serde::de::DeserializeOwned>(&self) -> Vec<T> {
self.received_with(&DefaultCodec::default())
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
#[must_use]
pub fn received_with<T, C>(&self, codec: &C) -> Vec<T>
where
T: serde::de::DeserializeOwned,
C: Codec,
{
self.with_records(|records| {
records
.iter()
.map(|record| {
codec.decode(&record.raw).unwrap_or_else(|err| {
panic!(
"subscriber {:?} received a payload that did not decode as {}: {err}",
self.name,
std::any::type_name::<T>(),
)
})
})
.collect()
})
}
pub fn assert_not_called(self) {
let count = self.with_records(|records| records.len());
assert_eq!(
count, 0,
"subscriber {:?} was called {count} times, expected never",
self.name,
);
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
pub fn with<T>(self, expected: &T) -> Self
where
T: serde::de::DeserializeOwned + PartialEq + std::fmt::Debug,
{
self.with_codec(&DefaultCodec::default(), expected)
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
pub fn with_codec<T, C>(self, codec: &C, expected: &T) -> Self
where
T: serde::de::DeserializeOwned + PartialEq + std::fmt::Debug,
C: Codec,
{
self.with_last("the received value", |record| {
let actual: T = codec.decode(&record.raw).unwrap_or_else(|err| {
panic!(
"subscriber {:?} received a payload that did not decode as {}: {err}",
self.name,
std::any::type_name::<T>(),
)
});
assert_eq!(
&actual, expected,
"subscriber {:?} received an unexpected value",
self.name
);
});
self
}
pub fn with_raw(self, bytes: &[u8]) -> Self {
self.with_last("the raw payload", |record| {
assert_eq!(
record.raw.as_ref(),
bytes,
"subscriber {:?} received unexpected raw bytes",
self.name,
);
});
self
}
pub fn settled(self, outcome: HandlerResult) -> Self {
self.with_last("the settlement", |record| {
assert_eq!(
record.settle,
Some(outcome),
"subscriber {:?} settled differently than expected",
self.name,
);
});
self
}
pub fn panicked(self) -> Self {
self.with_last("a panic", |record| {
assert!(
record.panicked,
"subscriber {:?} did not panic on its last delivery",
self.name,
);
});
self
}
pub fn assert_outcome(self, expected: Outcome) -> Self {
self.with_last("the outcome", |record| {
assert_eq!(
record.outcome(),
expected,
"subscriber {:?} had an unexpected outcome",
self.name,
);
});
self
}
pub fn assert_last_failed_to_decode(self) {
self.with_last("a decode failure", |record| {
assert!(
record.decode_failed,
"subscriber {:?} decoded its last delivery successfully",
self.name,
);
});
}
}
#[derive(Debug)]
pub struct PublishedAssertions<T> {
name: String,
messages: Vec<RawMessage>,
_payload: PhantomData<fn() -> T>,
}
impl<T> PublishedAssertions<T> {
pub(crate) fn new(name: String, messages: Vec<RawMessage>) -> Self {
Self {
name,
messages,
_payload: PhantomData,
}
}
pub fn assert_called_once(self) -> Self {
let count = self.messages.len();
assert_eq!(
count, 1,
"channel {:?} was published to {count} times, expected exactly once",
self.name,
);
self
}
pub fn assert_not_called(self) {
let count = self.messages.len();
assert_eq!(
count, 0,
"channel {:?} was published to {count} times, expected never",
self.name,
);
}
#[must_use]
pub fn messages(&self) -> &[RawMessage] {
&self.messages
}
fn last(&self, what: &str) -> &RawMessage {
self.messages.last().unwrap_or_else(|| {
panic!(
"nothing was published to {:?}, cannot assert {what}",
self.name
)
})
}
pub fn with_raw(self, bytes: &[u8]) -> Self {
assert_eq!(
self.last("the raw payload").payload(),
bytes,
"channel {:?} published unexpected raw bytes",
self.name,
);
self
}
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
impl<T> PublishedAssertions<T>
where
T: serde::de::DeserializeOwned + PartialEq + std::fmt::Debug,
{
pub fn with(self, expected: &T) -> Self {
self.with_codec(&DefaultCodec::default(), expected)
}
pub fn with_codec<C: Codec>(self, codec: &C, expected: &T) -> Self {
let actual: T = codec
.decode(self.last("the published value").payload())
.unwrap_or_else(|err| {
panic!(
"channel {:?} published a payload that did not decode as {}: {err}",
self.name,
std::any::type_name::<T>(),
)
});
assert_eq!(
&actual, expected,
"channel {:?} published an unexpected value",
self.name
);
self
}
}
#[cfg(any(feature = "json", feature = "cbor", feature = "msgpack"))]
impl<T: serde::de::DeserializeOwned> PublishedAssertions<T> {
#[must_use]
pub fn decoded(&self) -> Vec<T> {
self.decoded_with(&DefaultCodec::default())
}
#[must_use]
pub fn decoded_with<C: Codec>(&self, codec: &C) -> Vec<T> {
self.messages
.iter()
.map(|message| {
codec.decode(message.payload()).unwrap_or_else(|err| {
panic!(
"channel {:?} published a payload that did not decode as {}: {err}",
self.name,
std::any::type_name::<T>(),
)
})
})
.collect()
}
}