Skip to main content

coreutils_rs/sort/
compare.rs

1/// Comparison functions for different sort modes.
2/// All comparison functions are allocation-free for maximum sort performance.
3use std::cmp::Ordering;
4
5use super::key::KeyOpts;
6
7/// Strip leading blanks (space and tab).
8#[inline]
9pub fn skip_leading_blanks(s: &[u8]) -> &[u8] {
10    let mut i = 0;
11    while i < s.len() && (s[i] == b' ' || s[i] == b'\t') {
12        i += 1;
13    }
14    &s[i..]
15}
16
17/// Compare two byte slices lexicographically (default sort).
18#[inline]
19pub fn compare_lexical(a: &[u8], b: &[u8]) -> Ordering {
20    a.cmp(b)
21}
22
23/// Numeric sort (-n): compare leading numeric strings.
24/// Handles optional leading whitespace, sign, and decimal point.
25pub fn compare_numeric(a: &[u8], b: &[u8]) -> Ordering {
26    let va = parse_numeric_value(a);
27    let vb = parse_numeric_value(b);
28    va.partial_cmp(&vb).unwrap_or(Ordering::Equal)
29}
30
31/// Fast custom numeric parser: parses sign + digits + optional decimal directly from bytes.
32/// Avoids UTF-8 validation and str::parse::<f64>() overhead entirely.
33pub fn parse_numeric_value(s: &[u8]) -> f64 {
34    let s = skip_leading_blanks(s);
35    if s.is_empty() {
36        return 0.0;
37    }
38
39    let mut i = 0;
40    let negative = if s[i] == b'-' {
41        i += 1;
42        true
43    } else {
44        if s[i] == b'+' {
45            i += 1;
46        }
47        false
48    };
49
50    // Parse integer part
51    let mut integer: u64 = 0;
52    let mut has_digits = false;
53    while i < s.len() && s[i].is_ascii_digit() {
54        integer = integer.wrapping_mul(10).wrapping_add((s[i] - b'0') as u64);
55        has_digits = true;
56        i += 1;
57    }
58
59    // Parse fractional part
60    if i < s.len() && s[i] == b'.' {
61        i += 1;
62        let frac_start = i;
63        let mut frac_val: u64 = 0;
64        while i < s.len() && s[i].is_ascii_digit() {
65            frac_val = frac_val.wrapping_mul(10).wrapping_add((s[i] - b'0') as u64);
66            has_digits = true;
67            i += 1;
68        }
69        if !has_digits {
70            return 0.0;
71        }
72        let frac_digits = i - frac_start;
73        let result = if frac_digits > 0 {
74            // Use pre-computed powers of 10 for common cases
75            let divisor = POW10[frac_digits.min(POW10.len() - 1)];
76            integer as f64 + frac_val as f64 / divisor
77        } else {
78            integer as f64
79        };
80        return if negative { -result } else { result };
81    }
82
83    if !has_digits {
84        return 0.0;
85    }
86
87    let result = integer as f64;
88    if negative { -result } else { result }
89}
90
91/// Pre-computed powers of 10 for fast decimal conversion.
92const POW10: [f64; 20] = [
93    1.0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16,
94    1e17, 1e18, 1e19,
95];
96
97fn find_numeric_end(s: &[u8]) -> usize {
98    let mut i = 0;
99    if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
100        i += 1;
101    }
102    let mut has_digits = false;
103    while i < s.len() && s[i].is_ascii_digit() {
104        i += 1;
105        has_digits = true;
106    }
107    if i < s.len() && s[i] == b'.' {
108        i += 1;
109        while i < s.len() && s[i].is_ascii_digit() {
110            i += 1;
111            has_digits = true;
112        }
113    }
114    if has_digits { i } else { 0 }
115}
116
117/// General numeric sort (-g): handles scientific notation, infinity, NaN.
118/// O(n) parser.
119pub fn compare_general_numeric(a: &[u8], b: &[u8]) -> Ordering {
120    let va = parse_general_numeric(a);
121    let vb = parse_general_numeric(b);
122    match (va.is_nan(), vb.is_nan()) {
123        (true, true) => Ordering::Equal,
124        (true, false) => Ordering::Less,
125        (false, true) => Ordering::Greater,
126        (false, false) => va.partial_cmp(&vb).unwrap_or(Ordering::Equal),
127    }
128}
129
130pub fn parse_general_numeric(s: &[u8]) -> f64 {
131    let s = skip_leading_blanks(s);
132    if s.is_empty() {
133        return f64::NAN;
134    }
135
136    // Find the longest valid float prefix
137    let mut i = 0;
138
139    // Handle "inf", "-inf", "+inf", "nan" etc.
140    let start = if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
141        i += 1;
142        i - 1
143    } else {
144        i
145    };
146
147    // Check for "inf"/"infinity"/"nan" prefix (case-insensitive)
148    if i + 2 < s.len() {
149        let c0 = s[i].to_ascii_lowercase();
150        let c1 = s[i + 1].to_ascii_lowercase();
151        let c2 = s[i + 2].to_ascii_lowercase();
152        if (c0 == b'i' && c1 == b'n' && c2 == b'f') || (c0 == b'n' && c1 == b'a' && c2 == b'n') {
153            // Try parsing the prefix as a special float
154            let end = s.len().min(i + 8); // "infinity" is 8 chars
155            for e in (i + 3..=end).rev() {
156                if let Ok(text) = std::str::from_utf8(&s[start..e]) {
157                    if let Ok(v) = text.parse::<f64>() {
158                        return v;
159                    }
160                }
161            }
162            return f64::NAN;
163        }
164    }
165
166    // Reset i for numeric parsing
167    i = start;
168    if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
169        i += 1;
170    }
171
172    // Digits before decimal
173    let mut has_digits = false;
174    while i < s.len() && s[i].is_ascii_digit() {
175        i += 1;
176        has_digits = true;
177    }
178    // Decimal point
179    if i < s.len() && s[i] == b'.' {
180        i += 1;
181        while i < s.len() && s[i].is_ascii_digit() {
182            i += 1;
183            has_digits = true;
184        }
185    }
186    if !has_digits {
187        return f64::NAN;
188    }
189    // Exponent
190    if i < s.len() && (s[i] == b'e' || s[i] == b'E') {
191        let save = i;
192        i += 1;
193        if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
194            i += 1;
195        }
196        if i < s.len() && s[i].is_ascii_digit() {
197            while i < s.len() && s[i].is_ascii_digit() {
198                i += 1;
199            }
200        } else {
201            i = save;
202        }
203    }
204
205    // Parse the numeric prefix using standard library
206    std::str::from_utf8(&s[start..i])
207        .ok()
208        .and_then(|s| s.parse::<f64>().ok())
209        .unwrap_or(f64::NAN)
210}
211
212/// Human numeric sort (-h): handles suffixes K, M, G, T, P, E, Z, Y.
213pub fn compare_human_numeric(a: &[u8], b: &[u8]) -> Ordering {
214    let va = parse_human_numeric(a);
215    let vb = parse_human_numeric(b);
216    va.partial_cmp(&vb).unwrap_or(Ordering::Equal)
217}
218
219pub fn parse_human_numeric(s: &[u8]) -> f64 {
220    let s = skip_leading_blanks(s);
221    if s.is_empty() {
222        return 0.0;
223    }
224
225    // Use the fast custom parser for the numeric part
226    let base = parse_numeric_value(s);
227    let end = find_numeric_end(s);
228
229    if end < s.len() {
230        let multiplier = match s[end] {
231            b'K' | b'k' => 1e3,
232            b'M' => 1e6,
233            b'G' => 1e9,
234            b'T' => 1e12,
235            b'P' => 1e15,
236            b'E' => 1e18,
237            b'Z' => 1e21,
238            b'Y' => 1e24,
239            _ => 1.0,
240        };
241        base * multiplier
242    } else {
243        base
244    }
245}
246
247/// Month sort (-M).
248pub fn compare_month(a: &[u8], b: &[u8]) -> Ordering {
249    let ma = parse_month(a);
250    let mb = parse_month(b);
251    ma.cmp(&mb)
252}
253
254fn parse_month(s: &[u8]) -> u8 {
255    let s = skip_leading_blanks(s);
256    if s.len() < 3 {
257        return 0;
258    }
259    let m = [
260        s[0].to_ascii_uppercase(),
261        s[1].to_ascii_uppercase(),
262        s[2].to_ascii_uppercase(),
263    ];
264    match &m {
265        b"JAN" => 1,
266        b"FEB" => 2,
267        b"MAR" => 3,
268        b"APR" => 4,
269        b"MAY" => 5,
270        b"JUN" => 6,
271        b"JUL" => 7,
272        b"AUG" => 8,
273        b"SEP" => 9,
274        b"OCT" => 10,
275        b"NOV" => 11,
276        b"DEC" => 12,
277        _ => 0,
278    }
279}
280
281/// Version sort (-V): natural sort of version numbers.
282/// Uses byte slices directly instead of char iterators for maximum performance.
283pub fn compare_version(a: &[u8], b: &[u8]) -> Ordering {
284    let mut ai = 0usize;
285    let mut bi = 0usize;
286
287    loop {
288        if ai >= a.len() && bi >= b.len() {
289            return Ordering::Equal;
290        }
291        if ai >= a.len() {
292            return Ordering::Less;
293        }
294        if bi >= b.len() {
295            return Ordering::Greater;
296        }
297
298        let ac = a[ai];
299        let bc = b[bi];
300
301        if ac.is_ascii_digit() && bc.is_ascii_digit() {
302            let anum = consume_number_bytes(a, &mut ai);
303            let bnum = consume_number_bytes(b, &mut bi);
304            match anum.cmp(&bnum) {
305                Ordering::Equal => continue,
306                other => return other,
307            }
308        } else {
309            match ac.cmp(&bc) {
310                Ordering::Equal => {
311                    ai += 1;
312                    bi += 1;
313                }
314                other => return other,
315            }
316        }
317    }
318}
319
320#[inline]
321fn consume_number_bytes(data: &[u8], pos: &mut usize) -> u64 {
322    let mut n: u64 = 0;
323    while *pos < data.len() && data[*pos].is_ascii_digit() {
324        n = n
325            .saturating_mul(10)
326            .saturating_add((data[*pos] - b'0') as u64);
327        *pos += 1;
328    }
329    n
330}
331
332/// Random sort (-R): hash-based shuffle that groups identical keys.
333pub fn compare_random(a: &[u8], b: &[u8], seed: u64) -> Ordering {
334    let ha = fnv1a_hash(a, seed);
335    let hb = fnv1a_hash(b, seed);
336    ha.cmp(&hb)
337}
338
339/// FNV-1a hash with seed mixing.
340#[inline]
341fn fnv1a_hash(data: &[u8], seed: u64) -> u64 {
342    let mut hash = 0xcbf29ce484222325u64 ^ seed;
343    for &b in data {
344        hash ^= b as u64;
345        hash = hash.wrapping_mul(0x100000001b3);
346    }
347    hash
348}
349
350/// Compare with text filtering (-d, -i, -f flags in any combination).
351/// Allocation-free: uses iterator filtering.
352#[inline]
353fn is_dict_char(b: u8) -> bool {
354    b.is_ascii_alphanumeric() || b == b' ' || b == b'\t'
355}
356
357#[inline]
358fn is_printable(b: u8) -> bool {
359    b >= 0x20 && b < 0x7f
360}
361
362fn compare_text_filtered(
363    a: &[u8],
364    b: &[u8],
365    dict: bool,
366    no_print: bool,
367    fold_case: bool,
368) -> Ordering {
369    if !dict && !no_print && !fold_case {
370        return a.cmp(b);
371    }
372
373    let mut ai = a.iter().copied();
374    let mut bi = b.iter().copied();
375
376    loop {
377        let na = next_valid(&mut ai, dict, no_print);
378        let nb = next_valid(&mut bi, dict, no_print);
379        match (na, nb) {
380            (Some(ab), Some(bb)) => {
381                let ca = if fold_case {
382                    ab.to_ascii_uppercase()
383                } else {
384                    ab
385                };
386                let cb = if fold_case {
387                    bb.to_ascii_uppercase()
388                } else {
389                    bb
390                };
391                match ca.cmp(&cb) {
392                    Ordering::Equal => continue,
393                    other => return other,
394                }
395            }
396            (Some(_), None) => return Ordering::Greater,
397            (None, Some(_)) => return Ordering::Less,
398            (None, None) => return Ordering::Equal,
399        }
400    }
401}
402
403#[inline]
404fn next_valid(iter: &mut impl Iterator<Item = u8>, dict: bool, no_print: bool) -> Option<u8> {
405    loop {
406        match iter.next() {
407            None => return None,
408            Some(b) => {
409                if dict && !is_dict_char(b) {
410                    continue;
411                }
412                if no_print && !is_printable(b) {
413                    continue;
414                }
415                return Some(b);
416            }
417        }
418    }
419}
420
421// Public wrappers for backward compatibility with tests
422pub fn compare_ignore_case(a: &[u8], b: &[u8]) -> Ordering {
423    compare_text_filtered(a, b, false, false, true)
424}
425
426pub fn compare_dictionary(a: &[u8], b: &[u8], ignore_case: bool) -> Ordering {
427    compare_text_filtered(a, b, true, false, ignore_case)
428}
429
430pub fn compare_ignore_nonprinting(a: &[u8], b: &[u8], ignore_case: bool) -> Ordering {
431    compare_text_filtered(a, b, false, true, ignore_case)
432}
433
434/// Master comparison function that dispatches based on KeyOpts.
435pub fn compare_with_opts(a: &[u8], b: &[u8], opts: &KeyOpts, random_seed: u64) -> Ordering {
436    let a = if opts.ignore_leading_blanks {
437        skip_leading_blanks(a)
438    } else {
439        a
440    };
441    let b = if opts.ignore_leading_blanks {
442        skip_leading_blanks(b)
443    } else {
444        b
445    };
446
447    let result = if opts.numeric {
448        compare_numeric(a, b)
449    } else if opts.general_numeric {
450        compare_general_numeric(a, b)
451    } else if opts.human_numeric {
452        compare_human_numeric(a, b)
453    } else if opts.month {
454        compare_month(a, b)
455    } else if opts.version {
456        compare_version(a, b)
457    } else if opts.random {
458        compare_random(a, b, random_seed)
459    } else {
460        compare_text_filtered(
461            a,
462            b,
463            opts.dictionary_order,
464            opts.ignore_nonprinting,
465            opts.ignore_case,
466        )
467    };
468
469    if opts.reverse {
470        result.reverse()
471    } else {
472        result
473    }
474}
475
476/// Concrete comparison function type. Selected once at setup time to avoid
477/// per-comparison flag checking in hot sort loops.
478pub type CompareFn = fn(&[u8], &[u8]) -> Ordering;
479
480/// Select a concrete comparison function based on KeyOpts.
481/// Returns (compare_fn, needs_leading_blank_strip, needs_reverse).
482/// The caller applies blank-stripping and reversal outside the function pointer,
483/// eliminating all per-comparison branching.
484pub fn select_comparator(opts: &KeyOpts, random_seed: u64) -> (CompareFn, bool, bool) {
485    let needs_blank = opts.ignore_leading_blanks;
486    let needs_reverse = opts.reverse;
487
488    let cmp: CompareFn = if opts.numeric {
489        compare_numeric
490    } else if opts.general_numeric {
491        compare_general_numeric
492    } else if opts.human_numeric {
493        compare_human_numeric
494    } else if opts.month {
495        compare_month
496    } else if opts.version {
497        compare_version
498    } else if opts.random {
499        // Random needs seed — wrap in a closure-like pattern
500        // Since we need random_seed, we use a special case
501        return (
502            make_random_comparator(random_seed),
503            needs_blank,
504            needs_reverse,
505        );
506    } else if opts.dictionary_order || opts.ignore_nonprinting || opts.ignore_case {
507        // Text filtering: select specialized variant
508        match (
509            opts.dictionary_order,
510            opts.ignore_nonprinting,
511            opts.ignore_case,
512        ) {
513            (false, false, true) => compare_ignore_case,
514            (true, false, false) => |a: &[u8], b: &[u8]| compare_dictionary(a, b, false),
515            (true, false, true) => |a: &[u8], b: &[u8]| compare_dictionary(a, b, true),
516            (false, true, false) => |a: &[u8], b: &[u8]| compare_ignore_nonprinting(a, b, false),
517            (false, true, true) => |a: &[u8], b: &[u8]| compare_ignore_nonprinting(a, b, true),
518            (true, true, false) => {
519                |a: &[u8], b: &[u8]| compare_text_filtered(a, b, true, true, false)
520            }
521            (true, true, true) => {
522                |a: &[u8], b: &[u8]| compare_text_filtered(a, b, true, true, true)
523            }
524            _ => |a: &[u8], b: &[u8]| a.cmp(b),
525        }
526    } else {
527        |a: &[u8], b: &[u8]| a.cmp(b)
528    };
529
530    (cmp, needs_blank, needs_reverse)
531}
532
533fn make_random_comparator(seed: u64) -> CompareFn {
534    // We can't capture the seed in a function pointer, so we use a static.
535    // This is safe because sort is single-process and seed doesn't change during a sort.
536    RANDOM_SEED.store(seed, std::sync::atomic::Ordering::Relaxed);
537    random_compare_with_static_seed
538}
539
540static RANDOM_SEED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
541
542fn random_compare_with_static_seed(a: &[u8], b: &[u8]) -> Ordering {
543    let seed = RANDOM_SEED.load(std::sync::atomic::Ordering::Relaxed);
544    compare_random(a, b, seed)
545}