1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
24 register_http(lua, client)?;
25 register_json(lua)?;
26 register_yaml(lua)?;
27 register_toml(lua)?;
28 register_assert(lua)?;
29 register_log(lua)?;
30 register_env(lua)?;
31 register_sleep(lua)?;
32 register_time(lua)?;
33 register_fs(lua)?;
34 register_base64(lua)?;
35 register_crypto(lua)?;
36 register_regex(lua)?;
37 register_async(lua)?;
38 register_db(lua)?;
39 register_ws(lua)?;
40 register_template(lua)?;
41 Ok(())
42}
43
44fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
45 let http_table = lua.create_table()?;
46
47 for method in ["get", "post", "put", "patch", "delete"] {
48 let method_client = client.clone();
49 let method_name = method.to_string();
50 let has_body = method != "get" && method != "delete";
51
52 let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
53 let client = method_client.clone();
54 let method_name = method_name.clone();
55 async move {
56 let mut args_iter = args.into_iter();
57 let url: String = match args_iter.next() {
58 Some(Value::String(s)) => s.to_str()?.to_string(),
59 _ => {
60 return Err(mlua::Error::runtime(format!(
61 "http.{method_name}: first argument must be a URL string"
62 )));
63 }
64 };
65
66 let (body_str, auto_json, opts) = if has_body {
67 let (body, is_json) = match args_iter.next() {
68 Some(Value::String(s)) => (s.to_str()?.to_string(), false),
69 Some(Value::Table(t)) => {
70 let json_val = lua_table_to_json(&t)?;
71 let serialized = serde_json::to_string(&json_val).map_err(|e| {
72 mlua::Error::runtime(format!(
73 "http.{method_name}: JSON encode failed: {e}"
74 ))
75 })?;
76 (serialized, true)
77 }
78 Some(Value::Nil) | None => (String::new(), false),
79 _ => {
80 return Err(mlua::Error::runtime(format!(
81 "http.{method_name}: second argument must be a string, table, or nil"
82 )));
83 }
84 };
85 let opts = match args_iter.next() {
86 Some(Value::Table(t)) => Some(t),
87 Some(Value::Nil) | None => None,
88 _ => {
89 return Err(mlua::Error::runtime(format!(
90 "http.{method_name}: third argument must be a table or nil"
91 )));
92 }
93 };
94 (body, is_json, opts)
95 } else {
96 let opts = match args_iter.next() {
97 Some(Value::Table(t)) => Some(t),
98 Some(Value::Nil) | None => None,
99 _ => {
100 return Err(mlua::Error::runtime(format!(
101 "http.{method_name}: second argument must be a table or nil"
102 )));
103 }
104 };
105 (String::new(), false, opts)
106 };
107
108 let mut req = match method_name.as_str() {
109 "get" => client.get(&url),
110 "post" => client.post(&url),
111 "put" => client.put(&url),
112 "patch" => client.patch(&url),
113 "delete" => client.delete(&url),
114 _ => unreachable!(),
115 };
116
117 if has_body && !body_str.is_empty() {
118 req = req.body(body_str);
119 }
120 if auto_json {
121 req = req.header("Content-Type", "application/json");
122 }
123 if let Some(ref opts_table) = opts
124 && let Ok(headers_table) = opts_table.get::<Table>("headers")
125 {
126 for pair in headers_table.pairs::<String, String>() {
127 let (k, v) = pair?;
128 req = req.header(k, v);
129 }
130 }
131
132 let resp = req.send().await.map_err(|e| {
133 mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
134 })?;
135 let status = resp.status().as_u16();
136 let resp_headers = resp.headers().clone();
137 let body = resp.text().await.map_err(|e| {
138 mlua::Error::runtime(format!(
139 "http.{method_name}: reading body failed: {e}"
140 ))
141 })?;
142
143 let result = lua.create_table()?;
144 result.set("status", status)?;
145 result.set("body", body)?;
146
147 let headers_out = lua.create_table()?;
148 for (name, value) in &resp_headers {
149 if let Ok(v) = value.to_str() {
150 headers_out.set(name.as_str().to_string(), v.to_string())?;
151 }
152 }
153 result.set("headers", headers_out)?;
154
155 Ok(Value::Table(result))
156 }
157 })?;
158 http_table.set(method, func)?;
159 }
160
161 let serve_fn =
162 lua.create_async_function(|lua, args: mlua::MultiValue| async move {
163 let mut args_iter = args.into_iter();
164
165 let port: u16 = match args_iter.next() {
166 Some(Value::Integer(n)) => n as u16,
167 _ => {
168 return Err::<(), _>(mlua::Error::runtime(
169 "http.serve: first argument must be a port number",
170 ));
171 }
172 };
173
174 let routes_table = match args_iter.next() {
175 Some(Value::Table(t)) => t,
176 _ => {
177 return Err::<(), _>(mlua::Error::runtime(
178 "http.serve: second argument must be a routes table",
179 ));
180 }
181 };
182
183 let routes = Rc::new(parse_routes(&routes_table)?);
184
185 let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
186 .await
187 .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
188
189 loop {
190 let (stream, _addr) = listener.accept().await.map_err(|e| {
191 mlua::Error::runtime(format!("http.serve: accept failed: {e}"))
192 })?;
193
194 let routes = routes.clone();
195 let lua_clone = lua.clone();
196
197 tokio::task::spawn_local(async move {
198 let io = hyper_util::rt::TokioIo::new(stream);
199 let routes = routes.clone();
200 let lua = lua_clone.clone();
201
202 let service = service_fn(move |req: Request<Incoming>| {
203 let routes = routes.clone();
204 let lua = lua.clone();
205 async move { handle_request(&lua, &routes, req).await }
206 });
207
208 if let Err(e) = http1::Builder::new().serve_connection(io, service).await
209 && !e.to_string().contains("connection closed")
210 {
211 error!("http.serve: connection error: {e}");
212 }
213 });
214 }
215 })?;
216 http_table.set("serve", serve_fn)?;
217
218 lua.globals().set("http", http_table)?;
219 Ok(())
220}
221
222fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
223 let mut routes = HashMap::new();
224 for method_pair in routes_table.pairs::<String, Table>() {
225 let (method, paths_table) = method_pair?;
226 let method_upper = method.to_uppercase();
227 for path_pair in paths_table.pairs::<String, mlua::Function>() {
228 let (path, func) = path_pair?;
229 routes.insert((method_upper.clone(), path), func);
230 }
231 }
232 Ok(routes)
233}
234
235async fn handle_request(
236 lua: &Lua,
237 routes: &HashMap<(String, String), mlua::Function>,
238 req: Request<Incoming>,
239) -> Result<Response<Full<Bytes>>, hyper::Error> {
240 let method = req.method().to_string();
241 let path = req.uri().path().to_string();
242 let query = req.uri().query().unwrap_or("").to_string();
243 let headers: Vec<(String, String)> = req
244 .headers()
245 .iter()
246 .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
247 .collect();
248
249 let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
250 Ok(collected) => collected.to_bytes(),
251 Err(_) => Bytes::new(),
252 };
253 let body_str = String::from_utf8_lossy(&body_bytes).to_string();
254
255 let key = (method.clone(), path.clone());
256 let handler = match routes.get(&key) {
257 Some(f) => f,
258 None => {
259 return Ok(Response::builder()
260 .status(StatusCode::NOT_FOUND)
261 .header("content-type", "text/plain")
262 .body(Full::new(Bytes::from("not found")))
263 .unwrap());
264 }
265 };
266
267 match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
268 Ok(lua_resp) => lua_response_to_http(&lua_resp),
269 Err(e) => Ok(Response::builder()
270 .status(StatusCode::INTERNAL_SERVER_ERROR)
271 .header("content-type", "text/plain")
272 .body(Full::new(Bytes::from(format!("handler error: {e}"))))
273 .unwrap()),
274 }
275}
276
277fn build_lua_request_and_call(
278 lua: &Lua,
279 handler: &mlua::Function,
280 method: &str,
281 path: &str,
282 query: &str,
283 headers: &[(String, String)],
284 body: &str,
285) -> mlua::Result<Table> {
286 let req_table = lua.create_table()?;
287 req_table.set("method", method.to_string())?;
288 req_table.set("path", path.to_string())?;
289 req_table.set("query", query.to_string())?;
290 req_table.set("body", body.to_string())?;
291
292 let headers_table = lua.create_table()?;
293 for (k, v) in headers {
294 headers_table.set(k.as_str(), v.as_str())?;
295 }
296 req_table.set("headers", headers_table)?;
297
298 handler.call::<Table>(req_table)
299}
300
301fn lua_response_to_http(
302 resp_table: &Table,
303) -> Result<Response<Full<Bytes>>, hyper::Error> {
304 let status = resp_table
305 .get::<Option<u16>>("status")
306 .unwrap_or(None)
307 .unwrap_or(200);
308
309 let mut builder =
310 Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
311
312 if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
313 for (k, v) in headers_table.pairs::<String, String>().flatten() {
314 builder = builder.header(k, v);
315 }
316 }
317
318 let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
319 let json_val =
320 lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
321 let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
322 builder = builder.header("content-type", "application/json");
323 Bytes::from(serialized)
324 } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
325 builder = builder.header("content-type", "text/plain");
326 Bytes::from(body_str)
327 } else {
328 builder = builder.header("content-type", "text/plain");
329 Bytes::new()
330 };
331
332 Ok(builder.body(Full::new(body_bytes)).unwrap())
333}
334
335fn register_json(lua: &Lua) -> mlua::Result<()> {
336 let json_table = lua.create_table()?;
337
338 let parse_fn = lua.create_function(|lua, s: String| {
339 let value: serde_json::Value = serde_json::from_str(&s)
340 .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
341 json_value_to_lua(lua, &value)
342 })?;
343 json_table.set("parse", parse_fn)?;
344
345 let encode_fn = lua.create_function(|_, val: Value| {
346 let json_val = lua_value_to_json(&val)?;
347 serde_json::to_string(&json_val)
348 .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
349 })?;
350 json_table.set("encode", encode_fn)?;
351
352 lua.globals().set("json", json_table)?;
353 Ok(())
354}
355
356fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
357 let mut is_array = true;
358 let mut max_index: i64 = 0;
359 let mut count: i64 = 0;
360
361 for pair in table.clone().pairs::<Value, Value>() {
362 let (key, _) = pair?;
363 count += 1;
364 match key {
365 Value::Integer(i) if i >= 1 => {
366 if i > max_index {
367 max_index = i;
368 }
369 }
370 _ => {
371 is_array = false;
372 break;
373 }
374 }
375 }
376
377 if is_array && max_index == count {
378 let mut arr = Vec::with_capacity(max_index as usize);
379 for i in 1..=max_index {
380 let val: Value = table.get(i)?;
381 arr.push(lua_value_to_json(&val)?);
382 }
383 Ok(serde_json::Value::Array(arr))
384 } else {
385 let mut map = serde_json::Map::new();
386 for pair in table.clone().pairs::<Value, Value>() {
387 let (key, val) = pair?;
388 let key_str = match key {
389 Value::String(s) => s.to_str()?.to_string(),
390 Value::Integer(i) => i.to_string(),
391 Value::Number(f) => f.to_string(),
392 _ => {
393 return Err(mlua::Error::runtime(format!(
394 "unsupported table key type: {}",
395 key.type_name()
396 )));
397 }
398 };
399 map.insert(key_str, lua_value_to_json(&val)?);
400 }
401 Ok(serde_json::Value::Object(map))
402 }
403}
404
405fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
406 match val {
407 Value::Nil => Ok(serde_json::Value::Null),
408 Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
409 Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
410 Value::Number(f) => serde_json::Number::from_f64(*f)
411 .map(serde_json::Value::Number)
412 .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
413 Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
414 Value::Table(t) => lua_table_to_json(t),
415 _ => Err(mlua::Error::runtime(format!(
416 "unsupported Lua type for JSON: {}",
417 val.type_name()
418 ))),
419 }
420}
421
422pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
423 match val {
424 serde_json::Value::Null => Ok(Value::Nil),
425 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
426 serde_json::Value::Number(n) => {
427 if let Some(i) = n.as_i64() {
428 Ok(Value::Integer(i))
429 } else if let Some(f) = n.as_f64() {
430 Ok(Value::Number(f))
431 } else {
432 Ok(Value::Nil)
433 }
434 }
435 serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
436 serde_json::Value::Array(arr) => {
437 let table = lua.create_table()?;
438 for (i, item) in arr.iter().enumerate() {
439 table.set(i + 1, json_value_to_lua(lua, item)?)?;
440 }
441 Ok(Value::Table(table))
442 }
443 serde_json::Value::Object(map) => {
444 let table = lua.create_table()?;
445 for (k, v) in map {
446 table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
447 }
448 Ok(Value::Table(table))
449 }
450 }
451}
452
453fn register_assert(lua: &Lua) -> mlua::Result<()> {
454 let assert_table = lua.create_table()?;
455
456 let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
457 let mut args_iter = args.into_iter();
458 let a = args_iter.next().unwrap_or(Value::Nil);
459 let b = args_iter.next().unwrap_or(Value::Nil);
460 let msg = extract_string_arg(lua, args_iter.next());
461
462 if !lua_values_equal(&a, &b) {
463 let detail = format!(
464 "assert.eq failed: {:?} != {:?}{}",
465 format_lua_value(&a),
466 format_lua_value(&b),
467 msg.map(|m| format!(" - {m}")).unwrap_or_default()
468 );
469 return Err(mlua::Error::runtime(detail));
470 }
471 Ok(())
472 })?;
473 assert_table.set("eq", eq_fn)?;
474
475 let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
476 let mut args_iter = args.into_iter();
477 let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
478 let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
479 let msg = extract_string_arg(lua, args_iter.next());
480
481 match (a, b) {
482 (Some(va), Some(vb)) if va > vb => Ok(()),
483 (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
484 "assert.gt failed: {va} is not > {vb}{}",
485 msg.map(|m| format!(" - {m}")).unwrap_or_default()
486 ))),
487 _ => Err(mlua::Error::runtime(
488 "assert.gt: both arguments must be numbers",
489 )),
490 }
491 })?;
492 assert_table.set("gt", gt_fn)?;
493
494 let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
495 let mut args_iter = args.into_iter();
496 let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
497 let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
498 let msg = extract_string_arg(lua, args_iter.next());
499
500 match (a, b) {
501 (Some(va), Some(vb)) if va < vb => Ok(()),
502 (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
503 "assert.lt failed: {va} is not < {vb}{}",
504 msg.map(|m| format!(" - {m}")).unwrap_or_default()
505 ))),
506 _ => Err(mlua::Error::runtime(
507 "assert.lt: both arguments must be numbers",
508 )),
509 }
510 })?;
511 assert_table.set("lt", lt_fn)?;
512
513 let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
514 let mut args_iter = args.into_iter();
515 let haystack: String = match args_iter.next() {
516 Some(Value::String(s)) => s.to_str()?.to_string(),
517 _ => {
518 return Err(mlua::Error::runtime(
519 "assert.contains: first argument must be a string",
520 ));
521 }
522 };
523 let needle: String = match args_iter.next() {
524 Some(Value::String(s)) => s.to_str()?.to_string(),
525 _ => {
526 return Err(mlua::Error::runtime(
527 "assert.contains: second argument must be a string",
528 ));
529 }
530 };
531 let msg = extract_string_arg(lua, args_iter.next());
532
533 if !haystack.contains(&needle) {
534 return Err(mlua::Error::runtime(format!(
535 "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
536 msg.map(|m| format!(" - {m}")).unwrap_or_default()
537 )));
538 }
539 Ok(())
540 })?;
541 assert_table.set("contains", contains_fn)?;
542
543 let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
544 let mut args_iter = args.into_iter();
545 let val = args_iter.next().unwrap_or(Value::Nil);
546 let msg = extract_string_arg(lua, args_iter.next());
547
548 if val == Value::Nil {
549 return Err(mlua::Error::runtime(format!(
550 "assert.not_nil failed: value is nil{}",
551 msg.map(|m| format!(" - {m}")).unwrap_or_default()
552 )));
553 }
554 Ok(())
555 })?;
556 assert_table.set("not_nil", not_nil_fn)?;
557
558 let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
559 let mut args_iter = args.into_iter();
560 let text: String = match args_iter.next() {
561 Some(Value::String(s)) => s.to_str()?.to_string(),
562 _ => {
563 return Err(mlua::Error::runtime(
564 "assert.matches: first argument must be a string",
565 ));
566 }
567 };
568 let pattern: String = match args_iter.next() {
569 Some(Value::String(s)) => s.to_str()?.to_string(),
570 _ => {
571 return Err(mlua::Error::runtime(
572 "assert.matches: second argument must be a pattern string",
573 ));
574 }
575 };
576 let msg = extract_string_arg(lua, args_iter.next());
577
578 let found: bool = lua
579 .load(format!(
580 "return string.find({}, {}) ~= nil",
581 lua_string_literal(&text),
582 lua_string_literal(&pattern)
583 ))
584 .eval()
585 .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
586
587 if !found {
588 return Err(mlua::Error::runtime(format!(
589 "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
590 msg.map(|m| format!(" - {m}")).unwrap_or_default()
591 )));
592 }
593 Ok(())
594 })?;
595 assert_table.set("matches", matches_fn)?;
596
597 lua.globals().set("assert", assert_table)?;
598 Ok(())
599}
600
601fn register_log(lua: &Lua) -> mlua::Result<()> {
602 let log_table = lua.create_table()?;
603
604 let info_fn = lua.create_function(|_, msg: String| {
605 info!(target: "lua", "{}", msg);
606 Ok(())
607 })?;
608 log_table.set("info", info_fn)?;
609
610 let warn_fn = lua.create_function(|_, msg: String| {
611 warn!(target: "lua", "{}", msg);
612 Ok(())
613 })?;
614 log_table.set("warn", warn_fn)?;
615
616 let error_fn = lua.create_function(|_, msg: String| {
617 error!(target: "lua", "{}", msg);
618 Ok(())
619 })?;
620 log_table.set("error", error_fn)?;
621
622 lua.globals().set("log", log_table)?;
623 Ok(())
624}
625
626fn register_env(lua: &Lua) -> mlua::Result<()> {
627 let env_table = lua.create_table()?;
628
629 let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
630 Ok(val) => Ok(Some(val)),
631 Err(_) => Ok(None),
632 })?;
633 env_table.set("_process_get", process_get_fn)?;
634 env_table.set("_check_env", lua.create_table()?)?;
635
636 lua.globals().set("env", env_table)?;
637
638 lua.load(
639 r#"
640 function env.get(name)
641 local val = env._check_env[name]
642 if val ~= nil then return val end
643 return env._process_get(name)
644 end
645 "#,
646 )
647 .exec()?;
648
649 Ok(())
650}
651
652fn register_sleep(lua: &Lua) -> mlua::Result<()> {
653 let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
654 let duration = std::time::Duration::from_secs_f64(seconds);
655 tokio::time::sleep(duration).await;
656 Ok(())
657 })?;
658 lua.globals().set("sleep", sleep_fn)?;
659 Ok(())
660}
661
662fn register_time(lua: &Lua) -> mlua::Result<()> {
663 let time_fn = lua.create_function(|_, ()| {
664 let secs = SystemTime::now()
665 .duration_since(UNIX_EPOCH)
666 .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
667 .as_secs_f64();
668 Ok(secs)
669 })?;
670 lua.globals().set("time", time_fn)?;
671 Ok(())
672}
673
674fn register_fs(lua: &Lua) -> mlua::Result<()> {
675 let fs_table = lua.create_table()?;
676
677 let read_fn = lua.create_function(|_, path: String| {
678 std::fs::read_to_string(&path)
679 .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
680 })?;
681 fs_table.set("read", read_fn)?;
682
683 let write_fn = lua.create_function(|_, (path, content): (String, String)| {
684 let p = std::path::Path::new(&path);
685 if let Some(parent) = p.parent() {
686 std::fs::create_dir_all(parent).map_err(|e| {
687 mlua::Error::runtime(format!(
688 "fs.write: failed to create directories for {path:?}: {e}"
689 ))
690 })?;
691 }
692 std::fs::write(&path, &content)
693 .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
694 })?;
695 fs_table.set("write", write_fn)?;
696
697 lua.globals().set("fs", fs_table)?;
698 Ok(())
699}
700
701fn register_base64(lua: &Lua) -> mlua::Result<()> {
702 let b64_table = lua.create_table()?;
703
704 let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
705 b64_table.set("encode", encode_fn)?;
706
707 let decode_fn = lua.create_function(|_, input: String| {
708 let bytes = BASE64
709 .decode(input.as_bytes())
710 .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
711 String::from_utf8(bytes)
712 .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
713 })?;
714 b64_table.set("decode", decode_fn)?;
715
716 lua.globals().set("base64", b64_table)?;
717 Ok(())
718}
719
720fn register_crypto(lua: &Lua) -> mlua::Result<()> {
721 let crypto_table = lua.create_table()?;
722
723 let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
724 let mut args_iter = args.into_iter();
725
726 let claims_table = match args_iter.next() {
727 Some(Value::Table(t)) => t,
728 _ => {
729 return Err(mlua::Error::runtime(
730 "crypto.jwt_sign: first argument must be a claims table",
731 ));
732 }
733 };
734
735 let pem_key: String = match args_iter.next() {
736 Some(Value::String(s)) => s.to_str()?.to_string(),
737 _ => {
738 return Err(mlua::Error::runtime(
739 "crypto.jwt_sign: second argument must be a PEM key string",
740 ));
741 }
742 };
743
744 let algorithm = match args_iter.next() {
745 Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
746 "RS256" => Algorithm::RS256,
747 "RS384" => Algorithm::RS384,
748 "RS512" => Algorithm::RS512,
749 other => {
750 return Err(mlua::Error::runtime(format!(
751 "crypto.jwt_sign: unsupported algorithm: {other}"
752 )));
753 }
754 },
755 Some(Value::Nil) | None => Algorithm::RS256,
756 _ => {
757 return Err(mlua::Error::runtime(
758 "crypto.jwt_sign: third argument must be an algorithm string or nil",
759 ));
760 }
761 };
762
763 let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
764 let pem_bytes = Zeroizing::new(pem_key.into_bytes());
765 let key = EncodingKey::from_rsa_pem(&pem_bytes)
766 .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
767
768 let header = Header::new(algorithm);
769 let token = jsonwebtoken::encode(&header, &claims_json, &key)
770 .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
771
772 Ok(token)
773 })?;
774 crypto_table.set("jwt_sign", jwt_sign_fn)?;
775
776 let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
777 let mut args_iter = args.into_iter();
778
779 let input: String = match args_iter.next() {
780 Some(Value::String(s)) => s.to_str()?.to_string(),
781 _ => {
782 return Err(mlua::Error::runtime(
783 "crypto.hash: first argument must be a string",
784 ));
785 }
786 };
787
788 let algorithm: String = match args_iter.next() {
789 Some(Value::String(s)) => s.to_str()?.to_lowercase(),
790 Some(Value::Nil) | None => "sha256".to_string(),
791 _ => {
792 return Err(mlua::Error::runtime(
793 "crypto.hash: second argument must be an algorithm string or nil",
794 ));
795 }
796 };
797
798 let hex = match algorithm.as_str() {
799 "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
800 "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
801 "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
802 "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
803 "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
804 "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
805 "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
806 "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
807 other => {
808 return Err(mlua::Error::runtime(format!(
809 "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
810 )));
811 }
812 };
813
814 Ok(hex)
815 })?;
816 crypto_table.set("hash", hash_fn)?;
817
818 let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
819 let mut args_iter = args.into_iter();
820
821 let length: usize = match args_iter.next() {
822 Some(Value::Integer(n)) if n > 0 => n as usize,
823 Some(Value::Integer(n)) => {
824 return Err(mlua::Error::runtime(format!(
825 "crypto.random: length must be positive, got {n}"
826 )));
827 }
828 Some(Value::Nil) | None => 32,
829 _ => {
830 return Err(mlua::Error::runtime(
831 "crypto.random: first argument must be a positive integer or nil",
832 ));
833 }
834 };
835
836 let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
837 let mut rng = rand::rng();
838 let result: String = (0..length)
839 .map(|_| charset[rng.random_range(..charset.len())] as char)
840 .collect();
841
842 Ok(result)
843 })?;
844 crypto_table.set("random", random_fn)?;
845
846 lua.globals().set("crypto", crypto_table)?;
847 Ok(())
848}
849
850fn register_yaml(lua: &Lua) -> mlua::Result<()> {
851 let yaml_table = lua.create_table()?;
852
853 let parse_fn = lua.create_function(|lua, s: String| {
854 let json_val: serde_json::Value = serde_yml::from_str(&s)
855 .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
856 json_value_to_lua(lua, &json_val)
857 })?;
858 yaml_table.set("parse", parse_fn)?;
859
860 let encode_fn = lua.create_function(|_, val: Value| {
861 let json_val = lua_value_to_json(&val)?;
862 serde_yml::to_string(&json_val)
863 .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
864 })?;
865 yaml_table.set("encode", encode_fn)?;
866
867 lua.globals().set("yaml", yaml_table)?;
868 Ok(())
869}
870
871fn register_toml(lua: &Lua) -> mlua::Result<()> {
872 let toml_table = lua.create_table()?;
873
874 let parse_fn = lua.create_function(|lua, s: String| {
875 let toml_val: toml::Value = toml::from_str(&s)
876 .map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
877 let json_val = serde_json::to_value(&toml_val)
878 .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
879 json_value_to_lua(lua, &json_val)
880 })?;
881 toml_table.set("parse", parse_fn)?;
882
883 let encode_fn = lua.create_function(|_, val: Value| {
884 let json_val = lua_value_to_json(&val)?;
885 let toml_val: toml::Value = serde_json::from_value(json_val)
886 .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
887 toml::to_string_pretty(&toml_val)
888 .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
889 })?;
890 toml_table.set("encode", encode_fn)?;
891
892 lua.globals().set("toml", toml_table)?;
893 Ok(())
894}
895
896fn register_regex(lua: &Lua) -> mlua::Result<()> {
897 let regex_table = lua.create_table()?;
898
899 let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
900 let re = regex_lite::Regex::new(&pattern)
901 .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
902 Ok(re.is_match(&text))
903 })?;
904 regex_table.set("match", match_fn)?;
905
906 let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
907 let re = regex_lite::Regex::new(&pattern)
908 .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
909 match re.captures(&text) {
910 Some(caps) => {
911 let result = lua.create_table()?;
912 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
913 result.set("match", full_match.to_string())?;
914 let groups = lua.create_table()?;
915 for i in 1..caps.len() {
916 if let Some(m) = caps.get(i) {
917 groups.set(i, m.as_str().to_string())?;
918 }
919 }
920 result.set("groups", groups)?;
921 Ok(Value::Table(result))
922 }
923 None => Ok(Value::Nil),
924 }
925 })?;
926 regex_table.set("find", find_fn)?;
927
928 let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
929 let re = regex_lite::Regex::new(&pattern)
930 .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
931 let results = lua.create_table()?;
932 for (i, m) in re.find_iter(&text).enumerate() {
933 results.set(i + 1, m.as_str().to_string())?;
934 }
935 Ok(results)
936 })?;
937 regex_table.set("find_all", find_all_fn)?;
938
939 let replace_fn = lua.create_function(
940 |_, (text, pattern, replacement): (String, String, String)| {
941 let re = regex_lite::Regex::new(&pattern).map_err(|e| {
942 mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
943 })?;
944 Ok(re.replace_all(&text, replacement.as_str()).into_owned())
945 },
946 )?;
947 regex_table.set("replace", replace_fn)?;
948
949 lua.globals().set("regex", regex_table)?;
950 Ok(())
951}
952
953fn register_async(lua: &Lua) -> mlua::Result<()> {
954 let async_table = lua.create_table()?;
955
956 let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
957 let thread = lua.create_thread(func)?;
958 let async_thread = thread.into_async::<mlua::MultiValue>(())?;
959 let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
960 tokio::task::spawn_local(async move {
961 let values = async_thread.await.map_err(|e| e.to_string())?;
962 Ok(values.into_vec())
963 });
964
965 let handle = lua.create_table()?;
966 let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
967 let cell_clone = cell.clone();
968
969 let await_fn = lua.create_async_function(move |lua, ()| {
970 let cell = cell_clone.clone();
971 async move {
972 let join_handle = cell
973 .borrow_mut()
974 .take()
975 .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
976 let result = join_handle.await.map_err(|e| {
977 mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
978 })?;
979 match result {
980 Ok(values) => {
981 let tbl = lua.create_table()?;
982 for (i, v) in values.into_iter().enumerate() {
983 tbl.set(i + 1, v)?;
984 }
985 Ok(Value::Table(tbl))
986 }
987 Err(msg) => Err(mlua::Error::runtime(msg)),
988 }
989 }
990 })?;
991 handle.set("await", await_fn)?;
992
993 Ok(handle)
994 })?;
995 async_table.set("spawn", spawn_fn)?;
996
997 let spawn_interval_fn =
998 lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
999 if seconds <= 0.0 {
1000 return Err(mlua::Error::runtime(
1001 "async.spawn_interval: interval must be positive",
1002 ));
1003 }
1004
1005 let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1006 let cancel_clone = cancel.clone();
1007
1008 tokio::task::spawn_local({
1009 let cancel = cancel_clone.clone();
1010 async move {
1011 let mut interval =
1012 tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1013 interval.tick().await;
1014 loop {
1015 interval.tick().await;
1016 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1017 break;
1018 }
1019 if let Err(e) = func.call_async::<()>(()).await {
1020 error!("async.spawn_interval: callback error: {e}");
1021 break;
1022 }
1023 }
1024 }
1025 });
1026
1027 let handle = lua.create_table()?;
1028 let cancel_fn = lua.create_function(move |_, ()| {
1029 cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1030 Ok(())
1031 })?;
1032 handle.set("cancel", cancel_fn)?;
1033
1034 Ok(handle)
1035 })?;
1036 async_table.set("spawn_interval", spawn_interval_fn)?;
1037
1038 lua.globals().set("async", async_table)?;
1039 Ok(())
1040}
1041
1042struct DbPool(Arc<AnyPool>);
1043impl UserData for DbPool {}
1044
1045fn register_db(lua: &Lua) -> mlua::Result<()> {
1046 sqlx::any::install_default_drivers();
1047
1048 let db_table = lua.create_table()?;
1049
1050 let connect_fn = lua.create_async_function(|lua, url: String| async move {
1051 let pool = sqlx::any::AnyPoolOptions::new()
1052 .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1053 .connect(&url)
1054 .await
1055 .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1056 lua.create_any_userdata(DbPool(Arc::new(pool)))
1057 })?;
1058 db_table.set("connect", connect_fn)?;
1059
1060 let query_fn =
1061 lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1062 let mut args_iter = args.into_iter();
1063
1064 let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1065 let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1066 let params = extract_params(&args_iter.next())?;
1067
1068 let mut query = sqlx::query(&sql);
1069 for p in ¶ms {
1070 query = bind_param(query, p);
1071 }
1072
1073 let rows: Vec<AnyRow> = query
1074 .fetch_all(&*pool)
1075 .await
1076 .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1077
1078 let result = lua.create_table()?;
1079 for (i, row) in rows.iter().enumerate() {
1080 let row_table = any_row_to_lua_table(&lua, row)?;
1081 result.set(i + 1, row_table)?;
1082 }
1083 Ok(Value::Table(result))
1084 })?;
1085 db_table.set("query", query_fn)?;
1086
1087 let execute_fn =
1088 lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1089 let mut args_iter = args.into_iter();
1090
1091 let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1092 let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1093 let params = extract_params(&args_iter.next())?;
1094
1095 let mut query = sqlx::query(&sql);
1096 for p in ¶ms {
1097 query = bind_param(query, p);
1098 }
1099
1100 let result = query
1101 .execute(&*pool)
1102 .await
1103 .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1104
1105 let tbl = lua.create_table()?;
1106 tbl.set("rows_affected", result.rows_affected() as i64)?;
1107 Ok(Value::Table(tbl))
1108 })?;
1109 db_table.set("execute", execute_fn)?;
1110
1111 let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1112 let mut args_iter = args.into_iter();
1113 let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1114 pool.close().await;
1115 Ok(())
1116 })?;
1117 db_table.set("close", close_fn)?;
1118
1119 lua.globals().set("db", db_table)?;
1120 Ok(())
1121}
1122
1123fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1124 match val {
1125 Some(Value::UserData(ud)) => {
1126 let db = ud
1127 .borrow::<DbPool>()
1128 .map_err(|_| mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection")))?;
1129 Ok(db.0.clone())
1130 }
1131 _ => Err(mlua::Error::runtime(format!(
1132 "{fn_name}: first argument must be a db connection"
1133 ))),
1134 }
1135}
1136
1137fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1138 match val {
1139 Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1140 _ => Err(mlua::Error::runtime(format!(
1141 "{fn_name}: second argument must be a SQL string"
1142 ))),
1143 }
1144}
1145
1146#[derive(Clone)]
1147enum DbParam {
1148 Null,
1149 Bool(bool),
1150 Int(i64),
1151 Float(f64),
1152 Text(String),
1153}
1154
1155fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1156 match val {
1157 Some(Value::Table(t)) => {
1158 let mut params = Vec::new();
1159 let len = t.len()?;
1160 for i in 1..=len {
1161 let v: Value = t.get(i)?;
1162 let param = match v {
1163 Value::Nil => DbParam::Null,
1164 Value::Boolean(b) => DbParam::Bool(b),
1165 Value::Integer(n) => DbParam::Int(n),
1166 Value::Number(f) => DbParam::Float(f),
1167 Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1168 _ => {
1169 return Err(mlua::Error::runtime(format!(
1170 "db: unsupported parameter type: {}",
1171 v.type_name()
1172 )));
1173 }
1174 };
1175 params.push(param);
1176 }
1177 Ok(params)
1178 }
1179 Some(Value::Nil) | None => Ok(Vec::new()),
1180 _ => Err(mlua::Error::runtime(
1181 "db: params must be a table (array) or nil",
1182 )),
1183 }
1184}
1185
1186fn bind_param<'q>(
1187 query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1188 param: &'q DbParam,
1189) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1190 match param {
1191 DbParam::Null => query.bind(None::<String>),
1192 DbParam::Bool(b) => query.bind(*b),
1193 DbParam::Int(n) => query.bind(*n),
1194 DbParam::Float(f) => query.bind(*f),
1195 DbParam::Text(s) => query.bind(s.as_str()),
1196 }
1197}
1198
1199fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1200 let table = lua.create_table()?;
1201 for col in row.columns() {
1202 let name = col.name();
1203 let val: Value = any_column_to_lua_value(lua, row, col)?;
1204 table.set(name.to_string(), val)?;
1205 }
1206 Ok(table)
1207}
1208
1209fn any_column_to_lua_value<C: Column>(
1210 lua: &Lua,
1211 row: &AnyRow,
1212 col: &C,
1213) -> mlua::Result<Value> {
1214 let ordinal = col.ordinal();
1215 let type_info = col.type_info();
1216 let type_name = type_info.to_string();
1217 let type_name = type_name.to_uppercase();
1218
1219 if row.try_get_raw(ordinal).map(|v| v.is_null()).unwrap_or(true) {
1220 return Ok(Value::Nil);
1221 }
1222
1223 match type_name.as_str() {
1224 "BOOLEAN" | "BOOL" => {
1225 let v: bool = row
1226 .try_get(ordinal)
1227 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1228 Ok(Value::Boolean(v))
1229 }
1230 "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1231 let v: i64 = row
1232 .try_get(ordinal)
1233 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1234 Ok(Value::Integer(v))
1235 }
1236 "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1237 let v: f64 = row
1238 .try_get(ordinal)
1239 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1240 Ok(Value::Number(v))
1241 }
1242 _ => {
1243 let v: String = row
1244 .try_get(ordinal)
1245 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1246 Ok(Value::String(lua.create_string(&v)?))
1247 }
1248 }
1249}
1250
1251fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1252 match val {
1253 Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1254 Some(other) => {
1255 let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1256 result
1257 }
1258 None => None,
1259 }
1260}
1261
1262fn lua_values_equal(a: &Value, b: &Value) -> bool {
1263 match (a, b) {
1264 (Value::Nil, Value::Nil) => true,
1265 (Value::Boolean(a), Value::Boolean(b)) => a == b,
1266 (Value::Integer(a), Value::Integer(b)) => a == b,
1267 (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1268 (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1269 (*a as f64 - b).abs() < f64::EPSILON
1270 }
1271 (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1272 _ => false,
1273 }
1274}
1275
1276fn lua_value_to_f64(val: Value) -> Option<f64> {
1277 match val {
1278 Value::Integer(i) => Some(i as f64),
1279 Value::Number(f) => Some(f),
1280 _ => None,
1281 }
1282}
1283
1284fn format_lua_value(val: &Value) -> String {
1285 match val {
1286 Value::Nil => "nil".to_string(),
1287 Value::Boolean(b) => b.to_string(),
1288 Value::Integer(i) => i.to_string(),
1289 Value::Number(f) => f.to_string(),
1290 Value::String(s) => match s.to_str() {
1291 Ok(v) => v.to_string(),
1292 Err(_) => "<invalid utf-8>".to_string(),
1293 },
1294 _ => format!("<{}>", val.type_name()),
1295 }
1296}
1297
1298type WsSink = Rc<
1299 tokio::sync::Mutex<
1300 futures_util::stream::SplitSink<
1301 tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1302 tokio_tungstenite::tungstenite::Message,
1303 >,
1304 >,
1305>;
1306type WsStream = Rc<
1307 tokio::sync::Mutex<
1308 futures_util::stream::SplitStream<
1309 tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1310 >,
1311 >,
1312>;
1313
1314struct WsConn {
1315 sink: WsSink,
1316 stream: WsStream,
1317}
1318impl UserData for WsConn {}
1319
1320fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1321 let ud = match val {
1322 Value::UserData(ud) => ud,
1323 _ => {
1324 return Err(mlua::Error::runtime(format!(
1325 "{fn_name}: first argument must be a ws connection"
1326 )));
1327 }
1328 };
1329 let ws = ud.borrow::<WsConn>().map_err(|_| {
1330 mlua::Error::runtime(format!(
1331 "{fn_name}: first argument must be a ws connection"
1332 ))
1333 })?;
1334 Ok((ws.sink.clone(), ws.stream.clone()))
1335}
1336
1337fn register_ws(lua: &Lua) -> mlua::Result<()> {
1338 let ws_table = lua.create_table()?;
1339
1340 let connect_fn = lua.create_async_function(|lua, url: String| async move {
1341 let (stream, _response) = tokio_tungstenite::connect_async(&url)
1342 .await
1343 .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1344 let (sink, read) = stream.split();
1345 lua.create_any_userdata(WsConn {
1346 sink: Rc::new(tokio::sync::Mutex::new(sink)),
1347 stream: Rc::new(tokio::sync::Mutex::new(read)),
1348 })
1349 })?;
1350 ws_table.set("connect", connect_fn)?;
1351
1352 let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1353 let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1354 sink.lock()
1355 .await
1356 .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1357 .await
1358 .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1359 Ok(())
1360 })?;
1361 ws_table.set("send", send_fn)?;
1362
1363 let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1364 let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1365 loop {
1366 let msg = stream
1367 .lock()
1368 .await
1369 .next()
1370 .await
1371 .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1372 .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1373 match msg {
1374 tokio_tungstenite::tungstenite::Message::Text(t) => {
1375 return Ok(t.to_string());
1376 }
1377 tokio_tungstenite::tungstenite::Message::Binary(b) => {
1378 return String::from_utf8(b.into()).map_err(|e| {
1379 mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}"))
1380 });
1381 }
1382 tokio_tungstenite::tungstenite::Message::Close(_) => {
1383 return Err(mlua::Error::runtime("ws.recv: connection closed"));
1384 }
1385 _ => continue,
1386 }
1387 }
1388 })?;
1389 ws_table.set("recv", recv_fn)?;
1390
1391 let close_fn = lua.create_async_function(|_, conn: Value| async move {
1392 let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1393 sink.lock()
1394 .await
1395 .close()
1396 .await
1397 .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1398 Ok(())
1399 })?;
1400 ws_table.set("close", close_fn)?;
1401
1402 lua.globals().set("ws", ws_table)?;
1403 Ok(())
1404}
1405
1406fn register_template(lua: &Lua) -> mlua::Result<()> {
1407 let tmpl_table = lua.create_table()?;
1408
1409 let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1410 let json_vars = match &vars {
1411 Value::Table(_) => lua_value_to_json(&vars)?,
1412 Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1413 _ => {
1414 return Err(mlua::Error::runtime(
1415 "template.render_string: second argument must be a table or nil",
1416 ));
1417 }
1418 };
1419 let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1420 let env = minijinja::Environment::new();
1421 let tmpl = env
1422 .template_from_str(&template_str)
1423 .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1424 tmpl.render(mini_vars)
1425 .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1426 })?;
1427 tmpl_table.set("render_string", render_string_fn)?;
1428
1429 let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1430 let content = std::fs::read_to_string(&file_path).map_err(|e| {
1431 mlua::Error::runtime(format!("template.render: failed to read {file_path:?}: {e}"))
1432 })?;
1433 let json_vars = match &vars {
1434 Value::Table(_) => lua_value_to_json(&vars)?,
1435 Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1436 _ => {
1437 return Err(mlua::Error::runtime(
1438 "template.render: second argument must be a table or nil",
1439 ));
1440 }
1441 };
1442 let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1443 let env = minijinja::Environment::new();
1444 let tmpl = env
1445 .template_from_str(&content)
1446 .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1447 tmpl.render(mini_vars)
1448 .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1449 })?;
1450 tmpl_table.set("render", render_fn)?;
1451
1452 lua.globals().set("template", tmpl_table)?;
1453 Ok(())
1454}
1455
1456fn lua_string_literal(s: &str) -> String {
1457 let escaped = s
1458 .replace('\\', "\\\\")
1459 .replace('"', "\\\"")
1460 .replace('\n', "\\n")
1461 .replace('\r', "\\r")
1462 .replace('\0', "\\0");
1463 format!("\"{escaped}\"")
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468 use super::*;
1469
1470 #[test]
1471 fn test_base64_roundtrip() {
1472 let input = "hello world";
1473 let encoded = BASE64.encode(input.as_bytes());
1474 assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1475 let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1476 assert_eq!(String::from_utf8(decoded).unwrap(), input);
1477 }
1478
1479 #[test]
1480 fn test_base64_empty() {
1481 let encoded = BASE64.encode(b"");
1482 assert_eq!(encoded, "");
1483 let decoded = BASE64.decode(b"").unwrap();
1484 assert!(decoded.is_empty());
1485 }
1486
1487 #[test]
1488 fn test_lua_value_to_json_nil() {
1489 let result = lua_value_to_json(&Value::Nil).unwrap();
1490 assert_eq!(result, serde_json::Value::Null);
1491 }
1492
1493 #[test]
1494 fn test_lua_value_to_json_bool() {
1495 assert_eq!(
1496 lua_value_to_json(&Value::Boolean(true)).unwrap(),
1497 serde_json::Value::Bool(true)
1498 );
1499 }
1500
1501 #[test]
1502 fn test_lua_value_to_json_integer() {
1503 assert_eq!(
1504 lua_value_to_json(&Value::Integer(42)).unwrap(),
1505 serde_json::json!(42)
1506 );
1507 }
1508
1509 #[test]
1510 fn test_lua_value_to_json_number() {
1511 assert_eq!(
1512 lua_value_to_json(&Value::Number(1.5)).unwrap(),
1513 serde_json::json!(1.5)
1514 );
1515 }
1516
1517 #[test]
1518 fn test_lua_values_equal_nil() {
1519 assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1520 }
1521
1522 #[test]
1523 fn test_lua_values_equal_int_float() {
1524 assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1525 }
1526
1527 #[test]
1528 fn test_lua_values_not_equal() {
1529 assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1530 }
1531
1532 #[test]
1533 fn test_format_lua_value() {
1534 assert_eq!(format_lua_value(&Value::Nil), "nil");
1535 assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1536 assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1537 }
1538
1539 #[test]
1540 fn test_lua_string_literal_escaping() {
1541 assert_eq!(lua_string_literal("hello"), "\"hello\"");
1542 assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1543 assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1544 }
1545}