Skip to main content

coreutils_rs/join/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    Default,
8    Strict,
9    None,
10}
11
12/// An output field specification from -o format.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum OutputSpec {
15    /// Field 0: the join field
16    JoinField,
17    /// (file_index 0-based, field_index 0-based)
18    FileField(usize, usize),
19}
20
21/// Configuration for the join command.
22pub struct JoinConfig {
23    /// Join field for file 1 (0-indexed)
24    pub field1: usize,
25    /// Join field for file 2 (0-indexed)
26    pub field2: usize,
27    /// Also print unpairable lines from file 1 (-a 1)
28    pub print_unpaired1: bool,
29    /// Also print unpairable lines from file 2 (-a 2)
30    pub print_unpaired2: bool,
31    /// Print ONLY unpairable lines from file 1 (-v 1)
32    pub only_unpaired1: bool,
33    /// Print ONLY unpairable lines from file 2 (-v 2)
34    pub only_unpaired2: bool,
35    /// Replace missing fields with this string (-e)
36    pub empty_filler: Option<Vec<u8>>,
37    /// Ignore case in key comparison (-i)
38    pub case_insensitive: bool,
39    /// Output format (-o)
40    pub output_format: Option<Vec<OutputSpec>>,
41    /// Auto output format (-o auto)
42    pub auto_format: bool,
43    /// Field separator (-t). None = whitespace mode.
44    pub separator: Option<u8>,
45    /// Order checking
46    pub order_check: OrderCheck,
47    /// Treat first line as header (--header)
48    pub header: bool,
49    /// Use NUL as line delimiter (-z)
50    pub zero_terminated: bool,
51}
52
53impl Default for JoinConfig {
54    fn default() -> Self {
55        Self {
56            field1: 0,
57            field2: 0,
58            print_unpaired1: false,
59            print_unpaired2: false,
60            only_unpaired1: false,
61            only_unpaired2: false,
62            empty_filler: None,
63            case_insensitive: false,
64            output_format: None,
65            auto_format: false,
66            separator: None,
67            order_check: OrderCheck::Default,
68            header: false,
69            zero_terminated: false,
70        }
71    }
72}
73
74/// Split data into lines by delimiter using SIMD scanning.
75/// Uses heuristic capacity to avoid double-scan.
76fn split_lines<'a>(data: &'a [u8], delim: u8) -> Vec<&'a [u8]> {
77    if data.is_empty() {
78        return Vec::new();
79    }
80    // Heuristic: assume average line length of ~40 bytes
81    let est_lines = data.len() / 40 + 1;
82    let mut lines = Vec::with_capacity(est_lines);
83    let mut start = 0;
84    for pos in memchr::memchr_iter(delim, data) {
85        lines.push(&data[start..pos]);
86        start = pos + 1;
87    }
88    if start < data.len() {
89        lines.push(&data[start..]);
90    }
91    lines
92}
93
94/// Split a line into fields by whitespace (runs of space/tab).
95fn split_fields_whitespace<'a>(line: &'a [u8]) -> Vec<&'a [u8]> {
96    let mut fields = Vec::with_capacity(8);
97    let mut i = 0;
98    let len = line.len();
99    while i < len {
100        // Skip whitespace
101        while i < len && (line[i] == b' ' || line[i] == b'\t') {
102            i += 1;
103        }
104        if i >= len {
105            break;
106        }
107        let start = i;
108        while i < len && line[i] != b' ' && line[i] != b'\t' {
109            i += 1;
110        }
111        fields.push(&line[start..i]);
112    }
113    fields
114}
115
116/// Split a line into fields by exact single character.
117/// Single-pass: no pre-counting scan.
118fn split_fields_char<'a>(line: &'a [u8], sep: u8) -> Vec<&'a [u8]> {
119    let mut fields = Vec::with_capacity(8);
120    let mut start = 0;
121    for pos in memchr::memchr_iter(sep, line) {
122        fields.push(&line[start..pos]);
123        start = pos + 1;
124    }
125    fields.push(&line[start..]);
126    fields
127}
128
129/// Split a line into fields based on the separator setting.
130#[inline]
131fn split_fields<'a>(line: &'a [u8], separator: Option<u8>) -> Vec<&'a [u8]> {
132    if let Some(sep) = separator {
133        split_fields_char(line, sep)
134    } else {
135        split_fields_whitespace(line)
136    }
137}
138
139/// Extract a single field from a line without allocating a Vec.
140#[inline]
141fn extract_field<'a>(line: &'a [u8], field_index: usize, separator: Option<u8>) -> &'a [u8] {
142    if let Some(sep) = separator {
143        let mut count = 0;
144        let mut start = 0;
145        for pos in memchr::memchr_iter(sep, line) {
146            if count == field_index {
147                return &line[start..pos];
148            }
149            count += 1;
150            start = pos + 1;
151        }
152        if count == field_index {
153            return &line[start..];
154        }
155        b""
156    } else {
157        let mut count = 0;
158        let mut i = 0;
159        let len = line.len();
160        while i < len {
161            while i < len && (line[i] == b' ' || line[i] == b'\t') {
162                i += 1;
163            }
164            if i >= len {
165                break;
166            }
167            let start = i;
168            while i < len && line[i] != b' ' && line[i] != b'\t' {
169                i += 1;
170            }
171            if count == field_index {
172                return &line[start..i];
173            }
174            count += 1;
175        }
176        b""
177    }
178}
179
180/// Compare two keys, optionally case-insensitive.
181#[inline]
182fn compare_keys(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
183    if case_insensitive {
184        for (&ca, &cb) in a.iter().zip(b.iter()) {
185            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
186                Ordering::Equal => continue,
187                other => return other,
188            }
189        }
190        a.len().cmp(&b.len())
191    } else {
192        a.cmp(b)
193    }
194}
195
196/// Write a paired output line (default format: join_key + other fields).
197fn write_paired_default(
198    fields1: &[&[u8]],
199    fields2: &[&[u8]],
200    join_key: &[u8],
201    field1: usize,
202    field2: usize,
203    out_sep: u8,
204    delim: u8,
205    buf: &mut Vec<u8>,
206) {
207    buf.extend_from_slice(join_key);
208    for (i, f) in fields1.iter().enumerate() {
209        if i == field1 {
210            continue;
211        }
212        buf.push(out_sep);
213        buf.extend_from_slice(f);
214    }
215    for (i, f) in fields2.iter().enumerate() {
216        if i == field2 {
217            continue;
218        }
219        buf.push(out_sep);
220        buf.extend_from_slice(f);
221    }
222    buf.push(delim);
223}
224
225/// Write a paired output line with -o format.
226fn write_paired_format(
227    fields1: &[&[u8]],
228    fields2: &[&[u8]],
229    join_key: &[u8],
230    specs: &[OutputSpec],
231    empty: &[u8],
232    out_sep: u8,
233    delim: u8,
234    buf: &mut Vec<u8>,
235) {
236    for (i, spec) in specs.iter().enumerate() {
237        if i > 0 {
238            buf.push(out_sep);
239        }
240        match spec {
241            OutputSpec::JoinField => buf.extend_from_slice(join_key),
242            OutputSpec::FileField(file_num, field_idx) => {
243                let fields = if *file_num == 0 { fields1 } else { fields2 };
244                if let Some(f) = fields.get(*field_idx) {
245                    buf.extend_from_slice(f);
246                } else {
247                    buf.extend_from_slice(empty);
248                }
249            }
250        }
251    }
252    buf.push(delim);
253}
254
255/// Write an unpaired output line (default format).
256fn write_unpaired_default(
257    fields: &[&[u8]],
258    join_field: usize,
259    out_sep: u8,
260    delim: u8,
261    buf: &mut Vec<u8>,
262) {
263    let key = fields.get(join_field).copied().unwrap_or(b"");
264    buf.extend_from_slice(key);
265    for (i, f) in fields.iter().enumerate() {
266        if i == join_field {
267            continue;
268        }
269        buf.push(out_sep);
270        buf.extend_from_slice(f);
271    }
272    buf.push(delim);
273}
274
275/// Write an unpaired output line with -o format.
276fn write_unpaired_format(
277    fields: &[&[u8]],
278    file_num: usize,
279    join_field: usize,
280    specs: &[OutputSpec],
281    empty: &[u8],
282    out_sep: u8,
283    delim: u8,
284    buf: &mut Vec<u8>,
285) {
286    let key = fields.get(join_field).copied().unwrap_or(b"");
287    for (i, spec) in specs.iter().enumerate() {
288        if i > 0 {
289            buf.push(out_sep);
290        }
291        match spec {
292            OutputSpec::JoinField => buf.extend_from_slice(key),
293            OutputSpec::FileField(fnum, fidx) => {
294                if *fnum == file_num {
295                    if let Some(f) = fields.get(*fidx) {
296                        buf.extend_from_slice(f);
297                    } else {
298                        buf.extend_from_slice(empty);
299                    }
300                } else {
301                    buf.extend_from_slice(empty);
302                }
303            }
304        }
305    }
306    buf.push(delim);
307}
308
309/// Run the join merge algorithm on two sorted inputs.
310pub fn join(
311    data1: &[u8],
312    data2: &[u8],
313    config: &JoinConfig,
314    tool_name: &str,
315    file1_name: &str,
316    file2_name: &str,
317    out: &mut impl Write,
318) -> io::Result<bool> {
319    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
320    let out_sep = config.separator.unwrap_or(b' ');
321    let empty = config.empty_filler.as_deref().unwrap_or(b"");
322    let ci = config.case_insensitive;
323
324    let print_paired = !config.only_unpaired1 && !config.only_unpaired2;
325    let show_unpaired1 = config.print_unpaired1 || config.only_unpaired1;
326    let show_unpaired2 = config.print_unpaired2 || config.only_unpaired2;
327
328    let lines1 = split_lines(data1, delim);
329    let lines2 = split_lines(data2, delim);
330
331    // Pre-compute all join keys — turns O(field_position) per comparison into O(1).
332    // Memory: 16 bytes per fat pointer × (lines1 + lines2). At 1M+1M lines ≈ 32 MB,
333    // acceptable for the >2x speedup over repeated extract_field scanning.
334    let keys1: Vec<&[u8]> = lines1
335        .iter()
336        .map(|l| extract_field(l, config.field1, config.separator))
337        .collect();
338    let keys2: Vec<&[u8]> = lines2
339        .iter()
340        .map(|l| extract_field(l, config.field2, config.separator))
341        .collect();
342
343    let mut i1 = 0usize;
344    let mut i2 = 0usize;
345    let mut had_order_error = false;
346    let mut warned1 = false;
347    let mut warned2 = false;
348
349    const FLUSH_THRESHOLD: usize = 256 * 1024;
350    let mut buf = Vec::with_capacity((data1.len() + data2.len()).min(FLUSH_THRESHOLD * 2));
351
352    // Handle -o auto: build format from first lines
353    let auto_specs: Option<Vec<OutputSpec>> = if config.auto_format {
354        let fc1 = if !lines1.is_empty() {
355            split_fields(lines1[0], config.separator).len()
356        } else {
357            1
358        };
359        let fc2 = if !lines2.is_empty() {
360            split_fields(lines2[0], config.separator).len()
361        } else {
362            1
363        };
364        let mut specs = Vec::new();
365        specs.push(OutputSpec::JoinField);
366        for i in 0..fc1 {
367            if i != config.field1 {
368                specs.push(OutputSpec::FileField(0, i));
369            }
370        }
371        for i in 0..fc2 {
372            if i != config.field2 {
373                specs.push(OutputSpec::FileField(1, i));
374            }
375        }
376        Some(specs)
377    } else {
378        None
379    };
380
381    let format = config.output_format.as_deref().or(auto_specs.as_deref());
382
383    // Handle --header: join first lines without sort check
384    if config.header && !lines1.is_empty() && !lines2.is_empty() {
385        let fields1 = split_fields(lines1[0], config.separator);
386        let fields2 = split_fields(lines2[0], config.separator);
387        let key = fields1.get(config.field1).copied().unwrap_or(b"");
388
389        if let Some(specs) = format {
390            write_paired_format(
391                &fields1, &fields2, key, specs, empty, out_sep, delim, &mut buf,
392            );
393        } else {
394            write_paired_default(
395                &fields1,
396                &fields2,
397                key,
398                config.field1,
399                config.field2,
400                out_sep,
401                delim,
402                &mut buf,
403            );
404        }
405        i1 = 1;
406        i2 = 1;
407    } else if config.header {
408        // One or both files empty — skip header
409        if !lines1.is_empty() {
410            i1 = 1;
411        }
412        if !lines2.is_empty() {
413            i2 = 1;
414        }
415    }
416
417    while i1 < lines1.len() && i2 < lines2.len() {
418        debug_assert!(i1 < keys1.len() && i2 < keys2.len());
419        // SAFETY: keys1.len() == lines1.len() and keys2.len() == lines2.len(),
420        // guaranteed by the collect() above; loop condition ensures in-bounds.
421        let key1 = unsafe { *keys1.get_unchecked(i1) };
422        let key2 = unsafe { *keys2.get_unchecked(i2) };
423
424        // Order checks
425        if config.order_check != OrderCheck::None {
426            if !warned1 && i1 > (if config.header { 1 } else { 0 }) {
427                let prev_key = keys1[i1 - 1];
428                if compare_keys(key1, prev_key, ci) == Ordering::Less {
429                    had_order_error = true;
430                    warned1 = true;
431                    eprintln!(
432                        "{}: {}:{}: is not sorted: {}",
433                        tool_name,
434                        file1_name,
435                        i1 + 1,
436                        String::from_utf8_lossy(lines1[i1])
437                    );
438                    if config.order_check == OrderCheck::Strict {
439                        out.write_all(&buf)?;
440                        return Ok(true);
441                    }
442                }
443            }
444            if !warned2 && i2 > (if config.header { 1 } else { 0 }) {
445                let prev_key = keys2[i2 - 1];
446                if compare_keys(key2, prev_key, ci) == Ordering::Less {
447                    had_order_error = true;
448                    warned2 = true;
449                    eprintln!(
450                        "{}: {}:{}: is not sorted: {}",
451                        tool_name,
452                        file2_name,
453                        i2 + 1,
454                        String::from_utf8_lossy(lines2[i2])
455                    );
456                    if config.order_check == OrderCheck::Strict {
457                        out.write_all(&buf)?;
458                        return Ok(true);
459                    }
460                }
461            }
462        }
463
464        match compare_keys(key1, key2, ci) {
465            Ordering::Less => {
466                if show_unpaired1 {
467                    let fields1 = split_fields(lines1[i1], config.separator);
468                    if let Some(specs) = format {
469                        write_unpaired_format(
470                            &fields1,
471                            0,
472                            config.field1,
473                            specs,
474                            empty,
475                            out_sep,
476                            delim,
477                            &mut buf,
478                        );
479                    } else {
480                        write_unpaired_default(&fields1, config.field1, out_sep, delim, &mut buf);
481                    }
482                }
483                i1 += 1;
484                if show_unpaired1 && buf.len() >= FLUSH_THRESHOLD {
485                    out.write_all(&buf)?;
486                    buf.clear();
487                }
488            }
489            Ordering::Greater => {
490                if show_unpaired2 {
491                    let fields2 = split_fields(lines2[i2], config.separator);
492                    if let Some(specs) = format {
493                        write_unpaired_format(
494                            &fields2,
495                            1,
496                            config.field2,
497                            specs,
498                            empty,
499                            out_sep,
500                            delim,
501                            &mut buf,
502                        );
503                    } else {
504                        write_unpaired_default(&fields2, config.field2, out_sep, delim, &mut buf);
505                    }
506                }
507                i2 += 1;
508
509                // Periodic flush to limit memory usage for large inputs
510                if buf.len() >= FLUSH_THRESHOLD {
511                    out.write_all(&buf)?;
512                    buf.clear();
513                }
514            }
515            Ordering::Equal => {
516                // Find all consecutive file2 lines with the same key
517                let group_start = i2;
518                let current_key = key2;
519                i2 += 1;
520                while i2 < lines2.len() {
521                    debug_assert!(i2 < keys2.len());
522                    // SAFETY: i2 < lines2.len() == keys2.len()
523                    let next_key = unsafe { *keys2.get_unchecked(i2) };
524                    if compare_keys(next_key, current_key, ci) != Ordering::Equal {
525                        break;
526                    }
527                    i2 += 1;
528                }
529
530                // Pre-cache file2 group fields to avoid re-splitting in cross-product
531                let group2_fields: Vec<Vec<&[u8]>> = if print_paired {
532                    (group_start..i2)
533                        .map(|j| split_fields(lines2[j], config.separator))
534                        .collect()
535                } else {
536                    Vec::new()
537                };
538
539                // For each file1 line with the same key, cross-product with file2 group
540                loop {
541                    if print_paired {
542                        let fields1 = split_fields(lines1[i1], config.separator);
543                        let key = fields1.get(config.field1).copied().unwrap_or(b"");
544                        for fields2 in &group2_fields {
545                            if let Some(specs) = format {
546                                write_paired_format(
547                                    &fields1, fields2, key, specs, empty, out_sep, delim, &mut buf,
548                                );
549                            } else {
550                                write_paired_default(
551                                    &fields1,
552                                    fields2,
553                                    key,
554                                    config.field1,
555                                    config.field2,
556                                    out_sep,
557                                    delim,
558                                    &mut buf,
559                                );
560                            }
561                        }
562                    }
563                    // Flush inside cross-product loop to bound buffer for N×M groups
564                    if buf.len() >= FLUSH_THRESHOLD {
565                        out.write_all(&buf)?;
566                        buf.clear();
567                    }
568                    i1 += 1;
569                    if i1 >= lines1.len() {
570                        break;
571                    }
572                    debug_assert!(i1 < keys1.len());
573                    // SAFETY: i1 < lines1.len() == keys1.len() (checked above)
574                    let next_key = unsafe { *keys1.get_unchecked(i1) };
575                    let cmp = compare_keys(next_key, current_key, ci);
576                    if cmp != Ordering::Equal {
577                        // Check order: next_key should be > current_key
578                        if config.order_check != OrderCheck::None
579                            && !warned1
580                            && cmp == Ordering::Less
581                        {
582                            had_order_error = true;
583                            warned1 = true;
584                            eprintln!(
585                                "{}: {}:{}: is not sorted: {}",
586                                tool_name,
587                                file1_name,
588                                i1 + 1,
589                                String::from_utf8_lossy(lines1[i1])
590                            );
591                            if config.order_check == OrderCheck::Strict {
592                                out.write_all(&buf)?;
593                                return Ok(true);
594                            }
595                        }
596                        break;
597                    }
598                }
599            }
600        }
601    }
602
603    // Drain remaining from file 1
604    while i1 < lines1.len() {
605        // Check sort order even when draining (GNU join does this)
606        if config.order_check != OrderCheck::None
607            && !warned1
608            && i1 > (if config.header { 1 } else { 0 })
609        {
610            let key1 = keys1[i1];
611            let prev_key = keys1[i1 - 1];
612            if compare_keys(key1, prev_key, ci) == Ordering::Less {
613                had_order_error = true;
614                warned1 = true;
615                eprintln!(
616                    "{}: {}:{}: is not sorted: {}",
617                    tool_name,
618                    file1_name,
619                    i1 + 1,
620                    String::from_utf8_lossy(lines1[i1])
621                );
622                if config.order_check == OrderCheck::Strict {
623                    out.write_all(&buf)?;
624                    return Ok(true);
625                }
626            }
627        }
628        if show_unpaired1 {
629            let fields1 = split_fields(lines1[i1], config.separator);
630            if let Some(specs) = format {
631                write_unpaired_format(
632                    &fields1,
633                    0,
634                    config.field1,
635                    specs,
636                    empty,
637                    out_sep,
638                    delim,
639                    &mut buf,
640                );
641            } else {
642                write_unpaired_default(&fields1, config.field1, out_sep, delim, &mut buf);
643            }
644        }
645        i1 += 1;
646    }
647
648    // Drain remaining from file 2
649    while i2 < lines2.len() {
650        // Check sort order even when draining (GNU join does this)
651        if config.order_check != OrderCheck::None
652            && !warned2
653            && i2 > (if config.header { 1 } else { 0 })
654        {
655            let key2 = keys2[i2];
656            let prev_key = keys2[i2 - 1];
657            if compare_keys(key2, prev_key, ci) == Ordering::Less {
658                had_order_error = true;
659                warned2 = true;
660                eprintln!(
661                    "{}: {}:{}: is not sorted: {}",
662                    tool_name,
663                    file2_name,
664                    i2 + 1,
665                    String::from_utf8_lossy(lines2[i2])
666                );
667                if config.order_check == OrderCheck::Strict {
668                    out.write_all(&buf)?;
669                    return Ok(true);
670                }
671            }
672        }
673        if show_unpaired2 {
674            let fields2 = split_fields(lines2[i2], config.separator);
675            if let Some(specs) = format {
676                write_unpaired_format(
677                    &fields2,
678                    1,
679                    config.field2,
680                    specs,
681                    empty,
682                    out_sep,
683                    delim,
684                    &mut buf,
685                );
686            } else {
687                write_unpaired_default(&fields2, config.field2, out_sep, delim, &mut buf);
688            }
689        }
690        i2 += 1;
691    }
692
693    out.write_all(&buf)?;
694    Ok(had_order_error)
695}