use std::any::{Any, TypeId};
use std::collections::BTreeMap;
use std::fmt; use std::sync::Arc;
#[cfg(not(feature = "tokio"))]
use std::sync::Mutex;
#[cfg(feature = "tokio")]
use tokio::sync::Mutex;
#[cfg(feature = "tokio")]
use std::{
pin::Pin,
future::Future
};
#[derive(Debug)]
pub struct RsServiceError(String);
impl fmt::Display for RsServiceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RsService Error: {}", self.0)
}
}
impl std::error::Error for RsServiceError {
}
type ContainerStruct = Box<dyn Any + Send + Sync + 'static>;
type MapForContainer = BTreeMap<TypeId, ContainerStruct>;
#[cfg(not(feature = "tokio"))]
pub trait RSContextService: Any + Send + Sync + 'static {
fn on_register_crate_instance() -> Self where Self: Sized;
fn on_service_created(&mut self, builder: &RSContextBuilder) -> Result<(), RsServiceError>;
fn on_all_services_built(&self, context: &RSContext) -> Result<(), RsServiceError>;
}
#[cfg(feature = "tokio")]
pub type AsyncHooksResult = Result<(), RsServiceError>;
#[cfg(feature = "tokio")]
pub trait RSContextService: Any + Send + Sync + 'static {
async fn on_register_crate_instance() -> Self where Self: Sized;
fn on_service_created(&mut self, builder: &RSContextBuilder) -> impl std::future::Future<Output = AsyncHooksResult> + Send;
fn on_all_services_built(&self, context: &RSContext) -> impl std::future::Future<Output = AsyncHooksResult> + Send;
}
#[cfg(not(feature = "tokio"))]
type AfterBuildHook = Box<
dyn FnOnce(&RSContext) ->
Result<(), RsServiceError>
+ Send
+ Sync
>;
#[cfg(feature = "tokio")]
type AfterAsyncBuildHook = Box<
dyn Fn(Arc<RSContext>) -> Pin<Box<dyn Future<Output = Result<(), RsServiceError>> + Send>>
+ Send
+ Sync
>;
#[cfg(not(feature = "tokio"))]
pub struct RSContextBuilder {
pending_services: MapForContainer,
after_build_hooks: Vec<AfterBuildHook>,
}
#[cfg(feature = "tokio")]
pub struct RSContextBuilder {
pending_services: MapForContainer,
after_build_async_hooks: Vec<AfterAsyncBuildHook>,
}
impl RSContextBuilder {
#[cfg(feature = "tokio")]
pub fn new() -> Self {
RSContextBuilder {
pending_services: BTreeMap::new(),
after_build_async_hooks: Vec::new(),
}
}
#[cfg(feature = "tokio")]
pub async fn register<T>(mut self) -> Result<Self,RsServiceError>
where
T: RSContextService + Send + Sync + 'static, {
let type_id = TypeId::of::<T>();
if self.pending_services.contains_key(&type_id) {
return Err(RsServiceError(format!("Service type {:?} already registered.", std::any::type_name::<T>())));
}
let mut instance = T::on_register_crate_instance().await;
instance.on_service_created(&self)
.await
.map_err(
|e| RsServiceError(format!("on_service_created hook failed for {}: {}", std::any::type_name::<T>(), e))
)?;
let service_arc_mutex: Arc<Mutex<T>> = Arc::new(Mutex::new(instance));
self.pending_services.insert(
type_id,
Box::new(service_arc_mutex.clone()) as ContainerStruct,
);
{
let hook = Box::new(move |ctx: Arc<RSContext>| {
let arc_mutex = ctx.call::<T>().expect("Service not found");
Box::pin(async move {
arc_mutex.lock().await.on_all_services_built(&ctx).await
}) as Pin<Box<dyn Future<Output = Result<(), RsServiceError>> + Send>>
});
self.after_build_async_hooks.push(hook);
}
Ok(self)
}
#[cfg(feature = "tokio")]
pub async fn build(self) -> Result<RSContext, RsServiceError> { let context = RSContext {
service_map: self.pending_services,
};
let arc_context = Arc::new(context);
for async_hook in self.after_build_async_hooks {
let fut = async_hook(Arc::clone(&arc_context)).await;
}
match Arc::try_unwrap(arc_context) {
Ok(context) => Ok(context),
Err(_) => Err(RsServiceError("Failed to unwrap Arc<RSContext> in build()".to_string())),
}
}
#[cfg(not(feature = "tokio"))]
pub fn new() -> Self {
RSContextBuilder {
pending_services: BTreeMap::new(),
after_build_hooks: Vec::new(),
}
}
#[cfg(not(feature = "tokio"))]
pub fn register<T>(mut self) -> Result<Self,RsServiceError>
where
T: RSContextService, {
let type_id = TypeId::of::<T>();
if self.pending_services.contains_key(&type_id) {
return Err(RsServiceError(format!("Service type {:?} already registered.", std::any::type_name::<T>())));
}
let mut instance = T::on_register_crate_instance();
let result_on = instance.on_service_created(&self)
.map_err(
|e|
RsServiceError(format!("on_service_created hook failed for {}: {}", std::any::type_name::<T>(), e)
));
if let Err(e) = result_on {
return Err(e);
}
let service_arc_mutex: Arc<Mutex<T>> = Arc::new(Mutex::new(instance));
self.pending_services.insert(
type_id,
Box::new(service_arc_mutex.clone()) as ContainerStruct,
);
self.after_build_hooks.push(Box::new(move |ctx: &RSContext| {
if let Some(service_access) = ctx.call::<T>() { let service_guard = service_access.lock().map_err(|_| RsServiceError("Mutex poisoned".to_string()))?;
service_guard.on_all_services_built(ctx)?;
}
Ok(())
}));
Ok(self)
}
#[cfg(not(feature = "tokio"))]
pub fn build(self) -> Result<RSContext, RsServiceError> { let context = RSContext {
service_map: self.pending_services, };
for hook_fn in self.after_build_hooks {
hook_fn(&context)?;
}
Ok(context)
}
}
pub struct RSContext {
service_map: MapForContainer,
}
impl RSContext {
pub fn call<T>(&self) -> Option<Arc<Mutex<T>>>
where
T: RSContextService, {
self.service_map
.get(&TypeId::of::<T>())
.and_then(|boxed_val| {
boxed_val.downcast_ref::<Arc<Mutex<T>>>()
})
.cloned()
}
}