1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23struct HttpClient(reqwest::Client);
24impl UserData for HttpClient {}
25
26pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
27 register_http(lua, client)?;
28 register_json(lua)?;
29 register_yaml(lua)?;
30 register_toml(lua)?;
31 register_assert(lua)?;
32 register_log(lua)?;
33 register_env(lua)?;
34 register_sleep(lua)?;
35 register_time(lua)?;
36 register_fs(lua)?;
37 register_base64(lua)?;
38 register_crypto(lua)?;
39 register_regex(lua)?;
40 register_async(lua)?;
41 register_db(lua)?;
42 register_ws(lua)?;
43 register_template(lua)?;
44 Ok(())
45}
46
47async fn execute_http_request(
48 lua: &Lua,
49 client: &reqwest::Client,
50 method_name: &str,
51 args: mlua::MultiValue,
52) -> mlua::Result<Value> {
53 let has_body = method_name != "get" && method_name != "delete";
54
55 let mut args_iter = args.into_iter();
56 let url: String = match args_iter.next() {
57 Some(Value::String(s)) => s.to_str()?.to_string(),
58 _ => {
59 return Err(mlua::Error::runtime(format!(
60 "http.{method_name}: first argument must be a URL string"
61 )));
62 }
63 };
64
65 let (body_str, auto_json, opts) = if has_body {
66 let (body, is_json) = match args_iter.next() {
67 Some(Value::String(s)) => (s.to_str()?.to_string(), false),
68 Some(Value::Table(t)) => {
69 let json_val = lua_table_to_json(&t)?;
70 let serialized = serde_json::to_string(&json_val).map_err(|e| {
71 mlua::Error::runtime(format!(
72 "http.{method_name}: JSON encode failed: {e}"
73 ))
74 })?;
75 (serialized, true)
76 }
77 Some(Value::Nil) | None => (String::new(), false),
78 _ => {
79 return Err(mlua::Error::runtime(format!(
80 "http.{method_name}: second argument must be a string, table, or nil"
81 )));
82 }
83 };
84 let opts = match args_iter.next() {
85 Some(Value::Table(t)) => Some(t),
86 Some(Value::Nil) | None => None,
87 _ => {
88 return Err(mlua::Error::runtime(format!(
89 "http.{method_name}: third argument must be a table or nil"
90 )));
91 }
92 };
93 (body, is_json, opts)
94 } else {
95 let opts = match args_iter.next() {
96 Some(Value::Table(t)) => Some(t),
97 Some(Value::Nil) | None => None,
98 _ => {
99 return Err(mlua::Error::runtime(format!(
100 "http.{method_name}: second argument must be a table or nil"
101 )));
102 }
103 };
104 (String::new(), false, opts)
105 };
106
107 let mut req = match method_name {
108 "get" => client.get(&url),
109 "post" => client.post(&url),
110 "put" => client.put(&url),
111 "patch" => client.patch(&url),
112 "delete" => client.delete(&url),
113 _ => {
114 return Err(mlua::Error::runtime(format!(
115 "http: unsupported method: {method_name}"
116 )));
117 }
118 };
119
120 if has_body && !body_str.is_empty() {
121 req = req.body(body_str);
122 }
123 if auto_json {
124 req = req.header("Content-Type", "application/json");
125 }
126 if let Some(ref opts_table) = opts
127 && let Ok(headers_table) = opts_table.get::<Table>("headers")
128 {
129 for pair in headers_table.pairs::<String, String>() {
130 let (k, v) = pair?;
131 req = req.header(k, v);
132 }
133 }
134
135 let resp = req.send().await.map_err(|e| {
136 mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
137 })?;
138 let status = resp.status().as_u16();
139 let resp_headers = resp.headers().clone();
140 let body = resp.text().await.map_err(|e| {
141 mlua::Error::runtime(format!(
142 "http.{method_name}: reading body failed: {e}"
143 ))
144 })?;
145
146 let result = lua.create_table()?;
147 result.set("status", status)?;
148 result.set("body", body)?;
149
150 let headers_out = lua.create_table()?;
151 for (name, value) in &resp_headers {
152 if let Ok(v) = value.to_str() {
153 headers_out.set(name.as_str().to_string(), v.to_string())?;
154 }
155 }
156 result.set("headers", headers_out)?;
157
158 Ok(Value::Table(result))
159}
160
161fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
162 let http_table = lua.create_table()?;
163
164 for method in ["get", "post", "put", "patch", "delete"] {
165 let method_client = client.clone();
166 let method_name = method.to_string();
167
168 let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
169 let client = method_client.clone();
170 let method_name = method_name.clone();
171 async move {
172 execute_http_request(&lua, &client, &method_name, args).await
173 }
174 })?;
175 http_table.set(method, func)?;
176 }
177
178 let client_fn = lua.create_async_function(|lua, opts: Option<Table>| async move {
179 let mut builder = reqwest::Client::builder();
180
181 let timeout_secs: f64 = opts
182 .as_ref()
183 .and_then(|t| t.get::<f64>("timeout").ok())
184 .unwrap_or(30.0);
185 builder = builder.timeout(std::time::Duration::from_secs_f64(timeout_secs));
186
187 if let Some(ref opts_table) = opts {
188 if let Ok(ca_path) = opts_table.get::<String>("ca_cert_file") {
189 let pem = std::fs::read(&ca_path).map_err(|e| {
190 mlua::Error::runtime(format!(
191 "http.client: failed to read CA cert file {ca_path:?}: {e}"
192 ))
193 })?;
194 let cert = reqwest::Certificate::from_pem(&pem).map_err(|e| {
195 mlua::Error::runtime(format!(
196 "http.client: invalid PEM in {ca_path:?}: {e}"
197 ))
198 })?;
199 builder = builder.add_root_certificate(cert);
200 }
201 if let Ok(ca_pem) = opts_table.get::<String>("ca_cert") {
202 let cert = reqwest::Certificate::from_pem(ca_pem.as_bytes()).map_err(|e| {
203 mlua::Error::runtime(format!("http.client: invalid CA cert PEM: {e}"))
204 })?;
205 builder = builder.add_root_certificate(cert);
206 }
207 }
208
209 let client = builder.build().map_err(|e| {
210 mlua::Error::runtime(format!("http.client: failed to build client: {e}"))
211 })?;
212
213 let ud = lua.create_any_userdata(HttpClient(client))?;
214
215 let wrapper: Table = lua
216 .load(
217 r#"
218 local ud = ...
219 local obj = { _ud = ud }
220 setmetatable(obj, {
221 __index = {
222 get = function(self, url, opts)
223 return http._client_request(self._ud, "get", url, opts)
224 end,
225 post = function(self, url, body, opts)
226 return http._client_request(self._ud, "post", url, body, opts)
227 end,
228 put = function(self, url, body, opts)
229 return http._client_request(self._ud, "put", url, body, opts)
230 end,
231 patch = function(self, url, body, opts)
232 return http._client_request(self._ud, "patch", url, body, opts)
233 end,
234 delete = function(self, url, opts)
235 return http._client_request(self._ud, "delete", url, opts)
236 end,
237 }
238 })
239 return obj
240 "#,
241 )
242 .call(ud)?;
243
244 Ok(Value::Table(wrapper))
245 })?;
246 http_table.set("client", client_fn)?;
247
248 let client_request_fn =
249 lua.create_async_function(|lua, args: mlua::MultiValue| async move {
250 let mut args_iter = args.into_iter();
251
252 let client = match args_iter.next() {
253 Some(Value::UserData(ud)) => {
254 let hc = ud.borrow::<HttpClient>().map_err(|_| {
255 mlua::Error::runtime(
256 "http._client_request: first arg must be an http client",
257 )
258 })?;
259 hc.0.clone()
260 }
261 _ => {
262 return Err(mlua::Error::runtime(
263 "http._client_request: first arg must be an http client",
264 ));
265 }
266 };
267
268 let method_name: String = match args_iter.next() {
269 Some(Value::String(s)) => s.to_str()?.to_string(),
270 _ => {
271 return Err(mlua::Error::runtime(
272 "http._client_request: second arg must be method name",
273 ));
274 }
275 };
276
277 let remaining: mlua::MultiValue = args_iter.collect();
278 execute_http_request(&lua, &client, &method_name, remaining).await
279 })?;
280 http_table.set("_client_request", client_request_fn)?;
281
282 let serve_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
283 let mut args_iter = args.into_iter();
284
285 let port: u16 = match args_iter.next() {
286 Some(Value::Integer(n)) => n as u16,
287 _ => {
288 return Err::<(), _>(mlua::Error::runtime(
289 "http.serve: first argument must be a port number",
290 ));
291 }
292 };
293
294 let routes_table = match args_iter.next() {
295 Some(Value::Table(t)) => t,
296 _ => {
297 return Err::<(), _>(mlua::Error::runtime(
298 "http.serve: second argument must be a routes table",
299 ));
300 }
301 };
302
303 let routes = Rc::new(parse_routes(&routes_table)?);
304
305 let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
306 .await
307 .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
308
309 loop {
310 let (stream, _addr) = listener
311 .accept()
312 .await
313 .map_err(|e| mlua::Error::runtime(format!("http.serve: accept failed: {e}")))?;
314
315 let routes = routes.clone();
316 let lua_clone = lua.clone();
317
318 tokio::task::spawn_local(async move {
319 let io = hyper_util::rt::TokioIo::new(stream);
320 let routes = routes.clone();
321 let lua = lua_clone.clone();
322
323 let service = service_fn(move |req: Request<Incoming>| {
324 let routes = routes.clone();
325 let lua = lua.clone();
326 async move { handle_request(&lua, &routes, req).await }
327 });
328
329 if let Err(e) = http1::Builder::new().serve_connection(io, service).await
330 && !e.to_string().contains("connection closed")
331 {
332 error!("http.serve: connection error: {e}");
333 }
334 });
335 }
336 })?;
337 http_table.set("serve", serve_fn)?;
338
339 lua.globals().set("http", http_table)?;
340 Ok(())
341}
342
343fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
344 let mut routes = HashMap::new();
345 for method_pair in routes_table.pairs::<String, Table>() {
346 let (method, paths_table) = method_pair?;
347 let method_upper = method.to_uppercase();
348 for path_pair in paths_table.pairs::<String, mlua::Function>() {
349 let (path, func) = path_pair?;
350 routes.insert((method_upper.clone(), path), func);
351 }
352 }
353 Ok(routes)
354}
355
356async fn handle_request(
357 lua: &Lua,
358 routes: &HashMap<(String, String), mlua::Function>,
359 req: Request<Incoming>,
360) -> Result<Response<Full<Bytes>>, hyper::Error> {
361 let method = req.method().to_string();
362 let path = req.uri().path().to_string();
363 let query = req.uri().query().unwrap_or("").to_string();
364 let headers: Vec<(String, String)> = req
365 .headers()
366 .iter()
367 .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
368 .collect();
369
370 let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
371 Ok(collected) => collected.to_bytes(),
372 Err(_) => Bytes::new(),
373 };
374 let body_str = String::from_utf8_lossy(&body_bytes).to_string();
375
376 let key = (method.clone(), path.clone());
377 let handler = match routes.get(&key) {
378 Some(f) => f,
379 None => {
380 return Ok(Response::builder()
381 .status(StatusCode::NOT_FOUND)
382 .header("content-type", "text/plain")
383 .body(Full::new(Bytes::from("not found")))
384 .unwrap());
385 }
386 };
387
388 match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
389 Ok(lua_resp) => lua_response_to_http(&lua_resp),
390 Err(e) => Ok(Response::builder()
391 .status(StatusCode::INTERNAL_SERVER_ERROR)
392 .header("content-type", "text/plain")
393 .body(Full::new(Bytes::from(format!("handler error: {e}"))))
394 .unwrap()),
395 }
396}
397
398fn build_lua_request_and_call(
399 lua: &Lua,
400 handler: &mlua::Function,
401 method: &str,
402 path: &str,
403 query: &str,
404 headers: &[(String, String)],
405 body: &str,
406) -> mlua::Result<Table> {
407 let req_table = lua.create_table()?;
408 req_table.set("method", method.to_string())?;
409 req_table.set("path", path.to_string())?;
410 req_table.set("query", query.to_string())?;
411 req_table.set("body", body.to_string())?;
412
413 let headers_table = lua.create_table()?;
414 for (k, v) in headers {
415 headers_table.set(k.as_str(), v.as_str())?;
416 }
417 req_table.set("headers", headers_table)?;
418
419 handler.call::<Table>(req_table)
420}
421
422fn lua_response_to_http(resp_table: &Table) -> Result<Response<Full<Bytes>>, hyper::Error> {
423 let status = resp_table
424 .get::<Option<u16>>("status")
425 .unwrap_or(None)
426 .unwrap_or(200);
427
428 let mut builder =
429 Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
430
431 if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
432 for (k, v) in headers_table.pairs::<String, String>().flatten() {
433 builder = builder.header(k, v);
434 }
435 }
436
437 let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
438 let json_val =
439 lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
440 let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
441 builder = builder.header("content-type", "application/json");
442 Bytes::from(serialized)
443 } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
444 builder = builder.header("content-type", "text/plain");
445 Bytes::from(body_str)
446 } else {
447 builder = builder.header("content-type", "text/plain");
448 Bytes::new()
449 };
450
451 Ok(builder.body(Full::new(body_bytes)).unwrap())
452}
453
454fn register_json(lua: &Lua) -> mlua::Result<()> {
455 let json_table = lua.create_table()?;
456
457 let parse_fn = lua.create_function(|lua, s: String| {
458 let value: serde_json::Value = serde_json::from_str(&s)
459 .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
460 json_value_to_lua(lua, &value)
461 })?;
462 json_table.set("parse", parse_fn)?;
463
464 let encode_fn = lua.create_function(|_, val: Value| {
465 let json_val = lua_value_to_json(&val)?;
466 serde_json::to_string(&json_val)
467 .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
468 })?;
469 json_table.set("encode", encode_fn)?;
470
471 lua.globals().set("json", json_table)?;
472 Ok(())
473}
474
475fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
476 let mut is_array = true;
477 let mut max_index: i64 = 0;
478 let mut count: i64 = 0;
479
480 for pair in table.clone().pairs::<Value, Value>() {
481 let (key, _) = pair?;
482 count += 1;
483 match key {
484 Value::Integer(i) if i >= 1 => {
485 if i > max_index {
486 max_index = i;
487 }
488 }
489 _ => {
490 is_array = false;
491 break;
492 }
493 }
494 }
495
496 if is_array && max_index == count {
497 let mut arr = Vec::with_capacity(max_index as usize);
498 for i in 1..=max_index {
499 let val: Value = table.get(i)?;
500 arr.push(lua_value_to_json(&val)?);
501 }
502 Ok(serde_json::Value::Array(arr))
503 } else {
504 let mut map = serde_json::Map::new();
505 for pair in table.clone().pairs::<Value, Value>() {
506 let (key, val) = pair?;
507 let key_str = match key {
508 Value::String(s) => s.to_str()?.to_string(),
509 Value::Integer(i) => i.to_string(),
510 Value::Number(f) => f.to_string(),
511 _ => {
512 return Err(mlua::Error::runtime(format!(
513 "unsupported table key type: {}",
514 key.type_name()
515 )));
516 }
517 };
518 map.insert(key_str, lua_value_to_json(&val)?);
519 }
520 Ok(serde_json::Value::Object(map))
521 }
522}
523
524fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
525 match val {
526 Value::Nil => Ok(serde_json::Value::Null),
527 Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
528 Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
529 Value::Number(f) => serde_json::Number::from_f64(*f)
530 .map(serde_json::Value::Number)
531 .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
532 Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
533 Value::Table(t) => lua_table_to_json(t),
534 _ => Err(mlua::Error::runtime(format!(
535 "unsupported Lua type for JSON: {}",
536 val.type_name()
537 ))),
538 }
539}
540
541pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
542 match val {
543 serde_json::Value::Null => Ok(Value::Nil),
544 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
545 serde_json::Value::Number(n) => {
546 if let Some(i) = n.as_i64() {
547 Ok(Value::Integer(i))
548 } else if let Some(f) = n.as_f64() {
549 Ok(Value::Number(f))
550 } else {
551 Ok(Value::Nil)
552 }
553 }
554 serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
555 serde_json::Value::Array(arr) => {
556 let table = lua.create_table()?;
557 for (i, item) in arr.iter().enumerate() {
558 table.set(i + 1, json_value_to_lua(lua, item)?)?;
559 }
560 Ok(Value::Table(table))
561 }
562 serde_json::Value::Object(map) => {
563 let table = lua.create_table()?;
564 for (k, v) in map {
565 table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
566 }
567 Ok(Value::Table(table))
568 }
569 }
570}
571
572fn register_assert(lua: &Lua) -> mlua::Result<()> {
573 let assert_table = lua.create_table()?;
574
575 let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
576 let mut args_iter = args.into_iter();
577 let a = args_iter.next().unwrap_or(Value::Nil);
578 let b = args_iter.next().unwrap_or(Value::Nil);
579 let msg = extract_string_arg(lua, args_iter.next());
580
581 if !lua_values_equal(&a, &b) {
582 let detail = format!(
583 "assert.eq failed: {:?} != {:?}{}",
584 format_lua_value(&a),
585 format_lua_value(&b),
586 msg.map(|m| format!(" - {m}")).unwrap_or_default()
587 );
588 return Err(mlua::Error::runtime(detail));
589 }
590 Ok(())
591 })?;
592 assert_table.set("eq", eq_fn)?;
593
594 let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
595 let mut args_iter = args.into_iter();
596 let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
597 let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
598 let msg = extract_string_arg(lua, args_iter.next());
599
600 match (a, b) {
601 (Some(va), Some(vb)) if va > vb => Ok(()),
602 (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
603 "assert.gt failed: {va} is not > {vb}{}",
604 msg.map(|m| format!(" - {m}")).unwrap_or_default()
605 ))),
606 _ => Err(mlua::Error::runtime(
607 "assert.gt: both arguments must be numbers",
608 )),
609 }
610 })?;
611 assert_table.set("gt", gt_fn)?;
612
613 let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
614 let mut args_iter = args.into_iter();
615 let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
616 let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
617 let msg = extract_string_arg(lua, args_iter.next());
618
619 match (a, b) {
620 (Some(va), Some(vb)) if va < vb => Ok(()),
621 (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
622 "assert.lt failed: {va} is not < {vb}{}",
623 msg.map(|m| format!(" - {m}")).unwrap_or_default()
624 ))),
625 _ => Err(mlua::Error::runtime(
626 "assert.lt: both arguments must be numbers",
627 )),
628 }
629 })?;
630 assert_table.set("lt", lt_fn)?;
631
632 let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
633 let mut args_iter = args.into_iter();
634 let haystack: String = match args_iter.next() {
635 Some(Value::String(s)) => s.to_str()?.to_string(),
636 _ => {
637 return Err(mlua::Error::runtime(
638 "assert.contains: first argument must be a string",
639 ));
640 }
641 };
642 let needle: String = match args_iter.next() {
643 Some(Value::String(s)) => s.to_str()?.to_string(),
644 _ => {
645 return Err(mlua::Error::runtime(
646 "assert.contains: second argument must be a string",
647 ));
648 }
649 };
650 let msg = extract_string_arg(lua, args_iter.next());
651
652 if !haystack.contains(&needle) {
653 return Err(mlua::Error::runtime(format!(
654 "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
655 msg.map(|m| format!(" - {m}")).unwrap_or_default()
656 )));
657 }
658 Ok(())
659 })?;
660 assert_table.set("contains", contains_fn)?;
661
662 let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
663 let mut args_iter = args.into_iter();
664 let val = args_iter.next().unwrap_or(Value::Nil);
665 let msg = extract_string_arg(lua, args_iter.next());
666
667 if val == Value::Nil {
668 return Err(mlua::Error::runtime(format!(
669 "assert.not_nil failed: value is nil{}",
670 msg.map(|m| format!(" - {m}")).unwrap_or_default()
671 )));
672 }
673 Ok(())
674 })?;
675 assert_table.set("not_nil", not_nil_fn)?;
676
677 let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
678 let mut args_iter = args.into_iter();
679 let text: String = match args_iter.next() {
680 Some(Value::String(s)) => s.to_str()?.to_string(),
681 _ => {
682 return Err(mlua::Error::runtime(
683 "assert.matches: first argument must be a string",
684 ));
685 }
686 };
687 let pattern: String = match args_iter.next() {
688 Some(Value::String(s)) => s.to_str()?.to_string(),
689 _ => {
690 return Err(mlua::Error::runtime(
691 "assert.matches: second argument must be a pattern string",
692 ));
693 }
694 };
695 let msg = extract_string_arg(lua, args_iter.next());
696
697 let found: bool = lua
698 .load(format!(
699 "return string.find({}, {}) ~= nil",
700 lua_string_literal(&text),
701 lua_string_literal(&pattern)
702 ))
703 .eval()
704 .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
705
706 if !found {
707 return Err(mlua::Error::runtime(format!(
708 "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
709 msg.map(|m| format!(" - {m}")).unwrap_or_default()
710 )));
711 }
712 Ok(())
713 })?;
714 assert_table.set("matches", matches_fn)?;
715
716 lua.globals().set("assert", assert_table)?;
717 Ok(())
718}
719
720fn register_log(lua: &Lua) -> mlua::Result<()> {
721 let log_table = lua.create_table()?;
722
723 let info_fn = lua.create_function(|_, msg: String| {
724 info!(target: "lua", "{}", msg);
725 Ok(())
726 })?;
727 log_table.set("info", info_fn)?;
728
729 let warn_fn = lua.create_function(|_, msg: String| {
730 warn!(target: "lua", "{}", msg);
731 Ok(())
732 })?;
733 log_table.set("warn", warn_fn)?;
734
735 let error_fn = lua.create_function(|_, msg: String| {
736 error!(target: "lua", "{}", msg);
737 Ok(())
738 })?;
739 log_table.set("error", error_fn)?;
740
741 lua.globals().set("log", log_table)?;
742 Ok(())
743}
744
745fn register_env(lua: &Lua) -> mlua::Result<()> {
746 let env_table = lua.create_table()?;
747
748 let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
749 Ok(val) => Ok(Some(val)),
750 Err(_) => Ok(None),
751 })?;
752 env_table.set("_process_get", process_get_fn)?;
753 env_table.set("_check_env", lua.create_table()?)?;
754
755 lua.globals().set("env", env_table)?;
756
757 lua.load(
758 r#"
759 function env.get(name)
760 local val = env._check_env[name]
761 if val ~= nil then return val end
762 return env._process_get(name)
763 end
764 "#,
765 )
766 .exec()?;
767
768 Ok(())
769}
770
771fn register_sleep(lua: &Lua) -> mlua::Result<()> {
772 let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
773 let duration = std::time::Duration::from_secs_f64(seconds);
774 tokio::time::sleep(duration).await;
775 Ok(())
776 })?;
777 lua.globals().set("sleep", sleep_fn)?;
778 Ok(())
779}
780
781fn register_time(lua: &Lua) -> mlua::Result<()> {
782 let time_fn = lua.create_function(|_, ()| {
783 let secs = SystemTime::now()
784 .duration_since(UNIX_EPOCH)
785 .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
786 .as_secs_f64();
787 Ok(secs)
788 })?;
789 lua.globals().set("time", time_fn)?;
790 Ok(())
791}
792
793fn register_fs(lua: &Lua) -> mlua::Result<()> {
794 let fs_table = lua.create_table()?;
795
796 let read_fn = lua.create_function(|_, path: String| {
797 std::fs::read_to_string(&path)
798 .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
799 })?;
800 fs_table.set("read", read_fn)?;
801
802 let write_fn = lua.create_function(|_, (path, content): (String, String)| {
803 let p = std::path::Path::new(&path);
804 if let Some(parent) = p.parent() {
805 std::fs::create_dir_all(parent).map_err(|e| {
806 mlua::Error::runtime(format!(
807 "fs.write: failed to create directories for {path:?}: {e}"
808 ))
809 })?;
810 }
811 std::fs::write(&path, &content)
812 .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
813 })?;
814 fs_table.set("write", write_fn)?;
815
816 lua.globals().set("fs", fs_table)?;
817 Ok(())
818}
819
820fn register_base64(lua: &Lua) -> mlua::Result<()> {
821 let b64_table = lua.create_table()?;
822
823 let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
824 b64_table.set("encode", encode_fn)?;
825
826 let decode_fn = lua.create_function(|_, input: String| {
827 let bytes = BASE64
828 .decode(input.as_bytes())
829 .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
830 String::from_utf8(bytes)
831 .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
832 })?;
833 b64_table.set("decode", decode_fn)?;
834
835 lua.globals().set("base64", b64_table)?;
836 Ok(())
837}
838
839fn register_crypto(lua: &Lua) -> mlua::Result<()> {
840 let crypto_table = lua.create_table()?;
841
842 let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
843 let mut args_iter = args.into_iter();
844
845 let claims_table = match args_iter.next() {
846 Some(Value::Table(t)) => t,
847 _ => {
848 return Err(mlua::Error::runtime(
849 "crypto.jwt_sign: first argument must be a claims table",
850 ));
851 }
852 };
853
854 let pem_key: String = match args_iter.next() {
855 Some(Value::String(s)) => s.to_str()?.to_string(),
856 _ => {
857 return Err(mlua::Error::runtime(
858 "crypto.jwt_sign: second argument must be a PEM key string",
859 ));
860 }
861 };
862
863 let algorithm = match args_iter.next() {
864 Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
865 "RS256" => Algorithm::RS256,
866 "RS384" => Algorithm::RS384,
867 "RS512" => Algorithm::RS512,
868 other => {
869 return Err(mlua::Error::runtime(format!(
870 "crypto.jwt_sign: unsupported algorithm: {other}"
871 )));
872 }
873 },
874 Some(Value::Nil) | None => Algorithm::RS256,
875 _ => {
876 return Err(mlua::Error::runtime(
877 "crypto.jwt_sign: third argument must be an algorithm string or nil",
878 ));
879 }
880 };
881
882 let opts = match args_iter.next() {
884 Some(Value::Table(t)) => Some(t),
885 Some(Value::Nil) | None => None,
886 _ => {
887 return Err(mlua::Error::runtime(
888 "crypto.jwt_sign: fourth argument must be an options table or nil",
889 ));
890 }
891 };
892
893 let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
894 let pem_bytes = Zeroizing::new(pem_key.into_bytes());
895 let key = EncodingKey::from_rsa_pem(&pem_bytes)
896 .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
897
898 let mut header = Header::new(algorithm);
899 if let Some(ref opts_table) = opts
900 && let Ok(kid) = opts_table.get::<String>("kid")
901 {
902 header.kid = Some(kid);
903 }
904 let token = jsonwebtoken::encode(&header, &claims_json, &key)
905 .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
906
907 Ok(token)
908 })?;
909 crypto_table.set("jwt_sign", jwt_sign_fn)?;
910
911 let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
912 let mut args_iter = args.into_iter();
913
914 let input: String = match args_iter.next() {
915 Some(Value::String(s)) => s.to_str()?.to_string(),
916 _ => {
917 return Err(mlua::Error::runtime(
918 "crypto.hash: first argument must be a string",
919 ));
920 }
921 };
922
923 let algorithm: String = match args_iter.next() {
924 Some(Value::String(s)) => s.to_str()?.to_lowercase(),
925 Some(Value::Nil) | None => "sha256".to_string(),
926 _ => {
927 return Err(mlua::Error::runtime(
928 "crypto.hash: second argument must be an algorithm string or nil",
929 ));
930 }
931 };
932
933 let hex = match algorithm.as_str() {
934 "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
935 "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
936 "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
937 "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
938 "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
939 "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
940 "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
941 "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
942 other => {
943 return Err(mlua::Error::runtime(format!(
944 "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
945 )));
946 }
947 };
948
949 Ok(hex)
950 })?;
951 crypto_table.set("hash", hash_fn)?;
952
953 let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
954 let mut args_iter = args.into_iter();
955
956 let length: usize = match args_iter.next() {
957 Some(Value::Integer(n)) if n > 0 => n as usize,
958 Some(Value::Integer(n)) => {
959 return Err(mlua::Error::runtime(format!(
960 "crypto.random: length must be positive, got {n}"
961 )));
962 }
963 Some(Value::Nil) | None => 32,
964 _ => {
965 return Err(mlua::Error::runtime(
966 "crypto.random: first argument must be a positive integer or nil",
967 ));
968 }
969 };
970
971 let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
972 let mut rng = rand::rng();
973 let result: String = (0..length)
974 .map(|_| charset[rng.random_range(..charset.len())] as char)
975 .collect();
976
977 Ok(result)
978 })?;
979 crypto_table.set("random", random_fn)?;
980
981 lua.globals().set("crypto", crypto_table)?;
982 Ok(())
983}
984
985fn register_yaml(lua: &Lua) -> mlua::Result<()> {
986 let yaml_table = lua.create_table()?;
987
988 let parse_fn = lua.create_function(|lua, s: String| {
989 let json_val: serde_json::Value = serde_yml::from_str(&s)
990 .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
991 json_value_to_lua(lua, &json_val)
992 })?;
993 yaml_table.set("parse", parse_fn)?;
994
995 let encode_fn = lua.create_function(|_, val: Value| {
996 let json_val = lua_value_to_json(&val)?;
997 serde_yml::to_string(&json_val)
998 .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
999 })?;
1000 yaml_table.set("encode", encode_fn)?;
1001
1002 lua.globals().set("yaml", yaml_table)?;
1003 Ok(())
1004}
1005
1006fn register_toml(lua: &Lua) -> mlua::Result<()> {
1007 let toml_table = lua.create_table()?;
1008
1009 let parse_fn = lua.create_function(|lua, s: String| {
1010 let toml_val: toml::Value =
1011 toml::from_str(&s).map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
1012 let json_val = serde_json::to_value(&toml_val)
1013 .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
1014 json_value_to_lua(lua, &json_val)
1015 })?;
1016 toml_table.set("parse", parse_fn)?;
1017
1018 let encode_fn = lua.create_function(|_, val: Value| {
1019 let json_val = lua_value_to_json(&val)?;
1020 let toml_val: toml::Value = serde_json::from_value(json_val)
1021 .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
1022 toml::to_string_pretty(&toml_val)
1023 .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
1024 })?;
1025 toml_table.set("encode", encode_fn)?;
1026
1027 lua.globals().set("toml", toml_table)?;
1028 Ok(())
1029}
1030
1031fn register_regex(lua: &Lua) -> mlua::Result<()> {
1032 let regex_table = lua.create_table()?;
1033
1034 let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
1035 let re = regex_lite::Regex::new(&pattern)
1036 .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
1037 Ok(re.is_match(&text))
1038 })?;
1039 regex_table.set("match", match_fn)?;
1040
1041 let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
1042 let re = regex_lite::Regex::new(&pattern)
1043 .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
1044 match re.captures(&text) {
1045 Some(caps) => {
1046 let result = lua.create_table()?;
1047 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1048 result.set("match", full_match.to_string())?;
1049 let groups = lua.create_table()?;
1050 for i in 1..caps.len() {
1051 if let Some(m) = caps.get(i) {
1052 groups.set(i, m.as_str().to_string())?;
1053 }
1054 }
1055 result.set("groups", groups)?;
1056 Ok(Value::Table(result))
1057 }
1058 None => Ok(Value::Nil),
1059 }
1060 })?;
1061 regex_table.set("find", find_fn)?;
1062
1063 let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
1064 let re = regex_lite::Regex::new(&pattern)
1065 .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
1066 let results = lua.create_table()?;
1067 for (i, m) in re.find_iter(&text).enumerate() {
1068 results.set(i + 1, m.as_str().to_string())?;
1069 }
1070 Ok(results)
1071 })?;
1072 regex_table.set("find_all", find_all_fn)?;
1073
1074 let replace_fn = lua.create_function(
1075 |_, (text, pattern, replacement): (String, String, String)| {
1076 let re = regex_lite::Regex::new(&pattern).map_err(|e| {
1077 mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
1078 })?;
1079 Ok(re.replace_all(&text, replacement.as_str()).into_owned())
1080 },
1081 )?;
1082 regex_table.set("replace", replace_fn)?;
1083
1084 lua.globals().set("regex", regex_table)?;
1085 Ok(())
1086}
1087
1088fn register_async(lua: &Lua) -> mlua::Result<()> {
1089 let async_table = lua.create_table()?;
1090
1091 let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
1092 let thread = lua.create_thread(func)?;
1093 let async_thread = thread.into_async::<mlua::MultiValue>(())?;
1094 let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
1095 tokio::task::spawn_local(async move {
1096 let values = async_thread.await.map_err(|e| e.to_string())?;
1097 Ok(values.into_vec())
1098 });
1099
1100 let handle = lua.create_table()?;
1101 let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
1102 let cell_clone = cell.clone();
1103
1104 let await_fn = lua.create_async_function(move |lua, ()| {
1105 let cell = cell_clone.clone();
1106 async move {
1107 let join_handle = cell
1108 .borrow_mut()
1109 .take()
1110 .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
1111 let result = join_handle.await.map_err(|e| {
1112 mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
1113 })?;
1114 match result {
1115 Ok(values) => {
1116 let tbl = lua.create_table()?;
1117 for (i, v) in values.into_iter().enumerate() {
1118 tbl.set(i + 1, v)?;
1119 }
1120 Ok(Value::Table(tbl))
1121 }
1122 Err(msg) => Err(mlua::Error::runtime(msg)),
1123 }
1124 }
1125 })?;
1126 handle.set("await", await_fn)?;
1127
1128 Ok(handle)
1129 })?;
1130 async_table.set("spawn", spawn_fn)?;
1131
1132 let spawn_interval_fn =
1133 lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
1134 if seconds <= 0.0 {
1135 return Err(mlua::Error::runtime(
1136 "async.spawn_interval: interval must be positive",
1137 ));
1138 }
1139
1140 let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1141 let cancel_clone = cancel.clone();
1142
1143 tokio::task::spawn_local({
1144 let cancel = cancel_clone.clone();
1145 async move {
1146 let mut interval =
1147 tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1148 interval.tick().await;
1149 loop {
1150 interval.tick().await;
1151 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1152 break;
1153 }
1154 if let Err(e) = func.call_async::<()>(()).await {
1155 error!("async.spawn_interval: callback error: {e}");
1156 break;
1157 }
1158 }
1159 }
1160 });
1161
1162 let handle = lua.create_table()?;
1163 let cancel_fn = lua.create_function(move |_, ()| {
1164 cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1165 Ok(())
1166 })?;
1167 handle.set("cancel", cancel_fn)?;
1168
1169 Ok(handle)
1170 })?;
1171 async_table.set("spawn_interval", spawn_interval_fn)?;
1172
1173 lua.globals().set("async", async_table)?;
1174 Ok(())
1175}
1176
1177struct DbPool(Arc<AnyPool>);
1178impl UserData for DbPool {}
1179
1180fn register_db(lua: &Lua) -> mlua::Result<()> {
1181 sqlx::any::install_default_drivers();
1182
1183 let db_table = lua.create_table()?;
1184
1185 let connect_fn = lua.create_async_function(|lua, url: String| async move {
1186 let pool = sqlx::any::AnyPoolOptions::new()
1187 .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1188 .connect(&url)
1189 .await
1190 .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1191 lua.create_any_userdata(DbPool(Arc::new(pool)))
1192 })?;
1193 db_table.set("connect", connect_fn)?;
1194
1195 let query_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1196 let mut args_iter = args.into_iter();
1197
1198 let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1199 let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1200 let params = extract_params(&args_iter.next())?;
1201
1202 let mut query = sqlx::query(&sql);
1203 for p in ¶ms {
1204 query = bind_param(query, p);
1205 }
1206
1207 let rows: Vec<AnyRow> = query
1208 .fetch_all(&*pool)
1209 .await
1210 .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1211
1212 let result = lua.create_table()?;
1213 for (i, row) in rows.iter().enumerate() {
1214 let row_table = any_row_to_lua_table(&lua, row)?;
1215 result.set(i + 1, row_table)?;
1216 }
1217 Ok(Value::Table(result))
1218 })?;
1219 db_table.set("query", query_fn)?;
1220
1221 let execute_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1222 let mut args_iter = args.into_iter();
1223
1224 let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1225 let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1226 let params = extract_params(&args_iter.next())?;
1227
1228 let mut query = sqlx::query(&sql);
1229 for p in ¶ms {
1230 query = bind_param(query, p);
1231 }
1232
1233 let result = query
1234 .execute(&*pool)
1235 .await
1236 .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1237
1238 let tbl = lua.create_table()?;
1239 tbl.set("rows_affected", result.rows_affected() as i64)?;
1240 Ok(Value::Table(tbl))
1241 })?;
1242 db_table.set("execute", execute_fn)?;
1243
1244 let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1245 let mut args_iter = args.into_iter();
1246 let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1247 pool.close().await;
1248 Ok(())
1249 })?;
1250 db_table.set("close", close_fn)?;
1251
1252 lua.globals().set("db", db_table)?;
1253 Ok(())
1254}
1255
1256fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1257 match val {
1258 Some(Value::UserData(ud)) => {
1259 let db = ud.borrow::<DbPool>().map_err(|_| {
1260 mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection"))
1261 })?;
1262 Ok(db.0.clone())
1263 }
1264 _ => Err(mlua::Error::runtime(format!(
1265 "{fn_name}: first argument must be a db connection"
1266 ))),
1267 }
1268}
1269
1270fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1271 match val {
1272 Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1273 _ => Err(mlua::Error::runtime(format!(
1274 "{fn_name}: second argument must be a SQL string"
1275 ))),
1276 }
1277}
1278
1279#[derive(Clone)]
1280enum DbParam {
1281 Null,
1282 Bool(bool),
1283 Int(i64),
1284 Float(f64),
1285 Text(String),
1286}
1287
1288fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1289 match val {
1290 Some(Value::Table(t)) => {
1291 let mut params = Vec::new();
1292 let len = t.len()?;
1293 for i in 1..=len {
1294 let v: Value = t.get(i)?;
1295 let param = match v {
1296 Value::Nil => DbParam::Null,
1297 Value::Boolean(b) => DbParam::Bool(b),
1298 Value::Integer(n) => DbParam::Int(n),
1299 Value::Number(f) => DbParam::Float(f),
1300 Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1301 _ => {
1302 return Err(mlua::Error::runtime(format!(
1303 "db: unsupported parameter type: {}",
1304 v.type_name()
1305 )));
1306 }
1307 };
1308 params.push(param);
1309 }
1310 Ok(params)
1311 }
1312 Some(Value::Nil) | None => Ok(Vec::new()),
1313 _ => Err(mlua::Error::runtime(
1314 "db: params must be a table (array) or nil",
1315 )),
1316 }
1317}
1318
1319fn bind_param<'q>(
1320 query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1321 param: &'q DbParam,
1322) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1323 match param {
1324 DbParam::Null => query.bind(None::<String>),
1325 DbParam::Bool(b) => query.bind(*b),
1326 DbParam::Int(n) => query.bind(*n),
1327 DbParam::Float(f) => query.bind(*f),
1328 DbParam::Text(s) => query.bind(s.as_str()),
1329 }
1330}
1331
1332fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1333 let table = lua.create_table()?;
1334 for col in row.columns() {
1335 let name = col.name();
1336 let val: Value = any_column_to_lua_value(lua, row, col)?;
1337 table.set(name.to_string(), val)?;
1338 }
1339 Ok(table)
1340}
1341
1342fn any_column_to_lua_value<C: Column>(lua: &Lua, row: &AnyRow, col: &C) -> mlua::Result<Value> {
1343 let ordinal = col.ordinal();
1344 let type_info = col.type_info();
1345 let type_name = type_info.to_string();
1346 let type_name = type_name.to_uppercase();
1347
1348 if row
1349 .try_get_raw(ordinal)
1350 .map(|v| v.is_null())
1351 .unwrap_or(true)
1352 {
1353 return Ok(Value::Nil);
1354 }
1355
1356 match type_name.as_str() {
1357 "BOOLEAN" | "BOOL" => {
1358 let v: bool = row
1359 .try_get(ordinal)
1360 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1361 Ok(Value::Boolean(v))
1362 }
1363 "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1364 let v: i64 = row
1365 .try_get(ordinal)
1366 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1367 Ok(Value::Integer(v))
1368 }
1369 "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1370 let v: f64 = row
1371 .try_get(ordinal)
1372 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1373 Ok(Value::Number(v))
1374 }
1375 _ => {
1376 let v: String = row
1377 .try_get(ordinal)
1378 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1379 Ok(Value::String(lua.create_string(&v)?))
1380 }
1381 }
1382}
1383
1384fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1385 match val {
1386 Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1387 Some(other) => {
1388 let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1389 result
1390 }
1391 None => None,
1392 }
1393}
1394
1395fn lua_values_equal(a: &Value, b: &Value) -> bool {
1396 match (a, b) {
1397 (Value::Nil, Value::Nil) => true,
1398 (Value::Boolean(a), Value::Boolean(b)) => a == b,
1399 (Value::Integer(a), Value::Integer(b)) => a == b,
1400 (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1401 (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1402 (*a as f64 - b).abs() < f64::EPSILON
1403 }
1404 (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1405 _ => false,
1406 }
1407}
1408
1409fn lua_value_to_f64(val: Value) -> Option<f64> {
1410 match val {
1411 Value::Integer(i) => Some(i as f64),
1412 Value::Number(f) => Some(f),
1413 _ => None,
1414 }
1415}
1416
1417fn format_lua_value(val: &Value) -> String {
1418 match val {
1419 Value::Nil => "nil".to_string(),
1420 Value::Boolean(b) => b.to_string(),
1421 Value::Integer(i) => i.to_string(),
1422 Value::Number(f) => f.to_string(),
1423 Value::String(s) => match s.to_str() {
1424 Ok(v) => v.to_string(),
1425 Err(_) => "<invalid utf-8>".to_string(),
1426 },
1427 _ => format!("<{}>", val.type_name()),
1428 }
1429}
1430
1431type WsSink = Rc<
1432 tokio::sync::Mutex<
1433 futures_util::stream::SplitSink<
1434 tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1435 tokio_tungstenite::tungstenite::Message,
1436 >,
1437 >,
1438>;
1439type WsStream = Rc<
1440 tokio::sync::Mutex<
1441 futures_util::stream::SplitStream<
1442 tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1443 >,
1444 >,
1445>;
1446
1447struct WsConn {
1448 sink: WsSink,
1449 stream: WsStream,
1450}
1451impl UserData for WsConn {}
1452
1453fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1454 let ud = match val {
1455 Value::UserData(ud) => ud,
1456 _ => {
1457 return Err(mlua::Error::runtime(format!(
1458 "{fn_name}: first argument must be a ws connection"
1459 )));
1460 }
1461 };
1462 let ws = ud.borrow::<WsConn>().map_err(|_| {
1463 mlua::Error::runtime(format!("{fn_name}: first argument must be a ws connection"))
1464 })?;
1465 Ok((ws.sink.clone(), ws.stream.clone()))
1466}
1467
1468fn register_ws(lua: &Lua) -> mlua::Result<()> {
1469 let ws_table = lua.create_table()?;
1470
1471 let connect_fn = lua.create_async_function(|lua, url: String| async move {
1472 let (stream, _response) = tokio_tungstenite::connect_async(&url)
1473 .await
1474 .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1475 let (sink, read) = stream.split();
1476 lua.create_any_userdata(WsConn {
1477 sink: Rc::new(tokio::sync::Mutex::new(sink)),
1478 stream: Rc::new(tokio::sync::Mutex::new(read)),
1479 })
1480 })?;
1481 ws_table.set("connect", connect_fn)?;
1482
1483 let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1484 let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1485 sink.lock()
1486 .await
1487 .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1488 .await
1489 .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1490 Ok(())
1491 })?;
1492 ws_table.set("send", send_fn)?;
1493
1494 let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1495 let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1496 loop {
1497 let msg = stream
1498 .lock()
1499 .await
1500 .next()
1501 .await
1502 .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1503 .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1504 match msg {
1505 tokio_tungstenite::tungstenite::Message::Text(t) => {
1506 return Ok(t.to_string());
1507 }
1508 tokio_tungstenite::tungstenite::Message::Binary(b) => {
1509 return String::from_utf8(b.into())
1510 .map_err(|e| mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}")));
1511 }
1512 tokio_tungstenite::tungstenite::Message::Close(_) => {
1513 return Err(mlua::Error::runtime("ws.recv: connection closed"));
1514 }
1515 _ => continue,
1516 }
1517 }
1518 })?;
1519 ws_table.set("recv", recv_fn)?;
1520
1521 let close_fn = lua.create_async_function(|_, conn: Value| async move {
1522 let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1523 sink.lock()
1524 .await
1525 .close()
1526 .await
1527 .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1528 Ok(())
1529 })?;
1530 ws_table.set("close", close_fn)?;
1531
1532 lua.globals().set("ws", ws_table)?;
1533 Ok(())
1534}
1535
1536fn register_template(lua: &Lua) -> mlua::Result<()> {
1537 let tmpl_table = lua.create_table()?;
1538
1539 let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1540 let json_vars = match &vars {
1541 Value::Table(_) => lua_value_to_json(&vars)?,
1542 Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1543 _ => {
1544 return Err(mlua::Error::runtime(
1545 "template.render_string: second argument must be a table or nil",
1546 ));
1547 }
1548 };
1549 let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1550 let env = minijinja::Environment::new();
1551 let tmpl = env
1552 .template_from_str(&template_str)
1553 .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1554 tmpl.render(mini_vars)
1555 .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1556 })?;
1557 tmpl_table.set("render_string", render_string_fn)?;
1558
1559 let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1560 let content = std::fs::read_to_string(&file_path).map_err(|e| {
1561 mlua::Error::runtime(format!(
1562 "template.render: failed to read {file_path:?}: {e}"
1563 ))
1564 })?;
1565 let json_vars = match &vars {
1566 Value::Table(_) => lua_value_to_json(&vars)?,
1567 Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1568 _ => {
1569 return Err(mlua::Error::runtime(
1570 "template.render: second argument must be a table or nil",
1571 ));
1572 }
1573 };
1574 let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1575 let env = minijinja::Environment::new();
1576 let tmpl = env
1577 .template_from_str(&content)
1578 .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1579 tmpl.render(mini_vars)
1580 .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1581 })?;
1582 tmpl_table.set("render", render_fn)?;
1583
1584 lua.globals().set("template", tmpl_table)?;
1585 Ok(())
1586}
1587
1588fn lua_string_literal(s: &str) -> String {
1589 let escaped = s
1590 .replace('\\', "\\\\")
1591 .replace('"', "\\\"")
1592 .replace('\n', "\\n")
1593 .replace('\r', "\\r")
1594 .replace('\0', "\\0");
1595 format!("\"{escaped}\"")
1596}
1597
1598#[cfg(test)]
1599mod tests {
1600 use super::*;
1601
1602 #[test]
1603 fn test_base64_roundtrip() {
1604 let input = "hello world";
1605 let encoded = BASE64.encode(input.as_bytes());
1606 assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1607 let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1608 assert_eq!(String::from_utf8(decoded).unwrap(), input);
1609 }
1610
1611 #[test]
1612 fn test_base64_empty() {
1613 let encoded = BASE64.encode(b"");
1614 assert_eq!(encoded, "");
1615 let decoded = BASE64.decode(b"").unwrap();
1616 assert!(decoded.is_empty());
1617 }
1618
1619 #[test]
1620 fn test_lua_value_to_json_nil() {
1621 let result = lua_value_to_json(&Value::Nil).unwrap();
1622 assert_eq!(result, serde_json::Value::Null);
1623 }
1624
1625 #[test]
1626 fn test_lua_value_to_json_bool() {
1627 assert_eq!(
1628 lua_value_to_json(&Value::Boolean(true)).unwrap(),
1629 serde_json::Value::Bool(true)
1630 );
1631 }
1632
1633 #[test]
1634 fn test_lua_value_to_json_integer() {
1635 assert_eq!(
1636 lua_value_to_json(&Value::Integer(42)).unwrap(),
1637 serde_json::json!(42)
1638 );
1639 }
1640
1641 #[test]
1642 fn test_lua_value_to_json_number() {
1643 assert_eq!(
1644 lua_value_to_json(&Value::Number(1.5)).unwrap(),
1645 serde_json::json!(1.5)
1646 );
1647 }
1648
1649 #[test]
1650 fn test_lua_values_equal_nil() {
1651 assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1652 }
1653
1654 #[test]
1655 fn test_lua_values_equal_int_float() {
1656 assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1657 }
1658
1659 #[test]
1660 fn test_lua_values_not_equal() {
1661 assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1662 }
1663
1664 #[test]
1665 fn test_format_lua_value() {
1666 assert_eq!(format_lua_value(&Value::Nil), "nil");
1667 assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1668 assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1669 }
1670
1671 #[test]
1672 fn test_lua_string_literal_escaping() {
1673 assert_eq!(lua_string_literal("hello"), "\"hello\"");
1674 assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1675 assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1676 }
1677}