shaperail_runtime/handlers/
controller.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use shaperail_core::{ShaperailError, WASM_HOOK_PREFIX};
7
8use crate::auth::extractor::AuthenticatedUser;
9#[cfg(feature = "wasm-plugins")]
10use crate::plugins::{PluginContext, PluginUser, WasmRuntime};
11
12pub struct Context {
29 pub input: serde_json::Map<String, serde_json::Value>,
31 pub data: Option<serde_json::Value>,
33 pub user: Option<AuthenticatedUser>,
35 pub pool: sqlx::PgPool,
37 pub headers: HashMap<String, String>,
39 pub response_headers: Vec<(String, String)>,
41 pub tenant_id: Option<String>,
44}
45
46pub type ControllerResult = Result<(), ShaperailError>;
48
49pub trait ControllerHandler: Send + Sync {
51 fn call<'a>(
52 &'a self,
53 ctx: &'a mut Context,
54 ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>>;
55}
56
57impl<F> ControllerHandler for F
69where
70 F: for<'a> AsyncControllerFn<'a> + Send + Sync,
71{
72 fn call<'a>(
73 &'a self,
74 ctx: &'a mut Context,
75 ) -> Pin<Box<dyn Future<Output = ControllerResult> + Send + 'a>> {
76 Box::pin(self.call_async(ctx))
77 }
78}
79
80pub trait AsyncControllerFn<'a> {
82 type Fut: Future<Output = ControllerResult> + Send + 'a;
83 fn call_async(&self, ctx: &'a mut Context) -> Self::Fut;
84}
85
86impl<'a, F, Fut> AsyncControllerFn<'a> for F
87where
88 F: Fn(&'a mut Context) -> Fut + Send + Sync,
89 Fut: Future<Output = ControllerResult> + Send + 'a,
90{
91 type Fut = Fut;
92 fn call_async(&self, ctx: &'a mut Context) -> Self::Fut {
93 (self)(ctx)
94 }
95}
96
97pub struct ControllerMap {
102 fns: HashMap<(String, String), Arc<dyn ControllerHandler>>,
103}
104
105impl ControllerMap {
106 pub fn new() -> Self {
108 Self {
109 fns: HashMap::new(),
110 }
111 }
112
113 pub fn register<F>(&mut self, resource: &str, name: &str, f: F)
115 where
116 F: ControllerHandler + 'static,
117 {
118 self.fns
119 .insert((resource.to_string(), name.to_string()), Arc::new(f));
120 }
121
122 pub async fn call(&self, resource: &str, name: &str, ctx: &mut Context) -> ControllerResult {
126 if let Some(f) = self.fns.get(&(resource.to_string(), name.to_string())) {
127 f.call(ctx).await
128 } else {
129 Err(ShaperailError::Internal(format!(
130 "Controller '{name}' not found for resource '{resource}'. \
131 Ensure the function exists in resources/{resource}.controller.rs"
132 )))
133 }
134 }
135
136 pub fn has(&self, resource: &str, name: &str) -> bool {
138 self.fns
139 .contains_key(&(resource.to_string(), name.to_string()))
140 }
141}
142
143impl Default for ControllerMap {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149#[cfg(feature = "wasm-plugins")]
154pub async fn dispatch_controller(
155 name: &str,
156 resource: &str,
157 ctx: &mut Context,
158 controllers: Option<&ControllerMap>,
159 wasm_runtime: Option<&WasmRuntime>,
160) -> ControllerResult {
161 if let Some(wasm_path) = name.strip_prefix(WASM_HOOK_PREFIX) {
162 let runtime = wasm_runtime.ok_or_else(|| {
164 ShaperailError::Internal(
165 "WASM plugin declared but no WasmRuntime configured".to_string(),
166 )
167 })?;
168
169 let hook_name = if ctx.data.is_none() {
172 "before_hook"
173 } else {
174 "after_hook"
175 };
176
177 let plugin_ctx = PluginContext {
178 input: ctx.input.clone(),
179 data: ctx.data.clone(),
180 user: ctx.user.as_ref().map(|u| PluginUser {
181 id: u.id.to_string(),
182 role: u.role.clone(),
183 }),
184 headers: ctx.headers.clone(),
185 tenant_id: ctx.tenant_id.clone(),
186 };
187
188 let result = runtime.call_hook(wasm_path, hook_name, &plugin_ctx).await?;
189
190 if !result.ok {
191 let msg = result
192 .error
193 .unwrap_or_else(|| "WASM plugin returned error".to_string());
194 return Err(ShaperailError::Validation(vec![
195 shaperail_core::FieldError {
196 field: "wasm_plugin".to_string(),
197 message: msg,
198 code: "wasm_error".to_string(),
199 },
200 ]));
201 }
202
203 if let Some(modified_ctx) = result.ctx {
205 ctx.input = modified_ctx.input;
206 if modified_ctx.data.is_some() {
207 ctx.data = modified_ctx.data;
208 }
209 }
210
211 Ok(())
212 } else {
213 dispatch_rust_controller(name, resource, ctx, controllers).await
214 }
215}
216
217#[cfg(not(feature = "wasm-plugins"))]
221pub async fn dispatch_controller(
222 name: &str,
223 resource: &str,
224 ctx: &mut Context,
225 controllers: Option<&ControllerMap>,
226 _wasm_runtime: Option<&()>,
227) -> ControllerResult {
228 if name.starts_with(WASM_HOOK_PREFIX) {
229 return Err(ShaperailError::Internal(
230 "WASM plugin declared but the 'wasm-plugins' feature is not enabled. \
231 Add `features = [\"wasm-plugins\"]` to your shaperail-runtime dependency."
232 .to_string(),
233 ));
234 }
235 dispatch_rust_controller(name, resource, ctx, controllers).await
236}
237
238async fn dispatch_rust_controller(
240 name: &str,
241 resource: &str,
242 ctx: &mut Context,
243 controllers: Option<&ControllerMap>,
244) -> ControllerResult {
245 let map = controllers.ok_or_else(|| {
246 ShaperailError::Internal(format!(
247 "Controller '{name}' declared for '{resource}' but no ControllerMap configured"
248 ))
249 })?;
250 map.call(resource, name, ctx).await
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 async fn normalize_email(ctx: &mut Context) -> ControllerResult {
258 if let Some(email) = ctx.input.get("email").and_then(|v| v.as_str()) {
259 let lower = email.to_lowercase();
260 ctx.input["email"] = serde_json::json!(lower);
261 }
262 Ok(())
263 }
264
265 async fn noop(_ctx: &mut Context) -> ControllerResult {
266 Ok(())
267 }
268
269 fn test_pool() -> sqlx::PgPool {
270 sqlx::PgPool::connect_lazy("postgres://localhost/test").unwrap()
271 }
272
273 #[tokio::test]
274 async fn controller_map_register_and_call() {
275 let mut map = ControllerMap::new();
276 map.register("users", "normalize_email", normalize_email);
277
278 let mut input = serde_json::Map::new();
279 input.insert("email".to_string(), serde_json::json!("USER@EXAMPLE.COM"));
280
281 let mut ctx = Context {
282 input,
283 data: None,
284 user: None,
285 pool: test_pool(),
286 headers: HashMap::new(),
287 response_headers: vec![],
288 tenant_id: None,
289 };
290
291 map.call("users", "normalize_email", &mut ctx)
292 .await
293 .unwrap();
294 assert_eq!(ctx.input["email"], serde_json::json!("user@example.com"));
295 }
296
297 #[tokio::test]
298 async fn controller_map_missing_returns_error() {
299 let map = ControllerMap::new();
300 let mut ctx = Context {
301 input: serde_json::Map::new(),
302 data: None,
303 user: None,
304 pool: test_pool(),
305 headers: HashMap::new(),
306 response_headers: vec![],
307 tenant_id: None,
308 };
309
310 let result = map.call("users", "nonexistent", &mut ctx).await;
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn controller_map_has() {
316 let mut map = ControllerMap::new();
317 assert!(!map.has("users", "check"));
318 map.register("users", "check", noop);
319 assert!(map.has("users", "check"));
320 }
321}