use std::any;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::fmt;
use std::marker::PhantomData;
#[cfg(feature = "server")]
use ciborium;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::path::Path;
use crate::ports::QuerySource;
use crate::simulation::{QueryId, QueryIdErased, SchedulerRegistry};
#[cfg(feature = "server")]
use crate::ports::{ReplyReader, ReplyWriter};
use super::{EndpointError, Message, MessageSchema};
#[cfg(feature = "server")]
type SerializationError = ciborium::ser::Error<std::io::Error>;
#[cfg(feature = "server")]
type DeserializationError = ciborium::de::Error<std::io::Error>;
#[derive(Default)]
pub(crate) struct QuerySourceRegistry(HashMap<Path, Box<dyn QuerySourceEntryAny>>);
impl QuerySourceRegistry {
pub(crate) fn add<T, R>(
&mut self,
source: QuerySource<T, R>,
path: Path,
registry: &mut SchedulerRegistry,
) -> Result<(), (Path, QuerySource<T, R>)>
where
T: Message + Serialize + DeserializeOwned + Clone + Send + 'static,
R: Message + Serialize + Send + 'static,
{
self.add_any(source, path, || (T::schema(), R::schema()), registry)
}
pub(crate) fn add_raw<T, R>(
&mut self,
source: QuerySource<T, R>,
path: Path,
registry: &mut SchedulerRegistry,
) -> Result<(), (Path, QuerySource<T, R>)>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
R: Serialize + Send + 'static,
{
self.add_any(source, path, || (String::new(), String::new()), registry)
}
fn add_any<T, R, F>(
&mut self,
source: QuerySource<T, R>,
path: Path,
schema_gen: F,
registry: &mut SchedulerRegistry,
) -> Result<(), (Path, QuerySource<T, R>)>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
R: Serialize + Send + 'static,
F: Fn() -> (MessageSchema, MessageSchema) + Send + Sync + 'static,
{
match self.0.entry(path) {
Entry::Vacant(s) => {
let query_id = registry.add_query_source(source);
let entry = QuerySourceEntry {
inner: query_id,
schema_gen,
};
s.insert(Box::new(entry));
Ok(())
}
Entry::Occupied(e) => Err((e.key().clone(), source)),
}
}
pub(crate) fn get(&self, path: &Path) -> Result<&dyn QuerySourceEntryAny, EndpointError> {
self.0
.get(path)
.map(|s| s.as_ref())
.ok_or_else(|| EndpointError::QuerySourceNotFound { path: path.clone() })
}
pub(crate) fn get_source_id<T, R>(&self, path: &Path) -> Result<QueryId<T, R>, EndpointError>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
R: Send + 'static,
{
let query_id = self.get(path).and_then(|entry| {
if entry.request_type_id() == TypeId::of::<T>()
&& entry.reply_type_id() == TypeId::of::<R>()
{
Ok(entry.get_query_id())
} else {
Err(EndpointError::InvalidQuerySourceType {
path: path.clone(),
found_request_type: any::type_name::<T>(),
found_reply_type: any::type_name::<R>(),
expected_request_type: entry.request_type_name(),
expected_reply_type: entry.reply_type_name(),
})
}
})?;
Ok(QueryId(query_id.0, PhantomData))
}
pub(crate) fn list_sources(&self) -> impl Iterator<Item = &Path> {
self.0.keys()
}
pub(crate) fn get_source_schema(
&self,
path: &Path,
) -> Result<(MessageSchema, MessageSchema), EndpointError> {
Ok(self.get(path)?.query_schema())
}
#[cfg(feature = "server")]
pub(crate) fn list_schemas(
&self,
) -> impl Iterator<Item = (&Path, MessageSchema, MessageSchema)> {
self.0.iter().map(|(path, src)| {
let schemas = src.query_schema();
(path, schemas.0, schemas.1)
})
}
}
impl fmt::Debug for QuerySourceRegistry {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "QuerySourceRegistry ({} query sources)", self.0.len(),)
}
}
pub(crate) trait QuerySourceEntryAny: Any + Send + Sync + 'static {
#[cfg(feature = "server")]
fn deserialize_arg(&self, serialized_arg: &[u8]) -> Result<Box<dyn Any>, DeserializationError>;
fn request_type_id(&self) -> TypeId;
fn request_type_name(&self) -> &'static str;
fn reply_type_id(&self) -> TypeId;
fn reply_type_name(&self) -> &'static str;
fn query_schema(&self) -> (MessageSchema, MessageSchema);
fn get_query_id(&self) -> QueryIdErased;
#[cfg(feature = "server")]
fn replier(&self) -> (Box<dyn ReplyWriterAny>, Box<dyn ReplyReaderAny>);
}
struct QuerySourceEntry<T, R, F>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
R: Serialize + Send + 'static,
F: Fn() -> (MessageSchema, MessageSchema),
{
inner: QueryId<T, R>,
schema_gen: F,
}
impl<T, R, F> QuerySourceEntryAny for QuerySourceEntry<T, R, F>
where
T: Serialize + DeserializeOwned + Clone + Send + 'static,
R: Serialize + Send + 'static,
F: Fn() -> (MessageSchema, MessageSchema) + Send + Sync + 'static,
{
#[cfg(feature = "server")]
fn deserialize_arg(&self, serialized_arg: &[u8]) -> Result<Box<dyn Any>, DeserializationError> {
ciborium::from_reader(serialized_arg).map(|arg: T| Box::new(arg) as Box<dyn Any>)
}
fn request_type_id(&self) -> TypeId {
TypeId::of::<T>()
}
fn request_type_name(&self) -> &'static str {
any::type_name::<T>()
}
fn reply_type_id(&self) -> TypeId {
TypeId::of::<R>()
}
fn reply_type_name(&self) -> &'static str {
any::type_name::<R>()
}
fn query_schema(&self) -> (MessageSchema, MessageSchema) {
(self.schema_gen)()
}
fn get_query_id(&self) -> QueryIdErased {
self.inner.into()
}
#[cfg(feature = "server")]
fn replier(&self) -> (Box<dyn ReplyWriterAny>, Box<dyn ReplyReaderAny>) {
use crate::ports::query_replier;
let (tx, rx) = query_replier::<R>();
(Box::new(tx), Box::new(rx))
}
}
#[cfg(feature = "server")]
pub(crate) trait ReplyWriterAny: Any + Send {}
#[cfg(feature = "server")]
impl<R: Send + 'static> ReplyWriterAny for ReplyWriter<R> {}
#[cfg(feature = "server")]
pub(crate) trait ReplyReaderAny {
fn take_collect(&mut self) -> Option<Result<Vec<Vec<u8>>, SerializationError>>;
}
#[cfg(feature = "server")]
impl<R: Serialize + Send + 'static> ReplyReaderAny for ReplyReader<R> {
fn take_collect(&mut self) -> Option<Result<Vec<Vec<u8>>, SerializationError>> {
let replies = self.try_read()?;
let encoded_replies = (move || {
let mut encoded_replies = Vec::new();
for reply in replies {
let mut encoded_reply = Vec::new();
ciborium::into_writer(&reply, &mut encoded_reply)?;
encoded_replies.push(encoded_reply);
}
Ok(encoded_replies)
})();
Some(encoded_replies)
}
}