Skip to main content

coreutils_rs/join/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    Default,
8    Strict,
9    None,
10}
11
12/// An output field specification from -o format.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum OutputSpec {
15    /// Field 0: the join field
16    JoinField,
17    /// (file_index 0-based, field_index 0-based)
18    FileField(usize, usize),
19}
20
21/// Configuration for the join command.
22pub struct JoinConfig {
23    /// Join field for file 1 (0-indexed)
24    pub field1: usize,
25    /// Join field for file 2 (0-indexed)
26    pub field2: usize,
27    /// Also print unpairable lines from file 1 (-a 1)
28    pub print_unpaired1: bool,
29    /// Also print unpairable lines from file 2 (-a 2)
30    pub print_unpaired2: bool,
31    /// Print ONLY unpairable lines from file 1 (-v 1)
32    pub only_unpaired1: bool,
33    /// Print ONLY unpairable lines from file 2 (-v 2)
34    pub only_unpaired2: bool,
35    /// Replace missing fields with this string (-e)
36    pub empty_filler: Option<Vec<u8>>,
37    /// Ignore case in key comparison (-i)
38    pub case_insensitive: bool,
39    /// Output format (-o)
40    pub output_format: Option<Vec<OutputSpec>>,
41    /// Auto output format (-o auto)
42    pub auto_format: bool,
43    /// Field separator (-t). None = whitespace mode.
44    pub separator: Option<u8>,
45    /// Order checking
46    pub order_check: OrderCheck,
47    /// Treat first line as header (--header)
48    pub header: bool,
49    /// Use NUL as line delimiter (-z)
50    pub zero_terminated: bool,
51}
52
53impl Default for JoinConfig {
54    fn default() -> Self {
55        Self {
56            field1: 0,
57            field2: 0,
58            print_unpaired1: false,
59            print_unpaired2: false,
60            only_unpaired1: false,
61            only_unpaired2: false,
62            empty_filler: None,
63            case_insensitive: false,
64            output_format: None,
65            auto_format: false,
66            separator: None,
67            order_check: OrderCheck::Default,
68            header: false,
69            zero_terminated: false,
70        }
71    }
72}
73
74/// Split data into lines by delimiter using SIMD scanning.
75/// Uses heuristic capacity to avoid double-scan.
76fn split_lines<'a>(data: &'a [u8], delim: u8) -> Vec<&'a [u8]> {
77    if data.is_empty() {
78        return Vec::new();
79    }
80    // Heuristic: assume average line length of ~40 bytes
81    let est_lines = data.len() / 40 + 1;
82    let mut lines = Vec::with_capacity(est_lines);
83    let mut start = 0;
84    for pos in memchr::memchr_iter(delim, data) {
85        lines.push(&data[start..pos]);
86        start = pos + 1;
87    }
88    if start < data.len() {
89        lines.push(&data[start..]);
90    }
91    lines
92}
93
94/// Split a line into fields by whitespace (runs of space/tab).
95fn split_fields_whitespace<'a>(line: &'a [u8]) -> Vec<&'a [u8]> {
96    let mut fields = Vec::with_capacity(8);
97    let mut i = 0;
98    let len = line.len();
99    while i < len {
100        // Skip whitespace
101        while i < len && (line[i] == b' ' || line[i] == b'\t') {
102            i += 1;
103        }
104        if i >= len {
105            break;
106        }
107        let start = i;
108        while i < len && line[i] != b' ' && line[i] != b'\t' {
109            i += 1;
110        }
111        fields.push(&line[start..i]);
112    }
113    fields
114}
115
116/// Split a line into fields by exact single character.
117/// Single-pass: no pre-counting scan.
118fn split_fields_char<'a>(line: &'a [u8], sep: u8) -> Vec<&'a [u8]> {
119    let mut fields = Vec::with_capacity(8);
120    let mut start = 0;
121    for pos in memchr::memchr_iter(sep, line) {
122        fields.push(&line[start..pos]);
123        start = pos + 1;
124    }
125    fields.push(&line[start..]);
126    fields
127}
128
129/// Split a line into fields based on the separator setting.
130#[inline]
131fn split_fields<'a>(line: &'a [u8], separator: Option<u8>) -> Vec<&'a [u8]> {
132    if let Some(sep) = separator {
133        split_fields_char(line, sep)
134    } else {
135        split_fields_whitespace(line)
136    }
137}
138
139/// Extract a single field from a line without allocating a Vec.
140#[inline]
141fn extract_field<'a>(line: &'a [u8], field_index: usize, separator: Option<u8>) -> &'a [u8] {
142    if let Some(sep) = separator {
143        let mut count = 0;
144        let mut start = 0;
145        for pos in memchr::memchr_iter(sep, line) {
146            if count == field_index {
147                return &line[start..pos];
148            }
149            count += 1;
150            start = pos + 1;
151        }
152        if count == field_index {
153            return &line[start..];
154        }
155        b""
156    } else {
157        let mut count = 0;
158        let mut i = 0;
159        let len = line.len();
160        while i < len {
161            while i < len && (line[i] == b' ' || line[i] == b'\t') {
162                i += 1;
163            }
164            if i >= len {
165                break;
166            }
167            let start = i;
168            while i < len && line[i] != b' ' && line[i] != b'\t' {
169                i += 1;
170            }
171            if count == field_index {
172                return &line[start..i];
173            }
174            count += 1;
175        }
176        b""
177    }
178}
179
180/// Compare two keys, optionally case-insensitive.
181#[inline]
182fn compare_keys(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
183    if case_insensitive {
184        for (&ca, &cb) in a.iter().zip(b.iter()) {
185            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
186                Ordering::Equal => continue,
187                other => return other,
188            }
189        }
190        a.len().cmp(&b.len())
191    } else {
192        a.cmp(b)
193    }
194}
195
196/// Write a paired output line (default format: join_key + other fields).
197/// Zero-copy: writes directly from line slices without allocating field Vecs.
198fn write_paired_default_zerocopy(
199    line1: &[u8],
200    line2: &[u8],
201    join_key: &[u8],
202    field1: usize,
203    field2: usize,
204    separator: Option<u8>,
205    out_sep: u8,
206    delim: u8,
207    buf: &mut Vec<u8>,
208) {
209    buf.extend_from_slice(join_key);
210    write_other_fields(line1, field1, separator, out_sep, buf);
211    write_other_fields(line2, field2, separator, out_sep, buf);
212    buf.push(delim);
213}
214
215/// Write all fields from a line except the join field, prefixed by out_sep.
216/// Avoids allocating a Vec<&[u8]> for field splitting.
217#[inline]
218fn write_other_fields(
219    line: &[u8],
220    skip_field: usize,
221    separator: Option<u8>,
222    out_sep: u8,
223    buf: &mut Vec<u8>,
224) {
225    if let Some(sep) = separator {
226        let mut field_idx = 0;
227        let mut start = 0;
228        for pos in memchr::memchr_iter(sep, line) {
229            if field_idx != skip_field {
230                buf.push(out_sep);
231                buf.extend_from_slice(&line[start..pos]);
232            }
233            field_idx += 1;
234            start = pos + 1;
235        }
236        // Last field (no trailing separator)
237        if field_idx != skip_field {
238            buf.push(out_sep);
239            buf.extend_from_slice(&line[start..]);
240        }
241    } else {
242        // Whitespace-delimited
243        let mut field_idx = 0;
244        let mut i = 0;
245        let len = line.len();
246        while i < len {
247            while i < len && (line[i] == b' ' || line[i] == b'\t') {
248                i += 1;
249            }
250            if i >= len {
251                break;
252            }
253            let start = i;
254            while i < len && line[i] != b' ' && line[i] != b'\t' {
255                i += 1;
256            }
257            if field_idx != skip_field {
258                buf.push(out_sep);
259                buf.extend_from_slice(&line[start..i]);
260            }
261            field_idx += 1;
262        }
263    }
264}
265
266/// Write a paired output line with -o format.
267fn write_paired_format(
268    fields1: &[&[u8]],
269    fields2: &[&[u8]],
270    join_key: &[u8],
271    specs: &[OutputSpec],
272    empty: &[u8],
273    out_sep: u8,
274    delim: u8,
275    buf: &mut Vec<u8>,
276) {
277    for (i, spec) in specs.iter().enumerate() {
278        if i > 0 {
279            buf.push(out_sep);
280        }
281        match spec {
282            OutputSpec::JoinField => buf.extend_from_slice(join_key),
283            OutputSpec::FileField(file_num, field_idx) => {
284                let fields = if *file_num == 0 { fields1 } else { fields2 };
285                if let Some(f) = fields.get(*field_idx) {
286                    buf.extend_from_slice(f);
287                } else {
288                    buf.extend_from_slice(empty);
289                }
290            }
291        }
292    }
293    buf.push(delim);
294}
295
296/// Write an unpaired output line (default format), zero-copy from line.
297fn write_unpaired_default_zerocopy(
298    line: &[u8],
299    join_field: usize,
300    separator: Option<u8>,
301    out_sep: u8,
302    delim: u8,
303    buf: &mut Vec<u8>,
304) {
305    let key = extract_field(line, join_field, separator);
306    buf.extend_from_slice(key);
307    write_other_fields(line, join_field, separator, out_sep, buf);
308    buf.push(delim);
309}
310
311/// Write an unpaired output line with -o format.
312fn write_unpaired_format(
313    fields: &[&[u8]],
314    file_num: usize,
315    join_field: usize,
316    specs: &[OutputSpec],
317    empty: &[u8],
318    out_sep: u8,
319    delim: u8,
320    buf: &mut Vec<u8>,
321) {
322    let key = fields.get(join_field).copied().unwrap_or(b"");
323    for (i, spec) in specs.iter().enumerate() {
324        if i > 0 {
325            buf.push(out_sep);
326        }
327        match spec {
328            OutputSpec::JoinField => buf.extend_from_slice(key),
329            OutputSpec::FileField(fnum, fidx) => {
330                if *fnum == file_num {
331                    if let Some(f) = fields.get(*fidx) {
332                        buf.extend_from_slice(f);
333                    } else {
334                        buf.extend_from_slice(empty);
335                    }
336                } else {
337                    buf.extend_from_slice(empty);
338                }
339            }
340        }
341    }
342    buf.push(delim);
343}
344
345/// Run the join merge algorithm on two sorted inputs.
346pub fn join(
347    data1: &[u8],
348    data2: &[u8],
349    config: &JoinConfig,
350    tool_name: &str,
351    file1_name: &str,
352    file2_name: &str,
353    out: &mut impl Write,
354) -> io::Result<bool> {
355    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
356    let out_sep = config.separator.unwrap_or(b' ');
357    let empty = config.empty_filler.as_deref().unwrap_or(b"");
358    let ci = config.case_insensitive;
359
360    let print_paired = !config.only_unpaired1 && !config.only_unpaired2;
361    let show_unpaired1 = config.print_unpaired1 || config.only_unpaired1;
362    let show_unpaired2 = config.print_unpaired2 || config.only_unpaired2;
363
364    let lines1 = split_lines(data1, delim);
365    let lines2 = split_lines(data2, delim);
366
367    // Pre-compute all join keys — turns O(field_position) per comparison into O(1).
368    // Memory: 16 bytes per fat pointer × (lines1 + lines2). At 1M+1M lines ≈ 32 MB,
369    // acceptable for the >2x speedup over repeated extract_field scanning.
370    let keys1: Vec<&[u8]> = lines1
371        .iter()
372        .map(|l| extract_field(l, config.field1, config.separator))
373        .collect();
374    let keys2: Vec<&[u8]> = lines2
375        .iter()
376        .map(|l| extract_field(l, config.field2, config.separator))
377        .collect();
378
379    let mut i1 = 0usize;
380    let mut i2 = 0usize;
381    let mut had_order_error = false;
382    let mut warned1 = false;
383    let mut warned2 = false;
384
385    const FLUSH_THRESHOLD: usize = 256 * 1024;
386    let mut buf = Vec::with_capacity((data1.len() + data2.len()).min(FLUSH_THRESHOLD * 2));
387
388    // Handle -o auto: build format from first lines
389    let auto_specs: Option<Vec<OutputSpec>> = if config.auto_format {
390        let fc1 = if !lines1.is_empty() {
391            split_fields(lines1[0], config.separator).len()
392        } else {
393            1
394        };
395        let fc2 = if !lines2.is_empty() {
396            split_fields(lines2[0], config.separator).len()
397        } else {
398            1
399        };
400        let mut specs = Vec::new();
401        specs.push(OutputSpec::JoinField);
402        for i in 0..fc1 {
403            if i != config.field1 {
404                specs.push(OutputSpec::FileField(0, i));
405            }
406        }
407        for i in 0..fc2 {
408            if i != config.field2 {
409                specs.push(OutputSpec::FileField(1, i));
410            }
411        }
412        Some(specs)
413    } else {
414        None
415    };
416
417    let format = config.output_format.as_deref().or(auto_specs.as_deref());
418
419    // Handle --header: join first lines without sort check
420    if config.header && !lines1.is_empty() && !lines2.is_empty() {
421        let key = extract_field(lines1[0], config.field1, config.separator);
422
423        if let Some(specs) = format {
424            let fields1 = split_fields(lines1[0], config.separator);
425            let fields2 = split_fields(lines2[0], config.separator);
426            write_paired_format(
427                &fields1, &fields2, key, specs, empty, out_sep, delim, &mut buf,
428            );
429        } else {
430            write_paired_default_zerocopy(
431                lines1[0],
432                lines2[0],
433                key,
434                config.field1,
435                config.field2,
436                config.separator,
437                out_sep,
438                delim,
439                &mut buf,
440            );
441        }
442        i1 = 1;
443        i2 = 1;
444    } else if config.header {
445        // One or both files empty — skip header
446        if !lines1.is_empty() {
447            i1 = 1;
448        }
449        if !lines2.is_empty() {
450            i2 = 1;
451        }
452    }
453
454    while i1 < lines1.len() && i2 < lines2.len() {
455        debug_assert!(i1 < keys1.len() && i2 < keys2.len());
456        // SAFETY: keys1.len() == lines1.len() and keys2.len() == lines2.len(),
457        // guaranteed by the collect() above; loop condition ensures in-bounds.
458        let key1 = unsafe { *keys1.get_unchecked(i1) };
459        let key2 = unsafe { *keys2.get_unchecked(i2) };
460
461        // Order checks
462        if config.order_check != OrderCheck::None {
463            if !warned1 && i1 > (if config.header { 1 } else { 0 }) {
464                let prev_key = keys1[i1 - 1];
465                if compare_keys(key1, prev_key, ci) == Ordering::Less {
466                    had_order_error = true;
467                    warned1 = true;
468                    eprintln!(
469                        "{}: {}:{}: is not sorted: {}",
470                        tool_name,
471                        file1_name,
472                        i1 + 1,
473                        String::from_utf8_lossy(lines1[i1])
474                    );
475                    if config.order_check == OrderCheck::Strict {
476                        out.write_all(&buf)?;
477                        return Ok(true);
478                    }
479                }
480            }
481            if !warned2 && i2 > (if config.header { 1 } else { 0 }) {
482                let prev_key = keys2[i2 - 1];
483                if compare_keys(key2, prev_key, ci) == Ordering::Less {
484                    had_order_error = true;
485                    warned2 = true;
486                    eprintln!(
487                        "{}: {}:{}: is not sorted: {}",
488                        tool_name,
489                        file2_name,
490                        i2 + 1,
491                        String::from_utf8_lossy(lines2[i2])
492                    );
493                    if config.order_check == OrderCheck::Strict {
494                        out.write_all(&buf)?;
495                        return Ok(true);
496                    }
497                }
498            }
499        }
500
501        match compare_keys(key1, key2, ci) {
502            Ordering::Less => {
503                if show_unpaired1 {
504                    if let Some(specs) = format {
505                        let fields1 = split_fields(lines1[i1], config.separator);
506                        write_unpaired_format(
507                            &fields1,
508                            0,
509                            config.field1,
510                            specs,
511                            empty,
512                            out_sep,
513                            delim,
514                            &mut buf,
515                        );
516                    } else {
517                        write_unpaired_default_zerocopy(
518                            lines1[i1],
519                            config.field1,
520                            config.separator,
521                            out_sep,
522                            delim,
523                            &mut buf,
524                        );
525                    }
526                }
527                i1 += 1;
528                if show_unpaired1 && buf.len() >= FLUSH_THRESHOLD {
529                    out.write_all(&buf)?;
530                    buf.clear();
531                }
532            }
533            Ordering::Greater => {
534                if show_unpaired2 {
535                    if let Some(specs) = format {
536                        let fields2 = split_fields(lines2[i2], config.separator);
537                        write_unpaired_format(
538                            &fields2,
539                            1,
540                            config.field2,
541                            specs,
542                            empty,
543                            out_sep,
544                            delim,
545                            &mut buf,
546                        );
547                    } else {
548                        write_unpaired_default_zerocopy(
549                            lines2[i2],
550                            config.field2,
551                            config.separator,
552                            out_sep,
553                            delim,
554                            &mut buf,
555                        );
556                    }
557                }
558                i2 += 1;
559
560                // Periodic flush to limit memory usage for large inputs
561                if buf.len() >= FLUSH_THRESHOLD {
562                    out.write_all(&buf)?;
563                    buf.clear();
564                }
565            }
566            Ordering::Equal => {
567                // Find all consecutive file2 lines with the same key
568                let group_start = i2;
569                let current_key = key2;
570                i2 += 1;
571                while i2 < lines2.len() {
572                    debug_assert!(i2 < keys2.len());
573                    // SAFETY: i2 < lines2.len() == keys2.len()
574                    let next_key = unsafe { *keys2.get_unchecked(i2) };
575                    if compare_keys(next_key, current_key, ci) != Ordering::Equal {
576                        break;
577                    }
578                    i2 += 1;
579                }
580
581                // Pre-cache file2 group fields only for -o format (cross-product needs re-access)
582                let group2_fields: Vec<Vec<&[u8]>> = if print_paired && format.is_some() {
583                    (group_start..i2)
584                        .map(|j| split_fields(lines2[j], config.separator))
585                        .collect()
586                } else {
587                    Vec::new()
588                };
589
590                // For each file1 line with the same key, cross-product with file2 group
591                loop {
592                    if print_paired {
593                        let key = extract_field(lines1[i1], config.field1, config.separator);
594                        if let Some(specs) = format {
595                            let fields1 = split_fields(lines1[i1], config.separator);
596                            for fields2 in &group2_fields {
597                                write_paired_format(
598                                    &fields1, fields2, key, specs, empty, out_sep, delim, &mut buf,
599                                );
600                            }
601                        } else {
602                            // Zero-copy path: no field Vec allocation
603                            for j in group_start..i2 {
604                                write_paired_default_zerocopy(
605                                    lines1[i1],
606                                    lines2[j],
607                                    key,
608                                    config.field1,
609                                    config.field2,
610                                    config.separator,
611                                    out_sep,
612                                    delim,
613                                    &mut buf,
614                                );
615                            }
616                        }
617                    }
618                    // Flush inside cross-product loop to bound buffer for N×M groups
619                    if buf.len() >= FLUSH_THRESHOLD {
620                        out.write_all(&buf)?;
621                        buf.clear();
622                    }
623                    i1 += 1;
624                    if i1 >= lines1.len() {
625                        break;
626                    }
627                    debug_assert!(i1 < keys1.len());
628                    // SAFETY: i1 < lines1.len() == keys1.len() (checked above)
629                    let next_key = unsafe { *keys1.get_unchecked(i1) };
630                    let cmp = compare_keys(next_key, current_key, ci);
631                    if cmp != Ordering::Equal {
632                        // Check order: next_key should be > current_key
633                        if config.order_check != OrderCheck::None
634                            && !warned1
635                            && cmp == Ordering::Less
636                        {
637                            had_order_error = true;
638                            warned1 = true;
639                            eprintln!(
640                                "{}: {}:{}: is not sorted: {}",
641                                tool_name,
642                                file1_name,
643                                i1 + 1,
644                                String::from_utf8_lossy(lines1[i1])
645                            );
646                            if config.order_check == OrderCheck::Strict {
647                                out.write_all(&buf)?;
648                                return Ok(true);
649                            }
650                        }
651                        break;
652                    }
653                }
654            }
655        }
656    }
657
658    // Drain remaining from file 1
659    while i1 < lines1.len() {
660        // Check sort order even when draining (GNU join does this)
661        if config.order_check != OrderCheck::None
662            && !warned1
663            && i1 > (if config.header { 1 } else { 0 })
664        {
665            let key1 = keys1[i1];
666            let prev_key = keys1[i1 - 1];
667            if compare_keys(key1, prev_key, ci) == Ordering::Less {
668                had_order_error = true;
669                warned1 = true;
670                eprintln!(
671                    "{}: {}:{}: is not sorted: {}",
672                    tool_name,
673                    file1_name,
674                    i1 + 1,
675                    String::from_utf8_lossy(lines1[i1])
676                );
677                if config.order_check == OrderCheck::Strict {
678                    out.write_all(&buf)?;
679                    return Ok(true);
680                }
681            }
682        }
683        if show_unpaired1 {
684            if let Some(specs) = format {
685                let fields1 = split_fields(lines1[i1], config.separator);
686                write_unpaired_format(
687                    &fields1,
688                    0,
689                    config.field1,
690                    specs,
691                    empty,
692                    out_sep,
693                    delim,
694                    &mut buf,
695                );
696            } else {
697                write_unpaired_default_zerocopy(
698                    lines1[i1],
699                    config.field1,
700                    config.separator,
701                    out_sep,
702                    delim,
703                    &mut buf,
704                );
705            }
706        }
707        i1 += 1;
708    }
709
710    // Drain remaining from file 2
711    while i2 < lines2.len() {
712        // Check sort order even when draining (GNU join does this)
713        if config.order_check != OrderCheck::None
714            && !warned2
715            && i2 > (if config.header { 1 } else { 0 })
716        {
717            let key2 = keys2[i2];
718            let prev_key = keys2[i2 - 1];
719            if compare_keys(key2, prev_key, ci) == Ordering::Less {
720                had_order_error = true;
721                warned2 = true;
722                eprintln!(
723                    "{}: {}:{}: is not sorted: {}",
724                    tool_name,
725                    file2_name,
726                    i2 + 1,
727                    String::from_utf8_lossy(lines2[i2])
728                );
729                if config.order_check == OrderCheck::Strict {
730                    out.write_all(&buf)?;
731                    return Ok(true);
732                }
733            }
734        }
735        if show_unpaired2 {
736            if let Some(specs) = format {
737                let fields2 = split_fields(lines2[i2], config.separator);
738                write_unpaired_format(
739                    &fields2,
740                    1,
741                    config.field2,
742                    specs,
743                    empty,
744                    out_sep,
745                    delim,
746                    &mut buf,
747                );
748            } else {
749                write_unpaired_default_zerocopy(
750                    lines2[i2],
751                    config.field2,
752                    config.separator,
753                    out_sep,
754                    delim,
755                    &mut buf,
756                );
757            }
758        }
759        i2 += 1;
760    }
761
762    out.write_all(&buf)?;
763    Ok(had_order_error)
764}