Skip to main content

coreutils_rs/sort/
compare.rs

1/// Comparison functions for different sort modes.
2/// All comparison functions are allocation-free for maximum sort performance.
3use std::cmp::Ordering;
4
5use super::key::KeyOpts;
6
7/// Strip leading blanks (space and tab).
8#[inline(always)]
9pub fn skip_leading_blanks(s: &[u8]) -> &[u8] {
10    let mut i = 0;
11    while i < s.len()
12        && (unsafe { *s.get_unchecked(i) } == b' ' || unsafe { *s.get_unchecked(i) } == b'\t')
13    {
14        i += 1;
15    }
16    &s[i..]
17}
18
19/// Compare two byte slices lexicographically (default sort).
20#[inline]
21pub fn compare_lexical(a: &[u8], b: &[u8]) -> Ordering {
22    a.cmp(b)
23}
24
25/// Numeric sort (-n): compare leading numeric strings.
26/// Handles optional leading whitespace, sign, and decimal point.
27#[inline]
28pub fn compare_numeric(a: &[u8], b: &[u8]) -> Ordering {
29    let va = parse_numeric_value(a);
30    let vb = parse_numeric_value(b);
31    va.partial_cmp(&vb).unwrap_or(Ordering::Equal)
32}
33
34/// Fast custom numeric parser: parses sign + digits + optional decimal directly from bytes.
35/// Avoids UTF-8 validation and str::parse::<f64>() overhead entirely.
36/// Uses batch digit processing for the integer part (4 digits at a time) to reduce
37/// loop iterations and branch mispredictions.
38#[inline]
39pub fn parse_numeric_value(s: &[u8]) -> f64 {
40    let s = skip_leading_blanks(s);
41    if s.is_empty() {
42        return 0.0;
43    }
44
45    let mut i = 0;
46    let negative = if unsafe { *s.get_unchecked(i) } == b'-' {
47        i += 1;
48        true
49    } else {
50        if i < s.len() && unsafe { *s.get_unchecked(i) } == b'+' {
51            i += 1;
52        }
53        false
54    };
55
56    if i >= s.len() {
57        return 0.0;
58    }
59
60    // Parse integer part — batch 4 digits at a time for reduced loop overhead
61    let mut integer: u64 = 0;
62    let mut has_digits = false;
63
64    // Fast batch path: process 4 digits at a time
65    while i + 4 <= s.len() {
66        let d0 = unsafe { *s.get_unchecked(i) }.wrapping_sub(b'0');
67        if d0 > 9 {
68            break;
69        }
70        let d1 = unsafe { *s.get_unchecked(i + 1) }.wrapping_sub(b'0');
71        if d1 > 9 {
72            integer = integer.wrapping_mul(10).wrapping_add(d0 as u64);
73            i += 1;
74            has_digits = true;
75            break;
76        }
77        let d2 = unsafe { *s.get_unchecked(i + 2) }.wrapping_sub(b'0');
78        if d2 > 9 {
79            integer = integer
80                .wrapping_mul(100)
81                .wrapping_add(d0 as u64 * 10 + d1 as u64);
82            i += 2;
83            has_digits = true;
84            break;
85        }
86        let d3 = unsafe { *s.get_unchecked(i + 3) }.wrapping_sub(b'0');
87        if d3 > 9 {
88            integer = integer
89                .wrapping_mul(1000)
90                .wrapping_add(d0 as u64 * 100 + d1 as u64 * 10 + d2 as u64);
91            i += 3;
92            has_digits = true;
93            break;
94        }
95        integer = integer
96            .wrapping_mul(10000)
97            .wrapping_add(d0 as u64 * 1000 + d1 as u64 * 100 + d2 as u64 * 10 + d3 as u64);
98        i += 4;
99        has_digits = true;
100    }
101
102    // Tail: process remaining digits one at a time
103    while i < s.len() {
104        let d = unsafe { *s.get_unchecked(i) }.wrapping_sub(b'0');
105        if d > 9 {
106            break;
107        }
108        integer = integer.wrapping_mul(10).wrapping_add(d as u64);
109        has_digits = true;
110        i += 1;
111    }
112
113    // Parse fractional part
114    if i < s.len() && unsafe { *s.get_unchecked(i) } == b'.' {
115        i += 1;
116        let frac_start = i;
117        let mut frac_val: u64 = 0;
118        while i < s.len() {
119            let d = unsafe { *s.get_unchecked(i) }.wrapping_sub(b'0');
120            if d > 9 {
121                break;
122            }
123            frac_val = frac_val.wrapping_mul(10).wrapping_add(d as u64);
124            has_digits = true;
125            i += 1;
126        }
127        if !has_digits {
128            return 0.0;
129        }
130        let frac_digits = i - frac_start;
131        let result = if frac_digits > 0 {
132            // Use pre-computed powers of 10 for common cases
133            let divisor = POW10[frac_digits.min(POW10.len() - 1)];
134            integer as f64 + frac_val as f64 / divisor
135        } else {
136            integer as f64
137        };
138        return if negative { -result } else { result };
139    }
140
141    if !has_digits {
142        return 0.0;
143    }
144
145    let result = integer as f64;
146    if negative { -result } else { result }
147}
148
149/// Pre-computed powers of 10 for fast decimal conversion.
150const POW10: [f64; 20] = [
151    1.0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16,
152    1e17, 1e18, 1e19,
153];
154
155/// Fast integer-only parser for numeric sort (-n).
156/// Returns Some(i64) if the value is a pure integer (no decimal point, no exponent).
157/// Returns None if the value has a decimal point or is not a valid integer.
158/// This avoids the f64 conversion path for integer-only data.
159/// Uses wrapping arithmetic (no overflow checks) for speed — values >18 digits
160/// may wrap but sort order is still consistent since all values wrap the same way.
161#[inline]
162pub fn try_parse_integer(s: &[u8]) -> Option<i64> {
163    let s = skip_leading_blanks(s);
164    if s.is_empty() {
165        return Some(0);
166    }
167
168    let mut i = 0;
169    let negative = if unsafe { *s.get_unchecked(i) } == b'-' {
170        i += 1;
171        true
172    } else {
173        if i < s.len() && unsafe { *s.get_unchecked(i) } == b'+' {
174            i += 1;
175        }
176        false
177    };
178
179    // Parse digits — wrapping arithmetic, batch 4 digits at a time
180    let mut value: i64 = 0;
181    let mut has_digits = false;
182
183    // Fast batch: 4 digits at a time
184    while i + 4 <= s.len() {
185        let d0 = unsafe { *s.get_unchecked(i) }.wrapping_sub(b'0');
186        if d0 > 9 {
187            break;
188        }
189        let d1 = unsafe { *s.get_unchecked(i + 1) }.wrapping_sub(b'0');
190        if d1 > 9 {
191            value = value.wrapping_mul(10).wrapping_add(d0 as i64);
192            i += 1;
193            has_digits = true;
194            break;
195        }
196        let d2 = unsafe { *s.get_unchecked(i + 2) }.wrapping_sub(b'0');
197        if d2 > 9 {
198            value = value
199                .wrapping_mul(100)
200                .wrapping_add(d0 as i64 * 10 + d1 as i64);
201            i += 2;
202            has_digits = true;
203            break;
204        }
205        let d3 = unsafe { *s.get_unchecked(i + 3) }.wrapping_sub(b'0');
206        if d3 > 9 {
207            value = value
208                .wrapping_mul(1000)
209                .wrapping_add(d0 as i64 * 100 + d1 as i64 * 10 + d2 as i64);
210            i += 3;
211            has_digits = true;
212            break;
213        }
214        value = value
215            .wrapping_mul(10000)
216            .wrapping_add(d0 as i64 * 1000 + d1 as i64 * 100 + d2 as i64 * 10 + d3 as i64);
217        i += 4;
218        has_digits = true;
219    }
220
221    // Tail: remaining digits
222    while i < s.len() {
223        let d = unsafe { *s.get_unchecked(i) }.wrapping_sub(b'0');
224        if d > 9 {
225            break;
226        }
227        value = value.wrapping_mul(10).wrapping_add(d as i64);
228        has_digits = true;
229        i += 1;
230    }
231
232    // If there's a decimal point, this is not a pure integer
233    if i < s.len() && unsafe { *s.get_unchecked(i) } == b'.' {
234        return None;
235    }
236
237    if !has_digits {
238        return Some(0);
239    }
240
241    Some(if negative {
242        value.wrapping_neg()
243    } else {
244        value
245    })
246}
247
248/// Convert an i64 to a sortable u64 whose natural ordering matches signed integer ordering.
249/// Adds i64::MAX + 1 (0x8000000000000000) to shift the range to unsigned.
250#[inline]
251pub fn int_to_sortable_u64(v: i64) -> u64 {
252    (v as u64).wrapping_add(0x8000000000000000)
253}
254
255fn find_numeric_end(s: &[u8]) -> usize {
256    let mut i = 0;
257    if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
258        i += 1;
259    }
260    let mut has_digits = false;
261    while i < s.len() && s[i].is_ascii_digit() {
262        i += 1;
263        has_digits = true;
264    }
265    if i < s.len() && s[i] == b'.' {
266        i += 1;
267        while i < s.len() && s[i].is_ascii_digit() {
268            i += 1;
269            has_digits = true;
270        }
271    }
272    if has_digits { i } else { 0 }
273}
274
275/// General numeric sort (-g): handles scientific notation, infinity, NaN.
276/// O(n) parser.
277pub fn compare_general_numeric(a: &[u8], b: &[u8]) -> Ordering {
278    let va = parse_general_numeric(a);
279    let vb = parse_general_numeric(b);
280    match (va.is_nan(), vb.is_nan()) {
281        (true, true) => Ordering::Equal,
282        (true, false) => Ordering::Less,
283        (false, true) => Ordering::Greater,
284        (false, false) => va.partial_cmp(&vb).unwrap_or(Ordering::Equal),
285    }
286}
287
288pub fn parse_general_numeric(s: &[u8]) -> f64 {
289    let s = skip_leading_blanks(s);
290    if s.is_empty() {
291        return f64::NAN;
292    }
293
294    // Find the longest valid float prefix
295    let mut i = 0;
296
297    // Handle "inf", "-inf", "+inf", "nan" etc.
298    let start = if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
299        i += 1;
300        i - 1
301    } else {
302        i
303    };
304
305    // Check for "inf"/"infinity"/"nan" prefix (case-insensitive)
306    if i + 2 < s.len() {
307        let c0 = s[i].to_ascii_lowercase();
308        let c1 = s[i + 1].to_ascii_lowercase();
309        let c2 = s[i + 2].to_ascii_lowercase();
310        if (c0 == b'i' && c1 == b'n' && c2 == b'f') || (c0 == b'n' && c1 == b'a' && c2 == b'n') {
311            // Try parsing the prefix as a special float
312            let end = s.len().min(i + 8); // "infinity" is 8 chars
313            for e in (i + 3..=end).rev() {
314                if let Ok(text) = std::str::from_utf8(&s[start..e]) {
315                    if let Ok(v) = text.parse::<f64>() {
316                        return v;
317                    }
318                }
319            }
320            return f64::NAN;
321        }
322    }
323
324    // Reset i for numeric parsing
325    i = start;
326    if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
327        i += 1;
328    }
329
330    // Digits before decimal
331    let mut has_digits = false;
332    while i < s.len() && s[i].is_ascii_digit() {
333        i += 1;
334        has_digits = true;
335    }
336    // Decimal point
337    if i < s.len() && s[i] == b'.' {
338        i += 1;
339        while i < s.len() && s[i].is_ascii_digit() {
340            i += 1;
341            has_digits = true;
342        }
343    }
344    if !has_digits {
345        return f64::NAN;
346    }
347    // Exponent
348    if i < s.len() && (s[i] == b'e' || s[i] == b'E') {
349        let save = i;
350        i += 1;
351        if i < s.len() && (s[i] == b'+' || s[i] == b'-') {
352            i += 1;
353        }
354        if i < s.len() && s[i].is_ascii_digit() {
355            while i < s.len() && s[i].is_ascii_digit() {
356                i += 1;
357            }
358        } else {
359            i = save;
360        }
361    }
362
363    // Parse the numeric prefix using standard library
364    std::str::from_utf8(&s[start..i])
365        .ok()
366        .and_then(|s| s.parse::<f64>().ok())
367        .unwrap_or(f64::NAN)
368}
369
370/// Human numeric sort (-h): handles suffixes K, M, G, T, P, E, Z, Y.
371pub fn compare_human_numeric(a: &[u8], b: &[u8]) -> Ordering {
372    let va = parse_human_numeric(a);
373    let vb = parse_human_numeric(b);
374    va.partial_cmp(&vb).unwrap_or(Ordering::Equal)
375}
376
377pub fn parse_human_numeric(s: &[u8]) -> f64 {
378    let s = skip_leading_blanks(s);
379    if s.is_empty() {
380        return 0.0;
381    }
382
383    // Use the fast custom parser for the numeric part
384    let base = parse_numeric_value(s);
385    let end = find_numeric_end(s);
386
387    if end < s.len() {
388        let multiplier = match s[end] {
389            b'K' | b'k' => 1e3,
390            b'M' => 1e6,
391            b'G' => 1e9,
392            b'T' => 1e12,
393            b'P' => 1e15,
394            b'E' => 1e18,
395            b'Z' => 1e21,
396            b'Y' => 1e24,
397            _ => 1.0,
398        };
399        base * multiplier
400    } else {
401        base
402    }
403}
404
405/// Month sort (-M).
406pub fn compare_month(a: &[u8], b: &[u8]) -> Ordering {
407    let ma = parse_month(a);
408    let mb = parse_month(b);
409    ma.cmp(&mb)
410}
411
412fn parse_month(s: &[u8]) -> u8 {
413    let s = skip_leading_blanks(s);
414    if s.len() < 3 {
415        return 0;
416    }
417    let m = [
418        s[0].to_ascii_uppercase(),
419        s[1].to_ascii_uppercase(),
420        s[2].to_ascii_uppercase(),
421    ];
422    match &m {
423        b"JAN" => 1,
424        b"FEB" => 2,
425        b"MAR" => 3,
426        b"APR" => 4,
427        b"MAY" => 5,
428        b"JUN" => 6,
429        b"JUL" => 7,
430        b"AUG" => 8,
431        b"SEP" => 9,
432        b"OCT" => 10,
433        b"NOV" => 11,
434        b"DEC" => 12,
435        _ => 0,
436    }
437}
438
439/// Version sort (-V): natural sort of version numbers.
440/// Uses byte slices directly instead of char iterators for maximum performance.
441pub fn compare_version(a: &[u8], b: &[u8]) -> Ordering {
442    let mut ai = 0usize;
443    let mut bi = 0usize;
444
445    loop {
446        if ai >= a.len() && bi >= b.len() {
447            return Ordering::Equal;
448        }
449        if ai >= a.len() {
450            return Ordering::Less;
451        }
452        if bi >= b.len() {
453            return Ordering::Greater;
454        }
455
456        let ac = a[ai];
457        let bc = b[bi];
458
459        if ac.is_ascii_digit() && bc.is_ascii_digit() {
460            let anum = consume_number_bytes(a, &mut ai);
461            let bnum = consume_number_bytes(b, &mut bi);
462            match anum.cmp(&bnum) {
463                Ordering::Equal => continue,
464                other => return other,
465            }
466        } else {
467            match ac.cmp(&bc) {
468                Ordering::Equal => {
469                    ai += 1;
470                    bi += 1;
471                }
472                other => return other,
473            }
474        }
475    }
476}
477
478#[inline]
479fn consume_number_bytes(data: &[u8], pos: &mut usize) -> u64 {
480    let mut n: u64 = 0;
481    while *pos < data.len() && data[*pos].is_ascii_digit() {
482        n = n
483            .saturating_mul(10)
484            .saturating_add((data[*pos] - b'0') as u64);
485        *pos += 1;
486    }
487    n
488}
489
490/// Random sort (-R): hash-based shuffle that groups identical keys.
491pub fn compare_random(a: &[u8], b: &[u8], seed: u64) -> Ordering {
492    let ha = fnv1a_hash(a, seed);
493    let hb = fnv1a_hash(b, seed);
494    ha.cmp(&hb)
495}
496
497/// FNV-1a hash with seed mixing.
498#[inline]
499fn fnv1a_hash(data: &[u8], seed: u64) -> u64 {
500    let mut hash = 0xcbf29ce484222325u64 ^ seed;
501    for &b in data {
502        hash ^= b as u64;
503        hash = hash.wrapping_mul(0x100000001b3);
504    }
505    hash
506}
507
508/// Compare with text filtering (-d, -i, -f flags in any combination).
509/// Allocation-free: uses iterator filtering.
510#[inline]
511fn is_dict_char(b: u8) -> bool {
512    b.is_ascii_alphanumeric() || b == b' ' || b == b'\t'
513}
514
515#[inline]
516fn is_printable(b: u8) -> bool {
517    b >= 0x20 && b < 0x7f
518}
519
520fn compare_text_filtered(
521    a: &[u8],
522    b: &[u8],
523    dict: bool,
524    no_print: bool,
525    fold_case: bool,
526) -> Ordering {
527    if !dict && !no_print && !fold_case {
528        return a.cmp(b);
529    }
530
531    let mut ai = a.iter().copied();
532    let mut bi = b.iter().copied();
533
534    loop {
535        let na = next_valid(&mut ai, dict, no_print);
536        let nb = next_valid(&mut bi, dict, no_print);
537        match (na, nb) {
538            (Some(ab), Some(bb)) => {
539                let ca = if fold_case {
540                    ab.to_ascii_uppercase()
541                } else {
542                    ab
543                };
544                let cb = if fold_case {
545                    bb.to_ascii_uppercase()
546                } else {
547                    bb
548                };
549                match ca.cmp(&cb) {
550                    Ordering::Equal => continue,
551                    other => return other,
552                }
553            }
554            (Some(_), None) => return Ordering::Greater,
555            (None, Some(_)) => return Ordering::Less,
556            (None, None) => return Ordering::Equal,
557        }
558    }
559}
560
561#[inline]
562fn next_valid(iter: &mut impl Iterator<Item = u8>, dict: bool, no_print: bool) -> Option<u8> {
563    loop {
564        match iter.next() {
565            None => return None,
566            Some(b) => {
567                if dict && !is_dict_char(b) {
568                    continue;
569                }
570                if no_print && !is_printable(b) {
571                    continue;
572                }
573                return Some(b);
574            }
575        }
576    }
577}
578
579// Public wrappers for backward compatibility with tests
580pub fn compare_ignore_case(a: &[u8], b: &[u8]) -> Ordering {
581    compare_text_filtered(a, b, false, false, true)
582}
583
584pub fn compare_dictionary(a: &[u8], b: &[u8], ignore_case: bool) -> Ordering {
585    compare_text_filtered(a, b, true, false, ignore_case)
586}
587
588pub fn compare_ignore_nonprinting(a: &[u8], b: &[u8], ignore_case: bool) -> Ordering {
589    compare_text_filtered(a, b, false, true, ignore_case)
590}
591
592/// Master comparison function that dispatches based on KeyOpts.
593pub fn compare_with_opts(a: &[u8], b: &[u8], opts: &KeyOpts, random_seed: u64) -> Ordering {
594    let a = if opts.ignore_leading_blanks {
595        skip_leading_blanks(a)
596    } else {
597        a
598    };
599    let b = if opts.ignore_leading_blanks {
600        skip_leading_blanks(b)
601    } else {
602        b
603    };
604
605    let result = if opts.numeric {
606        compare_numeric(a, b)
607    } else if opts.general_numeric {
608        compare_general_numeric(a, b)
609    } else if opts.human_numeric {
610        compare_human_numeric(a, b)
611    } else if opts.month {
612        compare_month(a, b)
613    } else if opts.version {
614        compare_version(a, b)
615    } else if opts.random {
616        compare_random(a, b, random_seed)
617    } else {
618        compare_text_filtered(
619            a,
620            b,
621            opts.dictionary_order,
622            opts.ignore_nonprinting,
623            opts.ignore_case,
624        )
625    };
626
627    if opts.reverse {
628        result.reverse()
629    } else {
630        result
631    }
632}
633
634/// Concrete comparison function type. Selected once at setup time to avoid
635/// per-comparison flag checking in hot sort loops.
636pub type CompareFn = fn(&[u8], &[u8]) -> Ordering;
637
638/// Select a concrete comparison function based on KeyOpts.
639/// Returns (compare_fn, needs_leading_blank_strip, needs_reverse).
640/// The caller applies blank-stripping and reversal outside the function pointer,
641/// eliminating all per-comparison branching.
642pub fn select_comparator(opts: &KeyOpts, random_seed: u64) -> (CompareFn, bool, bool) {
643    let needs_blank = opts.ignore_leading_blanks;
644    let needs_reverse = opts.reverse;
645
646    let cmp: CompareFn = if opts.numeric {
647        compare_numeric
648    } else if opts.general_numeric {
649        compare_general_numeric
650    } else if opts.human_numeric {
651        compare_human_numeric
652    } else if opts.month {
653        compare_month
654    } else if opts.version {
655        compare_version
656    } else if opts.random {
657        // Random needs seed — wrap in a closure-like pattern
658        // Since we need random_seed, we use a special case
659        return (
660            make_random_comparator(random_seed),
661            needs_blank,
662            needs_reverse,
663        );
664    } else if opts.dictionary_order || opts.ignore_nonprinting || opts.ignore_case {
665        // Text filtering: select specialized variant
666        match (
667            opts.dictionary_order,
668            opts.ignore_nonprinting,
669            opts.ignore_case,
670        ) {
671            (false, false, true) => compare_ignore_case,
672            (true, false, false) => |a: &[u8], b: &[u8]| compare_dictionary(a, b, false),
673            (true, false, true) => |a: &[u8], b: &[u8]| compare_dictionary(a, b, true),
674            (false, true, false) => |a: &[u8], b: &[u8]| compare_ignore_nonprinting(a, b, false),
675            (false, true, true) => |a: &[u8], b: &[u8]| compare_ignore_nonprinting(a, b, true),
676            (true, true, false) => {
677                |a: &[u8], b: &[u8]| compare_text_filtered(a, b, true, true, false)
678            }
679            (true, true, true) => {
680                |a: &[u8], b: &[u8]| compare_text_filtered(a, b, true, true, true)
681            }
682            _ => |a: &[u8], b: &[u8]| a.cmp(b),
683        }
684    } else {
685        |a: &[u8], b: &[u8]| a.cmp(b)
686    };
687
688    (cmp, needs_blank, needs_reverse)
689}
690
691fn make_random_comparator(seed: u64) -> CompareFn {
692    // We can't capture the seed in a function pointer, so we use a static.
693    // This is safe because sort is single-process and seed doesn't change during a sort.
694    RANDOM_SEED.store(seed, std::sync::atomic::Ordering::Relaxed);
695    random_compare_with_static_seed
696}
697
698static RANDOM_SEED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
699
700fn random_compare_with_static_seed(a: &[u8], b: &[u8]) -> Ordering {
701    let seed = RANDOM_SEED.load(std::sync::atomic::Ordering::Relaxed);
702    compare_random(a, b, seed)
703}