Skip to main content

wae_effect/
lib.rs

1//! 代数效应系统
2//!
3//! 提供声明式的依赖注入能力,允许在处理请求时通过类型安全的接口获取各种依赖。
4
5#![warn(missing_docs)]
6
7use std::{any::TypeId, collections::HashMap, sync::Arc};
8
9use http::{Response, StatusCode, request::Parts};
10use wae_types::{WaeError, WaeResult};
11
12/// 依赖作用域
13///
14/// 定义依赖的生命周期和可见范围。
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Scope {
17    /// 单例作用域
18    ///
19    /// 依赖在整个应用生命周期内只创建一次,所有请求共享同一个实例。
20    Singleton,
21    /// 请求作用域
22    ///
23    /// 每个请求都会创建一个新的依赖实例,仅在该请求范围内可见。
24    RequestScoped,
25}
26
27impl Default for Scope {
28    fn default() -> Self {
29        Scope::Singleton
30    }
31}
32
33/// 带作用域的服务包装器
34struct ScopedService {
35    scope: Scope,
36    service: Box<dyn std::any::Any + Send + Sync>,
37}
38
39/// 依赖容器
40///
41/// 存储所有注册的服务实例。
42#[derive(Default)]
43pub struct Dependencies {
44    services: HashMap<String, ScopedService>,
45    typed_services: HashMap<TypeId, ScopedService>,
46}
47
48impl Dependencies {
49    /// 创建新的依赖容器
50    pub fn new() -> Self {
51        Self { services: HashMap::new(), typed_services: HashMap::new() }
52    }
53
54    /// 注册服务(按字符串键,默认单例作用域)
55    pub fn register<T: Send + Sync + 'static>(&mut self, name: &str, service: T) {
56        self.register_with_scope(name, service, Scope::Singleton);
57    }
58
59    /// 按作用域注册服务(按字符串键)
60    pub fn register_with_scope<T: Send + Sync + 'static>(&mut self, name: &str, service: T, scope: Scope) {
61        self.services.insert(name.to_string(), ScopedService { scope, service: Box::new(service) });
62    }
63
64    /// 获取服务(按字符串键)
65    pub fn get<T: Clone + Send + Sync + 'static>(&self, name: &str) -> WaeResult<T> {
66        self.services
67            .get(name)
68            .and_then(|s| s.service.downcast_ref::<T>())
69            .cloned()
70            .ok_or_else(|| WaeError::not_found("Dependency", name))
71    }
72
73    /// 检查服务的作用域(按字符串键)
74    pub fn get_scope(&self, name: &str) -> Option<Scope> {
75        self.services.get(name).map(|s| s.scope)
76    }
77
78    /// 按类型注册服务(默认单例作用域)
79    pub fn register_type<T: Clone + Send + Sync + 'static>(&mut self, service: T) {
80        self.register_type_with_scope(service, Scope::Singleton);
81    }
82
83    /// 按作用域注册服务(按类型)
84    pub fn register_type_with_scope<T: Clone + Send + Sync + 'static>(&mut self, service: T, scope: Scope) {
85        self.typed_services.insert(TypeId::of::<T>(), ScopedService { scope, service: Box::new(service) });
86    }
87
88    /// 按类型获取服务
89    pub fn get_type<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
90        self.typed_services
91            .get(&TypeId::of::<T>())
92            .and_then(|s| s.service.downcast_ref::<T>())
93            .cloned()
94            .ok_or_else(|| WaeError::not_found("Typed dependency", std::any::type_name::<T>()))
95    }
96
97    /// 检查服务的作用域(按类型)
98    pub fn get_type_scope<T: 'static>(&self) -> Option<Scope> {
99        self.typed_services.get(&TypeId::of::<T>()).map(|s| s.scope)
100    }
101}
102
103/// 代数效应请求上下文
104///
105/// 包含依赖容器和请求信息,用于在请求处理过程中获取依赖。
106pub struct Effectful {
107    deps: Arc<Dependencies>,
108    parts: Parts,
109    request_scoped_services: HashMap<String, Box<dyn std::any::Any + Send + Sync>>,
110    request_scoped_typed_services: HashMap<TypeId, Box<dyn std::any::Any + Send + Sync>>,
111}
112
113impl Effectful {
114    /// 创建新的 Effectful
115    pub fn new(deps: Arc<Dependencies>, parts: Parts) -> Self {
116        Self { deps, parts, request_scoped_services: HashMap::new(), request_scoped_typed_services: HashMap::new() }
117    }
118
119    /// 获取依赖(按字符串键)
120    ///
121    /// 如果依赖是 RequestScoped 作用域,将从当前请求的独立存储中获取或初始化。
122    pub fn get<T: Clone + Send + Sync + 'static>(&self, name: &str) -> WaeResult<T> {
123        if let Some(scope) = self.deps.get_scope(name) {
124            match scope {
125                Scope::Singleton => self.deps.get(name),
126                Scope::RequestScoped => {
127                    if let Some(service) = self.request_scoped_services.get(name) {
128                        service
129                            .downcast_ref::<T>()
130                            .cloned()
131                            .ok_or_else(|| WaeError::not_found("Request-scoped dependency", name))
132                    }
133                    else {
134                        self.deps.get(name)
135                    }
136                }
137            }
138        }
139        else {
140            self.deps.get(name)
141        }
142    }
143
144    /// 设置请求作用域的依赖(按字符串键)
145    ///
146    /// 仅对 RequestScoped 作用域的依赖有效。
147    pub fn set<T: Send + Sync + 'static>(&mut self, name: &str, service: T) -> WaeResult<()> {
148        if let Some(Scope::RequestScoped) = self.deps.get_scope(name) {
149            self.request_scoped_services.insert(name.to_string(), Box::new(service));
150            Ok(())
151        }
152        else {
153            Err(WaeError::invalid_params("dependency", "Can only set RequestScoped dependencies"))
154        }
155    }
156
157    /// 按类型获取依赖
158    ///
159    /// 如果依赖是 RequestScoped 作用域,将从当前请求的独立存储中获取或初始化。
160    pub fn get_type<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
161        if let Some(scope) = self.deps.get_type_scope::<T>() {
162            match scope {
163                Scope::Singleton => self.deps.get_type(),
164                Scope::RequestScoped => {
165                    if let Some(service) = self.request_scoped_typed_services.get(&TypeId::of::<T>()) {
166                        service
167                            .downcast_ref::<T>()
168                            .cloned()
169                            .ok_or_else(|| WaeError::not_found("Typed request-scoped dependency", std::any::type_name::<T>()))
170                    }
171                    else {
172                        self.deps.get_type()
173                    }
174                }
175            }
176        }
177        else {
178            self.deps.get_type()
179        }
180    }
181
182    /// 设置请求作用域的依赖(按类型)
183    ///
184    /// 仅对 RequestScoped 作用域的依赖有效。
185    pub fn set_type<T: Clone + Send + Sync + 'static>(&mut self, service: T) -> WaeResult<()> {
186        if let Some(Scope::RequestScoped) = self.deps.get_type_scope::<T>() {
187            self.request_scoped_typed_services.insert(TypeId::of::<T>(), Box::new(service));
188            Ok(())
189        }
190        else {
191            Err(WaeError::invalid_params("dependency", "Can only set RequestScoped dependencies"))
192        }
193    }
194
195    /// 获取请求头
196    pub fn header(&self, name: &str) -> Option<&str> {
197        self.parts.headers.get(name).and_then(|v| v.to_str().ok())
198    }
199
200    /// 获取请求 Parts 的引用
201    pub fn parts(&self) -> &Parts {
202        &self.parts
203    }
204
205    /// 按类型获取依赖(便捷方法)
206    pub fn use_type<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
207        self.get_type()
208    }
209
210    /// 获取配置(按类型)
211    pub fn use_config<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
212        self.get_type()
213    }
214
215    /// 获取认证服务(按类型)
216    pub fn use_auth<T: Clone + Send + Sync + 'static>(&self) -> WaeResult<T> {
217        self.get_type()
218    }
219}
220
221/// 代数效应构建器
222///
223/// 用于构建依赖容器并注册各种依赖。
224pub struct AlgebraicEffect {
225    deps: Dependencies,
226}
227
228impl Default for AlgebraicEffect {
229    fn default() -> Self {
230        Self::new()
231    }
232}
233
234impl AlgebraicEffect {
235    /// 创建新的代数效应构建器
236    pub fn new() -> Self {
237        Self { deps: Dependencies::new() }
238    }
239
240    /// 注册服务(按字符串键,默认单例作用域)
241    pub fn with<T: Send + Sync + 'static>(mut self, name: &str, service: T) -> Self {
242        self.deps.register(name, service);
243        self
244    }
245
246    /// 按作用域注册服务(按字符串键)
247    pub fn with_scope<T: Send + Sync + 'static>(mut self, name: &str, service: T, scope: Scope) -> Self {
248        self.deps.register_with_scope(name, service, scope);
249        self
250    }
251
252    /// 按类型注册服务(默认单例作用域)
253    pub fn with_type<T: Clone + Send + Sync + 'static>(mut self, service: T) -> Self {
254        self.deps.register_type(service);
255        self
256    }
257
258    /// 按作用域注册服务(按类型)
259    pub fn with_type_scope<T: Clone + Send + Sync + 'static>(mut self, service: T, scope: Scope) -> Self {
260        self.deps.register_type_with_scope(service, scope);
261        self
262    }
263
264    /// 注册配置服务(按类型)
265    pub fn with_config<T: Clone + Send + Sync + 'static>(mut self, config: T) -> Self {
266        self.deps.register_type(config);
267        self
268    }
269
270    /// 注册配置服务(按类型和作用域)
271    pub fn with_config_scope<T: Clone + Send + Sync + 'static>(mut self, config: T, scope: Scope) -> Self {
272        self.deps.register_type_with_scope(config, scope);
273        self
274    }
275
276    /// 注册认证服务(按类型)
277    pub fn with_auth<T: Clone + Send + Sync + 'static>(mut self, auth: T) -> Self {
278        self.deps.register_type(auth);
279        self
280    }
281
282    /// 注册认证服务(按类型和作用域)
283    pub fn with_auth_scope<T: Clone + Send + Sync + 'static>(mut self, auth: T, scope: Scope) -> Self {
284        self.deps.register_type_with_scope(auth, scope);
285        self
286    }
287
288    /// 构建依赖容器
289    pub fn build(self) -> Arc<Dependencies> {
290        Arc::new(self.deps)
291    }
292}
293
294/// WaeError 的包装类型,用于解决 orphan rule 问题
295pub struct WaeErrorResponse(pub WaeError);
296
297impl WaeErrorResponse {
298    /// 将 WaeError 转换为 http::Response
299    pub fn into_response<B>(self) -> Response<B>
300    where
301        B: From<String>,
302    {
303        let status = self.0.http_status();
304        let body = B::from(self.0.to_string());
305        let status_code = StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
306        Response::builder().status(status_code).body(body).unwrap()
307    }
308}