Skip to main content

assay/lua/
builtins.rs

1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
24    register_http(lua, client)?;
25    register_json(lua)?;
26    register_yaml(lua)?;
27    register_toml(lua)?;
28    register_assert(lua)?;
29    register_log(lua)?;
30    register_env(lua)?;
31    register_sleep(lua)?;
32    register_time(lua)?;
33    register_fs(lua)?;
34    register_base64(lua)?;
35    register_crypto(lua)?;
36    register_regex(lua)?;
37    register_async(lua)?;
38    register_db(lua)?;
39    register_ws(lua)?;
40    register_template(lua)?;
41    Ok(())
42}
43
44fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
45    let http_table = lua.create_table()?;
46
47    for method in ["get", "post", "put", "patch", "delete"] {
48        let method_client = client.clone();
49        let method_name = method.to_string();
50        let has_body = method != "get" && method != "delete";
51
52        let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
53            let client = method_client.clone();
54            let method_name = method_name.clone();
55            async move {
56                let mut args_iter = args.into_iter();
57                let url: String = match args_iter.next() {
58                    Some(Value::String(s)) => s.to_str()?.to_string(),
59                    _ => {
60                        return Err(mlua::Error::runtime(format!(
61                            "http.{method_name}: first argument must be a URL string"
62                        )));
63                    }
64                };
65
66                let (body_str, auto_json, opts) = if has_body {
67                    let (body, is_json) = match args_iter.next() {
68                        Some(Value::String(s)) => (s.to_str()?.to_string(), false),
69                        Some(Value::Table(t)) => {
70                            let json_val = lua_table_to_json(&t)?;
71                            let serialized = serde_json::to_string(&json_val).map_err(|e| {
72                                mlua::Error::runtime(format!(
73                                    "http.{method_name}: JSON encode failed: {e}"
74                                ))
75                            })?;
76                            (serialized, true)
77                        }
78                        Some(Value::Nil) | None => (String::new(), false),
79                        _ => {
80                            return Err(mlua::Error::runtime(format!(
81                                "http.{method_name}: second argument must be a string, table, or nil"
82                            )));
83                        }
84                    };
85                    let opts = match args_iter.next() {
86                        Some(Value::Table(t)) => Some(t),
87                        Some(Value::Nil) | None => None,
88                        _ => {
89                            return Err(mlua::Error::runtime(format!(
90                                "http.{method_name}: third argument must be a table or nil"
91                            )));
92                        }
93                    };
94                    (body, is_json, opts)
95                } else {
96                    let opts = match args_iter.next() {
97                        Some(Value::Table(t)) => Some(t),
98                        Some(Value::Nil) | None => None,
99                        _ => {
100                            return Err(mlua::Error::runtime(format!(
101                                "http.{method_name}: second argument must be a table or nil"
102                            )));
103                        }
104                    };
105                    (String::new(), false, opts)
106                };
107
108                let mut req = match method_name.as_str() {
109                    "get" => client.get(&url),
110                    "post" => client.post(&url),
111                    "put" => client.put(&url),
112                    "patch" => client.patch(&url),
113                    "delete" => client.delete(&url),
114                    _ => unreachable!(),
115                };
116
117                if has_body && !body_str.is_empty() {
118                    req = req.body(body_str);
119                }
120                if auto_json {
121                    req = req.header("Content-Type", "application/json");
122                }
123                if let Some(ref opts_table) = opts
124                    && let Ok(headers_table) = opts_table.get::<Table>("headers")
125                {
126                    for pair in headers_table.pairs::<String, String>() {
127                        let (k, v) = pair?;
128                        req = req.header(k, v);
129                    }
130                }
131
132                let resp = req.send().await.map_err(|e| {
133                    mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
134                })?;
135                let status = resp.status().as_u16();
136                let resp_headers = resp.headers().clone();
137                let body = resp.text().await.map_err(|e| {
138                    mlua::Error::runtime(format!(
139                        "http.{method_name}: reading body failed: {e}"
140                    ))
141                })?;
142
143                let result = lua.create_table()?;
144                result.set("status", status)?;
145                result.set("body", body)?;
146
147                let headers_out = lua.create_table()?;
148                for (name, value) in &resp_headers {
149                    if let Ok(v) = value.to_str() {
150                        headers_out.set(name.as_str().to_string(), v.to_string())?;
151                    }
152                }
153                result.set("headers", headers_out)?;
154
155                Ok(Value::Table(result))
156            }
157        })?;
158        http_table.set(method, func)?;
159    }
160
161    let serve_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
162        let mut args_iter = args.into_iter();
163
164        let port: u16 = match args_iter.next() {
165            Some(Value::Integer(n)) => n as u16,
166            _ => {
167                return Err::<(), _>(mlua::Error::runtime(
168                    "http.serve: first argument must be a port number",
169                ));
170            }
171        };
172
173        let routes_table = match args_iter.next() {
174            Some(Value::Table(t)) => t,
175            _ => {
176                return Err::<(), _>(mlua::Error::runtime(
177                    "http.serve: second argument must be a routes table",
178                ));
179            }
180        };
181
182        let routes = Rc::new(parse_routes(&routes_table)?);
183
184        let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
185            .await
186            .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
187
188        loop {
189            let (stream, _addr) = listener
190                .accept()
191                .await
192                .map_err(|e| mlua::Error::runtime(format!("http.serve: accept failed: {e}")))?;
193
194            let routes = routes.clone();
195            let lua_clone = lua.clone();
196
197            tokio::task::spawn_local(async move {
198                let io = hyper_util::rt::TokioIo::new(stream);
199                let routes = routes.clone();
200                let lua = lua_clone.clone();
201
202                let service = service_fn(move |req: Request<Incoming>| {
203                    let routes = routes.clone();
204                    let lua = lua.clone();
205                    async move { handle_request(&lua, &routes, req).await }
206                });
207
208                if let Err(e) = http1::Builder::new().serve_connection(io, service).await
209                    && !e.to_string().contains("connection closed")
210                {
211                    error!("http.serve: connection error: {e}");
212                }
213            });
214        }
215    })?;
216    http_table.set("serve", serve_fn)?;
217
218    lua.globals().set("http", http_table)?;
219    Ok(())
220}
221
222fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
223    let mut routes = HashMap::new();
224    for method_pair in routes_table.pairs::<String, Table>() {
225        let (method, paths_table) = method_pair?;
226        let method_upper = method.to_uppercase();
227        for path_pair in paths_table.pairs::<String, mlua::Function>() {
228            let (path, func) = path_pair?;
229            routes.insert((method_upper.clone(), path), func);
230        }
231    }
232    Ok(routes)
233}
234
235async fn handle_request(
236    lua: &Lua,
237    routes: &HashMap<(String, String), mlua::Function>,
238    req: Request<Incoming>,
239) -> Result<Response<Full<Bytes>>, hyper::Error> {
240    let method = req.method().to_string();
241    let path = req.uri().path().to_string();
242    let query = req.uri().query().unwrap_or("").to_string();
243    let headers: Vec<(String, String)> = req
244        .headers()
245        .iter()
246        .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
247        .collect();
248
249    let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
250        Ok(collected) => collected.to_bytes(),
251        Err(_) => Bytes::new(),
252    };
253    let body_str = String::from_utf8_lossy(&body_bytes).to_string();
254
255    let key = (method.clone(), path.clone());
256    let handler = match routes.get(&key) {
257        Some(f) => f,
258        None => {
259            return Ok(Response::builder()
260                .status(StatusCode::NOT_FOUND)
261                .header("content-type", "text/plain")
262                .body(Full::new(Bytes::from("not found")))
263                .unwrap());
264        }
265    };
266
267    match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
268        Ok(lua_resp) => lua_response_to_http(&lua_resp),
269        Err(e) => Ok(Response::builder()
270            .status(StatusCode::INTERNAL_SERVER_ERROR)
271            .header("content-type", "text/plain")
272            .body(Full::new(Bytes::from(format!("handler error: {e}"))))
273            .unwrap()),
274    }
275}
276
277fn build_lua_request_and_call(
278    lua: &Lua,
279    handler: &mlua::Function,
280    method: &str,
281    path: &str,
282    query: &str,
283    headers: &[(String, String)],
284    body: &str,
285) -> mlua::Result<Table> {
286    let req_table = lua.create_table()?;
287    req_table.set("method", method.to_string())?;
288    req_table.set("path", path.to_string())?;
289    req_table.set("query", query.to_string())?;
290    req_table.set("body", body.to_string())?;
291
292    let headers_table = lua.create_table()?;
293    for (k, v) in headers {
294        headers_table.set(k.as_str(), v.as_str())?;
295    }
296    req_table.set("headers", headers_table)?;
297
298    handler.call::<Table>(req_table)
299}
300
301fn lua_response_to_http(resp_table: &Table) -> Result<Response<Full<Bytes>>, hyper::Error> {
302    let status = resp_table
303        .get::<Option<u16>>("status")
304        .unwrap_or(None)
305        .unwrap_or(200);
306
307    let mut builder =
308        Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
309
310    if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
311        for (k, v) in headers_table.pairs::<String, String>().flatten() {
312            builder = builder.header(k, v);
313        }
314    }
315
316    let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
317        let json_val =
318            lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
319        let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
320        builder = builder.header("content-type", "application/json");
321        Bytes::from(serialized)
322    } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
323        builder = builder.header("content-type", "text/plain");
324        Bytes::from(body_str)
325    } else {
326        builder = builder.header("content-type", "text/plain");
327        Bytes::new()
328    };
329
330    Ok(builder.body(Full::new(body_bytes)).unwrap())
331}
332
333fn register_json(lua: &Lua) -> mlua::Result<()> {
334    let json_table = lua.create_table()?;
335
336    let parse_fn = lua.create_function(|lua, s: String| {
337        let value: serde_json::Value = serde_json::from_str(&s)
338            .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
339        json_value_to_lua(lua, &value)
340    })?;
341    json_table.set("parse", parse_fn)?;
342
343    let encode_fn = lua.create_function(|_, val: Value| {
344        let json_val = lua_value_to_json(&val)?;
345        serde_json::to_string(&json_val)
346            .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
347    })?;
348    json_table.set("encode", encode_fn)?;
349
350    lua.globals().set("json", json_table)?;
351    Ok(())
352}
353
354fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
355    let mut is_array = true;
356    let mut max_index: i64 = 0;
357    let mut count: i64 = 0;
358
359    for pair in table.clone().pairs::<Value, Value>() {
360        let (key, _) = pair?;
361        count += 1;
362        match key {
363            Value::Integer(i) if i >= 1 => {
364                if i > max_index {
365                    max_index = i;
366                }
367            }
368            _ => {
369                is_array = false;
370                break;
371            }
372        }
373    }
374
375    if is_array && max_index == count {
376        let mut arr = Vec::with_capacity(max_index as usize);
377        for i in 1..=max_index {
378            let val: Value = table.get(i)?;
379            arr.push(lua_value_to_json(&val)?);
380        }
381        Ok(serde_json::Value::Array(arr))
382    } else {
383        let mut map = serde_json::Map::new();
384        for pair in table.clone().pairs::<Value, Value>() {
385            let (key, val) = pair?;
386            let key_str = match key {
387                Value::String(s) => s.to_str()?.to_string(),
388                Value::Integer(i) => i.to_string(),
389                Value::Number(f) => f.to_string(),
390                _ => {
391                    return Err(mlua::Error::runtime(format!(
392                        "unsupported table key type: {}",
393                        key.type_name()
394                    )));
395                }
396            };
397            map.insert(key_str, lua_value_to_json(&val)?);
398        }
399        Ok(serde_json::Value::Object(map))
400    }
401}
402
403fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
404    match val {
405        Value::Nil => Ok(serde_json::Value::Null),
406        Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
407        Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
408        Value::Number(f) => serde_json::Number::from_f64(*f)
409            .map(serde_json::Value::Number)
410            .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
411        Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
412        Value::Table(t) => lua_table_to_json(t),
413        _ => Err(mlua::Error::runtime(format!(
414            "unsupported Lua type for JSON: {}",
415            val.type_name()
416        ))),
417    }
418}
419
420pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
421    match val {
422        serde_json::Value::Null => Ok(Value::Nil),
423        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
424        serde_json::Value::Number(n) => {
425            if let Some(i) = n.as_i64() {
426                Ok(Value::Integer(i))
427            } else if let Some(f) = n.as_f64() {
428                Ok(Value::Number(f))
429            } else {
430                Ok(Value::Nil)
431            }
432        }
433        serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
434        serde_json::Value::Array(arr) => {
435            let table = lua.create_table()?;
436            for (i, item) in arr.iter().enumerate() {
437                table.set(i + 1, json_value_to_lua(lua, item)?)?;
438            }
439            Ok(Value::Table(table))
440        }
441        serde_json::Value::Object(map) => {
442            let table = lua.create_table()?;
443            for (k, v) in map {
444                table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
445            }
446            Ok(Value::Table(table))
447        }
448    }
449}
450
451fn register_assert(lua: &Lua) -> mlua::Result<()> {
452    let assert_table = lua.create_table()?;
453
454    let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
455        let mut args_iter = args.into_iter();
456        let a = args_iter.next().unwrap_or(Value::Nil);
457        let b = args_iter.next().unwrap_or(Value::Nil);
458        let msg = extract_string_arg(lua, args_iter.next());
459
460        if !lua_values_equal(&a, &b) {
461            let detail = format!(
462                "assert.eq failed: {:?} != {:?}{}",
463                format_lua_value(&a),
464                format_lua_value(&b),
465                msg.map(|m| format!(" - {m}")).unwrap_or_default()
466            );
467            return Err(mlua::Error::runtime(detail));
468        }
469        Ok(())
470    })?;
471    assert_table.set("eq", eq_fn)?;
472
473    let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
474        let mut args_iter = args.into_iter();
475        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
476        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
477        let msg = extract_string_arg(lua, args_iter.next());
478
479        match (a, b) {
480            (Some(va), Some(vb)) if va > vb => Ok(()),
481            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
482                "assert.gt failed: {va} is not > {vb}{}",
483                msg.map(|m| format!(" - {m}")).unwrap_or_default()
484            ))),
485            _ => Err(mlua::Error::runtime(
486                "assert.gt: both arguments must be numbers",
487            )),
488        }
489    })?;
490    assert_table.set("gt", gt_fn)?;
491
492    let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
493        let mut args_iter = args.into_iter();
494        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
495        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
496        let msg = extract_string_arg(lua, args_iter.next());
497
498        match (a, b) {
499            (Some(va), Some(vb)) if va < vb => Ok(()),
500            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
501                "assert.lt failed: {va} is not < {vb}{}",
502                msg.map(|m| format!(" - {m}")).unwrap_or_default()
503            ))),
504            _ => Err(mlua::Error::runtime(
505                "assert.lt: both arguments must be numbers",
506            )),
507        }
508    })?;
509    assert_table.set("lt", lt_fn)?;
510
511    let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
512        let mut args_iter = args.into_iter();
513        let haystack: String = match args_iter.next() {
514            Some(Value::String(s)) => s.to_str()?.to_string(),
515            _ => {
516                return Err(mlua::Error::runtime(
517                    "assert.contains: first argument must be a string",
518                ));
519            }
520        };
521        let needle: String = match args_iter.next() {
522            Some(Value::String(s)) => s.to_str()?.to_string(),
523            _ => {
524                return Err(mlua::Error::runtime(
525                    "assert.contains: second argument must be a string",
526                ));
527            }
528        };
529        let msg = extract_string_arg(lua, args_iter.next());
530
531        if !haystack.contains(&needle) {
532            return Err(mlua::Error::runtime(format!(
533                "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
534                msg.map(|m| format!(" - {m}")).unwrap_or_default()
535            )));
536        }
537        Ok(())
538    })?;
539    assert_table.set("contains", contains_fn)?;
540
541    let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
542        let mut args_iter = args.into_iter();
543        let val = args_iter.next().unwrap_or(Value::Nil);
544        let msg = extract_string_arg(lua, args_iter.next());
545
546        if val == Value::Nil {
547            return Err(mlua::Error::runtime(format!(
548                "assert.not_nil failed: value is nil{}",
549                msg.map(|m| format!(" - {m}")).unwrap_or_default()
550            )));
551        }
552        Ok(())
553    })?;
554    assert_table.set("not_nil", not_nil_fn)?;
555
556    let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
557        let mut args_iter = args.into_iter();
558        let text: String = match args_iter.next() {
559            Some(Value::String(s)) => s.to_str()?.to_string(),
560            _ => {
561                return Err(mlua::Error::runtime(
562                    "assert.matches: first argument must be a string",
563                ));
564            }
565        };
566        let pattern: String = match args_iter.next() {
567            Some(Value::String(s)) => s.to_str()?.to_string(),
568            _ => {
569                return Err(mlua::Error::runtime(
570                    "assert.matches: second argument must be a pattern string",
571                ));
572            }
573        };
574        let msg = extract_string_arg(lua, args_iter.next());
575
576        let found: bool = lua
577            .load(format!(
578                "return string.find({}, {}) ~= nil",
579                lua_string_literal(&text),
580                lua_string_literal(&pattern)
581            ))
582            .eval()
583            .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
584
585        if !found {
586            return Err(mlua::Error::runtime(format!(
587                "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
588                msg.map(|m| format!(" - {m}")).unwrap_or_default()
589            )));
590        }
591        Ok(())
592    })?;
593    assert_table.set("matches", matches_fn)?;
594
595    lua.globals().set("assert", assert_table)?;
596    Ok(())
597}
598
599fn register_log(lua: &Lua) -> mlua::Result<()> {
600    let log_table = lua.create_table()?;
601
602    let info_fn = lua.create_function(|_, msg: String| {
603        info!(target: "lua", "{}", msg);
604        Ok(())
605    })?;
606    log_table.set("info", info_fn)?;
607
608    let warn_fn = lua.create_function(|_, msg: String| {
609        warn!(target: "lua", "{}", msg);
610        Ok(())
611    })?;
612    log_table.set("warn", warn_fn)?;
613
614    let error_fn = lua.create_function(|_, msg: String| {
615        error!(target: "lua", "{}", msg);
616        Ok(())
617    })?;
618    log_table.set("error", error_fn)?;
619
620    lua.globals().set("log", log_table)?;
621    Ok(())
622}
623
624fn register_env(lua: &Lua) -> mlua::Result<()> {
625    let env_table = lua.create_table()?;
626
627    let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
628        Ok(val) => Ok(Some(val)),
629        Err(_) => Ok(None),
630    })?;
631    env_table.set("_process_get", process_get_fn)?;
632    env_table.set("_check_env", lua.create_table()?)?;
633
634    lua.globals().set("env", env_table)?;
635
636    lua.load(
637        r#"
638        function env.get(name)
639            local val = env._check_env[name]
640            if val ~= nil then return val end
641            return env._process_get(name)
642        end
643        "#,
644    )
645    .exec()?;
646
647    Ok(())
648}
649
650fn register_sleep(lua: &Lua) -> mlua::Result<()> {
651    let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
652        let duration = std::time::Duration::from_secs_f64(seconds);
653        tokio::time::sleep(duration).await;
654        Ok(())
655    })?;
656    lua.globals().set("sleep", sleep_fn)?;
657    Ok(())
658}
659
660fn register_time(lua: &Lua) -> mlua::Result<()> {
661    let time_fn = lua.create_function(|_, ()| {
662        let secs = SystemTime::now()
663            .duration_since(UNIX_EPOCH)
664            .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
665            .as_secs_f64();
666        Ok(secs)
667    })?;
668    lua.globals().set("time", time_fn)?;
669    Ok(())
670}
671
672fn register_fs(lua: &Lua) -> mlua::Result<()> {
673    let fs_table = lua.create_table()?;
674
675    let read_fn = lua.create_function(|_, path: String| {
676        std::fs::read_to_string(&path)
677            .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
678    })?;
679    fs_table.set("read", read_fn)?;
680
681    let write_fn = lua.create_function(|_, (path, content): (String, String)| {
682        let p = std::path::Path::new(&path);
683        if let Some(parent) = p.parent() {
684            std::fs::create_dir_all(parent).map_err(|e| {
685                mlua::Error::runtime(format!(
686                    "fs.write: failed to create directories for {path:?}: {e}"
687                ))
688            })?;
689        }
690        std::fs::write(&path, &content)
691            .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
692    })?;
693    fs_table.set("write", write_fn)?;
694
695    lua.globals().set("fs", fs_table)?;
696    Ok(())
697}
698
699fn register_base64(lua: &Lua) -> mlua::Result<()> {
700    let b64_table = lua.create_table()?;
701
702    let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
703    b64_table.set("encode", encode_fn)?;
704
705    let decode_fn = lua.create_function(|_, input: String| {
706        let bytes = BASE64
707            .decode(input.as_bytes())
708            .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
709        String::from_utf8(bytes)
710            .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
711    })?;
712    b64_table.set("decode", decode_fn)?;
713
714    lua.globals().set("base64", b64_table)?;
715    Ok(())
716}
717
718fn register_crypto(lua: &Lua) -> mlua::Result<()> {
719    let crypto_table = lua.create_table()?;
720
721    let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
722        let mut args_iter = args.into_iter();
723
724        let claims_table = match args_iter.next() {
725            Some(Value::Table(t)) => t,
726            _ => {
727                return Err(mlua::Error::runtime(
728                    "crypto.jwt_sign: first argument must be a claims table",
729                ));
730            }
731        };
732
733        let pem_key: String = match args_iter.next() {
734            Some(Value::String(s)) => s.to_str()?.to_string(),
735            _ => {
736                return Err(mlua::Error::runtime(
737                    "crypto.jwt_sign: second argument must be a PEM key string",
738                ));
739            }
740        };
741
742        let algorithm = match args_iter.next() {
743            Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
744                "RS256" => Algorithm::RS256,
745                "RS384" => Algorithm::RS384,
746                "RS512" => Algorithm::RS512,
747                other => {
748                    return Err(mlua::Error::runtime(format!(
749                        "crypto.jwt_sign: unsupported algorithm: {other}"
750                    )));
751                }
752            },
753            Some(Value::Nil) | None => Algorithm::RS256,
754            _ => {
755                return Err(mlua::Error::runtime(
756                    "crypto.jwt_sign: third argument must be an algorithm string or nil",
757                ));
758            }
759        };
760
761        // 4th optional argument: options table with header fields (kid, typ, etc.)
762        let opts = match args_iter.next() {
763            Some(Value::Table(t)) => Some(t),
764            Some(Value::Nil) | None => None,
765            _ => {
766                return Err(mlua::Error::runtime(
767                    "crypto.jwt_sign: fourth argument must be an options table or nil",
768                ));
769            }
770        };
771
772        let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
773        let pem_bytes = Zeroizing::new(pem_key.into_bytes());
774        let key = EncodingKey::from_rsa_pem(&pem_bytes)
775            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
776
777        let mut header = Header::new(algorithm);
778        if let Some(ref opts_table) = opts
779            && let Ok(kid) = opts_table.get::<String>("kid")
780        {
781            header.kid = Some(kid);
782        }
783        let token = jsonwebtoken::encode(&header, &claims_json, &key)
784            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
785
786        Ok(token)
787    })?;
788    crypto_table.set("jwt_sign", jwt_sign_fn)?;
789
790    let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
791        let mut args_iter = args.into_iter();
792
793        let input: String = match args_iter.next() {
794            Some(Value::String(s)) => s.to_str()?.to_string(),
795            _ => {
796                return Err(mlua::Error::runtime(
797                    "crypto.hash: first argument must be a string",
798                ));
799            }
800        };
801
802        let algorithm: String = match args_iter.next() {
803            Some(Value::String(s)) => s.to_str()?.to_lowercase(),
804            Some(Value::Nil) | None => "sha256".to_string(),
805            _ => {
806                return Err(mlua::Error::runtime(
807                    "crypto.hash: second argument must be an algorithm string or nil",
808                ));
809            }
810        };
811
812        let hex = match algorithm.as_str() {
813            "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
814            "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
815            "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
816            "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
817            "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
818            "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
819            "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
820            "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
821            other => {
822                return Err(mlua::Error::runtime(format!(
823                    "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
824                )));
825            }
826        };
827
828        Ok(hex)
829    })?;
830    crypto_table.set("hash", hash_fn)?;
831
832    let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
833        let mut args_iter = args.into_iter();
834
835        let length: usize = match args_iter.next() {
836            Some(Value::Integer(n)) if n > 0 => n as usize,
837            Some(Value::Integer(n)) => {
838                return Err(mlua::Error::runtime(format!(
839                    "crypto.random: length must be positive, got {n}"
840                )));
841            }
842            Some(Value::Nil) | None => 32,
843            _ => {
844                return Err(mlua::Error::runtime(
845                    "crypto.random: first argument must be a positive integer or nil",
846                ));
847            }
848        };
849
850        let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
851        let mut rng = rand::rng();
852        let result: String = (0..length)
853            .map(|_| charset[rng.random_range(..charset.len())] as char)
854            .collect();
855
856        Ok(result)
857    })?;
858    crypto_table.set("random", random_fn)?;
859
860    lua.globals().set("crypto", crypto_table)?;
861    Ok(())
862}
863
864fn register_yaml(lua: &Lua) -> mlua::Result<()> {
865    let yaml_table = lua.create_table()?;
866
867    let parse_fn = lua.create_function(|lua, s: String| {
868        let json_val: serde_json::Value = serde_yml::from_str(&s)
869            .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
870        json_value_to_lua(lua, &json_val)
871    })?;
872    yaml_table.set("parse", parse_fn)?;
873
874    let encode_fn = lua.create_function(|_, val: Value| {
875        let json_val = lua_value_to_json(&val)?;
876        serde_yml::to_string(&json_val)
877            .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
878    })?;
879    yaml_table.set("encode", encode_fn)?;
880
881    lua.globals().set("yaml", yaml_table)?;
882    Ok(())
883}
884
885fn register_toml(lua: &Lua) -> mlua::Result<()> {
886    let toml_table = lua.create_table()?;
887
888    let parse_fn = lua.create_function(|lua, s: String| {
889        let toml_val: toml::Value =
890            toml::from_str(&s).map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
891        let json_val = serde_json::to_value(&toml_val)
892            .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
893        json_value_to_lua(lua, &json_val)
894    })?;
895    toml_table.set("parse", parse_fn)?;
896
897    let encode_fn = lua.create_function(|_, val: Value| {
898        let json_val = lua_value_to_json(&val)?;
899        let toml_val: toml::Value = serde_json::from_value(json_val)
900            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
901        toml::to_string_pretty(&toml_val)
902            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
903    })?;
904    toml_table.set("encode", encode_fn)?;
905
906    lua.globals().set("toml", toml_table)?;
907    Ok(())
908}
909
910fn register_regex(lua: &Lua) -> mlua::Result<()> {
911    let regex_table = lua.create_table()?;
912
913    let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
914        let re = regex_lite::Regex::new(&pattern)
915            .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
916        Ok(re.is_match(&text))
917    })?;
918    regex_table.set("match", match_fn)?;
919
920    let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
921        let re = regex_lite::Regex::new(&pattern)
922            .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
923        match re.captures(&text) {
924            Some(caps) => {
925                let result = lua.create_table()?;
926                let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
927                result.set("match", full_match.to_string())?;
928                let groups = lua.create_table()?;
929                for i in 1..caps.len() {
930                    if let Some(m) = caps.get(i) {
931                        groups.set(i, m.as_str().to_string())?;
932                    }
933                }
934                result.set("groups", groups)?;
935                Ok(Value::Table(result))
936            }
937            None => Ok(Value::Nil),
938        }
939    })?;
940    regex_table.set("find", find_fn)?;
941
942    let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
943        let re = regex_lite::Regex::new(&pattern)
944            .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
945        let results = lua.create_table()?;
946        for (i, m) in re.find_iter(&text).enumerate() {
947            results.set(i + 1, m.as_str().to_string())?;
948        }
949        Ok(results)
950    })?;
951    regex_table.set("find_all", find_all_fn)?;
952
953    let replace_fn = lua.create_function(
954        |_, (text, pattern, replacement): (String, String, String)| {
955            let re = regex_lite::Regex::new(&pattern).map_err(|e| {
956                mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
957            })?;
958            Ok(re.replace_all(&text, replacement.as_str()).into_owned())
959        },
960    )?;
961    regex_table.set("replace", replace_fn)?;
962
963    lua.globals().set("regex", regex_table)?;
964    Ok(())
965}
966
967fn register_async(lua: &Lua) -> mlua::Result<()> {
968    let async_table = lua.create_table()?;
969
970    let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
971        let thread = lua.create_thread(func)?;
972        let async_thread = thread.into_async::<mlua::MultiValue>(())?;
973        let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
974            tokio::task::spawn_local(async move {
975                let values = async_thread.await.map_err(|e| e.to_string())?;
976                Ok(values.into_vec())
977            });
978
979        let handle = lua.create_table()?;
980        let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
981        let cell_clone = cell.clone();
982
983        let await_fn = lua.create_async_function(move |lua, ()| {
984            let cell = cell_clone.clone();
985            async move {
986                let join_handle = cell
987                    .borrow_mut()
988                    .take()
989                    .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
990                let result = join_handle.await.map_err(|e| {
991                    mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
992                })?;
993                match result {
994                    Ok(values) => {
995                        let tbl = lua.create_table()?;
996                        for (i, v) in values.into_iter().enumerate() {
997                            tbl.set(i + 1, v)?;
998                        }
999                        Ok(Value::Table(tbl))
1000                    }
1001                    Err(msg) => Err(mlua::Error::runtime(msg)),
1002                }
1003            }
1004        })?;
1005        handle.set("await", await_fn)?;
1006
1007        Ok(handle)
1008    })?;
1009    async_table.set("spawn", spawn_fn)?;
1010
1011    let spawn_interval_fn =
1012        lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
1013            if seconds <= 0.0 {
1014                return Err(mlua::Error::runtime(
1015                    "async.spawn_interval: interval must be positive",
1016                ));
1017            }
1018
1019            let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1020            let cancel_clone = cancel.clone();
1021
1022            tokio::task::spawn_local({
1023                let cancel = cancel_clone.clone();
1024                async move {
1025                    let mut interval =
1026                        tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1027                    interval.tick().await;
1028                    loop {
1029                        interval.tick().await;
1030                        if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1031                            break;
1032                        }
1033                        if let Err(e) = func.call_async::<()>(()).await {
1034                            error!("async.spawn_interval: callback error: {e}");
1035                            break;
1036                        }
1037                    }
1038                }
1039            });
1040
1041            let handle = lua.create_table()?;
1042            let cancel_fn = lua.create_function(move |_, ()| {
1043                cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1044                Ok(())
1045            })?;
1046            handle.set("cancel", cancel_fn)?;
1047
1048            Ok(handle)
1049        })?;
1050    async_table.set("spawn_interval", spawn_interval_fn)?;
1051
1052    lua.globals().set("async", async_table)?;
1053    Ok(())
1054}
1055
1056struct DbPool(Arc<AnyPool>);
1057impl UserData for DbPool {}
1058
1059fn register_db(lua: &Lua) -> mlua::Result<()> {
1060    sqlx::any::install_default_drivers();
1061
1062    let db_table = lua.create_table()?;
1063
1064    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1065        let pool = sqlx::any::AnyPoolOptions::new()
1066            .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1067            .connect(&url)
1068            .await
1069            .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1070        lua.create_any_userdata(DbPool(Arc::new(pool)))
1071    })?;
1072    db_table.set("connect", connect_fn)?;
1073
1074    let query_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1075        let mut args_iter = args.into_iter();
1076
1077        let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1078        let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1079        let params = extract_params(&args_iter.next())?;
1080
1081        let mut query = sqlx::query(&sql);
1082        for p in &params {
1083            query = bind_param(query, p);
1084        }
1085
1086        let rows: Vec<AnyRow> = query
1087            .fetch_all(&*pool)
1088            .await
1089            .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1090
1091        let result = lua.create_table()?;
1092        for (i, row) in rows.iter().enumerate() {
1093            let row_table = any_row_to_lua_table(&lua, row)?;
1094            result.set(i + 1, row_table)?;
1095        }
1096        Ok(Value::Table(result))
1097    })?;
1098    db_table.set("query", query_fn)?;
1099
1100    let execute_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1101        let mut args_iter = args.into_iter();
1102
1103        let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1104        let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1105        let params = extract_params(&args_iter.next())?;
1106
1107        let mut query = sqlx::query(&sql);
1108        for p in &params {
1109            query = bind_param(query, p);
1110        }
1111
1112        let result = query
1113            .execute(&*pool)
1114            .await
1115            .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1116
1117        let tbl = lua.create_table()?;
1118        tbl.set("rows_affected", result.rows_affected() as i64)?;
1119        Ok(Value::Table(tbl))
1120    })?;
1121    db_table.set("execute", execute_fn)?;
1122
1123    let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1124        let mut args_iter = args.into_iter();
1125        let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1126        pool.close().await;
1127        Ok(())
1128    })?;
1129    db_table.set("close", close_fn)?;
1130
1131    lua.globals().set("db", db_table)?;
1132    Ok(())
1133}
1134
1135fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1136    match val {
1137        Some(Value::UserData(ud)) => {
1138            let db = ud.borrow::<DbPool>().map_err(|_| {
1139                mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection"))
1140            })?;
1141            Ok(db.0.clone())
1142        }
1143        _ => Err(mlua::Error::runtime(format!(
1144            "{fn_name}: first argument must be a db connection"
1145        ))),
1146    }
1147}
1148
1149fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1150    match val {
1151        Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1152        _ => Err(mlua::Error::runtime(format!(
1153            "{fn_name}: second argument must be a SQL string"
1154        ))),
1155    }
1156}
1157
1158#[derive(Clone)]
1159enum DbParam {
1160    Null,
1161    Bool(bool),
1162    Int(i64),
1163    Float(f64),
1164    Text(String),
1165}
1166
1167fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1168    match val {
1169        Some(Value::Table(t)) => {
1170            let mut params = Vec::new();
1171            let len = t.len()?;
1172            for i in 1..=len {
1173                let v: Value = t.get(i)?;
1174                let param = match v {
1175                    Value::Nil => DbParam::Null,
1176                    Value::Boolean(b) => DbParam::Bool(b),
1177                    Value::Integer(n) => DbParam::Int(n),
1178                    Value::Number(f) => DbParam::Float(f),
1179                    Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1180                    _ => {
1181                        return Err(mlua::Error::runtime(format!(
1182                            "db: unsupported parameter type: {}",
1183                            v.type_name()
1184                        )));
1185                    }
1186                };
1187                params.push(param);
1188            }
1189            Ok(params)
1190        }
1191        Some(Value::Nil) | None => Ok(Vec::new()),
1192        _ => Err(mlua::Error::runtime(
1193            "db: params must be a table (array) or nil",
1194        )),
1195    }
1196}
1197
1198fn bind_param<'q>(
1199    query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1200    param: &'q DbParam,
1201) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1202    match param {
1203        DbParam::Null => query.bind(None::<String>),
1204        DbParam::Bool(b) => query.bind(*b),
1205        DbParam::Int(n) => query.bind(*n),
1206        DbParam::Float(f) => query.bind(*f),
1207        DbParam::Text(s) => query.bind(s.as_str()),
1208    }
1209}
1210
1211fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1212    let table = lua.create_table()?;
1213    for col in row.columns() {
1214        let name = col.name();
1215        let val: Value = any_column_to_lua_value(lua, row, col)?;
1216        table.set(name.to_string(), val)?;
1217    }
1218    Ok(table)
1219}
1220
1221fn any_column_to_lua_value<C: Column>(lua: &Lua, row: &AnyRow, col: &C) -> mlua::Result<Value> {
1222    let ordinal = col.ordinal();
1223    let type_info = col.type_info();
1224    let type_name = type_info.to_string();
1225    let type_name = type_name.to_uppercase();
1226
1227    if row
1228        .try_get_raw(ordinal)
1229        .map(|v| v.is_null())
1230        .unwrap_or(true)
1231    {
1232        return Ok(Value::Nil);
1233    }
1234
1235    match type_name.as_str() {
1236        "BOOLEAN" | "BOOL" => {
1237            let v: bool = row
1238                .try_get(ordinal)
1239                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1240            Ok(Value::Boolean(v))
1241        }
1242        "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1243            let v: i64 = row
1244                .try_get(ordinal)
1245                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1246            Ok(Value::Integer(v))
1247        }
1248        "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1249            let v: f64 = row
1250                .try_get(ordinal)
1251                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1252            Ok(Value::Number(v))
1253        }
1254        _ => {
1255            let v: String = row
1256                .try_get(ordinal)
1257                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1258            Ok(Value::String(lua.create_string(&v)?))
1259        }
1260    }
1261}
1262
1263fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1264    match val {
1265        Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1266        Some(other) => {
1267            let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1268            result
1269        }
1270        None => None,
1271    }
1272}
1273
1274fn lua_values_equal(a: &Value, b: &Value) -> bool {
1275    match (a, b) {
1276        (Value::Nil, Value::Nil) => true,
1277        (Value::Boolean(a), Value::Boolean(b)) => a == b,
1278        (Value::Integer(a), Value::Integer(b)) => a == b,
1279        (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1280        (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1281            (*a as f64 - b).abs() < f64::EPSILON
1282        }
1283        (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1284        _ => false,
1285    }
1286}
1287
1288fn lua_value_to_f64(val: Value) -> Option<f64> {
1289    match val {
1290        Value::Integer(i) => Some(i as f64),
1291        Value::Number(f) => Some(f),
1292        _ => None,
1293    }
1294}
1295
1296fn format_lua_value(val: &Value) -> String {
1297    match val {
1298        Value::Nil => "nil".to_string(),
1299        Value::Boolean(b) => b.to_string(),
1300        Value::Integer(i) => i.to_string(),
1301        Value::Number(f) => f.to_string(),
1302        Value::String(s) => match s.to_str() {
1303            Ok(v) => v.to_string(),
1304            Err(_) => "<invalid utf-8>".to_string(),
1305        },
1306        _ => format!("<{}>", val.type_name()),
1307    }
1308}
1309
1310type WsSink = Rc<
1311    tokio::sync::Mutex<
1312        futures_util::stream::SplitSink<
1313            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1314            tokio_tungstenite::tungstenite::Message,
1315        >,
1316    >,
1317>;
1318type WsStream = Rc<
1319    tokio::sync::Mutex<
1320        futures_util::stream::SplitStream<
1321            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1322        >,
1323    >,
1324>;
1325
1326struct WsConn {
1327    sink: WsSink,
1328    stream: WsStream,
1329}
1330impl UserData for WsConn {}
1331
1332fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1333    let ud = match val {
1334        Value::UserData(ud) => ud,
1335        _ => {
1336            return Err(mlua::Error::runtime(format!(
1337                "{fn_name}: first argument must be a ws connection"
1338            )));
1339        }
1340    };
1341    let ws = ud.borrow::<WsConn>().map_err(|_| {
1342        mlua::Error::runtime(format!("{fn_name}: first argument must be a ws connection"))
1343    })?;
1344    Ok((ws.sink.clone(), ws.stream.clone()))
1345}
1346
1347fn register_ws(lua: &Lua) -> mlua::Result<()> {
1348    let ws_table = lua.create_table()?;
1349
1350    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1351        let (stream, _response) = tokio_tungstenite::connect_async(&url)
1352            .await
1353            .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1354        let (sink, read) = stream.split();
1355        lua.create_any_userdata(WsConn {
1356            sink: Rc::new(tokio::sync::Mutex::new(sink)),
1357            stream: Rc::new(tokio::sync::Mutex::new(read)),
1358        })
1359    })?;
1360    ws_table.set("connect", connect_fn)?;
1361
1362    let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1363        let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1364        sink.lock()
1365            .await
1366            .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1367            .await
1368            .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1369        Ok(())
1370    })?;
1371    ws_table.set("send", send_fn)?;
1372
1373    let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1374        let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1375        loop {
1376            let msg = stream
1377                .lock()
1378                .await
1379                .next()
1380                .await
1381                .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1382                .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1383            match msg {
1384                tokio_tungstenite::tungstenite::Message::Text(t) => {
1385                    return Ok(t.to_string());
1386                }
1387                tokio_tungstenite::tungstenite::Message::Binary(b) => {
1388                    return String::from_utf8(b.into())
1389                        .map_err(|e| mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}")));
1390                }
1391                tokio_tungstenite::tungstenite::Message::Close(_) => {
1392                    return Err(mlua::Error::runtime("ws.recv: connection closed"));
1393                }
1394                _ => continue,
1395            }
1396        }
1397    })?;
1398    ws_table.set("recv", recv_fn)?;
1399
1400    let close_fn = lua.create_async_function(|_, conn: Value| async move {
1401        let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1402        sink.lock()
1403            .await
1404            .close()
1405            .await
1406            .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1407        Ok(())
1408    })?;
1409    ws_table.set("close", close_fn)?;
1410
1411    lua.globals().set("ws", ws_table)?;
1412    Ok(())
1413}
1414
1415fn register_template(lua: &Lua) -> mlua::Result<()> {
1416    let tmpl_table = lua.create_table()?;
1417
1418    let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1419        let json_vars = match &vars {
1420            Value::Table(_) => lua_value_to_json(&vars)?,
1421            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1422            _ => {
1423                return Err(mlua::Error::runtime(
1424                    "template.render_string: second argument must be a table or nil",
1425                ));
1426            }
1427        };
1428        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1429        let env = minijinja::Environment::new();
1430        let tmpl = env
1431            .template_from_str(&template_str)
1432            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1433        tmpl.render(mini_vars)
1434            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1435    })?;
1436    tmpl_table.set("render_string", render_string_fn)?;
1437
1438    let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1439        let content = std::fs::read_to_string(&file_path).map_err(|e| {
1440            mlua::Error::runtime(format!(
1441                "template.render: failed to read {file_path:?}: {e}"
1442            ))
1443        })?;
1444        let json_vars = match &vars {
1445            Value::Table(_) => lua_value_to_json(&vars)?,
1446            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1447            _ => {
1448                return Err(mlua::Error::runtime(
1449                    "template.render: second argument must be a table or nil",
1450                ));
1451            }
1452        };
1453        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1454        let env = minijinja::Environment::new();
1455        let tmpl = env
1456            .template_from_str(&content)
1457            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1458        tmpl.render(mini_vars)
1459            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1460    })?;
1461    tmpl_table.set("render", render_fn)?;
1462
1463    lua.globals().set("template", tmpl_table)?;
1464    Ok(())
1465}
1466
1467fn lua_string_literal(s: &str) -> String {
1468    let escaped = s
1469        .replace('\\', "\\\\")
1470        .replace('"', "\\\"")
1471        .replace('\n', "\\n")
1472        .replace('\r', "\\r")
1473        .replace('\0', "\\0");
1474    format!("\"{escaped}\"")
1475}
1476
1477#[cfg(test)]
1478mod tests {
1479    use super::*;
1480
1481    #[test]
1482    fn test_base64_roundtrip() {
1483        let input = "hello world";
1484        let encoded = BASE64.encode(input.as_bytes());
1485        assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1486        let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1487        assert_eq!(String::from_utf8(decoded).unwrap(), input);
1488    }
1489
1490    #[test]
1491    fn test_base64_empty() {
1492        let encoded = BASE64.encode(b"");
1493        assert_eq!(encoded, "");
1494        let decoded = BASE64.decode(b"").unwrap();
1495        assert!(decoded.is_empty());
1496    }
1497
1498    #[test]
1499    fn test_lua_value_to_json_nil() {
1500        let result = lua_value_to_json(&Value::Nil).unwrap();
1501        assert_eq!(result, serde_json::Value::Null);
1502    }
1503
1504    #[test]
1505    fn test_lua_value_to_json_bool() {
1506        assert_eq!(
1507            lua_value_to_json(&Value::Boolean(true)).unwrap(),
1508            serde_json::Value::Bool(true)
1509        );
1510    }
1511
1512    #[test]
1513    fn test_lua_value_to_json_integer() {
1514        assert_eq!(
1515            lua_value_to_json(&Value::Integer(42)).unwrap(),
1516            serde_json::json!(42)
1517        );
1518    }
1519
1520    #[test]
1521    fn test_lua_value_to_json_number() {
1522        assert_eq!(
1523            lua_value_to_json(&Value::Number(1.5)).unwrap(),
1524            serde_json::json!(1.5)
1525        );
1526    }
1527
1528    #[test]
1529    fn test_lua_values_equal_nil() {
1530        assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1531    }
1532
1533    #[test]
1534    fn test_lua_values_equal_int_float() {
1535        assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1536    }
1537
1538    #[test]
1539    fn test_lua_values_not_equal() {
1540        assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1541    }
1542
1543    #[test]
1544    fn test_format_lua_value() {
1545        assert_eq!(format_lua_value(&Value::Nil), "nil");
1546        assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1547        assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1548    }
1549
1550    #[test]
1551    fn test_lua_string_literal_escaping() {
1552        assert_eq!(lua_string_literal("hello"), "\"hello\"");
1553        assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1554        assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1555    }
1556}