Skip to main content

assay/lua/
builtins.rs

1use data_encoding::BASE64;
2use digest::Digest;
3use futures_util::{SinkExt, StreamExt};
4use http_body_util::Full;
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response, StatusCode};
9use jsonwebtoken::{Algorithm, EncodingKey, Header};
10use mlua::{Lua, Table, UserData, Value};
11use rand::RngExt;
12use sqlx::any::AnyRow;
13use sqlx::{AnyPool, Column, Row, ValueRef};
14use std::collections::HashMap;
15use std::rc::Rc;
16use std::sync::Arc;
17use std::time::{SystemTime, UNIX_EPOCH};
18use tokio::net::TcpListener;
19use tokio_tungstenite::MaybeTlsStream;
20use tracing::{error, info, warn};
21use zeroize::Zeroizing;
22
23struct HttpClient(reqwest::Client);
24impl UserData for HttpClient {}
25
26pub fn register_all(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
27    register_http(lua, client)?;
28    register_json(lua)?;
29    register_yaml(lua)?;
30    register_toml(lua)?;
31    register_assert(lua)?;
32    register_log(lua)?;
33    register_env(lua)?;
34    register_sleep(lua)?;
35    register_time(lua)?;
36    register_fs(lua)?;
37    register_base64(lua)?;
38    register_crypto(lua)?;
39    register_regex(lua)?;
40    register_async(lua)?;
41    register_db(lua)?;
42    register_ws(lua)?;
43    register_template(lua)?;
44    Ok(())
45}
46
47async fn execute_http_request(
48    lua: &Lua,
49    client: &reqwest::Client,
50    method_name: &str,
51    args: mlua::MultiValue,
52) -> mlua::Result<Value> {
53    let has_body = method_name != "get" && method_name != "delete";
54
55    let mut args_iter = args.into_iter();
56    let url: String = match args_iter.next() {
57        Some(Value::String(s)) => s.to_str()?.to_string(),
58        _ => {
59            return Err(mlua::Error::runtime(format!(
60                "http.{method_name}: first argument must be a URL string"
61            )));
62        }
63    };
64
65    let (body_str, auto_json, opts) = if has_body {
66        let (body, is_json) = match args_iter.next() {
67            Some(Value::String(s)) => (s.to_str()?.to_string(), false),
68            Some(Value::Table(t)) => {
69                let json_val = lua_table_to_json(&t)?;
70                let serialized = serde_json::to_string(&json_val).map_err(|e| {
71                    mlua::Error::runtime(format!(
72                        "http.{method_name}: JSON encode failed: {e}"
73                    ))
74                })?;
75                (serialized, true)
76            }
77            Some(Value::Nil) | None => (String::new(), false),
78            _ => {
79                return Err(mlua::Error::runtime(format!(
80                    "http.{method_name}: second argument must be a string, table, or nil"
81                )));
82            }
83        };
84        let opts = match args_iter.next() {
85            Some(Value::Table(t)) => Some(t),
86            Some(Value::Nil) | None => None,
87            _ => {
88                return Err(mlua::Error::runtime(format!(
89                    "http.{method_name}: third argument must be a table or nil"
90                )));
91            }
92        };
93        (body, is_json, opts)
94    } else {
95        let opts = match args_iter.next() {
96            Some(Value::Table(t)) => Some(t),
97            Some(Value::Nil) | None => None,
98            _ => {
99                return Err(mlua::Error::runtime(format!(
100                    "http.{method_name}: second argument must be a table or nil"
101                )));
102            }
103        };
104        (String::new(), false, opts)
105    };
106
107    let mut req = match method_name {
108        "get" => client.get(&url),
109        "post" => client.post(&url),
110        "put" => client.put(&url),
111        "patch" => client.patch(&url),
112        "delete" => client.delete(&url),
113        _ => {
114            return Err(mlua::Error::runtime(format!(
115                "http: unsupported method: {method_name}"
116            )));
117        }
118    };
119
120    if has_body && !body_str.is_empty() {
121        req = req.body(body_str);
122    }
123    if auto_json {
124        req = req.header("Content-Type", "application/json");
125    }
126    if let Some(ref opts_table) = opts
127        && let Ok(headers_table) = opts_table.get::<Table>("headers")
128    {
129        for pair in headers_table.pairs::<String, String>() {
130            let (k, v) = pair?;
131            req = req.header(k, v);
132        }
133    }
134
135    let resp = req.send().await.map_err(|e| {
136        mlua::Error::runtime(format!("http.{method_name} failed: {e}"))
137    })?;
138    let status = resp.status().as_u16();
139    let resp_headers = resp.headers().clone();
140    let body = resp.text().await.map_err(|e| {
141        mlua::Error::runtime(format!(
142            "http.{method_name}: reading body failed: {e}"
143        ))
144    })?;
145
146    let result = lua.create_table()?;
147    result.set("status", status)?;
148    result.set("body", body)?;
149
150    let headers_out = lua.create_table()?;
151    for (name, value) in &resp_headers {
152        if let Ok(v) = value.to_str() {
153            headers_out.set(name.as_str().to_string(), v.to_string())?;
154        }
155    }
156    result.set("headers", headers_out)?;
157
158    Ok(Value::Table(result))
159}
160
161fn register_http(lua: &Lua, client: reqwest::Client) -> mlua::Result<()> {
162    let http_table = lua.create_table()?;
163
164    for method in ["get", "post", "put", "patch", "delete"] {
165        let method_client = client.clone();
166        let method_name = method.to_string();
167
168        let func = lua.create_async_function(move |lua, args: mlua::MultiValue| {
169            let client = method_client.clone();
170            let method_name = method_name.clone();
171            async move {
172                execute_http_request(&lua, &client, &method_name, args).await
173            }
174        })?;
175        http_table.set(method, func)?;
176    }
177
178    let client_fn = lua.create_async_function(|lua, opts: Option<Table>| async move {
179        let mut builder = reqwest::Client::builder();
180
181        let timeout_secs: f64 = opts
182            .as_ref()
183            .and_then(|t| t.get::<f64>("timeout").ok())
184            .unwrap_or(30.0);
185        builder = builder.timeout(std::time::Duration::from_secs_f64(timeout_secs));
186
187        if let Some(ref opts_table) = opts {
188            if let Ok(ca_path) = opts_table.get::<String>("ca_cert_file") {
189                let pem = std::fs::read(&ca_path).map_err(|e| {
190                    mlua::Error::runtime(format!(
191                        "http.client: failed to read CA cert file {ca_path:?}: {e}"
192                    ))
193                })?;
194                let cert = reqwest::Certificate::from_pem(&pem).map_err(|e| {
195                    mlua::Error::runtime(format!(
196                        "http.client: invalid PEM in {ca_path:?}: {e}"
197                    ))
198                })?;
199                builder = builder.add_root_certificate(cert);
200            }
201            if let Ok(ca_pem) = opts_table.get::<String>("ca_cert") {
202                let cert = reqwest::Certificate::from_pem(ca_pem.as_bytes()).map_err(|e| {
203                    mlua::Error::runtime(format!("http.client: invalid CA cert PEM: {e}"))
204                })?;
205                builder = builder.add_root_certificate(cert);
206            }
207        }
208
209        let client = builder.build().map_err(|e| {
210            mlua::Error::runtime(format!("http.client: failed to build client: {e}"))
211        })?;
212
213        let ud = lua.create_any_userdata(HttpClient(client))?;
214
215        let wrapper: Table = lua
216            .load(
217                r#"
218                local ud = ...
219                local obj = { _ud = ud }
220                setmetatable(obj, {
221                    __index = {
222                        get = function(self, url, opts)
223                            return http._client_request(self._ud, "get", url, opts)
224                        end,
225                        post = function(self, url, body, opts)
226                            return http._client_request(self._ud, "post", url, body, opts)
227                        end,
228                        put = function(self, url, body, opts)
229                            return http._client_request(self._ud, "put", url, body, opts)
230                        end,
231                        patch = function(self, url, body, opts)
232                            return http._client_request(self._ud, "patch", url, body, opts)
233                        end,
234                        delete = function(self, url, opts)
235                            return http._client_request(self._ud, "delete", url, opts)
236                        end,
237                    }
238                })
239                return obj
240            "#,
241            )
242            .call(ud)?;
243
244        Ok(Value::Table(wrapper))
245    })?;
246    http_table.set("client", client_fn)?;
247
248    let client_request_fn =
249        lua.create_async_function(|lua, args: mlua::MultiValue| async move {
250            let mut args_iter = args.into_iter();
251
252            let client = match args_iter.next() {
253                Some(Value::UserData(ud)) => {
254                    let hc = ud.borrow::<HttpClient>().map_err(|_| {
255                        mlua::Error::runtime(
256                            "http._client_request: first arg must be an http client",
257                        )
258                    })?;
259                    hc.0.clone()
260                }
261                _ => {
262                    return Err(mlua::Error::runtime(
263                        "http._client_request: first arg must be an http client",
264                    ));
265                }
266            };
267
268            let method_name: String = match args_iter.next() {
269                Some(Value::String(s)) => s.to_str()?.to_string(),
270                _ => {
271                    return Err(mlua::Error::runtime(
272                        "http._client_request: second arg must be method name",
273                    ));
274                }
275            };
276
277            let remaining: mlua::MultiValue = args_iter.collect();
278            execute_http_request(&lua, &client, &method_name, remaining).await
279        })?;
280    http_table.set("_client_request", client_request_fn)?;
281
282    let serve_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
283        let mut args_iter = args.into_iter();
284
285        let port: u16 = match args_iter.next() {
286            Some(Value::Integer(n)) => n as u16,
287            _ => {
288                return Err::<(), _>(mlua::Error::runtime(
289                    "http.serve: first argument must be a port number",
290                ));
291            }
292        };
293
294        let routes_table = match args_iter.next() {
295            Some(Value::Table(t)) => t,
296            _ => {
297                return Err::<(), _>(mlua::Error::runtime(
298                    "http.serve: second argument must be a routes table",
299                ));
300            }
301        };
302
303        let routes = Rc::new(parse_routes(&routes_table)?);
304
305        let listener = TcpListener::bind(format!("0.0.0.0:{port}"))
306            .await
307            .map_err(|e| mlua::Error::runtime(format!("http.serve: bind failed: {e}")))?;
308
309        loop {
310            let (stream, _addr) = listener
311                .accept()
312                .await
313                .map_err(|e| mlua::Error::runtime(format!("http.serve: accept failed: {e}")))?;
314
315            let routes = routes.clone();
316            let lua_clone = lua.clone();
317
318            tokio::task::spawn_local(async move {
319                let io = hyper_util::rt::TokioIo::new(stream);
320                let routes = routes.clone();
321                let lua = lua_clone.clone();
322
323                let service = service_fn(move |req: Request<Incoming>| {
324                    let routes = routes.clone();
325                    let lua = lua.clone();
326                    async move { handle_request(&lua, &routes, req).await }
327                });
328
329                if let Err(e) = http1::Builder::new().serve_connection(io, service).await
330                    && !e.to_string().contains("connection closed")
331                {
332                    error!("http.serve: connection error: {e}");
333                }
334            });
335        }
336    })?;
337    http_table.set("serve", serve_fn)?;
338
339    lua.globals().set("http", http_table)?;
340    Ok(())
341}
342
343fn parse_routes(routes_table: &Table) -> mlua::Result<HashMap<(String, String), mlua::Function>> {
344    let mut routes = HashMap::new();
345    for method_pair in routes_table.pairs::<String, Table>() {
346        let (method, paths_table) = method_pair?;
347        let method_upper = method.to_uppercase();
348        for path_pair in paths_table.pairs::<String, mlua::Function>() {
349            let (path, func) = path_pair?;
350            routes.insert((method_upper.clone(), path), func);
351        }
352    }
353    Ok(routes)
354}
355
356async fn handle_request(
357    lua: &Lua,
358    routes: &HashMap<(String, String), mlua::Function>,
359    req: Request<Incoming>,
360) -> Result<Response<Full<Bytes>>, hyper::Error> {
361    let method = req.method().to_string();
362    let path = req.uri().path().to_string();
363    let query = req.uri().query().unwrap_or("").to_string();
364    let headers: Vec<(String, String)> = req
365        .headers()
366        .iter()
367        .filter_map(|(k, v)| v.to_str().ok().map(|v| (k.to_string(), v.to_string())))
368        .collect();
369
370    let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
371        Ok(collected) => collected.to_bytes(),
372        Err(_) => Bytes::new(),
373    };
374    let body_str = String::from_utf8_lossy(&body_bytes).to_string();
375
376    let key = (method.clone(), path.clone());
377    let handler = match routes.get(&key) {
378        Some(f) => f,
379        None => {
380            return Ok(Response::builder()
381                .status(StatusCode::NOT_FOUND)
382                .header("content-type", "text/plain")
383                .body(Full::new(Bytes::from("not found")))
384                .unwrap());
385        }
386    };
387
388    match build_lua_request_and_call(lua, handler, &method, &path, &query, &headers, &body_str) {
389        Ok(lua_resp) => lua_response_to_http(&lua_resp),
390        Err(e) => Ok(Response::builder()
391            .status(StatusCode::INTERNAL_SERVER_ERROR)
392            .header("content-type", "text/plain")
393            .body(Full::new(Bytes::from(format!("handler error: {e}"))))
394            .unwrap()),
395    }
396}
397
398fn build_lua_request_and_call(
399    lua: &Lua,
400    handler: &mlua::Function,
401    method: &str,
402    path: &str,
403    query: &str,
404    headers: &[(String, String)],
405    body: &str,
406) -> mlua::Result<Table> {
407    let req_table = lua.create_table()?;
408    req_table.set("method", method.to_string())?;
409    req_table.set("path", path.to_string())?;
410    req_table.set("query", query.to_string())?;
411    req_table.set("body", body.to_string())?;
412
413    let headers_table = lua.create_table()?;
414    for (k, v) in headers {
415        headers_table.set(k.as_str(), v.as_str())?;
416    }
417    req_table.set("headers", headers_table)?;
418
419    handler.call::<Table>(req_table)
420}
421
422fn lua_response_to_http(resp_table: &Table) -> Result<Response<Full<Bytes>>, hyper::Error> {
423    let status = resp_table
424        .get::<Option<u16>>("status")
425        .unwrap_or(None)
426        .unwrap_or(200);
427
428    let mut builder =
429        Response::builder().status(StatusCode::from_u16(status).unwrap_or(StatusCode::OK));
430
431    if let Ok(Some(headers_table)) = resp_table.get::<Option<Table>>("headers") {
432        for (k, v) in headers_table.pairs::<String, String>().flatten() {
433            builder = builder.header(k, v);
434        }
435    }
436
437    let body_bytes = if let Ok(Some(json_table)) = resp_table.get::<Option<Table>>("json") {
438        let json_val =
439            lua_value_to_json(&Value::Table(json_table)).unwrap_or(serde_json::Value::Null);
440        let serialized = serde_json::to_string(&json_val).unwrap_or_else(|_| "null".to_string());
441        builder = builder.header("content-type", "application/json");
442        Bytes::from(serialized)
443    } else if let Ok(Some(body_str)) = resp_table.get::<Option<String>>("body") {
444        builder = builder.header("content-type", "text/plain");
445        Bytes::from(body_str)
446    } else {
447        builder = builder.header("content-type", "text/plain");
448        Bytes::new()
449    };
450
451    Ok(builder.body(Full::new(body_bytes)).unwrap())
452}
453
454fn register_json(lua: &Lua) -> mlua::Result<()> {
455    let json_table = lua.create_table()?;
456
457    let parse_fn = lua.create_function(|lua, s: String| {
458        let value: serde_json::Value = serde_json::from_str(&s)
459            .map_err(|e| mlua::Error::runtime(format!("json.parse: {e}")))?;
460        json_value_to_lua(lua, &value)
461    })?;
462    json_table.set("parse", parse_fn)?;
463
464    let encode_fn = lua.create_function(|_, val: Value| {
465        let json_val = lua_value_to_json(&val)?;
466        serde_json::to_string(&json_val)
467            .map_err(|e| mlua::Error::runtime(format!("json.encode: {e}")))
468    })?;
469    json_table.set("encode", encode_fn)?;
470
471    lua.globals().set("json", json_table)?;
472    Ok(())
473}
474
475fn lua_table_to_json(table: &mlua::Table) -> mlua::Result<serde_json::Value> {
476    let mut is_array = true;
477    let mut max_index: i64 = 0;
478    let mut count: i64 = 0;
479
480    for pair in table.clone().pairs::<Value, Value>() {
481        let (key, _) = pair?;
482        count += 1;
483        match key {
484            Value::Integer(i) if i >= 1 => {
485                if i > max_index {
486                    max_index = i;
487                }
488            }
489            _ => {
490                is_array = false;
491                break;
492            }
493        }
494    }
495
496    if is_array && max_index == count {
497        let mut arr = Vec::with_capacity(max_index as usize);
498        for i in 1..=max_index {
499            let val: Value = table.get(i)?;
500            arr.push(lua_value_to_json(&val)?);
501        }
502        Ok(serde_json::Value::Array(arr))
503    } else {
504        let mut map = serde_json::Map::new();
505        for pair in table.clone().pairs::<Value, Value>() {
506            let (key, val) = pair?;
507            let key_str = match key {
508                Value::String(s) => s.to_str()?.to_string(),
509                Value::Integer(i) => i.to_string(),
510                Value::Number(f) => f.to_string(),
511                _ => {
512                    return Err(mlua::Error::runtime(format!(
513                        "unsupported table key type: {}",
514                        key.type_name()
515                    )));
516                }
517            };
518            map.insert(key_str, lua_value_to_json(&val)?);
519        }
520        Ok(serde_json::Value::Object(map))
521    }
522}
523
524fn lua_value_to_json(val: &Value) -> mlua::Result<serde_json::Value> {
525    match val {
526        Value::Nil => Ok(serde_json::Value::Null),
527        Value::Boolean(b) => Ok(serde_json::Value::Bool(*b)),
528        Value::Integer(i) => Ok(serde_json::Value::Number(serde_json::Number::from(*i))),
529        Value::Number(f) => serde_json::Number::from_f64(*f)
530            .map(serde_json::Value::Number)
531            .ok_or_else(|| mlua::Error::runtime(format!("cannot encode {f} as JSON number"))),
532        Value::String(s) => Ok(serde_json::Value::String(s.to_str()?.to_string())),
533        Value::Table(t) => lua_table_to_json(t),
534        _ => Err(mlua::Error::runtime(format!(
535            "unsupported Lua type for JSON: {}",
536            val.type_name()
537        ))),
538    }
539}
540
541pub fn json_value_to_lua(lua: &Lua, val: &serde_json::Value) -> mlua::Result<Value> {
542    match val {
543        serde_json::Value::Null => Ok(Value::Nil),
544        serde_json::Value::Bool(b) => Ok(Value::Boolean(*b)),
545        serde_json::Value::Number(n) => {
546            if let Some(i) = n.as_i64() {
547                Ok(Value::Integer(i))
548            } else if let Some(f) = n.as_f64() {
549                Ok(Value::Number(f))
550            } else {
551                Ok(Value::Nil)
552            }
553        }
554        serde_json::Value::String(s) => Ok(Value::String(lua.create_string(s)?)),
555        serde_json::Value::Array(arr) => {
556            let table = lua.create_table()?;
557            for (i, item) in arr.iter().enumerate() {
558                table.set(i + 1, json_value_to_lua(lua, item)?)?;
559            }
560            Ok(Value::Table(table))
561        }
562        serde_json::Value::Object(map) => {
563            let table = lua.create_table()?;
564            for (k, v) in map {
565                table.set(k.as_str(), json_value_to_lua(lua, v)?)?;
566            }
567            Ok(Value::Table(table))
568        }
569    }
570}
571
572fn register_assert(lua: &Lua) -> mlua::Result<()> {
573    let assert_table = lua.create_table()?;
574
575    let eq_fn = lua.create_function(|lua, args: mlua::MultiValue| {
576        let mut args_iter = args.into_iter();
577        let a = args_iter.next().unwrap_or(Value::Nil);
578        let b = args_iter.next().unwrap_or(Value::Nil);
579        let msg = extract_string_arg(lua, args_iter.next());
580
581        if !lua_values_equal(&a, &b) {
582            let detail = format!(
583                "assert.eq failed: {:?} != {:?}{}",
584                format_lua_value(&a),
585                format_lua_value(&b),
586                msg.map(|m| format!(" - {m}")).unwrap_or_default()
587            );
588            return Err(mlua::Error::runtime(detail));
589        }
590        Ok(())
591    })?;
592    assert_table.set("eq", eq_fn)?;
593
594    let gt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
595        let mut args_iter = args.into_iter();
596        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
597        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
598        let msg = extract_string_arg(lua, args_iter.next());
599
600        match (a, b) {
601            (Some(va), Some(vb)) if va > vb => Ok(()),
602            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
603                "assert.gt failed: {va} is not > {vb}{}",
604                msg.map(|m| format!(" - {m}")).unwrap_or_default()
605            ))),
606            _ => Err(mlua::Error::runtime(
607                "assert.gt: both arguments must be numbers",
608            )),
609        }
610    })?;
611    assert_table.set("gt", gt_fn)?;
612
613    let lt_fn = lua.create_function(|lua, args: mlua::MultiValue| {
614        let mut args_iter = args.into_iter();
615        let a = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
616        let b = lua_value_to_f64(args_iter.next().unwrap_or(Value::Nil));
617        let msg = extract_string_arg(lua, args_iter.next());
618
619        match (a, b) {
620            (Some(va), Some(vb)) if va < vb => Ok(()),
621            (Some(va), Some(vb)) => Err(mlua::Error::runtime(format!(
622                "assert.lt failed: {va} is not < {vb}{}",
623                msg.map(|m| format!(" - {m}")).unwrap_or_default()
624            ))),
625            _ => Err(mlua::Error::runtime(
626                "assert.lt: both arguments must be numbers",
627            )),
628        }
629    })?;
630    assert_table.set("lt", lt_fn)?;
631
632    let contains_fn = lua.create_function(|lua, args: mlua::MultiValue| {
633        let mut args_iter = args.into_iter();
634        let haystack: String = match args_iter.next() {
635            Some(Value::String(s)) => s.to_str()?.to_string(),
636            _ => {
637                return Err(mlua::Error::runtime(
638                    "assert.contains: first argument must be a string",
639                ));
640            }
641        };
642        let needle: String = match args_iter.next() {
643            Some(Value::String(s)) => s.to_str()?.to_string(),
644            _ => {
645                return Err(mlua::Error::runtime(
646                    "assert.contains: second argument must be a string",
647                ));
648            }
649        };
650        let msg = extract_string_arg(lua, args_iter.next());
651
652        if !haystack.contains(&needle) {
653            return Err(mlua::Error::runtime(format!(
654                "assert.contains failed: {haystack:?} does not contain {needle:?}{}",
655                msg.map(|m| format!(" - {m}")).unwrap_or_default()
656            )));
657        }
658        Ok(())
659    })?;
660    assert_table.set("contains", contains_fn)?;
661
662    let not_nil_fn = lua.create_function(|lua, args: mlua::MultiValue| {
663        let mut args_iter = args.into_iter();
664        let val = args_iter.next().unwrap_or(Value::Nil);
665        let msg = extract_string_arg(lua, args_iter.next());
666
667        if val == Value::Nil {
668            return Err(mlua::Error::runtime(format!(
669                "assert.not_nil failed: value is nil{}",
670                msg.map(|m| format!(" - {m}")).unwrap_or_default()
671            )));
672        }
673        Ok(())
674    })?;
675    assert_table.set("not_nil", not_nil_fn)?;
676
677    let matches_fn = lua.create_function(|lua, args: mlua::MultiValue| {
678        let mut args_iter = args.into_iter();
679        let text: String = match args_iter.next() {
680            Some(Value::String(s)) => s.to_str()?.to_string(),
681            _ => {
682                return Err(mlua::Error::runtime(
683                    "assert.matches: first argument must be a string",
684                ));
685            }
686        };
687        let pattern: String = match args_iter.next() {
688            Some(Value::String(s)) => s.to_str()?.to_string(),
689            _ => {
690                return Err(mlua::Error::runtime(
691                    "assert.matches: second argument must be a pattern string",
692                ));
693            }
694        };
695        let msg = extract_string_arg(lua, args_iter.next());
696
697        let found: bool = lua
698            .load(format!(
699                "return string.find({}, {}) ~= nil",
700                lua_string_literal(&text),
701                lua_string_literal(&pattern)
702            ))
703            .eval()
704            .map_err(|e| mlua::Error::runtime(format!("assert.matches: pattern error: {e}")))?;
705
706        if !found {
707            return Err(mlua::Error::runtime(format!(
708                "assert.matches failed: {text:?} does not match pattern {pattern:?}{}",
709                msg.map(|m| format!(" - {m}")).unwrap_or_default()
710            )));
711        }
712        Ok(())
713    })?;
714    assert_table.set("matches", matches_fn)?;
715
716    lua.globals().set("assert", assert_table)?;
717    Ok(())
718}
719
720fn register_log(lua: &Lua) -> mlua::Result<()> {
721    let log_table = lua.create_table()?;
722
723    let info_fn = lua.create_function(|_, msg: String| {
724        info!(target: "lua", "{}", msg);
725        Ok(())
726    })?;
727    log_table.set("info", info_fn)?;
728
729    let warn_fn = lua.create_function(|_, msg: String| {
730        warn!(target: "lua", "{}", msg);
731        Ok(())
732    })?;
733    log_table.set("warn", warn_fn)?;
734
735    let error_fn = lua.create_function(|_, msg: String| {
736        error!(target: "lua", "{}", msg);
737        Ok(())
738    })?;
739    log_table.set("error", error_fn)?;
740
741    lua.globals().set("log", log_table)?;
742    Ok(())
743}
744
745fn register_env(lua: &Lua) -> mlua::Result<()> {
746    let env_table = lua.create_table()?;
747
748    let process_get_fn = lua.create_function(|_, name: String| match std::env::var(&name) {
749        Ok(val) => Ok(Some(val)),
750        Err(_) => Ok(None),
751    })?;
752    env_table.set("_process_get", process_get_fn)?;
753    env_table.set("_check_env", lua.create_table()?)?;
754
755    lua.globals().set("env", env_table)?;
756
757    lua.load(
758        r#"
759        function env.get(name)
760            local val = env._check_env[name]
761            if val ~= nil then return val end
762            return env._process_get(name)
763        end
764        "#,
765    )
766    .exec()?;
767
768    Ok(())
769}
770
771fn register_sleep(lua: &Lua) -> mlua::Result<()> {
772    let sleep_fn = lua.create_async_function(|_, seconds: f64| async move {
773        let duration = std::time::Duration::from_secs_f64(seconds);
774        tokio::time::sleep(duration).await;
775        Ok(())
776    })?;
777    lua.globals().set("sleep", sleep_fn)?;
778    Ok(())
779}
780
781fn register_time(lua: &Lua) -> mlua::Result<()> {
782    let time_fn = lua.create_function(|_, ()| {
783        let secs = SystemTime::now()
784            .duration_since(UNIX_EPOCH)
785            .map_err(|e| mlua::Error::runtime(format!("time(): {e}")))?
786            .as_secs_f64();
787        Ok(secs)
788    })?;
789    lua.globals().set("time", time_fn)?;
790    Ok(())
791}
792
793fn register_fs(lua: &Lua) -> mlua::Result<()> {
794    let fs_table = lua.create_table()?;
795
796    let read_fn = lua.create_function(|_, path: String| {
797        std::fs::read_to_string(&path)
798            .map_err(|e| mlua::Error::runtime(format!("fs.read: failed to read {path:?}: {e}")))
799    })?;
800    fs_table.set("read", read_fn)?;
801
802    let write_fn = lua.create_function(|_, (path, content): (String, String)| {
803        let p = std::path::Path::new(&path);
804        if let Some(parent) = p.parent() {
805            std::fs::create_dir_all(parent).map_err(|e| {
806                mlua::Error::runtime(format!(
807                    "fs.write: failed to create directories for {path:?}: {e}"
808                ))
809            })?;
810        }
811        std::fs::write(&path, &content)
812            .map_err(|e| mlua::Error::runtime(format!("fs.write: failed to write {path:?}: {e}")))
813    })?;
814    fs_table.set("write", write_fn)?;
815
816    lua.globals().set("fs", fs_table)?;
817    Ok(())
818}
819
820fn register_base64(lua: &Lua) -> mlua::Result<()> {
821    let b64_table = lua.create_table()?;
822
823    let encode_fn = lua.create_function(|_, input: String| Ok(BASE64.encode(input.as_bytes())))?;
824    b64_table.set("encode", encode_fn)?;
825
826    let decode_fn = lua.create_function(|_, input: String| {
827        let bytes = BASE64
828            .decode(input.as_bytes())
829            .map_err(|e| mlua::Error::runtime(format!("base64.decode: {e}")))?;
830        String::from_utf8(bytes)
831            .map_err(|e| mlua::Error::runtime(format!("base64.decode: invalid UTF-8: {e}")))
832    })?;
833    b64_table.set("decode", decode_fn)?;
834
835    lua.globals().set("base64", b64_table)?;
836    Ok(())
837}
838
839fn register_crypto(lua: &Lua) -> mlua::Result<()> {
840    let crypto_table = lua.create_table()?;
841
842    let jwt_sign_fn = lua.create_function(|_, args: mlua::MultiValue| {
843        let mut args_iter = args.into_iter();
844
845        let claims_table = match args_iter.next() {
846            Some(Value::Table(t)) => t,
847            _ => {
848                return Err(mlua::Error::runtime(
849                    "crypto.jwt_sign: first argument must be a claims table",
850                ));
851            }
852        };
853
854        let pem_key: String = match args_iter.next() {
855            Some(Value::String(s)) => s.to_str()?.to_string(),
856            _ => {
857                return Err(mlua::Error::runtime(
858                    "crypto.jwt_sign: second argument must be a PEM key string",
859                ));
860            }
861        };
862
863        let algorithm = match args_iter.next() {
864            Some(Value::String(s)) => match s.to_str()?.to_uppercase().as_str() {
865                "RS256" => Algorithm::RS256,
866                "RS384" => Algorithm::RS384,
867                "RS512" => Algorithm::RS512,
868                other => {
869                    return Err(mlua::Error::runtime(format!(
870                        "crypto.jwt_sign: unsupported algorithm: {other}"
871                    )));
872                }
873            },
874            Some(Value::Nil) | None => Algorithm::RS256,
875            _ => {
876                return Err(mlua::Error::runtime(
877                    "crypto.jwt_sign: third argument must be an algorithm string or nil",
878                ));
879            }
880        };
881
882        // 4th optional argument: options table with header fields (kid, typ, etc.)
883        let opts = match args_iter.next() {
884            Some(Value::Table(t)) => Some(t),
885            Some(Value::Nil) | None => None,
886            _ => {
887                return Err(mlua::Error::runtime(
888                    "crypto.jwt_sign: fourth argument must be an options table or nil",
889                ));
890            }
891        };
892
893        let claims_json = lua_value_to_json(&Value::Table(claims_table))?;
894        let pem_bytes = Zeroizing::new(pem_key.into_bytes());
895        let key = EncodingKey::from_rsa_pem(&pem_bytes)
896            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: invalid PEM key: {e}")))?;
897
898        let mut header = Header::new(algorithm);
899        if let Some(ref opts_table) = opts
900            && let Ok(kid) = opts_table.get::<String>("kid")
901        {
902            header.kid = Some(kid);
903        }
904        let token = jsonwebtoken::encode(&header, &claims_json, &key)
905            .map_err(|e| mlua::Error::runtime(format!("crypto.jwt_sign: encoding failed: {e}")))?;
906
907        Ok(token)
908    })?;
909    crypto_table.set("jwt_sign", jwt_sign_fn)?;
910
911    let hash_fn = lua.create_function(|_, args: mlua::MultiValue| {
912        let mut args_iter = args.into_iter();
913
914        let input: String = match args_iter.next() {
915            Some(Value::String(s)) => s.to_str()?.to_string(),
916            _ => {
917                return Err(mlua::Error::runtime(
918                    "crypto.hash: first argument must be a string",
919                ));
920            }
921        };
922
923        let algorithm: String = match args_iter.next() {
924            Some(Value::String(s)) => s.to_str()?.to_lowercase(),
925            Some(Value::Nil) | None => "sha256".to_string(),
926            _ => {
927                return Err(mlua::Error::runtime(
928                    "crypto.hash: second argument must be an algorithm string or nil",
929                ));
930            }
931        };
932
933        let hex = match algorithm.as_str() {
934            "sha224" => format!("{:x}", sha2::Sha224::digest(input.as_bytes())),
935            "sha256" => format!("{:x}", sha2::Sha256::digest(input.as_bytes())),
936            "sha384" => format!("{:x}", sha2::Sha384::digest(input.as_bytes())),
937            "sha512" => format!("{:x}", sha2::Sha512::digest(input.as_bytes())),
938            "sha3-224" => format!("{:x}", sha3::Sha3_224::digest(input.as_bytes())),
939            "sha3-256" => format!("{:x}", sha3::Sha3_256::digest(input.as_bytes())),
940            "sha3-384" => format!("{:x}", sha3::Sha3_384::digest(input.as_bytes())),
941            "sha3-512" => format!("{:x}", sha3::Sha3_512::digest(input.as_bytes())),
942            other => {
943                return Err(mlua::Error::runtime(format!(
944                    "crypto.hash: unsupported algorithm: {other} (supported: sha224, sha256, sha384, sha512, sha3-224, sha3-256, sha3-384, sha3-512)"
945                )));
946            }
947        };
948
949        Ok(hex)
950    })?;
951    crypto_table.set("hash", hash_fn)?;
952
953    let random_fn = lua.create_function(|_, args: mlua::MultiValue| {
954        let mut args_iter = args.into_iter();
955
956        let length: usize = match args_iter.next() {
957            Some(Value::Integer(n)) if n > 0 => n as usize,
958            Some(Value::Integer(n)) => {
959                return Err(mlua::Error::runtime(format!(
960                    "crypto.random: length must be positive, got {n}"
961                )));
962            }
963            Some(Value::Nil) | None => 32,
964            _ => {
965                return Err(mlua::Error::runtime(
966                    "crypto.random: first argument must be a positive integer or nil",
967                ));
968            }
969        };
970
971        let charset: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
972        let mut rng = rand::rng();
973        let result: String = (0..length)
974            .map(|_| charset[rng.random_range(..charset.len())] as char)
975            .collect();
976
977        Ok(result)
978    })?;
979    crypto_table.set("random", random_fn)?;
980
981    lua.globals().set("crypto", crypto_table)?;
982    Ok(())
983}
984
985fn register_yaml(lua: &Lua) -> mlua::Result<()> {
986    let yaml_table = lua.create_table()?;
987
988    let parse_fn = lua.create_function(|lua, s: String| {
989        let json_val: serde_json::Value = serde_yml::from_str(&s)
990            .map_err(|e| mlua::Error::runtime(format!("yaml.parse: {e}")))?;
991        json_value_to_lua(lua, &json_val)
992    })?;
993    yaml_table.set("parse", parse_fn)?;
994
995    let encode_fn = lua.create_function(|_, val: Value| {
996        let json_val = lua_value_to_json(&val)?;
997        serde_yml::to_string(&json_val)
998            .map_err(|e| mlua::Error::runtime(format!("yaml.encode: {e}")))
999    })?;
1000    yaml_table.set("encode", encode_fn)?;
1001
1002    lua.globals().set("yaml", yaml_table)?;
1003    Ok(())
1004}
1005
1006fn register_toml(lua: &Lua) -> mlua::Result<()> {
1007    let toml_table = lua.create_table()?;
1008
1009    let parse_fn = lua.create_function(|lua, s: String| {
1010        let toml_val: toml::Value =
1011            toml::from_str(&s).map_err(|e| mlua::Error::runtime(format!("toml.parse: {e}")))?;
1012        let json_val = serde_json::to_value(&toml_val)
1013            .map_err(|e| mlua::Error::runtime(format!("toml.parse: conversion failed: {e}")))?;
1014        json_value_to_lua(lua, &json_val)
1015    })?;
1016    toml_table.set("parse", parse_fn)?;
1017
1018    let encode_fn = lua.create_function(|_, val: Value| {
1019        let json_val = lua_value_to_json(&val)?;
1020        let toml_val: toml::Value = serde_json::from_value(json_val)
1021            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))?;
1022        toml::to_string_pretty(&toml_val)
1023            .map_err(|e| mlua::Error::runtime(format!("toml.encode: {e}")))
1024    })?;
1025    toml_table.set("encode", encode_fn)?;
1026
1027    lua.globals().set("toml", toml_table)?;
1028    Ok(())
1029}
1030
1031fn register_regex(lua: &Lua) -> mlua::Result<()> {
1032    let regex_table = lua.create_table()?;
1033
1034    let match_fn = lua.create_function(|_, (text, pattern): (String, String)| {
1035        let re = regex_lite::Regex::new(&pattern)
1036            .map_err(|e| mlua::Error::runtime(format!("regex.match: invalid pattern: {e}")))?;
1037        Ok(re.is_match(&text))
1038    })?;
1039    regex_table.set("match", match_fn)?;
1040
1041    let find_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
1042        let re = regex_lite::Regex::new(&pattern)
1043            .map_err(|e| mlua::Error::runtime(format!("regex.find: invalid pattern: {e}")))?;
1044        match re.captures(&text) {
1045            Some(caps) => {
1046                let result = lua.create_table()?;
1047                let full_match = caps.get(0).map(|m| m.as_str()).unwrap_or("");
1048                result.set("match", full_match.to_string())?;
1049                let groups = lua.create_table()?;
1050                for i in 1..caps.len() {
1051                    if let Some(m) = caps.get(i) {
1052                        groups.set(i, m.as_str().to_string())?;
1053                    }
1054                }
1055                result.set("groups", groups)?;
1056                Ok(Value::Table(result))
1057            }
1058            None => Ok(Value::Nil),
1059        }
1060    })?;
1061    regex_table.set("find", find_fn)?;
1062
1063    let find_all_fn = lua.create_function(|lua, (text, pattern): (String, String)| {
1064        let re = regex_lite::Regex::new(&pattern)
1065            .map_err(|e| mlua::Error::runtime(format!("regex.find_all: invalid pattern: {e}")))?;
1066        let results = lua.create_table()?;
1067        for (i, m) in re.find_iter(&text).enumerate() {
1068            results.set(i + 1, m.as_str().to_string())?;
1069        }
1070        Ok(results)
1071    })?;
1072    regex_table.set("find_all", find_all_fn)?;
1073
1074    let replace_fn = lua.create_function(
1075        |_, (text, pattern, replacement): (String, String, String)| {
1076            let re = regex_lite::Regex::new(&pattern).map_err(|e| {
1077                mlua::Error::runtime(format!("regex.replace: invalid pattern: {e}"))
1078            })?;
1079            Ok(re.replace_all(&text, replacement.as_str()).into_owned())
1080        },
1081    )?;
1082    regex_table.set("replace", replace_fn)?;
1083
1084    lua.globals().set("regex", regex_table)?;
1085    Ok(())
1086}
1087
1088fn register_async(lua: &Lua) -> mlua::Result<()> {
1089    let async_table = lua.create_table()?;
1090
1091    let spawn_fn = lua.create_async_function(|lua, func: mlua::Function| async move {
1092        let thread = lua.create_thread(func)?;
1093        let async_thread = thread.into_async::<mlua::MultiValue>(())?;
1094        let join_handle: tokio::task::JoinHandle<Result<Vec<Value>, String>> =
1095            tokio::task::spawn_local(async move {
1096                let values = async_thread.await.map_err(|e| e.to_string())?;
1097                Ok(values.into_vec())
1098            });
1099
1100        let handle = lua.create_table()?;
1101        let cell = std::rc::Rc::new(std::cell::RefCell::new(Some(join_handle)));
1102        let cell_clone = cell.clone();
1103
1104        let await_fn = lua.create_async_function(move |lua, ()| {
1105            let cell = cell_clone.clone();
1106            async move {
1107                let join_handle = cell
1108                    .borrow_mut()
1109                    .take()
1110                    .ok_or_else(|| mlua::Error::runtime("async handle already awaited"))?;
1111                let result = join_handle.await.map_err(|e| {
1112                    mlua::Error::runtime(format!("async.spawn: task panicked: {e}"))
1113                })?;
1114                match result {
1115                    Ok(values) => {
1116                        let tbl = lua.create_table()?;
1117                        for (i, v) in values.into_iter().enumerate() {
1118                            tbl.set(i + 1, v)?;
1119                        }
1120                        Ok(Value::Table(tbl))
1121                    }
1122                    Err(msg) => Err(mlua::Error::runtime(msg)),
1123                }
1124            }
1125        })?;
1126        handle.set("await", await_fn)?;
1127
1128        Ok(handle)
1129    })?;
1130    async_table.set("spawn", spawn_fn)?;
1131
1132    let spawn_interval_fn =
1133        lua.create_async_function(|lua, (seconds, func): (f64, mlua::Function)| async move {
1134            if seconds <= 0.0 {
1135                return Err(mlua::Error::runtime(
1136                    "async.spawn_interval: interval must be positive",
1137                ));
1138            }
1139
1140            let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
1141            let cancel_clone = cancel.clone();
1142
1143            tokio::task::spawn_local({
1144                let cancel = cancel_clone.clone();
1145                async move {
1146                    let mut interval =
1147                        tokio::time::interval(std::time::Duration::from_secs_f64(seconds));
1148                    interval.tick().await;
1149                    loop {
1150                        interval.tick().await;
1151                        if cancel.load(std::sync::atomic::Ordering::Relaxed) {
1152                            break;
1153                        }
1154                        if let Err(e) = func.call_async::<()>(()).await {
1155                            error!("async.spawn_interval: callback error: {e}");
1156                            break;
1157                        }
1158                    }
1159                }
1160            });
1161
1162            let handle = lua.create_table()?;
1163            let cancel_fn = lua.create_function(move |_, ()| {
1164                cancel.store(true, std::sync::atomic::Ordering::Relaxed);
1165                Ok(())
1166            })?;
1167            handle.set("cancel", cancel_fn)?;
1168
1169            Ok(handle)
1170        })?;
1171    async_table.set("spawn_interval", spawn_interval_fn)?;
1172
1173    lua.globals().set("async", async_table)?;
1174    Ok(())
1175}
1176
1177struct DbPool(Arc<AnyPool>);
1178impl UserData for DbPool {}
1179
1180fn register_db(lua: &Lua) -> mlua::Result<()> {
1181    sqlx::any::install_default_drivers();
1182
1183    let db_table = lua.create_table()?;
1184
1185    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1186        let pool = sqlx::any::AnyPoolOptions::new()
1187            .max_connections(if url.starts_with("sqlite:") { 1 } else { 5 })
1188            .connect(&url)
1189            .await
1190            .map_err(|e| mlua::Error::runtime(format!("db.connect: {e}")))?;
1191        lua.create_any_userdata(DbPool(Arc::new(pool)))
1192    })?;
1193    db_table.set("connect", connect_fn)?;
1194
1195    let query_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1196        let mut args_iter = args.into_iter();
1197
1198        let pool = extract_db_pool(&args_iter.next(), "db.query")?;
1199        let sql = extract_sql_string(&args_iter.next(), "db.query")?;
1200        let params = extract_params(&args_iter.next())?;
1201
1202        let mut query = sqlx::query(&sql);
1203        for p in &params {
1204            query = bind_param(query, p);
1205        }
1206
1207        let rows: Vec<AnyRow> = query
1208            .fetch_all(&*pool)
1209            .await
1210            .map_err(|e| mlua::Error::runtime(format!("db.query: {e}")))?;
1211
1212        let result = lua.create_table()?;
1213        for (i, row) in rows.iter().enumerate() {
1214            let row_table = any_row_to_lua_table(&lua, row)?;
1215            result.set(i + 1, row_table)?;
1216        }
1217        Ok(Value::Table(result))
1218    })?;
1219    db_table.set("query", query_fn)?;
1220
1221    let execute_fn = lua.create_async_function(|lua, args: mlua::MultiValue| async move {
1222        let mut args_iter = args.into_iter();
1223
1224        let pool = extract_db_pool(&args_iter.next(), "db.execute")?;
1225        let sql = extract_sql_string(&args_iter.next(), "db.execute")?;
1226        let params = extract_params(&args_iter.next())?;
1227
1228        let mut query = sqlx::query(&sql);
1229        for p in &params {
1230            query = bind_param(query, p);
1231        }
1232
1233        let result = query
1234            .execute(&*pool)
1235            .await
1236            .map_err(|e| mlua::Error::runtime(format!("db.execute: {e}")))?;
1237
1238        let tbl = lua.create_table()?;
1239        tbl.set("rows_affected", result.rows_affected() as i64)?;
1240        Ok(Value::Table(tbl))
1241    })?;
1242    db_table.set("execute", execute_fn)?;
1243
1244    let close_fn = lua.create_async_function(|_, args: mlua::MultiValue| async move {
1245        let mut args_iter = args.into_iter();
1246        let pool = extract_db_pool(&args_iter.next(), "db.close")?;
1247        pool.close().await;
1248        Ok(())
1249    })?;
1250    db_table.set("close", close_fn)?;
1251
1252    lua.globals().set("db", db_table)?;
1253    Ok(())
1254}
1255
1256fn extract_db_pool(val: &Option<Value>, fn_name: &str) -> mlua::Result<Arc<AnyPool>> {
1257    match val {
1258        Some(Value::UserData(ud)) => {
1259            let db = ud.borrow::<DbPool>().map_err(|_| {
1260                mlua::Error::runtime(format!("{fn_name}: first argument must be a db connection"))
1261            })?;
1262            Ok(db.0.clone())
1263        }
1264        _ => Err(mlua::Error::runtime(format!(
1265            "{fn_name}: first argument must be a db connection"
1266        ))),
1267    }
1268}
1269
1270fn extract_sql_string(val: &Option<Value>, fn_name: &str) -> mlua::Result<String> {
1271    match val {
1272        Some(Value::String(s)) => Ok(s.to_str()?.to_string()),
1273        _ => Err(mlua::Error::runtime(format!(
1274            "{fn_name}: second argument must be a SQL string"
1275        ))),
1276    }
1277}
1278
1279#[derive(Clone)]
1280enum DbParam {
1281    Null,
1282    Bool(bool),
1283    Int(i64),
1284    Float(f64),
1285    Text(String),
1286}
1287
1288fn extract_params(val: &Option<Value>) -> mlua::Result<Vec<DbParam>> {
1289    match val {
1290        Some(Value::Table(t)) => {
1291            let mut params = Vec::new();
1292            let len = t.len()?;
1293            for i in 1..=len {
1294                let v: Value = t.get(i)?;
1295                let param = match v {
1296                    Value::Nil => DbParam::Null,
1297                    Value::Boolean(b) => DbParam::Bool(b),
1298                    Value::Integer(n) => DbParam::Int(n),
1299                    Value::Number(f) => DbParam::Float(f),
1300                    Value::String(s) => DbParam::Text(s.to_str()?.to_string()),
1301                    _ => {
1302                        return Err(mlua::Error::runtime(format!(
1303                            "db: unsupported parameter type: {}",
1304                            v.type_name()
1305                        )));
1306                    }
1307                };
1308                params.push(param);
1309            }
1310            Ok(params)
1311        }
1312        Some(Value::Nil) | None => Ok(Vec::new()),
1313        _ => Err(mlua::Error::runtime(
1314            "db: params must be a table (array) or nil",
1315        )),
1316    }
1317}
1318
1319fn bind_param<'q>(
1320    query: sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>>,
1321    param: &'q DbParam,
1322) -> sqlx::query::Query<'q, sqlx::Any, sqlx::any::AnyArguments<'q>> {
1323    match param {
1324        DbParam::Null => query.bind(None::<String>),
1325        DbParam::Bool(b) => query.bind(*b),
1326        DbParam::Int(n) => query.bind(*n),
1327        DbParam::Float(f) => query.bind(*f),
1328        DbParam::Text(s) => query.bind(s.as_str()),
1329    }
1330}
1331
1332fn any_row_to_lua_table(lua: &Lua, row: &AnyRow) -> mlua::Result<Table> {
1333    let table = lua.create_table()?;
1334    for col in row.columns() {
1335        let name = col.name();
1336        let val: Value = any_column_to_lua_value(lua, row, col)?;
1337        table.set(name.to_string(), val)?;
1338    }
1339    Ok(table)
1340}
1341
1342fn any_column_to_lua_value<C: Column>(lua: &Lua, row: &AnyRow, col: &C) -> mlua::Result<Value> {
1343    let ordinal = col.ordinal();
1344    let type_info = col.type_info();
1345    let type_name = type_info.to_string();
1346    let type_name = type_name.to_uppercase();
1347
1348    if row
1349        .try_get_raw(ordinal)
1350        .map(|v| v.is_null())
1351        .unwrap_or(true)
1352    {
1353        return Ok(Value::Nil);
1354    }
1355
1356    match type_name.as_str() {
1357        "BOOLEAN" | "BOOL" => {
1358            let v: bool = row
1359                .try_get(ordinal)
1360                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1361            Ok(Value::Boolean(v))
1362        }
1363        "INTEGER" | "INT" | "INT4" | "INT8" | "BIGINT" | "SMALLINT" | "TINYINT" | "INT2" => {
1364            let v: i64 = row
1365                .try_get(ordinal)
1366                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1367            Ok(Value::Integer(v))
1368        }
1369        "REAL" | "FLOAT" | "FLOAT4" | "FLOAT8" | "DOUBLE" | "DOUBLE PRECISION" | "NUMERIC" => {
1370            let v: f64 = row
1371                .try_get(ordinal)
1372                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1373            Ok(Value::Number(v))
1374        }
1375        _ => {
1376            let v: String = row
1377                .try_get(ordinal)
1378                .map_err(|e| mlua::Error::runtime(format!("db: column read error: {e}")))?;
1379            Ok(Value::String(lua.create_string(&v)?))
1380        }
1381    }
1382}
1383
1384fn extract_string_arg(lua: &Lua, val: Option<Value>) -> Option<String> {
1385    match val {
1386        Some(Value::String(s)) => s.to_str().ok().map(|s| s.to_string()),
1387        Some(other) => {
1388            let result: Option<String> = lua.load("return tostring(...)").call(other).ok();
1389            result
1390        }
1391        None => None,
1392    }
1393}
1394
1395fn lua_values_equal(a: &Value, b: &Value) -> bool {
1396    match (a, b) {
1397        (Value::Nil, Value::Nil) => true,
1398        (Value::Boolean(a), Value::Boolean(b)) => a == b,
1399        (Value::Integer(a), Value::Integer(b)) => a == b,
1400        (Value::Number(a), Value::Number(b)) => (a - b).abs() < f64::EPSILON,
1401        (Value::Integer(a), Value::Number(b)) | (Value::Number(b), Value::Integer(a)) => {
1402            (*a as f64 - b).abs() < f64::EPSILON
1403        }
1404        (Value::String(a), Value::String(b)) => a.as_bytes() == b.as_bytes(),
1405        _ => false,
1406    }
1407}
1408
1409fn lua_value_to_f64(val: Value) -> Option<f64> {
1410    match val {
1411        Value::Integer(i) => Some(i as f64),
1412        Value::Number(f) => Some(f),
1413        _ => None,
1414    }
1415}
1416
1417fn format_lua_value(val: &Value) -> String {
1418    match val {
1419        Value::Nil => "nil".to_string(),
1420        Value::Boolean(b) => b.to_string(),
1421        Value::Integer(i) => i.to_string(),
1422        Value::Number(f) => f.to_string(),
1423        Value::String(s) => match s.to_str() {
1424            Ok(v) => v.to_string(),
1425            Err(_) => "<invalid utf-8>".to_string(),
1426        },
1427        _ => format!("<{}>", val.type_name()),
1428    }
1429}
1430
1431type WsSink = Rc<
1432    tokio::sync::Mutex<
1433        futures_util::stream::SplitSink<
1434            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1435            tokio_tungstenite::tungstenite::Message,
1436        >,
1437    >,
1438>;
1439type WsStream = Rc<
1440    tokio::sync::Mutex<
1441        futures_util::stream::SplitStream<
1442            tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
1443        >,
1444    >,
1445>;
1446
1447struct WsConn {
1448    sink: WsSink,
1449    stream: WsStream,
1450}
1451impl UserData for WsConn {}
1452
1453fn extract_ws_conn(val: &Value, fn_name: &str) -> mlua::Result<(WsSink, WsStream)> {
1454    let ud = match val {
1455        Value::UserData(ud) => ud,
1456        _ => {
1457            return Err(mlua::Error::runtime(format!(
1458                "{fn_name}: first argument must be a ws connection"
1459            )));
1460        }
1461    };
1462    let ws = ud.borrow::<WsConn>().map_err(|_| {
1463        mlua::Error::runtime(format!("{fn_name}: first argument must be a ws connection"))
1464    })?;
1465    Ok((ws.sink.clone(), ws.stream.clone()))
1466}
1467
1468fn register_ws(lua: &Lua) -> mlua::Result<()> {
1469    let ws_table = lua.create_table()?;
1470
1471    let connect_fn = lua.create_async_function(|lua, url: String| async move {
1472        let (stream, _response) = tokio_tungstenite::connect_async(&url)
1473            .await
1474            .map_err(|e| mlua::Error::runtime(format!("ws.connect: {e}")))?;
1475        let (sink, read) = stream.split();
1476        lua.create_any_userdata(WsConn {
1477            sink: Rc::new(tokio::sync::Mutex::new(sink)),
1478            stream: Rc::new(tokio::sync::Mutex::new(read)),
1479        })
1480    })?;
1481    ws_table.set("connect", connect_fn)?;
1482
1483    let send_fn = lua.create_async_function(|_, (conn, msg): (Value, String)| async move {
1484        let (sink, _stream) = extract_ws_conn(&conn, "ws.send")?;
1485        sink.lock()
1486            .await
1487            .send(tokio_tungstenite::tungstenite::Message::Text(msg.into()))
1488            .await
1489            .map_err(|e| mlua::Error::runtime(format!("ws.send: {e}")))?;
1490        Ok(())
1491    })?;
1492    ws_table.set("send", send_fn)?;
1493
1494    let recv_fn = lua.create_async_function(|_, conn: Value| async move {
1495        let (_sink, stream) = extract_ws_conn(&conn, "ws.recv")?;
1496        loop {
1497            let msg = stream
1498                .lock()
1499                .await
1500                .next()
1501                .await
1502                .ok_or_else(|| mlua::Error::runtime("ws.recv: connection closed"))?
1503                .map_err(|e| mlua::Error::runtime(format!("ws.recv: {e}")))?;
1504            match msg {
1505                tokio_tungstenite::tungstenite::Message::Text(t) => {
1506                    return Ok(t.to_string());
1507                }
1508                tokio_tungstenite::tungstenite::Message::Binary(b) => {
1509                    return String::from_utf8(b.into())
1510                        .map_err(|e| mlua::Error::runtime(format!("ws.recv: invalid UTF-8: {e}")));
1511                }
1512                tokio_tungstenite::tungstenite::Message::Close(_) => {
1513                    return Err(mlua::Error::runtime("ws.recv: connection closed"));
1514                }
1515                _ => continue,
1516            }
1517        }
1518    })?;
1519    ws_table.set("recv", recv_fn)?;
1520
1521    let close_fn = lua.create_async_function(|_, conn: Value| async move {
1522        let (sink, _stream) = extract_ws_conn(&conn, "ws.close")?;
1523        sink.lock()
1524            .await
1525            .close()
1526            .await
1527            .map_err(|e| mlua::Error::runtime(format!("ws.close: {e}")))?;
1528        Ok(())
1529    })?;
1530    ws_table.set("close", close_fn)?;
1531
1532    lua.globals().set("ws", ws_table)?;
1533    Ok(())
1534}
1535
1536fn register_template(lua: &Lua) -> mlua::Result<()> {
1537    let tmpl_table = lua.create_table()?;
1538
1539    let render_string_fn = lua.create_function(|_, (template_str, vars): (String, Value)| {
1540        let json_vars = match &vars {
1541            Value::Table(_) => lua_value_to_json(&vars)?,
1542            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1543            _ => {
1544                return Err(mlua::Error::runtime(
1545                    "template.render_string: second argument must be a table or nil",
1546                ));
1547            }
1548        };
1549        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1550        let env = minijinja::Environment::new();
1551        let tmpl = env
1552            .template_from_str(&template_str)
1553            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))?;
1554        tmpl.render(mini_vars)
1555            .map_err(|e| mlua::Error::runtime(format!("template.render_string: {e}")))
1556    })?;
1557    tmpl_table.set("render_string", render_string_fn)?;
1558
1559    let render_fn = lua.create_function(|_, (file_path, vars): (String, Value)| {
1560        let content = std::fs::read_to_string(&file_path).map_err(|e| {
1561            mlua::Error::runtime(format!(
1562                "template.render: failed to read {file_path:?}: {e}"
1563            ))
1564        })?;
1565        let json_vars = match &vars {
1566            Value::Table(_) => lua_value_to_json(&vars)?,
1567            Value::Nil => serde_json::Value::Object(serde_json::Map::new()),
1568            _ => {
1569                return Err(mlua::Error::runtime(
1570                    "template.render: second argument must be a table or nil",
1571                ));
1572            }
1573        };
1574        let mini_vars = minijinja::value::Value::from_serialize(&json_vars);
1575        let env = minijinja::Environment::new();
1576        let tmpl = env
1577            .template_from_str(&content)
1578            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))?;
1579        tmpl.render(mini_vars)
1580            .map_err(|e| mlua::Error::runtime(format!("template.render: {e}")))
1581    })?;
1582    tmpl_table.set("render", render_fn)?;
1583
1584    lua.globals().set("template", tmpl_table)?;
1585    Ok(())
1586}
1587
1588fn lua_string_literal(s: &str) -> String {
1589    let escaped = s
1590        .replace('\\', "\\\\")
1591        .replace('"', "\\\"")
1592        .replace('\n', "\\n")
1593        .replace('\r', "\\r")
1594        .replace('\0', "\\0");
1595    format!("\"{escaped}\"")
1596}
1597
1598#[cfg(test)]
1599mod tests {
1600    use super::*;
1601
1602    #[test]
1603    fn test_base64_roundtrip() {
1604        let input = "hello world";
1605        let encoded = BASE64.encode(input.as_bytes());
1606        assert_eq!(encoded, "aGVsbG8gd29ybGQ=");
1607        let decoded = BASE64.decode(encoded.as_bytes()).unwrap();
1608        assert_eq!(String::from_utf8(decoded).unwrap(), input);
1609    }
1610
1611    #[test]
1612    fn test_base64_empty() {
1613        let encoded = BASE64.encode(b"");
1614        assert_eq!(encoded, "");
1615        let decoded = BASE64.decode(b"").unwrap();
1616        assert!(decoded.is_empty());
1617    }
1618
1619    #[test]
1620    fn test_lua_value_to_json_nil() {
1621        let result = lua_value_to_json(&Value::Nil).unwrap();
1622        assert_eq!(result, serde_json::Value::Null);
1623    }
1624
1625    #[test]
1626    fn test_lua_value_to_json_bool() {
1627        assert_eq!(
1628            lua_value_to_json(&Value::Boolean(true)).unwrap(),
1629            serde_json::Value::Bool(true)
1630        );
1631    }
1632
1633    #[test]
1634    fn test_lua_value_to_json_integer() {
1635        assert_eq!(
1636            lua_value_to_json(&Value::Integer(42)).unwrap(),
1637            serde_json::json!(42)
1638        );
1639    }
1640
1641    #[test]
1642    fn test_lua_value_to_json_number() {
1643        assert_eq!(
1644            lua_value_to_json(&Value::Number(1.5)).unwrap(),
1645            serde_json::json!(1.5)
1646        );
1647    }
1648
1649    #[test]
1650    fn test_lua_values_equal_nil() {
1651        assert!(lua_values_equal(&Value::Nil, &Value::Nil));
1652    }
1653
1654    #[test]
1655    fn test_lua_values_equal_int_float() {
1656        assert!(lua_values_equal(&Value::Integer(42), &Value::Number(42.0)));
1657    }
1658
1659    #[test]
1660    fn test_lua_values_not_equal() {
1661        assert!(!lua_values_equal(&Value::Integer(1), &Value::Integer(2)));
1662    }
1663
1664    #[test]
1665    fn test_format_lua_value() {
1666        assert_eq!(format_lua_value(&Value::Nil), "nil");
1667        assert_eq!(format_lua_value(&Value::Boolean(true)), "true");
1668        assert_eq!(format_lua_value(&Value::Integer(42)), "42");
1669    }
1670
1671    #[test]
1672    fn test_lua_string_literal_escaping() {
1673        assert_eq!(lua_string_literal("hello"), "\"hello\"");
1674        assert_eq!(lua_string_literal("line\nnew"), "\"line\\nnew\"");
1675        assert_eq!(lua_string_literal("quote\"here"), "\"quote\\\"here\"");
1676    }
1677}