use crate::Result;
use crate::run::RuntimeContext;
use crate::script::lua_script::DEFAULT_MARKERS;
use crate::script::lua_script::helpers::to_vec_of_strings;
use crate::support::Extrude;
use crate::support::html::decode_html_entities;
use crate::support::text::{self, EnsureOptions, truncate_with_ellipsis};
use crate::support::text::{LineBlockIter, LineBlockIterOptions};
use mlua::{FromLua, Lua, MultiValue, String as LuaString, Table, Value};
use std::borrow::Cow;
pub fn init_module(lua: &Lua, _runtime_context: &RuntimeContext) -> Result<Table> {
let table = lua.create_table()?;
table.set("escape_decode", lua.create_function(escape_decode)?)?;
table.set("escape_decode_if_needed", lua.create_function(escape_decode_if_needed)?)?;
table.set("split_first", lua.create_function(split_first)?)?;
table.set("remove_first_line", lua.create_function(remove_first_line)?)?;
table.set("remove_first_lines", lua.create_function(remove_first_lines)?)?;
table.set("remove_last_lines", lua.create_function(remove_last_lines)?)?;
table.set("remove_last_line", lua.create_function(remove_last_line)?)?;
table.set("trim", lua.create_function(trim)?)?;
table.set("trim_start", lua.create_function(trim_start)?)?;
table.set("trim_end", lua.create_function(trim_end)?)?;
table.set("truncate", lua.create_function(truncate)?)?;
table.set(
"replace_markers",
lua.create_function(replace_markers_with_default_parkers)?,
)?;
table.set("ensure", lua.create_function(ensure)?)?;
table.set(
"ensure_single_ending_newline",
lua.create_function(ensure_single_ending_newline)?,
)?;
table.set("extract_line_blocks", lua.create_function(extract_line_blocks)?)?;
Ok(table)
}
impl FromLua for EnsureOptions {
fn from_lua(value: Value, _lua: &Lua) -> mlua::Result<Self> {
let table = value.as_table().ok_or_else(|| {
mlua::Error::runtime(
"Ensure argument needs to be a table with the format {start = string, end = string} (both optional",
)
})?;
let prefix = table.get::<String>("prefix").ok();
let suffix = table.get::<String>("suffix").ok();
for (key, _value) in table.pairs::<Value, Value>().flatten() {
if let Some(key) = key.as_str() {
if key != "prefix" && key != "suffix" {
let msg = format!(
"Ensure argument contains invalid table property `{key}`. Can only contain `prefix` and/or `suffix`"
);
return Err(mlua::Error::RuntimeError(msg));
}
}
}
Ok(EnsureOptions { prefix, suffix })
}
}
fn ensure(lua: &Lua, (content, inst): (String, Value)) -> mlua::Result<String> {
let inst = EnsureOptions::from_lua(inst, lua)?;
let res = crate::support::text::ensure(&content, inst);
let res = res.to_string();
Ok(res)
}
fn ensure_single_ending_newline(_lua: &Lua, content: String) -> mlua::Result<String> {
Ok(crate::support::text::ensure_single_ending_newline(content))
}
fn replace_markers_with_default_parkers(_lua: &Lua, (content, new_sections): (String, Value)) -> mlua::Result<String> {
let sections = to_vec_of_strings(new_sections, "new_sections")?;
let sections: Vec<&str> = sections.iter().map(|s| s.as_str()).collect();
let new_content = text::replace_markers(&content, §ions, DEFAULT_MARKERS)?;
Ok(new_content)
}
fn truncate(_lua: &Lua, (content, max_len, ellipsis): (String, usize, Option<String>)) -> mlua::Result<String> {
let ellipsis = ellipsis.unwrap_or_default();
match truncate_with_ellipsis(&content, max_len, &ellipsis) {
Cow::Borrowed(txt) => Ok(txt.to_string()),
Cow::Owned(txt) => Ok(txt),
}
}
fn split_first(lua: &Lua, (content, sep): (LuaString, LuaString)) -> mlua::Result<MultiValue> {
let content_str = content.to_str()?;
let sep_str = sep.to_str()?;
if let Some(index) = content_str.find(&*sep_str) {
let first_part = &content_str[..index];
let second_part = &content_str[index + sep_str.len()..];
Ok(MultiValue::from_vec(vec![
Value::String(lua.create_string(first_part)?),
Value::String(lua.create_string(second_part)?),
]))
} else {
Ok(MultiValue::from_vec(vec![Value::String(content), Value::Nil]))
}
}
fn trim(lua: &Lua, content: LuaString) -> mlua::Result<Value> {
let original_str = content.to_str()?;
let trimmed = original_str.trim();
if trimmed.len() == original_str.len() {
Ok(Value::String(content))
} else {
lua.create_string(trimmed).map(Value::String)
}
}
fn trim_end(lua: &Lua, content: LuaString) -> mlua::Result<Value> {
let original_str = content.to_str()?;
let trimmed = original_str.trim_end();
if trimmed.len() == original_str.len() {
Ok(Value::String(content))
} else {
lua.create_string(trimmed).map(Value::String)
}
}
fn trim_start(lua: &Lua, content: LuaString) -> mlua::Result<Value> {
let original_str = content.to_str()?;
let trimmed = original_str.trim_start();
if trimmed.len() == original_str.len() {
Ok(Value::String(content))
} else {
lua.create_string(trimmed).map(Value::String)
}
}
fn remove_first_line(_lua: &Lua, content: String) -> mlua::Result<String> {
Ok(remove_first_lines_impl(&content, 1).to_string())
}
fn remove_first_lines(_lua: &Lua, (content, num_of_lines): (String, i64)) -> mlua::Result<String> {
Ok(remove_first_lines_impl(&content, num_of_lines as usize).to_string())
}
fn remove_first_lines_impl(content: &str, num_of_lines: usize) -> &str {
let mut start_idx = 0;
let mut newline_count = 0;
for (i, c) in content.char_indices() {
if c == '\n' {
newline_count += 1;
if newline_count == num_of_lines {
start_idx = i + 1;
break;
}
}
}
if newline_count < num_of_lines {
return "";
}
&content[start_idx..]
}
fn remove_last_line(_lua: &Lua, content: String) -> mlua::Result<String> {
Ok(remove_last_lines_impl(&content, 1).to_string())
}
fn remove_last_lines(_lua: &Lua, (content, num_of_lines): (String, i64)) -> mlua::Result<String> {
Ok(remove_last_lines_impl(&content, num_of_lines as usize).to_string())
}
fn remove_last_lines_impl(content: &str, num_of_lines: usize) -> &str {
let mut end_idx = content.len();
let mut newline_count = 0;
for (i, c) in content.char_indices().rev() {
if c == '\n' {
newline_count += 1;
if newline_count == num_of_lines {
end_idx = i;
break;
}
}
}
if newline_count < num_of_lines {
return "";
}
&content[..end_idx]
}
fn escape_decode_if_needed(_lua: &Lua, content: String) -> mlua::Result<String> {
if !content.contains("<") {
Ok(content)
} else {
escape_decode(_lua, content)
}
}
fn escape_decode(_lua: &Lua, content: String) -> mlua::Result<String> {
Ok(decode_html_entities(&content))
}
fn extract_line_blocks(lua: &Lua, (content, options): (String, Table)) -> mlua::Result<MultiValue> {
let starts_with: Option<String> = options.get("starts_with")?;
let Some(starts_with) = starts_with else {
return Err(crate::Error::custom(
r#"utils.text.extract_line_blocks requires to options with {starts_with = ".."} "#,
)
.into());
};
let extrude_param: Option<String> = options.get("extrude").ok();
let return_extrude = matches!(extrude_param.as_deref(), Some("content"));
let first_opt: Option<i64> = options.get("first").ok();
let first_count: Option<usize> = first_opt.map(|n| n as usize);
let iter_options = LineBlockIterOptions {
starts_with: &starts_with,
extrude: if return_extrude { Some(Extrude::Content) } else { None },
};
let mut iterator = LineBlockIter::new(content.as_str(), iter_options);
let (blocks, extruded_content) = if let Some(n) = first_count {
let mut limited_blocks = Vec::new();
for _ in 0..n {
if let Some(block) = iterator.next() {
limited_blocks.push(block);
} else {
break;
}
}
let remains = if return_extrude {
let (_ignored, extruded) = iterator.collect_remains();
extruded
} else {
String::new()
};
(limited_blocks, remains)
} else {
iterator.collect_blocks_and_extruded_content()
};
let blocks_table = lua.create_table()?;
for block in blocks.iter() {
blocks_table.push(block.as_str())?;
}
let extruded_value = if return_extrude {
Value::String(lua.create_string(&extruded_content)?)
} else {
Value::Nil
};
Ok(MultiValue::from_vec(vec![Value::Table(blocks_table), extruded_value]))
}
#[cfg(test)]
mod tests {
type Result<T> = core::result::Result<T, Box<dyn std::error::Error>>;
use crate::_test_support::{assert_contains, eval_lua, setup_lua};
use value_ext::JsonValueExt as _;
#[tokio::test]
async fn test_lua_text_split_first_simple() -> Result<()> {
let lua = setup_lua(super::init_module, "text")?;
let data = [
(
"some first content\n===\nsecond content",
"===",
("some first content\n", Some("\nsecond content")),
),
("some first content\n", "===", ("some first content\n", None)),
("some first content\n===", "===", ("some first content\n", Some(""))),
];
for (content, sep, expected) in data {
let script = format!(
r#"
local first, second = utils.text.split_first({content:?}, "{sep}")
return {{first, second}}
"#
);
let res = eval_lua(&lua, &script)?;
let values = res.as_array().ok_or("Should have returned an array")?;
let first = values
.first()
.ok_or("Should always have at least a first return")?
.as_str()
.ok_or("First should be string")?;
assert_eq!(expected.0, first);
let second = values.get(1);
if let Some(exp_second) = expected.1 {
let second = second.ok_or("Should have second")?;
assert_eq!(exp_second, second)
} else {
assert!(second.is_none(), "Second should not have been none");
}
}
Ok(())
}
#[tokio::test]
async fn test_lua_text_ensure_simple() -> Result<()> {
let lua = setup_lua(super::init_module, "text")?;
let data = [
(
"some- ! -path",
r#"{prefix = "./", suffix = ".md"}"#,
"./some- ! -path.md",
),
("some- ! -path", r#"{suffix = ".md"}"#, "some- ! -path.md"),
(" ~ some- ! -path", r#"{prefix = " ~ "}"#, " ~ some- ! -path"),
("~ some- ! -path", r#"{prefix = " ~ "}"#, " ~ ~ some- ! -path"),
];
for (content, arg, expected) in data {
let script = format!("return utils.text.ensure(\"{content}\", {arg})");
let res = eval_lua(&lua, &script)?;
assert_eq!(res, expected);
}
Ok(())
}
#[tokio::test]
async fn test_lua_text_extract_line_blocks_simple() -> Result<()> {
let lua = setup_lua(super::init_module, "text")?;
let lua_code = r#"
local content = [[
> one
> two
Some line A
> 3
The end
]]
local a, b = utils.text.extract_line_blocks(content, { starts_with = ">", extrude = "content" })
return {blocks = a, extruded = b}
"#;
let res = eval_lua(&lua, lua_code)?;
let block = res.x_get_str("/blocks/0")?;
assert_eq!(block, "> one\n> two\n");
let block = res.x_get_str("/blocks/1")?;
assert_eq!(block, "> 3\n");
let content = res.x_get_str("/extruded")?;
assert_contains(content, "Some line A");
assert_contains(content, "The end");
Ok(())
}
#[tokio::test]
async fn test_lua_text_extract_line_blocks_with_first_extrude() -> Result<()> {
let lua = setup_lua(super::init_module, "text")?;
let lua_code = r#"
local content = [[
> one
> two
line1
> three
line2
> four
line3
]]
local a, b = utils.text.extract_line_blocks(content, { starts_with = ">", extrude = "content", first = 2 })
return { blocks = a, extruded = b }
"#;
let res = eval_lua(&lua, lua_code)?;
let block1 = res.x_get_str("/blocks/0")?;
assert_eq!(block1, "> one\n> two\n");
let block2 = res.x_get_str("/blocks/1")?;
assert_eq!(block2, "> three\n");
let extruded = res.x_get_str("/extruded")?;
assert_eq!(extruded, "line1\nline2\n> four\nline3\n");
Ok(())
}
#[tokio::test]
async fn test_lua_text_extract_line_blocks_with_first_no_extrude() -> Result<()> {
let lua = setup_lua(super::init_module, "text")?;
let lua_code = r#"
local content = [[
> one
> two
line1
> three
line2
> four
line3
]]
local a, b = utils.text.extract_line_blocks(content, { starts_with = ">", first = 2 })
return { blocks = a, extruded = b }
"#;
let res = eval_lua(&lua, lua_code)?;
let blocks = res.x_get_as::<Vec<&str>>("blocks")?;
assert_eq!(blocks.len(), 2, "should have only 2 blocks");
let block1 = res.x_get_str("/blocks/0")?;
assert_eq!(block1, "> one\n> two\n");
let block2 = res.x_get_str("/blocks/1")?;
assert_eq!(block2, "> three\n");
let extruded = res.get("extruded");
assert!(
extruded.is_none(),
"extruded should be nil when extrude option is not set"
);
Ok(())
}
}