Skip to main content

shaperail_runtime/handlers/
controller.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use shaperail_core::ShaperailError;
7
8use crate::auth::extractor::AuthenticatedUser;
9
10/// Context passed to controller functions for synchronous in-request business logic.
11///
12/// Controllers receive a mutable `Context` and can:
13/// - Modify `input` before the DB operation (before-controllers)
14/// - Read/modify `data` after the DB operation (after-controllers)
15/// - Return `Err(...)` to halt the request with an error response
16///
17/// # Example
18/// ```rust,ignore
19/// pub async fn validate_org(ctx: &mut Context) -> Result<(), ShaperailError> {
20///     if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
21///         ctx.input["email"] = serde_json::json!(email.to_lowercase());
22///     }
23///     Ok(())
24/// }
25/// ```
26pub struct Context {
27    /// Mutable input data. Before-controllers can modify what gets written to DB.
28    pub input: serde_json::Map<String, serde_json::Value>,
29    /// DB result data. `None` in before-controllers, `Some(...)` in after-controllers.
30    pub data: Option<serde_json::Value>,
31    /// The authenticated user, if present.
32    pub user: Option<AuthenticatedUser>,
33    /// Database pool for custom queries within the controller.
34    pub pool: sqlx::PgPool,
35    /// Request headers (read-only).
36    pub headers: HashMap<String, String>,
37    /// Extra response headers the controller wants to add.
38    pub response_headers: Vec<(String, String)>,
39}
40
41/// Type alias for controller function results.
42pub type ControllerResult = Result<(), ShaperailError>;
43
44/// Trait for controller functions that can be stored in the registry.
45pub trait ControllerHandler: Send + Sync {
46    fn call<'a>(
47        &'a self,
48        ctx: &'a mut Context,
49    ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>>;
50}
51
52/// Blanket implementation for named async functions that take `&mut Context`.
53///
54/// Use named async functions (not closures) for controller registration:
55///
56/// ```rust,ignore
57/// async fn normalize_email(ctx: &mut Context) -> ControllerResult {
58///     // modify ctx.input...
59///     Ok(())
60/// }
61/// map.register("users", "normalize_email", normalize_email);
62/// ```
63impl<F> ControllerHandler for F
64where
65    F: for<'a> AsyncControllerFn<'a> + Send + Sync,
66{
67    fn call<'a>(
68        &'a self,
69        ctx: &'a mut Context,
70    ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>> {
71        Box::pin(self.call_async(ctx))
72    }
73}
74
75/// Helper trait to express the async function signature with proper lifetimes.
76pub trait AsyncControllerFn<'a> {
77    type Fut: Future<Output = ControllerResult> + Send + 'a;
78    fn call_async(&self, ctx: &'a mut Context) -> Self::Fut;
79}
80
81impl<'a, F, Fut> AsyncControllerFn<'a> for F
82where
83    F: Fn(&'a mut Context) -> Fut + Send + Sync,
84    Fut: Future<Output = ControllerResult> + Send + 'a,
85{
86    type Fut = Fut;
87    fn call_async(&self, ctx: &'a mut Context) -> Self::Fut {
88        (self)(ctx)
89    }
90}
91
92/// Registry that maps (resource_name, function_name) to controller functions.
93///
94/// Follows the same pattern as `StoreRegistry` — generated code populates this
95/// at startup, and handlers look up controllers by name at request time.
96pub struct ControllerMap {
97    fns: HashMap<(String, String), Arc<dyn ControllerHandler>>,
98}
99
100impl ControllerMap {
101    /// Creates an empty controller registry.
102    pub fn new() -> Self {
103        Self {
104            fns: HashMap::new(),
105        }
106    }
107
108    /// Registers a controller function for a resource.
109    pub fn register<F>(&mut self, resource: &str, name: &str, f: F)
110    where
111        F: ControllerHandler + 'static,
112    {
113        self.fns
114            .insert((resource.to_string(), name.to_string()), Arc::new(f));
115    }
116
117    /// Calls a controller function by resource and name.
118    ///
119    /// Returns `Ok(())` if no controller is registered for this (resource, name) pair.
120    pub async fn call(&self, resource: &str, name: &str, ctx: &mut Context) -> ControllerResult {
121        if let Some(f) = self.fns.get(&(resource.to_string(), name.to_string())) {
122            f.call(ctx).await
123        } else {
124            Err(ShaperailError::Internal(format!(
125                "Controller '{name}' not found for resource '{resource}'. \
126                 Ensure the function exists in resources/{resource}.controller.rs"
127            )))
128        }
129    }
130
131    /// Returns true if a controller is registered for this (resource, name) pair.
132    pub fn has(&self, resource: &str, name: &str) -> bool {
133        self.fns
134            .contains_key(&(resource.to_string(), name.to_string()))
135    }
136}
137
138impl Default for ControllerMap {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    async fn normalize_email(ctx: &mut Context) -> ControllerResult {
149        if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
150            let lower = email.to_lowercase();
151            ctx.input["email"] = serde_json::json!(lower);
152        }
153        Ok(())
154    }
155
156    async fn noop(_ctx: &mut Context) -> ControllerResult {
157        Ok(())
158    }
159
160    fn test_pool() -> sqlx::PgPool {
161        sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap()
162    }
163
164    #[tokio::test]
165    async fn controller_map_register_and_call() {
166        let mut map = ControllerMap::new();
167        map.register("users", "normalize_email", normalize_email);
168
169        let mut input = serde_json::Map::new();
170        input.insert("email".to_string(), serde_json::json!("USER@EXAMPLE.COM"));
171
172        let mut ctx = Context {
173            input,
174            data: None,
175            user: None,
176            pool: test_pool(),
177            headers: HashMap::new(),
178            response_headers: vec![],
179        };
180
181        map.call("users", "normalize_email", &mut ctx)
182            .await
183            .unwrap();
184        assert_eq!(ctx.input["email"], serde_json::json!("user@example.com"));
185    }
186
187    #[tokio::test]
188    async fn controller_map_missing_returns_error() {
189        let map = ControllerMap::new();
190        let mut ctx = Context {
191            input: serde_json::Map::new(),
192            data: None,
193            user: None,
194            pool: test_pool(),
195            headers: HashMap::new(),
196            response_headers: vec![],
197        };
198
199        let result = map.call("users", "nonexistent", &mut ctx).await;
200        assert!(result.is_err());
201    }
202
203    #[test]
204    fn controller_map_has() {
205        let mut map = ControllerMap::new();
206        assert!(!map.has("users", "check"));
207        map.register("users", "check", noop);
208        assert!(map.has("users", "check"));
209    }
210}