use std::{collections::HashMap, hash::Hash, ops::Deref};
use triomphe::Arc;
use futures::FutureExt;
use rand::{RngCore, SeedableRng, rngs::SmallRng};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, oneshot};
use crate::{
error::{CallSubscribeError, CborValueError, TryFromEventError},
event::{self, Event, EventData},
};
use super::EffectWright;
#[cfg(feature = "macros")]
pub use ioevent_macro::{ProcedureCall, procedure};
pub struct EventShooter {
selector: Box<dyn Fn(&EventData) -> Option<*const ()> + Send + Sync + 'static>,
shooter: oneshot::Sender<*const ()>,
}
unsafe impl Send for EventShooter {}
unsafe impl Sync for EventShooter {}
impl EventShooter {
pub fn create_with_selector<F, T>(
selector: F,
) -> (
Self,
impl Future<Output = Result<T, oneshot::error::RecvError>>,
)
where
F: Fn(&EventData) -> Option<T> + Send + Sync + 'static,
T: Send + 'static,
{
let (shooter, receiver) = oneshot::channel();
(
Self {
selector: Box::new(move |event_data: &EventData| match selector(event_data) {
Some(value_t) => {
let boxed_t = Box::new(value_t);
let raw_ptr = Box::into_raw(boxed_t);
Some(raw_ptr as *const ())
}
None => None,
}),
shooter,
},
receiver.map(|recv_result| {
recv_result.map(|ptr| {
let raw_ptr = ptr as *mut T;
unsafe { *Box::from_raw(raw_ptr) }
})
}),
)
}
pub fn try_dispatch(self, event: &EventData) -> Option<Self> {
if self.shooter.is_closed() {
return None;
}
let event = (self.selector)(event);
match event {
Some(event) => {
unsafe { self.force_dispatch(event) };
None
}
None => Some(self),
}
}
unsafe fn force_dispatch(self, event_ptr: *const ()) -> bool {
self.shooter.send(event_ptr).is_ok()
}
}
#[derive(Clone)]
pub struct State<T> {
pub state: T,
pub wright: EffectWright,
pub event_shooters: Arc<Mutex<Vec<EventShooter>>>,
}
impl<T> Deref for State<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl<T> State<T> {
pub fn new(state: T, bus: EffectWright) -> Self {
Self {
state,
wright: bus,
event_shooters: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn wait_next<F, E>(
&self,
selector: F,
) -> impl Future<Output = Result<E, oneshot::error::RecvError>>
where
F: for<'a> Fn(&'a EventData) -> Option<E> + Send + Sync + 'static,
E: Send + 'static,
{
let (shoot, rx) = EventShooter::create_with_selector(selector);
self.event_shooters.lock().await.push(shoot);
rx
}
}
pub fn encode_procedure_call(path: &str, echo: u64, r#type: ProcedureCallType) -> String {
match r#type {
ProcedureCallType::Request => format!(
"internal.ProcedureCall\u{0000}|{}\u{0000}|?echo=\u{0000}|{}",
path, echo
),
ProcedureCallType::Response => format!(
"internal.ProcedureCall\u{0000}|{}\u{0000}|!echo=\u{0000}|{}",
path, echo
),
}
}
pub fn decode_procedure_call(path: &str) -> Result<(String, u64, ProcedureCallType), String> {
let parts: Vec<&str> = path.split("\u{0000}|").collect();
if parts.len() != 4 || parts[0] != "internal.ProcedureCall" {
return Err("Invalid procedure call path format".to_string());
}
let path = parts[1].to_string();
let echo = parts[3]
.parse()
.map_err(|_| "Invalid echo format".to_string())?;
let r#type = match parts[2] {
"?echo=" => ProcedureCallType::Request,
"!echo=" => ProcedureCallType::Response,
_ => return Err("Invalid procedure call type indicator".to_string()),
};
Ok((path, echo, r#type))
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
pub enum ProcedureCallType {
Request,
Response,
}
pub trait ProcedureCall: Serialize + for<'de> Deserialize<'de> {
fn path() -> String;
}
pub trait ProcedureCallRequest:
ProcedureCall + TryFrom<ProcedureCallData, Error = TryFromEventError> + Sized
{
type RESPONSE: ProcedureCallResponse;
fn upcast(&self, echo: u64) -> Result<ProcedureCallData, CborValueError> {
Ok(ProcedureCallData {
path: Self::path(),
echo,
r#type: ProcedureCallType::Request,
payload: event::Value::serialized(&self)?,
})
}
fn match_self(other: &ProcedureCallData) -> bool {
other.path == Self::path() && other.r#type == ProcedureCallType::Request
}
}
pub trait ProcedureCallResponse:
ProcedureCall + TryFrom<ProcedureCallData, Error = TryFromEventError> + Sized
{
fn upcast(&self, echo: u64) -> Result<ProcedureCallData, CborValueError> {
Ok(ProcedureCallData {
path: Self::path(),
echo,
r#type: ProcedureCallType::Response,
payload: event::Value::serialized(&self)?,
})
}
fn match_echo(other: &ProcedureCallData, echo: u64) -> bool {
other.path == Self::path()
&& other.r#type == ProcedureCallType::Response
&& other.echo == echo
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct ProcedureCallData {
pub path: String,
pub echo: u64,
pub r#type: ProcedureCallType,
pub payload: event::Value,
}
impl From<ProcedureCallData> for EventData {
fn from(value: ProcedureCallData) -> Self {
EventData {
tag: encode_procedure_call(&value.path, value.echo, value.r#type),
payload: value.payload,
}
}
}
impl Event for ProcedureCallData {
fn upcast(&self) -> Result<EventData, CborValueError> {
Ok(self.clone().into())
}
const TAG: &'static str = "internal.ProcedureCall";
const SELECTOR: crate::event::Selector =
crate::event::Selector(|e| e.tag.starts_with(Self::TAG));
}
impl TryFrom<&EventData> for ProcedureCallData {
type Error = TryFromEventError;
fn try_from(value: &EventData) -> Result<Self, Self::Error> {
let (path, echo, r#type) = decode_procedure_call(&value.tag)?;
Ok(ProcedureCallData {
path,
echo,
r#type,
payload: value.payload.clone(),
})
}
}
pub trait ProcedureCallExt {
fn call<P>(
&self,
procedure: &P,
) -> impl Future<Output = Result<P::RESPONSE, CallSubscribeError>>
where
P: ProcedureCallRequest;
fn resolve<P>(
&self,
echo: u64,
response: &P::RESPONSE,
) -> impl Future<Output = Result<(), CallSubscribeError>>
where
P: ProcedureCallRequest;
}
impl<T> ProcedureCallExt for State<T>
where
T: ProcedureCallWright,
{
async fn call<P>(&self, procedure: &P) -> Result<P::RESPONSE, CallSubscribeError>
where
P: ProcedureCallRequest,
{
let echo = self.state.next_echo().await;
let request = procedure.upcast(echo)?;
let response = self
.wait_next(move |e| {
if ProcedureCallData::SELECTOR.match_event(&e) {
if let Ok(data) = ProcedureCallData::try_from(e) {
if P::RESPONSE::match_echo(&data, echo) {
return Some(data);
}
}
}
None
})
.await;
self.wright.emit(&request)?;
let response = response.await?;
Ok(P::RESPONSE::try_from(response)?)
}
async fn resolve<P>(&self, echo: u64, response: &P::RESPONSE) -> Result<(), CallSubscribeError>
where
P: ProcedureCallRequest,
{
let data = response.upcast(echo)?;
self.wright.emit(&data)?;
Ok(())
}
}
#[derive(Clone)]
pub struct DefaultProcedureWright {
pub rng: Arc<Mutex<SmallRng>>,
}
impl Default for DefaultProcedureWright {
fn default() -> Self {
Self {
rng: Arc::new(Mutex::new(SmallRng::from_os_rng())),
}
}
}
pub trait ProcedureCallWright {
fn next_echo(&self) -> impl Future<Output = u64> + Send + Sync;
}
impl ProcedureCallWright for DefaultProcedureWright {
async fn next_echo(&self) -> u64 {
let mut rand = self.rng.lock().await;
rand.next_u64()
}
}
impl ProcedureCall for () {
fn path() -> String {
"core::Unit".to_owned()
}
}
impl ProcedureCallResponse for () {}
impl TryFrom<ProcedureCallData> for () {
type Error = TryFromEventError;
fn try_from(_: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(())
}
}
impl<T, E> ProcedureCall for Result<T, E>
where
T: ProcedureCall,
E: ProcedureCall,
{
fn path() -> String {
format!("core::Result<{}, {}>", T::path(), E::path())
}
}
impl<T, E> ProcedureCallResponse for Result<T, E>
where
T: ProcedureCallResponse,
E: ProcedureCallResponse,
{
}
impl<T, E> TryFrom<ProcedureCallData> for Result<T, E>
where
T: ProcedureCall,
E: ProcedureCall,
{
type Error = TryFromEventError;
fn try_from(value: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(value.payload.deserialized()?)
}
}
impl<T> ProcedureCall for Option<T>
where
T: ProcedureCall,
{
fn path() -> String {
format!("core::Option<{}>", T::path())
}
}
impl<T> ProcedureCallResponse for Option<T> where T: ProcedureCallResponse {}
impl<T> TryFrom<ProcedureCallData> for Option<T>
where
T: ProcedureCall,
{
type Error = TryFromEventError;
fn try_from(value: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(value.payload.deserialized()?)
}
}
impl<T> ProcedureCall for Vec<T>
where
T: ProcedureCall,
{
fn path() -> String {
format!("core::Vec<{}>", T::path())
}
}
impl<T> ProcedureCallResponse for Vec<T> where T: ProcedureCallResponse {}
impl<T> TryFrom<ProcedureCallData> for Vec<T>
where
T: ProcedureCall,
{
type Error = TryFromEventError;
fn try_from(value: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(value.payload.deserialized()?)
}
}
impl<K, V> ProcedureCall for HashMap<K, V>
where
K: ProcedureCall + Hash + Eq,
V: ProcedureCall,
{
fn path() -> String {
format!("core::HashMap<{}, {}>", K::path(), V::path())
}
}
impl<K, V> ProcedureCallResponse for HashMap<K, V>
where
K: ProcedureCallResponse + Hash + Eq,
V: ProcedureCallResponse,
{
}
impl<K, V> TryFrom<ProcedureCallData> for HashMap<K, V>
where
K: ProcedureCallResponse + Hash + Eq,
V: ProcedureCallResponse,
{
type Error = TryFromEventError;
fn try_from(value: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(value.payload.deserialized()?)
}
}
macro_rules! impl_procedure_call {
($($t:ty),*) => {
$(
impl ProcedureCall for $t {
fn path() -> String {
concat!("core::", stringify!($t)).to_owned()
}
}
impl ProcedureCallResponse for $t {}
impl TryFrom<ProcedureCallData> for $t {
type Error = TryFromEventError;
fn try_from(value: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(value.payload.deserialized()?)
}
}
)*
};
}
impl_procedure_call!(
String, bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, char
);
macro_rules! impl_procedure_call_tuple {
($($t:ident),*) => {
impl<$($t: ProcedureCall),*> ProcedureCall for ($($t,)*) {
fn path() -> String {
"core::Tuple".to_owned() + "(" + $($t::path().as_str() + ", " +)* ")"
}
}
impl<$($t: ProcedureCallResponse),*> ProcedureCallResponse for ($($t,)*) {}
impl<$($t: ProcedureCall),*> TryFrom<ProcedureCallData> for ($($t,)*) {
type Error = TryFromEventError;
fn try_from(value: ProcedureCallData) -> Result<Self, Self::Error> {
Ok(value.payload.deserialized()?)
}
}
};
}
impl_procedure_call_tuple!(P0);
impl_procedure_call_tuple!(P0, P1);
impl_procedure_call_tuple!(P0, P1, P2);
impl_procedure_call_tuple!(P0, P1, P2, P3);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7, P8);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12);
impl_procedure_call_tuple!(P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13);
impl_procedure_call_tuple!(
P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14
);
impl_procedure_call_tuple!(
P0, P1, P2, P3, P4, P5, P6, P7, P8, P9, P10, P11, P12, P13, P14, P15
);