use std::{
any::{Any, TypeId},
fmt::Debug,
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
};
pub use async_trait::async_trait;
use futures::stream::Stream;
pub trait Data: Send + Sync + 'static {}
impl<T: Send + Sync + 'static> Data for T {}
pub type DataUnary<T> = Pin<Box<dyn Future<Output = T> + Send>>;
pub type DataStream<T> = Pin<Box<dyn Stream<Item = T> + Send>>;
pub type Engine<Req, Resp, E> = Arc<dyn AsyncEngine<Req, Resp, E>>;
pub type EngineUnary<Resp> = Pin<Box<dyn AsyncEngineUnary<Resp>>>;
pub type EngineStream<Resp> = Pin<Box<dyn AsyncEngineStream<Resp>>>;
pub type Context = Arc<dyn AsyncEngineContext>;
impl<T: Data> From<EngineStream<T>> for DataStream<T> {
fn from(stream: EngineStream<T>) -> Self {
Box::pin(stream)
}
}
pub trait AsyncEngineController: Send + Sync {}
#[async_trait]
pub trait AsyncEngineContext: Send + Sync + Debug {
fn id(&self) -> &str;
fn is_stopped(&self) -> bool;
fn is_killed(&self) -> bool;
async fn stopped(&self);
async fn killed(&self);
fn stop_generating(&self);
fn stop(&self);
fn kill(&self);
fn link_child(&self, child: Arc<dyn AsyncEngineContext>);
}
pub trait AsyncEngineContextProvider: Send + Debug {
fn context(&self) -> Arc<dyn AsyncEngineContext>;
}
pub trait AsyncEngineUnary<Resp: Data>:
Future<Output = Resp> + AsyncEngineContextProvider + Send
{
}
pub trait AsyncEngineStream<Resp: Data>:
Stream<Item = Resp> + AsyncEngineContextProvider + Send
{
}
#[async_trait]
pub trait AsyncEngine<Req: Send + Sync + 'static, Resp: AsyncEngineContextProvider, E: Data>:
Send + Sync
{
async fn generate(&self, request: Req) -> Result<Resp, E>;
}
pub struct ResponseStream<R: Data> {
stream: DataStream<R>,
ctx: Arc<dyn AsyncEngineContext>,
}
impl<R: Data> ResponseStream<R> {
pub fn new(stream: DataStream<R>, ctx: Arc<dyn AsyncEngineContext>) -> Pin<Box<Self>> {
Box::pin(Self { stream, ctx })
}
}
impl<R: Data> Stream for ResponseStream<R> {
type Item = R;
#[inline]
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
impl<R: Data> AsyncEngineStream<R> for ResponseStream<R> {}
impl<R: Data> AsyncEngineContextProvider for ResponseStream<R> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
self.ctx.clone()
}
}
impl<R: Data> Debug for ResponseStream<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseStream")
.field("ctx", &self.ctx)
.finish()
}
}
impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineUnary<T>>> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
AsyncEngineContextProvider::context(&**self)
}
}
impl<T: Data> AsyncEngineContextProvider for Pin<Box<dyn AsyncEngineStream<T>>> {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
AsyncEngineContextProvider::context(&**self)
}
}
pub trait AnyAsyncEngine: Send + Sync {
fn request_type_id(&self) -> TypeId;
fn response_type_id(&self) -> TypeId;
fn error_type_id(&self) -> TypeId;
fn as_any(&self) -> &dyn Any;
}
struct AnyEngineWrapper<Req, Resp, E>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
engine: Arc<dyn AsyncEngine<Req, Resp, E>>,
_phantom: PhantomData<fn(Req, Resp, E)>,
}
impl<Req, Resp, E> AnyAsyncEngine for AnyEngineWrapper<Req, Resp, E>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
fn request_type_id(&self) -> TypeId {
TypeId::of::<Req>()
}
fn response_type_id(&self) -> TypeId {
TypeId::of::<Resp>()
}
fn error_type_id(&self) -> TypeId {
TypeId::of::<E>()
}
fn as_any(&self) -> &dyn Any {
&self.engine
}
}
pub trait AsAnyAsyncEngine {
fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine>;
}
impl<Req, Resp, E> AsAnyAsyncEngine for Arc<dyn AsyncEngine<Req, Resp, E>>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
fn into_any_engine(self) -> Arc<dyn AnyAsyncEngine> {
Arc::new(AnyEngineWrapper {
engine: self,
_phantom: PhantomData,
})
}
}
pub trait DowncastAnyAsyncEngine {
fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data;
}
impl DowncastAnyAsyncEngine for Arc<dyn AnyAsyncEngine> {
fn downcast<Req, Resp, E>(&self) -> Option<Arc<dyn AsyncEngine<Req, Resp, E>>>
where
Req: Data,
Resp: Data + AsyncEngineContextProvider,
E: Data,
{
if self.request_type_id() == TypeId::of::<Req>()
&& self.response_type_id() == TypeId::of::<Resp>()
&& self.error_type_id() == TypeId::of::<E>()
{
self.as_any()
.downcast_ref::<Arc<dyn AsyncEngine<Req, Resp, E>>>()
.cloned()
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[derive(Debug, PartialEq)]
struct Req1(String);
#[derive(Debug, PartialEq)]
struct Resp1(String);
impl AsyncEngineContextProvider for Resp1 {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
unimplemented!()
}
}
#[derive(Debug)]
struct Err1;
#[derive(Debug)]
struct Req2;
#[derive(Debug)]
struct Resp2;
impl AsyncEngineContextProvider for Resp2 {
fn context(&self) -> Arc<dyn AsyncEngineContext> {
unimplemented!()
}
}
struct MockEngine;
#[async_trait]
impl AsyncEngine<Req1, Resp1, Err1> for MockEngine {
async fn generate(&self, request: Req1) -> Result<Resp1, Err1> {
Ok(Resp1(format!("response to {}", request.0)))
}
}
#[tokio::test]
async fn test_engine_type_erasure_and_downcast() {
let typed_engine: Arc<dyn AsyncEngine<Req1, Resp1, Err1>> = Arc::new(MockEngine);
let any_engine = typed_engine.into_any_engine();
assert_eq!(any_engine.request_type_id(), TypeId::of::<Req1>());
assert_eq!(any_engine.response_type_id(), TypeId::of::<Resp1>());
assert_eq!(any_engine.error_type_id(), TypeId::of::<Err1>());
let downcasted_engine = any_engine.downcast::<Req1, Resp1, Err1>();
assert!(downcasted_engine.is_some());
let response = downcasted_engine
.unwrap()
.generate(Req1("hello".to_string()))
.await;
assert_eq!(response.unwrap(), Resp1("response to hello".to_string()));
let failed_downcast = any_engine.downcast::<Req2, Resp2, Err1>();
assert!(failed_downcast.is_none());
let mut engine_map: HashMap<String, Arc<dyn AnyAsyncEngine>> = HashMap::new();
engine_map.insert("mock".to_string(), any_engine);
let retrieved_engine = engine_map.get("mock").unwrap();
let final_engine = retrieved_engine.downcast::<Req1, Resp1, Err1>().unwrap();
let final_response = final_engine.generate(Req1("world".to_string())).await;
assert_eq!(
final_response.unwrap(),
Resp1("response to world".to_string())
);
}
}