Skip to main content

lua_stdlib/
utf8_lib.rs

1//! UTF-8 standard library for Lua 5.4.
2//!
3//! Port of `lutf8lib.c` (291 lines, 9 functions).
4//!
5//! Provides the `utf8` module with `char`, `codepoint`, `codes`, `len`,
6//! `offset`, and `charpattern`. Supports both strict (Unicode-conformant)
7//! and lax (extended UTF-8, up to `MAX_UTF = 0x7FFFFFFF`) decoding modes.
8//!
9//! Strict mode rejects surrogates (U+D800..U+DFFF) and values above U+10FFFF.
10//! Lax mode accepts any well-formed byte sequence with a value ≤ MAX_UTF.
11
12use lua_types::error::LuaError;
13use lua_types::value::LuaValue;
14use crate::state_stub::{LuaState, LuaStateStubExt as _};
15
16const MAX_UNICODE: u32 = 0x10_FFFF;
17
18const MAX_UTF: u32 = 0x7FFF_FFFF;
19
20// 31 bits are needed for MAX_UTF; u32 is sufficient on all Rust targets.
21type UtfInt = u32;
22
23// sizeof(UTF8PATT)/sizeof(char) - 1 = 14 bytes (contains an embedded NUL).
24const UTF8_PATT: &[u8] = b"[\x00-\x7F\xC2-\xFD][\x80-\xBF]*";
25
26// ── Internal helpers ───────────────────────────────────────────────────────
27
28/// Translate a relative string position: negative values count backward from end.
29///
30fn pos_relat(pos: i64, len: usize) -> i64 {
31    if pos >= 0 {
32        pos
33    } else {
34        // 0u - (size_t)pos is the magnitude of pos as an unsigned value.
35        let abs_pos = pos.unsigned_abs() as u64;
36        if abs_pos > len as u64 {
37            0
38        } else {
39            len as i64 + pos + 1
40        }
41    }
42}
43
44/// Return `true` if byte `c` is a UTF-8 continuation byte (`10xxxxxx`).
45///
46#[inline]
47fn is_cont(c: u8) -> bool {
48    (c & 0xC0) == 0x80
49}
50
51/// Return `true` if the byte at 0-based index `pos` in `s` is a continuation
52/// byte, treating out-of-bounds positions as non-continuation.
53///
54/// C strings carry a NUL terminator that is never a continuation byte;
55/// the bounds-check here replaces that guarantee.
56#[inline]
57fn is_cont_at(s: &[u8], pos: i64) -> bool {
58    if pos < 0 {
59        return false;
60    }
61    s.get(pos as usize).map_or(false, |&b| is_cont(b))
62}
63
64/// Decode one UTF-8 sequence from the start of `s`.
65///
66/// Returns `None` if the byte sequence is invalid.
67/// Returns `Some((remaining_slice, codepoint))` on success.
68///
69/// When `strict` is `true`, surrogates and values above `MAX_UNICODE` are
70/// rejected. When `false`, any value ≤ `MAX_UTF` is accepted (extended UTF-8).
71///
72fn utf8_decode(s: &[u8], strict: bool) -> Option<(&[u8], UtfInt)> {
73    // LIMITS[count] is the minimum value for a sequence with `count` continuation bytes.
74    // LIMITS[0] = u32::MAX forces an error when a non-ASCII byte has no continuation bytes.
75    const LIMITS: [UtfInt; 6] = [u32::MAX, 0x80, 0x800, 0x10000, 0x200000, 0x4000000];
76
77    if s.is_empty() {
78        return None;
79    }
80
81    let mut c = s[0] as u32;
82    let res: UtfInt;
83    let advance: usize;
84
85    if c < 0x80 {
86        // ASCII fast path — no continuation bytes needed.
87        res = c;
88        advance = 1;
89    } else {
90        let mut count: usize = 0;
91        let mut r: UtfInt = 0;
92
93        // The C for-loop runs the body first, then applies `c <<= 1` as the update.
94        while c & 0x40 != 0 {
95            count += 1;
96            if count >= s.len() {
97                return None; // string too short for the indicated sequence length
98            }
99            let cc = s[count] as u32;
100
101            if (cc & 0xC0) != 0x80 {
102                return None; // expected continuation byte, got something else
103            }
104
105            r = (r << 6) | (cc & 0x3F);
106
107            // C for-loop update: c <<= 1
108            c <<= 1;
109        }
110
111        r |= (c & 0x7F) << (count as u32 * 5);
112
113        if count > 5 || r > MAX_UTF || r < LIMITS[count] {
114            return None; // invalid (overlong, too large, or excess continuation bytes)
115        }
116
117        res = r;
118        advance = count + 1;
119        if advance > s.len() {
120            return None;
121        }
122    }
123
124    if strict && (res > MAX_UNICODE || (0xD800 <= res && res <= 0xDFFF)) {
125        return None; // surrogate or out-of-Unicode-range value in strict mode
126    }
127
128    Some((&s[advance..], res))
129}
130
131/// Encode a codepoint (≤ `MAX_UTF`) as extended UTF-8 bytes.
132///
133/// Mirrors `luaO_utf8esc` from `lobject.c`, which fills a fixed buffer backwards.
134/// This Rust version builds the bytes naturally and returns a `Vec<u8>`.
135///
136fn encode_utf8_codepoint(code: u32) -> Vec<u8> {
137    debug_assert!(code <= MAX_UTF);
138
139    if code < 0x80 {
140        return vec![code as u8];
141    }
142
143    let mut x = code;
144    let mut mfb: u32 = 0x3F;
145    // Continuation bytes built in reverse, then reversed at the end.
146    let mut bytes_rev: Vec<u8> = Vec::with_capacity(6);
147
148    //    while (x > mfb);
149    loop {
150        bytes_rev.push(0x80 | (x & 0x3F) as u8);
151        x >>= 6;
152        mfb >>= 1;
153        if x <= mfb {
154            break;
155        }
156    }
157
158    // wrapping_shl avoids a Rust debug-mode overflow panic on `!mfb << 1`
159    // (e.g., !0x1Fu32 = 0xFFFF_FFE0; << 1 = 0xFFFF_FFC0; as u8 = 0xC0).
160    let leading = ((!mfb).wrapping_shl(1) as u8) | (x as u8);
161
162    let mut result = Vec::with_capacity(bytes_rev.len() + 1);
163    result.push(leading);
164    for &b in bytes_rev.iter().rev() {
165        result.push(b);
166    }
167    result
168}
169
170// ── Library functions ──────────────────────────────────────────────────────
171
172/// `utf8.len(s [, i [, j [, lax]]])` → integer | (nil, integer)
173///
174/// Returns the number of UTF-8 characters that start in the byte range `[i,j]`
175/// of string `s` (1-based, defaulting to the whole string).
176/// On a malformed sequence, returns `(nil, position)` where `position` is the
177/// 1-based byte offset of the first bad byte.
178///
179fn utf_len(state: &mut LuaState) -> Result<usize, LuaError> {
180    // Clone to avoid holding a borrow across subsequent mutable state calls.
181    let s: Vec<u8> = state.check_arg_string(1)?.to_vec();
182    let len = s.len();
183
184    let raw_posi: i64 = state.opt_arg_integer(2, 1)?;
185    let mut posi: i64 = pos_relat(raw_posi, len);
186
187    let raw_posj: i64 = state.opt_arg_integer(3, -1)?;
188    let mut posj: i64 = pos_relat(raw_posj, len);
189
190    let lax: bool = state.to_boolean(4);
191
192    let is_v53 = state.global().lua_version == lua_types::LuaVersion::V53;
193    let initial_msg: &[u8] = if is_v53 {
194        b"initial position out of string"
195    } else {
196        b"initial position out of bounds"
197    };
198    let final_msg: &[u8] = if is_v53 {
199        b"final position out of string"
200    } else {
201        b"final position out of bounds"
202    };
203
204    // Note: C short-circuits, so --posi only executes when 1 <= posi.
205    if posi < 1 {
206        return Err(lua_vm::debug::arg_error_impl(state, 2, initial_msg));
207    }
208    posi -= 1; // 1-based → 0-based
209    if posi > len as i64 {
210        return Err(lua_vm::debug::arg_error_impl(state, 2, initial_msg));
211    }
212
213    posj -= 1; // 1-based → 0-based (always decremented, no short-circuit)
214    if posj >= len as i64 {
215        return Err(lua_vm::debug::arg_error_impl(state, 3, final_msg));
216    }
217
218    let mut n: i64 = 0;
219
220    while posi <= posj {
221        match utf8_decode(&s[posi as usize..], !lax) {
222            None => {
223                state.push(LuaValue::Nil); // luaL_pushfail
224                state.push(LuaValue::Int(posi + 1)); // 1-based position of failure
225                return Ok(2);
226            }
227            Some((remaining, _)) => {
228                posi = (len - remaining.len()) as i64;
229                n += 1;
230            }
231        }
232    }
233
234    state.push(LuaValue::Int(n));
235    Ok(1)
236}
237
238/// `utf8.codepoint(s [, i [, j [, lax]]])` → integer, ...
239///
240/// Returns the codepoints (as integers) for all characters starting in `s[i..j]`.
241///
242fn codepoint(state: &mut LuaState) -> Result<usize, LuaError> {
243    let s: Vec<u8> = state.check_arg_string(1)?.to_vec();
244    let len = s.len();
245
246    let raw_posi: i64 = state.opt_arg_integer(2, 1)?;
247    let posi: i64 = pos_relat(raw_posi, len);
248
249    // Default for the end position is posi (1-based), giving a single character.
250    let raw_pose: i64 = state.opt_arg_integer(3, posi)?;
251    let pose: i64 = pos_relat(raw_pose, len);
252
253    let lax: bool = state.to_boolean(4);
254
255    let bounds_msg: &[u8] = if state.global().lua_version == lua_types::LuaVersion::V53 {
256        b"out of range"
257    } else {
258        b"out of bounds"
259    };
260    if posi < 1 {
261        return Err(lua_vm::debug::arg_error_impl(state, 2, bounds_msg));
262    }
263
264    if pose > len as i64 {
265        return Err(lua_vm::debug::arg_error_impl(state, 3, bounds_msg));
266    }
267
268    if posi > pose {
269        return Ok(0); // empty interval: no values
270    }
271
272    if pose - posi >= i32::MAX as i64 {
273        return Err(LuaError::runtime(format_args!("string slice too long")));
274    }
275
276    let n_max = (pose - posi + 1) as i32;
277    state.ensure_stack(n_max, "string slice too long")?;
278
279    // 0-based: start at (posi - 1), stop before byte index `pose`.
280    let mut pos: usize = (posi - 1) as usize; // 0-based start
281    let end: usize = pose as usize; // 0-based exclusive end
282    let mut count: usize = 0;
283
284    while pos < end {
285        match utf8_decode(&s[pos..], !lax) {
286            None => return Err(LuaError::runtime(format_args!("invalid UTF-8 code"))),
287            Some((remaining, code)) => {
288                state.push(LuaValue::Int(code as i64));
289                count += 1;
290                pos = len - remaining.len(); // advance by decoded character width
291            }
292        }
293    }
294
295    Ok(count)
296}
297
298/// Encode the codepoint at stack argument `arg` and return the UTF-8 bytes.
299///
300/// `Vec<u8>` directly rather than pushing to the stack, avoiding the push/pop
301/// dance that `luaL_Buffer` required.
302///
303/// PORT NOTE: C's `pushutfchar` called `lua_pushfstring(L, "%U", code)` to encode
304/// and push in one step. Here the encoding is extracted so `utf_char` can build
305/// the concatenated result without intermediate stack operations.
306fn get_utf_char_bytes(state: &mut LuaState, arg: i32) -> Result<Vec<u8>, LuaError> {
307    let code = state.check_arg_integer(arg)? as u64;
308
309    let max_code: u64 = if state.global().lua_version == lua_types::LuaVersion::V53 {
310        0x10FFFF
311    } else {
312        MAX_UTF as u64
313    };
314    if code > max_code {
315        return crate::auxlib::arg_error(state, arg, b"value out of range").map(|_| Vec::new());
316    }
317
318    Ok(encode_utf8_codepoint(code as u32))
319}
320
321/// `utf8.char(n1, n2, ...)` → string
322///
323/// Returns a string formed by the UTF-8 encoding of the given codepoints.
324///
325fn utf_char(state: &mut LuaState) -> Result<usize, LuaError> {
326    let n: i32 = state.stack_top() as i32;
327
328    if n == 1 {
329        let bytes = get_utf_char_bytes(state, 1)?;
330        let s = state.intern_str(&bytes)?;
331        state.push(LuaValue::Str(s));
332    } else {
333        //    for (i = 1; i <= n; i++) { pushutfchar(L, i); luaL_addvalue(&b); }
334        //    luaL_pushresult(&b);
335        // PORT NOTE: luaL_Buffer replaced by Vec<u8>; codepoints are encoded
336        // directly into the accumulator without intermediate stack push/pop.
337        let mut buf: Vec<u8> = Vec::new();
338        for i in 1..=n {
339            buf.extend_from_slice(&get_utf_char_bytes(state, i)?);
340        }
341        let s = state.intern_str(&buf)?;
342        state.push(LuaValue::Str(s));
343    }
344
345    Ok(1)
346}
347
348/// `utf8.offset(s, n [, i])` → integer | nil
349///
350/// Returns the byte offset where the n-th character (counting from position `i`)
351/// starts. Negative `n` counts from the end. `n == 0` returns the start of the
352/// character that contains position `i`.
353/// Returns `nil` if the character cannot be found.
354///
355fn byte_offset(state: &mut LuaState) -> Result<usize, LuaError> {
356    let s: Vec<u8> = state.check_arg_string(1)?.to_vec();
357    let len = s.len();
358
359    let n: i64 = state.check_arg_integer(2)?;
360
361    let default_posi: i64 = if n >= 0 { 1 } else { len as i64 + 1 };
362
363    let raw_posi: i64 = state.opt_arg_integer(3, default_posi)?;
364    let posi_1based: i64 = pos_relat(raw_posi, len);
365
366    let pos_msg: &[u8] = if state.global().lua_version == lua_types::LuaVersion::V53 {
367        b"position out of range"
368    } else {
369        b"position out of bounds"
370    };
371    if posi_1based < 1 {
372        return Err(lua_vm::debug::arg_error_impl(state, 3, pos_msg));
373    }
374    let mut posi: i64 = posi_1based - 1; // 1-based → 0-based
375    if posi > len as i64 {
376        return Err(lua_vm::debug::arg_error_impl(state, 3, pos_msg));
377    }
378
379    // `count` is a mutable copy of `n`; driven to 0 when the target character is found.
380    let mut count = n;
381
382    if count == 0 {
383        // Scan backward to find the start of the character containing `posi`.
384        while posi > 0 && is_cont_at(&s, posi) {
385            posi -= 1;
386        }
387        // count remains 0
388    } else {
389        if is_cont_at(&s, posi) {
390            return Err(LuaError::runtime(format_args!(
391                "initial position is a continuation byte"
392            )));
393        }
394
395        if count < 0 {
396            //      do { posi--; } while (posi > 0 && iscontp(s + posi));
397            //      n++;
398            //    }
399            while count < 0 && posi > 0 {
400                // do-while: always decrements at least once, then skips back over
401                // any continuation bytes to land on a leading byte.
402                loop {
403                    posi -= 1;
404                    if posi == 0 || !is_cont_at(&s, posi) {
405                        break;
406                    }
407                }
408                count += 1;
409            }
410        } else {
411            //    while (n > 0 && posi < (lua_Integer)len) {
412            //      do { posi++; } while (iscontp(s + posi));  /* cannot pass '\0' */
413            //      n--;
414            //    }
415            count -= 1; // do not move for the 1st character
416            while count > 0 && posi < len as i64 {
417                // C relies on the NUL terminator to stop the inner do-while.
418                // Rust uses an explicit bounds check instead.
419                loop {
420                    posi += 1;
421                    if !is_cont_at(&s, posi) {
422                        break;
423                    }
424                }
425                count -= 1;
426            }
427        }
428    }
429
430    if count != 0 {
431        state.push(LuaValue::Nil); // luaL_pushfail: character not found
432        return Ok(1);
433    }
434
435    state.push(LuaValue::Int(posi + 1)); // 0-based → 1-based (initial position)
436
437    // Lua 5.5 additionally returns the byte position where the character ends
438    // (inclusive). 5.3/5.4 return only the start.
439    if state.global().lua_version != lua_types::LuaVersion::V55 {
440        return Ok(1);
441    }
442
443    // Multi-byte character? (high bit set on the leading byte)
444    if s.get(posi as usize).is_some_and(|&b| b & 0x80 != 0) {
445        // A continuation byte at the start means the position is mid-character;
446        // mirror the C guard. (Practically unreachable on the success branch.)
447        if is_cont_at(&s, posi) {
448            return Err(LuaError::runtime(format_args!(
449                "initial position is a continuation byte"
450            )));
451        }
452        // Skip forward over trailing continuation bytes to land on the last
453        // byte of this character.
454        while is_cont_at(&s, posi + 1) {
455            posi += 1;
456        }
457    }
458    // One-byte character: final position equals the initial position.
459    state.push(LuaValue::Int(posi + 1)); // 0-based → 1-based (final position)
460    Ok(2)
461}
462
463/// Internal iterator body shared by `iter_aux_strict` and `iter_aux_lax`.
464///
465/// Stack on entry (from the generic for): (1) string, (2) current byte position
466/// (0-based; initially pushed as 0 by `iter_codes`).
467///
468/// Advances past any leading continuation bytes, decodes the next character,
469/// and returns `(next_1based_pos, codepoint)`.  Returns nothing (0) when the
470/// string is exhausted.
471///
472fn iter_aux(state: &mut LuaState, strict: bool) -> Result<usize, LuaError> {
473    let s: Vec<u8> = state.check_arg_string(1)?.to_vec();
474    let len = s.len();
475
476    let mut n: u64 = state.to_integer(2).unwrap_or(0) as u64;
477
478    if (n as usize) < len {
479        while (n as usize) < len && is_cont(s[n as usize]) {
480            n += 1;
481        }
482    }
483
484    if (n as usize) >= len {
485        return Ok(0); // no more codepoints
486    }
487
488    //    if (next == NULL || iscontp(next)) return luaL_error(L, MSGInvalid);
489    match utf8_decode(&s[n as usize..], strict) {
490        None => Err(lua_vm::debug::c_api_runtime(state, b"invalid UTF-8 code".to_vec())),
491        Some((remaining, code)) => {
492            let next_pos = len - remaining.len(); // 0-based index of the next character
493            // valid sequence indicates a malformed input stream.
494            if next_pos < len && is_cont(s[next_pos]) {
495                return Err(lua_vm::debug::c_api_runtime(state, b"invalid UTF-8 code".to_vec()));
496            }
497            state.push(LuaValue::Int((n + 1) as i64)); // 1-based position for next iteration
498            state.push(LuaValue::Int(code as i64));
499            Ok(2)
500        }
501    }
502}
503
504/// Strict iterator body: rejects surrogates and values > MAX_UNICODE.
505///
506fn iter_aux_strict(state: &mut LuaState) -> Result<usize, LuaError> {
507    iter_aux(state, true)
508}
509
510/// Lax iterator body: accepts extended UTF-8 up to MAX_UTF.
511///
512fn iter_aux_lax(state: &mut LuaState) -> Result<usize, LuaError> {
513    iter_aux(state, false)
514}
515
516/// `utf8.codes(s [, lax])` → function, string, integer
517///
518/// Returns the iterator triple `(f, s, 0)` for use in a generic for loop.
519/// Each call to `f(s, pos)` returns the next `(pos, codepoint)` pair.
520///
521fn iter_codes(state: &mut LuaState) -> Result<usize, LuaError> {
522    let lax: bool = state.to_boolean(2);
523
524    let s: Vec<u8> = state.check_arg_string(1)?.to_vec();
525
526    if s.first().map_or(false, |&b| is_cont(b)) {
527        return Err(LuaError::arg_error(1, "invalid UTF-8 code"));
528    }
529
530    let iter_fn: fn(&mut LuaState) -> Result<usize, LuaError> =
531        if lax { iter_aux_lax } else { iter_aux_strict };
532    state.push_c_function(iter_fn)?;
533
534    state.push_value_at(1)?;
535
536    state.push(LuaValue::Int(0));
537
538    Ok(3)
539}
540
541// ── Library registration ───────────────────────────────────────────────────
542
543/// Function registration table for the `utf8` library.
544///
545/// "charpattern" is intentionally absent here; it is a string value and is
546/// registered separately inside `open_utf8` via `lua_setfield`.
547pub const FUNCS: &[(&[u8], fn(&mut LuaState) -> Result<usize, LuaError>)] = &[
548    (b"offset", byte_offset),
549    (b"codepoint", codepoint),
550    (b"char", utf_char),
551    (b"len", utf_len),
552    (b"codes", iter_codes),
553];
554
555/// Open the `utf8` library.
556///
557/// Registers all functions from `FUNCS` into a new table, then sets
558/// `utf8.charpattern` to the byte-string pattern matching one UTF-8 sequence.
559///
560pub fn open_utf8(state: &mut LuaState) -> Result<usize, LuaError> {
561    state.new_lib(FUNCS)?;
562
563    let patt = state.intern_str(UTF8_PATT)?;
564    state.push(LuaValue::Str(patt));
565
566    state.set_field(-2, b"charpattern")?;
567
568    Ok(1)
569}
570
571// ──────────────────────────────────────────────────────────────────────────
572// PORT STATUS
573//   source:        src/lutf8lib.c  (291 lines, 9 functions)
574//   target_crate:  lua-stdlib
575//   confidence:    high
576//   todos:         0
577//   unsafe_blocks: 0   (must be 0 outside explicit unsafe-budget crates)
578//   notes:         Core UTF-8 logic (utf8_decode, encode_utf8_codepoint,
579//                  pos_relat, is_cont_at) is a faithful translation. LuaState
580//                  API names reconciled against state_stub overrides. No unsafe
581//                  blocks; NUL-terminator reliance in C replaced by Rust bounds
582//                  checks throughout.
583// ──────────────────────────────────────────────────────────────────────────