use std::any::{Any, TypeId, type_name};
use std::collections::HashMap;
use std::sync::Arc;
use klauthed_macros::DomainError;
use crate::config::{Config, FromConfig};
use crate::error::ConfigError;
#[derive(Debug, DomainError)]
#[domain(prefix = "wiring", category = "internal")]
#[non_exhaustive]
pub enum WiringError {
#[domain(category = "internal", code = "missing_component")]
MissingComponent(&'static str),
}
impl std::fmt::Display for WiringError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WiringError::MissingComponent(ty) => {
write!(f, "no component of type `{ty}` is registered")
}
}
}
}
impl std::error::Error for WiringError {}
#[derive(Default)]
pub struct AppContext {
components: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl AppContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn register<T: Any + Send + Sync>(&mut self, component: T) -> &mut Self {
self.register_arc(Arc::new(component))
}
pub fn register_arc<T: Any + Send + Sync>(&mut self, component: Arc<T>) -> &mut Self {
self.components.insert(TypeId::of::<T>(), component);
self
}
#[must_use]
pub fn with<T: Any + Send + Sync>(mut self, component: T) -> Self {
self.register(component);
self
}
pub fn register_from_config<T>(&mut self, config: &Config) -> Result<&mut Self, ConfigError>
where
T: FromConfig + Any + Send + Sync,
{
let component = T::from_config(config)?;
Ok(self.register(component))
}
#[must_use]
pub fn get<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
self.components.get(&TypeId::of::<T>()).and_then(|c| Arc::clone(c).downcast::<T>().ok())
}
pub fn require<T: Any + Send + Sync>(&self) -> Result<Arc<T>, WiringError> {
self.get::<T>().ok_or(WiringError::MissingComponent(type_name::<T>()))
}
#[must_use]
pub fn contains<T: Any + Send + Sync>(&self) -> bool {
self.components.contains_key(&TypeId::of::<T>())
}
#[must_use]
pub fn len(&self) -> usize {
self.components.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.components.is_empty()
}
}
pub trait Starter {
fn name(&self) -> &str;
fn configure(&self, config: &Config, ctx: &mut AppContext) -> Result<(), ConfigError>;
}
pub struct AppBuilder {
config: Config,
starters: Vec<Box<dyn Starter>>,
}
impl AppBuilder {
#[must_use]
pub fn new(config: Config) -> Self {
Self { config, starters: Vec::new() }
}
#[must_use]
pub fn with_starter<S: Starter + 'static>(mut self, starter: S) -> Self {
self.starters.push(Box::new(starter));
self
}
#[must_use]
pub fn with_boxed_starter(mut self, starter: Box<dyn Starter>) -> Self {
self.starters.push(starter);
self
}
pub fn build(self) -> Result<AppContext, ConfigError> {
let mut ctx = AppContext::new();
ctx.register(self.config.clone());
for starter in &self.starters {
tracing::debug!(starter = starter.name(), "running config starter");
starter.configure(&self.config, &mut ctx)?;
}
Ok(ctx)
}
}
pub struct ConfigSectionsStarter;
impl Starter for ConfigSectionsStarter {
fn name(&self) -> &str {
"config-sections"
}
fn configure(&self, config: &Config, ctx: &mut AppContext) -> Result<(), ConfigError> {
use crate::config::keys;
use crate::config::{
CacheConfig, DatabaseConfig, MessagingConfig, ServerConfig, StorageConfig,
};
if let Some(section) = config.get_optional::<DatabaseConfig>(keys::DATABASE)? {
ctx.register(section);
}
if let Some(section) = config.get_optional::<CacheConfig>(keys::CACHE)? {
ctx.register(section);
}
if let Some(section) = config.get_optional::<MessagingConfig>(keys::MESSAGING)? {
ctx.register(section);
}
if let Some(section) = config.get_optional::<StorageConfig>(keys::STORAGE)? {
ctx.register(section);
}
if let Some(section) = config.get_optional::<ServerConfig>(keys::SERVER)? {
ctx.register(section);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::provider::MemoryProvider;
use crate::config::{ConfigBuilder, Profile};
use serde::Deserialize;
use serde_json::json;
struct Db {
url: String,
}
struct Cache {
entries: u32,
}
#[test]
fn register_and_resolve_by_type() {
let mut ctx = AppContext::new();
ctx.register(Db { url: "u".into() }).register(Cache { entries: 10 });
assert_eq!(ctx.len(), 2);
assert_eq!(ctx.require::<Db>().unwrap().url, "u");
assert_eq!(ctx.get::<Cache>().unwrap().entries, 10);
}
#[test]
fn missing_component_errors() {
let ctx = AppContext::new();
assert!(!ctx.contains::<Db>());
assert!(matches!(ctx.require::<Db>(), Err(WiringError::MissingComponent(_))));
}
#[test]
fn registering_same_type_replaces() {
let mut ctx = AppContext::new();
ctx.register(Db { url: "first".into() });
ctx.register(Db { url: "second".into() });
assert_eq!(ctx.len(), 1);
assert_eq!(ctx.require::<Db>().unwrap().url, "second");
}
#[test]
fn shared_arcs_are_cheap_clones() {
let mut ctx = AppContext::new();
ctx.register(Db { url: "u".into() });
let a = ctx.require::<Db>().unwrap();
let b = ctx.require::<Db>().unwrap();
assert!(Arc::ptr_eq(&a, &b));
}
#[derive(Deserialize, FromConfig)]
#[config(key = "database")]
struct DatabaseSettings {
host: String,
}
#[tokio::test]
async fn register_from_config_binds_and_registers() {
let config = ConfigBuilder::new(Profile::Test)
.with_provider(MemoryProvider::new().set("database", json!({ "host": "db.internal" })))
.build()
.await
.unwrap();
let mut ctx = AppContext::new();
ctx.register_from_config::<DatabaseSettings>(&config).unwrap();
assert_eq!(ctx.require::<DatabaseSettings>().unwrap().host, "db.internal");
}
#[tokio::test]
async fn app_builder_runs_config_sections_starter() {
use crate::config::{DatabaseConfig, ServerConfig};
let config = ConfigBuilder::new(Profile::Test)
.with_provider(MemoryProvider::new().set("server", json!({ "port": 9000 })))
.build()
.await
.unwrap();
let ctx = AppBuilder::new(config).with_starter(ConfigSectionsStarter).build().unwrap();
assert!(ctx.contains::<Config>());
assert_eq!(ctx.require::<ServerConfig>().unwrap().port, 9000);
assert!(!ctx.contains::<DatabaseConfig>());
}
#[tokio::test]
async fn app_builder_runs_a_custom_starter() {
struct DbStarter;
impl Starter for DbStarter {
fn name(&self) -> &str {
"db"
}
fn configure(&self, _config: &Config, ctx: &mut AppContext) -> Result<(), ConfigError> {
ctx.register(Db { url: "wired".into() });
Ok(())
}
}
let config = ConfigBuilder::new(Profile::Test)
.with_provider(MemoryProvider::new().set("x", json!(1)))
.build()
.await
.unwrap();
let ctx = AppBuilder::new(config).with_starter(DbStarter).build().unwrap();
assert_eq!(ctx.require::<Db>().unwrap().url, "wired");
}
}