use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use std::rc::Rc;
use std::sync::Arc;
use log::{debug, trace};
use crate::{
ProviderAvailability, ProviderCreateError, ProviderDescriptor, ProviderFailure, ProviderName,
ProviderRegistryError, ProviderSelection, ServiceProvider, ServiceSpec,
};
#[derive(Debug)]
pub struct ProviderRegistry<Spec>
where
Spec: ServiceSpec + 'static,
{
providers: Vec<ProviderEntry<Spec>>,
index: HashMap<ProviderName, usize>,
marker: PhantomData<fn() -> Spec>,
}
#[derive(Debug)]
struct ProviderEntry<Spec>
where
Spec: ServiceSpec + 'static,
{
descriptor: ProviderDescriptor,
provider: Arc<dyn ServiceProvider<Spec>>,
}
impl<Spec> ProviderRegistry<Spec>
where
Spec: ServiceSpec + 'static,
{
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn len(&self) -> usize {
self.providers.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
#[inline]
pub fn register<P>(&mut self, provider: P) -> Result<(), ProviderRegistryError>
where
P: ServiceProvider<Spec> + 'static,
{
self.register_provider(Arc::new(provider))
}
#[inline]
pub fn register_shared<P>(&mut self, provider: Arc<P>) -> Result<(), ProviderRegistryError>
where
P: ServiceProvider<Spec> + 'static,
{
self.register_provider(provider)
}
fn register_provider(
&mut self,
provider: Arc<dyn ServiceProvider<Spec>>,
) -> Result<(), ProviderRegistryError> {
let descriptor = provider.descriptor()?;
self.validate_descriptor(&descriptor)?;
let provider_index = self.providers.len();
for name in descriptor.names() {
self.index.insert(name.clone(), provider_index);
}
self.providers.push(ProviderEntry {
descriptor,
provider,
});
debug!(
"registered provider '{}' with {} aliases and priority {}",
self.providers[provider_index].descriptor.id(),
self.providers[provider_index].descriptor.aliases().len(),
self.providers[provider_index].descriptor.priority(),
);
Ok(())
}
#[inline]
pub fn provider_names(&self) -> Vec<&str> {
self.iter_provider_names().collect()
}
#[inline]
pub fn iter_provider_names(&self) -> impl Iterator<Item = &str> + '_ {
self.providers
.iter()
.map(|entry| entry.descriptor.id().as_str())
}
#[inline]
pub fn provider_descriptors(&self) -> Vec<&ProviderDescriptor> {
self.iter_provider_descriptors().collect()
}
#[inline]
pub fn iter_provider_descriptors(&self) -> impl Iterator<Item = &ProviderDescriptor> + '_ {
self.providers.iter().map(|entry| &entry.descriptor)
}
#[inline]
pub fn find_provider(&self, name: &str) -> Option<&dyn ServiceProvider<Spec>> {
self.resolve_provider(name).ok()
}
pub fn resolve_provider(
&self,
name: &str,
) -> Result<&dyn ServiceProvider<Spec>, ProviderRegistryError> {
let name = match ProviderName::new(name) {
Ok(name) => name,
Err(error) => {
trace!("provider resolution rejected invalid name: {error}");
return Err(error);
}
};
let Some(entry) = self.find_entry_by_name(&name) else {
trace!("provider resolution missed provider '{name}'");
return Err(ProviderRegistryError::UnknownProvider { name });
};
trace!(
"provider resolution matched '{}' to registered provider '{}'",
name,
entry.descriptor.id(),
);
Ok(entry.provider.as_ref())
}
#[inline]
pub fn create_box(
&self,
name: &str,
config: &Spec::Config,
) -> Result<Box<Spec::Service>, ProviderRegistryError> {
self.create_with(name, config, |provider, config| provider.create_box(config))
}
#[inline]
pub fn create_arc(
&self,
name: &str,
config: &Spec::Config,
) -> Result<Arc<Spec::Service>, ProviderRegistryError> {
self.create_with(name, config, |provider, config| provider.create_arc(config))
}
#[inline]
pub fn create_rc(
&self,
name: &str,
config: &Spec::Config,
) -> Result<Rc<Spec::Service>, ProviderRegistryError> {
self.create_with(name, config, |provider, config| provider.create_rc(config))
}
fn create_with<Handle, Create>(
&self,
name: &str,
config: &Spec::Config,
create: Create,
) -> Result<Handle, ProviderRegistryError>
where
Create:
Fn(&dyn ServiceProvider<Spec>, &Spec::Config) -> Result<Handle, ProviderCreateError>,
{
let name = ProviderName::new(name)?;
let entry = self
.find_entry_by_name(&name)
.ok_or_else(|| ProviderRegistryError::UnknownProvider { name: name.clone() })?;
trace!("creating service from provider '{name}'");
match entry.provider.availability(config) {
ProviderAvailability::Available => match create(entry.provider.as_ref(), config) {
Ok(service) => {
debug!("provider '{name}' created service");
Ok(service)
}
Err(error) => {
trace!(
"provider '{name}' failed to create service: {}",
error.reason(),
);
Err(registry_error_from_create_error(name, error))
}
},
ProviderAvailability::Unavailable { reason } => {
trace!("provider '{name}' is unavailable: {reason}");
Err(ProviderRegistryError::ProviderUnavailable {
name,
source: ProviderCreateError::unavailable(&reason),
})
}
}
}
#[inline]
pub fn create_auto_box(
&self,
config: &Spec::Config,
) -> Result<Box<Spec::Service>, ProviderRegistryError> {
self.create_selected_box(&ProviderSelection::Auto, config)
}
#[inline]
pub fn create_auto_arc(
&self,
config: &Spec::Config,
) -> Result<Arc<Spec::Service>, ProviderRegistryError> {
self.create_selected_arc(&ProviderSelection::Auto, config)
}
#[inline]
pub fn create_auto_rc(
&self,
config: &Spec::Config,
) -> Result<Rc<Spec::Service>, ProviderRegistryError> {
self.create_selected_rc(&ProviderSelection::Auto, config)
}
#[inline]
pub fn create_selected_box(
&self,
selection: &ProviderSelection,
config: &Spec::Config,
) -> Result<Box<Spec::Service>, ProviderRegistryError> {
self.create_selected_with(selection, config, |provider, config| {
provider.create_box(config)
})
}
#[inline]
pub fn create_selected_arc(
&self,
selection: &ProviderSelection,
config: &Spec::Config,
) -> Result<Arc<Spec::Service>, ProviderRegistryError> {
self.create_selected_with(selection, config, |provider, config| {
provider.create_arc(config)
})
}
#[inline]
pub fn create_selected_rc(
&self,
selection: &ProviderSelection,
config: &Spec::Config,
) -> Result<Rc<Spec::Service>, ProviderRegistryError> {
self.create_selected_with(selection, config, |provider, config| {
provider.create_rc(config)
})
}
fn create_selected_with<Handle, Create>(
&self,
selection: &ProviderSelection,
config: &Spec::Config,
create: Create,
) -> Result<Handle, ProviderRegistryError>
where
Create:
Fn(&dyn ServiceProvider<Spec>, &Spec::Config) -> Result<Handle, ProviderCreateError>,
{
if self.providers.is_empty() {
trace!("provider selection failed because registry is empty");
return Err(ProviderRegistryError::EmptyRegistry);
}
match selection {
ProviderSelection::Auto => {
let candidates = self.auto_candidates();
trace!(
"automatic provider selection prepared {} candidate(s)",
candidates.len(),
);
self.create_from_candidates_with(candidates.iter(), config, &create)
}
ProviderSelection::Named { primary, fallbacks } => {
trace!(
"named provider selection will try primary '{}' with {} fallback(s)",
primary,
fallbacks.len(),
);
self.create_from_candidates_with(
std::iter::once(primary).chain(fallbacks.iter()),
config,
&create,
)
}
}
}
fn create_from_candidates_with<'a, I, Handle, Create>(
&self,
candidates: I,
config: &Spec::Config,
create: &Create,
) -> Result<Handle, ProviderRegistryError>
where
I: IntoIterator<Item = &'a ProviderName>,
Create:
Fn(&dyn ServiceProvider<Spec>, &Spec::Config) -> Result<Handle, ProviderCreateError>,
{
let mut failures = Vec::new();
for candidate in candidates {
match self.create_from_candidate_with(candidate, config, create) {
Ok(service) => {
debug!("provider candidate '{candidate}' created service");
return Ok(service);
}
Err(failure) => {
trace!("provider candidate failed: {failure}");
failures.push(failure);
}
}
}
trace!(
"provider selection exhausted all candidates with {} failure(s)",
failures.len(),
);
Err(ProviderRegistryError::NoAvailableProvider { failures })
}
fn create_from_candidate_with<Handle, Create>(
&self,
candidate: &ProviderName,
config: &Spec::Config,
create: &Create,
) -> Result<Handle, ProviderFailure>
where
Create:
Fn(&dyn ServiceProvider<Spec>, &Spec::Config) -> Result<Handle, ProviderCreateError>,
{
let Some(entry) = self.find_entry_by_name(candidate) else {
trace!("provider candidate '{candidate}' is unknown");
return Err(ProviderFailure::unknown_name(candidate.clone()));
};
match entry.provider.availability(config) {
ProviderAvailability::Available => create(entry.provider.as_ref(), config)
.map_err(|error| failure_from_create_error(candidate.clone(), error)),
ProviderAvailability::Unavailable { reason } => {
trace!("provider candidate '{candidate}' is unavailable: {reason}");
Err(ProviderFailure::unavailable_name(
candidate.clone(),
&reason,
))
}
}
}
fn validate_descriptor(
&self,
descriptor: &ProviderDescriptor,
) -> Result<(), ProviderRegistryError> {
let mut local_names = HashSet::with_capacity(descriptor.aliases().len() + 1);
for name in descriptor.names() {
if !local_names.insert(name.clone()) || self.index.contains_key(name) {
return Err(ProviderRegistryError::DuplicateProviderName { name: name.clone() });
}
}
Ok(())
}
fn find_entry_by_name(&self, name: &ProviderName) -> Option<&ProviderEntry<Spec>> {
self.index
.get(name)
.and_then(|provider_index| self.providers.get(*provider_index))
}
fn auto_candidates(&self) -> Vec<ProviderName> {
let mut providers: Vec<&ProviderEntry<Spec>> = self.providers.iter().collect();
providers.sort_by(|left, right| {
right
.descriptor
.priority()
.cmp(&left.descriptor.priority())
.then_with(|| left.descriptor.id().cmp(right.descriptor.id()))
});
providers
.into_iter()
.map(|entry| entry.descriptor.id().clone())
.collect()
}
}
impl<Spec> Clone for ProviderRegistry<Spec>
where
Spec: ServiceSpec + 'static,
{
#[inline]
fn clone(&self) -> Self {
Self {
providers: self.providers.clone(),
index: self.index.clone(),
marker: PhantomData,
}
}
}
impl<Spec> Clone for ProviderEntry<Spec>
where
Spec: ServiceSpec + 'static,
{
#[inline]
fn clone(&self) -> Self {
Self {
descriptor: self.descriptor.clone(),
provider: self.provider.clone(),
}
}
}
impl<Spec> Default for ProviderRegistry<Spec>
where
Spec: ServiceSpec + 'static,
{
#[inline]
fn default() -> Self {
Self {
providers: Vec::new(),
index: HashMap::new(),
marker: PhantomData,
}
}
}
fn registry_error_from_create_error(
name: ProviderName,
error: ProviderCreateError,
) -> ProviderRegistryError {
if error.is_unavailable() {
ProviderRegistryError::ProviderUnavailable {
name,
source: error,
}
} else {
ProviderRegistryError::ProviderCreate {
name,
source: error,
}
}
}
fn failure_from_create_error(name: ProviderName, error: ProviderCreateError) -> ProviderFailure {
if error.is_unavailable() {
ProviderFailure::unavailable_error(name, error)
} else {
ProviderFailure::create_failed_error(name, error)
}
}