use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
pub type AnyProvider = Arc<dyn Any + Send + Sync>;
pub struct Resolver<'a> {
map: &'a HashMap<TypeId, AnyProvider>,
}
impl<'a> Resolver<'a> {
#[doc(hidden)]
#[inline]
pub fn get_any(&self, ty: TypeId) -> Option<&'a AnyProvider> {
self.map.get(&ty)
}
#[doc(hidden)]
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> &'static T {
let r: &T = self
.map
.get(&TypeId::of::<T>())
.and_then(|a| a.downcast_ref::<T>())
.expect("Arcly DI: dependency requested by a provider was not in the resolved set");
#[allow(unsafe_code)]
unsafe {
std::mem::transmute::<&T, &'static T>(r)
}
}
}
pub struct ProviderDescriptor {
pub name: &'static str,
pub type_id_fn: fn() -> TypeId,
pub deps_fn: fn() -> Vec<TypeId>,
pub build: fn(&Resolver<'_>) -> AnyProvider,
}
pub struct ModuleDescriptor {
pub name: &'static str,
pub providers: &'static [&'static ProviderDescriptor],
pub controllers: &'static [&'static str],
pub imports: &'static [fn() -> &'static ModuleDescriptor],
pub gateways: &'static [&'static str],
}
pub trait Module: 'static {
fn descriptor() -> &'static ModuleDescriptor;
}
#[derive(Default)]
pub struct DiContainerBuilder {
providers: Vec<&'static ProviderDescriptor>,
direct: HashMap<TypeId, AnyProvider>,
direct_names: HashMap<TypeId, &'static str>,
overridden: std::collections::HashSet<TypeId>,
}
impl DiContainerBuilder {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn register<T: Send + Sync + 'static>(&mut self, v: T) -> &mut Self {
self.direct.insert(TypeId::of::<T>(), Arc::new(v));
self.direct_names
.insert(TypeId::of::<T>(), std::any::type_name::<T>());
self
}
#[inline]
pub fn register_override<T: Send + Sync + 'static>(&mut self, v: T) -> &mut Self {
self.overridden.insert(TypeId::of::<T>());
self.register(v)
}
#[inline]
pub fn add_provider(&mut self, d: &'static ProviderDescriptor) -> &mut Self {
self.providers.push(d);
self
}
pub fn freeze(self) -> &'static FrozenDiContainer {
let providers: Vec<&'static ProviderDescriptor> = self
.providers
.into_iter()
.filter(|d| !self.overridden.contains(&(d.type_id_fn)()))
.collect();
for d in &providers {
let ty = (d.type_id_fn)();
if self.direct.contains_key(&ty) {
let direct_name = self.direct_names.get(&ty).copied().unwrap_or("<unknown>");
panic!(
"Arcly DI: `{}` is registered both directly (as `{direct_name}`) and \
via #[Injectable] `{}`. Use `register_override` (plugins: \
`ArclyPluginContext::override_provider`) to replace the injectable, \
or drop one of the registrations.",
d.name, d.name
);
}
}
let order = Self::topo_sort(&providers, &self.direct);
let mut map: HashMap<TypeId, AnyProvider> = self.direct;
for d in order {
let resolver = Resolver { map: &map };
let instance = (d.build)(&resolver);
map.insert((d.type_id_fn)(), instance);
}
Box::leak(Box::new(FrozenDiContainer { map }))
}
fn topo_sort(
providers: &[&'static ProviderDescriptor],
direct: &HashMap<TypeId, AnyProvider>,
) -> Vec<&'static ProviderDescriptor> {
use std::collections::VecDeque;
let mut by_ty: HashMap<TypeId, &'static ProviderDescriptor> = HashMap::new();
for p in providers {
let prev = by_ty.insert((p.type_id_fn)(), *p);
if prev.is_some() {
panic!("Arcly DI: provider for `{}` registered twice", p.name);
}
}
let mut indeg: HashMap<TypeId, usize> = HashMap::new();
let mut adj: HashMap<TypeId, Vec<TypeId>> = HashMap::new();
for p in providers {
let ty = (p.type_id_fn)();
let mut d = 0usize;
for dep in (p.deps_fn)() {
if by_ty.contains_key(&dep) {
d += 1;
adj.entry(dep).or_default().push(ty);
} else if direct.contains_key(&dep) {
} else {
panic!(
"Arcly DI: provider `{}` depends on type `{:?}` which has no registered provider",
p.name, dep
);
}
}
indeg.insert(ty, d);
}
let mut q: VecDeque<TypeId> = indeg
.iter()
.filter_map(|(k, &v)| (v == 0).then_some(*k))
.collect();
let mut order: Vec<&'static ProviderDescriptor> = Vec::with_capacity(providers.len());
while let Some(ty) = q.pop_front() {
order.push(by_ty[&ty]);
if let Some(children) = adj.get(&ty) {
for c in children {
let e = indeg.get_mut(c).unwrap();
*e -= 1;
if *e == 0 {
q.push_back(*c);
}
}
}
}
if order.len() != providers.len() {
let unresolved: Vec<&'static str> = providers
.iter()
.filter(|p| !order.iter().any(|q| (q.type_id_fn)() == (p.type_id_fn)()))
.map(|p| p.name)
.collect();
panic!("Arcly DI: dependency cycle detected. Unresolved providers: {unresolved:?}");
}
order
}
}
pub struct FrozenDiContainer {
map: HashMap<TypeId, AnyProvider>,
}
impl FrozenDiContainer {
#[inline(always)]
pub fn get<T: Send + Sync + 'static>(&self) -> &T {
self.map
.get(&TypeId::of::<T>())
.and_then(|a| a.downcast_ref::<T>())
.expect("Arcly DI: missing provider")
}
#[inline]
pub fn try_get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.map
.get(&TypeId::of::<T>())
.and_then(|a| a.downcast_ref::<T>())
}
#[doc(hidden)]
#[inline]
pub fn resolver(&self) -> Resolver<'_> {
Resolver { map: &self.map }
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum HttpMethod {
GET,
POST,
PUT,
DELETE,
PATCH,
}
impl From<HttpMethod> for axum::http::Method {
#[inline]
fn from(m: HttpMethod) -> Self {
match m {
HttpMethod::GET => axum::http::Method::GET,
HttpMethod::POST => axum::http::Method::POST,
HttpMethod::PUT => axum::http::Method::PUT,
HttpMethod::DELETE => axum::http::Method::DELETE,
HttpMethod::PATCH => axum::http::Method::PATCH,
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum ParamLoc {
Path,
Query,
Header,
}
pub struct ParamSpec {
pub name: &'static str,
pub loc: ParamLoc,
pub required: bool,
pub schema: fn() -> serde_json::Value,
}
pub struct RouteSpec {
pub summary: &'static str,
pub description: &'static str,
pub operation_id: &'static str,
pub tags: &'static [&'static str],
pub security: &'static [&'static str],
pub status_code: Option<u16>,
pub deprecated: bool,
pub params: &'static [ParamSpec],
pub has_body: bool,
pub body_schema: Option<fn() -> serde_json::Value>,
pub query_schema: Option<fn() -> serde_json::Value>,
pub response_schema: Option<fn() -> serde_json::Value>,
pub cache_ttl_secs: u64,
pub cache_key: &'static str,
pub api_version: &'static str,
pub sunset: &'static str,
pub idempotent_ttl_secs: u64,
pub policies: &'static [&'static str],
pub audit_action: &'static str,
pub audit_resource: &'static str,
pub timeout_ms: u64,
pub transactional: bool,
pub mask_fields: &'static [&'static str],
}
pub struct RouteDescriptor {
pub method: HttpMethod,
pub path: &'static str,
pub handler: fn(
crate::web::context::RequestContext,
) -> futures::future::BoxFuture<'static, axum::response::Response>,
pub spec: &'static RouteSpec,
pub controller: &'static str,
}
inventory::collect!(&'static ModuleDescriptor);
inventory::collect!(&'static RouteDescriptor);