use std::any;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::fmt;
use std::marker::PhantomData;
#[cfg(feature = "server")]
use std::time::Duration;
#[cfg(feature = "server")]
use ciborium;
use serde::{Serialize, de::DeserializeOwned};
use crate::path::Path;
use crate::ports::EventSource;
use crate::simulation::{EventId, EventIdErased, SchedulerRegistry};
#[cfg(feature = "server")]
use crate::simulation::{Event, EventKey};
use super::{EndpointError, Message, MessageSchema};
#[cfg(feature = "server")]
type DeserializationError = ciborium::de::Error<std::io::Error>;
#[derive(Default)]
pub(crate) struct EventSourceRegistry(HashMap<Path, Box<dyn EventSourceEntryAny>>);
impl EventSourceRegistry {
pub(crate) fn add<T>(
&mut self,
source: EventSource<T>,
path: Path,
registry: &mut SchedulerRegistry,
) -> Result<(), (Path, EventSource<T>)>
where
T: Message + Serialize + DeserializeOwned + Clone + Send + 'static,
{
self.add_any(source, path, T::schema, registry)
}
pub(crate) fn add_raw<T>(
&mut self,
source: EventSource<T>,
path: Path,
registry: &mut SchedulerRegistry,
) -> Result<(), (Path, EventSource<T>)>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
{
self.add_any(source, path, String::new, registry)
}
fn add_any<T, F>(
&mut self,
source: EventSource<T>,
path: Path,
schema_gen: F,
registry: &mut SchedulerRegistry,
) -> Result<(), (Path, EventSource<T>)>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
F: Fn() -> MessageSchema + Send + Sync + 'static,
{
match self.0.entry(path) {
Entry::Vacant(s) => {
let event_id = registry.add_event_source(source);
let entry = EventSourceEntry {
inner: event_id,
schema_gen,
};
s.insert(Box::new(entry));
Ok(())
}
Entry::Occupied(e) => Err((e.key().clone(), source)),
}
}
pub(crate) fn get(&self, path: &Path) -> Result<&dyn EventSourceEntryAny, EndpointError> {
self.0
.get(path)
.map(|s| s.as_ref())
.ok_or_else(|| EndpointError::EventSourceNotFound { path: path.clone() })
}
pub(crate) fn get_source_id<T>(&self, path: &Path) -> Result<EventId<T>, EndpointError>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
{
let event_id = self.get(path).and_then(|entry| {
if entry.event_type_id() == TypeId::of::<T>() {
Ok(entry.get_event_id())
} else {
Err(EndpointError::InvalidEventSourceType {
path: path.clone(),
found_event_type: any::type_name::<T>(),
expected_event_type: entry.event_type_name(),
})
}
})?;
Ok(EventId(event_id.0, PhantomData))
}
pub(crate) fn list_sources(&self) -> impl Iterator<Item = &Path> {
self.0.keys()
}
pub(crate) fn get_source_schema(&self, path: &Path) -> Result<MessageSchema, EndpointError> {
Ok(self.get(path)?.event_schema())
}
#[cfg(feature = "server")]
pub(crate) fn list_schemas(&self) -> impl Iterator<Item = (&Path, MessageSchema)> {
self.0.iter().map(|(path, src)| (path, src.event_schema()))
}
}
impl fmt::Debug for EventSourceRegistry {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "EventSourceRegistry ({} sources)", self.0.len())
}
}
pub(crate) trait EventSourceEntryAny: Any + Send + Sync + 'static {
#[cfg(feature = "server")]
fn deserialize_arg(&self, serialized_arg: &[u8]) -> Result<Box<dyn Any>, DeserializationError>;
#[cfg(feature = "server")]
fn event(&self, serialized_arg: &[u8]) -> Result<Event, DeserializationError>;
#[cfg(feature = "server")]
fn keyed_event(&self, serialized_arg: &[u8])
-> Result<(Event, EventKey), DeserializationError>;
#[cfg(feature = "server")]
fn periodic_event(
&self,
period: Duration,
serialized_arg: &[u8],
) -> Result<Event, DeserializationError>;
#[cfg(feature = "server")]
fn keyed_periodic_event(
&self,
period: Duration,
serialized_arg: &[u8],
) -> Result<(Event, EventKey), DeserializationError>;
fn event_type_id(&self) -> TypeId;
fn event_type_name(&self) -> &'static str;
fn event_schema(&self) -> MessageSchema;
fn get_event_id(&self) -> EventIdErased;
}
struct EventSourceEntry<T, F>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
F: Fn() -> MessageSchema,
{
inner: EventId<T>,
schema_gen: F,
}
impl<T, F> EventSourceEntryAny for EventSourceEntry<T, F>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
F: Fn() -> MessageSchema + Send + Sync + 'static,
{
#[cfg(feature = "server")]
fn deserialize_arg(&self, serialized_arg: &[u8]) -> Result<Box<dyn Any>, DeserializationError> {
ciborium::from_reader(serialized_arg).map(|arg: T| Box::new(arg) as Box<dyn Any>)
}
#[cfg(feature = "server")]
fn event(&self, serialized_arg: &[u8]) -> Result<Event, DeserializationError> {
ciborium::from_reader(serialized_arg).map(|arg| Event::new(&self.inner, arg))
}
#[cfg(feature = "server")]
fn keyed_event(
&self,
serialized_arg: &[u8],
) -> Result<(Event, EventKey), DeserializationError> {
let key = EventKey::new();
ciborium::from_reader(serialized_arg)
.map(|arg| (Event::new(&self.inner, arg).with_key(key.clone()), key))
}
#[cfg(feature = "server")]
fn periodic_event(
&self,
period: Duration,
serialized_arg: &[u8],
) -> Result<Event, DeserializationError> {
ciborium::from_reader(serialized_arg)
.map(|arg| Event::new(&self.inner, arg).with_period(period))
}
#[cfg(feature = "server")]
fn keyed_periodic_event(
&self,
period: Duration,
serialized_arg: &[u8],
) -> Result<(Event, EventKey), DeserializationError> {
let key = EventKey::new();
ciborium::from_reader(serialized_arg).map(|arg| {
(
Event::new(&self.inner, arg)
.with_period(period)
.with_key(key.clone()),
key,
)
})
}
fn event_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn event_type_name(&self) -> &'static str {
any::type_name::<T>()
}
fn event_schema(&self) -> MessageSchema {
(self.schema_gen)()
}
fn get_event_id(&self) -> EventIdErased {
self.inner.into()
}
}