Skip to main content

shaperail_runtime/handlers/
controller.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use shaperail_core::{ShaperailError, WASM_HOOK_PREFIX};
7
8use crate::auth::extractor::AuthenticatedUser;
9#[cfg(feature = "wasm-plugins")]
10use crate::plugins::{PluginContext, PluginUser, WasmRuntime};
11
12/// Context passed to controller functions for synchronous in-request business logic.
13///
14/// Controllers receive a mutable `Context` and can:
15/// - Modify `input` before the DB operation (before-controllers)
16/// - Read/modify `data` after the DB operation (after-controllers)
17/// - Return `Err(...)` to halt the request with an error response
18///
19/// # Example
20/// ```rust,ignore
21/// pub async fn validate_org(ctx: &mut Context) -> Result<(), ShaperailError> {
22///     if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
23///         ctx.input["email"] = serde_json::json!(email.to_lowercase());
24///     }
25///     Ok(())
26/// }
27/// ```
28pub struct Context {
29    /// Mutable input data. Before-controllers can modify what gets written to DB.
30    pub input: serde_json::Map<String, serde_json::Value>,
31    /// DB result data. `None` in before-controllers, `Some(...)` in after-controllers.
32    pub data: Option<serde_json::Value>,
33    /// The authenticated user, if present.
34    pub user: Option<AuthenticatedUser>,
35    /// Database pool for custom queries within the controller.
36    pub pool: sqlx::PgPool,
37    /// Request headers (read-only).
38    pub headers: HashMap<String, String>,
39    /// Extra response headers the controller wants to add.
40    pub response_headers: Vec<(String, String)>,
41    /// The tenant ID extracted from the authenticated user (M18).
42    /// Present when the resource has `tenant_key` and the user has a `tenant_id` claim.
43    pub tenant_id: Option<String>,
44}
45
46/// Type alias for controller function results.
47pub type ControllerResult = Result<(), ShaperailError>;
48
49/// Trait for controller functions that can be stored in the registry.
50pub trait ControllerHandler: Send + Sync {
51    fn call<'a>(
52        &'a self,
53        ctx: &'a mut Context,
54    ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>>;
55}
56
57/// Blanket implementation for named async functions that take `&mut Context`.
58///
59/// Use named async functions (not closures) for controller registration:
60///
61/// ```rust,ignore
62/// async fn normalize_email(ctx: &mut Context) -> ControllerResult {
63///     // modify ctx.input...
64///     Ok(())
65/// }
66/// map.register("users", "normalize_email", normalize_email);
67/// ```
68impl<F> ControllerHandler for F
69where
70    F: for<'a> AsyncControllerFn<'a> + Send + Sync,
71{
72    fn call<'a>(
73        &'a self,
74        ctx: &'a mut Context,
75    ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>> {
76        Box::pin(self.call_async(ctx))
77    }
78}
79
80/// Helper trait to express the async function signature with proper lifetimes.
81pub trait AsyncControllerFn<'a> {
82    type Fut: Future<Output = ControllerResult> + Send + 'a;
83    fn call_async(&self, ctx: &'a mut Context) -> Self::Fut;
84}
85
86impl<'a, F, Fut> AsyncControllerFn<'a> for F
87where
88    F: Fn(&'a mut Context) -> Fut + Send + Sync,
89    Fut: Future<Output = ControllerResult> + Send + 'a,
90{
91    type Fut = Fut;
92    fn call_async(&self, ctx: &'a mut Context) -> Self::Fut {
93        (self)(ctx)
94    }
95}
96
97/// Registry that maps (resource_name, function_name) to controller functions.
98///
99/// Follows the same pattern as `StoreRegistry` — generated code populates this
100/// at startup, and handlers look up controllers by name at request time.
101pub struct ControllerMap {
102    fns: HashMap<(String, String), Arc<dyn ControllerHandler>>,
103}
104
105impl ControllerMap {
106    /// Creates an empty controller registry.
107    pub fn new() -> Self {
108        Self {
109            fns: HashMap::new(),
110        }
111    }
112
113    /// Registers a controller function for a resource.
114    pub fn register<F>(&mut self, resource: &str, name: &str, f: F)
115    where
116        F: ControllerHandler + 'static,
117    {
118        self.fns
119            .insert((resource.to_string(), name.to_string()), Arc::new(f));
120    }
121
122    /// Calls a controller function by resource and name.
123    ///
124    /// Returns `Ok(())` if no controller is registered for this (resource, name) pair.
125    pub async fn call(&self, resource: &str, name: &str, ctx: &mut Context) -> ControllerResult {
126        if let Some(f) = self.fns.get(&(resource.to_string(), name.to_string())) {
127            f.call(ctx).await
128        } else {
129            Err(ShaperailError::Internal(format!(
130                "Controller '{name}' not found for resource '{resource}'. \
131                 Ensure the function exists in resources/{resource}.controller.rs"
132            )))
133        }
134    }
135
136    /// Returns true if a controller is registered for this (resource, name) pair.
137    pub fn has(&self, resource: &str, name: &str) -> bool {
138        self.fns
139            .contains_key(&(resource.to_string(), name.to_string()))
140    }
141}
142
143impl Default for ControllerMap {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149/// Dispatches a controller call, handling both Rust and WASM controllers.
150///
151/// If `name` starts with `wasm:`, delegates to the WASM runtime.
152/// Otherwise, looks up and calls a registered Rust controller function.
153#[cfg(feature = "wasm-plugins")]
154pub async fn dispatch_controller(
155    name: &str,
156    resource: &str,
157    ctx: &mut Context,
158    controllers: Option<&ControllerMap>,
159    wasm_runtime: Option<&WasmRuntime>,
160) -> ControllerResult {
161    if let Some(wasm_path) = name.strip_prefix(WASM_HOOK_PREFIX) {
162        // WASM plugin path
163        let runtime = wasm_runtime.ok_or_else(|| {
164            ShaperailError::Internal(
165                "WASM plugin declared but no WasmRuntime configured".to_string(),
166            )
167        })?;
168
169        // Determine hook name based on whether we're in before or after phase.
170        // The caller should set ctx.data = None for before, Some(...) for after.
171        let hook_name = if ctx.data.is_none() {
172            "before_hook"
173        } else {
174            "after_hook"
175        };
176
177        let plugin_ctx = PluginContext {
178            input: ctx.input.clone(),
179            data: ctx.data.clone(),
180            user: ctx.user.as_ref().map(|u| PluginUser {
181                id: u.id.to_string(),
182                role: u.role.clone(),
183            }),
184            headers: ctx.headers.clone(),
185            tenant_id: ctx.tenant_id.clone(),
186        };
187
188        let result = runtime.call_hook(wasm_path, hook_name, &plugin_ctx).await?;
189
190        if !result.ok {
191            let msg = result
192                .error
193                .unwrap_or_else(|| "WASM plugin returned error".to_string());
194            return Err(ShaperailError::Validation(vec![
195                shaperail_core::FieldError {
196                    field: "wasm_plugin".to_string(),
197                    message: msg,
198                    code: "wasm_error".to_string(),
199                },
200            ]));
201        }
202
203        // Apply modifications from plugin back to context
204        if let Some(modified_ctx) = result.ctx {
205            ctx.input = modified_ctx.input;
206            if modified_ctx.data.is_some() {
207                ctx.data = modified_ctx.data;
208            }
209        }
210
211        Ok(())
212    } else {
213        dispatch_rust_controller(name, resource, ctx, controllers).await
214    }
215}
216
217/// Dispatches a controller call (Rust controllers only, WASM disabled).
218///
219/// Any `wasm:` prefix controllers return an error explaining the feature is not enabled.
220#[cfg(not(feature = "wasm-plugins"))]
221pub async fn dispatch_controller(
222    name: &str,
223    resource: &str,
224    ctx: &mut Context,
225    controllers: Option<&ControllerMap>,
226    _wasm_runtime: Option<&()>,
227) -> ControllerResult {
228    if name.starts_with(WASM_HOOK_PREFIX) {
229        return Err(ShaperailError::Internal(
230            "WASM plugin declared but the 'wasm-plugins' feature is not enabled. \
231             Add `features = [\"wasm-plugins\"]` to your shaperail-runtime dependency."
232                .to_string(),
233        ));
234    }
235    dispatch_rust_controller(name, resource, ctx, controllers).await
236}
237
238/// Shared Rust controller dispatch used by both feature-gated variants.
239async fn dispatch_rust_controller(
240    name: &str,
241    resource: &str,
242    ctx: &mut Context,
243    controllers: Option<&ControllerMap>,
244) -> ControllerResult {
245    let map = controllers.ok_or_else(|| {
246        ShaperailError::Internal(format!(
247            "Controller '{name}' declared for '{resource}' but no ControllerMap configured"
248        ))
249    })?;
250    map.call(resource, name, ctx).await
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    async fn normalize_email(ctx: &mut Context) -> ControllerResult {
258        if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
259            let lower = email.to_lowercase();
260            ctx.input["email"] = serde_json::json!(lower);
261        }
262        Ok(())
263    }
264
265    async fn noop(_ctx: &mut Context) -> ControllerResult {
266        Ok(())
267    }
268
269    fn test_pool() -> sqlx::PgPool {
270        sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap()
271    }
272
273    #[tokio::test]
274    async fn controller_map_register_and_call() {
275        let mut map = ControllerMap::new();
276        map.register("users", "normalize_email", normalize_email);
277
278        let mut input = serde_json::Map::new();
279        input.insert("email".to_string(), serde_json::json!("USER@EXAMPLE.COM"));
280
281        let mut ctx = Context {
282            input,
283            data: None,
284            user: None,
285            pool: test_pool(),
286            headers: HashMap::new(),
287            response_headers: vec![],
288            tenant_id: None,
289        };
290
291        map.call("users", "normalize_email", &mut ctx)
292            .await
293            .unwrap();
294        assert_eq!(ctx.input["email"], serde_json::json!("user@example.com"));
295    }
296
297    #[tokio::test]
298    async fn controller_map_missing_returns_error() {
299        let map = ControllerMap::new();
300        let mut ctx = Context {
301            input: serde_json::Map::new(),
302            data: None,
303            user: None,
304            pool: test_pool(),
305            headers: HashMap::new(),
306            response_headers: vec![],
307            tenant_id: None,
308        };
309
310        let result = map.call("users", "nonexistent", &mut ctx).await;
311        assert!(result.is_err());
312    }
313
314    #[test]
315    fn controller_map_has() {
316        let mut map = ControllerMap::new();
317        assert!(!map.has("users", "check"));
318        map.register("users", "check", noop);
319        assert!(map.has("users", "check"));
320    }
321}