1use async_trait::async_trait;
5
6use crate::context::Context;
7use crate::errors::ModuleError;
8
9#[derive(Debug, Clone)]
18pub struct RetrySignal {
19 pub inputs: serde_json::Value,
20}
21
22impl RetrySignal {
23 #[must_use]
24 pub fn new(inputs: serde_json::Value) -> Self {
25 Self { inputs }
26 }
27}
28
29#[derive(Debug, Clone)]
38pub enum OnErrorOutcome {
39 Recovery(serde_json::Value),
40 Retry(RetrySignal),
41}
42
43#[async_trait]
53pub trait Middleware: Send + Sync + std::fmt::Debug {
54 fn name(&self) -> &str;
56
57 fn priority(&self) -> u16 {
61 100
62 }
63
64 async fn before(
67 &self,
68 module_id: &str,
69 inputs: serde_json::Value,
70 ctx: &Context<serde_json::Value>,
71 ) -> Result<Option<serde_json::Value>, ModuleError>;
72
73 async fn after(
77 &self,
78 module_id: &str,
79 inputs: serde_json::Value,
80 output: serde_json::Value,
81 ctx: &Context<serde_json::Value>,
82 ) -> Result<Option<serde_json::Value>, ModuleError>;
83
84 async fn on_error(
92 &self,
93 module_id: &str,
94 inputs: serde_json::Value,
95 error: &ModuleError,
96 ctx: &Context<serde_json::Value>,
97 ) -> Result<Option<serde_json::Value>, ModuleError>;
98
99 async fn on_error_outcome(
108 &self,
109 module_id: &str,
110 inputs: serde_json::Value,
111 error: &ModuleError,
112 ctx: &Context<serde_json::Value>,
113 ) -> Result<Option<OnErrorOutcome>, ModuleError> {
114 Ok(self
115 .on_error(module_id, inputs, error, ctx)
116 .await?
117 .map(OnErrorOutcome::Recovery))
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::errors::ErrorCode;
125 use serde_json::json;
126
127 #[derive(Debug)]
128 struct TestMiddleware {
129 name: String,
130 prio: u16,
131 }
132
133 impl TestMiddleware {
134 fn new(name: &str, prio: u16) -> Self {
135 Self {
136 name: name.to_string(),
137 prio,
138 }
139 }
140 }
141
142 #[async_trait]
143 impl Middleware for TestMiddleware {
144 fn name(&self) -> &str {
145 &self.name
146 }
147
148 fn priority(&self) -> u16 {
149 self.prio
150 }
151
152 async fn before(
153 &self,
154 _module_id: &str,
155 _inputs: serde_json::Value,
156 _ctx: &Context<serde_json::Value>,
157 ) -> Result<Option<serde_json::Value>, ModuleError> {
158 Ok(None)
159 }
160
161 async fn after(
162 &self,
163 _module_id: &str,
164 _inputs: serde_json::Value,
165 _output: serde_json::Value,
166 _ctx: &Context<serde_json::Value>,
167 ) -> Result<Option<serde_json::Value>, ModuleError> {
168 Ok(None)
169 }
170
171 async fn on_error(
172 &self,
173 _module_id: &str,
174 _inputs: serde_json::Value,
175 _error: &ModuleError,
176 _ctx: &Context<serde_json::Value>,
177 ) -> Result<Option<serde_json::Value>, ModuleError> {
178 Ok(None)
179 }
180 }
181
182 #[test]
183 fn test_middleware_default_priority() {
184 #[derive(Debug)]
185 struct DefaultPrio;
186
187 #[async_trait]
188 impl Middleware for DefaultPrio {
189 fn name(&self) -> &'static str {
190 "default"
191 }
192 async fn before(
193 &self,
194 _: &str,
195 _: serde_json::Value,
196 _: &Context<serde_json::Value>,
197 ) -> Result<Option<serde_json::Value>, ModuleError> {
198 Ok(None)
199 }
200 async fn after(
201 &self,
202 _: &str,
203 _: serde_json::Value,
204 _: serde_json::Value,
205 _: &Context<serde_json::Value>,
206 ) -> Result<Option<serde_json::Value>, ModuleError> {
207 Ok(None)
208 }
209 async fn on_error(
210 &self,
211 _: &str,
212 _: serde_json::Value,
213 _: &ModuleError,
214 _: &Context<serde_json::Value>,
215 ) -> Result<Option<serde_json::Value>, ModuleError> {
216 Ok(None)
217 }
218 }
219
220 let mw = DefaultPrio;
221 assert_eq!(mw.priority(), 100);
222 }
223
224 #[test]
225 fn test_middleware_custom_priority() {
226 let mw = TestMiddleware::new("high_priority", 500);
227 assert_eq!(mw.priority(), 500);
228 assert_eq!(mw.name(), "high_priority");
229 }
230
231 #[tokio::test]
232 async fn test_middleware_before_returns_none() {
233 let mw = TestMiddleware::new("test", 100);
234 let ctx = Context::<serde_json::Value>::anonymous();
235 let result = mw.before("mod.a", json!({"x": 1}), &ctx).await.unwrap();
236 assert_eq!(result, None);
237 }
238
239 #[tokio::test]
240 async fn test_middleware_after_returns_none() {
241 let mw = TestMiddleware::new("test", 100);
242 let ctx = Context::<serde_json::Value>::anonymous();
243 let result = mw
244 .after("mod.a", json!({}), json!({"result": true}), &ctx)
245 .await
246 .unwrap();
247 assert_eq!(result, None);
248 }
249
250 #[tokio::test]
251 async fn test_middleware_on_error_returns_none() {
252 let mw = TestMiddleware::new("test", 100);
253 let ctx = Context::<serde_json::Value>::anonymous();
254 let err = ModuleError::new(ErrorCode::ModuleExecuteError, "boom");
255 let result = mw.on_error("mod.a", json!({}), &err, &ctx).await.unwrap();
256 assert_eq!(result, None);
257 }
258}