Skip to main content

shaperail_runtime/handlers/
controller.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use shaperail_core::{ShaperailError, WASM_HOOK_PREFIX};
7
8use crate::auth::extractor::AuthenticatedUser;
9use crate::plugins::{PluginContext, PluginUser, WasmRuntime};
10
11/// Context passed to controller functions for synchronous in-request business logic.
12///
13/// Controllers receive a mutable `Context` and can:
14/// - Modify `input` before the DB operation (before-controllers)
15/// - Read/modify `data` after the DB operation (after-controllers)
16/// - Return `Err(...)` to halt the request with an error response
17///
18/// # Example
19/// ```rust,ignore
20/// pub async fn validate_org(ctx: &mut Context) -> Result<(), ShaperailError> {
21///     if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
22///         ctx.input["email"] = serde_json::json!(email.to_lowercase());
23///     }
24///     Ok(())
25/// }
26/// ```
27pub struct Context {
28    /// Mutable input data. Before-controllers can modify what gets written to DB.
29    pub input: serde_json::Map<String, serde_json::Value>,
30    /// DB result data. `None` in before-controllers, `Some(...)` in after-controllers.
31    pub data: Option<serde_json::Value>,
32    /// The authenticated user, if present.
33    pub user: Option<AuthenticatedUser>,
34    /// Database pool for custom queries within the controller.
35    pub pool: sqlx::PgPool,
36    /// Request headers (read-only).
37    pub headers: HashMap<String, String>,
38    /// Extra response headers the controller wants to add.
39    pub response_headers: Vec<(String, String)>,
40    /// The tenant ID extracted from the authenticated user (M18).
41    /// Present when the resource has `tenant_key` and the user has a `tenant_id` claim.
42    pub tenant_id: Option<String>,
43}
44
45/// Type alias for controller function results.
46pub type ControllerResult = Result<(), ShaperailError>;
47
48/// Trait for controller functions that can be stored in the registry.
49pub trait ControllerHandler: Send + Sync {
50    fn call<'a>(
51        &'a self,
52        ctx: &'a mut Context,
53    ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>>;
54}
55
56/// Blanket implementation for named async functions that take `&mut Context`.
57///
58/// Use named async functions (not closures) for controller registration:
59///
60/// ```rust,ignore
61/// async fn normalize_email(ctx: &mut Context) -> ControllerResult {
62///     // modify ctx.input...
63///     Ok(())
64/// }
65/// map.register("users", "normalize_email", normalize_email);
66/// ```
67impl<F> ControllerHandler for F
68where
69    F: for<'a> AsyncControllerFn<'a> + Send + Sync,
70{
71    fn call<'a>(
72        &'a self,
73        ctx: &'a mut Context,
74    ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>> {
75        Box::pin(self.call_async(ctx))
76    }
77}
78
79/// Helper trait to express the async function signature with proper lifetimes.
80pub trait AsyncControllerFn<'a> {
81    type Fut: Future<Output = ControllerResult> + Send + 'a;
82    fn call_async(&self, ctx: &'a mut Context) -> Self::Fut;
83}
84
85impl<'a, F, Fut> AsyncControllerFn<'a> for F
86where
87    F: Fn(&'a mut Context) -> Fut + Send + Sync,
88    Fut: Future<Output = ControllerResult> + Send + 'a,
89{
90    type Fut = Fut;
91    fn call_async(&self, ctx: &'a mut Context) -> Self::Fut {
92        (self)(ctx)
93    }
94}
95
96/// Registry that maps (resource_name, function_name) to controller functions.
97///
98/// Follows the same pattern as `StoreRegistry` — generated code populates this
99/// at startup, and handlers look up controllers by name at request time.
100pub struct ControllerMap {
101    fns: HashMap<(String, String), Arc<dyn ControllerHandler>>,
102}
103
104impl ControllerMap {
105    /// Creates an empty controller registry.
106    pub fn new() -> Self {
107        Self {
108            fns: HashMap::new(),
109        }
110    }
111
112    /// Registers a controller function for a resource.
113    pub fn register<F>(&mut self, resource: &str, name: &str, f: F)
114    where
115        F: ControllerHandler + 'static,
116    {
117        self.fns
118            .insert((resource.to_string(), name.to_string()), Arc::new(f));
119    }
120
121    /// Calls a controller function by resource and name.
122    ///
123    /// Returns `Ok(())` if no controller is registered for this (resource, name) pair.
124    pub async fn call(&self, resource: &str, name: &str, ctx: &mut Context) -> ControllerResult {
125        if let Some(f) = self.fns.get(&(resource.to_string(), name.to_string())) {
126            f.call(ctx).await
127        } else {
128            Err(ShaperailError::Internal(format!(
129                "Controller '{name}' not found for resource '{resource}'. \
130                 Ensure the function exists in resources/{resource}.controller.rs"
131            )))
132        }
133    }
134
135    /// Returns true if a controller is registered for this (resource, name) pair.
136    pub fn has(&self, resource: &str, name: &str) -> bool {
137        self.fns
138            .contains_key(&(resource.to_string(), name.to_string()))
139    }
140}
141
142impl Default for ControllerMap {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148/// Dispatches a controller call, handling both Rust and WASM controllers.
149///
150/// If `name` starts with `wasm:`, delegates to the WASM runtime.
151/// Otherwise, looks up and calls a registered Rust controller function.
152pub async fn dispatch_controller(
153    name: &str,
154    resource: &str,
155    ctx: &mut Context,
156    controllers: Option<&ControllerMap>,
157    wasm_runtime: Option<&WasmRuntime>,
158) -> ControllerResult {
159    if let Some(wasm_path) = name.strip_prefix(WASM_HOOK_PREFIX) {
160        // WASM plugin path
161        let runtime = wasm_runtime.ok_or_else(|| {
162            ShaperailError::Internal(
163                "WASM plugin declared but no WasmRuntime configured".to_string(),
164            )
165        })?;
166
167        // Determine hook name based on whether we're in before or after phase.
168        // The caller should set ctx.data = None for before, Some(...) for after.
169        let hook_name = if ctx.data.is_none() {
170            "before_hook"
171        } else {
172            "after_hook"
173        };
174
175        let plugin_ctx = PluginContext {
176            input: ctx.input.clone(),
177            data: ctx.data.clone(),
178            user: ctx.user.as_ref().map(|u| PluginUser {
179                id: u.id.to_string(),
180                role: u.role.clone(),
181            }),
182            headers: ctx.headers.clone(),
183            tenant_id: ctx.tenant_id.clone(),
184        };
185
186        let result = runtime.call_hook(wasm_path, hook_name, &plugin_ctx).await?;
187
188        if !result.ok {
189            let msg = result
190                .error
191                .unwrap_or_else(|| "WASM plugin returned error".to_string());
192            return Err(ShaperailError::Validation(vec![
193                shaperail_core::FieldError {
194                    field: "wasm_plugin".to_string(),
195                    message: msg,
196                    code: "wasm_error".to_string(),
197                },
198            ]));
199        }
200
201        // Apply modifications from plugin back to context
202        if let Some(modified_ctx) = result.ctx {
203            ctx.input = modified_ctx.input;
204            if modified_ctx.data.is_some() {
205                ctx.data = modified_ctx.data;
206            }
207        }
208
209        Ok(())
210    } else {
211        // Rust controller function
212        let map = controllers.ok_or_else(|| {
213            ShaperailError::Internal(format!(
214                "Controller '{name}' declared for '{resource}' but no ControllerMap configured"
215            ))
216        })?;
217        map.call(resource, name, ctx).await
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    async fn normalize_email(ctx: &mut Context) -> ControllerResult {
226        if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
227            let lower = email.to_lowercase();
228            ctx.input["email"] = serde_json::json!(lower);
229        }
230        Ok(())
231    }
232
233    async fn noop(_ctx: &mut Context) -> ControllerResult {
234        Ok(())
235    }
236
237    fn test_pool() -> sqlx::PgPool {
238        sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap()
239    }
240
241    #[tokio::test]
242    async fn controller_map_register_and_call() {
243        let mut map = ControllerMap::new();
244        map.register("users", "normalize_email", normalize_email);
245
246        let mut input = serde_json::Map::new();
247        input.insert("email".to_string(), serde_json::json!("USER@EXAMPLE.COM"));
248
249        let mut ctx = Context {
250            input,
251            data: None,
252            user: None,
253            pool: test_pool(),
254            headers: HashMap::new(),
255            response_headers: vec![],
256            tenant_id: None,
257        };
258
259        map.call("users", "normalize_email", &mut ctx)
260            .await
261            .unwrap();
262        assert_eq!(ctx.input["email"], serde_json::json!("user@example.com"));
263    }
264
265    #[tokio::test]
266    async fn controller_map_missing_returns_error() {
267        let map = ControllerMap::new();
268        let mut ctx = Context {
269            input: serde_json::Map::new(),
270            data: None,
271            user: None,
272            pool: test_pool(),
273            headers: HashMap::new(),
274            response_headers: vec![],
275            tenant_id: None,
276        };
277
278        let result = map.call("users", "nonexistent", &mut ctx).await;
279        assert!(result.is_err());
280    }
281
282    #[test]
283    fn controller_map_has() {
284        let mut map = ControllerMap::new();
285        assert!(!map.has("users", "check"));
286        map.register("users", "check", noop);
287        assert!(map.has("users", "check"));
288    }
289}