use std::convert::Infallible;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use futures::{Stream, StreamExt};
use ruststream::memory::{MemoryBroker, MemoryMessage, MemorySubscriber};
use ruststream::runtime::{
AppInfo, Context, Handler, HandlerExt, HandlerMetadata, HandlerResult, Layer, RustStream,
Settle,
};
use ruststream::{
AckError, BuildContext, Field, FieldMut, Headers, IncomingMessage, OutgoingMessage, Publisher,
};
use tokio::sync::Notify;
struct TaggedMessage {
inner: MemoryMessage,
tag: u32,
}
impl IncomingMessage for TaggedMessage {
fn payload(&self) -> &[u8] {
self.inner.payload()
}
fn headers(&self) -> &Headers {
self.inner.headers()
}
async fn ack(self) -> Result<(), AckError> {
self.inner.ack().await
}
async fn nack(self, requeue: bool) -> Result<(), AckError> {
self.inner.nack(requeue).await
}
}
struct TagContext {
tag: u32,
}
impl BuildContext<TaggedMessage> for TagContext {
fn build(msg: &TaggedMessage) -> Self {
Self { tag: msg.tag }
}
}
#[derive(Clone, Copy)]
struct Tag;
impl Field<TagContext> for Tag {
type Value<'a> = u32;
fn get(self, cx: &TagContext) -> u32 {
cx.tag
}
}
struct TaggedSubscriber {
inner: MemorySubscriber,
next_tag: u32,
}
impl ruststream::Subscriber for TaggedSubscriber {
type Message = TaggedMessage;
type Error = Infallible;
fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
self.inner.stream().map(|item| {
self.next_tag += 1;
item.map(|inner| TaggedMessage {
inner,
tag: self.next_tag,
})
})
}
}
async fn wait_for(mut cond: impl FnMut() -> bool, timeout: Duration) {
let result = tokio::time::timeout(timeout, async {
while !cond() {
tokio::task::yield_now().await;
}
})
.await;
assert!(result.is_ok(), "condition not met within {timeout:?}");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn broker_contributed_field_reaches_handler_by_key() {
let broker = MemoryBroker::new();
let publisher = broker.publisher();
let seen: Arc<Mutex<Vec<u32>>> = Arc::new(Mutex::new(Vec::new()));
let seen_clone = Arc::clone(&seen);
let app = RustStream::new(AppInfo::new("svc", "0.1.0")).with_broker(broker, |b| {
let subscriber = TaggedSubscriber {
inner: b.broker().subscribe("orders"),
next_tag: 0,
};
b.handle(
subscriber,
move |_msg: &TaggedMessage, ctx: &mut Context<'_, TagContext>| {
let tag = ctx.context(Tag);
let seen = Arc::clone(&seen_clone);
async move {
seen.lock().expect("poisoned").push(tag);
HandlerResult::Ack
}
},
HandlerMetadata::raw("orders"),
);
});
let shutdown = Arc::new(Notify::new());
let shutdown_signal = Arc::clone(&shutdown);
let run = tokio::spawn(app.run_until(async move { shutdown_signal.notified().await }));
publisher
.publish(OutgoingMessage::new("orders", b"a"))
.await
.unwrap();
publisher
.publish(OutgoingMessage::new("orders", b"b"))
.await
.unwrap();
wait_for(
|| seen.lock().expect("poisoned").len() >= 2,
Duration::from_secs(5),
)
.await;
assert_eq!(*seen.lock().expect("poisoned"), vec![1, 2]);
shutdown.notify_one();
run.await.unwrap().unwrap();
}
#[derive(Default)]
struct Scratch {
stamp: Option<u32>,
}
impl<M: ?Sized> BuildContext<M> for Scratch {
fn build(_msg: &M) -> Self {
Self::default()
}
}
#[derive(Clone, Copy)]
struct Stamp;
impl Field<Scratch> for Stamp {
type Value<'a> = Option<&'a u32>;
fn get(self, cx: &Scratch) -> Option<&u32> {
cx.stamp.as_ref()
}
}
impl FieldMut<Scratch> for Stamp {
type Owned = u32;
fn set(self, cx: &mut Scratch, value: u32) {
cx.stamp = Some(value);
}
}
struct StampLayer {
counter: Arc<std::sync::atomic::AtomicU32>,
}
struct StampHandler<H> {
inner: H,
counter: Arc<std::sync::atomic::AtomicU32>,
}
impl<H> Layer<H> for StampLayer {
type Handler = StampHandler<H>;
fn layer(&self, inner: H) -> StampHandler<H> {
StampHandler {
inner,
counter: Arc::clone(&self.counter),
}
}
}
impl<M, H> Handler<M, Scratch> for StampHandler<H>
where
M: Sync,
H: Handler<M, Scratch>,
{
async fn handle(&self, msg: &M, ctx: &mut Context<'_, Scratch>) -> Settle {
assert!(
ctx.context(Stamp).is_none(),
"scratch leaked across deliveries"
);
let n = self
.counter
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
ctx.set(Stamp, n);
self.inner.handle(msg, ctx).await
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn middleware_written_scratch_reaches_downstream_handler_and_is_isolated() {
let broker = MemoryBroker::new();
let publisher = broker.publisher();
let seen: Arc<Mutex<Vec<u32>>> = Arc::new(Mutex::new(Vec::new()));
let seen_clone = Arc::clone(&seen);
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let app = RustStream::new(AppInfo::new("svc", "0.1.0")).with_broker(broker, |b| {
let subscriber = b.broker().subscribe("orders");
let handler = {
let layer = StampLayer {
counter: Arc::clone(&counter),
};
(move |_msg: &MemoryMessage, ctx: &mut Context<'_, Scratch>| {
let stamp = ctx.context(Stamp).copied();
let seen = Arc::clone(&seen_clone);
async move {
if let Some(n) = stamp {
seen.lock().expect("poisoned").push(n);
}
HandlerResult::Ack
}
})
.with(layer)
};
b.handle(subscriber, handler, HandlerMetadata::raw("orders"));
});
let shutdown = Arc::new(Notify::new());
let shutdown_signal = Arc::clone(&shutdown);
let run = tokio::spawn(app.run_until(async move { shutdown_signal.notified().await }));
publisher
.publish(OutgoingMessage::new("orders", b"a"))
.await
.unwrap();
publisher
.publish(OutgoingMessage::new("orders", b"b"))
.await
.unwrap();
wait_for(
|| seen.lock().expect("poisoned").len() >= 2,
Duration::from_secs(5),
)
.await;
assert_eq!(*seen.lock().expect("poisoned"), vec![0, 1]);
shutdown.notify_one();
run.await.unwrap().unwrap();
}
struct AppPrefix(String);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn state_reaches_app_state_independently_of_the_delivery_context() {
let broker = MemoryBroker::new();
let publisher = broker.publisher();
let seen: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let seen_clone = Arc::clone(&seen);
let app = RustStream::new(AppInfo::new("svc", "0.1.0"))
.on_startup(|()| async { Ok::<_, Infallible>(AppPrefix("svc".to_owned())) })
.with_broker(broker, |b| {
let subscriber = b.broker().subscribe("orders");
b.handle(
subscriber,
move |_msg: &MemoryMessage, ctx: &mut Context<'_, (), AppPrefix>| {
let prefix = Some(ctx.state().0.clone());
let seen = Arc::clone(&seen_clone);
async move {
*seen.lock().expect("poisoned") = prefix;
HandlerResult::Ack
}
},
HandlerMetadata::raw("orders"),
);
});
let shutdown = Arc::new(Notify::new());
let shutdown_signal = Arc::clone(&shutdown);
let run = tokio::spawn(app.run_until(async move { shutdown_signal.notified().await }));
publisher
.publish(OutgoingMessage::new("orders", b"x"))
.await
.unwrap();
wait_for(
|| seen.lock().expect("poisoned").is_some(),
Duration::from_secs(5),
)
.await;
assert_eq!(*seen.lock().expect("poisoned"), Some("svc".to_owned()));
shutdown.notify_one();
run.await.unwrap().unwrap();
}