Skip to main content

coreutils_rs/tr/
charset.rs

1/// Parse a tr character set string into a Vec<u8> of expanded characters.
2///
3/// Supports:
4/// - Literal characters
5/// - Escape sequences: \\, \a, \b, \f, \n, \r, \t, \v, \NNN (octal)
6/// - Ranges: a-z, A-Z, 0-9
7/// - Character classes: [:alnum:], [:alpha:], etc.
8/// - Equivalence classes: [=c=]
9/// - Repeat: [c*n] or [c*] (SET2 only, handled by caller)
10
11/// Build the complement of a character set: all bytes NOT in the given set.
12/// Result is sorted ascending (0, 1, 2, ... 255 minus the set members).
13pub fn complement(set: &[u8]) -> Vec<u8> {
14    let mut member = [0u8; 32];
15    for &ch in set {
16        member[ch as usize >> 3] |= 1 << (ch & 7);
17    }
18    (0u8..=255)
19        .filter(|&c| (member[c as usize >> 3] & (1 << (c & 7))) == 0)
20        .collect()
21}
22
23/// Parse a SET string into expanded bytes.
24pub fn parse_set(s: &str) -> Vec<u8> {
25    let bytes = s.as_bytes();
26    let mut result = Vec::with_capacity(bytes.len());
27    let mut i = 0;
28
29    while i < bytes.len() {
30        if bytes[i] == b'[' && i + 1 < bytes.len() {
31            // Try character class [:name:]
32            if bytes.get(i + 1) == Some(&b':') {
33                if let Some((class_bytes, end)) = parse_char_class(bytes, i) {
34                    result.extend_from_slice(&class_bytes);
35                    i = end;
36                    continue;
37                }
38            }
39            // Try equivalence class [=c=]
40            if bytes.get(i + 1) == Some(&b'=') {
41                if let Some((ch, end)) = parse_equiv_class(bytes, i) {
42                    result.push(ch);
43                    i = end;
44                    continue;
45                }
46            }
47            // Try repeat [c*n] or [c*]
48            if let Some((ch, count, end)) = parse_repeat(bytes, i) {
49                for _ in 0..count {
50                    result.push(ch);
51                }
52                i = end;
53                continue;
54            }
55        }
56
57        // Escape sequence
58        if bytes[i] == b'\\' && i + 1 < bytes.len() {
59            let (ch, advance) = parse_escape(bytes, i);
60            result.push(ch);
61            i += advance;
62            continue;
63        }
64
65        // Range: prev-next (only if we have a previous char and a next char)
66        if bytes[i] == b'-' && !result.is_empty() && i + 1 < bytes.len() {
67            let start = *result.last().unwrap();
68            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
69                let (ch, adv) = parse_escape(bytes, i + 1);
70                (ch, adv)
71            } else {
72                (bytes[i + 1], 1)
73            };
74            if end_ch >= start {
75                // Expand range (start is already in result)
76                for c in (start + 1)..=end_ch {
77                    result.push(c);
78                }
79                i += 1 + advance;
80            } else {
81                // Invalid range in GNU tr: still emit the characters
82                // GNU tr treats invalid ranges as error, but let's be compatible
83                // Actually GNU tr gives an error for descending ranges
84                // We'll just push the literal '-'
85                result.push(b'-');
86                i += 1;
87            }
88            continue;
89        }
90
91        result.push(bytes[i]);
92        i += 1;
93    }
94
95    result
96}
97
98/// Parse escape sequence starting at position `i` (which points to '\').
99/// Returns (byte_value, number_of_bytes_consumed).
100fn parse_escape(bytes: &[u8], i: usize) -> (u8, usize) {
101    debug_assert_eq!(bytes[i], b'\\');
102    if i + 1 >= bytes.len() {
103        return (b'\\', 1);
104    }
105    match bytes[i + 1] {
106        b'\\' => (b'\\', 2),
107        b'a' => (0x07, 2),
108        b'b' => (0x08, 2),
109        b'f' => (0x0C, 2),
110        b'n' => (b'\n', 2),
111        b'r' => (b'\r', 2),
112        b't' => (b'\t', 2),
113        b'v' => (0x0B, 2),
114        // Octal: \NNN (1-3 octal digits)
115        b'0'..=b'7' => {
116            let mut val: u8 = bytes[i + 1] - b'0';
117            let mut consumed = 2;
118            if i + 2 < bytes.len() && bytes[i + 2] >= b'0' && bytes[i + 2] <= b'7' {
119                val = val * 8 + (bytes[i + 2] - b'0');
120                consumed = 3;
121                if i + 3 < bytes.len() && bytes[i + 3] >= b'0' && bytes[i + 3] <= b'7' {
122                    let new_val = val as u16 * 8 + (bytes[i + 3] - b'0') as u16;
123                    if new_val <= 255 {
124                        val = new_val as u8;
125                        consumed = 4;
126                    }
127                }
128            }
129            (val, consumed)
130        }
131        // Unknown escape: just the char itself (GNU behavior)
132        ch => (ch, 2),
133    }
134}
135
136/// Try to parse a character class like [:alpha:] starting at position i.
137/// Returns (expanded bytes, position after the closing ']').
138fn parse_char_class(bytes: &[u8], i: usize) -> Option<(Vec<u8>, usize)> {
139    // Format: [:name:]
140    // bytes[i] = '[', bytes[i+1] = ':'
141    let start = i + 2;
142    let mut end = start;
143    while end < bytes.len() && bytes[end] != b':' {
144        end += 1;
145    }
146    // Need ':' followed by ']'
147    if end + 1 >= bytes.len() || bytes[end] != b':' || bytes[end + 1] != b']' {
148        return None;
149    }
150    let name = &bytes[start..end];
151    let chars = expand_class(name)?;
152    Some((chars, end + 2))
153}
154
155/// Expand a character class name to its bytes.
156fn expand_class(name: &[u8]) -> Option<Vec<u8>> {
157    match name {
158        b"alnum" => Some(
159            (b'0'..=b'9')
160                .chain(b'A'..=b'Z')
161                .chain(b'a'..=b'z')
162                .collect(),
163        ),
164        b"alpha" => Some((b'A'..=b'Z').chain(b'a'..=b'z').collect()),
165        b"blank" => Some(vec![b'\t', b' ']),
166        b"cntrl" => Some((0u8..=31).chain(std::iter::once(127)).collect()),
167        b"digit" => Some((b'0'..=b'9').collect()),
168        b"graph" => Some((33u8..=126).collect()),
169        b"lower" => Some((b'a'..=b'z').collect()),
170        b"print" => Some((32u8..=126).collect()),
171        b"punct" => Some(
172            (33u8..=47)
173                .chain(58u8..=64)
174                .chain(91u8..=96)
175                .chain(123u8..=126)
176                .collect(),
177        ),
178        b"space" => Some(vec![b'\t', b'\n', 0x0B, 0x0C, b'\r', b' ']),
179        b"upper" => Some((b'A'..=b'Z').collect()),
180        b"xdigit" => Some(
181            (b'0'..=b'9')
182                .chain(b'A'..=b'F')
183                .chain(b'a'..=b'f')
184                .collect(),
185        ),
186        _ => None,
187    }
188}
189
190/// Try to parse an equivalence class like [=c=] starting at position i.
191fn parse_equiv_class(bytes: &[u8], i: usize) -> Option<(u8, usize)> {
192    // Format: [=c=]
193    // bytes[i] = '[', bytes[i+1] = '='
194    if i + 4 >= bytes.len() {
195        return None;
196    }
197    let ch = bytes[i + 2];
198    if bytes[i + 3] == b'=' && bytes[i + 4] == b']' {
199        Some((ch, i + 5))
200    } else {
201        None
202    }
203}
204
205/// Try to parse a repeat construct like [c*n] or [c*] starting at position i.
206/// Returns (character, count, position after ']').
207/// A count of 0 means "fill to match SET1 length" (caller handles).
208fn parse_repeat(bytes: &[u8], i: usize) -> Option<(u8, usize, usize)> {
209    // Format: [c*n] or [c*]
210    // bytes[i] = '['
211    if i + 3 >= bytes.len() {
212        return None;
213    }
214
215    // The char could be an escape
216    let (ch, char_len) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
217        let (c, adv) = parse_escape(bytes, i + 1);
218        (c, adv)
219    } else {
220        (bytes[i + 1], 1)
221    };
222
223    let star_pos = i + 1 + char_len;
224    if star_pos >= bytes.len() || bytes[star_pos] != b'*' {
225        return None;
226    }
227
228    let after_star = star_pos + 1;
229    if after_star >= bytes.len() {
230        return None;
231    }
232
233    // [c*] - repeat to fill
234    if bytes[after_star] == b']' {
235        return Some((ch, 0, after_star + 1));
236    }
237
238    // [c*n] - repeat n times
239    // n can be octal (starts with 0) or decimal
240    let mut end = after_star;
241    while end < bytes.len() && bytes[end] != b']' {
242        end += 1;
243    }
244    if end >= bytes.len() {
245        return None;
246    }
247
248    let num_str = std::str::from_utf8(&bytes[after_star..end]).ok()?;
249    let count = if num_str.starts_with('0') && num_str.len() > 1 {
250        usize::from_str_radix(num_str, 8).ok()?
251    } else {
252        num_str.parse::<usize>().ok()?
253    };
254
255    Some((ch, count, end + 1))
256}
257
258/// Expand SET2 to match SET1 length for translation.
259/// If SET2 has [c*] repeats, fill them. Otherwise repeat last char.
260pub fn expand_set2(set2_str: &str, set1_len: usize) -> Vec<u8> {
261    let bytes = set2_str.as_bytes();
262
263    // Check if there's a [c*] (fill repeat) in SET2
264    // We need to parse SET2 specially: expand everything except [c*] fills,
265    // then compute how many fill chars are needed.
266    let mut before_fill = Vec::new();
267    let mut fill_char: Option<u8> = None;
268    let mut after_fill = Vec::new();
269    let mut i = 0;
270    let mut found_fill = false;
271
272    while i < bytes.len() {
273        if bytes[i] == b'[' && i + 1 < bytes.len() {
274            if bytes.get(i + 1) == Some(&b':') {
275                if let Some((class_bytes, end)) = parse_char_class(bytes, i) {
276                    if found_fill {
277                        after_fill.extend_from_slice(&class_bytes);
278                    } else {
279                        before_fill.extend_from_slice(&class_bytes);
280                    }
281                    i = end;
282                    continue;
283                }
284            }
285            if bytes.get(i + 1) == Some(&b'=') {
286                if let Some((ch, end)) = parse_equiv_class(bytes, i) {
287                    if found_fill {
288                        after_fill.push(ch);
289                    } else {
290                        before_fill.push(ch);
291                    }
292                    i = end;
293                    continue;
294                }
295            }
296            if let Some((ch, count, end)) = parse_repeat(bytes, i) {
297                if count == 0 && !found_fill {
298                    fill_char = Some(ch);
299                    found_fill = true;
300                    i = end;
301                    continue;
302                } else {
303                    let target = if found_fill {
304                        &mut after_fill
305                    } else {
306                        &mut before_fill
307                    };
308                    for _ in 0..count {
309                        target.push(ch);
310                    }
311                    i = end;
312                    continue;
313                }
314            }
315        }
316
317        if bytes[i] == b'\\' && i + 1 < bytes.len() {
318            let (ch, advance) = parse_escape(bytes, i);
319            if found_fill {
320                after_fill.push(ch);
321            } else {
322                before_fill.push(ch);
323            }
324            i += advance;
325            continue;
326        }
327
328        if bytes[i] == b'-' && !before_fill.is_empty() && !found_fill && i + 1 < bytes.len() {
329            let start = *before_fill.last().unwrap();
330            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
331                let (ch, adv) = parse_escape(bytes, i + 1);
332                (ch, adv)
333            } else {
334                (bytes[i + 1], 1)
335            };
336            if end_ch >= start {
337                for c in (start + 1)..=end_ch {
338                    before_fill.push(c);
339                }
340                i += 1 + advance;
341            } else {
342                before_fill.push(b'-');
343                i += 1;
344            }
345            continue;
346        }
347        if bytes[i] == b'-' && !after_fill.is_empty() && found_fill && i + 1 < bytes.len() {
348            let start = *after_fill.last().unwrap();
349            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
350                let (ch, adv) = parse_escape(bytes, i + 1);
351                (ch, adv)
352            } else {
353                (bytes[i + 1], 1)
354            };
355            if end_ch >= start {
356                for c in (start + 1)..=end_ch {
357                    after_fill.push(c);
358                }
359                i += 1 + advance;
360            } else {
361                after_fill.push(b'-');
362                i += 1;
363            }
364            continue;
365        }
366
367        if found_fill {
368            after_fill.push(bytes[i]);
369        } else {
370            before_fill.push(bytes[i]);
371        }
372        i += 1;
373    }
374
375    if let Some(fc) = fill_char {
376        let fixed = before_fill.len() + after_fill.len();
377        let fill_count = if set1_len > fixed {
378            set1_len - fixed
379        } else {
380            0
381        };
382        let mut result = before_fill;
383        result.resize(result.len() + fill_count, fc);
384        result.extend_from_slice(&after_fill);
385        result
386    } else {
387        // No fill repeat — use parse_set and extend with last char
388        let mut set2 = parse_set(set2_str);
389        if set2.len() < set1_len && !set2.is_empty() {
390            let last = *set2.last().unwrap();
391            set2.resize(set1_len, last);
392        }
393        set2
394    }
395}