Skip to main content

coreutils_rs/tr/
charset.rs

1/// Parse a tr character set string into a Vec<u8> of expanded characters.
2///
3/// Supports:
4/// - Literal characters
5/// - Escape sequences: \\, \a, \b, \f, \n, \r, \t, \v, \NNN (octal)
6/// - Ranges: a-z, A-Z, 0-9
7/// - Character classes: [:alnum:], [:alpha:], etc.
8/// - Equivalence classes: [=c=]
9/// - Repeat: [c*n] or [c*] (SET2 only, handled by caller)
10
11/// Identifies a case-conversion character class and its position in the expanded set.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum CaseClass {
14    Upper,
15    Lower,
16}
17
18/// Records the position and type of a [:upper:] or [:lower:] class in a set.
19#[derive(Debug, Clone, Copy)]
20pub struct CaseClassInfo {
21    pub class: CaseClass,
22    pub position: usize,
23}
24
25/// Build the complement of a character set: all bytes NOT in the given set.
26/// Result is sorted ascending (0, 1, 2, ... 255 minus the set members).
27pub fn complement(set: &[u8]) -> Vec<u8> {
28    let mut member = [0u8; 32];
29    for &ch in set {
30        member[ch as usize >> 3] |= 1 << (ch & 7);
31    }
32    (0u8..=255)
33        .filter(|&c| (member[c as usize >> 3] & (1 << (c & 7))) == 0)
34        .collect()
35}
36
37/// Parse a SET string into expanded bytes.
38pub fn parse_set(s: &str) -> Vec<u8> {
39    let bytes = s.as_bytes();
40    let mut result = Vec::with_capacity(bytes.len());
41    let mut i = 0;
42
43    while i < bytes.len() {
44        if bytes[i] == b'[' && i + 1 < bytes.len() {
45            // Try character class [:name:]
46            if bytes.get(i + 1) == Some(&b':') {
47                if let Some((class_bytes, end)) = parse_char_class(bytes, i) {
48                    result.extend_from_slice(&class_bytes);
49                    i = end;
50                    continue;
51                }
52            }
53            // Try equivalence class [=c=]
54            if bytes.get(i + 1) == Some(&b'=') {
55                if let Some((ch, end)) = parse_equiv_class(bytes, i) {
56                    result.push(ch);
57                    i = end;
58                    continue;
59                }
60            }
61            // Try repeat [c*n] or [c*]
62            if let Some((ch, count, end)) = parse_repeat(bytes, i) {
63                for _ in 0..count {
64                    result.push(ch);
65                }
66                i = end;
67                continue;
68            }
69        }
70
71        // Escape sequence
72        if bytes[i] == b'\\' && i + 1 < bytes.len() {
73            let (ch, advance) = parse_escape(bytes, i);
74            result.push(ch);
75            i += advance;
76            continue;
77        }
78
79        // Range: prev-next (only if we have a previous char and a next char)
80        if bytes[i] == b'-' && !result.is_empty() && i + 1 < bytes.len() {
81            let start = *result.last().unwrap();
82            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
83                let (ch, adv) = parse_escape(bytes, i + 1);
84                (ch, adv)
85            } else {
86                (bytes[i + 1], 1)
87            };
88            if end_ch >= start {
89                // Expand range (start is already in result)
90                for c in (start + 1)..=end_ch {
91                    result.push(c);
92                }
93                i += 1 + advance;
94            } else {
95                // Invalid range in GNU tr: still emit the characters
96                // GNU tr treats invalid ranges as error, but let's be compatible
97                // Actually GNU tr gives an error for descending ranges
98                // We'll just push the literal '-'
99                result.push(b'-');
100                i += 1;
101            }
102            continue;
103        }
104
105        result.push(bytes[i]);
106        i += 1;
107    }
108
109    result
110}
111
112/// Parse a SET string into expanded bytes AND track positions of [:upper:]/[:lower:] classes.
113/// This is needed for GNU-compatible validation of case class alignment.
114pub fn parse_set_with_classes(s: &str) -> (Vec<u8>, Vec<CaseClassInfo>) {
115    let bytes = s.as_bytes();
116    let mut result = Vec::with_capacity(bytes.len());
117    let mut classes = Vec::new();
118    let mut i = 0;
119
120    while i < bytes.len() {
121        if bytes[i] == b'[' && i + 1 < bytes.len() {
122            // Try character class [:name:]
123            if bytes.get(i + 1) == Some(&b':') {
124                if let Some((class_bytes, end)) = parse_char_class(bytes, i) {
125                    // Check if this is [:upper:] or [:lower:]
126                    let name_start = i + 2;
127                    let mut name_end = name_start;
128                    while name_end < bytes.len() && bytes[name_end] != b':' {
129                        name_end += 1;
130                    }
131                    let name = &bytes[name_start..name_end];
132                    if name == b"upper" {
133                        classes.push(CaseClassInfo {
134                            class: CaseClass::Upper,
135                            position: result.len(),
136                        });
137                    } else if name == b"lower" {
138                        classes.push(CaseClassInfo {
139                            class: CaseClass::Lower,
140                            position: result.len(),
141                        });
142                    }
143                    result.extend_from_slice(&class_bytes);
144                    i = end;
145                    continue;
146                }
147            }
148            // Try equivalence class [=c=]
149            if bytes.get(i + 1) == Some(&b'=') {
150                if let Some((ch, end)) = parse_equiv_class(bytes, i) {
151                    result.push(ch);
152                    i = end;
153                    continue;
154                }
155            }
156            // Try repeat [c*n] or [c*]
157            if let Some((ch, count, end)) = parse_repeat(bytes, i) {
158                for _ in 0..count {
159                    result.push(ch);
160                }
161                i = end;
162                continue;
163            }
164        }
165
166        // Escape sequence
167        if bytes[i] == b'\\' && i + 1 < bytes.len() {
168            let (ch, advance) = parse_escape(bytes, i);
169            result.push(ch);
170            i += advance;
171            continue;
172        }
173
174        // Range: prev-next (only if we have a previous char and a next char)
175        if bytes[i] == b'-' && !result.is_empty() && i + 1 < bytes.len() {
176            let start = *result.last().unwrap();
177            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
178                let (ch, adv) = parse_escape(bytes, i + 1);
179                (ch, adv)
180            } else {
181                (bytes[i + 1], 1)
182            };
183            if end_ch >= start {
184                for c in (start + 1)..=end_ch {
185                    result.push(c);
186                }
187                i += 1 + advance;
188            } else {
189                result.push(b'-');
190                i += 1;
191            }
192            continue;
193        }
194
195        result.push(bytes[i]);
196        i += 1;
197    }
198
199    (result, classes)
200}
201
202/// Parse SET2 string with class tracking, expanding to match SET1 length.
203/// Returns (expanded_bytes, case_class_positions).
204pub fn expand_set2_with_classes(set2_str: &str, set1_len: usize) -> (Vec<u8>, Vec<CaseClassInfo>) {
205    let bytes = set2_str.as_bytes();
206
207    // Check if there's a [c*] (fill repeat) in SET2
208    // If so, we handle it specially. Otherwise, use parse_set_with_classes + extend.
209    let mut has_fill = false;
210    {
211        let mut j = 0;
212        while j < bytes.len() {
213            if bytes[j] == b'[' {
214                if let Some((_ch, count, _end)) = parse_repeat(bytes, j) {
215                    if count == 0 {
216                        has_fill = true;
217                        break;
218                    }
219                    j = _end;
220                    continue;
221                }
222            }
223            if bytes[j] == b'\\' && j + 1 < bytes.len() {
224                let (_ch, adv) = parse_escape(bytes, j);
225                j += adv;
226                continue;
227            }
228            j += 1;
229        }
230    }
231
232    if has_fill {
233        // When there's a fill repeat, expand_set2 handles it but we still need classes.
234        // Parse the full set for class positions, then use expand_set2 for the bytes.
235        let expanded = expand_set2(set2_str, set1_len);
236        // Re-parse to find class positions (they won't be affected by fill repeats
237        // since fills don't generate case classes)
238        let (_raw, classes) = parse_set_with_classes(set2_str);
239        (expanded, classes)
240    } else {
241        let (mut set2, classes) = parse_set_with_classes(set2_str);
242        if set2.len() < set1_len && !set2.is_empty() {
243            let last = *set2.last().unwrap();
244            set2.resize(set1_len, last);
245        }
246        (set2, classes)
247    }
248}
249
250/// Validate that [:upper:] and [:lower:] classes are properly paired between SET1 and SET2.
251///
252/// GNU tr rules:
253/// - Every case class in SET2 MUST have a matching case class (same or opposite) at the
254///   same expanded position in SET1. If not, error.
255/// - Case classes in SET1 that don't have a corresponding class in SET2 are fine —
256///   they're just treated as expanded character sets (26 chars) with normal SET2 extension.
257///
258/// This means:
259/// - `tr '[:lower:]' 'xyz'` is fine (SET1 has class, SET2 doesn't → normal mapping)
260/// - `tr 'abc' '[:upper:]'` is an error (SET2 has class, SET1 doesn't)
261/// - `tr '[:lower:][:upper:]' '[:upper:]xyz'` is fine (SET1 upper at pos 26 maps to xyz)
262/// - `tr 'A-Z[:lower:]' 'a-y[:upper:]'` is an error (SET2 upper at pos 25, SET1 lower at pos 26)
263pub fn validate_case_classes(
264    set1_classes: &[CaseClassInfo],
265    set2_classes: &[CaseClassInfo],
266) -> Result<(), String> {
267    // Every case class in SET2 must have a case class in SET1 at the same position
268    for c2 in set2_classes {
269        let found = set1_classes.iter().any(|c1| c1.position == c2.position);
270        if !found {
271            return Err("misaligned [:upper:] and/or [:lower:] construct".to_string());
272        }
273    }
274
275    // If SET2 has no case classes, SET1 classes are fine (treated as expanded sets)
276    if set2_classes.is_empty() {
277        return Ok(());
278    }
279
280    // For each case class in SET1, if SET2 has a class at the same position,
281    // they must match (same or opposite). If SET2 has no class at that position, that's fine.
282    for c1 in set1_classes {
283        let matching_c2 = set2_classes.iter().find(|c2| c2.position == c1.position);
284        if let Some(c2) = matching_c2 {
285            // Both have a class at this position — they must be same or opposite
286            // (upper-upper, lower-lower, upper-lower, lower-upper are all valid)
287            // This is always true since CaseClass only has Upper and Lower,
288            // so any pair is either same or opposite. No additional check needed.
289            let _ = c2;
290        }
291        // If no matching c2, that's fine — SET1 class maps to normal SET2 chars
292    }
293
294    Ok(())
295}
296
297/// Check if SET2 ends with a case class and SET1 is longer than SET2 (before expansion).
298/// GNU tr: "when translating with string1 longer than string2, the latter string
299/// must not end with a character class".
300/// `set1_len` is the expanded length of SET1.
301/// `set2_raw_len` is the expanded length of SET2 before extension to match SET1.
302/// `set2_classes` are the case class positions in SET2.
303pub fn validate_set2_class_at_end(
304    set1_len: usize,
305    set2_raw_len: usize,
306    set2_classes: &[CaseClassInfo],
307) -> Result<(), String> {
308    if set1_len <= set2_raw_len || set2_classes.is_empty() {
309        return Ok(());
310    }
311    // Check if the last class in SET2 ends exactly at the end of the raw (unexpanded) SET2
312    let last_class = &set2_classes[set2_classes.len() - 1];
313    // A case class always has 26 characters
314    let class_end = last_class.position + 26;
315    if class_end == set2_raw_len {
316        return Err("when translating with string1 longer than string2,\n\
317             the latter string must not end with a character class"
318            .to_string());
319    }
320    Ok(())
321}
322
323/// Parse escape sequence starting at position `i` (which points to '\').
324/// Returns (byte_value, number_of_bytes_consumed).
325fn parse_escape(bytes: &[u8], i: usize) -> (u8, usize) {
326    debug_assert_eq!(bytes[i], b'\\');
327    if i + 1 >= bytes.len() {
328        return (b'\\', 1);
329    }
330    match bytes[i + 1] {
331        b'\\' => (b'\\', 2),
332        b'a' => (0x07, 2),
333        b'b' => (0x08, 2),
334        b'f' => (0x0C, 2),
335        b'n' => (b'\n', 2),
336        b'r' => (b'\r', 2),
337        b't' => (b'\t', 2),
338        b'v' => (0x0B, 2),
339        // Octal: \NNN (1-3 octal digits)
340        b'0'..=b'7' => {
341            let mut val: u8 = bytes[i + 1] - b'0';
342            let mut consumed = 2;
343            if i + 2 < bytes.len() && bytes[i + 2] >= b'0' && bytes[i + 2] <= b'7' {
344                val = val * 8 + (bytes[i + 2] - b'0');
345                consumed = 3;
346                if i + 3 < bytes.len() && bytes[i + 3] >= b'0' && bytes[i + 3] <= b'7' {
347                    let new_val = val as u16 * 8 + (bytes[i + 3] - b'0') as u16;
348                    if new_val <= 255 {
349                        val = new_val as u8;
350                        consumed = 4;
351                    }
352                }
353            }
354            (val, consumed)
355        }
356        // Unknown escape: just the char itself (GNU behavior)
357        ch => (ch, 2),
358    }
359}
360
361/// Try to parse a character class like [:alpha:] starting at position i.
362/// Returns (expanded bytes, position after the closing ']').
363fn parse_char_class(bytes: &[u8], i: usize) -> Option<(Vec<u8>, usize)> {
364    // Format: [:name:]
365    // bytes[i] = '[', bytes[i+1] = ':'
366    let start = i + 2;
367    let mut end = start;
368    while end < bytes.len() && bytes[end] != b':' {
369        end += 1;
370    }
371    // Need ':' followed by ']'
372    if end + 1 >= bytes.len() || bytes[end] != b':' || bytes[end + 1] != b']' {
373        return None;
374    }
375    let name = &bytes[start..end];
376    let chars = expand_class(name)?;
377    Some((chars, end + 2))
378}
379
380/// Expand a character class name to its bytes.
381fn expand_class(name: &[u8]) -> Option<Vec<u8>> {
382    match name {
383        b"alnum" => Some(
384            (b'0'..=b'9')
385                .chain(b'A'..=b'Z')
386                .chain(b'a'..=b'z')
387                .collect(),
388        ),
389        b"alpha" => Some((b'A'..=b'Z').chain(b'a'..=b'z').collect()),
390        b"blank" => Some(vec![b'\t', b' ']),
391        b"cntrl" => Some((0u8..=31).chain(std::iter::once(127)).collect()),
392        b"digit" => Some((b'0'..=b'9').collect()),
393        b"graph" => Some((33u8..=126).collect()),
394        b"lower" => Some((b'a'..=b'z').collect()),
395        b"print" => Some((32u8..=126).collect()),
396        b"punct" => Some(
397            (33u8..=47)
398                .chain(58u8..=64)
399                .chain(91u8..=96)
400                .chain(123u8..=126)
401                .collect(),
402        ),
403        b"space" => Some(vec![b'\t', b'\n', 0x0B, 0x0C, b'\r', b' ']),
404        b"upper" => Some((b'A'..=b'Z').collect()),
405        b"xdigit" => Some(
406            (b'0'..=b'9')
407                .chain(b'A'..=b'F')
408                .chain(b'a'..=b'f')
409                .collect(),
410        ),
411        _ => None,
412    }
413}
414
415/// Try to parse an equivalence class like [=c=] starting at position i.
416fn parse_equiv_class(bytes: &[u8], i: usize) -> Option<(u8, usize)> {
417    // Format: [=c=]
418    // bytes[i] = '[', bytes[i+1] = '='
419    if i + 4 >= bytes.len() {
420        return None;
421    }
422    let ch = bytes[i + 2];
423    if bytes[i + 3] == b'=' && bytes[i + 4] == b']' {
424        Some((ch, i + 5))
425    } else {
426        None
427    }
428}
429
430/// Try to parse a repeat construct like [c*n] or [c*] starting at position i.
431/// Returns (character, count, position after ']').
432/// A count of 0 means "fill to match SET1 length" (caller handles).
433fn parse_repeat(bytes: &[u8], i: usize) -> Option<(u8, usize, usize)> {
434    // Format: [c*n] or [c*]
435    // bytes[i] = '['
436    if i + 3 >= bytes.len() {
437        return None;
438    }
439
440    // The char could be an escape
441    let (ch, char_len) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
442        let (c, adv) = parse_escape(bytes, i + 1);
443        (c, adv)
444    } else {
445        (bytes[i + 1], 1)
446    };
447
448    let star_pos = i + 1 + char_len;
449    if star_pos >= bytes.len() || bytes[star_pos] != b'*' {
450        return None;
451    }
452
453    let after_star = star_pos + 1;
454    if after_star >= bytes.len() {
455        return None;
456    }
457
458    // [c*] - repeat to fill
459    if bytes[after_star] == b']' {
460        return Some((ch, 0, after_star + 1));
461    }
462
463    // [c*n] - repeat n times
464    // n can be octal (starts with 0) or decimal
465    let mut end = after_star;
466    while end < bytes.len() && bytes[end] != b']' {
467        end += 1;
468    }
469    if end >= bytes.len() {
470        return None;
471    }
472
473    let num_str = std::str::from_utf8(&bytes[after_star..end]).ok()?;
474    let count = if num_str.starts_with('0') && num_str.len() > 1 {
475        usize::from_str_radix(num_str, 8).ok()?
476    } else {
477        num_str.parse::<usize>().ok()?
478    };
479
480    Some((ch, count, end + 1))
481}
482
483/// Expand SET2 to match SET1 length for translation.
484/// If SET2 has [c*] repeats, fill them. Otherwise repeat last char.
485pub fn expand_set2(set2_str: &str, set1_len: usize) -> Vec<u8> {
486    let bytes = set2_str.as_bytes();
487
488    // Check if there's a [c*] (fill repeat) in SET2
489    // We need to parse SET2 specially: expand everything except [c*] fills,
490    // then compute how many fill chars are needed.
491    let mut before_fill = Vec::new();
492    let mut fill_char: Option<u8> = None;
493    let mut after_fill = Vec::new();
494    let mut i = 0;
495    let mut found_fill = false;
496
497    while i < bytes.len() {
498        if bytes[i] == b'[' && i + 1 < bytes.len() {
499            if bytes.get(i + 1) == Some(&b':') {
500                if let Some((class_bytes, end)) = parse_char_class(bytes, i) {
501                    if found_fill {
502                        after_fill.extend_from_slice(&class_bytes);
503                    } else {
504                        before_fill.extend_from_slice(&class_bytes);
505                    }
506                    i = end;
507                    continue;
508                }
509            }
510            if bytes.get(i + 1) == Some(&b'=') {
511                if let Some((ch, end)) = parse_equiv_class(bytes, i) {
512                    if found_fill {
513                        after_fill.push(ch);
514                    } else {
515                        before_fill.push(ch);
516                    }
517                    i = end;
518                    continue;
519                }
520            }
521            if let Some((ch, count, end)) = parse_repeat(bytes, i) {
522                if count == 0 && !found_fill {
523                    fill_char = Some(ch);
524                    found_fill = true;
525                    i = end;
526                    continue;
527                } else {
528                    let target = if found_fill {
529                        &mut after_fill
530                    } else {
531                        &mut before_fill
532                    };
533                    for _ in 0..count {
534                        target.push(ch);
535                    }
536                    i = end;
537                    continue;
538                }
539            }
540        }
541
542        if bytes[i] == b'\\' && i + 1 < bytes.len() {
543            let (ch, advance) = parse_escape(bytes, i);
544            if found_fill {
545                after_fill.push(ch);
546            } else {
547                before_fill.push(ch);
548            }
549            i += advance;
550            continue;
551        }
552
553        if bytes[i] == b'-' && !before_fill.is_empty() && !found_fill && i + 1 < bytes.len() {
554            let start = *before_fill.last().unwrap();
555            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
556                let (ch, adv) = parse_escape(bytes, i + 1);
557                (ch, adv)
558            } else {
559                (bytes[i + 1], 1)
560            };
561            if end_ch >= start {
562                for c in (start + 1)..=end_ch {
563                    before_fill.push(c);
564                }
565                i += 1 + advance;
566            } else {
567                before_fill.push(b'-');
568                i += 1;
569            }
570            continue;
571        }
572        if bytes[i] == b'-' && !after_fill.is_empty() && found_fill && i + 1 < bytes.len() {
573            let start = *after_fill.last().unwrap();
574            let (end_ch, advance) = if bytes[i + 1] == b'\\' && i + 2 < bytes.len() {
575                let (ch, adv) = parse_escape(bytes, i + 1);
576                (ch, adv)
577            } else {
578                (bytes[i + 1], 1)
579            };
580            if end_ch >= start {
581                for c in (start + 1)..=end_ch {
582                    after_fill.push(c);
583                }
584                i += 1 + advance;
585            } else {
586                after_fill.push(b'-');
587                i += 1;
588            }
589            continue;
590        }
591
592        if found_fill {
593            after_fill.push(bytes[i]);
594        } else {
595            before_fill.push(bytes[i]);
596        }
597        i += 1;
598    }
599
600    if let Some(fc) = fill_char {
601        let fixed = before_fill.len() + after_fill.len();
602        let fill_count = if set1_len > fixed {
603            set1_len - fixed
604        } else {
605            0
606        };
607        let mut result = before_fill;
608        result.resize(result.len() + fill_count, fc);
609        result.extend_from_slice(&after_fill);
610        result
611    } else {
612        // No fill repeat — use parse_set and extend with last char
613        let mut set2 = parse_set(set2_str);
614        if set2.len() < set1_len && !set2.is_empty() {
615            let last = *set2.last().unwrap();
616            set2.resize(set1_len, last);
617        }
618        set2
619    }
620}