use std::sync::Arc;
use prost_types::Any;
use crate::proto::{event_page, EventBook, EventPage};
pub type EventApplier<S> = Box<dyn Fn(&mut S, &[u8]) + Send + Sync>;
pub type StateFactory<S> = Box<dyn Fn() -> S + Send + Sync>;
pub type EventApplierHOF<S> = Arc<dyn Fn() -> EventApplier<S> + Send + Sync>;
enum HandlerEntry<S> {
Static(EventApplier<S>),
Factory(EventApplierHOF<S>),
}
pub struct StateRouter<S: Default> {
handlers: Vec<(String, HandlerEntry<S>)>,
factory: Option<StateFactory<S>>,
}
impl<S: Default + 'static> Default for StateRouter<S> {
fn default() -> Self {
Self::new()
}
}
impl<S: Default + 'static> StateRouter<S> {
pub fn new() -> Self {
Self {
handlers: Vec::new(),
factory: None,
}
}
pub fn with_factory(factory: fn() -> S) -> Self {
Self {
handlers: Vec::new(),
factory: Some(Box::new(factory)),
}
}
fn create_state(&self) -> S {
match &self.factory {
Some(factory) => factory(),
None => S::default(),
}
}
pub fn on<E>(mut self, handler: fn(&mut S, E)) -> Self
where
E: prost::Message + Default + prost::Name + 'static,
{
let type_name = E::full_name();
let boxed: EventApplier<S> = Box::new(move |state, bytes| {
if let Ok(event) = E::decode(bytes) {
handler(state, event);
}
});
self.handlers.push((type_name, HandlerEntry::Static(boxed)));
self
}
pub fn on_with<E, F>(mut self, factory: F) -> Self
where
E: prost::Message + Default + prost::Name + 'static,
F: Fn() -> Box<dyn Fn(&mut S, E) + Send + Sync> + Send + Sync + 'static,
{
let type_name = E::full_name();
let factory_arc: EventApplierHOF<S> = Arc::new(move || {
let inner = factory();
Box::new(move |state: &mut S, bytes: &[u8]| {
if let Ok(event) = E::decode(bytes) {
inner(state, event);
}
})
});
self.handlers
.push((type_name, HandlerEntry::Factory(factory_arc)));
self
}
pub fn with_events(&self, pages: &[EventPage]) -> S {
let mut state = self.create_state();
for page in pages {
if let Some(event_page::Payload::Event(event)) = &page.payload {
self.apply_single(&mut state, event);
}
}
state
}
pub fn with_event_book(&self, event_book: &EventBook) -> S {
self.with_events(&event_book.pages)
}
pub fn apply_single(&self, state: &mut S, event_any: &Any) {
let type_url = &event_any.type_url;
for (type_name, entry) in &self.handlers {
if Self::type_matches(type_url, type_name) {
match entry {
HandlerEntry::Static(handler) => {
handler(state, &event_any.value);
}
HandlerEntry::Factory(factory) => {
let handler = factory();
handler(state, &event_any.value);
}
}
return;
}
}
}
fn type_matches(type_url: &str, type_name: &str) -> bool {
type_url == format!("type.googleapis.com/{}", type_name)
}
pub fn into_rebuilder(self) -> impl Fn(&EventBook) -> S + Send + Sync {
move |event_book| self.with_event_book(event_book)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn type_matches_requires_fully_qualified_name() {
assert!(StateRouter::<()>::type_matches(
"type.googleapis.com/examples.CardsDealt",
"examples.CardsDealt"
));
assert!(!StateRouter::<()>::type_matches(
"type.googleapis.com/examples.CardsDealt",
"CardsDealt"
));
}
#[test]
fn type_matches_rejects_partial_names() {
assert!(!StateRouter::<()>::type_matches(
"type.googleapis.com/examples.CommunityCardsDealt",
"examples.CardsDealt"
));
assert!(StateRouter::<()>::type_matches(
"type.googleapis.com/examples.CommunityCardsDealt",
"examples.CommunityCardsDealt"
));
}
#[test]
fn type_matches_rejects_wrong_package() {
assert!(!StateRouter::<()>::type_matches(
"type.googleapis.com/examples.CardsDealt",
"other.CardsDealt"
));
}
#[test]
fn type_matches_handles_edge_cases() {
assert!(!StateRouter::<()>::type_matches(
"type.googleapis.com/examples.Test",
""
));
assert!(!StateRouter::<()>::type_matches(
"type.googleapis.com/examples.Other",
"examples.CardsDealt"
));
}
#[test]
fn state_router_default() {
let router: StateRouter<String> = StateRouter::default();
let state = router.with_events(&[]);
assert_eq!(state, String::default());
}
#[test]
fn on_with_factory_called_per_event() {
use std::sync::atomic::{AtomicU32, Ordering};
let call_count = Arc::new(AtomicU32::new(0));
let call_count_clone = call_count.clone();
let router: StateRouter<u32> = StateRouter::new();
let _ = (move || {
call_count_clone.fetch_add(1, Ordering::SeqCst);
Box::new(|_state: &mut u32, _value: u32| {})
as Box<dyn Fn(&mut u32, u32) + Send + Sync>
})();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
assert!(router.handlers.is_empty());
}
}