1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
24 register_http(lua, client)?;
25 register_json(lua)?;
26 register_yaml(lua)?;
27 register_toml(lua)?;
28 register_assert(lua)?;
29 register_log(lua)?;
30 register_env(lua)?;
31 register_sleep(lua)?;
32 register_time(lua)?;
33 register_fs(lua)?;
34 register_base64(lua)?;
35 register_crypto(lua)?;
36 register_regex(lua)?;
37 register_async(lua)?;
38 register_db(lua)?;
39 register_ws(lua)?;
40 register_template(lua)?;
41 Ok(())
42}
43
44fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
45 let http_table = lua.create_table()?;
46
47 for method in ["get", "post", "put", "patch", "delete"] {
48 let method_client = client.clone();
49 let method_name = method.to_string();
50 let has_body = method != "get" && method != "delete";
51
52 let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
53 let client = method_client.clone();
54 let method_name = method_name.clone();
55 async move {
56 let mut args_iter = args.into_iter();
57 let url: String = match args_iter.next() {
58 Some(Value::String(s)) => s.to_str()?.to_string(),
59 _ => {
60 return Err(mlua::Error::runtime(format!(
61 "http.{method_name}: first argument must be a URL string"
62 )));
63 }
64 };
65
66 let (body_str, auto_json, opts) = if has_body {
67 let (body, is_json) = match args_iter.next() {
68 Some(Value::String(s)) => (s.to_str()?.to_string(), false),
69 Some(Value::Table(t)) => {
70 let json_val = lua_table_to_json(&t)?;
71 let serialized = serde_json::to_string(&json_val).map_err(|e| {
72 mlua::Error::runtime(format!(
73 "http.{method_name}: JSON encode failed: {e}"
74 ))
75 })?;
76 (serialized, true)
77 }
78 Some(Value::Nil) | None => (String::new(), false),
79 _ => {
80 return Err(mlua::Error::runtime(format!(
81 "http.{method_name}: second argument must be a string, table, or nil"
82 )));
83 }
84 };
85 let opts = match args_iter.next() {
86 Some(Value::Table(t)) => Some(t),
87 Some(Value::Nil) | None => None,
88 _ => {
89 return Err(mlua::Error::runtime(format!(
90 "http.{method_name}: third argument must be a table or nil"
91 )));
92 }
93 };
94 (body, is_json, opts)
95 } else {
96 let opts = match args_iter.next() {
97 Some(Value::Table(t)) => Some(t),
98 Some(Value::Nil) | None => None,
99 _ => {
100 return Err(mlua::Error::runtime(format!(
101 "http.{method_name}: second argument must be a table or nil"
102 )));
103 }
104 };
105 (String::new(), false, opts)
106 };
107
108 let mut req = match method_name.as_str() {
109 "get" => client.get(&url),
110 "post" => client.post(&url),
111 "put" => client.put(&url),
112 "patch" => client.patch(&url),
113 "delete" => client.delete(&url),
114 _ => unreachable!(),
115 };
116
117 if has_body && !body_str.is_empty() {
118 req = req.body(body_str);
119 }
120 if auto_json {
121 req = req.header("Content-Type", "application/json");
122 }
123 if let Some(ref opts_table) = opts
124 && let Ok(headers_table) = opts_table.get::<Table>("headers")
125 {
126 for pair in headers_table.pairs::<String, String>() {
127 let (k, v) = pair?;
128 req = req.header(k, v);
129 }
130 }
131
132 let resp = req.send().await.map_err(|e| {
133 mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
134 })?;
135 let status = resp.status().as_u16();
136 let resp_headers = resp.headers().clone();
137 let body = resp.text().await.map_err(|e| {
138 mlua::Error::runtime(format!(
139 "http.{method_name}: reading body failed: {e}"
140 ))
141 })?;
142
143 let result = lua.create_table()?;
144 result.set("status", status)?;
145 result.set("body", body)?;
146
147 let headers_out = lua.create_table()?;
148 for (name, value) in &resp_headers {
149 if let Ok(v) = value.to_str() {
150 headers_out.set(name.as_str().to_string(), v.to_string())?;
151 }
152 }
153 result.set("headers", headers_out)?;
154
155 Ok(Value::Table(result))
156 }
157 })?;
158 http_table.set(method, func)?;
159 }
160
161 let serve_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
162 let mut args_iter = args.into_iter();
163
164 let port: u16 = match args_iter.next() {
165 Some(Value::Integer(n)) => n as u16,
166 _ => {
167 return Err::<(), _>(mlua::Error::runtime(
168 "http.serve: first argument must be a port number",
169 ));
170 }
171 };
172
173 let routes_table = match args_iter.next() {
174 Some(Value::Table(t)) => t,
175 _ => {
176 return Err::<(), _>(mlua::Error::runtime(
177 "http.serve: second argument must be a routes table",
178 ));
179 }
180 };
181
182 let routes = Rc::new(parse_routes(&routes_table)?);
183
184 let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
185 .await
186 .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
187
188 loop {
189 let (stream, _addr) = listener
190 .accept()
191 .await
192 .map_err(|e| mlua::Error::runtime(format!("http.serve: accept failed: {e}")))?;
193
194 let routes = routes.clone();
195 let lua_clone = lua.clone();
196
197 tokio::task::spawn_local(async move {
198 let io = hyper_util::rt::TokioIo::new(stream);
199 let routes = routes.clone();
200 let lua = lua_clone.clone();
201
202 let service = service_fn(move |req: Request<Incoming>| {
203 let routes = routes.clone();
204 let lua = lua.clone();
205 async move { handle_request(&lua, &routes, req).await }
206 });
207
208 if let Err(e) = http1::Builder::new().serve_connection(io, service).await
209 && !e.to_string().contains("connection closed")
210 {
211 error!("http.serve: connection error: {e}");
212 }
213 });
214 }
215 })?;
216 http_table.set("serve", serve_fn)?;
217
218 lua.globals().set("http", http_table)?;
219 Ok(())
220}
221
222fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
223 let mut routes = HashMap::new();
224 for method_pair in routes_table.pairs::<String, Table>() {
225 let (method, paths_table) = method_pair?;
226 let method_upper = method.to_uppercase();
227 for path_pair in paths_table.pairs::<String, mlua::Function>() {
228 let (path, func) = path_pair?;
229 routes.insert((method_upper.clone(), path), func);
230 }
231 }
232 Ok(routes)
233}
234
235async fn handle_request(
236 lua: &Lua,
237 routes: &HashMap<(String, String), mlua::Function>,
238 req: Request<Incoming>,
239) -> Result<Response<Full<Bytes>>, hyper::Error> {
240 let method = req.method().to_string();
241 let path = req.uri().path().to_string();
242 let query = req.uri().query().unwrap_or("").to_string();
243 let headers: Vec<(String, String)> = req
244 .headers()
245 .iter()
246 .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
247 .collect();
248
249 let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
250 Ok(collected) => collected.to_bytes(),
251 Err(_) => Bytes::new(),
252 };
253 let body_str = String::from_utf8_lossy(&body_bytes).to_string();
254
255 let key = (method.clone(), path.clone());
256 let handler = match routes.get(&key) {
257 Some(f) => f,
258 None => {
259 return Ok(Response::builder()
260 .status(StatusCode::NOT_FOUND)
261 .header("content-type", "text/plain")
262 .body(Full::new(Bytes::from("not found")))
263 .unwrap());
264 }
265 };
266
267 match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
268 Ok(lua_resp) => lua_response_to_http(&lua_resp),
269 Err(e) => Ok(Response::builder()
270 .status(StatusCode::INTERNAL_SERVER_ERROR)
271 .header("content-type", "text/plain")
272 .body(Full::new(Bytes::from(format!("handler error: {e}"))))
273 .unwrap()),
274 }
275}
276
277fn build_lua_request_and_call(
278 lua: &Lua,
279 handler: &mlua::Function,
280 method: &str,
281 path: &str,
282 query: &str,
283 headers: &[(String, String)],
284 body: &str,
285) -> mlua::Result<Table> {
286 let req_table = lua.create_table()?;
287 req_table.set("method", method.to_string())?;
288 req_table.set("path", path.to_string())?;
289 req_table.set("query", query.to_string())?;
290 req_table.set("body", body.to_string())?;
291
292 let headers_table = lua.create_table()?;
293 for (k, v) in headers {
294 headers_table.set(k.as_str(), v.as_str())?;
295 }
296 req_table.set("headers", headers_table)?;
297
298 handler.call::<Table>(req_table)
299}
300
301fn lua_response_to_http(resp_table: &Table) -> Result<Response<Full<Bytes>>, hyper::Error> {
302 let status = resp_table
303 .get::<Option<u16>>("status")
304 .unwrap_or(None)
305 .unwrap_or(200);
306
307 let mut builder =
308 Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
309
310 if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
311 for (k, v) in headers_table.pairs::<String, String>().flatten() {
312 builder = builder.header(k, v);
313 }
314 }
315
316 let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
317 let json_val =
318 lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
319 let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
320 builder = builder.header("content-type", "application/json");
321 Bytes::from(serialized)
322 } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
323 builder = builder.header("content-type", "text/plain");
324 Bytes::from(body_str)
325 } else {
326 builder = builder.header("content-type", "text/plain");
327 Bytes::new()
328 };
329
330 Ok(builder.body(Full::new(body_bytes)).unwrap())
331}
332
333fn register_json(lua: &Lua) -> mlua::Result<()> {
334 let json_table = lua.create_table()?;
335
336 let parse_fn = lua.create_function(|lua, s: String| {
337 let value: serde_json::Value = serde_json::from_str(&s)
338 .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
339 json_value_to_lua(lua, &value)
340 })?;
341 json_table.set("parse", parse_fn)?;
342
343 let encode_fn = lua.create_function(|_, val: Value| {
344 let json_val = lua_value_to_json(&val)?;
345 serde_json::to_string(&json_val)
346 .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
347 })?;
348 json_table.set("encode", encode_fn)?;
349
350 lua.globals().set("json", json_table)?;
351 Ok(())
352}
353
354fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
355 let mut is_array = true;
356 let mut max_index: i64 = 0;
357 let mut count: i64 = 0;
358
359 for pair in table.clone().pairs::<Value, Value>() {
360 let (key, _) = pair?;
361 count += 1;
362 match key {
363 Value::Integer(i) if i >= 1 => {
364 if i > max_index {
365 max_index = i;
366 }
367 }
368 _ => {
369 is_array = false;
370 break;
371 }
372 }
373 }
374
375 if is_array && max_index == count {
376 let mut arr = Vec::with_capacity(max_index as usize);
377 for i in 1..=max_index {
378 let val: Value = table.get(i)?;
379 arr.push(lua_value_to_json(&val)?);
380 }
381 Ok(serde_json::Value::Array(arr))
382 } else {
383 let mut map = serde_json::Map::new();
384 for pair in table.clone().pairs::<Value, Value>() {
385 let (key, val) = pair?;
386 let key_str = match key {
387 Value::String(s) => s.to_str()?.to_string(),
388 Value::Integer(i) => i.to_string(),
389 Value::Number(f) => f.to_string(),
390 _ => {
391 return Err(mlua::Error::runtime(format!(
392 "unsupported table key type: {}",
393 key.type_name()
394 )));
395 }
396 };
397 map.insert(key_str, lua_value_to_json(&val)?);
398 }
399 Ok(serde_json::Value::Object(map))
400 }
401}
402
403fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
404 match val {
405 Value::Nil => Ok(serde_json::Value::Null),
406 Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
407 Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
408 Value::Number(f) => serde_json::Number::from_f64(*f)
409 .map(serde_json::Value::Number)
410 .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
411 Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
412 Value::Table(t) => lua_table_to_json(t),
413 _ => Err(mlua::Error::runtime(format!(
414 "unsupported Lua type for JSON: {}",
415 val.type_name()
416 ))),
417 }
418}
419
420pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
421 match val {
422 serde_json::Value::Null => Ok(Value::Nil),
423 serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
424 serde_json::Value::Number(n) => {
425 if let Some(i) = n.as_i64() {
426 Ok(Value::Integer(i))
427 } else if let Some(f) = n.as_f64() {
428 Ok(Value::Number(f))
429 } else {
430 Ok(Value::Nil)
431 }
432 }
433 serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
434 serde_json::Value::Array(arr) => {
435 let table = lua.create_table()?;
436 for (i, item) in arr.iter().enumerate() {
437 table.set(i + 1, json_value_to_lua(lua, item)?)?;
438 }
439 Ok(Value::Table(table))
440 }
441 serde_json::Value::Object(map) => {
442 let table = lua.create_table()?;
443 for (k, v) in map {
444 table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
445 }
446 Ok(Value::Table(table))
447 }
448 }
449}
450
451fn register_assert(lua: &Lua) -> mlua::Result<()> {
452 let assert_table = lua.create_table()?;
453
454 let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
455 let mut args_iter = args.into_iter();
456 let a = args_iter.next().unwrap_or(Value::Nil);
457 let b = args_iter.next().unwrap_or(Value::Nil);
458 let msg = extract_string_arg(lua, args_iter.next());
459
460 if !lua_values_equal(&a, &b) {
461 let detail = format!(
462 "assert.eq failed: {:?} != {:?}{}",
463 format_lua_value(&a),
464 format_lua_value(&b),
465 msg.map(|m| format!(" - {m}")).unwrap_or_default()
466 );
467 return Err(mlua::Error::runtime(detail));
468 }
469 Ok(())
470 })?;
471 assert_table.set("eq", eq_fn)?;
472
473 let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
474 let mut args_iter = args.into_iter();
475 let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
476 let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
477 let msg = extract_string_arg(lua, args_iter.next());
478
479 match (a, b) {
480 (Some(va), Some(vb)) if va > vb => Ok(()),
481 (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
482 "assert.gt failed: {va} is not > {vb}{}",
483 msg.map(|m| format!(" - {m}")).unwrap_or_default()
484 ))),
485 _ => Err(mlua::Error::runtime(
486 "assert.gt: both arguments must be numbers",
487 )),
488 }
489 })?;
490 assert_table.set("gt", gt_fn)?;
491
492 let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
493 let mut args_iter = args.into_iter();
494 let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
495 let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
496 let msg = extract_string_arg(lua, args_iter.next());
497
498 match (a, b) {
499 (Some(va), Some(vb)) if va < vb => Ok(()),
500 (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
501 "assert.lt failed: {va} is not < {vb}{}",
502 msg.map(|m| format!(" - {m}")).unwrap_or_default()
503 ))),
504 _ => Err(mlua::Error::runtime(
505 "assert.lt: both arguments must be numbers",
506 )),
507 }
508 })?;
509 assert_table.set("lt", lt_fn)?;
510
511 let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
512 let mut args_iter = args.into_iter();
513 let haystack: String = match args_iter.next() {
514 Some(Value::String(s)) => s.to_str()?.to_string(),
515 _ => {
516 return Err(mlua::Error::runtime(
517 "assert.contains: first argument must be a string",
518 ));
519 }
520 };
521 let needle: String = match args_iter.next() {
522 Some(Value::String(s)) => s.to_str()?.to_string(),
523 _ => {
524 return Err(mlua::Error::runtime(
525 "assert.contains: second argument must be a string",
526 ));
527 }
528 };
529 let msg = extract_string_arg(lua, args_iter.next());
530
531 if !haystack.contains(&needle) {
532 return Err(mlua::Error::runtime(format!(
533 "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
534 msg.map(|m| format!(" - {m}")).unwrap_or_default()
535 )));
536 }
537 Ok(())
538 })?;
539 assert_table.set("contains", contains_fn)?;
540
541 let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
542 let mut args_iter = args.into_iter();
543 let val = args_iter.next().unwrap_or(Value::Nil);
544 let msg = extract_string_arg(lua, args_iter.next());
545
546 if val == Value::Nil {
547 return Err(mlua::Error::runtime(format!(
548 "assert.not_nil failed: value is nil{}",
549 msg.map(|m| format!(" - {m}")).unwrap_or_default()
550 )));
551 }
552 Ok(())
553 })?;
554 assert_table.set("not_nil", not_nil_fn)?;
555
556 let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
557 let mut args_iter = args.into_iter();
558 let text: String = match args_iter.next() {
559 Some(Value::String(s)) => s.to_str()?.to_string(),
560 _ => {
561 return Err(mlua::Error::runtime(
562 "assert.matches: first argument must be a string",
563 ));
564 }
565 };
566 let pattern: String = match args_iter.next() {
567 Some(Value::String(s)) => s.to_str()?.to_string(),
568 _ => {
569 return Err(mlua::Error::runtime(
570 "assert.matches: second argument must be a pattern string",
571 ));
572 }
573 };
574 let msg = extract_string_arg(lua, args_iter.next());
575
576 let found: bool = lua
577 .load(format!(
578 "return string.find({}, {}) ~= nil",
579 lua_string_literal(&text),
580 lua_string_literal(&pattern)
581 ))
582 .eval()
583 .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
584
585 if !found {
586 return Err(mlua::Error::runtime(format!(
587 "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
588 msg.map(|m| format!(" - {m}")).unwrap_or_default()
589 )));
590 }
591 Ok(())
592 })?;
593 assert_table.set("matches", matches_fn)?;
594
595 lua.globals().set("assert", assert_table)?;
596 Ok(())
597}
598
599fn register_log(lua: &Lua) -> mlua::Result<()> {
600 let log_table = lua.create_table()?;
601
602 let info_fn = lua.create_function(|_, msg: String| {
603 info!(target: "lua", "{}", msg);
604 Ok(())
605 })?;
606 log_table.set("info", info_fn)?;
607
608 let warn_fn = lua.create_function(|_, msg: String| {
609 warn!(target: "lua", "{}", msg);
610 Ok(())
611 })?;
612 log_table.set("warn", warn_fn)?;
613
614 let error_fn = lua.create_function(|_, msg: String| {
615 error!(target: "lua", "{}", msg);
616 Ok(())
617 })?;
618 log_table.set("error", error_fn)?;
619
620 lua.globals().set("log", log_table)?;
621 Ok(())
622}
623
624fn register_env(lua: &Lua) -> mlua::Result<()> {
625 let env_table = lua.create_table()?;
626
627 let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
628 Ok(val) => Ok(Some(val)),
629 Err(_) => Ok(None),
630 })?;
631 env_table.set("_process_get", process_get_fn)?;
632 env_table.set("_check_env", lua.create_table()?)?;
633
634 lua.globals().set("env", env_table)?;
635
636 lua.load(
637 r#"
638 function env.get(name)
639 local val = env._check_env[name]
640 if val ~= nil then return val end
641 return env._process_get(name)
642 end
643 "#,
644 )
645 .exec()?;
646
647 Ok(())
648}
649
650fn register_sleep(lua: &Lua) -> mlua::Result<()> {
651 let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
652 let duration = std::time::Duration::from_secs_f64(seconds);
653 tokio::time::sleep(duration).await;
654 Ok(())
655 })?;
656 lua.globals().set("sleep", sleep_fn)?;
657 Ok(())
658}
659
660fn register_time(lua: &Lua) -> mlua::Result<()> {
661 let time_fn = lua.create_function(|_, ()| {
662 let secs = SystemTime::now()
663 .duration_since(UNIX_EPOCH)
664 .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
665 .as_secs_f64();
666 Ok(secs)
667 })?;
668 lua.globals().set("time", time_fn)?;
669 Ok(())
670}
671
672fn register_fs(lua: &Lua) -> mlua::Result<()> {
673 let fs_table = lua.create_table()?;
674
675 let read_fn = lua.create_function(|_, path: String| {
676 std::fs::read_to_string(&path)
677 .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
678 })?;
679 fs_table.set("read", read_fn)?;
680
681 let write_fn = lua.create_function(|_, (path, content): (String, String)| {
682 let p = std::path::Path::new(&path);
683 if let Some(parent) = p.parent() {
684 std::fs::create_dir_all(parent).map_err(|e| {
685 mlua::Error::runtime(format!(
686 "fs.write: failed to create directories for {path:?}: {e}"
687 ))
688 })?;
689 }
690 std::fs::write(&path, &content)
691 .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
692 })?;
693 fs_table.set("write", write_fn)?;
694
695 lua.globals().set("fs", fs_table)?;
696 Ok(())
697}
698
699fn register_base64(lua: &Lua) -> mlua::Result<()> {
700 let b64_table = lua.create_table()?;
701
702 let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
703 b64_table.set("encode", encode_fn)?;
704
705 let decode_fn = lua.create_function(|_, input: String| {
706 let bytes = BASE64
707 .decode(input.as_bytes())
708 .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
709 String::from_utf8(bytes)
710 .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
711 })?;
712 b64_table.set("decode", decode_fn)?;
713
714 lua.globals().set("base64", b64_table)?;
715 Ok(())
716}
717
718fn register_crypto(lua: &Lua) -> mlua::Result<()> {
719 let crypto_table = lua.create_table()?;
720
721 let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
722 let mut args_iter = args.into_iter();
723
724 let claims_table = match args_iter.next() {
725 Some(Value::Table(t)) => t,
726 _ => {
727 return Err(mlua::Error::runtime(
728 "crypto.jwt_sign: first argument must be a claims table",
729 ));
730 }
731 };
732
733 let pem_key: String = match args_iter.next() {
734 Some(Value::String(s)) => s.to_str()?.to_string(),
735 _ => {
736 return Err(mlua::Error::runtime(
737 "crypto.jwt_sign: second argument must be a PEM key string",
738 ));
739 }
740 };
741
742 let algorithm = match args_iter.next() {
743 Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
744 "RS256" => Algorithm::RS256,
745 "RS384" => Algorithm::RS384,
746 "RS512" => Algorithm::RS512,
747 other => {
748 return Err(mlua::Error::runtime(format!(
749 "crypto.jwt_sign: unsupported algorithm: {other}"
750 )));
751 }
752 },
753 Some(Value::Nil) | None => Algorithm::RS256,
754 _ => {
755 return Err(mlua::Error::runtime(
756 "crypto.jwt_sign: third argument must be an algorithm string or nil",
757 ));
758 }
759 };
760
761 let opts = match args_iter.next() {
763 Some(Value::Table(t)) => Some(t),
764 Some(Value::Nil) | None => None,
765 _ => {
766 return Err(mlua::Error::runtime(
767 "crypto.jwt_sign: fourth argument must be an options table or nil",
768 ));
769 }
770 };
771
772 let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
773 let pem_bytes = Zeroizing::new(pem_key.into_bytes());
774 let key = EncodingKey::from_rsa_pem(&pem_bytes)
775 .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
776
777 let mut header = Header::new(algorithm);
778 if let Some(ref opts_table) = opts
779 && let Ok(kid) = opts_table.get::<String>("kid")
780 {
781 header.kid = Some(kid);
782 }
783 let token = jsonwebtoken::encode(&header, &claims_json, &key)
784 .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
785
786 Ok(token)
787 })?;
788 crypto_table.set("jwt_sign", jwt_sign_fn)?;
789
790 let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
791 let mut args_iter = args.into_iter();
792
793 let input: String = match args_iter.next() {
794 Some(Value::String(s)) => s.to_str()?.to_string(),
795 _ => {
796 return Err(mlua::Error::runtime(
797 "crypto.hash: first argument must be a string",
798 ));
799 }
800 };
801
802 let algorithm: String = match args_iter.next() {
803 Some(Value::String(s)) => s.to_str()?.to_lowercase(),
804 Some(Value::Nil) | None => "sha256".to_string(),
805 _ => {
806 return Err(mlua::Error::runtime(
807 "crypto.hash: second argument must be an algorithm string or nil",
808 ));
809 }
810 };
811
812 let hex = match algorithm.as_str() {
813 "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
814 "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
815 "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
816 "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
817 "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
818 "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
819 "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
820 "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
821 other => {
822 return Err(mlua::Error::runtime(format!(
823 "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
824 )));
825 }
826 };
827
828 Ok(hex)
829 })?;
830 crypto_table.set("hash", hash_fn)?;
831
832 let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
833 let mut args_iter = args.into_iter();
834
835 let length: usize = match args_iter.next() {
836 Some(Value::Integer(n)) if n > 0 => n as usize,
837 Some(Value::Integer(n)) => {
838 return Err(mlua::Error::runtime(format!(
839 "crypto.random: length must be positive, got {n}"
840 )));
841 }
842 Some(Value::Nil) | None => 32,
843 _ => {
844 return Err(mlua::Error::runtime(
845 "crypto.random: first argument must be a positive integer or nil",
846 ));
847 }
848 };
849
850 let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
851 let mut rng = rand::rng();
852 let result: String = (0..length)
853 .map(|_| charset[rng.random_range(..charset.len())] as char)
854 .collect();
855
856 Ok(result)
857 })?;
858 crypto_table.set("random", random_fn)?;
859
860 lua.globals().set("crypto", crypto_table)?;
861 Ok(())
862}
863
864fn register_yaml(lua: &Lua) -> mlua::Result<()> {
865 let yaml_table = lua.create_table()?;
866
867 let parse_fn = lua.create_function(|lua, s: String| {
868 let json_val: serde_json::Value = serde_yml::from_str(&s)
869 .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
870 json_value_to_lua(lua, &json_val)
871 })?;
872 yaml_table.set("parse", parse_fn)?;
873
874 let encode_fn = lua.create_function(|_, val: Value| {
875 let json_val = lua_value_to_json(&val)?;
876 serde_yml::to_string(&json_val)
877 .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
878 })?;
879 yaml_table.set("encode", encode_fn)?;
880
881 lua.globals().set("yaml", yaml_table)?;
882 Ok(())
883}
884
885fn register_toml(lua: &Lua) -> mlua::Result<()> {
886 let toml_table = lua.create_table()?;
887
888 let parse_fn = lua.create_function(|lua, s: String| {
889 let toml_val: toml::Value =
890 toml::from_str(&s).map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
891 let json_val = serde_json::to_value(&toml_val)
892 .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
893 json_value_to_lua(lua, &json_val)
894 })?;
895 toml_table.set("parse", parse_fn)?;
896
897 let encode_fn = lua.create_function(|_, val: Value| {
898 let json_val = lua_value_to_json(&val)?;
899 let toml_val: toml::Value = serde_json::from_value(json_val)
900 .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
901 toml::to_string_pretty(&toml_val)
902 .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
903 })?;
904 toml_table.set("encode", encode_fn)?;
905
906 lua.globals().set("toml", toml_table)?;
907 Ok(())
908}
909
910fn register_regex(lua: &Lua) -> mlua::Result<()> {
911 let regex_table = lua.create_table()?;
912
913 let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
914 let re = regex_lite::Regex::new(&pattern)
915 .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
916 Ok(re.is_match(&text))
917 })?;
918 regex_table.set("match", match_fn)?;
919
920 let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
921 let re = regex_lite::Regex::new(&pattern)
922 .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
923 match re.captures(&text) {
924 Some(caps) => {
925 let result = lua.create_table()?;
926 let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
927 result.set("match", full_match.to_string())?;
928 let groups = lua.create_table()?;
929 for i in 1..caps.len() {
930 if let Some(m) = caps.get(i) {
931 groups.set(i, m.as_str().to_string())?;
932 }
933 }
934 result.set("groups", groups)?;
935 Ok(Value::Table(result))
936 }
937 None => Ok(Value::Nil),
938 }
939 })?;
940 regex_table.set("find", find_fn)?;
941
942 let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
943 let re = regex_lite::Regex::new(&pattern)
944 .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
945 let results = lua.create_table()?;
946 for (i, m) in re.find_iter(&text).enumerate() {
947 results.set(i + 1, m.as_str().to_string())?;
948 }
949 Ok(results)
950 })?;
951 regex_table.set("find_all", find_all_fn)?;
952
953 let replace_fn = lua.create_function(
954 |_, (text, pattern, replacement): (String, String, String)| {
955 let re = regex_lite::Regex::new(&pattern).map_err(|e| {
956 mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
957 })?;
958 Ok(re.replace_all(&text, replacement.as_str()).into_owned())
959 },
960 )?;
961 regex_table.set("replace", replace_fn)?;
962
963 lua.globals().set("regex", regex_table)?;
964 Ok(())
965}
966
967fn register_async(lua: &Lua) -> mlua::Result<()> {
968 let async_table = lua.create_table()?;
969
970 let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
971 let thread = lua.create_thread(func)?;
972 let async_thread = thread.into_async::<mlua::MultiValue>(())?;
973 let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
974 tokio::task::spawn_local(async move {
975 let values = async_thread.await.map_err(|e| e.to_string())?;
976 Ok(values.into_vec())
977 });
978
979 let handle = lua.create_table()?;
980 let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
981 let cell_clone = cell.clone();
982
983 let await_fn = lua.create_async_function(move |lua, ()| {
984 let cell = cell_clone.clone();
985 async move {
986 let join_handle = cell
987 .borrow_mut()
988 .take()
989 .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
990 let result = join_handle.await.map_err(|e| {
991 mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
992 })?;
993 match result {
994 Ok(values) => {
995 let tbl = lua.create_table()?;
996 for (i, v) in values.into_iter().enumerate() {
997 tbl.set(i + 1, v)?;
998 }
999 Ok(Value::Table(tbl))
1000 }
1001 Err(msg) => Err(mlua::Error::runtime(msg)),
1002 }
1003 }
1004 })?;
1005 handle.set("await", await_fn)?;
1006
1007 Ok(handle)
1008 })?;
1009 async_table.set("spawn", spawn_fn)?;
1010
1011 let spawn_interval_fn =
1012 lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
1013 if seconds <= 0.0 {
1014 return Err(mlua::Error::runtime(
1015 "async.spawn_interval: interval must be positive",
1016 ));
1017 }
1018
1019 let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1020 let cancel_clone = cancel.clone();
1021
1022 tokio::task::spawn_local({
1023 let cancel = cancel_clone.clone();
1024 async move {
1025 let mut interval =
1026 tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1027 interval.tick().await;
1028 loop {
1029 interval.tick().await;
1030 if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1031 break;
1032 }
1033 if let Err(e) = func.call_async::<()>(()).await {
1034 error!("async.spawn_interval: callback error: {e}");
1035 break;
1036 }
1037 }
1038 }
1039 });
1040
1041 let handle = lua.create_table()?;
1042 let cancel_fn = lua.create_function(move |_, ()| {
1043 cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1044 Ok(())
1045 })?;
1046 handle.set("cancel", cancel_fn)?;
1047
1048 Ok(handle)
1049 })?;
1050 async_table.set("spawn_interval", spawn_interval_fn)?;
1051
1052 lua.globals().set("async", async_table)?;
1053 Ok(())
1054}
1055
1056struct DbPool(Arc<AnyPool>);
1057impl UserData for DbPool {}
1058
1059fn register_db(lua: &Lua) -> mlua::Result<()> {
1060 sqlx::any::install_default_drivers();
1061
1062 let db_table = lua.create_table()?;
1063
1064 let connect_fn = lua.create_async_function(|lua, url: String| async move {
1065 let pool = sqlx::any::AnyPoolOptions::new()
1066 .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1067 .connect(&url)
1068 .await
1069 .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1070 lua.create_any_userdata(DbPool(Arc::new(pool)))
1071 })?;
1072 db_table.set("connect", connect_fn)?;
1073
1074 let query_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1075 let mut args_iter = args.into_iter();
1076
1077 let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1078 let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1079 let params = extract_params(&args_iter.next())?;
1080
1081 let mut query = sqlx::query(&sql);
1082 for p in ¶ms {
1083 query = bind_param(query, p);
1084 }
1085
1086 let rows: Vec<AnyRow> = query
1087 .fetch_all(&*pool)
1088 .await
1089 .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1090
1091 let result = lua.create_table()?;
1092 for (i, row) in rows.iter().enumerate() {
1093 let row_table = any_row_to_lua_table(&lua, row)?;
1094 result.set(i + 1, row_table)?;
1095 }
1096 Ok(Value::Table(result))
1097 })?;
1098 db_table.set("query", query_fn)?;
1099
1100 let execute_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1101 let mut args_iter = args.into_iter();
1102
1103 let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1104 let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1105 let params = extract_params(&args_iter.next())?;
1106
1107 let mut query = sqlx::query(&sql);
1108 for p in ¶ms {
1109 query = bind_param(query, p);
1110 }
1111
1112 let result = query
1113 .execute(&*pool)
1114 .await
1115 .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1116
1117 let tbl = lua.create_table()?;
1118 tbl.set("rows_affected", result.rows_affected() as i64)?;
1119 Ok(Value::Table(tbl))
1120 })?;
1121 db_table.set("execute", execute_fn)?;
1122
1123 let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1124 let mut args_iter = args.into_iter();
1125 let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1126 pool.close().await;
1127 Ok(())
1128 })?;
1129 db_table.set("close", close_fn)?;
1130
1131 lua.globals().set("db", db_table)?;
1132 Ok(())
1133}
1134
1135fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1136 match val {
1137 Some(Value::UserData(ud)) => {
1138 let db = ud.borrow::<DbPool>().map_err(|_| {
1139 mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection"))
1140 })?;
1141 Ok(db.0.clone())
1142 }
1143 _ => Err(mlua::Error::runtime(format!(
1144 "{fn_name}: first argument must be a db connection"
1145 ))),
1146 }
1147}
1148
1149fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1150 match val {
1151 Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1152 _ => Err(mlua::Error::runtime(format!(
1153 "{fn_name}: second argument must be a SQL string"
1154 ))),
1155 }
1156}
1157
1158#[derive(Clone)]
1159enum DbParam {
1160 Null,
1161 Bool(bool),
1162 Int(i64),
1163 Float(f64),
1164 Text(String),
1165}
1166
1167fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1168 match val {
1169 Some(Value::Table(t)) => {
1170 let mut params = Vec::new();
1171 let len = t.len()?;
1172 for i in 1..=len {
1173 let v: Value = t.get(i)?;
1174 let param = match v {
1175 Value::Nil => DbParam::Null,
1176 Value::Boolean(b) => DbParam::Bool(b),
1177 Value::Integer(n) => DbParam::Int(n),
1178 Value::Number(f) => DbParam::Float(f),
1179 Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1180 _ => {
1181 return Err(mlua::Error::runtime(format!(
1182 "db: unsupported parameter type: {}",
1183 v.type_name()
1184 )));
1185 }
1186 };
1187 params.push(param);
1188 }
1189 Ok(params)
1190 }
1191 Some(Value::Nil) | None => Ok(Vec::new()),
1192 _ => Err(mlua::Error::runtime(
1193 "db: params must be a table (array) or nil",
1194 )),
1195 }
1196}
1197
1198fn bind_param<'q>(
1199 query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1200 param: &'q DbParam,
1201) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1202 match param {
1203 DbParam::Null => query.bind(None::<String>),
1204 DbParam::Bool(b) => query.bind(*b),
1205 DbParam::Int(n) => query.bind(*n),
1206 DbParam::Float(f) => query.bind(*f),
1207 DbParam::Text(s) => query.bind(s.as_str()),
1208 }
1209}
1210
1211fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1212 let table = lua.create_table()?;
1213 for col in row.columns() {
1214 let name = col.name();
1215 let val: Value = any_column_to_lua_value(lua, row, col)?;
1216 table.set(name.to_string(), val)?;
1217 }
1218 Ok(table)
1219}
1220
1221fn any_column_to_lua_value<C: Column>(lua: &Lua, row: &AnyRow, col: &C) -> mlua::Result<Value> {
1222 let ordinal = col.ordinal();
1223 let type_info = col.type_info();
1224 let type_name = type_info.to_string();
1225 let type_name = type_name.to_uppercase();
1226
1227 if row
1228 .try_get_raw(ordinal)
1229 .map(|v| v.is_null())
1230 .unwrap_or(true)
1231 {
1232 return Ok(Value::Nil);
1233 }
1234
1235 match type_name.as_str() {
1236 "BOOLEAN" | "BOOL" => {
1237 let v: bool = row
1238 .try_get(ordinal)
1239 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1240 Ok(Value::Boolean(v))
1241 }
1242 "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1243 let v: i64 = row
1244 .try_get(ordinal)
1245 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1246 Ok(Value::Integer(v))
1247 }
1248 "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1249 let v: f64 = row
1250 .try_get(ordinal)
1251 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1252 Ok(Value::Number(v))
1253 }
1254 _ => {
1255 let v: String = row
1256 .try_get(ordinal)
1257 .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1258 Ok(Value::String(lua.create_string(&v)?))
1259 }
1260 }
1261}
1262
1263fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1264 match val {
1265 Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1266 Some(other) => {
1267 let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1268 result
1269 }
1270 None => None,
1271 }
1272}
1273
1274fn lua_values_equal(a: &Value, b: &Value) -> bool {
1275 match (a, b) {
1276 (Value::Nil, Value::Nil) => true,
1277 (Value::Boolean(a), Value::Boolean(b)) => a == b,
1278 (Value::Integer(a), Value::Integer(b)) => a == b,
1279 (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1280 (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1281 (*a as f64 - b).abs() < f64::EPSILON
1282 }
1283 (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1284 _ => false,
1285 }
1286}
1287
1288fn lua_value_to_f64(val: Value) -> Option<f64> {
1289 match val {
1290 Value::Integer(i) => Some(i as f64),
1291 Value::Number(f) => Some(f),
1292 _ => None,
1293 }
1294}
1295
1296fn format_lua_value(val: &Value) -> String {
1297 match val {
1298 Value::Nil => "nil".to_string(),
1299 Value::Boolean(b) => b.to_string(),
1300 Value::Integer(i) => i.to_string(),
1301 Value::Number(f) => f.to_string(),
1302 Value::String(s) => match s.to_str() {
1303 Ok(v) => v.to_string(),
1304 Err(_) => "<invalid utf-8>".to_string(),
1305 },
1306 _ => format!("<{}>", val.type_name()),
1307 }
1308}
1309
1310type WsSink = Rc<
1311 tokio::sync::Mutex<
1312 futures_util::stream::SplitSink<
1313 tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1314 tokio_tungstenite::tungstenite::Message,
1315 >,
1316 >,
1317>;
1318type WsStream = Rc<
1319 tokio::sync::Mutex<
1320 futures_util::stream::SplitStream<
1321 tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1322 >,
1323 >,
1324>;
1325
1326struct WsConn {
1327 sink: WsSink,
1328 stream: WsStream,
1329}
1330impl UserData for WsConn {}
1331
1332fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1333 let ud = match val {
1334 Value::UserData(ud) => ud,
1335 _ => {
1336 return Err(mlua::Error::runtime(format!(
1337 "{fn_name}: first argument must be a ws connection"
1338 )));
1339 }
1340 };
1341 let ws = ud.borrow::<WsConn>().map_err(|_| {
1342 mlua::Error::runtime(format!("{fn_name}: first argument must be a ws connection"))
1343 })?;
1344 Ok((ws.sink.clone(), ws.stream.clone()))
1345}
1346
1347fn register_ws(lua: &Lua) -> mlua::Result<()> {
1348 let ws_table = lua.create_table()?;
1349
1350 let connect_fn = lua.create_async_function(|lua, url: String| async move {
1351 let (stream, _response) = tokio_tungstenite::connect_async(&url)
1352 .await
1353 .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1354 let (sink, read) = stream.split();
1355 lua.create_any_userdata(WsConn {
1356 sink: Rc::new(tokio::sync::Mutex::new(sink)),
1357 stream: Rc::new(tokio::sync::Mutex::new(read)),
1358 })
1359 })?;
1360 ws_table.set("connect", connect_fn)?;
1361
1362 let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1363 let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1364 sink.lock()
1365 .await
1366 .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1367 .await
1368 .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1369 Ok(())
1370 })?;
1371 ws_table.set("send", send_fn)?;
1372
1373 let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1374 let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1375 loop {
1376 let msg = stream
1377 .lock()
1378 .await
1379 .next()
1380 .await
1381 .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1382 .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1383 match msg {
1384 tokio_tungstenite::tungstenite::Message::Text(t) => {
1385 return Ok(t.to_string());
1386 }
1387 tokio_tungstenite::tungstenite::Message::Binary(b) => {
1388 return String::from_utf8(b.into())
1389 .map_err(|e| mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}")));
1390 }
1391 tokio_tungstenite::tungstenite::Message::Close(_) => {
1392 return Err(mlua::Error::runtime("ws.recv: connection closed"));
1393 }
1394 _ => continue,
1395 }
1396 }
1397 })?;
1398 ws_table.set("recv", recv_fn)?;
1399
1400 let close_fn = lua.create_async_function(|_, conn: Value| async move {
1401 let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1402 sink.lock()
1403 .await
1404 .close()
1405 .await
1406 .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1407 Ok(())
1408 })?;
1409 ws_table.set("close", close_fn)?;
1410
1411 lua.globals().set("ws", ws_table)?;
1412 Ok(())
1413}
1414
1415fn register_template(lua: &Lua) -> mlua::Result<()> {
1416 let tmpl_table = lua.create_table()?;
1417
1418 let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1419 let json_vars = match &vars {
1420 Value::Table(_) => lua_value_to_json(&vars)?,
1421 Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1422 _ => {
1423 return Err(mlua::Error::runtime(
1424 "template.render_string: second argument must be a table or nil",
1425 ));
1426 }
1427 };
1428 let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1429 let env = minijinja::Environment::new();
1430 let tmpl = env
1431 .template_from_str(&template_str)
1432 .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1433 tmpl.render(mini_vars)
1434 .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1435 })?;
1436 tmpl_table.set("render_string", render_string_fn)?;
1437
1438 let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1439 let content = std::fs::read_to_string(&file_path).map_err(|e| {
1440 mlua::Error::runtime(format!(
1441 "template.render: failed to read {file_path:?}: {e}"
1442 ))
1443 })?;
1444 let json_vars = match &vars {
1445 Value::Table(_) => lua_value_to_json(&vars)?,
1446 Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1447 _ => {
1448 return Err(mlua::Error::runtime(
1449 "template.render: second argument must be a table or nil",
1450 ));
1451 }
1452 };
1453 let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1454 let env = minijinja::Environment::new();
1455 let tmpl = env
1456 .template_from_str(&content)
1457 .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1458 tmpl.render(mini_vars)
1459 .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1460 })?;
1461 tmpl_table.set("render", render_fn)?;
1462
1463 lua.globals().set("template", tmpl_table)?;
1464 Ok(())
1465}
1466
1467fn lua_string_literal(s: &str) -> String {
1468 let escaped = s
1469 .replace('\\', "\\\\")
1470 .replace('"', "\\\"")
1471 .replace('\n', "\\n")
1472 .replace('\r', "\\r")
1473 .replace('\0', "\\0");
1474 format!("\"{escaped}\"")
1475}
1476
1477#[cfg(test)]
1478mod tests {
1479 use super::*;
1480
1481 #[test]
1482 fn test_base64_roundtrip() {
1483 let input = "hello world";
1484 let encoded = BASE64.encode(input.as_bytes());
1485 assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1486 let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1487 assert_eq!(String::from_utf8(decoded).unwrap(), input);
1488 }
1489
1490 #[test]
1491 fn test_base64_empty() {
1492 let encoded = BASE64.encode(b"");
1493 assert_eq!(encoded, "");
1494 let decoded = BASE64.decode(b"").unwrap();
1495 assert!(decoded.is_empty());
1496 }
1497
1498 #[test]
1499 fn test_lua_value_to_json_nil() {
1500 let result = lua_value_to_json(&Value::Nil).unwrap();
1501 assert_eq!(result, serde_json::Value::Null);
1502 }
1503
1504 #[test]
1505 fn test_lua_value_to_json_bool() {
1506 assert_eq!(
1507 lua_value_to_json(&Value::Boolean(true)).unwrap(),
1508 serde_json::Value::Bool(true)
1509 );
1510 }
1511
1512 #[test]
1513 fn test_lua_value_to_json_integer() {
1514 assert_eq!(
1515 lua_value_to_json(&Value::Integer(42)).unwrap(),
1516 serde_json::json!(42)
1517 );
1518 }
1519
1520 #[test]
1521 fn test_lua_value_to_json_number() {
1522 assert_eq!(
1523 lua_value_to_json(&Value::Number(1.5)).unwrap(),
1524 serde_json::json!(1.5)
1525 );
1526 }
1527
1528 #[test]
1529 fn test_lua_values_equal_nil() {
1530 assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1531 }
1532
1533 #[test]
1534 fn test_lua_values_equal_int_float() {
1535 assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1536 }
1537
1538 #[test]
1539 fn test_lua_values_not_equal() {
1540 assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1541 }
1542
1543 #[test]
1544 fn test_format_lua_value() {
1545 assert_eq!(format_lua_value(&Value::Nil), "nil");
1546 assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1547 assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1548 }
1549
1550 #[test]
1551 fn test_lua_string_literal_escaping() {
1552 assert_eq!(lua_string_literal("hello"), "\"hello\"");
1553 assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1554 assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1555 }
1556}