Skip to main content

barbacane_wasm/
chain.rs

1//! Middleware chain execution.
2//!
3//! Per SPEC-003 section 5, middlewares execute in order for on_request,
4//! and reverse order for on_response. A middleware returning 1 from
5//! on_request short-circuits the chain with an immediate response.
6
7use std::time::Instant;
8
9use serde::{Deserialize, Serialize};
10
11use crate::error::WasmError;
12use crate::instance::{PluginInstance, RequestContext};
13use crate::trap::{TrapContext, TrapResult};
14
15/// Callback for recording middleware metrics.
16/// Parameters: middleware_name, phase ("request" or "response"), duration_secs, short_circuit
17pub type MetricsCallback<'a> = Option<&'a dyn Fn(&str, &str, f64, bool)>;
18
19/// The result of executing on_request on a single middleware.
20#[derive(Debug)]
21pub enum OnRequestResult {
22    /// Continue to the next middleware with possibly modified request.
23    Continue(Vec<u8>),
24    /// Short-circuit with an immediate response.
25    ShortCircuit(Vec<u8>),
26}
27
28/// Output format from middleware on_request.
29///
30/// The plugin sets output as JSON with this structure.
31#[derive(Debug, Serialize, Deserialize)]
32struct MiddlewareOutput {
33    /// 0 = continue, 1 = short-circuit
34    action: i32,
35    /// The request (if continue) or response (if short-circuit)
36    data: serde_json::Value,
37}
38
39/// A configured middleware in the chain.
40#[derive(Debug, Clone)]
41pub struct MiddlewareConfig {
42    /// Plugin name.
43    pub name: String,
44    /// Plugin config as JSON.
45    pub config: serde_json::Value,
46}
47
48impl MiddlewareConfig {
49    /// Create a new middleware config.
50    pub fn new(name: impl Into<String>, config: serde_json::Value) -> Self {
51        Self {
52            name: name.into(),
53            config,
54        }
55    }
56}
57
58/// A middleware chain that executes multiple middlewares in sequence.
59pub struct MiddlewareChain {
60    /// Middleware configs in order of execution.
61    configs: Vec<MiddlewareConfig>,
62}
63
64impl MiddlewareChain {
65    /// Create a new empty middleware chain.
66    pub fn new() -> Self {
67        Self {
68            configs: Vec::new(),
69        }
70    }
71
72    /// Create a chain from a list of middleware configs.
73    pub fn from_configs(configs: Vec<MiddlewareConfig>) -> Self {
74        Self { configs }
75    }
76
77    /// Add a middleware to the chain.
78    pub fn push(&mut self, config: MiddlewareConfig) {
79        self.configs.push(config);
80    }
81
82    /// Get the number of middlewares in the chain.
83    pub fn len(&self) -> usize {
84        self.configs.len()
85    }
86
87    /// Check if the chain is empty.
88    pub fn is_empty(&self) -> bool {
89        self.configs.is_empty()
90    }
91
92    /// Get the middleware configs.
93    pub fn configs(&self) -> &[MiddlewareConfig] {
94        &self.configs
95    }
96}
97
98impl Default for MiddlewareChain {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// Result of executing the full request chain.
105#[derive(Debug)]
106pub enum ChainResult {
107    /// Chain completed, continue to dispatch.
108    Continue {
109        /// The final request after all middlewares.
110        request: Vec<u8>,
111        /// Context to pass to on_response chain.
112        context: RequestContext,
113    },
114    /// Chain short-circuited with a response.
115    ShortCircuit {
116        /// The response from the short-circuiting middleware.
117        response: Vec<u8>,
118        /// Index of the middleware that short-circuited.
119        middleware_index: usize,
120        /// Context for response chain (partial).
121        context: RequestContext,
122    },
123    /// Chain failed with an error.
124    Error {
125        /// The error that occurred.
126        error: WasmError,
127        /// The trap result for error handling.
128        trap_result: TrapResult,
129    },
130}
131
132/// Execute the on_request chain.
133///
134/// Processes middlewares in order. If any middleware returns 1 (short-circuit),
135/// stops and returns the response. If all middlewares return 0 (continue),
136/// returns the final request for dispatch.
137pub fn execute_on_request(
138    instances: &mut [PluginInstance],
139    initial_request: &[u8],
140    context: RequestContext,
141) -> ChainResult {
142    execute_on_request_with_metrics(instances, initial_request, context, None)
143}
144
145/// Execute the on_request chain with optional metrics recording.
146pub fn execute_on_request_with_metrics(
147    instances: &mut [PluginInstance],
148    initial_request: &[u8],
149    context: RequestContext,
150    metrics_callback: MetricsCallback<'_>,
151) -> ChainResult {
152    let mut current_request = initial_request.to_vec();
153    let mut current_context = context;
154
155    for (index, instance) in instances.iter_mut().enumerate() {
156        // Set context for this middleware
157        instance.set_context(current_context.clone());
158
159        // Record start time
160        let start = Instant::now();
161        let middleware_name = instance.name().to_string();
162
163        // Call on_request
164        match instance.on_request(&current_request) {
165            Ok(result_code) => {
166                let output = instance.take_output();
167
168                // Parse the output to determine action
169                match parse_middleware_output(&output, result_code) {
170                    Ok(OnRequestResult::Continue(new_request)) => {
171                        // Record metrics (not a short-circuit)
172                        if let Some(callback) = metrics_callback {
173                            callback(
174                                &middleware_name,
175                                "request",
176                                start.elapsed().as_secs_f64(),
177                                false,
178                            );
179                        }
180                        current_request = new_request;
181                        // Get context modifications from the middleware
182                        current_context = instance.get_context();
183                    }
184                    Ok(OnRequestResult::ShortCircuit(response)) => {
185                        // Record metrics (short-circuit)
186                        if let Some(callback) = metrics_callback {
187                            callback(
188                                &middleware_name,
189                                "request",
190                                start.elapsed().as_secs_f64(),
191                                true,
192                            );
193                        }
194                        // Get context modifications before short-circuit
195                        let final_context = instance.get_context();
196                        return ChainResult::ShortCircuit {
197                            response,
198                            middleware_index: index,
199                            context: final_context,
200                        };
201                    }
202                    Err(e) => {
203                        // Record metrics for error case
204                        if let Some(callback) = metrics_callback {
205                            callback(
206                                &middleware_name,
207                                "request",
208                                start.elapsed().as_secs_f64(),
209                                false,
210                            );
211                        }
212                        return ChainResult::Error {
213                            trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
214                            error: e,
215                        };
216                    }
217                }
218            }
219            Err(e) => {
220                // Record metrics for error case
221                if let Some(callback) = metrics_callback {
222                    callback(
223                        &middleware_name,
224                        "request",
225                        start.elapsed().as_secs_f64(),
226                        false,
227                    );
228                }
229                return ChainResult::Error {
230                    trap_result: TrapResult::from_error(&e, TrapContext::OnRequest),
231                    error: e,
232                };
233            }
234        }
235    }
236
237    ChainResult::Continue {
238        request: current_request,
239        context: current_context,
240    }
241}
242
243/// Execute the on_response chain.
244///
245/// Processes middlewares in reverse order. If any middleware fails,
246/// logs the error and continues with the original response (fault-tolerant).
247pub fn execute_on_response(
248    instances: &mut [PluginInstance],
249    initial_response: &[u8],
250    context: RequestContext,
251) -> Vec<u8> {
252    execute_on_response_with_metrics(instances, initial_response, context, None)
253}
254
255/// Execute the on_response chain with optional metrics recording.
256pub fn execute_on_response_with_metrics(
257    instances: &mut [PluginInstance],
258    initial_response: &[u8],
259    context: RequestContext,
260    metrics_callback: MetricsCallback<'_>,
261) -> Vec<u8> {
262    let mut current_response = initial_response.to_vec();
263
264    // Process in reverse order
265    for instance in instances.iter_mut().rev() {
266        instance.set_context(context.clone());
267
268        // Record start time
269        let start = Instant::now();
270        let middleware_name = instance.name().to_string();
271
272        match instance.on_response(&current_response) {
273            Ok(_result_code) => {
274                // Record metrics
275                if let Some(callback) = metrics_callback {
276                    callback(
277                        &middleware_name,
278                        "response",
279                        start.elapsed().as_secs_f64(),
280                        false,
281                    );
282                }
283                let output = instance.take_output();
284                if !output.is_empty() {
285                    current_response = output;
286                }
287            }
288            Err(e) => {
289                // Record metrics for error case
290                if let Some(callback) = metrics_callback {
291                    callback(
292                        &middleware_name,
293                        "response",
294                        start.elapsed().as_secs_f64(),
295                        false,
296                    );
297                }
298                // Fault-tolerant: log and continue with current response
299                let trap_result = TrapResult::from_error(&e, TrapContext::OnResponse);
300                tracing::warn!(
301                    error = %trap_result.message(),
302                    "Middleware on_response failed, continuing with original response"
303                );
304            }
305        }
306    }
307
308    current_response
309}
310
311/// Execute on_response for a partial chain (after short-circuit).
312///
313/// Only processes middlewares up to (but not including) the short-circuiting one.
314pub fn execute_on_response_partial(
315    instances: &mut [PluginInstance],
316    response: &[u8],
317    short_circuit_index: usize,
318    context: RequestContext,
319) -> Vec<u8> {
320    if short_circuit_index == 0 {
321        return response.to_vec();
322    }
323
324    let partial_instances = &mut instances[..short_circuit_index];
325    execute_on_response(partial_instances, response, context)
326}
327
328/// Parse middleware output to determine the action.
329pub fn parse_middleware_output(
330    output: &[u8],
331    result_code: i32,
332) -> Result<OnRequestResult, WasmError> {
333    // If no output, use result code as simple continue/short-circuit
334    if output.is_empty() {
335        return if result_code == 0 {
336            Ok(OnRequestResult::Continue(Vec::new()))
337        } else {
338            Err(WasmError::InitFailed(
339                "middleware returned short-circuit without output".into(),
340            ))
341        };
342    }
343
344    // Try to parse as MiddlewareOutput JSON
345    match serde_json::from_slice::<MiddlewareOutput>(output) {
346        Ok(parsed) => {
347            let data = serde_json::to_vec(&parsed.data)
348                .map_err(|e| WasmError::InitFailed(format!("failed to serialize output: {}", e)))?;
349
350            if parsed.action == 0 || result_code == 0 {
351                Ok(OnRequestResult::Continue(data))
352            } else {
353                Ok(OnRequestResult::ShortCircuit(data))
354            }
355        }
356        Err(_) => {
357            // If not structured output, use raw output with result code
358            if result_code == 0 {
359                Ok(OnRequestResult::Continue(output.to_vec()))
360            } else {
361                Ok(OnRequestResult::ShortCircuit(output.to_vec()))
362            }
363        }
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use serde_json::json;
371
372    #[test]
373    fn middleware_config_new() {
374        let config = MiddlewareConfig::new("rate-limit", json!({"quota": 100}));
375        assert_eq!(config.name, "rate-limit");
376        assert_eq!(config.config["quota"], 100);
377    }
378
379    #[test]
380    fn chain_new_is_empty() {
381        let chain = MiddlewareChain::new();
382        assert!(chain.is_empty());
383        assert_eq!(chain.len(), 0);
384    }
385
386    #[test]
387    fn chain_push() {
388        let mut chain = MiddlewareChain::new();
389        chain.push(MiddlewareConfig::new("auth", json!({})));
390        chain.push(MiddlewareConfig::new("rate-limit", json!({})));
391
392        assert_eq!(chain.len(), 2);
393        assert_eq!(chain.configs()[0].name, "auth");
394        assert_eq!(chain.configs()[1].name, "rate-limit");
395    }
396
397    #[test]
398    fn chain_from_configs() {
399        let configs = vec![
400            MiddlewareConfig::new("auth", json!({})),
401            MiddlewareConfig::new("cors", json!({})),
402        ];
403        let chain = MiddlewareChain::from_configs(configs);
404
405        assert_eq!(chain.len(), 2);
406    }
407
408    #[test]
409    fn parse_continue_output() {
410        let output = serde_json::to_vec(&json!({
411            "action": 0,
412            "data": {"method": "GET", "path": "/api"}
413        }))
414        .unwrap();
415
416        let result = parse_middleware_output(&output, 0).unwrap();
417        assert!(matches!(result, OnRequestResult::Continue(_)));
418    }
419
420    #[test]
421    fn parse_short_circuit_output() {
422        let output = serde_json::to_vec(&json!({
423            "action": 1,
424            "data": {"status": 401, "body": "Unauthorized"}
425        }))
426        .unwrap();
427
428        let result = parse_middleware_output(&output, 1).unwrap();
429        assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
430    }
431
432    #[test]
433    fn parse_raw_output_continue() {
434        let output = b"raw request data";
435        let result = parse_middleware_output(output, 0).unwrap();
436        assert!(matches!(result, OnRequestResult::Continue(_)));
437    }
438
439    #[test]
440    fn parse_raw_output_short_circuit() {
441        let output = b"error response";
442        let result = parse_middleware_output(output, 1).unwrap();
443        assert!(matches!(result, OnRequestResult::ShortCircuit(_)));
444    }
445
446    #[test]
447    fn parse_empty_continue() {
448        let result = parse_middleware_output(&[], 0).unwrap();
449        assert!(matches!(result, OnRequestResult::Continue(data) if data.is_empty()));
450    }
451
452    #[test]
453    fn parse_empty_short_circuit_fails() {
454        let result = parse_middleware_output(&[], 1);
455        assert!(result.is_err());
456    }
457
458    /// Middleware output with Request metadata (body travels via side-channel).
459    /// Ensures metadata survives the parse → serialize cycle in parse_middleware_output.
460    #[test]
461    fn parse_continue_with_request_metadata() {
462        use barbacane_plugin_sdk::types::Request;
463        use std::collections::BTreeMap;
464
465        let req = Request {
466            method: "POST".into(),
467            path: "/upload".into(),
468            query: None,
469            headers: {
470                let mut h = BTreeMap::new();
471                h.insert("content-type".into(), "application/octet-stream".into());
472                h
473            },
474            body: None, // Body travels via side-channel, not in JSON
475            client_ip: "127.0.0.1".into(),
476            path_params: BTreeMap::new(),
477        };
478
479        // Build the structured output the macro produces
480        let output = serde_json::to_vec(&json!({
481            "action": 0,
482            "data": req
483        }))
484        .unwrap();
485
486        let result = parse_middleware_output(&output, 0).unwrap();
487        match result {
488            OnRequestResult::Continue(data) => {
489                let parsed: Request = serde_json::from_slice(&data).unwrap();
490                assert_eq!(parsed.method, "POST");
491                assert_eq!(parsed.path, "/upload");
492                assert_eq!(parsed.body, None); // body is serde(skip)
493            }
494            OnRequestResult::ShortCircuit(_) => panic!("expected Continue"),
495        }
496    }
497
498    /// Middleware short-circuit Response metadata (body travels via side-channel).
499    #[test]
500    fn parse_short_circuit_with_response_metadata() {
501        use barbacane_plugin_sdk::types::Response;
502        use std::collections::BTreeMap;
503
504        let resp = Response {
505            status: 403,
506            headers: {
507                let mut h = BTreeMap::new();
508                h.insert("content-type".into(), "application/json".into());
509                h
510            },
511            body: None, // Body travels via side-channel
512        };
513
514        let output = serde_json::to_vec(&json!({
515            "action": 1,
516            "data": resp
517        }))
518        .unwrap();
519
520        let result = parse_middleware_output(&output, 1).unwrap();
521        match result {
522            OnRequestResult::ShortCircuit(data) => {
523                let parsed: Response = serde_json::from_slice(&data).unwrap();
524                assert_eq!(parsed.status, 403);
525                assert_eq!(parsed.body, None); // body is serde(skip)
526            }
527            OnRequestResult::Continue(_) => panic!("expected ShortCircuit"),
528        }
529    }
530
531    #[test]
532    fn metrics_callback_type_accepts_closure() {
533        use std::cell::RefCell;
534        use std::rc::Rc;
535
536        // Verify the callback type works with a recording closure
537        let invocations = Rc::new(RefCell::new(Vec::new()));
538        let invocations_clone = invocations.clone();
539
540        let callback = move |name: &str, phase: &str, duration: f64, short_circuit: bool| {
541            invocations_clone.borrow_mut().push((
542                name.to_string(),
543                phase.to_string(),
544                duration,
545                short_circuit,
546            ));
547        };
548
549        // Verify the callback can be used as MetricsCallback
550        let metrics_callback: MetricsCallback<'_> = Some(&callback);
551        assert!(metrics_callback.is_some());
552
553        // Invoke the callback
554        if let Some(cb) = metrics_callback {
555            cb("test-middleware", "request", 0.001, false);
556            cb("test-middleware", "response", 0.002, true);
557        }
558
559        // Verify invocations were recorded
560        let recorded = invocations.borrow();
561        assert_eq!(recorded.len(), 2);
562        assert_eq!(recorded[0].0, "test-middleware");
563        assert_eq!(recorded[0].1, "request");
564        assert!(!recorded[0].3); // not short-circuit
565        assert_eq!(recorded[1].1, "response");
566        assert!(recorded[1].3); // short-circuit
567    }
568
569    #[test]
570    fn execute_on_request_empty_instances_returns_continue() {
571        let mut instances: Vec<PluginInstance> = vec![];
572        let request = b"test request";
573        let context = RequestContext::default();
574
575        let result = execute_on_request(&mut instances, request, context);
576        assert!(matches!(result, ChainResult::Continue { .. }));
577
578        if let ChainResult::Continue {
579            request: req,
580            context: _,
581        } = result
582        {
583            assert_eq!(req, request.to_vec());
584        }
585    }
586
587    #[test]
588    fn execute_on_response_empty_instances_returns_input() {
589        let mut instances: Vec<PluginInstance> = vec![];
590        let response = b"test response";
591        let context = RequestContext::default();
592
593        let result = execute_on_response(&mut instances, response, context);
594        assert_eq!(result, response.to_vec());
595    }
596
597    #[test]
598    fn execute_with_metrics_none_callback_works() {
599        let mut instances: Vec<PluginInstance> = vec![];
600        let request = b"test";
601        let context = RequestContext::default();
602
603        // Verify None callback doesn't cause issues
604        let result =
605            execute_on_request_with_metrics(&mut instances, request, context.clone(), None);
606        assert!(matches!(result, ChainResult::Continue { .. }));
607
608        let response = execute_on_response_with_metrics(&mut instances, request, context, None);
609        assert_eq!(response, request.to_vec());
610    }
611}