1use crate::bindings::{RequestBinding, ResponseBinding};
4use crate::context::ScriptContext;
5use crate::engine::RhaiEngine;
6use crate::error::Result;
7use armature_core::{HttpRequest, HttpResponse};
8use rhai::{Dynamic, Scope};
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11use tracing::{debug, instrument};
12
13pub struct ScriptHandler {
17 engine: Arc<RhaiEngine>,
18 script_path: PathBuf,
19}
20
21impl ScriptHandler {
22 pub fn new(engine: Arc<RhaiEngine>, script_path: impl Into<PathBuf>) -> Self {
24 Self {
25 engine,
26 script_path: script_path.into(),
27 }
28 }
29
30 #[instrument(skip(self, request), fields(path = %self.script_path.display()))]
32 pub async fn handle(&self, request: HttpRequest) -> Result<HttpResponse> {
33 let req_binding = RequestBinding::from_request(&request);
35
36 let context = ScriptContext::new(req_binding);
38 let mut scope = context.into_scope();
39
40 let script = self.engine.compile_file(&self.script_path)?;
42
43 debug!("Executing script handler");
44
45 let result = self.engine.eval(&script, &mut scope)?;
46
47 self.result_to_response(result, &scope)
49 }
50
51 fn result_to_response(&self, result: Dynamic, scope: &Scope) -> Result<HttpResponse> {
53 if result.is::<ResponseBinding>() {
55 let response: ResponseBinding = result.cast();
56 return Ok(response.into_http_response());
57 }
58
59 if let Some(response) = scope.get_value::<ResponseBinding>("response") {
61 return Ok(response.into_http_response());
62 }
63
64 if result.is_string() {
66 let text: String = result.cast();
67 let mut resp = HttpResponse::new(200);
68 resp.headers.insert("content-type".to_string(), "text/plain".to_string());
69 return Ok(resp.with_body(text.into_bytes()));
70 }
71
72 if result.is_map() || result.is_array() {
74 let mut binding = ResponseBinding::new();
75 let json_resp = binding.json(result)?.into_http_response();
76 return Ok(json_resp);
77 }
78
79 Ok(HttpResponse::new(200))
81 }
82
83 pub fn script_path(&self) -> &Path {
85 &self.script_path
86 }
87}
88
89pub struct ScriptMiddleware {
93 engine: Arc<RhaiEngine>,
94 before_script: Option<PathBuf>,
95 after_script: Option<PathBuf>,
96}
97
98impl ScriptMiddleware {
99 pub fn before(engine: Arc<RhaiEngine>, script_path: impl Into<PathBuf>) -> Self {
101 Self {
102 engine,
103 before_script: Some(script_path.into()),
104 after_script: None,
105 }
106 }
107
108 pub fn after(engine: Arc<RhaiEngine>, script_path: impl Into<PathBuf>) -> Self {
110 Self {
111 engine,
112 before_script: None,
113 after_script: Some(script_path.into()),
114 }
115 }
116
117 pub fn both(
119 engine: Arc<RhaiEngine>,
120 before: impl Into<PathBuf>,
121 after: impl Into<PathBuf>,
122 ) -> Self {
123 Self {
124 engine,
125 before_script: Some(before.into()),
126 after_script: Some(after.into()),
127 }
128 }
129
130 #[instrument(skip(self, request), fields(script = ?self.before_script))]
135 pub async fn call_before(&self, request: &HttpRequest) -> Result<Option<HttpResponse>> {
136 let Some(script_path) = &self.before_script else {
137 return Ok(None);
138 };
139
140 let req_binding = RequestBinding::from_request(request);
141 let context = ScriptContext::new(req_binding);
142 let mut scope = context.into_scope();
143
144 scope.push("continue", true);
146
147 let script = self.engine.compile_file(script_path)?;
148 let result = self.engine.eval(&script, &mut scope)?;
149
150 if let Some(should_continue) = scope.get_value::<bool>("continue") {
152 if !should_continue {
153 if result.is::<ResponseBinding>() {
155 let response: ResponseBinding = result.cast();
156 return Ok(Some(response.into_http_response()));
157 }
158 if let Some(response) = scope.get_value::<ResponseBinding>("response") {
159 return Ok(Some(response.into_http_response()));
160 }
161 }
162 }
163
164 Ok(None)
165 }
166
167 #[instrument(skip(self, request, response), fields(script = ?self.after_script))]
169 pub async fn call_after(
170 &self,
171 request: &HttpRequest,
172 response: HttpResponse,
173 ) -> Result<HttpResponse> {
174 let Some(script_path) = &self.after_script else {
175 return Ok(response);
176 };
177
178 let req_binding = RequestBinding::from_request(request);
179 let mut context = ScriptContext::new(req_binding);
180
181 context.set_local("status", Dynamic::from(response.status as i64));
183
184 let mut scope = context.into_scope();
185
186 let resp_binding = ResponseBinding::new();
188 scope.push("original_response", resp_binding);
189
190 let script = self.engine.compile_file(script_path)?;
191 let result = self.engine.eval(&script, &mut scope)?;
192
193 if result.is::<ResponseBinding>() {
195 let response: ResponseBinding = result.cast();
196 return Ok(response.into_http_response());
197 }
198
199 Ok(response)
201 }
202}
203
204pub type ScriptHandlerFn = Box<dyn Fn(HttpRequest) -> Result<HttpResponse> + Send + Sync>;
206
207pub fn script_handler(
209 engine: Arc<RhaiEngine>,
210 script_path: impl Into<PathBuf>,
211) -> ScriptHandlerFn {
212 let handler = Arc::new(ScriptHandler::new(engine, script_path));
213
214 Box::new(move |request| {
215 let handler = handler.clone();
218 let rt = tokio::runtime::Handle::try_current().expect("must be called from async context");
219
220 rt.block_on(handler.handle(request))
221 })
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 fn _create_test_request() -> HttpRequest {
229 HttpRequest::new("GET".to_string(), "/".to_string())
230 }
231
232 #[tokio::test]
233 async fn test_script_handler_basic() {
234 }
237}