#![warn(missing_docs)]
use std::{any::TypeId, collections::HashMap, sync::Arc};
use http::{Response, StatusCode, request::Parts};
use wae_types::{WaeError, WaeResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Scope {
Singleton,
RequestScoped,
}
impl Default for Scope {
fn default() -> Self {
Scope::Singleton
}
}
struct ScopedService {
scope: Scope,
service: Box<dyn std::any::Any + Send + Sync>,
}
#[derive(Default)]
pub struct Dependencies {
services: HashMap<String, ScopedService>,
typed_services: HashMap<TypeId, ScopedService>,
}
impl Dependencies {
pub fn new() -> Self {
Self { services: HashMap::new(), typed_services: HashMap::new() }
}
pub fn register<T: Send + Sync + 'static>(&mut self, name: &str, service: T) {
self.register_with_scope(name, service, Scope::Singleton);
}
pub fn register_with_scope<T: Send + Sync + 'static>(&mut self, name: &str, service: T, scope: Scope) {
self.services.insert(name.to_string(), ScopedService { scope, service: Box::new(service) });
}
pub fn get<T: Clone + Send + Sync + 'static>(&self, name: &str) -> WaeResult<T> {
self.services
.get(name)
.and_then(|s| s.service.downcast_ref::<T>())
.cloned()
.ok_or_else(|| WaeError::not_found("Dependency", name))
}
pub fn get_scope(&self, name: &str) -> Option<Scope> {
self.services.get(name).map(|s| s.scope)
}
pub fn register_type<T: Clone + Send + Sync + 'static>(&mut self, service: T) {
self.register_type_with_scope(service, Scope::Singleton);
}
pub fn register_type_with_scope<T: Clone + Send + Sync + 'static>(&mut self, service: T, scope: Scope) {
self.typed_services.insert(TypeId::of::<T>(), ScopedService { scope, service: Box::new(service) });
}
pub fn get_type<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
self.typed_services
.get(&TypeId::of::<T>())
.and_then(|s| s.service.downcast_ref::<T>())
.cloned()
.ok_or_else(|| WaeError::not_found("Typed dependency", std::any::type_name::<T>()))
}
pub fn get_type_scope<T: 'static>(&self) -> Option<Scope> {
self.typed_services.get(&TypeId::of::<T>()).map(|s| s.scope)
}
}
pub struct Effectful {
deps: Arc<Dependencies>,
parts: Parts,
request_scoped_services: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
request_scoped_typed_services: HashMap<TypeId, Box<dyn std::any::Any + Send + Sync>>,
}
impl Effectful {
pub fn new(deps: Arc<Dependencies>, parts: Parts) -> Self {
Self { deps, parts, request_scoped_services: HashMap::new(), request_scoped_typed_services: HashMap::new() }
}
pub fn get<T: Clone + Send + Sync + 'static>(&self, name: &str) -> WaeResult<T> {
if let Some(scope) = self.deps.get_scope(name) {
match scope {
Scope::Singleton => self.deps.get(name),
Scope::RequestScoped => {
if let Some(service) = self.request_scoped_services.get(name) {
service
.downcast_ref::<T>()
.cloned()
.ok_or_else(|| WaeError::not_found("Request-scoped dependency", name))
}
else {
self.deps.get(name)
}
}
}
}
else {
self.deps.get(name)
}
}
pub fn set<T: Send + Sync + 'static>(&mut self, name: &str, service: T) -> WaeResult<()> {
if let Some(Scope::RequestScoped) = self.deps.get_scope(name) {
self.request_scoped_services.insert(name.to_string(), Box::new(service));
Ok(())
}
else {
Err(WaeError::invalid_params("dependency", "Can only set RequestScoped dependencies"))
}
}
pub fn get_type<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
if let Some(scope) = self.deps.get_type_scope::<T>() {
match scope {
Scope::Singleton => self.deps.get_type(),
Scope::RequestScoped => {
if let Some(service) = self.request_scoped_typed_services.get(&TypeId::of::<T>()) {
service
.downcast_ref::<T>()
.cloned()
.ok_or_else(|| WaeError::not_found("Typed request-scoped dependency", std::any::type_name::<T>()))
}
else {
self.deps.get_type()
}
}
}
}
else {
self.deps.get_type()
}
}
pub fn set_type<T: Clone + Send + Sync + 'static>(&mut self, service: T) -> WaeResult<()> {
if let Some(Scope::RequestScoped) = self.deps.get_type_scope::<T>() {
self.request_scoped_typed_services.insert(TypeId::of::<T>(), Box::new(service));
Ok(())
}
else {
Err(WaeError::invalid_params("dependency", "Can only set RequestScoped dependencies"))
}
}
pub fn header(&self, name: &str) -> Option<&str> {
self.parts.headers.get(name).and_then(|v| v.to_str().ok())
}
pub fn parts(&self) -> &Parts {
&self.parts
}
pub fn use_type<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
self.get_type()
}
pub fn use_config<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
self.get_type()
}
pub fn use_auth<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
self.get_type()
}
}
pub struct AlgebraicEffect {
deps: Dependencies,
}
impl Default for AlgebraicEffect {
fn default() -> Self {
Self::new()
}
}
impl AlgebraicEffect {
pub fn new() -> Self {
Self { deps: Dependencies::new() }
}
pub fn with<T: Send + Sync + 'static>(mut self, name: &str, service: T) -> Self {
self.deps.register(name, service);
self
}
pub fn with_scope<T: Send + Sync + 'static>(mut self, name: &str, service: T, scope: Scope) -> Self {
self.deps.register_with_scope(name, service, scope);
self
}
pub fn with_type<T: Clone + Send + Sync + 'static>(mut self, service: T) -> Self {
self.deps.register_type(service);
self
}
pub fn with_type_scope<T: Clone + Send + Sync + 'static>(mut self, service: T, scope: Scope) -> Self {
self.deps.register_type_with_scope(service, scope);
self
}
pub fn with_config<T: Clone + Send + Sync + 'static>(mut self, config: T) -> Self {
self.deps.register_type(config);
self
}
pub fn with_config_scope<T: Clone + Send + Sync + 'static>(mut self, config: T, scope: Scope) -> Self {
self.deps.register_type_with_scope(config, scope);
self
}
pub fn with_auth<T: Clone + Send + Sync + 'static>(mut self, auth: T) -> Self {
self.deps.register_type(auth);
self
}
pub fn with_auth_scope<T: Clone + Send + Sync + 'static>(mut self, auth: T, scope: Scope) -> Self {
self.deps.register_type_with_scope(auth, scope);
self
}
pub fn build(self) -> Arc<Dependencies> {
Arc::new(self.deps)
}
}
pub struct WaeErrorResponse(pub WaeError);
impl WaeErrorResponse {
pub fn into_response<B>(self) -> Response<B>
where
B: From<String>,
{
let status = self.0.http_status();
let body = B::from(self.0.to_string());
let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
Response::builder().status(status_code).body(body).unwrap()
}
}