Skip to main content

assay_lua/lua/
builtins.rs

1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
24    register_http(lua, client)?;
25    register_json(lua)?;
26    register_yaml(lua)?;
27    register_toml(lua)?;
28    register_assert(lua)?;
29    register_log(lua)?;
30    register_env(lua)?;
31    register_sleep(lua)?;
32    register_time(lua)?;
33    register_fs(lua)?;
34    register_base64(lua)?;
35    register_crypto(lua)?;
36    register_regex(lua)?;
37    register_async(lua)?;
38    register_db(lua)?;
39    register_ws(lua)?;
40    register_template(lua)?;
41    Ok(())
42}
43
44fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
45    let http_table = lua.create_table()?;
46
47    for method in ["get", "post", "put", "patch", "delete"] {
48        let method_client = client.clone();
49        let method_name = method.to_string();
50        let has_body = method != "get" && method != "delete";
51
52        let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
53            let client = method_client.clone();
54            let method_name = method_name.clone();
55            async move {
56                let mut args_iter = args.into_iter();
57                let url: String = match args_iter.next() {
58                    Some(Value::String(s)) => s.to_str()?.to_string(),
59                    _ => {
60                        return Err(mlua::Error::runtime(format!(
61                            "http.{method_name}: first argument must be a URL string"
62                        )));
63                    }
64                };
65
66                let (body_str, auto_json, opts) = if has_body {
67                    let (body, is_json) = match args_iter.next() {
68                        Some(Value::String(s)) => (s.to_str()?.to_string(), false),
69                        Some(Value::Table(t)) => {
70                            let json_val = lua_table_to_json(&t)?;
71                            let serialized = serde_json::to_string(&json_val).map_err(|e| {
72                                mlua::Error::runtime(format!(
73                                    "http.{method_name}: JSON encode failed: {e}"
74                                ))
75                            })?;
76                            (serialized, true)
77                        }
78                        Some(Value::Nil) | None => (String::new(), false),
79                        _ => {
80                            return Err(mlua::Error::runtime(format!(
81                                "http.{method_name}: second argument must be a string, table, or nil"
82                            )));
83                        }
84                    };
85                    let opts = match args_iter.next() {
86                        Some(Value::Table(t)) => Some(t),
87                        Some(Value::Nil) | None => None,
88                        _ => {
89                            return Err(mlua::Error::runtime(format!(
90                                "http.{method_name}: third argument must be a table or nil"
91                            )));
92                        }
93                    };
94                    (body, is_json, opts)
95                } else {
96                    let opts = match args_iter.next() {
97                        Some(Value::Table(t)) => Some(t),
98                        Some(Value::Nil) | None => None,
99                        _ => {
100                            return Err(mlua::Error::runtime(format!(
101                                "http.{method_name}: second argument must be a table or nil"
102                            )));
103                        }
104                    };
105                    (String::new(), false, opts)
106                };
107
108                let mut req = match method_name.as_str() {
109                    "get" => client.get(&url),
110                    "post" => client.post(&url),
111                    "put" => client.put(&url),
112                    "patch" => client.patch(&url),
113                    "delete" => client.delete(&url),
114                    _ => unreachable!(),
115                };
116
117                if has_body && !body_str.is_empty() {
118                    req = req.body(body_str);
119                }
120                if auto_json {
121                    req = req.header("Content-Type", "application/json");
122                }
123                if let Some(ref opts_table) = opts
124                    && let Ok(headers_table) = opts_table.get::<Table>("headers")
125                {
126                    for pair in headers_table.pairs::<String, String>() {
127                        let (k, v) = pair?;
128                        req = req.header(k, v);
129                    }
130                }
131
132                let resp = req.send().await.map_err(|e| {
133                    mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
134                })?;
135                let status = resp.status().as_u16();
136                let resp_headers = resp.headers().clone();
137                let body = resp.text().await.map_err(|e| {
138                    mlua::Error::runtime(format!(
139                        "http.{method_name}: reading body failed: {e}"
140                    ))
141                })?;
142
143                let result = lua.create_table()?;
144                result.set("status", status)?;
145                result.set("body", body)?;
146
147                let headers_out = lua.create_table()?;
148                for (name, value) in &resp_headers {
149                    if let Ok(v) = value.to_str() {
150                        headers_out.set(name.as_str().to_string(), v.to_string())?;
151                    }
152                }
153                result.set("headers", headers_out)?;
154
155                Ok(Value::Table(result))
156            }
157        })?;
158        http_table.set(method, func)?;
159    }
160
161    let serve_fn =
162        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
163            let mut args_iter = args.into_iter();
164
165            let port: u16 = match args_iter.next() {
166                Some(Value::Integer(n)) => n as u16,
167                _ => {
168                    return Err::<(), _>(mlua::Error::runtime(
169                        "http.serve: first argument must be a port number",
170                    ));
171                }
172            };
173
174            let routes_table = match args_iter.next() {
175                Some(Value::Table(t)) => t,
176                _ => {
177                    return Err::<(), _>(mlua::Error::runtime(
178                        "http.serve: second argument must be a routes table",
179                    ));
180                }
181            };
182
183            let routes = Rc::new(parse_routes(&routes_table)?);
184
185            let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
186                .await
187                .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
188
189            loop {
190                let (stream, _addr) = listener.accept().await.map_err(|e| {
191                    mlua::Error::runtime(format!("http.serve: accept failed: {e}"))
192                })?;
193
194                let routes = routes.clone();
195                let lua_clone = lua.clone();
196
197                tokio::task::spawn_local(async move {
198                    let io = hyper_util::rt::TokioIo::new(stream);
199                    let routes = routes.clone();
200                    let lua = lua_clone.clone();
201
202                    let service = service_fn(move |req: Request<Incoming>| {
203                        let routes = routes.clone();
204                        let lua = lua.clone();
205                        async move { handle_request(&lua, &routes, req).await }
206                    });
207
208                    if let Err(e) = http1::Builder::new().serve_connection(io, service).await
209                        && !e.to_string().contains("connection closed")
210                    {
211                        error!("http.serve: connection error: {e}");
212                    }
213                });
214            }
215        })?;
216    http_table.set("serve", serve_fn)?;
217
218    lua.globals().set("http", http_table)?;
219    Ok(())
220}
221
222fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
223    let mut routes = HashMap::new();
224    for method_pair in routes_table.pairs::<String, Table>() {
225        let (method, paths_table) = method_pair?;
226        let method_upper = method.to_uppercase();
227        for path_pair in paths_table.pairs::<String, mlua::Function>() {
228            let (path, func) = path_pair?;
229            routes.insert((method_upper.clone(), path), func);
230        }
231    }
232    Ok(routes)
233}
234
235async fn handle_request(
236    lua: &Lua,
237    routes: &HashMap<(String, String), mlua::Function>,
238    req: Request<Incoming>,
239) -> Result<Response<Full<Bytes>>, hyper::Error> {
240    let method = req.method().to_string();
241    let path = req.uri().path().to_string();
242    let query = req.uri().query().unwrap_or("").to_string();
243    let headers: Vec<(String, String)> = req
244        .headers()
245        .iter()
246        .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
247        .collect();
248
249    let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
250        Ok(collected) => collected.to_bytes(),
251        Err(_) => Bytes::new(),
252    };
253    let body_str = String::from_utf8_lossy(&body_bytes).to_string();
254
255    let key = (method.clone(), path.clone());
256    let handler = match routes.get(&key) {
257        Some(f) => f,
258        None => {
259            return Ok(Response::builder()
260                .status(StatusCode::NOT_FOUND)
261                .header("content-type", "text/plain")
262                .body(Full::new(Bytes::from("not found")))
263                .unwrap());
264        }
265    };
266
267    match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
268        Ok(lua_resp) => lua_response_to_http(&lua_resp),
269        Err(e) => Ok(Response::builder()
270            .status(StatusCode::INTERNAL_SERVER_ERROR)
271            .header("content-type", "text/plain")
272            .body(Full::new(Bytes::from(format!("handler error: {e}"))))
273            .unwrap()),
274    }
275}
276
277fn build_lua_request_and_call(
278    lua: &Lua,
279    handler: &mlua::Function,
280    method: &str,
281    path: &str,
282    query: &str,
283    headers: &[(String, String)],
284    body: &str,
285) -> mlua::Result<Table> {
286    let req_table = lua.create_table()?;
287    req_table.set("method", method.to_string())?;
288    req_table.set("path", path.to_string())?;
289    req_table.set("query", query.to_string())?;
290    req_table.set("body", body.to_string())?;
291
292    let headers_table = lua.create_table()?;
293    for (k, v) in headers {
294        headers_table.set(k.as_str(), v.as_str())?;
295    }
296    req_table.set("headers", headers_table)?;
297
298    handler.call::<Table>(req_table)
299}
300
301fn lua_response_to_http(
302    resp_table: &Table,
303) -> Result<Response<Full<Bytes>>, hyper::Error> {
304    let status = resp_table
305        .get::<Option<u16>>("status")
306        .unwrap_or(None)
307        .unwrap_or(200);
308
309    let mut builder =
310        Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
311
312    if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
313        for (k, v) in headers_table.pairs::<String, String>().flatten() {
314            builder = builder.header(k, v);
315        }
316    }
317
318    let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
319        let json_val =
320            lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
321        let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
322        builder = builder.header("content-type", "application/json");
323        Bytes::from(serialized)
324    } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
325        builder = builder.header("content-type", "text/plain");
326        Bytes::from(body_str)
327    } else {
328        builder = builder.header("content-type", "text/plain");
329        Bytes::new()
330    };
331
332    Ok(builder.body(Full::new(body_bytes)).unwrap())
333}
334
335fn register_json(lua: &Lua) -> mlua::Result<()> {
336    let json_table = lua.create_table()?;
337
338    let parse_fn = lua.create_function(|lua, s: String| {
339        let value: serde_json::Value = serde_json::from_str(&s)
340            .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
341        json_value_to_lua(lua, &value)
342    })?;
343    json_table.set("parse", parse_fn)?;
344
345    let encode_fn = lua.create_function(|_, val: Value| {
346        let json_val = lua_value_to_json(&val)?;
347        serde_json::to_string(&json_val)
348            .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
349    })?;
350    json_table.set("encode", encode_fn)?;
351
352    lua.globals().set("json", json_table)?;
353    Ok(())
354}
355
356fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
357    let mut is_array = true;
358    let mut max_index: i64 = 0;
359    let mut count: i64 = 0;
360
361    for pair in table.clone().pairs::<Value, Value>() {
362        let (key, _) = pair?;
363        count += 1;
364        match key {
365            Value::Integer(i) if i >= 1 => {
366                if i > max_index {
367                    max_index = i;
368                }
369            }
370            _ => {
371                is_array = false;
372                break;
373            }
374        }
375    }
376
377    if is_array && max_index == count {
378        let mut arr = Vec::with_capacity(max_index as usize);
379        for i in 1..=max_index {
380            let val: Value = table.get(i)?;
381            arr.push(lua_value_to_json(&val)?);
382        }
383        Ok(serde_json::Value::Array(arr))
384    } else {
385        let mut map = serde_json::Map::new();
386        for pair in table.clone().pairs::<Value, Value>() {
387            let (key, val) = pair?;
388            let key_str = match key {
389                Value::String(s) => s.to_str()?.to_string(),
390                Value::Integer(i) => i.to_string(),
391                Value::Number(f) => f.to_string(),
392                _ => {
393                    return Err(mlua::Error::runtime(format!(
394                        "unsupported table key type: {}",
395                        key.type_name()
396                    )));
397                }
398            };
399            map.insert(key_str, lua_value_to_json(&val)?);
400        }
401        Ok(serde_json::Value::Object(map))
402    }
403}
404
405fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
406    match val {
407        Value::Nil => Ok(serde_json::Value::Null),
408        Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
409        Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
410        Value::Number(f) => serde_json::Number::from_f64(*f)
411            .map(serde_json::Value::Number)
412            .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
413        Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
414        Value::Table(t) => lua_table_to_json(t),
415        _ => Err(mlua::Error::runtime(format!(
416            "unsupported Lua type for JSON: {}",
417            val.type_name()
418        ))),
419    }
420}
421
422pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
423    match val {
424        serde_json::Value::Null => Ok(Value::Nil),
425        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
426        serde_json::Value::Number(n) => {
427            if let Some(i) = n.as_i64() {
428                Ok(Value::Integer(i))
429            } else if let Some(f) = n.as_f64() {
430                Ok(Value::Number(f))
431            } else {
432                Ok(Value::Nil)
433            }
434        }
435        serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
436        serde_json::Value::Array(arr) => {
437            let table = lua.create_table()?;
438            for (i, item) in arr.iter().enumerate() {
439                table.set(i + 1, json_value_to_lua(lua, item)?)?;
440            }
441            Ok(Value::Table(table))
442        }
443        serde_json::Value::Object(map) => {
444            let table = lua.create_table()?;
445            for (k, v) in map {
446                table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
447            }
448            Ok(Value::Table(table))
449        }
450    }
451}
452
453fn register_assert(lua: &Lua) -> mlua::Result<()> {
454    let assert_table = lua.create_table()?;
455
456    let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
457        let mut args_iter = args.into_iter();
458        let a = args_iter.next().unwrap_or(Value::Nil);
459        let b = args_iter.next().unwrap_or(Value::Nil);
460        let msg = extract_string_arg(lua, args_iter.next());
461
462        if !lua_values_equal(&a, &b) {
463            let detail = format!(
464                "assert.eq failed: {:?} != {:?}{}",
465                format_lua_value(&a),
466                format_lua_value(&b),
467                msg.map(|m| format!(" - {m}")).unwrap_or_default()
468            );
469            return Err(mlua::Error::runtime(detail));
470        }
471        Ok(())
472    })?;
473    assert_table.set("eq", eq_fn)?;
474
475    let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
476        let mut args_iter = args.into_iter();
477        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
478        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
479        let msg = extract_string_arg(lua, args_iter.next());
480
481        match (a, b) {
482            (Some(va), Some(vb)) if va > vb => Ok(()),
483            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
484                "assert.gt failed: {va} is not > {vb}{}",
485                msg.map(|m| format!(" - {m}")).unwrap_or_default()
486            ))),
487            _ => Err(mlua::Error::runtime(
488                "assert.gt: both arguments must be numbers",
489            )),
490        }
491    })?;
492    assert_table.set("gt", gt_fn)?;
493
494    let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
495        let mut args_iter = args.into_iter();
496        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
497        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
498        let msg = extract_string_arg(lua, args_iter.next());
499
500        match (a, b) {
501            (Some(va), Some(vb)) if va < vb => Ok(()),
502            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
503                "assert.lt failed: {va} is not < {vb}{}",
504                msg.map(|m| format!(" - {m}")).unwrap_or_default()
505            ))),
506            _ => Err(mlua::Error::runtime(
507                "assert.lt: both arguments must be numbers",
508            )),
509        }
510    })?;
511    assert_table.set("lt", lt_fn)?;
512
513    let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
514        let mut args_iter = args.into_iter();
515        let haystack: String = match args_iter.next() {
516            Some(Value::String(s)) => s.to_str()?.to_string(),
517            _ => {
518                return Err(mlua::Error::runtime(
519                    "assert.contains: first argument must be a string",
520                ));
521            }
522        };
523        let needle: String = match args_iter.next() {
524            Some(Value::String(s)) => s.to_str()?.to_string(),
525            _ => {
526                return Err(mlua::Error::runtime(
527                    "assert.contains: second argument must be a string",
528                ));
529            }
530        };
531        let msg = extract_string_arg(lua, args_iter.next());
532
533        if !haystack.contains(&needle) {
534            return Err(mlua::Error::runtime(format!(
535                "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
536                msg.map(|m| format!(" - {m}")).unwrap_or_default()
537            )));
538        }
539        Ok(())
540    })?;
541    assert_table.set("contains", contains_fn)?;
542
543    let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
544        let mut args_iter = args.into_iter();
545        let val = args_iter.next().unwrap_or(Value::Nil);
546        let msg = extract_string_arg(lua, args_iter.next());
547
548        if val == Value::Nil {
549            return Err(mlua::Error::runtime(format!(
550                "assert.not_nil failed: value is nil{}",
551                msg.map(|m| format!(" - {m}")).unwrap_or_default()
552            )));
553        }
554        Ok(())
555    })?;
556    assert_table.set("not_nil", not_nil_fn)?;
557
558    let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
559        let mut args_iter = args.into_iter();
560        let text: String = match args_iter.next() {
561            Some(Value::String(s)) => s.to_str()?.to_string(),
562            _ => {
563                return Err(mlua::Error::runtime(
564                    "assert.matches: first argument must be a string",
565                ));
566            }
567        };
568        let pattern: String = match args_iter.next() {
569            Some(Value::String(s)) => s.to_str()?.to_string(),
570            _ => {
571                return Err(mlua::Error::runtime(
572                    "assert.matches: second argument must be a pattern string",
573                ));
574            }
575        };
576        let msg = extract_string_arg(lua, args_iter.next());
577
578        let found: bool = lua
579            .load(format!(
580                "return string.find({}, {}) ~= nil",
581                lua_string_literal(&text),
582                lua_string_literal(&pattern)
583            ))
584            .eval()
585            .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
586
587        if !found {
588            return Err(mlua::Error::runtime(format!(
589                "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
590                msg.map(|m| format!(" - {m}")).unwrap_or_default()
591            )));
592        }
593        Ok(())
594    })?;
595    assert_table.set("matches", matches_fn)?;
596
597    lua.globals().set("assert", assert_table)?;
598    Ok(())
599}
600
601fn register_log(lua: &Lua) -> mlua::Result<()> {
602    let log_table = lua.create_table()?;
603
604    let info_fn = lua.create_function(|_, msg: String| {
605        info!(target: "lua", "{}", msg);
606        Ok(())
607    })?;
608    log_table.set("info", info_fn)?;
609
610    let warn_fn = lua.create_function(|_, msg: String| {
611        warn!(target: "lua", "{}", msg);
612        Ok(())
613    })?;
614    log_table.set("warn", warn_fn)?;
615
616    let error_fn = lua.create_function(|_, msg: String| {
617        error!(target: "lua", "{}", msg);
618        Ok(())
619    })?;
620    log_table.set("error", error_fn)?;
621
622    lua.globals().set("log", log_table)?;
623    Ok(())
624}
625
626fn register_env(lua: &Lua) -> mlua::Result<()> {
627    let env_table = lua.create_table()?;
628
629    let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
630        Ok(val) => Ok(Some(val)),
631        Err(_) => Ok(None),
632    })?;
633    env_table.set("_process_get", process_get_fn)?;
634    env_table.set("_check_env", lua.create_table()?)?;
635
636    lua.globals().set("env", env_table)?;
637
638    lua.load(
639        r#"
640        function env.get(name)
641            local val = env._check_env[name]
642            if val ~= nil then return val end
643            return env._process_get(name)
644        end
645        "#,
646    )
647    .exec()?;
648
649    Ok(())
650}
651
652fn register_sleep(lua: &Lua) -> mlua::Result<()> {
653    let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
654        let duration = std::time::Duration::from_secs_f64(seconds);
655        tokio::time::sleep(duration).await;
656        Ok(())
657    })?;
658    lua.globals().set("sleep", sleep_fn)?;
659    Ok(())
660}
661
662fn register_time(lua: &Lua) -> mlua::Result<()> {
663    let time_fn = lua.create_function(|_, ()| {
664        let secs = SystemTime::now()
665            .duration_since(UNIX_EPOCH)
666            .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
667            .as_secs_f64();
668        Ok(secs)
669    })?;
670    lua.globals().set("time", time_fn)?;
671    Ok(())
672}
673
674fn register_fs(lua: &Lua) -> mlua::Result<()> {
675    let fs_table = lua.create_table()?;
676
677    let read_fn = lua.create_function(|_, path: String| {
678        std::fs::read_to_string(&path)
679            .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
680    })?;
681    fs_table.set("read", read_fn)?;
682
683    let write_fn = lua.create_function(|_, (path, content): (String, String)| {
684        let p = std::path::Path::new(&path);
685        if let Some(parent) = p.parent() {
686            std::fs::create_dir_all(parent).map_err(|e| {
687                mlua::Error::runtime(format!(
688                    "fs.write: failed to create directories for {path:?}: {e}"
689                ))
690            })?;
691        }
692        std::fs::write(&path, &content)
693            .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
694    })?;
695    fs_table.set("write", write_fn)?;
696
697    lua.globals().set("fs", fs_table)?;
698    Ok(())
699}
700
701fn register_base64(lua: &Lua) -> mlua::Result<()> {
702    let b64_table = lua.create_table()?;
703
704    let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
705    b64_table.set("encode", encode_fn)?;
706
707    let decode_fn = lua.create_function(|_, input: String| {
708        let bytes = BASE64
709            .decode(input.as_bytes())
710            .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
711        String::from_utf8(bytes)
712            .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
713    })?;
714    b64_table.set("decode", decode_fn)?;
715
716    lua.globals().set("base64", b64_table)?;
717    Ok(())
718}
719
720fn register_crypto(lua: &Lua) -> mlua::Result<()> {
721    let crypto_table = lua.create_table()?;
722
723    let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
724        let mut args_iter = args.into_iter();
725
726        let claims_table = match args_iter.next() {
727            Some(Value::Table(t)) => t,
728            _ => {
729                return Err(mlua::Error::runtime(
730                    "crypto.jwt_sign: first argument must be a claims table",
731                ));
732            }
733        };
734
735        let pem_key: String = match args_iter.next() {
736            Some(Value::String(s)) => s.to_str()?.to_string(),
737            _ => {
738                return Err(mlua::Error::runtime(
739                    "crypto.jwt_sign: second argument must be a PEM key string",
740                ));
741            }
742        };
743
744        let algorithm = match args_iter.next() {
745            Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
746                "RS256" => Algorithm::RS256,
747                "RS384" => Algorithm::RS384,
748                "RS512" => Algorithm::RS512,
749                other => {
750                    return Err(mlua::Error::runtime(format!(
751                        "crypto.jwt_sign: unsupported algorithm: {other}"
752                    )));
753                }
754            },
755            Some(Value::Nil) | None => Algorithm::RS256,
756            _ => {
757                return Err(mlua::Error::runtime(
758                    "crypto.jwt_sign: third argument must be an algorithm string or nil",
759                ));
760            }
761        };
762
763        let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
764        let pem_bytes = Zeroizing::new(pem_key.into_bytes());
765        let key = EncodingKey::from_rsa_pem(&pem_bytes)
766            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
767
768        let header = Header::new(algorithm);
769        let token = jsonwebtoken::encode(&header, &claims_json, &key)
770            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
771
772        Ok(token)
773    })?;
774    crypto_table.set("jwt_sign", jwt_sign_fn)?;
775
776    let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
777        let mut args_iter = args.into_iter();
778
779        let input: String = match args_iter.next() {
780            Some(Value::String(s)) => s.to_str()?.to_string(),
781            _ => {
782                return Err(mlua::Error::runtime(
783                    "crypto.hash: first argument must be a string",
784                ));
785            }
786        };
787
788        let algorithm: String = match args_iter.next() {
789            Some(Value::String(s)) => s.to_str()?.to_lowercase(),
790            Some(Value::Nil) | None => "sha256".to_string(),
791            _ => {
792                return Err(mlua::Error::runtime(
793                    "crypto.hash: second argument must be an algorithm string or nil",
794                ));
795            }
796        };
797
798        let hex = match algorithm.as_str() {
799            "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
800            "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
801            "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
802            "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
803            "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
804            "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
805            "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
806            "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
807            other => {
808                return Err(mlua::Error::runtime(format!(
809                    "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
810                )));
811            }
812        };
813
814        Ok(hex)
815    })?;
816    crypto_table.set("hash", hash_fn)?;
817
818    let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
819        let mut args_iter = args.into_iter();
820
821        let length: usize = match args_iter.next() {
822            Some(Value::Integer(n)) if n > 0 => n as usize,
823            Some(Value::Integer(n)) => {
824                return Err(mlua::Error::runtime(format!(
825                    "crypto.random: length must be positive, got {n}"
826                )));
827            }
828            Some(Value::Nil) | None => 32,
829            _ => {
830                return Err(mlua::Error::runtime(
831                    "crypto.random: first argument must be a positive integer or nil",
832                ));
833            }
834        };
835
836        let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
837        let mut rng = rand::rng();
838        let result: String = (0..length)
839            .map(|_| charset[rng.random_range(..charset.len())] as char)
840            .collect();
841
842        Ok(result)
843    })?;
844    crypto_table.set("random", random_fn)?;
845
846    lua.globals().set("crypto", crypto_table)?;
847    Ok(())
848}
849
850fn register_yaml(lua: &Lua) -> mlua::Result<()> {
851    let yaml_table = lua.create_table()?;
852
853    let parse_fn = lua.create_function(|lua, s: String| {
854        let json_val: serde_json::Value = serde_yml::from_str(&s)
855            .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
856        json_value_to_lua(lua, &json_val)
857    })?;
858    yaml_table.set("parse", parse_fn)?;
859
860    let encode_fn = lua.create_function(|_, val: Value| {
861        let json_val = lua_value_to_json(&val)?;
862        serde_yml::to_string(&json_val)
863            .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
864    })?;
865    yaml_table.set("encode", encode_fn)?;
866
867    lua.globals().set("yaml", yaml_table)?;
868    Ok(())
869}
870
871fn register_toml(lua: &Lua) -> mlua::Result<()> {
872    let toml_table = lua.create_table()?;
873
874    let parse_fn = lua.create_function(|lua, s: String| {
875        let toml_val: toml::Value = toml::from_str(&s)
876            .map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
877        let json_val = serde_json::to_value(&toml_val)
878            .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
879        json_value_to_lua(lua, &json_val)
880    })?;
881    toml_table.set("parse", parse_fn)?;
882
883    let encode_fn = lua.create_function(|_, val: Value| {
884        let json_val = lua_value_to_json(&val)?;
885        let toml_val: toml::Value = serde_json::from_value(json_val)
886            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
887        toml::to_string_pretty(&toml_val)
888            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
889    })?;
890    toml_table.set("encode", encode_fn)?;
891
892    lua.globals().set("toml", toml_table)?;
893    Ok(())
894}
895
896fn register_regex(lua: &Lua) -> mlua::Result<()> {
897    let regex_table = lua.create_table()?;
898
899    let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
900        let re = regex_lite::Regex::new(&pattern)
901            .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
902        Ok(re.is_match(&text))
903    })?;
904    regex_table.set("match", match_fn)?;
905
906    let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
907        let re = regex_lite::Regex::new(&pattern)
908            .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
909        match re.captures(&text) {
910            Some(caps) => {
911                let result = lua.create_table()?;
912                let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
913                result.set("match", full_match.to_string())?;
914                let groups = lua.create_table()?;
915                for i in 1..caps.len() {
916                    if let Some(m) = caps.get(i) {
917                        groups.set(i, m.as_str().to_string())?;
918                    }
919                }
920                result.set("groups", groups)?;
921                Ok(Value::Table(result))
922            }
923            None => Ok(Value::Nil),
924        }
925    })?;
926    regex_table.set("find", find_fn)?;
927
928    let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
929        let re = regex_lite::Regex::new(&pattern)
930            .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
931        let results = lua.create_table()?;
932        for (i, m) in re.find_iter(&text).enumerate() {
933            results.set(i + 1, m.as_str().to_string())?;
934        }
935        Ok(results)
936    })?;
937    regex_table.set("find_all", find_all_fn)?;
938
939    let replace_fn = lua.create_function(
940        |_, (text, pattern, replacement): (String, String, String)| {
941            let re = regex_lite::Regex::new(&pattern).map_err(|e| {
942                mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
943            })?;
944            Ok(re.replace_all(&text, replacement.as_str()).into_owned())
945        },
946    )?;
947    regex_table.set("replace", replace_fn)?;
948
949    lua.globals().set("regex", regex_table)?;
950    Ok(())
951}
952
953fn register_async(lua: &Lua) -> mlua::Result<()> {
954    let async_table = lua.create_table()?;
955
956    let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
957        let thread = lua.create_thread(func)?;
958        let async_thread = thread.into_async::<mlua::MultiValue>(())?;
959        let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
960            tokio::task::spawn_local(async move {
961                let values = async_thread.await.map_err(|e| e.to_string())?;
962                Ok(values.into_vec())
963            });
964
965        let handle = lua.create_table()?;
966        let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
967        let cell_clone = cell.clone();
968
969        let await_fn = lua.create_async_function(move |lua, ()| {
970            let cell = cell_clone.clone();
971            async move {
972                let join_handle = cell
973                    .borrow_mut()
974                    .take()
975                    .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
976                let result = join_handle.await.map_err(|e| {
977                    mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
978                })?;
979                match result {
980                    Ok(values) => {
981                        let tbl = lua.create_table()?;
982                        for (i, v) in values.into_iter().enumerate() {
983                            tbl.set(i + 1, v)?;
984                        }
985                        Ok(Value::Table(tbl))
986                    }
987                    Err(msg) => Err(mlua::Error::runtime(msg)),
988                }
989            }
990        })?;
991        handle.set("await", await_fn)?;
992
993        Ok(handle)
994    })?;
995    async_table.set("spawn", spawn_fn)?;
996
997    let spawn_interval_fn =
998        lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
999            if seconds <= 0.0 {
1000                return Err(mlua::Error::runtime(
1001                    "async.spawn_interval: interval must be positive",
1002                ));
1003            }
1004
1005            let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1006            let cancel_clone = cancel.clone();
1007
1008            tokio::task::spawn_local({
1009                let cancel = cancel_clone.clone();
1010                async move {
1011                    let mut interval =
1012                        tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1013                    interval.tick().await;
1014                    loop {
1015                        interval.tick().await;
1016                        if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1017                            break;
1018                        }
1019                        if let Err(e) = func.call_async::<()>(()).await {
1020                            error!("async.spawn_interval: callback error: {e}");
1021                            break;
1022                        }
1023                    }
1024                }
1025            });
1026
1027            let handle = lua.create_table()?;
1028            let cancel_fn = lua.create_function(move |_, ()| {
1029                cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1030                Ok(())
1031            })?;
1032            handle.set("cancel", cancel_fn)?;
1033
1034            Ok(handle)
1035        })?;
1036    async_table.set("spawn_interval", spawn_interval_fn)?;
1037
1038    lua.globals().set("async", async_table)?;
1039    Ok(())
1040}
1041
1042struct DbPool(Arc<AnyPool>);
1043impl UserData for DbPool {}
1044
1045fn register_db(lua: &Lua) -> mlua::Result<()> {
1046    sqlx::any::install_default_drivers();
1047
1048    let db_table = lua.create_table()?;
1049
1050    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1051        let pool = sqlx::any::AnyPoolOptions::new()
1052            .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1053            .connect(&url)
1054            .await
1055            .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1056        lua.create_any_userdata(DbPool(Arc::new(pool)))
1057    })?;
1058    db_table.set("connect", connect_fn)?;
1059
1060    let query_fn =
1061        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1062            let mut args_iter = args.into_iter();
1063
1064            let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1065            let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1066            let params = extract_params(&args_iter.next())?;
1067
1068            let mut query = sqlx::query(&sql);
1069            for p in &params {
1070                query = bind_param(query, p);
1071            }
1072
1073            let rows: Vec<AnyRow> = query
1074                .fetch_all(&*pool)
1075                .await
1076                .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1077
1078            let result = lua.create_table()?;
1079            for (i, row) in rows.iter().enumerate() {
1080                let row_table = any_row_to_lua_table(&lua, row)?;
1081                result.set(i + 1, row_table)?;
1082            }
1083            Ok(Value::Table(result))
1084        })?;
1085    db_table.set("query", query_fn)?;
1086
1087    let execute_fn =
1088        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1089            let mut args_iter = args.into_iter();
1090
1091            let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1092            let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1093            let params = extract_params(&args_iter.next())?;
1094
1095            let mut query = sqlx::query(&sql);
1096            for p in &params {
1097                query = bind_param(query, p);
1098            }
1099
1100            let result = query
1101                .execute(&*pool)
1102                .await
1103                .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1104
1105            let tbl = lua.create_table()?;
1106            tbl.set("rows_affected", result.rows_affected() as i64)?;
1107            Ok(Value::Table(tbl))
1108        })?;
1109    db_table.set("execute", execute_fn)?;
1110
1111    let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1112        let mut args_iter = args.into_iter();
1113        let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1114        pool.close().await;
1115        Ok(())
1116    })?;
1117    db_table.set("close", close_fn)?;
1118
1119    lua.globals().set("db", db_table)?;
1120    Ok(())
1121}
1122
1123fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1124    match val {
1125        Some(Value::UserData(ud)) => {
1126            let db = ud
1127                .borrow::<DbPool>()
1128                .map_err(|_| mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection")))?;
1129            Ok(db.0.clone())
1130        }
1131        _ => Err(mlua::Error::runtime(format!(
1132            "{fn_name}: first argument must be a db connection"
1133        ))),
1134    }
1135}
1136
1137fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1138    match val {
1139        Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1140        _ => Err(mlua::Error::runtime(format!(
1141            "{fn_name}: second argument must be a SQL string"
1142        ))),
1143    }
1144}
1145
1146#[derive(Clone)]
1147enum DbParam {
1148    Null,
1149    Bool(bool),
1150    Int(i64),
1151    Float(f64),
1152    Text(String),
1153}
1154
1155fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1156    match val {
1157        Some(Value::Table(t)) => {
1158            let mut params = Vec::new();
1159            let len = t.len()?;
1160            for i in 1..=len {
1161                let v: Value = t.get(i)?;
1162                let param = match v {
1163                    Value::Nil => DbParam::Null,
1164                    Value::Boolean(b) => DbParam::Bool(b),
1165                    Value::Integer(n) => DbParam::Int(n),
1166                    Value::Number(f) => DbParam::Float(f),
1167                    Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1168                    _ => {
1169                        return Err(mlua::Error::runtime(format!(
1170                            "db: unsupported parameter type: {}",
1171                            v.type_name()
1172                        )));
1173                    }
1174                };
1175                params.push(param);
1176            }
1177            Ok(params)
1178        }
1179        Some(Value::Nil) | None => Ok(Vec::new()),
1180        _ => Err(mlua::Error::runtime(
1181            "db: params must be a table (array) or nil",
1182        )),
1183    }
1184}
1185
1186fn bind_param<'q>(
1187    query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1188    param: &'q DbParam,
1189) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1190    match param {
1191        DbParam::Null => query.bind(None::<String>),
1192        DbParam::Bool(b) => query.bind(*b),
1193        DbParam::Int(n) => query.bind(*n),
1194        DbParam::Float(f) => query.bind(*f),
1195        DbParam::Text(s) => query.bind(s.as_str()),
1196    }
1197}
1198
1199fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1200    let table = lua.create_table()?;
1201    for col in row.columns() {
1202        let name = col.name();
1203        let val: Value = any_column_to_lua_value(lua, row, col)?;
1204        table.set(name.to_string(), val)?;
1205    }
1206    Ok(table)
1207}
1208
1209fn any_column_to_lua_value<C: Column>(
1210    lua: &Lua,
1211    row: &AnyRow,
1212    col: &C,
1213) -> mlua::Result<Value> {
1214    let ordinal = col.ordinal();
1215    let type_info = col.type_info();
1216    let type_name = type_info.to_string();
1217    let type_name = type_name.to_uppercase();
1218
1219    if row.try_get_raw(ordinal).map(|v| v.is_null()).unwrap_or(true) {
1220        return Ok(Value::Nil);
1221    }
1222
1223    match type_name.as_str() {
1224        "BOOLEAN" | "BOOL" => {
1225            let v: bool = row
1226                .try_get(ordinal)
1227                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1228            Ok(Value::Boolean(v))
1229        }
1230        "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1231            let v: i64 = row
1232                .try_get(ordinal)
1233                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1234            Ok(Value::Integer(v))
1235        }
1236        "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1237            let v: f64 = row
1238                .try_get(ordinal)
1239                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1240            Ok(Value::Number(v))
1241        }
1242        _ => {
1243            let v: String = row
1244                .try_get(ordinal)
1245                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1246            Ok(Value::String(lua.create_string(&v)?))
1247        }
1248    }
1249}
1250
1251fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1252    match val {
1253        Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1254        Some(other) => {
1255            let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1256            result
1257        }
1258        None => None,
1259    }
1260}
1261
1262fn lua_values_equal(a: &Value, b: &Value) -> bool {
1263    match (a, b) {
1264        (Value::Nil, Value::Nil) => true,
1265        (Value::Boolean(a), Value::Boolean(b)) => a == b,
1266        (Value::Integer(a), Value::Integer(b)) => a == b,
1267        (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1268        (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1269            (*a as f64 - b).abs() < f64::EPSILON
1270        }
1271        (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1272        _ => false,
1273    }
1274}
1275
1276fn lua_value_to_f64(val: Value) -> Option<f64> {
1277    match val {
1278        Value::Integer(i) => Some(i as f64),
1279        Value::Number(f) => Some(f),
1280        _ => None,
1281    }
1282}
1283
1284fn format_lua_value(val: &Value) -> String {
1285    match val {
1286        Value::Nil => "nil".to_string(),
1287        Value::Boolean(b) => b.to_string(),
1288        Value::Integer(i) => i.to_string(),
1289        Value::Number(f) => f.to_string(),
1290        Value::String(s) => match s.to_str() {
1291            Ok(v) => v.to_string(),
1292            Err(_) => "<invalid utf-8>".to_string(),
1293        },
1294        _ => format!("<{}>", val.type_name()),
1295    }
1296}
1297
1298type WsSink = Rc<
1299    tokio::sync::Mutex<
1300        futures_util::stream::SplitSink<
1301            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1302            tokio_tungstenite::tungstenite::Message,
1303        >,
1304    >,
1305>;
1306type WsStream = Rc<
1307    tokio::sync::Mutex<
1308        futures_util::stream::SplitStream<
1309            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1310        >,
1311    >,
1312>;
1313
1314struct WsConn {
1315    sink: WsSink,
1316    stream: WsStream,
1317}
1318impl UserData for WsConn {}
1319
1320fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1321    let ud = match val {
1322        Value::UserData(ud) => ud,
1323        _ => {
1324            return Err(mlua::Error::runtime(format!(
1325                "{fn_name}: first argument must be a ws connection"
1326            )));
1327        }
1328    };
1329    let ws = ud.borrow::<WsConn>().map_err(|_| {
1330        mlua::Error::runtime(format!(
1331            "{fn_name}: first argument must be a ws connection"
1332        ))
1333    })?;
1334    Ok((ws.sink.clone(), ws.stream.clone()))
1335}
1336
1337fn register_ws(lua: &Lua) -> mlua::Result<()> {
1338    let ws_table = lua.create_table()?;
1339
1340    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1341        let (stream, _response) = tokio_tungstenite::connect_async(&url)
1342            .await
1343            .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1344        let (sink, read) = stream.split();
1345        lua.create_any_userdata(WsConn {
1346            sink: Rc::new(tokio::sync::Mutex::new(sink)),
1347            stream: Rc::new(tokio::sync::Mutex::new(read)),
1348        })
1349    })?;
1350    ws_table.set("connect", connect_fn)?;
1351
1352    let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1353        let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1354        sink.lock()
1355            .await
1356            .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1357            .await
1358            .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1359        Ok(())
1360    })?;
1361    ws_table.set("send", send_fn)?;
1362
1363    let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1364        let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1365        loop {
1366            let msg = stream
1367                .lock()
1368                .await
1369                .next()
1370                .await
1371                .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1372                .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1373            match msg {
1374                tokio_tungstenite::tungstenite::Message::Text(t) => {
1375                    return Ok(t.to_string());
1376                }
1377                tokio_tungstenite::tungstenite::Message::Binary(b) => {
1378                    return String::from_utf8(b.into()).map_err(|e| {
1379                        mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}"))
1380                    });
1381                }
1382                tokio_tungstenite::tungstenite::Message::Close(_) => {
1383                    return Err(mlua::Error::runtime("ws.recv: connection closed"));
1384                }
1385                _ => continue,
1386            }
1387        }
1388    })?;
1389    ws_table.set("recv", recv_fn)?;
1390
1391    let close_fn = lua.create_async_function(|_, conn: Value| async move {
1392        let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1393        sink.lock()
1394            .await
1395            .close()
1396            .await
1397            .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1398        Ok(())
1399    })?;
1400    ws_table.set("close", close_fn)?;
1401
1402    lua.globals().set("ws", ws_table)?;
1403    Ok(())
1404}
1405
1406fn register_template(lua: &Lua) -> mlua::Result<()> {
1407    let tmpl_table = lua.create_table()?;
1408
1409    let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1410        let json_vars = match &vars {
1411            Value::Table(_) => lua_value_to_json(&vars)?,
1412            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1413            _ => {
1414                return Err(mlua::Error::runtime(
1415                    "template.render_string: second argument must be a table or nil",
1416                ));
1417            }
1418        };
1419        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1420        let env = minijinja::Environment::new();
1421        let tmpl = env
1422            .template_from_str(&template_str)
1423            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1424        tmpl.render(mini_vars)
1425            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1426    })?;
1427    tmpl_table.set("render_string", render_string_fn)?;
1428
1429    let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1430        let content = std::fs::read_to_string(&file_path).map_err(|e| {
1431            mlua::Error::runtime(format!("template.render: failed to read {file_path:?}: {e}"))
1432        })?;
1433        let json_vars = match &vars {
1434            Value::Table(_) => lua_value_to_json(&vars)?,
1435            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1436            _ => {
1437                return Err(mlua::Error::runtime(
1438                    "template.render: second argument must be a table or nil",
1439                ));
1440            }
1441        };
1442        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1443        let env = minijinja::Environment::new();
1444        let tmpl = env
1445            .template_from_str(&content)
1446            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1447        tmpl.render(mini_vars)
1448            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1449    })?;
1450    tmpl_table.set("render", render_fn)?;
1451
1452    lua.globals().set("template", tmpl_table)?;
1453    Ok(())
1454}
1455
1456fn lua_string_literal(s: &str) -> String {
1457    let escaped = s
1458        .replace('\\', "\\\\")
1459        .replace('"', "\\\"")
1460        .replace('\n', "\\n")
1461        .replace('\r', "\\r")
1462        .replace('\0', "\\0");
1463    format!("\"{escaped}\"")
1464}
1465
1466#[cfg(test)]
1467mod tests {
1468    use super::*;
1469
1470    #[test]
1471    fn test_base64_roundtrip() {
1472        let input = "hello world";
1473        let encoded = BASE64.encode(input.as_bytes());
1474        assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1475        let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1476        assert_eq!(String::from_utf8(decoded).unwrap(), input);
1477    }
1478
1479    #[test]
1480    fn test_base64_empty() {
1481        let encoded = BASE64.encode(b"");
1482        assert_eq!(encoded, "");
1483        let decoded = BASE64.decode(b"").unwrap();
1484        assert!(decoded.is_empty());
1485    }
1486
1487    #[test]
1488    fn test_lua_value_to_json_nil() {
1489        let result = lua_value_to_json(&Value::Nil).unwrap();
1490        assert_eq!(result, serde_json::Value::Null);
1491    }
1492
1493    #[test]
1494    fn test_lua_value_to_json_bool() {
1495        assert_eq!(
1496            lua_value_to_json(&Value::Boolean(true)).unwrap(),
1497            serde_json::Value::Bool(true)
1498        );
1499    }
1500
1501    #[test]
1502    fn test_lua_value_to_json_integer() {
1503        assert_eq!(
1504            lua_value_to_json(&Value::Integer(42)).unwrap(),
1505            serde_json::json!(42)
1506        );
1507    }
1508
1509    #[test]
1510    fn test_lua_value_to_json_number() {
1511        assert_eq!(
1512            lua_value_to_json(&Value::Number(1.5)).unwrap(),
1513            serde_json::json!(1.5)
1514        );
1515    }
1516
1517    #[test]
1518    fn test_lua_values_equal_nil() {
1519        assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1520    }
1521
1522    #[test]
1523    fn test_lua_values_equal_int_float() {
1524        assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1525    }
1526
1527    #[test]
1528    fn test_lua_values_not_equal() {
1529        assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1530    }
1531
1532    #[test]
1533    fn test_format_lua_value() {
1534        assert_eq!(format_lua_value(&Value::Nil), "nil");
1535        assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1536        assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1537    }
1538
1539    #[test]
1540    fn test_lua_string_literal_escaping() {
1541        assert_eq!(lua_string_literal("hello"), "\"hello\"");
1542        assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1543        assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1544    }
1545}