use std::any::{Any, TypeId};
use std::collections::BTreeMap;
use std::fmt; use std::sync::Arc;
#[cfg(not(feature = "tokio"))]
use std::sync::Mutex;
#[cfg(feature = "tokio")]
use tokio::sync::Mutex;
#[cfg(feature = "tokio")]
use std::{
pin::Pin,
future::Future
};
#[derive(Debug)]
pub struct RsServiceError(String);
impl fmt::Display for RsServiceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RsService Error: {}", self.0)
}
}
impl std::error::Error for RsServiceError {
}
type ContainerStruct = Box<dyn Any + Send + Sync + 'static>;
type MapForContainer = BTreeMap<TypeId, ContainerStruct>;
pub trait RSContextService: Any + Send + Sync + 'static {
fn on_register_crate_instance() -> Self where Self: Sized;
fn on_service_created(&mut self, builder: &RSContextBuilder) -> Result<(), RsServiceError>;
fn on_all_services_built(&self, context: &RSContext) -> Result<(), RsServiceError>;
}
#[cfg(not(feature = "tokio"))]
pub struct RSContextBuilder {
pending_services: MapForContainer,
after_build_hooks: Vec<Box<dyn FnOnce(&RSContext) -> Result<(), RsServiceError> + Send + Sync>>,
}
#[cfg(feature = "tokio")]
pub struct RSContextBuilder {
pending_services: MapForContainer,
after_build_async_hooks: Vec<
Box<
dyn Fn(Arc<RSContext>) -> Pin<Box<dyn Future<Output = Result<(), RsServiceError>> + Send>>
+ Send
+ Sync
>
>,
}
impl RSContextBuilder {
#[cfg(feature = "tokio")]
pub fn new() -> Self {
RSContextBuilder {
pending_services: BTreeMap::new(),
after_build_async_hooks: Vec::new(),
}
}
#[cfg(feature = "tokio")]
pub fn register<T>(mut self) -> Self
where
T: RSContextService, {
let type_id = TypeId::of::<T>();
if self.pending_services.contains_key(&type_id) {
panic!("Service type {:?} already registered.", std::any::type_name::<T>());
}
let mut instance = T::on_register_crate_instance();
instance.on_service_created(&self)
.map_err(|e| RsServiceError(format!("on_service_created hook failed for {}: {}", std::any::type_name::<T>(), e)))
.expect("on_service_created hook failed");
let service_arc_mutex: Arc<Mutex<T>> = Arc::new(Mutex::new(instance));
self.pending_services.insert(
type_id,
Box::new(service_arc_mutex.clone()) as ContainerStruct,
);
{
let hook = Box::new(move |ctx: Arc<RSContext>| {
let arc_mutex = ctx.call::<T>().expect("Service not found");
Box::pin(async move {
let ret = arc_mutex.lock().await;
ret.on_all_services_built(&ctx)
}) as Pin<Box<dyn Future<Output = Result<(), RsServiceError>> + Send>>
});
self.after_build_async_hooks.push(hook);
}
self
}
#[cfg(feature = "tokio")]
pub async fn build(self) -> Result<RSContext, RsServiceError> { let context = RSContext {
service_map: self.pending_services,
};
let arc_context = Arc::new(context);
for async_hook in self.after_build_async_hooks {
let fut = async_hook(Arc::clone(&arc_context)).await;
}
match Arc::try_unwrap(arc_context) {
Ok(context) => Ok(context),
Err(_) => Err(RsServiceError("Failed to unwrap Arc<RSContext> in build()".to_string())),
}
}
#[cfg(not(feature = "tokio"))]
pub fn new() -> Self {
RSContextBuilder {
pending_services: BTreeMap::new(),
after_build_hooks: Vec::new(),
}
}
#[cfg(not(feature = "tokio"))]
pub fn register<T>(mut self) -> Self
where
T: RSContextService, {
let type_id = TypeId::of::<T>();
if self.pending_services.contains_key(&type_id) {
panic!("Service type {:?} already registered.", std::any::type_name::<T>());
}
let mut instance = T::on_register_crate_instance();
instance.on_service_created(&self)
.map_err(|e| RsServiceError(format!("on_service_created hook failed for {}: {}", std::any::type_name::<T>(), e)))
.expect("on_service_created hook failed");
let service_arc_mutex: Arc<Mutex<T>> = Arc::new(Mutex::new(instance));
self.pending_services.insert(
type_id,
Box::new(service_arc_mutex.clone()) as ContainerStruct,
);
self.after_build_hooks.push(Box::new(move |ctx: &RSContext| {
if let Some(service_access) = ctx.call::<T>() { let service_guard = service_access.lock().map_err(|_| RsServiceError("Mutex poisoned".to_string()))?;
service_guard.on_all_services_built(ctx)?;
}
Ok(())
}));
self
}
#[cfg(not(feature = "tokio"))]
pub fn build(self) -> Result<RSContext, RsServiceError> { let context = RSContext {
service_map: self.pending_services, };
for hook_fn in self.after_build_hooks {
hook_fn(&context)?;
}
Ok(context)
}
}
pub struct RSContext {
service_map: MapForContainer,
}
impl RSContext {
pub fn call<T>(&self) -> Option<Arc<Mutex<T>>>
where
T: RSContextService, {
self.service_map
.get(&TypeId::of::<T>())
.and_then(|boxed_val| {
boxed_val.downcast_ref::<Arc<Mutex<T>>>()
})
.cloned()
}
}