mod error;
use crate::error::{AppContextDroppedError, BeanError};
use once_cell::sync::OnceCell;
use std::any::{type_name, Any};
use std::collections::HashMap;
use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Weak};
#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct BeanMetadata {
type_name: &'static str,
bean_name: &'static str,
}
impl BeanMetadata {
pub(crate) fn build_meta<T>(name: &'static str) -> BeanMetadata {
BeanMetadata {
type_name: type_name::<T>(),
bean_name: name,
}
}
}
pub trait BuildFromContext<E, CtxErr = (), InitErr = ()> {
fn build_from(ctx: &AppContextBuilder, extras: E) -> Result<Self, CtxErr>
where
Self: Sized;
fn init_self(&self) -> Result<(), InitErr> {
return Ok(());
}
}
#[async_trait::async_trait]
pub trait BuildFromContextAsync<E, CtxErr = (), InitErr = ()> {
async fn build_from(ctx: &AppContextBuilder, extras: E) -> Result<Self, CtxErr>
where
Self: Sized;
async fn init_self(&self) -> Result<(), InitErr> {
return Ok(());
}
}
pub struct BeanType<T>(PhantomData<T>);
pub trait BeanTypeOf<T> {
const BEAN_TYPE: BeanType<T>;
}
impl<T> BeanTypeOf<T> for T {
const BEAN_TYPE: BeanType<T> = BeanType(PhantomData);
}
pub struct RefWrapper<T>(Arc<OnceCell<T>>);
impl<T> Deref for RefWrapper<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.0.get().expect("bean is not initialized properly")
}
}
pub struct BeanRef<T> {
inner: Weak<OnceCell<T>>,
}
impl<T> BeanRef<T> {
pub fn try_acquire(&self) -> Result<RefWrapper<T>, AppContextDroppedError> {
self.inner
.upgrade()
.map(|c| RefWrapper(c))
.ok_or(AppContextDroppedError)
}
pub fn acquire(&self) -> RefWrapper<T> {
self.try_acquire()
.expect("app context is dropped, all beans are not acquirable")
}
pub fn is_active(&self) -> bool {
self.try_acquire().is_ok()
}
}
pub struct BeanWrapper {
bean: Arc<dyn Any + Send + Sync>,
initialized: AtomicBool,
meta: BeanMetadata,
}
impl Clone for BeanWrapper {
fn clone(&self) -> Self {
Self {
bean: self.bean.clone(),
initialized: AtomicBool::new(self.initialized.load(Ordering::Acquire)),
meta: self.meta.clone(),
}
}
}
impl BeanWrapper {
pub(crate) fn wrap<T>(bean: OnceCell<T>, meta: BeanMetadata) -> Self
where
T: Send + Sync + 'static,
{
Self {
initialized: AtomicBool::new(bean.get().is_some()),
bean: Arc::new(bean),
meta,
}
}
pub(crate) fn initialized(&self) -> bool {
self.initialized.load(Ordering::Acquire)
}
pub(crate) fn build_bean_ref<T>(&self) -> BeanRef<T>
where
T: Send + Sync + 'static,
{
let weak_arc = self
.bean
.clone()
.downcast::<OnceCell<T>>()
.ok()
.map(|c| Arc::downgrade(&c))
.expect("bean type is not matched");
BeanRef { inner: weak_arc }
}
pub(crate) fn set_inner<T>(&self, bean: T)
where
T: Send + Sync + 'static,
{
self.bean
.clone()
.downcast::<OnceCell<T>>()
.ok()
.map(|c| c.set(bean).ok())
.flatten()
.expect("bean is setted before");
self.initialized.store(true, Ordering::Release);
}
}
pub struct AppContextInner {
bean_map: HashMap<BeanMetadata, BeanWrapper>,
}
pub struct AppContextBuilder {
inner: Mutex<AppContextInner>,
init_fn_map: HashMap<BeanMetadata, InitFnEnum>,
}
impl AppContextBuilder {
pub fn new() -> Self {
Self {
inner: Mutex::new(AppContextInner {
bean_map: Default::default(),
}),
init_fn_map: Default::default(),
}
}
pub fn acquire_bean_or_init<T>(&self, _ty: BeanType<T>, name: &'static str) -> BeanRef<T>
where
T: Send + Sync + 'static,
{
let meta = BeanMetadata::build_meta::<T>(name);
self.inner
.lock()
.expect("unexpected lock")
.bean_map
.entry(meta.clone())
.or_insert(BeanWrapper::wrap(OnceCell::<T>::new(), meta))
.build_bean_ref()
}
pub fn construct_bean<T, E, Err, Err2>(
mut self,
_ty: BeanType<T>,
name: &'static str,
extras: E,
) -> Result<Self, Err>
where
T: Send + Sync + BuildFromContext<E, Err, Err2> + 'static,
Err2: Send + Sync + 'static,
{
let meta = BeanMetadata::build_meta::<T>(name);
let bean = T::build_from(&self, extras)?;
self.inner
.lock()
.expect("unexpected lock")
.bean_map
.entry(meta.clone())
.or_insert(BeanWrapper::wrap(OnceCell::<T>::new(), meta.clone()))
.set_inner(bean);
self.init_fn_map
.insert(meta.clone(), build_init_fn::<E, Err, Err2, T>());
Ok(self)
}
pub async fn construct_bean_async<T, E, Err, Err2>(
mut self,
_ty: BeanType<T>,
name: &'static str,
extras: E,
) -> Result<Self, Err>
where
T: Send + Sync + BuildFromContextAsync<E, Err, Err2> + 'static,
Err2: Send + Sync + 'static,
{
let meta = BeanMetadata::build_meta::<T>(name);
let bean = T::build_from(&self, extras).await?;
self.inner
.lock()
.expect("unexpected lock")
.bean_map
.entry(meta.clone())
.or_insert(BeanWrapper::wrap(OnceCell::<T>::new(), meta.clone()))
.set_inner(bean);
self.init_fn_map
.insert(meta.clone(), build_init_fn_async::<E, Err, Err2, T>());
Ok(self)
}
pub fn build_without_init(self) -> Result<AppContext, BeanError> {
if let Some((uninit_meta, _)) = self
.inner
.lock()
.expect("unexpected lock")
.bean_map
.iter()
.find(|(meta, bean)| !bean.initialized())
{
return Err(BeanError::NotInitialized(uninit_meta.clone()));
}
Ok(AppContext {
inner: Arc::new(self.inner.into_inner().expect("unexpected lock")),
})
}
pub fn build_non_async(self) -> Result<AppContext, BeanError> {
{
let wrapper = self.inner.lock().expect("unexpected lock");
if let Some((uninit_meta, _)) = wrapper
.bean_map
.iter()
.find(|(meta, bean)| !bean.initialized())
{
return Err(BeanError::NotInitialized(uninit_meta.clone()));
}
for (k, v) in self.init_fn_map {
let wrapper_cloned = wrapper
.bean_map
.get(&k)
.expect("unexpected meta key error")
.clone();
if let InitFnEnum::Sync(f) = v {
f(wrapper_cloned)?;
} else {
return Err(BeanError::HasAsync(k));
}
}
}
Ok(AppContext {
inner: Arc::new(self.inner.into_inner().expect("unexpected lock")),
})
}
pub async fn build_all(self) -> Result<AppContext, BeanError> {
{
let wrapper = self.inner.lock().expect("unexpected lock");
if let Some((uninit_meta, _)) = wrapper
.bean_map
.iter()
.find(|(meta, bean)| !bean.initialized())
{
return Err(BeanError::NotInitialized(uninit_meta.clone()));
}
for (k, v) in self.init_fn_map {
let wrapper_cloned = wrapper
.bean_map
.get(&k)
.expect("unexpected meta key error")
.clone();
match v {
InitFnEnum::Sync(f) => {
f(wrapper_cloned)?;
}
InitFnEnum::Async(fut) => {
fut(wrapper_cloned).await?;
}
}
}
}
Ok(AppContext {
inner: Arc::new(self.inner.into_inner().expect("unexpected lock")),
})
}
}
pub(crate) type InitFn = Box<dyn Fn(BeanWrapper) -> Result<(), BeanError>>;
pub(crate) type InitFnAsync =
Box<dyn Fn(BeanWrapper) -> Pin<Box<dyn Future<Output = Result<(), BeanError>>>>>;
pub(crate) enum InitFnEnum {
Sync(InitFn),
Async(InitFnAsync),
}
pub(crate) fn build_init_fn<Props, E1, E2, T>() -> InitFnEnum
where
E2: Send + Sync + 'static,
T: BuildFromContext<Props, E1, E2> + Send + Sync + 'static,
{
InitFnEnum::Sync(Box::new(|wrap| {
let r = wrap.build_bean_ref::<T>().acquire();
match r.init_self() {
Ok(_) => Ok(()),
Err(e) => Err(BeanError::DuringInit(wrap.meta.clone(), Box::new(e))),
}
}))
}
pub(crate) fn build_init_fn_async<Props, E1, E2, T>() -> InitFnEnum
where
E2: Send + Sync + 'static,
T: BuildFromContextAsync<Props, E1, E2> + Send + Sync + 'static,
{
InitFnEnum::Async(Box::new(|wrap| {
Box::pin(async move {
let r = wrap.build_bean_ref::<T>().acquire();
match r.init_self().await {
Ok(_) => Ok(()),
Err(e) => Err(BeanError::DuringInit(wrap.meta.clone(), Box::new(e))),
}
})
}))
}
pub struct AppContext {
inner: Arc<AppContextInner>,
}
impl Clone for AppContext {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl AppContext {
pub fn try_acquire_bean<T>(&self, name: &'static str) -> Option<BeanRef<T>>
where
T: Send + Sync + 'static,
{
let meta = BeanMetadata::build_meta::<T>(name);
self.inner
.bean_map
.get(&meta)
.cloned()
.map(|w| w.build_bean_ref())
}
pub fn acquire_bean<T>(&self, _ty: BeanType<T>, name: &'static str) -> BeanRef<T>
where
T: Send + Sync + 'static,
{
self.try_acquire_bean(name)
.expect("bean is not initialized")
}
pub fn acquire_beans_by_type<T>(&self, _ty: BeanType<T>) -> Vec<BeanRef<T>>
where
T: Send + Sync + 'static,
{
self.inner
.bean_map
.iter()
.filter(|(k, v)| k.type_name == type_name::<T>())
.map(|(k, v)| v.clone().build_bean_ref())
.collect()
}
pub fn acquire_beans_by_name<T>(&self, name: &'static str) -> Vec<BeanRef<T>>
where
T: Send + Sync + 'static,
{
self.inner
.bean_map
.iter()
.filter(|(k, v)| k.bean_name == name)
.map(|(k, v)| v.clone().build_bean_ref())
.collect()
}
pub fn is_last_clone(&self)->bool{
Arc::strong_count(&self.inner)==1
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::bail;
use async_trait::async_trait;
use std::time::Duration;
pub struct ServiceA {
svc_b: BeanRef<ServiceB>,
dao: BeanRef<DaoC>,
}
impl ServiceA {
pub fn check(&self) {
println!("svc a is ready");
}
}
impl Drop for ServiceA {
fn drop(&mut self) {
println!("svc a is dropped");
}
}
impl BuildFromContext<(), ()> for ServiceA {
fn build_from(ctx: &AppContextBuilder, extras: ()) -> Result<Self, ()> {
Ok(ServiceA {
svc_b: ctx.acquire_bean_or_init(ServiceB::BEAN_TYPE, "b"),
dao: ctx.acquire_bean_or_init(DaoC::BEAN_TYPE, "c"),
})
}
}
pub struct ServiceB {
svc_a: BeanRef<ServiceA>,
dao: BeanRef<DaoC>,
config_val: u32,
}
impl Drop for ServiceB {
fn drop(&mut self) {
println!("svc b is dropped");
}
}
impl ServiceB {
pub fn check(&self) {
println!("svc b is ready");
}
}
impl BuildFromContext<u32, (), anyhow::Error> for ServiceB {
fn build_from(ctx: &AppContextBuilder, extras: u32) -> Result<Self, ()> {
Ok(ServiceB {
svc_a: ctx.acquire_bean_or_init(ServiceA::BEAN_TYPE, "a"),
dao: ctx.acquire_bean_or_init(DaoC::BEAN_TYPE, "c"),
config_val: extras,
})
}
fn init_self(&self) -> Result<(), anyhow::Error> {
Ok(())
}
}
pub struct DaoC {
inner_map: HashMap<String, String>,
}
impl Drop for DaoC {
fn drop(&mut self) {
println!("dao c is dropped");
}
}
impl DaoC {
pub fn check(&self) {
println!("dao c is ready");
}
}
impl BuildFromContext<HashMap<String, String>, ()> for DaoC {
fn build_from(
ctx: &AppContextBuilder,
extras: HashMap<String, String>,
) -> Result<Self, ()> {
Ok(DaoC { inner_map: extras })
}
}
pub struct DaoD {
inner_vec: Vec<i32>,
}
impl Drop for DaoD {
fn drop(&mut self) {
println!("dao d is droped");
}
}
impl DaoD {
pub async fn check(&self) {
println!("dao d is ready");
}
}
#[async_trait]
impl BuildFromContextAsync<usize, String> for DaoD {
async fn build_from(ctx: &AppContextBuilder, extras: usize) -> Result<Self, String> {
Ok(DaoD {
inner_vec: Vec::with_capacity(extras),
})
}
async fn init_self(&self) -> Result<(), ()> {
tokio::time::sleep(Duration::from_millis(500)).await;
Ok(())
}
}
#[tokio::test]
async fn it_works() -> anyhow::Result<()> {
let svc_a = {
let ctx = AppContextBuilder::new()
.construct_bean(ServiceA::BEAN_TYPE, "a", ())
.unwrap()
.construct_bean(ServiceB::BEAN_TYPE, "b", 32)
.unwrap()
.construct_bean(DaoC::BEAN_TYPE, "c", HashMap::new())
.unwrap()
.construct_bean_async(DaoD::BEAN_TYPE, "d", 5_usize)
.await
.unwrap()
.build_all()
.await?;
let svc_a = ctx.acquire_bean(ServiceA::BEAN_TYPE, "a");
svc_a.acquire().check();
let svc_b = ctx.acquire_bean(ServiceB::BEAN_TYPE, "b");
svc_b.acquire().check();
let dao_c = ctx.acquire_bean(DaoC::BEAN_TYPE, "c");
dao_c.acquire().check();
let dao_d = ctx.acquire_bean(DaoD::BEAN_TYPE, "d");
dao_d.acquire().check().await;
assert!(ctx.is_last_clone());
svc_a
};
assert!(!svc_a.is_active());
let ctx = AppContextBuilder::new()
.construct_bean(ServiceA::BEAN_TYPE, "a", ())
.unwrap()
.build_without_init();
assert!(ctx.is_err());
Ok(())
}
}