Skip to main content

assay/lua/
builtins.rs

1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
24    register_http(lua, client)?;
25    register_json(lua)?;
26    register_yaml(lua)?;
27    register_toml(lua)?;
28    register_assert(lua)?;
29    register_log(lua)?;
30    register_env(lua)?;
31    register_sleep(lua)?;
32    register_time(lua)?;
33    register_fs(lua)?;
34    register_base64(lua)?;
35    register_crypto(lua)?;
36    register_regex(lua)?;
37    register_async(lua)?;
38    register_db(lua)?;
39    register_ws(lua)?;
40    register_template(lua)?;
41    Ok(())
42}
43
44fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
45    let http_table = lua.create_table()?;
46
47    for method in ["get", "post", "put", "patch", "delete"] {
48        let method_client = client.clone();
49        let method_name = method.to_string();
50        let has_body = method != "get" && method != "delete";
51
52        let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
53            let client = method_client.clone();
54            let method_name = method_name.clone();
55            async move {
56                let mut args_iter = args.into_iter();
57                let url: String = match args_iter.next() {
58                    Some(Value::String(s)) => s.to_str()?.to_string(),
59                    _ => {
60                        return Err(mlua::Error::runtime(format!(
61                            "http.{method_name}: first argument must be a URL string"
62                        )));
63                    }
64                };
65
66                let (body_str, auto_json, opts) = if has_body {
67                    let (body, is_json) = match args_iter.next() {
68                        Some(Value::String(s)) => (s.to_str()?.to_string(), false),
69                        Some(Value::Table(t)) => {
70                            let json_val = lua_table_to_json(&t)?;
71                            let serialized = serde_json::to_string(&json_val).map_err(|e| {
72                                mlua::Error::runtime(format!(
73                                    "http.{method_name}: JSON encode failed: {e}"
74                                ))
75                            })?;
76                            (serialized, true)
77                        }
78                        Some(Value::Nil) | None => (String::new(), false),
79                        _ => {
80                            return Err(mlua::Error::runtime(format!(
81                                "http.{method_name}: second argument must be a string, table, or nil"
82                            )));
83                        }
84                    };
85                    let opts = match args_iter.next() {
86                        Some(Value::Table(t)) => Some(t),
87                        Some(Value::Nil) | None => None,
88                        _ => {
89                            return Err(mlua::Error::runtime(format!(
90                                "http.{method_name}: third argument must be a table or nil"
91                            )));
92                        }
93                    };
94                    (body, is_json, opts)
95                } else {
96                    let opts = match args_iter.next() {
97                        Some(Value::Table(t)) => Some(t),
98                        Some(Value::Nil) | None => None,
99                        _ => {
100                            return Err(mlua::Error::runtime(format!(
101                                "http.{method_name}: second argument must be a table or nil"
102                            )));
103                        }
104                    };
105                    (String::new(), false, opts)
106                };
107
108                let mut req = match method_name.as_str() {
109                    "get" => client.get(&url),
110                    "post" => client.post(&url),
111                    "put" => client.put(&url),
112                    "patch" => client.patch(&url),
113                    "delete" => client.delete(&url),
114                    _ => unreachable!(),
115                };
116
117                if has_body && !body_str.is_empty() {
118                    req = req.body(body_str);
119                }
120                if auto_json {
121                    req = req.header("Content-Type", "application/json");
122                }
123                if let Some(ref opts_table) = opts
124                    && let Ok(headers_table) = opts_table.get::<Table>("headers")
125                {
126                    for pair in headers_table.pairs::<String, String>() {
127                        let (k, v) = pair?;
128                        req = req.header(k, v);
129                    }
130                }
131
132                let resp = req.send().await.map_err(|e| {
133                    mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
134                })?;
135                let status = resp.status().as_u16();
136                let resp_headers = resp.headers().clone();
137                let body = resp.text().await.map_err(|e| {
138                    mlua::Error::runtime(format!(
139                        "http.{method_name}: reading body failed: {e}"
140                    ))
141                })?;
142
143                let result = lua.create_table()?;
144                result.set("status", status)?;
145                result.set("body", body)?;
146
147                let headers_out = lua.create_table()?;
148                for (name, value) in &resp_headers {
149                    if let Ok(v) = value.to_str() {
150                        headers_out.set(name.as_str().to_string(), v.to_string())?;
151                    }
152                }
153                result.set("headers", headers_out)?;
154
155                Ok(Value::Table(result))
156            }
157        })?;
158        http_table.set(method, func)?;
159    }
160
161    let serve_fn =
162        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
163            let mut args_iter = args.into_iter();
164
165            let port: u16 = match args_iter.next() {
166                Some(Value::Integer(n)) => n as u16,
167                _ => {
168                    return Err::<(), _>(mlua::Error::runtime(
169                        "http.serve: first argument must be a port number",
170                    ));
171                }
172            };
173
174            let routes_table = match args_iter.next() {
175                Some(Value::Table(t)) => t,
176                _ => {
177                    return Err::<(), _>(mlua::Error::runtime(
178                        "http.serve: second argument must be a routes table",
179                    ));
180                }
181            };
182
183            let routes = Rc::new(parse_routes(&routes_table)?);
184
185            let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
186                .await
187                .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
188
189            loop {
190                let (stream, _addr) = listener.accept().await.map_err(|e| {
191                    mlua::Error::runtime(format!("http.serve: accept failed: {e}"))
192                })?;
193
194                let routes = routes.clone();
195                let lua_clone = lua.clone();
196
197                tokio::task::spawn_local(async move {
198                    let io = hyper_util::rt::TokioIo::new(stream);
199                    let routes = routes.clone();
200                    let lua = lua_clone.clone();
201
202                    let service = service_fn(move |req: Request<Incoming>| {
203                        let routes = routes.clone();
204                        let lua = lua.clone();
205                        async move { handle_request(&lua, &routes, req).await }
206                    });
207
208                    if let Err(e) = http1::Builder::new().serve_connection(io, service).await
209                        && !e.to_string().contains("connection closed")
210                    {
211                        error!("http.serve: connection error: {e}");
212                    }
213                });
214            }
215        })?;
216    http_table.set("serve", serve_fn)?;
217
218    lua.globals().set("http", http_table)?;
219    Ok(())
220}
221
222fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
223    let mut routes = HashMap::new();
224    for method_pair in routes_table.pairs::<String, Table>() {
225        let (method, paths_table) = method_pair?;
226        let method_upper = method.to_uppercase();
227        for path_pair in paths_table.pairs::<String, mlua::Function>() {
228            let (path, func) = path_pair?;
229            routes.insert((method_upper.clone(), path), func);
230        }
231    }
232    Ok(routes)
233}
234
235async fn handle_request(
236    lua: &Lua,
237    routes: &HashMap<(String, String), mlua::Function>,
238    req: Request<Incoming>,
239) -> Result<Response<Full<Bytes>>, hyper::Error> {
240    let method = req.method().to_string();
241    let path = req.uri().path().to_string();
242    let query = req.uri().query().unwrap_or("").to_string();
243    let headers: Vec<(String, String)> = req
244        .headers()
245        .iter()
246        .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
247        .collect();
248
249    let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
250        Ok(collected) => collected.to_bytes(),
251        Err(_) => Bytes::new(),
252    };
253    let body_str = String::from_utf8_lossy(&body_bytes).to_string();
254
255    let key = (method.clone(), path.clone());
256    let handler = match routes.get(&key) {
257        Some(f) => f,
258        None => {
259            return Ok(Response::builder()
260                .status(StatusCode::NOT_FOUND)
261                .header("content-type", "text/plain")
262                .body(Full::new(Bytes::from("not found")))
263                .unwrap());
264        }
265    };
266
267    match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
268        Ok(lua_resp) => lua_response_to_http(&lua_resp),
269        Err(e) => Ok(Response::builder()
270            .status(StatusCode::INTERNAL_SERVER_ERROR)
271            .header("content-type", "text/plain")
272            .body(Full::new(Bytes::from(format!("handler error: {e}"))))
273            .unwrap()),
274    }
275}
276
277fn build_lua_request_and_call(
278    lua: &Lua,
279    handler: &mlua::Function,
280    method: &str,
281    path: &str,
282    query: &str,
283    headers: &[(String, String)],
284    body: &str,
285) -> mlua::Result<Table> {
286    let req_table = lua.create_table()?;
287    req_table.set("method", method.to_string())?;
288    req_table.set("path", path.to_string())?;
289    req_table.set("query", query.to_string())?;
290    req_table.set("body", body.to_string())?;
291
292    let headers_table = lua.create_table()?;
293    for (k, v) in headers {
294        headers_table.set(k.as_str(), v.as_str())?;
295    }
296    req_table.set("headers", headers_table)?;
297
298    handler.call::<Table>(req_table)
299}
300
301fn lua_response_to_http(
302    resp_table: &Table,
303) -> Result<Response<Full<Bytes>>, hyper::Error> {
304    let status = resp_table
305        .get::<Option<u16>>("status")
306        .unwrap_or(None)
307        .unwrap_or(200);
308
309    let mut builder =
310        Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
311
312    if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
313        for (k, v) in headers_table.pairs::<String, String>().flatten() {
314            builder = builder.header(k, v);
315        }
316    }
317
318    let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
319        let json_val =
320            lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
321        let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
322        builder = builder.header("content-type", "application/json");
323        Bytes::from(serialized)
324    } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
325        builder = builder.header("content-type", "text/plain");
326        Bytes::from(body_str)
327    } else {
328        builder = builder.header("content-type", "text/plain");
329        Bytes::new()
330    };
331
332    Ok(builder.body(Full::new(body_bytes)).unwrap())
333}
334
335fn register_json(lua: &Lua) -> mlua::Result<()> {
336    let json_table = lua.create_table()?;
337
338    let parse_fn = lua.create_function(|lua, s: String| {
339        let value: serde_json::Value = serde_json::from_str(&s)
340            .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
341        json_value_to_lua(lua, &value)
342    })?;
343    json_table.set("parse", parse_fn)?;
344
345    let encode_fn = lua.create_function(|_, val: Value| {
346        let json_val = lua_value_to_json(&val)?;
347        serde_json::to_string(&json_val)
348            .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
349    })?;
350    json_table.set("encode", encode_fn)?;
351
352    lua.globals().set("json", json_table)?;
353    Ok(())
354}
355
356fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
357    let mut is_array = true;
358    let mut max_index: i64 = 0;
359    let mut count: i64 = 0;
360
361    for pair in table.clone().pairs::<Value, Value>() {
362        let (key, _) = pair?;
363        count += 1;
364        match key {
365            Value::Integer(i) if i >= 1 => {
366                if i > max_index {
367                    max_index = i;
368                }
369            }
370            _ => {
371                is_array = false;
372                break;
373            }
374        }
375    }
376
377    if is_array && max_index == count {
378        let mut arr = Vec::with_capacity(max_index as usize);
379        for i in 1..=max_index {
380            let val: Value = table.get(i)?;
381            arr.push(lua_value_to_json(&val)?);
382        }
383        Ok(serde_json::Value::Array(arr))
384    } else {
385        let mut map = serde_json::Map::new();
386        for pair in table.clone().pairs::<Value, Value>() {
387            let (key, val) = pair?;
388            let key_str = match key {
389                Value::String(s) => s.to_str()?.to_string(),
390                Value::Integer(i) => i.to_string(),
391                Value::Number(f) => f.to_string(),
392                _ => {
393                    return Err(mlua::Error::runtime(format!(
394                        "unsupported table key type: {}",
395                        key.type_name()
396                    )));
397                }
398            };
399            map.insert(key_str, lua_value_to_json(&val)?);
400        }
401        Ok(serde_json::Value::Object(map))
402    }
403}
404
405fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
406    match val {
407        Value::Nil => Ok(serde_json::Value::Null),
408        Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
409        Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
410        Value::Number(f) => serde_json::Number::from_f64(*f)
411            .map(serde_json::Value::Number)
412            .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
413        Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
414        Value::Table(t) => lua_table_to_json(t),
415        _ => Err(mlua::Error::runtime(format!(
416            "unsupported Lua type for JSON: {}",
417            val.type_name()
418        ))),
419    }
420}
421
422pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
423    match val {
424        serde_json::Value::Null => Ok(Value::Nil),
425        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
426        serde_json::Value::Number(n) => {
427            if let Some(i) = n.as_i64() {
428                Ok(Value::Integer(i))
429            } else if let Some(f) = n.as_f64() {
430                Ok(Value::Number(f))
431            } else {
432                Ok(Value::Nil)
433            }
434        }
435        serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
436        serde_json::Value::Array(arr) => {
437            let table = lua.create_table()?;
438            for (i, item) in arr.iter().enumerate() {
439                table.set(i + 1, json_value_to_lua(lua, item)?)?;
440            }
441            Ok(Value::Table(table))
442        }
443        serde_json::Value::Object(map) => {
444            let table = lua.create_table()?;
445            for (k, v) in map {
446                table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
447            }
448            Ok(Value::Table(table))
449        }
450    }
451}
452
453fn register_assert(lua: &Lua) -> mlua::Result<()> {
454    let assert_table = lua.create_table()?;
455
456    let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
457        let mut args_iter = args.into_iter();
458        let a = args_iter.next().unwrap_or(Value::Nil);
459        let b = args_iter.next().unwrap_or(Value::Nil);
460        let msg = extract_string_arg(lua, args_iter.next());
461
462        if !lua_values_equal(&a, &b) {
463            let detail = format!(
464                "assert.eq failed: {:?} != {:?}{}",
465                format_lua_value(&a),
466                format_lua_value(&b),
467                msg.map(|m| format!(" - {m}")).unwrap_or_default()
468            );
469            return Err(mlua::Error::runtime(detail));
470        }
471        Ok(())
472    })?;
473    assert_table.set("eq", eq_fn)?;
474
475    let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
476        let mut args_iter = args.into_iter();
477        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
478        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
479        let msg = extract_string_arg(lua, args_iter.next());
480
481        match (a, b) {
482            (Some(va), Some(vb)) if va > vb => Ok(()),
483            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
484                "assert.gt failed: {va} is not > {vb}{}",
485                msg.map(|m| format!(" - {m}")).unwrap_or_default()
486            ))),
487            _ => Err(mlua::Error::runtime(
488                "assert.gt: both arguments must be numbers",
489            )),
490        }
491    })?;
492    assert_table.set("gt", gt_fn)?;
493
494    let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
495        let mut args_iter = args.into_iter();
496        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
497        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
498        let msg = extract_string_arg(lua, args_iter.next());
499
500        match (a, b) {
501            (Some(va), Some(vb)) if va < vb => Ok(()),
502            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
503                "assert.lt failed: {va} is not < {vb}{}",
504                msg.map(|m| format!(" - {m}")).unwrap_or_default()
505            ))),
506            _ => Err(mlua::Error::runtime(
507                "assert.lt: both arguments must be numbers",
508            )),
509        }
510    })?;
511    assert_table.set("lt", lt_fn)?;
512
513    let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
514        let mut args_iter = args.into_iter();
515        let haystack: String = match args_iter.next() {
516            Some(Value::String(s)) => s.to_str()?.to_string(),
517            _ => {
518                return Err(mlua::Error::runtime(
519                    "assert.contains: first argument must be a string",
520                ));
521            }
522        };
523        let needle: String = match args_iter.next() {
524            Some(Value::String(s)) => s.to_str()?.to_string(),
525            _ => {
526                return Err(mlua::Error::runtime(
527                    "assert.contains: second argument must be a string",
528                ));
529            }
530        };
531        let msg = extract_string_arg(lua, args_iter.next());
532
533        if !haystack.contains(&needle) {
534            return Err(mlua::Error::runtime(format!(
535                "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
536                msg.map(|m| format!(" - {m}")).unwrap_or_default()
537            )));
538        }
539        Ok(())
540    })?;
541    assert_table.set("contains", contains_fn)?;
542
543    let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
544        let mut args_iter = args.into_iter();
545        let val = args_iter.next().unwrap_or(Value::Nil);
546        let msg = extract_string_arg(lua, args_iter.next());
547
548        if val == Value::Nil {
549            return Err(mlua::Error::runtime(format!(
550                "assert.not_nil failed: value is nil{}",
551                msg.map(|m| format!(" - {m}")).unwrap_or_default()
552            )));
553        }
554        Ok(())
555    })?;
556    assert_table.set("not_nil", not_nil_fn)?;
557
558    let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
559        let mut args_iter = args.into_iter();
560        let text: String = match args_iter.next() {
561            Some(Value::String(s)) => s.to_str()?.to_string(),
562            _ => {
563                return Err(mlua::Error::runtime(
564                    "assert.matches: first argument must be a string",
565                ));
566            }
567        };
568        let pattern: String = match args_iter.next() {
569            Some(Value::String(s)) => s.to_str()?.to_string(),
570            _ => {
571                return Err(mlua::Error::runtime(
572                    "assert.matches: second argument must be a pattern string",
573                ));
574            }
575        };
576        let msg = extract_string_arg(lua, args_iter.next());
577
578        let found: bool = lua
579            .load(format!(
580                "return string.find({}, {}) ~= nil",
581                lua_string_literal(&text),
582                lua_string_literal(&pattern)
583            ))
584            .eval()
585            .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
586
587        if !found {
588            return Err(mlua::Error::runtime(format!(
589                "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
590                msg.map(|m| format!(" - {m}")).unwrap_or_default()
591            )));
592        }
593        Ok(())
594    })?;
595    assert_table.set("matches", matches_fn)?;
596
597    lua.globals().set("assert", assert_table)?;
598    Ok(())
599}
600
601fn register_log(lua: &Lua) -> mlua::Result<()> {
602    let log_table = lua.create_table()?;
603
604    let info_fn = lua.create_function(|_, msg: String| {
605        info!(target: "lua", "{}", msg);
606        Ok(())
607    })?;
608    log_table.set("info", info_fn)?;
609
610    let warn_fn = lua.create_function(|_, msg: String| {
611        warn!(target: "lua", "{}", msg);
612        Ok(())
613    })?;
614    log_table.set("warn", warn_fn)?;
615
616    let error_fn = lua.create_function(|_, msg: String| {
617        error!(target: "lua", "{}", msg);
618        Ok(())
619    })?;
620    log_table.set("error", error_fn)?;
621
622    lua.globals().set("log", log_table)?;
623    Ok(())
624}
625
626fn register_env(lua: &Lua) -> mlua::Result<()> {
627    let env_table = lua.create_table()?;
628
629    let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
630        Ok(val) => Ok(Some(val)),
631        Err(_) => Ok(None),
632    })?;
633    env_table.set("_process_get", process_get_fn)?;
634    env_table.set("_check_env", lua.create_table()?)?;
635
636    lua.globals().set("env", env_table)?;
637
638    lua.load(
639        r#"
640        function env.get(name)
641            local val = env._check_env[name]
642            if val ~= nil then return val end
643            return env._process_get(name)
644        end
645        "#,
646    )
647    .exec()?;
648
649    Ok(())
650}
651
652fn register_sleep(lua: &Lua) -> mlua::Result<()> {
653    let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
654        let duration = std::time::Duration::from_secs_f64(seconds);
655        tokio::time::sleep(duration).await;
656        Ok(())
657    })?;
658    lua.globals().set("sleep", sleep_fn)?;
659    Ok(())
660}
661
662fn register_time(lua: &Lua) -> mlua::Result<()> {
663    let time_fn = lua.create_function(|_, ()| {
664        let secs = SystemTime::now()
665            .duration_since(UNIX_EPOCH)
666            .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
667            .as_secs_f64();
668        Ok(secs)
669    })?;
670    lua.globals().set("time", time_fn)?;
671    Ok(())
672}
673
674fn register_fs(lua: &Lua) -> mlua::Result<()> {
675    let fs_table = lua.create_table()?;
676
677    let read_fn = lua.create_function(|_, path: String| {
678        std::fs::read_to_string(&path)
679            .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
680    })?;
681    fs_table.set("read", read_fn)?;
682
683    let write_fn = lua.create_function(|_, (path, content): (String, String)| {
684        let p = std::path::Path::new(&path);
685        if let Some(parent) = p.parent() {
686            std::fs::create_dir_all(parent).map_err(|e| {
687                mlua::Error::runtime(format!(
688                    "fs.write: failed to create directories for {path:?}: {e}"
689                ))
690            })?;
691        }
692        std::fs::write(&path, &content)
693            .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
694    })?;
695    fs_table.set("write", write_fn)?;
696
697    lua.globals().set("fs", fs_table)?;
698    Ok(())
699}
700
701fn register_base64(lua: &Lua) -> mlua::Result<()> {
702    let b64_table = lua.create_table()?;
703
704    let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
705    b64_table.set("encode", encode_fn)?;
706
707    let decode_fn = lua.create_function(|_, input: String| {
708        let bytes = BASE64
709            .decode(input.as_bytes())
710            .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
711        String::from_utf8(bytes)
712            .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
713    })?;
714    b64_table.set("decode", decode_fn)?;
715
716    lua.globals().set("base64", b64_table)?;
717    Ok(())
718}
719
720fn register_crypto(lua: &Lua) -> mlua::Result<()> {
721    let crypto_table = lua.create_table()?;
722
723    let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
724        let mut args_iter = args.into_iter();
725
726        let claims_table = match args_iter.next() {
727            Some(Value::Table(t)) => t,
728            _ => {
729                return Err(mlua::Error::runtime(
730                    "crypto.jwt_sign: first argument must be a claims table",
731                ));
732            }
733        };
734
735        let pem_key: String = match args_iter.next() {
736            Some(Value::String(s)) => s.to_str()?.to_string(),
737            _ => {
738                return Err(mlua::Error::runtime(
739                    "crypto.jwt_sign: second argument must be a PEM key string",
740                ));
741            }
742        };
743
744        let algorithm = match args_iter.next() {
745            Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
746                "RS256" => Algorithm::RS256,
747                "RS384" => Algorithm::RS384,
748                "RS512" => Algorithm::RS512,
749                other => {
750                    return Err(mlua::Error::runtime(format!(
751                        "crypto.jwt_sign: unsupported algorithm: {other}"
752                    )));
753                }
754            },
755            Some(Value::Nil) | None => Algorithm::RS256,
756            _ => {
757                return Err(mlua::Error::runtime(
758                    "crypto.jwt_sign: third argument must be an algorithm string or nil",
759                ));
760            }
761        };
762
763        // 4th optional argument: options table with header fields (kid, typ, etc.)
764        let opts = match args_iter.next() {
765            Some(Value::Table(t)) => Some(t),
766            Some(Value::Nil) | None => None,
767            _ => {
768                return Err(mlua::Error::runtime(
769                    "crypto.jwt_sign: fourth argument must be an options table or nil",
770                ));
771            }
772        };
773
774        let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
775        let pem_bytes = Zeroizing::new(pem_key.into_bytes());
776        let key = EncodingKey::from_rsa_pem(&pem_bytes)
777            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
778
779        let mut header = Header::new(algorithm);
780        if let Some(ref opts_table) = opts {
781            if let Ok(kid) = opts_table.get::<String>("kid") {
782                header.kid = Some(kid);
783            }
784        }
785        let token = jsonwebtoken::encode(&header, &claims_json, &key)
786            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
787
788        Ok(token)
789    })?;
790    crypto_table.set("jwt_sign", jwt_sign_fn)?;
791
792    let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
793        let mut args_iter = args.into_iter();
794
795        let input: String = match args_iter.next() {
796            Some(Value::String(s)) => s.to_str()?.to_string(),
797            _ => {
798                return Err(mlua::Error::runtime(
799                    "crypto.hash: first argument must be a string",
800                ));
801            }
802        };
803
804        let algorithm: String = match args_iter.next() {
805            Some(Value::String(s)) => s.to_str()?.to_lowercase(),
806            Some(Value::Nil) | None => "sha256".to_string(),
807            _ => {
808                return Err(mlua::Error::runtime(
809                    "crypto.hash: second argument must be an algorithm string or nil",
810                ));
811            }
812        };
813
814        let hex = match algorithm.as_str() {
815            "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
816            "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
817            "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
818            "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
819            "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
820            "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
821            "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
822            "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
823            other => {
824                return Err(mlua::Error::runtime(format!(
825                    "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
826                )));
827            }
828        };
829
830        Ok(hex)
831    })?;
832    crypto_table.set("hash", hash_fn)?;
833
834    let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
835        let mut args_iter = args.into_iter();
836
837        let length: usize = match args_iter.next() {
838            Some(Value::Integer(n)) if n > 0 => n as usize,
839            Some(Value::Integer(n)) => {
840                return Err(mlua::Error::runtime(format!(
841                    "crypto.random: length must be positive, got {n}"
842                )));
843            }
844            Some(Value::Nil) | None => 32,
845            _ => {
846                return Err(mlua::Error::runtime(
847                    "crypto.random: first argument must be a positive integer or nil",
848                ));
849            }
850        };
851
852        let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
853        let mut rng = rand::rng();
854        let result: String = (0..length)
855            .map(|_| charset[rng.random_range(..charset.len())] as char)
856            .collect();
857
858        Ok(result)
859    })?;
860    crypto_table.set("random", random_fn)?;
861
862    lua.globals().set("crypto", crypto_table)?;
863    Ok(())
864}
865
866fn register_yaml(lua: &Lua) -> mlua::Result<()> {
867    let yaml_table = lua.create_table()?;
868
869    let parse_fn = lua.create_function(|lua, s: String| {
870        let json_val: serde_json::Value = serde_yml::from_str(&s)
871            .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
872        json_value_to_lua(lua, &json_val)
873    })?;
874    yaml_table.set("parse", parse_fn)?;
875
876    let encode_fn = lua.create_function(|_, val: Value| {
877        let json_val = lua_value_to_json(&val)?;
878        serde_yml::to_string(&json_val)
879            .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
880    })?;
881    yaml_table.set("encode", encode_fn)?;
882
883    lua.globals().set("yaml", yaml_table)?;
884    Ok(())
885}
886
887fn register_toml(lua: &Lua) -> mlua::Result<()> {
888    let toml_table = lua.create_table()?;
889
890    let parse_fn = lua.create_function(|lua, s: String| {
891        let toml_val: toml::Value = toml::from_str(&s)
892            .map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
893        let json_val = serde_json::to_value(&toml_val)
894            .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
895        json_value_to_lua(lua, &json_val)
896    })?;
897    toml_table.set("parse", parse_fn)?;
898
899    let encode_fn = lua.create_function(|_, val: Value| {
900        let json_val = lua_value_to_json(&val)?;
901        let toml_val: toml::Value = serde_json::from_value(json_val)
902            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
903        toml::to_string_pretty(&toml_val)
904            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
905    })?;
906    toml_table.set("encode", encode_fn)?;
907
908    lua.globals().set("toml", toml_table)?;
909    Ok(())
910}
911
912fn register_regex(lua: &Lua) -> mlua::Result<()> {
913    let regex_table = lua.create_table()?;
914
915    let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
916        let re = regex_lite::Regex::new(&pattern)
917            .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
918        Ok(re.is_match(&text))
919    })?;
920    regex_table.set("match", match_fn)?;
921
922    let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
923        let re = regex_lite::Regex::new(&pattern)
924            .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
925        match re.captures(&text) {
926            Some(caps) => {
927                let result = lua.create_table()?;
928                let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
929                result.set("match", full_match.to_string())?;
930                let groups = lua.create_table()?;
931                for i in 1..caps.len() {
932                    if let Some(m) = caps.get(i) {
933                        groups.set(i, m.as_str().to_string())?;
934                    }
935                }
936                result.set("groups", groups)?;
937                Ok(Value::Table(result))
938            }
939            None => Ok(Value::Nil),
940        }
941    })?;
942    regex_table.set("find", find_fn)?;
943
944    let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
945        let re = regex_lite::Regex::new(&pattern)
946            .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
947        let results = lua.create_table()?;
948        for (i, m) in re.find_iter(&text).enumerate() {
949            results.set(i + 1, m.as_str().to_string())?;
950        }
951        Ok(results)
952    })?;
953    regex_table.set("find_all", find_all_fn)?;
954
955    let replace_fn = lua.create_function(
956        |_, (text, pattern, replacement): (String, String, String)| {
957            let re = regex_lite::Regex::new(&pattern).map_err(|e| {
958                mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
959            })?;
960            Ok(re.replace_all(&text, replacement.as_str()).into_owned())
961        },
962    )?;
963    regex_table.set("replace", replace_fn)?;
964
965    lua.globals().set("regex", regex_table)?;
966    Ok(())
967}
968
969fn register_async(lua: &Lua) -> mlua::Result<()> {
970    let async_table = lua.create_table()?;
971
972    let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
973        let thread = lua.create_thread(func)?;
974        let async_thread = thread.into_async::<mlua::MultiValue>(())?;
975        let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
976            tokio::task::spawn_local(async move {
977                let values = async_thread.await.map_err(|e| e.to_string())?;
978                Ok(values.into_vec())
979            });
980
981        let handle = lua.create_table()?;
982        let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
983        let cell_clone = cell.clone();
984
985        let await_fn = lua.create_async_function(move |lua, ()| {
986            let cell = cell_clone.clone();
987            async move {
988                let join_handle = cell
989                    .borrow_mut()
990                    .take()
991                    .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
992                let result = join_handle.await.map_err(|e| {
993                    mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
994                })?;
995                match result {
996                    Ok(values) => {
997                        let tbl = lua.create_table()?;
998                        for (i, v) in values.into_iter().enumerate() {
999                            tbl.set(i + 1, v)?;
1000                        }
1001                        Ok(Value::Table(tbl))
1002                    }
1003                    Err(msg) => Err(mlua::Error::runtime(msg)),
1004                }
1005            }
1006        })?;
1007        handle.set("await", await_fn)?;
1008
1009        Ok(handle)
1010    })?;
1011    async_table.set("spawn", spawn_fn)?;
1012
1013    let spawn_interval_fn =
1014        lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
1015            if seconds <= 0.0 {
1016                return Err(mlua::Error::runtime(
1017                    "async.spawn_interval: interval must be positive",
1018                ));
1019            }
1020
1021            let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1022            let cancel_clone = cancel.clone();
1023
1024            tokio::task::spawn_local({
1025                let cancel = cancel_clone.clone();
1026                async move {
1027                    let mut interval =
1028                        tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1029                    interval.tick().await;
1030                    loop {
1031                        interval.tick().await;
1032                        if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1033                            break;
1034                        }
1035                        if let Err(e) = func.call_async::<()>(()).await {
1036                            error!("async.spawn_interval: callback error: {e}");
1037                            break;
1038                        }
1039                    }
1040                }
1041            });
1042
1043            let handle = lua.create_table()?;
1044            let cancel_fn = lua.create_function(move |_, ()| {
1045                cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1046                Ok(())
1047            })?;
1048            handle.set("cancel", cancel_fn)?;
1049
1050            Ok(handle)
1051        })?;
1052    async_table.set("spawn_interval", spawn_interval_fn)?;
1053
1054    lua.globals().set("async", async_table)?;
1055    Ok(())
1056}
1057
1058struct DbPool(Arc<AnyPool>);
1059impl UserData for DbPool {}
1060
1061fn register_db(lua: &Lua) -> mlua::Result<()> {
1062    sqlx::any::install_default_drivers();
1063
1064    let db_table = lua.create_table()?;
1065
1066    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1067        let pool = sqlx::any::AnyPoolOptions::new()
1068            .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1069            .connect(&url)
1070            .await
1071            .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1072        lua.create_any_userdata(DbPool(Arc::new(pool)))
1073    })?;
1074    db_table.set("connect", connect_fn)?;
1075
1076    let query_fn =
1077        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1078            let mut args_iter = args.into_iter();
1079
1080            let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1081            let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1082            let params = extract_params(&args_iter.next())?;
1083
1084            let mut query = sqlx::query(&sql);
1085            for p in &params {
1086                query = bind_param(query, p);
1087            }
1088
1089            let rows: Vec<AnyRow> = query
1090                .fetch_all(&*pool)
1091                .await
1092                .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1093
1094            let result = lua.create_table()?;
1095            for (i, row) in rows.iter().enumerate() {
1096                let row_table = any_row_to_lua_table(&lua, row)?;
1097                result.set(i + 1, row_table)?;
1098            }
1099            Ok(Value::Table(result))
1100        })?;
1101    db_table.set("query", query_fn)?;
1102
1103    let execute_fn =
1104        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1105            let mut args_iter = args.into_iter();
1106
1107            let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1108            let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1109            let params = extract_params(&args_iter.next())?;
1110
1111            let mut query = sqlx::query(&sql);
1112            for p in &params {
1113                query = bind_param(query, p);
1114            }
1115
1116            let result = query
1117                .execute(&*pool)
1118                .await
1119                .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1120
1121            let tbl = lua.create_table()?;
1122            tbl.set("rows_affected", result.rows_affected() as i64)?;
1123            Ok(Value::Table(tbl))
1124        })?;
1125    db_table.set("execute", execute_fn)?;
1126
1127    let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1128        let mut args_iter = args.into_iter();
1129        let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1130        pool.close().await;
1131        Ok(())
1132    })?;
1133    db_table.set("close", close_fn)?;
1134
1135    lua.globals().set("db", db_table)?;
1136    Ok(())
1137}
1138
1139fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1140    match val {
1141        Some(Value::UserData(ud)) => {
1142            let db = ud
1143                .borrow::<DbPool>()
1144                .map_err(|_| mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection")))?;
1145            Ok(db.0.clone())
1146        }
1147        _ => Err(mlua::Error::runtime(format!(
1148            "{fn_name}: first argument must be a db connection"
1149        ))),
1150    }
1151}
1152
1153fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1154    match val {
1155        Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1156        _ => Err(mlua::Error::runtime(format!(
1157            "{fn_name}: second argument must be a SQL string"
1158        ))),
1159    }
1160}
1161
1162#[derive(Clone)]
1163enum DbParam {
1164    Null,
1165    Bool(bool),
1166    Int(i64),
1167    Float(f64),
1168    Text(String),
1169}
1170
1171fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1172    match val {
1173        Some(Value::Table(t)) => {
1174            let mut params = Vec::new();
1175            let len = t.len()?;
1176            for i in 1..=len {
1177                let v: Value = t.get(i)?;
1178                let param = match v {
1179                    Value::Nil => DbParam::Null,
1180                    Value::Boolean(b) => DbParam::Bool(b),
1181                    Value::Integer(n) => DbParam::Int(n),
1182                    Value::Number(f) => DbParam::Float(f),
1183                    Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1184                    _ => {
1185                        return Err(mlua::Error::runtime(format!(
1186                            "db: unsupported parameter type: {}",
1187                            v.type_name()
1188                        )));
1189                    }
1190                };
1191                params.push(param);
1192            }
1193            Ok(params)
1194        }
1195        Some(Value::Nil) | None => Ok(Vec::new()),
1196        _ => Err(mlua::Error::runtime(
1197            "db: params must be a table (array) or nil",
1198        )),
1199    }
1200}
1201
1202fn bind_param<'q>(
1203    query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1204    param: &'q DbParam,
1205) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1206    match param {
1207        DbParam::Null => query.bind(None::<String>),
1208        DbParam::Bool(b) => query.bind(*b),
1209        DbParam::Int(n) => query.bind(*n),
1210        DbParam::Float(f) => query.bind(*f),
1211        DbParam::Text(s) => query.bind(s.as_str()),
1212    }
1213}
1214
1215fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1216    let table = lua.create_table()?;
1217    for col in row.columns() {
1218        let name = col.name();
1219        let val: Value = any_column_to_lua_value(lua, row, col)?;
1220        table.set(name.to_string(), val)?;
1221    }
1222    Ok(table)
1223}
1224
1225fn any_column_to_lua_value<C: Column>(
1226    lua: &Lua,
1227    row: &AnyRow,
1228    col: &C,
1229) -> mlua::Result<Value> {
1230    let ordinal = col.ordinal();
1231    let type_info = col.type_info();
1232    let type_name = type_info.to_string();
1233    let type_name = type_name.to_uppercase();
1234
1235    if row.try_get_raw(ordinal).map(|v| v.is_null()).unwrap_or(true) {
1236        return Ok(Value::Nil);
1237    }
1238
1239    match type_name.as_str() {
1240        "BOOLEAN" | "BOOL" => {
1241            let v: bool = row
1242                .try_get(ordinal)
1243                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1244            Ok(Value::Boolean(v))
1245        }
1246        "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1247            let v: i64 = row
1248                .try_get(ordinal)
1249                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1250            Ok(Value::Integer(v))
1251        }
1252        "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1253            let v: f64 = row
1254                .try_get(ordinal)
1255                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1256            Ok(Value::Number(v))
1257        }
1258        _ => {
1259            let v: String = row
1260                .try_get(ordinal)
1261                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1262            Ok(Value::String(lua.create_string(&v)?))
1263        }
1264    }
1265}
1266
1267fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1268    match val {
1269        Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1270        Some(other) => {
1271            let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1272            result
1273        }
1274        None => None,
1275    }
1276}
1277
1278fn lua_values_equal(a: &Value, b: &Value) -> bool {
1279    match (a, b) {
1280        (Value::Nil, Value::Nil) => true,
1281        (Value::Boolean(a), Value::Boolean(b)) => a == b,
1282        (Value::Integer(a), Value::Integer(b)) => a == b,
1283        (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1284        (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1285            (*a as f64 - b).abs() < f64::EPSILON
1286        }
1287        (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1288        _ => false,
1289    }
1290}
1291
1292fn lua_value_to_f64(val: Value) -> Option<f64> {
1293    match val {
1294        Value::Integer(i) => Some(i as f64),
1295        Value::Number(f) => Some(f),
1296        _ => None,
1297    }
1298}
1299
1300fn format_lua_value(val: &Value) -> String {
1301    match val {
1302        Value::Nil => "nil".to_string(),
1303        Value::Boolean(b) => b.to_string(),
1304        Value::Integer(i) => i.to_string(),
1305        Value::Number(f) => f.to_string(),
1306        Value::String(s) => match s.to_str() {
1307            Ok(v) => v.to_string(),
1308            Err(_) => "<invalid utf-8>".to_string(),
1309        },
1310        _ => format!("<{}>", val.type_name()),
1311    }
1312}
1313
1314type WsSink = Rc<
1315    tokio::sync::Mutex<
1316        futures_util::stream::SplitSink<
1317            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1318            tokio_tungstenite::tungstenite::Message,
1319        >,
1320    >,
1321>;
1322type WsStream = Rc<
1323    tokio::sync::Mutex<
1324        futures_util::stream::SplitStream<
1325            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1326        >,
1327    >,
1328>;
1329
1330struct WsConn {
1331    sink: WsSink,
1332    stream: WsStream,
1333}
1334impl UserData for WsConn {}
1335
1336fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1337    let ud = match val {
1338        Value::UserData(ud) => ud,
1339        _ => {
1340            return Err(mlua::Error::runtime(format!(
1341                "{fn_name}: first argument must be a ws connection"
1342            )));
1343        }
1344    };
1345    let ws = ud.borrow::<WsConn>().map_err(|_| {
1346        mlua::Error::runtime(format!(
1347            "{fn_name}: first argument must be a ws connection"
1348        ))
1349    })?;
1350    Ok((ws.sink.clone(), ws.stream.clone()))
1351}
1352
1353fn register_ws(lua: &Lua) -> mlua::Result<()> {
1354    let ws_table = lua.create_table()?;
1355
1356    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1357        let (stream, _response) = tokio_tungstenite::connect_async(&url)
1358            .await
1359            .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1360        let (sink, read) = stream.split();
1361        lua.create_any_userdata(WsConn {
1362            sink: Rc::new(tokio::sync::Mutex::new(sink)),
1363            stream: Rc::new(tokio::sync::Mutex::new(read)),
1364        })
1365    })?;
1366    ws_table.set("connect", connect_fn)?;
1367
1368    let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1369        let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1370        sink.lock()
1371            .await
1372            .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1373            .await
1374            .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1375        Ok(())
1376    })?;
1377    ws_table.set("send", send_fn)?;
1378
1379    let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1380        let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1381        loop {
1382            let msg = stream
1383                .lock()
1384                .await
1385                .next()
1386                .await
1387                .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1388                .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1389            match msg {
1390                tokio_tungstenite::tungstenite::Message::Text(t) => {
1391                    return Ok(t.to_string());
1392                }
1393                tokio_tungstenite::tungstenite::Message::Binary(b) => {
1394                    return String::from_utf8(b.into()).map_err(|e| {
1395                        mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}"))
1396                    });
1397                }
1398                tokio_tungstenite::tungstenite::Message::Close(_) => {
1399                    return Err(mlua::Error::runtime("ws.recv: connection closed"));
1400                }
1401                _ => continue,
1402            }
1403        }
1404    })?;
1405    ws_table.set("recv", recv_fn)?;
1406
1407    let close_fn = lua.create_async_function(|_, conn: Value| async move {
1408        let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1409        sink.lock()
1410            .await
1411            .close()
1412            .await
1413            .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1414        Ok(())
1415    })?;
1416    ws_table.set("close", close_fn)?;
1417
1418    lua.globals().set("ws", ws_table)?;
1419    Ok(())
1420}
1421
1422fn register_template(lua: &Lua) -> mlua::Result<()> {
1423    let tmpl_table = lua.create_table()?;
1424
1425    let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1426        let json_vars = match &vars {
1427            Value::Table(_) => lua_value_to_json(&vars)?,
1428            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1429            _ => {
1430                return Err(mlua::Error::runtime(
1431                    "template.render_string: second argument must be a table or nil",
1432                ));
1433            }
1434        };
1435        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1436        let env = minijinja::Environment::new();
1437        let tmpl = env
1438            .template_from_str(&template_str)
1439            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1440        tmpl.render(mini_vars)
1441            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1442    })?;
1443    tmpl_table.set("render_string", render_string_fn)?;
1444
1445    let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1446        let content = std::fs::read_to_string(&file_path).map_err(|e| {
1447            mlua::Error::runtime(format!("template.render: failed to read {file_path:?}: {e}"))
1448        })?;
1449        let json_vars = match &vars {
1450            Value::Table(_) => lua_value_to_json(&vars)?,
1451            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1452            _ => {
1453                return Err(mlua::Error::runtime(
1454                    "template.render: second argument must be a table or nil",
1455                ));
1456            }
1457        };
1458        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1459        let env = minijinja::Environment::new();
1460        let tmpl = env
1461            .template_from_str(&content)
1462            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1463        tmpl.render(mini_vars)
1464            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1465    })?;
1466    tmpl_table.set("render", render_fn)?;
1467
1468    lua.globals().set("template", tmpl_table)?;
1469    Ok(())
1470}
1471
1472fn lua_string_literal(s: &str) -> String {
1473    let escaped = s
1474        .replace('\\', "\\\\")
1475        .replace('"', "\\\"")
1476        .replace('\n', "\\n")
1477        .replace('\r', "\\r")
1478        .replace('\0', "\\0");
1479    format!("\"{escaped}\"")
1480}
1481
1482#[cfg(test)]
1483mod tests {
1484    use super::*;
1485
1486    #[test]
1487    fn test_base64_roundtrip() {
1488        let input = "hello world";
1489        let encoded = BASE64.encode(input.as_bytes());
1490        assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1491        let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1492        assert_eq!(String::from_utf8(decoded).unwrap(), input);
1493    }
1494
1495    #[test]
1496    fn test_base64_empty() {
1497        let encoded = BASE64.encode(b"");
1498        assert_eq!(encoded, "");
1499        let decoded = BASE64.decode(b"").unwrap();
1500        assert!(decoded.is_empty());
1501    }
1502
1503    #[test]
1504    fn test_lua_value_to_json_nil() {
1505        let result = lua_value_to_json(&Value::Nil).unwrap();
1506        assert_eq!(result, serde_json::Value::Null);
1507    }
1508
1509    #[test]
1510    fn test_lua_value_to_json_bool() {
1511        assert_eq!(
1512            lua_value_to_json(&Value::Boolean(true)).unwrap(),
1513            serde_json::Value::Bool(true)
1514        );
1515    }
1516
1517    #[test]
1518    fn test_lua_value_to_json_integer() {
1519        assert_eq!(
1520            lua_value_to_json(&Value::Integer(42)).unwrap(),
1521            serde_json::json!(42)
1522        );
1523    }
1524
1525    #[test]
1526    fn test_lua_value_to_json_number() {
1527        assert_eq!(
1528            lua_value_to_json(&Value::Number(1.5)).unwrap(),
1529            serde_json::json!(1.5)
1530        );
1531    }
1532
1533    #[test]
1534    fn test_lua_values_equal_nil() {
1535        assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1536    }
1537
1538    #[test]
1539    fn test_lua_values_equal_int_float() {
1540        assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1541    }
1542
1543    #[test]
1544    fn test_lua_values_not_equal() {
1545        assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1546    }
1547
1548    #[test]
1549    fn test_format_lua_value() {
1550        assert_eq!(format_lua_value(&Value::Nil), "nil");
1551        assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1552        assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1553    }
1554
1555    #[test]
1556    fn test_lua_string_literal_escaping() {
1557        assert_eq!(lua_string_literal("hello"), "\"hello\"");
1558        assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1559        assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1560    }
1561}