use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt;
use crate::context::{Cons, Nil, Tagged};
use crate::runtime::Never;
pub trait Service: Clone + 'static {
const NAME: &'static str;
#[inline]
fn to_context(self) -> ServiceContext
where
Self: Sized,
{
ServiceContext::empty().add(self)
}
#[inline]
fn layer(self) -> crate::layer::Layer<Self, Never, ()>
where
Self: Sized,
{
crate::layer::Layer::succeed(self)
}
#[inline]
fn use_<A, E, R, F>(f: F) -> crate::Effect<A, E, R>
where
Self: Sized,
A: 'static,
E: From<MissingService> + 'static,
R: ServiceLookup<Self> + 'static,
F: FnOnce(Self) -> crate::Effect<A, E, R> + 'static,
{
crate::Effect::<Self, E, R>::service::<Self>().flat_map(f)
}
#[inline]
fn use_sync<A, E, R, F>(f: F) -> crate::Effect<A, E, R>
where
Self: Sized,
A: 'static,
E: From<MissingService> + 'static,
R: ServiceLookup<Self> + 'static,
F: FnOnce(Self) -> A + 'static,
{
crate::Effect::<Self, E, R>::service::<Self>().map(f)
}
}
#[derive(Default)]
pub struct ServiceContext {
entries: HashMap<TypeId, ServiceEntry>,
}
struct ServiceEntry {
name: &'static str,
value: Box<dyn Any>,
clone_value: fn(&dyn Any) -> Box<dyn Any>,
}
impl Clone for ServiceEntry {
fn clone(&self) -> Self {
Self {
name: self.name,
value: (self.clone_value)(self.value.as_ref()),
clone_value: self.clone_value,
}
}
}
impl Clone for ServiceContext {
fn clone(&self) -> Self {
Self {
entries: self.entries.clone(),
}
}
}
impl fmt::Debug for ServiceContext {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut names = self
.entries
.values()
.map(|entry| entry.name)
.collect::<Vec<_>>();
names.sort_unstable();
f.debug_struct("ServiceContext")
.field("services", &names)
.finish()
}
}
impl ServiceContext {
#[inline]
pub fn empty() -> Self {
Self {
entries: HashMap::new(),
}
}
#[inline]
pub fn add<S>(mut self, service: S) -> Self
where
S: Service,
{
self.insert(service);
self
}
#[inline]
pub fn insert<S>(&mut self, service: S)
where
S: Service,
{
self.entries.insert(
TypeId::of::<S>(),
ServiceEntry {
name: S::NAME,
value: Box::new(service),
clone_value: |value| {
Box::new(
value
.downcast_ref::<S>()
.expect("service context entry stored under the wrong type")
.clone(),
)
},
},
);
}
#[inline]
pub fn merge(mut self, other: ServiceContext) -> Self {
self.entries.extend(other.entries);
self
}
#[inline]
pub fn get<S>(&self) -> Option<&S>
where
S: Service,
{
self
.entries
.get(&TypeId::of::<S>())
.and_then(|entry| entry.value.downcast_ref::<S>())
}
#[inline]
pub fn get_cloned<S>(&self) -> Option<S>
where
S: Service,
{
self.get::<S>().cloned()
}
#[inline]
pub fn contains<S>(&self) -> bool
where
S: Service,
{
self.entries.contains_key(&TypeId::of::<S>())
}
pub fn service_names(&self) -> Vec<&'static str> {
let mut names = self
.entries
.values()
.map(|entry| entry.name)
.collect::<Vec<_>>();
names.sort_unstable();
names
}
}
pub trait ServiceLookup<S: Service> {
fn service(&self) -> Option<&S>;
}
impl<S> ServiceLookup<S> for ServiceContext
where
S: Service,
{
#[inline]
fn service(&self) -> Option<&S> {
self.get::<S>()
}
}
impl<S> ServiceLookup<S> for &ServiceContext
where
S: Service,
{
#[inline]
fn service(&self) -> Option<&S> {
(*self).get::<S>()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MissingService {
pub name: &'static str,
}
impl fmt::Display for MissingService {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "missing service `{}`", self.name)
}
}
impl std::error::Error for MissingService {}
pub trait IntoServiceContext {
fn into_service_context(self) -> ServiceContext;
}
impl IntoServiceContext for ServiceContext {
#[inline]
fn into_service_context(self) -> ServiceContext {
self
}
}
impl IntoServiceContext for Nil {
#[inline]
fn into_service_context(self) -> ServiceContext {
ServiceContext::empty()
}
}
impl<S, Tail> IntoServiceContext for Cons<Tagged<S, S>, Tail>
where
S: Service,
Tail: IntoServiceContext,
{
#[inline]
fn into_service_context(self) -> ServiceContext {
let tail = self.1.into_service_context();
tail.add(self.0.value)
}
}
impl<L> IntoServiceContext for crate::context::Context<L>
where
L: IntoServiceContext,
{
#[inline]
fn into_service_context(self) -> ServiceContext {
self.0.into_service_context()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, crate::Service)]
struct Config {
port: u16,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash, crate::Service)]
struct Db {
id: u8,
}
mod into_service_context {
use super::*;
#[test]
fn nil_produces_empty_context() {
let ctx = Nil.into_service_context();
assert!(ctx.service_names().is_empty());
}
#[test]
fn single_service_cell_converts() {
let ctx = Cons(Tagged::<Config, _>::new(Config { port: 8080 }), Nil).into_service_context();
assert_eq!(ctx.get_cloned::<Config>(), Some(Config { port: 8080 }));
}
#[test]
fn multiple_service_cells_convert() {
let ctx = Cons(
Tagged::<Config, _>::new(Config { port: 8080 }),
Cons(Tagged::<Db, _>::new(Db { id: 1 }), Nil),
)
.into_service_context();
assert_eq!(ctx.get_cloned::<Config>(), Some(Config { port: 8080 }));
assert_eq!(ctx.get_cloned::<Db>(), Some(Db { id: 1 }));
}
#[test]
fn context_wrapper_converts() {
let ctx =
crate::context::Context::new(Cons(Tagged::<Config, _>::new(Config { port: 9090 }), Nil));
let svc_ctx = ctx.into_service_context();
assert_eq!(svc_ctx.get_cloned::<Config>(), Some(Config { port: 9090 }));
}
#[test]
fn service_context_identity() {
let ctx = ServiceContext::empty().add(Config { port: 3000 });
let converted = ctx.into_service_context();
assert_eq!(
converted.get_cloned::<Config>(),
Some(Config { port: 3000 })
);
}
#[test]
fn duplicate_service_types_head_wins() {
let ctx = Cons(
Tagged::<Config, _>::new(Config { port: 1 }),
Cons(Tagged::<Config, _>::new(Config { port: 2 }), Nil),
)
.into_service_context();
assert_eq!(ctx.get_cloned::<Config>(), Some(Config { port: 1 }));
}
#[test]
fn service_lookup_trait_still_works_after_conversion() {
let ctx = Cons(Tagged::<Config, _>::new(Config { port: 8080 }), Nil).into_service_context();
let looked_up: Option<&Config> = ServiceLookup::<Config>::service(&ctx);
assert_eq!(looked_up, Some(&Config { port: 8080 }));
}
}
}