Skip to main content

mlua_batteries/
string.rs

1//! Extended string operations (Unicode-aware).
2//!
3//! Complements Lua's built-in `string` library with operations that are
4//! either missing or only ASCII-aware in standard Lua.
5//!
6//! ```lua
7//! local str = std.string
8//! str.trim("  hello  ")            --> "hello"
9//! str.split("a,b,c", ",")          --> {"a", "b", "c"}
10//! str.starts_with("hello", "he")   --> true
11//! str.replace("abab", "ab", "x")   --> "xab"
12//! str.replace_all("abab", "ab", "x") --> "xx"
13//! str.upper("café")                --> "CAFÉ"
14//! str.pad_start("42", 5, "0")      --> "00042"
15//! ```
16
17use mlua::prelude::*;
18
19pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
20    let t = lua.create_table()?;
21
22    // ─── Trimming ─────────────────────────────────────────
23
24    t.set(
25        "trim",
26        lua.create_function(|_, s: String| Ok(s.trim().to_string()))?,
27    )?;
28
29    t.set(
30        "trim_start",
31        lua.create_function(|_, s: String| Ok(s.trim_start().to_string()))?,
32    )?;
33
34    t.set(
35        "trim_end",
36        lua.create_function(|_, s: String| Ok(s.trim_end().to_string()))?,
37    )?;
38
39    // ─── Split ────────────────────────────────────────────
40
41    t.set(
42        "split",
43        lua.create_function(|lua, (s, sep): (String, String)| {
44            if sep.is_empty() {
45                return Err(LuaError::external(
46                    "string.split: separator must not be empty",
47                ));
48            }
49            let table = lua.create_table()?;
50            for (i, part) in s.split(&*sep).enumerate() {
51                table.set(i + 1, part)?;
52            }
53            Ok(table)
54        })?,
55    )?;
56
57    // ─── Predicates ───────────────────────────────────────
58
59    t.set(
60        "starts_with",
61        lua.create_function(|_, (s, prefix): (String, String)| Ok(s.starts_with(&*prefix)))?,
62    )?;
63
64    t.set(
65        "ends_with",
66        lua.create_function(|_, (s, suffix): (String, String)| Ok(s.ends_with(&*suffix)))?,
67    )?;
68
69    t.set(
70        "contains",
71        lua.create_function(|_, (s, needle): (String, String)| Ok(s.contains(&*needle)))?,
72    )?;
73
74    // ─── Replace (non-regex) ──────────────────────────────
75
76    t.set(
77        "replace",
78        lua.create_function(|_, (s, from, to): (String, String, String)| {
79            Ok(s.replacen(&*from, &to, 1))
80        })?,
81    )?;
82
83    t.set(
84        "replace_all",
85        lua.create_function(|_, (s, from, to): (String, String, String)| {
86            Ok(s.replace(&*from, &to))
87        })?,
88    )?;
89
90    // ─── Padding ──────────────────────────────────────────
91
92    t.set(
93        "pad_start",
94        lua.create_function(|_, (s, width, fill): (String, usize, Option<String>)| {
95            let fill_char = parse_fill_char(&fill)?;
96            let char_count = s.chars().count();
97            if char_count >= width {
98                return Ok(s);
99            }
100            let padding: String = std::iter::repeat(fill_char)
101                .take(width - char_count)
102                .collect();
103            Ok(format!("{padding}{s}"))
104        })?,
105    )?;
106
107    t.set(
108        "pad_end",
109        lua.create_function(|_, (s, width, fill): (String, usize, Option<String>)| {
110            let fill_char = parse_fill_char(&fill)?;
111            let char_count = s.chars().count();
112            if char_count >= width {
113                return Ok(s);
114            }
115            let padding: String = std::iter::repeat(fill_char)
116                .take(width - char_count)
117                .collect();
118            Ok(format!("{s}{padding}"))
119        })?,
120    )?;
121
122    // ─── Truncate ─────────────────────────────────────────
123
124    t.set(
125        "truncate",
126        lua.create_function(|_, (s, max_len, suffix): (String, usize, Option<String>)| {
127            let suffix = suffix.unwrap_or_default();
128            let char_count = s.chars().count();
129            if char_count <= max_len {
130                return Ok(s);
131            }
132            let suffix_len = suffix.chars().count();
133            if max_len <= suffix_len {
134                return Ok(suffix.chars().take(max_len).collect());
135            }
136            let keep = max_len - suffix_len;
137            let truncated: String = s.chars().take(keep).collect();
138            Ok(format!("{truncated}{suffix}"))
139        })?,
140    )?;
141
142    // ─── Unicode-aware case conversion ────────────────────
143
144    t.set(
145        "upper",
146        lua.create_function(|_, s: String| Ok(s.to_uppercase()))?,
147    )?;
148
149    t.set(
150        "lower",
151        lua.create_function(|_, s: String| Ok(s.to_lowercase()))?,
152    )?;
153
154    // ─── Unicode utilities ────────────────────────────────
155
156    t.set(
157        "chars",
158        lua.create_function(|lua, s: String| {
159            let table = lua.create_table()?;
160            for (i, ch) in s.chars().enumerate() {
161                let mut buf = [0u8; 4];
162                table.set(i + 1, &*ch.encode_utf8(&mut buf))?;
163            }
164            Ok(table)
165        })?,
166    )?;
167
168    t.set(
169        "char_count",
170        lua.create_function(|_, s: String| Ok(s.chars().count()))?,
171    )?;
172
173    t.set(
174        "reverse",
175        lua.create_function(|_, s: String| Ok(s.chars().rev().collect::<String>()))?,
176    )?;
177
178    Ok(t)
179}
180
181fn parse_fill_char(fill: &Option<String>) -> LuaResult<char> {
182    match fill {
183        None => Ok(' '),
184        Some(s) => {
185            let mut chars = s.chars();
186            match (chars.next(), chars.next()) {
187                (Some(c), None) => Ok(c),
188                _ => Err(LuaError::external("fill must be a single character")),
189            }
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::util::test_eval as eval;
197
198    // ─── trim ─────────────────────────────────────────────
199
200    #[test]
201    fn trim_whitespace() {
202        let s: String = eval(r#"return std.string.trim("  hello  ")"#);
203        assert_eq!(s, "hello");
204    }
205
206    #[test]
207    fn trim_start_whitespace() {
208        let s: String = eval(r#"return std.string.trim_start("  hello  ")"#);
209        assert_eq!(s, "hello  ");
210    }
211
212    #[test]
213    fn trim_end_whitespace() {
214        let s: String = eval(r#"return std.string.trim_end("  hello  ")"#);
215        assert_eq!(s, "  hello");
216    }
217
218    #[test]
219    fn trim_empty_string() {
220        let s: String = eval(r#"return std.string.trim("")"#);
221        assert_eq!(s, "");
222    }
223
224    #[test]
225    fn trim_no_whitespace() {
226        let s: String = eval(r#"return std.string.trim("hello")"#);
227        assert_eq!(s, "hello");
228    }
229
230    // ─── split ────────────────────────────────────────────
231
232    #[test]
233    fn split_by_comma() {
234        let s: String = eval(
235            r#"
236            local parts = std.string.split("a,b,c", ",")
237            return parts[1] .. "|" .. parts[2] .. "|" .. parts[3]
238        "#,
239        );
240        assert_eq!(s, "a|b|c");
241    }
242
243    #[test]
244    fn split_no_match() {
245        let n: i64 = eval(
246            r#"
247            local parts = std.string.split("hello", ",")
248            return #parts
249        "#,
250        );
251        assert_eq!(n, 1);
252    }
253
254    #[test]
255    fn split_empty_parts() {
256        let n: i64 = eval(
257            r#"
258            local parts = std.string.split(",a,,b,", ",")
259            return #parts
260        "#,
261        );
262        assert_eq!(n, 5);
263    }
264
265    #[test]
266    fn split_empty_separator_returns_error() {
267        let lua = mlua::Lua::new();
268        crate::register_all(&lua, "std").unwrap();
269        let result: mlua::Result<mlua::Value> =
270            lua.load(r#"return std.string.split("abc", "")"#).eval();
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn split_multi_char_separator() {
276        let s: String = eval(
277            r#"
278            local parts = std.string.split("a::b::c", "::")
279            return parts[1] .. "|" .. parts[2] .. "|" .. parts[3]
280        "#,
281        );
282        assert_eq!(s, "a|b|c");
283    }
284
285    // ─── predicates ───────────────────────────────────────
286
287    #[test]
288    fn starts_with_true() {
289        let b: bool = eval(r#"return std.string.starts_with("hello world", "hello")"#);
290        assert!(b);
291    }
292
293    #[test]
294    fn starts_with_false() {
295        let b: bool = eval(r#"return std.string.starts_with("hello world", "world")"#);
296        assert!(!b);
297    }
298
299    #[test]
300    fn ends_with_true() {
301        let b: bool = eval(r#"return std.string.ends_with("hello world", "world")"#);
302        assert!(b);
303    }
304
305    #[test]
306    fn ends_with_false() {
307        let b: bool = eval(r#"return std.string.ends_with("hello world", "hello")"#);
308        assert!(!b);
309    }
310
311    #[test]
312    fn contains_true() {
313        let b: bool = eval(r#"return std.string.contains("hello world", "lo wo")"#);
314        assert!(b);
315    }
316
317    #[test]
318    fn contains_false() {
319        let b: bool = eval(r#"return std.string.contains("hello world", "xyz")"#);
320        assert!(!b);
321    }
322
323    // ─── replace ──────────────────────────────────────────
324
325    #[test]
326    fn replace_first_only() {
327        let s: String = eval(r#"return std.string.replace("abab", "ab", "x")"#);
328        assert_eq!(s, "xab");
329    }
330
331    #[test]
332    fn replace_all_occurrences() {
333        let s: String = eval(r#"return std.string.replace_all("abab", "ab", "x")"#);
334        assert_eq!(s, "xx");
335    }
336
337    #[test]
338    fn replace_no_match() {
339        let s: String = eval(r#"return std.string.replace("hello", "xyz", "!")"#);
340        assert_eq!(s, "hello");
341    }
342
343    // ─── pad ──────────────────────────────────────────────
344
345    #[test]
346    fn pad_start_with_zeros() {
347        let s: String = eval(r#"return std.string.pad_start("42", 5, "0")"#);
348        assert_eq!(s, "00042");
349    }
350
351    #[test]
352    fn pad_start_default_space() {
353        let s: String = eval(r#"return std.string.pad_start("hi", 5)"#);
354        assert_eq!(s, "   hi");
355    }
356
357    #[test]
358    fn pad_start_already_long() {
359        let s: String = eval(r#"return std.string.pad_start("hello", 3, "x")"#);
360        assert_eq!(s, "hello");
361    }
362
363    #[test]
364    fn pad_end_with_dots() {
365        let s: String = eval(r#"return std.string.pad_end("hi", 5, ".")"#);
366        assert_eq!(s, "hi...");
367    }
368
369    #[test]
370    fn pad_end_default_space() {
371        let s: String = eval(r#"return std.string.pad_end("hi", 5)"#);
372        assert_eq!(s, "hi   ");
373    }
374
375    #[test]
376    fn pad_fill_multi_char_returns_error() {
377        let lua = mlua::Lua::new();
378        crate::register_all(&lua, "std").unwrap();
379        let result: mlua::Result<mlua::Value> = lua
380            .load(r#"return std.string.pad_start("x", 5, "ab")"#)
381            .eval();
382        assert!(result.is_err());
383    }
384
385    // ─── truncate ─────────────────────────────────────────
386
387    #[test]
388    fn truncate_with_ellipsis() {
389        let s: String = eval(r#"return std.string.truncate("hello world", 8, "...")"#);
390        assert_eq!(s, "hello...");
391    }
392
393    #[test]
394    fn truncate_no_suffix() {
395        let s: String = eval(r#"return std.string.truncate("hello world", 5)"#);
396        assert_eq!(s, "hello");
397    }
398
399    #[test]
400    fn truncate_already_short() {
401        let s: String = eval(r#"return std.string.truncate("hi", 10, "...")"#);
402        assert_eq!(s, "hi");
403    }
404
405    #[test]
406    fn truncate_max_equals_suffix_len() {
407        let s: String = eval(r#"return std.string.truncate("hello", 3, "...")"#);
408        assert_eq!(s, "...");
409    }
410
411    #[test]
412    fn truncate_max_less_than_suffix_len() {
413        let s: String = eval(r#"return std.string.truncate("hello", 2, "...")"#);
414        assert_eq!(s, "..");
415    }
416
417    // ─── Unicode-aware case conversion ────────────────────
418
419    #[test]
420    fn upper_unicode() {
421        let s: String = eval(r#"return std.string.upper("café")"#);
422        assert_eq!(s, "CAFÉ");
423    }
424
425    #[test]
426    fn lower_unicode() {
427        let s: String = eval(r#"return std.string.lower("CAFÉ")"#);
428        assert_eq!(s, "café");
429    }
430
431    #[test]
432    fn upper_german_eszett() {
433        let s: String = eval(r#"return std.string.upper("straße")"#);
434        assert_eq!(s, "STRASSE");
435    }
436
437    // ─── Unicode utilities ────────────────────────────────
438
439    #[test]
440    fn chars_ascii() {
441        let s: String = eval(
442            r#"
443            local cs = std.string.chars("abc")
444            return cs[1] .. cs[2] .. cs[3]
445        "#,
446        );
447        assert_eq!(s, "abc");
448    }
449
450    #[test]
451    fn chars_multibyte() {
452        let n: i64 = eval(r#"return #std.string.chars("café")"#);
453        assert_eq!(n, 4);
454    }
455
456    #[test]
457    fn char_count_ascii() {
458        let n: i64 = eval(r#"return std.string.char_count("hello")"#);
459        assert_eq!(n, 5);
460    }
461
462    #[test]
463    fn char_count_multibyte() {
464        let n: i64 = eval(r#"return std.string.char_count("café")"#);
465        assert_eq!(n, 4);
466    }
467
468    #[test]
469    fn char_count_emoji() {
470        let n: i64 = eval(r#"return std.string.char_count("👋🌍")"#);
471        assert_eq!(n, 2);
472    }
473
474    #[test]
475    fn reverse_ascii() {
476        let s: String = eval(r#"return std.string.reverse("hello")"#);
477        assert_eq!(s, "olleh");
478    }
479
480    #[test]
481    fn reverse_unicode() {
482        let s: String = eval(r#"return std.string.reverse("café")"#);
483        assert_eq!(s, "éfac");
484    }
485
486    #[test]
487    fn reverse_empty() {
488        let s: String = eval(r#"return std.string.reverse("")"#);
489        assert_eq!(s, "");
490    }
491}