use async_trait::async_trait;
use serde_json::{Map, Value};
use crate::core::exceptions::OperonError;
#[derive(Debug, Clone, Default)]
pub struct MiddlewareContext {
pub user_id: String,
pub session_id: String,
pub request_id: String,
pub extra: Map<String, Value>,
}
#[async_trait]
pub trait Middleware: Send + Sync {
async fn before_run(
&self,
inputs: Map<String, Value>,
_ctx: &MiddlewareContext,
) -> Result<Map<String, Value>, OperonError> {
Ok(inputs)
}
async fn after_run(
&self,
_inputs: &Map<String, Value>,
result: Map<String, Value>,
_ctx: &MiddlewareContext,
) -> Result<Map<String, Value>, OperonError> {
Ok(result)
}
async fn on_error(
&self,
_inputs: &Map<String, Value>,
error: OperonError,
_ctx: &MiddlewareContext,
) -> Result<(), OperonError> {
Err(error)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TagInput;
#[async_trait]
impl Middleware for TagInput {
async fn before_run(
&self,
mut inputs: Map<String, Value>,
_ctx: &MiddlewareContext,
) -> Result<Map<String, Value>, OperonError> {
inputs.insert("tagged".into(), Value::Bool(true));
Ok(inputs)
}
}
#[tokio::test]
async fn default_hooks_passthrough() {
struct Noop;
#[async_trait]
impl Middleware for Noop {}
let m = Noop;
let ctx = MiddlewareContext::default();
let inputs = Map::from_iter([("a".to_string(), Value::from(1))]);
let out = m.before_run(inputs.clone(), &ctx).await.unwrap();
assert_eq!(out, inputs);
let result = Map::from_iter([("b".to_string(), Value::from(2))]);
let out2 = m.after_run(&inputs, result.clone(), &ctx).await.unwrap();
assert_eq!(out2, result);
}
#[tokio::test]
async fn before_run_can_mutate_inputs() {
let m = TagInput;
let ctx = MiddlewareContext::default();
let out = m.before_run(Map::new(), &ctx).await.unwrap();
assert_eq!(out.get("tagged"), Some(&Value::Bool(true)));
}
}