use std::{any::TypeId, collections::BTreeMap, fmt::Debug};
use inventory::{Collect, Registry};
use thiserror::Error;
pub trait Factory<T: ?Sized> {
fn create(&self) -> Box<T>;
}
#[derive(Debug, Error)]
pub enum FactoryError {
#[error("factory with ID '{0}' not found")]
FactoryNotFound(String),
#[error("empty ID provided without fallback")]
EmptyIdNoFallback,
#[error("no factories available")]
NoFactoriesAvailable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FactoryFallback {
First,
Last,
NoFallback,
}
pub struct SimpleFactory<T: ?Sized + 'static>(
BTreeMap<&'static str, &'static (dyn Factory<T> + Sync)>,
);
impl<T> SimpleFactory<T>
where
T: ?Sized + 'static,
{
pub fn create(
&self,
id: impl AsRef<str>,
strategy: FactoryFallback,
) -> Result<Box<T>, FactoryError> {
let id = id.as_ref();
if !id.is_empty() {
return if let Some(factory) = self.0.get(id) {
Ok(factory.create())
} else {
Err(FactoryError::FactoryNotFound(id.to_string()))
};
}
match strategy {
FactoryFallback::First => {
if let Some((_, factory)) = self.0.first_key_value() {
return Ok(factory.create());
}
}
FactoryFallback::Last => {
if let Some((_, factory)) = self.0.last_key_value() {
return Ok(factory.create());
}
}
FactoryFallback::NoFallback => return Err(FactoryError::EmptyIdNoFallback),
}
Err(FactoryError::NoFactoriesAvailable)
}
}
pub struct FactoryRegistry<T>
where
T: ?Sized + 'static,
{
id: &'static str,
factory: &'static (dyn Factory<T> + Sync),
type_id: TypeId,
}
impl<T> Collect for FactoryRegistry<T>
where
T: ?Sized + 'static,
{
fn registry() -> &'static Registry {
static REGISTRY: Registry = Registry::new();
®ISTRY
}
}
impl<T> FactoryRegistry<T>
where
T: ?Sized + 'static,
{
#[inline]
pub const fn new(id: &'static str, factory: &'static (dyn Factory<T> + Sync)) -> Self {
Self {
id,
factory,
type_id: TypeId::of::<T>(),
}
}
pub fn simple_factory() -> SimpleFactory<T> {
let type_id = TypeId::of::<T>();
let factories = inventory::iter::<Self>()
.filter_map(|reg| (type_id == reg.type_id).then_some((reg.id, reg.factory)))
.collect();
SimpleFactory(factories)
}
}
#[macro_export]
macro_rules! register_factory {
($product:ty, $id:literal, $implement:ty) => {
$crate::const_assert!(!$id.is_empty());
$crate::assert_impl_one!($implement: Default);
const _: () = {
struct ConcreteFactory;
impl $crate::Factory<$product> for ConcreteFactory {
fn create(&self) -> Box<$product> {
Box::<$implement>::default()
}
}
$crate::submit! {
$crate::FactoryRegistry::new(
$id,
&ConcreteFactory as &'static (dyn $crate::Factory<$product> + Sync),
)
}
};
};
}
#[cfg(test)]
mod tests {
use super::*;
trait TestProduct {
fn get_value(&self) -> &str;
}
struct ProductA {
value: String,
}
impl ProductA {
#[allow(dead_code)]
fn new(value: &str) -> Self {
Self {
value: value.to_string(),
}
}
}
impl TestProduct for ProductA {
fn get_value(&self) -> &str {
&self.value
}
}
impl Default for ProductA {
fn default() -> Self {
Self {
value: "default_a".to_string(),
}
}
}
struct ProductB {
value: String,
}
impl ProductB {
#[allow(dead_code)]
fn new(value: &str) -> Self {
Self {
value: value.to_string(),
}
}
}
impl TestProduct for ProductB {
fn get_value(&self) -> &str {
&self.value
}
}
impl Default for ProductB {
fn default() -> Self {
Self {
value: "default_b".to_string(),
}
}
}
register_factory!(dyn TestProduct, "product_a", ProductA);
register_factory!(dyn TestProduct, "product_b", ProductB);
#[test]
fn test_factory_registration() {
let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
let result_a = factory.create("product_a", FactoryFallback::NoFallback);
assert!(result_a.is_ok(), "product_a factory should exist");
let result_b = factory.create("product_b", FactoryFallback::NoFallback);
assert!(result_b.is_ok(), "product_b factory should exist");
}
#[test]
fn test_factory_creation() {
let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
let result = factory.create("product_a", FactoryFallback::NoFallback);
assert!(result.is_ok());
let product = result.unwrap();
assert_eq!(product.get_value(), "default_a");
let result = factory.create("product_b", FactoryFallback::NoFallback);
assert!(result.is_ok());
let product = result.unwrap();
assert_eq!(product.get_value(), "default_b");
}
#[test]
fn test_factory_error_cases() {
let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
let result = factory.create("non_existent", FactoryFallback::NoFallback);
assert!(result.is_err());
if let Err(FactoryError::FactoryNotFound(id)) = result {
assert_eq!(id, "non_existent");
} else {
panic!("Expected FactoryNotFound error");
}
let result = factory.create("", FactoryFallback::NoFallback);
assert!(result.is_err());
if let Err(FactoryError::EmptyIdNoFallback) = result {
} else {
panic!("Expected EmptyIdNoFallback error");
}
}
#[test]
fn test_factory_fallback_first() {
let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
let result = factory.create("", FactoryFallback::First);
match result {
Ok(_product) => {
}
Err(FactoryError::NoFactoriesAvailable) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
let result = factory.create("invalid_id", FactoryFallback::First);
match result {
Ok(_product) => {
panic!("Expected FactoryNotFound for invalid ID");
}
Err(FactoryError::FactoryNotFound(id)) => {
assert_eq!(id, "invalid_id");
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[test]
fn test_factory_fallback_last() {
let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
let result = factory.create("", FactoryFallback::Last);
match result {
Ok(_product) => {
}
Err(FactoryError::NoFactoriesAvailable) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
let result = factory.create("invalid_id", FactoryFallback::Last);
match result {
Ok(_product) => {
panic!("Expected FactoryNotFound for invalid ID");
}
Err(FactoryError::FactoryNotFound(id)) => {
assert_eq!(id, "invalid_id");
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[test]
fn test_factory_no_factories_available() {
trait EmptyProduct {
#[allow(dead_code)]
fn dummy(&self);
}
let factory = FactoryRegistry::<dyn EmptyProduct>::simple_factory();
let result = factory.create("", FactoryFallback::First);
assert!(result.is_err());
if let Err(FactoryError::NoFactoriesAvailable) = result {
} else {
panic!("Expected NoFactoriesAvailable error");
}
let result = factory.create("", FactoryFallback::Last);
assert!(result.is_err());
if let Err(FactoryError::NoFactoriesAvailable) = result {
} else {
panic!("Expected NoFactoriesAvailable error");
}
}
#[test]
fn test_factory_registry_new() {
struct TestFactory;
impl Factory<String> for TestFactory {
fn create(&self) -> Box<String> {
Box::new("test".to_string())
}
}
let factory = &TestFactory as &'static (dyn Factory<String> + Sync);
let registry = FactoryRegistry::new("test_id", factory);
assert_eq!(registry.id, "test_id");
assert_eq!(registry.type_id, TypeId::of::<String>());
}
#[test]
fn test_factory_error_display() {
let error = FactoryError::FactoryNotFound("test_id".to_string());
assert_eq!(format!("{}", error), "factory with ID 'test_id' not found");
let error = FactoryError::EmptyIdNoFallback;
assert_eq!(format!("{}", error), "empty ID provided without fallback");
let error = FactoryError::NoFactoriesAvailable;
assert_eq!(format!("{}", error), "no factories available");
}
#[test]
fn test_factory_fallback_debug() {
assert_eq!(format!("{:?}", FactoryFallback::First), "First");
assert_eq!(format!("{:?}", FactoryFallback::Last), "Last");
assert_eq!(format!("{:?}", FactoryFallback::NoFallback), "NoFallback");
}
#[test]
fn test_factory_fallback_eq() {
assert_eq!(FactoryFallback::First, FactoryFallback::First);
assert_eq!(FactoryFallback::Last, FactoryFallback::Last);
assert_eq!(FactoryFallback::NoFallback, FactoryFallback::NoFallback);
assert_ne!(FactoryFallback::First, FactoryFallback::Last);
assert_ne!(FactoryFallback::First, FactoryFallback::NoFallback);
}
#[test]
fn test_simple_factory_debug() {
let factory = FactoryRegistry::<dyn TestProduct>::simple_factory();
let result = factory.create("product_a", FactoryFallback::NoFallback);
assert!(result.is_ok());
}
}