Skip to main content

apcore/middleware/
base.rs

1// APCore Protocol — Middleware base trait
2// Spec reference: Middleware lifecycle (before, after, on_error)
3
4use async_trait::async_trait;
5
6use crate::context::Context;
7use crate::errors::ModuleError;
8
9/// Return value from `Middleware::on_error_outcome` requesting a retry.
10///
11/// Distinct from a plain `Recovery(value)` — `Recovery` is the *final
12/// recovery output* of the call. `RetrySignal` instead asks the executor
13/// to re-run the module with `inputs`; no recovery output is produced.
14///
15/// Cross-language parity with apcore-python `apcore.middleware.RetrySignal`
16/// and apcore-typescript `apcore-js.RetrySignal` (sync finding A-D-017).
17#[derive(Debug, Clone)]
18pub struct RetrySignal {
19    pub inputs: serde_json::Value,
20}
21
22impl RetrySignal {
23    #[must_use]
24    pub fn new(inputs: serde_json::Value) -> Self {
25        Self { inputs }
26    }
27}
28
29/// Outcome of a middleware's `on_error_outcome` hook.
30///
31/// - `Recovery(value)` — the middleware produced a recovery output; the
32///   executor returns this value to the caller and skips the rest of the
33///   error path.
34/// - `Retry(signal)` — the middleware asks for a pipeline retry with new
35///   inputs (only honored by the unary `Executor::call` path; ignored for
36///   streaming, where mid-flight retry is not well-defined).
37#[derive(Debug, Clone)]
38pub enum OnErrorOutcome {
39    Recovery(serde_json::Value),
40    Retry(RetrySignal),
41}
42
43/// Core middleware trait with `before/after/on_error` hooks.
44///
45/// All hooks return `Option<Value>`:
46/// - `Some(value)` means the middleware modified the input/output/recovery value.
47/// - `None` means "no modification" — the pipeline keeps the previous value.
48///
49/// `on_error` returns `Option<Value>` where `Some(value)` signals a recovery
50/// (the pipeline should retry with the returned inputs) and `None` means
51/// the error should propagate.
52#[async_trait]
53pub trait Middleware: Send + Sync + std::fmt::Debug {
54    /// Name of this middleware for logging/debugging.
55    fn name(&self) -> &str;
56
57    /// Priority of this middleware (higher runs first). Default is 100.
58    /// Valid range: 0-1000 (enforced by `MiddlewareManager::add`).
59    /// When two middlewares have the same priority, registration order is preserved.
60    fn priority(&self) -> u16 {
61        100
62    }
63
64    /// Called before module execution. Can modify input.
65    /// Return `Ok(None)` to pass through unchanged, `Ok(Some(v))` to modify.
66    async fn before(
67        &self,
68        module_id: &str,
69        inputs: serde_json::Value,
70        ctx: &Context<serde_json::Value>,
71    ) -> Result<Option<serde_json::Value>, ModuleError>;
72
73    /// Called after successful module execution. Can modify output.
74    /// `inputs` is the original (post-before) input for correlation.
75    /// Return `Ok(None)` to pass through unchanged, `Ok(Some(v))` to modify.
76    async fn after(
77        &self,
78        module_id: &str,
79        inputs: serde_json::Value,
80        output: serde_json::Value,
81        ctx: &Context<serde_json::Value>,
82    ) -> Result<Option<serde_json::Value>, ModuleError>;
83
84    /// Called when module execution fails.
85    /// `inputs` is the original (post-before) input for correlation.
86    /// Return `Ok(Some(v))` to signal a recovery output, or `Ok(None)` to
87    /// let the error propagate.
88    ///
89    /// To request a pipeline retry instead of a recovery, override
90    /// [`Self::on_error_outcome`] and return `OnErrorOutcome::Retry(...)`.
91    async fn on_error(
92        &self,
93        module_id: &str,
94        inputs: serde_json::Value,
95        error: &ModuleError,
96        ctx: &Context<serde_json::Value>,
97    ) -> Result<Option<serde_json::Value>, ModuleError>;
98
99    /// Extended on_error hook that can request a pipeline retry via
100    /// [`OnErrorOutcome::Retry`] in addition to producing a recovery output.
101    ///
102    /// Default implementation delegates to [`Self::on_error`] and wraps any
103    /// returned value as `OnErrorOutcome::Recovery` — existing middlewares
104    /// work unchanged. Override this method to opt into retry semantics
105    /// (cross-language parity with apcore-python and apcore-typescript
106    /// `Middleware.on_error` returning `RetrySignal`; sync finding A-D-017).
107    async fn on_error_outcome(
108        &self,
109        module_id: &str,
110        inputs: serde_json::Value,
111        error: &ModuleError,
112        ctx: &Context<serde_json::Value>,
113    ) -> Result<Option<OnErrorOutcome>, ModuleError> {
114        Ok(self
115            .on_error(module_id, inputs, error, ctx)
116            .await?
117            .map(OnErrorOutcome::Recovery))
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::errors::ErrorCode;
125    use serde_json::json;
126
127    #[derive(Debug)]
128    struct TestMiddleware {
129        name: String,
130        prio: u16,
131    }
132
133    impl TestMiddleware {
134        fn new(name: &str, prio: u16) -> Self {
135            Self {
136                name: name.to_string(),
137                prio,
138            }
139        }
140    }
141
142    #[async_trait]
143    impl Middleware for TestMiddleware {
144        fn name(&self) -> &str {
145            &self.name
146        }
147
148        fn priority(&self) -> u16 {
149            self.prio
150        }
151
152        async fn before(
153            &self,
154            _module_id: &str,
155            _inputs: serde_json::Value,
156            _ctx: &Context<serde_json::Value>,
157        ) -> Result<Option<serde_json::Value>, ModuleError> {
158            Ok(None)
159        }
160
161        async fn after(
162            &self,
163            _module_id: &str,
164            _inputs: serde_json::Value,
165            _output: serde_json::Value,
166            _ctx: &Context<serde_json::Value>,
167        ) -> Result<Option<serde_json::Value>, ModuleError> {
168            Ok(None)
169        }
170
171        async fn on_error(
172            &self,
173            _module_id: &str,
174            _inputs: serde_json::Value,
175            _error: &ModuleError,
176            _ctx: &Context<serde_json::Value>,
177        ) -> Result<Option<serde_json::Value>, ModuleError> {
178            Ok(None)
179        }
180    }
181
182    #[test]
183    fn test_middleware_default_priority() {
184        #[derive(Debug)]
185        struct DefaultPrio;
186
187        #[async_trait]
188        impl Middleware for DefaultPrio {
189            fn name(&self) -> &'static str {
190                "default"
191            }
192            async fn before(
193                &self,
194                _: &str,
195                _: serde_json::Value,
196                _: &Context<serde_json::Value>,
197            ) -> Result<Option<serde_json::Value>, ModuleError> {
198                Ok(None)
199            }
200            async fn after(
201                &self,
202                _: &str,
203                _: serde_json::Value,
204                _: serde_json::Value,
205                _: &Context<serde_json::Value>,
206            ) -> Result<Option<serde_json::Value>, ModuleError> {
207                Ok(None)
208            }
209            async fn on_error(
210                &self,
211                _: &str,
212                _: serde_json::Value,
213                _: &ModuleError,
214                _: &Context<serde_json::Value>,
215            ) -> Result<Option<serde_json::Value>, ModuleError> {
216                Ok(None)
217            }
218        }
219
220        let mw = DefaultPrio;
221        assert_eq!(mw.priority(), 100);
222    }
223
224    #[test]
225    fn test_middleware_custom_priority() {
226        let mw = TestMiddleware::new("high_priority", 500);
227        assert_eq!(mw.priority(), 500);
228        assert_eq!(mw.name(), "high_priority");
229    }
230
231    #[tokio::test]
232    async fn test_middleware_before_returns_none() {
233        let mw = TestMiddleware::new("test", 100);
234        let ctx = Context::<serde_json::Value>::anonymous();
235        let result = mw.before("mod.a", json!({"x": 1}), &ctx).await.unwrap();
236        assert_eq!(result, None);
237    }
238
239    #[tokio::test]
240    async fn test_middleware_after_returns_none() {
241        let mw = TestMiddleware::new("test", 100);
242        let ctx = Context::<serde_json::Value>::anonymous();
243        let result = mw
244            .after("mod.a", json!({}), json!({"result": true}), &ctx)
245            .await
246            .unwrap();
247        assert_eq!(result, None);
248    }
249
250    #[tokio::test]
251    async fn test_middleware_on_error_returns_none() {
252        let mw = TestMiddleware::new("test", 100);
253        let ctx = Context::<serde_json::Value>::anonymous();
254        let err = ModuleError::new(ErrorCode::ModuleExecuteError, "boom");
255        let result = mw.on_error("mod.a", json!({}), &err, &ctx).await.unwrap();
256        assert_eq!(result, None);
257    }
258}