use std::{
borrow::Cow,
collections::{HashMap, hash_map::Entry},
sync::Arc,
};
use schemars::{JsonSchema, Schema, SchemaGenerator};
use serde::{Serialize, de::DeserializeOwned};
use super::{
DiagramContext, DiagramErrorCode, DynForkResult, DynInputSlot, DynOutput, JsonMessage,
MessageRegistration, MessageRegistry, TypeInfo, TypeMismatch, supported::*,
};
use crate::JsonBuffer;
#[cfg(feature = "trace")]
use crate::Trace;
pub trait DynType {
fn type_name() -> Cow<'static, str>;
fn json_schema(generator: &mut SchemaGenerator) -> Schema;
}
impl<T> DynType for T
where
T: JsonSchema,
{
fn type_name() -> Cow<'static, str> {
<T>::schema_name()
}
fn json_schema(generator: &mut SchemaGenerator) -> Schema {
generator.subschema_for::<T>()
}
}
pub trait SerializeMessage<T> {
fn register_serialize(
messages: &mut HashMap<TypeInfo, MessageRegistration>,
schema_generator: &mut SchemaGenerator,
);
}
impl<T> SerializeMessage<T> for Supported
where
T: Serialize + DynType + Send + Sync + 'static,
{
fn register_serialize(
messages: &mut HashMap<TypeInfo, MessageRegistration>,
schema_generator: &mut SchemaGenerator,
) {
let reg = &mut messages
.entry(TypeInfo::of::<T>())
.or_insert(MessageRegistration::new::<T>());
reg.operations.serialize_impl = Some(|builder| {
let serialize = builder.create_map_block(|message: T| {
serde_json::to_value(message).map_err(|err| err.to_string())
});
let (ok, err) = builder
.chain(serialize.output)
.fork_result(|ok| ok.output(), |err| err.output());
Ok(DynForkResult {
input: serialize.input.into(),
ok: ok.into(),
err: err.into(),
})
});
#[cfg(feature = "trace")]
{
reg.operations.enable_trace_serialization =
Some(Trace::enable_value_serialization::<T>);
}
if reg.schema.is_none() {
reg.schema = Some(T::json_schema(schema_generator));
}
}
}
pub trait DeserializeMessage<T> {
fn register_deserialize(
messages: &mut HashMap<TypeInfo, MessageRegistration>,
schema_generator: &mut SchemaGenerator,
);
}
impl<T> DeserializeMessage<T> for Supported
where
T: 'static + Send + Sync + DeserializeOwned + DynType,
{
fn register_deserialize(
messages: &mut HashMap<TypeInfo, MessageRegistration>,
schema_generator: &mut SchemaGenerator,
) {
let reg = &mut messages
.entry(TypeInfo::of::<T>())
.or_insert(MessageRegistration::new::<T>());
reg.operations.deserialize_impl = Some(|builder| {
let deserialize = builder.create_map_block(|message: JsonMessage| {
serde_json::from_value::<T>(message).map_err(|err| err.to_string())
});
let (ok, err) = builder
.chain(deserialize.output)
.fork_result(|ok| ok.output(), |err| err.output());
Ok(DynForkResult {
input: deserialize.input.into(),
ok: ok.into(),
err: err.into(),
})
});
if reg.schema.is_none() {
reg.schema = Some(T::json_schema(schema_generator));
}
}
}
impl<T> SerializeMessage<T> for NotSupported {
fn register_serialize(_: &mut HashMap<TypeInfo, MessageRegistration>, _: &mut SchemaGenerator) {
}
}
impl<T> DeserializeMessage<T> for NotSupported {
fn register_deserialize(
_: &mut HashMap<TypeInfo, MessageRegistration>,
_: &mut SchemaGenerator,
) {
}
}
pub trait RegisterJson<T> {
fn register_json();
}
pub struct JsonRegistration<Serializer, Deserializer> {
_ignore: std::marker::PhantomData<fn(Serializer, Deserializer)>,
}
impl<T> RegisterJson<T> for JsonRegistration<Supported, Supported>
where
T: 'static + Send + Sync + Serialize + DeserializeOwned,
{
fn register_json() {
JsonBuffer::register_for::<T>();
}
}
impl<T> RegisterJson<T> for JsonRegistration<Supported, NotSupported> {
fn register_json() {
}
}
impl<T> RegisterJson<T> for JsonRegistration<NotSupported, Supported> {
fn register_json() {
}
}
impl<T> RegisterJson<T> for JsonRegistration<NotSupported, NotSupported> {
fn register_json() {
}
}
pub(super) fn register_json<T, Serializer, Deserializer>()
where
JsonRegistration<Serializer, Deserializer>: RegisterJson<T>,
{
JsonRegistration::<Serializer, Deserializer>::register_json();
}
pub struct ImplicitSerialization {
incoming_types: HashMap<TypeInfo, DynInputSlot>,
serialized_input: Arc<DynInputSlot>,
}
impl ImplicitSerialization {
pub fn new(serialized_input: DynInputSlot) -> Result<Self, DiagramErrorCode> {
if serialized_input.message_info() != &TypeInfo::of::<JsonMessage>() {
return Err(TypeMismatch {
source_type: TypeInfo::of::<JsonMessage>(),
target_type: *serialized_input.message_info(),
}
.into());
}
Ok(Self {
serialized_input: Arc::new(serialized_input),
incoming_types: Default::default(),
})
}
pub fn try_implicit_serialize(
&mut self,
incoming: DynOutput,
ctx: &mut DiagramContext,
) -> Result<Result<(), DynOutput>, DiagramErrorCode> {
if incoming.message_info() == &TypeInfo::of::<JsonMessage>() {
incoming.connect_to(&self.serialized_input, ctx.builder)?;
return Ok(Ok(()));
}
let input = match self.incoming_types.entry(*incoming.message_info()) {
Entry::Occupied(input_slot) => input_slot.get().clone(),
Entry::Vacant(vacant) => {
let Some(serialize) = ctx
.registry
.messages
.try_serialize(incoming.message_info(), ctx.builder)?
else {
return Ok(Err(incoming));
};
serialize
.ok
.connect_to(&self.serialized_input, ctx.builder)?;
let error_target = ctx.get_implicit_error_target();
ctx.add_output_into_target(error_target, serialize.err);
vacant.insert(serialize.input).clone()
}
};
incoming.connect_to(&input, ctx.builder)?;
Ok(Ok(()))
}
pub fn implicit_serialize(
&mut self,
incoming: DynOutput,
ctx: &mut DiagramContext,
) -> Result<(), DiagramErrorCode> {
self.try_implicit_serialize(incoming, ctx)?
.map_err(|incoming| DiagramErrorCode::NotSerializable(*incoming.message_info()))
}
pub fn serialized_input_slot(&self) -> &Arc<DynInputSlot> {
&self.serialized_input
}
}
pub struct ImplicitDeserialization {
deserialized_input: Arc<DynInputSlot>,
serialized_input: Option<DynInputSlot>,
}
impl ImplicitDeserialization {
pub fn try_new(
deserialized_input: DynInputSlot,
registration: &MessageRegistry,
) -> Result<Option<Self>, DiagramErrorCode> {
if registration
.messages
.get(&deserialized_input.message_info())
.and_then(|reg| reg.operations.deserialize_impl.as_ref())
.is_some()
{
return Ok(Some(Self {
deserialized_input: Arc::new(deserialized_input),
serialized_input: None,
}));
}
return Ok(None);
}
pub fn implicit_deserialize(
&mut self,
incoming: DynOutput,
ctx: &mut DiagramContext,
) -> Result<(), DiagramErrorCode> {
if incoming.message_info() == self.deserialized_input.message_info() {
return incoming
.connect_to(&self.deserialized_input, ctx.builder)
.map_err(Into::into);
}
if incoming.message_info() == &TypeInfo::of::<JsonMessage>() {
let serialized_input = match self.serialized_input {
Some(serialized_input) => serialized_input,
None => {
let deserialize = ctx
.registry
.messages
.deserialize(self.deserialized_input.message_info(), ctx.builder)?;
deserialize
.ok
.connect_to(&self.deserialized_input, ctx.builder)?;
let error_target = ctx.get_implicit_error_target();
ctx.add_output_into_target(error_target, deserialize.err);
self.serialized_input = Some(deserialize.input);
deserialize.input
}
};
return incoming
.connect_to(&serialized_input, ctx.builder)
.map_err(Into::into);
}
Err(TypeMismatch {
source_type: *incoming.message_info(),
target_type: *self.deserialized_input.message_info(),
}
.into())
}
pub fn deserialized_input_slot(&self) -> &Arc<DynInputSlot> {
&self.deserialized_input
}
}
pub struct ImplicitStringify {
incoming_types: HashMap<TypeInfo, DynInputSlot>,
string_input: DynInputSlot,
}
impl ImplicitStringify {
pub fn new(string_input: DynInputSlot) -> Result<Self, DiagramErrorCode> {
if string_input.message_info() != &TypeInfo::of::<String>() {
return Err(TypeMismatch {
source_type: TypeInfo::of::<String>(),
target_type: *string_input.message_info(),
}
.into());
}
Ok(Self {
string_input,
incoming_types: Default::default(),
})
}
pub fn try_implicit_stringify(
&mut self,
incoming: DynOutput,
ctx: &mut DiagramContext,
) -> Result<Result<(), DynOutput>, DiagramErrorCode> {
if incoming.message_info() == &TypeInfo::of::<String>() {
incoming.connect_to(&self.string_input, ctx.builder)?;
return Ok(Ok(()));
}
let input = match self.incoming_types.entry(*incoming.message_info()) {
Entry::Occupied(input_slot) => input_slot.get().clone(),
Entry::Vacant(vacant) => {
let Some(stringify) = ctx
.registry
.messages
.try_to_string(incoming.message_info(), ctx.builder)?
else {
return Ok(Err(incoming));
};
stringify
.output
.connect_to(&self.string_input, ctx.builder)?;
vacant.insert(stringify.input).clone()
}
};
incoming.connect_to(&input, ctx.builder)?;
Ok(Ok(()))
}
}