Skip to main content

coreutils_rs/join/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    Default,
8    Strict,
9    None,
10}
11
12/// An output field specification from -o format.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum OutputSpec {
15    /// Field 0: the join field
16    JoinField,
17    /// (file_index 0-based, field_index 0-based)
18    FileField(usize, usize),
19}
20
21/// Configuration for the join command.
22pub struct JoinConfig {
23    /// Join field for file 1 (0-indexed)
24    pub field1: usize,
25    /// Join field for file 2 (0-indexed)
26    pub field2: usize,
27    /// Also print unpairable lines from file 1 (-a 1)
28    pub print_unpaired1: bool,
29    /// Also print unpairable lines from file 2 (-a 2)
30    pub print_unpaired2: bool,
31    /// Print ONLY unpairable lines from file 1 (-v 1)
32    pub only_unpaired1: bool,
33    /// Print ONLY unpairable lines from file 2 (-v 2)
34    pub only_unpaired2: bool,
35    /// Replace missing fields with this string (-e)
36    pub empty_filler: Option<Vec<u8>>,
37    /// Ignore case in key comparison (-i)
38    pub case_insensitive: bool,
39    /// Output format (-o)
40    pub output_format: Option<Vec<OutputSpec>>,
41    /// Auto output format (-o auto)
42    pub auto_format: bool,
43    /// Field separator (-t). None = whitespace mode.
44    pub separator: Option<u8>,
45    /// Order checking
46    pub order_check: OrderCheck,
47    /// Treat first line as header (--header)
48    pub header: bool,
49    /// Use NUL as line delimiter (-z)
50    pub zero_terminated: bool,
51}
52
53impl Default for JoinConfig {
54    fn default() -> Self {
55        Self {
56            field1: 0,
57            field2: 0,
58            print_unpaired1: false,
59            print_unpaired2: false,
60            only_unpaired1: false,
61            only_unpaired2: false,
62            empty_filler: None,
63            case_insensitive: false,
64            output_format: None,
65            auto_format: false,
66            separator: None,
67            order_check: OrderCheck::Default,
68            header: false,
69            zero_terminated: false,
70        }
71    }
72}
73
74/// Split data into lines by delimiter using SIMD scanning.
75fn split_lines<'a>(data: &'a [u8], delim: u8) -> Vec<&'a [u8]> {
76    if data.is_empty() {
77        return Vec::new();
78    }
79    let count = memchr::memchr_iter(delim, data).count();
80    let has_trailing = data.last() == Some(&delim);
81    let cap = if has_trailing { count } else { count + 1 };
82    let mut lines = Vec::with_capacity(cap);
83    let mut start = 0;
84    for pos in memchr::memchr_iter(delim, data) {
85        lines.push(&data[start..pos]);
86        start = pos + 1;
87    }
88    if start < data.len() {
89        lines.push(&data[start..]);
90    }
91    lines
92}
93
94/// Split a line into fields by whitespace (runs of space/tab).
95fn split_fields_whitespace<'a>(line: &'a [u8]) -> Vec<&'a [u8]> {
96    let mut fields = Vec::new();
97    let mut i = 0;
98    let len = line.len();
99    while i < len {
100        // Skip whitespace
101        while i < len && (line[i] == b' ' || line[i] == b'\t') {
102            i += 1;
103        }
104        if i >= len {
105            break;
106        }
107        let start = i;
108        while i < len && line[i] != b' ' && line[i] != b'\t' {
109            i += 1;
110        }
111        fields.push(&line[start..i]);
112    }
113    fields
114}
115
116/// Split a line into fields by exact single character.
117fn split_fields_char<'a>(line: &'a [u8], sep: u8) -> Vec<&'a [u8]> {
118    let count = memchr::memchr_iter(sep, line).count();
119    let mut fields = Vec::with_capacity(count + 1);
120    let mut start = 0;
121    for pos in memchr::memchr_iter(sep, line) {
122        fields.push(&line[start..pos]);
123        start = pos + 1;
124    }
125    fields.push(&line[start..]);
126    fields
127}
128
129/// Split a line into fields based on the separator setting.
130#[inline]
131fn split_fields<'a>(line: &'a [u8], separator: Option<u8>) -> Vec<&'a [u8]> {
132    if let Some(sep) = separator {
133        split_fields_char(line, sep)
134    } else {
135        split_fields_whitespace(line)
136    }
137}
138
139/// Extract a single field from a line without allocating a Vec.
140#[inline]
141fn extract_field<'a>(line: &'a [u8], field_index: usize, separator: Option<u8>) -> &'a [u8] {
142    if let Some(sep) = separator {
143        let mut count = 0;
144        let mut start = 0;
145        for pos in memchr::memchr_iter(sep, line) {
146            if count == field_index {
147                return &line[start..pos];
148            }
149            count += 1;
150            start = pos + 1;
151        }
152        if count == field_index {
153            return &line[start..];
154        }
155        b""
156    } else {
157        let mut count = 0;
158        let mut i = 0;
159        let len = line.len();
160        while i < len {
161            while i < len && (line[i] == b' ' || line[i] == b'\t') {
162                i += 1;
163            }
164            if i >= len {
165                break;
166            }
167            let start = i;
168            while i < len && line[i] != b' ' && line[i] != b'\t' {
169                i += 1;
170            }
171            if count == field_index {
172                return &line[start..i];
173            }
174            count += 1;
175        }
176        b""
177    }
178}
179
180/// Compare two keys, optionally case-insensitive.
181#[inline]
182fn compare_keys(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
183    if case_insensitive {
184        for (&ca, &cb) in a.iter().zip(b.iter()) {
185            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
186                Ordering::Equal => continue,
187                other => return other,
188            }
189        }
190        a.len().cmp(&b.len())
191    } else {
192        a.cmp(b)
193    }
194}
195
196/// Write a paired output line (default format: join_key + other fields).
197fn write_paired_default(
198    fields1: &[&[u8]],
199    fields2: &[&[u8]],
200    join_key: &[u8],
201    field1: usize,
202    field2: usize,
203    out_sep: u8,
204    delim: u8,
205    buf: &mut Vec<u8>,
206) {
207    buf.extend_from_slice(join_key);
208    for (i, f) in fields1.iter().enumerate() {
209        if i == field1 {
210            continue;
211        }
212        buf.push(out_sep);
213        buf.extend_from_slice(f);
214    }
215    for (i, f) in fields2.iter().enumerate() {
216        if i == field2 {
217            continue;
218        }
219        buf.push(out_sep);
220        buf.extend_from_slice(f);
221    }
222    buf.push(delim);
223}
224
225/// Write a paired output line with -o format.
226fn write_paired_format(
227    fields1: &[&[u8]],
228    fields2: &[&[u8]],
229    join_key: &[u8],
230    specs: &[OutputSpec],
231    empty: &[u8],
232    out_sep: u8,
233    delim: u8,
234    buf: &mut Vec<u8>,
235) {
236    for (i, spec) in specs.iter().enumerate() {
237        if i > 0 {
238            buf.push(out_sep);
239        }
240        match spec {
241            OutputSpec::JoinField => buf.extend_from_slice(join_key),
242            OutputSpec::FileField(file_num, field_idx) => {
243                let fields = if *file_num == 0 { fields1 } else { fields2 };
244                if let Some(f) = fields.get(*field_idx) {
245                    buf.extend_from_slice(f);
246                } else {
247                    buf.extend_from_slice(empty);
248                }
249            }
250        }
251    }
252    buf.push(delim);
253}
254
255/// Write an unpaired output line (default format).
256fn write_unpaired_default(
257    fields: &[&[u8]],
258    join_field: usize,
259    out_sep: u8,
260    delim: u8,
261    buf: &mut Vec<u8>,
262) {
263    let key = fields.get(join_field).copied().unwrap_or(b"");
264    buf.extend_from_slice(key);
265    for (i, f) in fields.iter().enumerate() {
266        if i == join_field {
267            continue;
268        }
269        buf.push(out_sep);
270        buf.extend_from_slice(f);
271    }
272    buf.push(delim);
273}
274
275/// Write an unpaired output line with -o format.
276fn write_unpaired_format(
277    fields: &[&[u8]],
278    file_num: usize,
279    join_field: usize,
280    specs: &[OutputSpec],
281    empty: &[u8],
282    out_sep: u8,
283    delim: u8,
284    buf: &mut Vec<u8>,
285) {
286    let key = fields.get(join_field).copied().unwrap_or(b"");
287    for (i, spec) in specs.iter().enumerate() {
288        if i > 0 {
289            buf.push(out_sep);
290        }
291        match spec {
292            OutputSpec::JoinField => buf.extend_from_slice(key),
293            OutputSpec::FileField(fnum, fidx) => {
294                if *fnum == file_num {
295                    if let Some(f) = fields.get(*fidx) {
296                        buf.extend_from_slice(f);
297                    } else {
298                        buf.extend_from_slice(empty);
299                    }
300                } else {
301                    buf.extend_from_slice(empty);
302                }
303            }
304        }
305    }
306    buf.push(delim);
307}
308
309/// Run the join merge algorithm on two sorted inputs.
310pub fn join(
311    data1: &[u8],
312    data2: &[u8],
313    config: &JoinConfig,
314    tool_name: &str,
315    file1_name: &str,
316    file2_name: &str,
317    out: &mut impl Write,
318) -> io::Result<bool> {
319    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
320    let out_sep = config.separator.unwrap_or(b' ');
321    let empty = config.empty_filler.as_deref().unwrap_or(b"");
322    let ci = config.case_insensitive;
323
324    let print_paired = !config.only_unpaired1 && !config.only_unpaired2;
325    let show_unpaired1 = config.print_unpaired1 || config.only_unpaired1;
326    let show_unpaired2 = config.print_unpaired2 || config.only_unpaired2;
327
328    let lines1 = split_lines(data1, delim);
329    let lines2 = split_lines(data2, delim);
330
331    let mut i1 = 0usize;
332    let mut i2 = 0usize;
333    let mut had_order_error = false;
334    let mut warned1 = false;
335    let mut warned2 = false;
336
337    let mut buf = Vec::with_capacity((data1.len() + data2.len()) / 2);
338
339    // Handle -o auto: build format from first lines
340    let auto_specs: Option<Vec<OutputSpec>> = if config.auto_format {
341        let fc1 = if !lines1.is_empty() {
342            split_fields(lines1[0], config.separator).len()
343        } else {
344            1
345        };
346        let fc2 = if !lines2.is_empty() {
347            split_fields(lines2[0], config.separator).len()
348        } else {
349            1
350        };
351        let mut specs = Vec::new();
352        specs.push(OutputSpec::JoinField);
353        for i in 0..fc1 {
354            if i != config.field1 {
355                specs.push(OutputSpec::FileField(0, i));
356            }
357        }
358        for i in 0..fc2 {
359            if i != config.field2 {
360                specs.push(OutputSpec::FileField(1, i));
361            }
362        }
363        Some(specs)
364    } else {
365        None
366    };
367
368    let format = config.output_format.as_deref().or(auto_specs.as_deref());
369
370    // Handle --header: join first lines without sort check
371    if config.header && !lines1.is_empty() && !lines2.is_empty() {
372        let fields1 = split_fields(lines1[0], config.separator);
373        let fields2 = split_fields(lines2[0], config.separator);
374        let key = fields1.get(config.field1).copied().unwrap_or(b"");
375
376        if let Some(specs) = format {
377            write_paired_format(
378                &fields1, &fields2, key, specs, empty, out_sep, delim, &mut buf,
379            );
380        } else {
381            write_paired_default(
382                &fields1,
383                &fields2,
384                key,
385                config.field1,
386                config.field2,
387                out_sep,
388                delim,
389                &mut buf,
390            );
391        }
392        i1 = 1;
393        i2 = 1;
394    } else if config.header {
395        // One or both files empty — skip header
396        if !lines1.is_empty() {
397            i1 = 1;
398        }
399        if !lines2.is_empty() {
400            i2 = 1;
401        }
402    }
403
404    while i1 < lines1.len() && i2 < lines2.len() {
405        let key1 = extract_field(lines1[i1], config.field1, config.separator);
406        let key2 = extract_field(lines2[i2], config.field2, config.separator);
407
408        // Order checks
409        if config.order_check != OrderCheck::None {
410            if !warned1 && i1 > (if config.header { 1 } else { 0 }) {
411                let prev_key = extract_field(lines1[i1 - 1], config.field1, config.separator);
412                if compare_keys(key1, prev_key, ci) == Ordering::Less {
413                    had_order_error = true;
414                    warned1 = true;
415                    eprintln!(
416                        "{}: {}:{}: is not sorted: {}",
417                        tool_name,
418                        file1_name,
419                        i1 + 1,
420                        String::from_utf8_lossy(lines1[i1])
421                    );
422                    if config.order_check == OrderCheck::Strict {
423                        out.write_all(&buf)?;
424                        return Ok(true);
425                    }
426                }
427            }
428            if !warned2 && i2 > (if config.header { 1 } else { 0 }) {
429                let prev_key = extract_field(lines2[i2 - 1], config.field2, config.separator);
430                if compare_keys(key2, prev_key, ci) == Ordering::Less {
431                    had_order_error = true;
432                    warned2 = true;
433                    eprintln!(
434                        "{}: {}:{}: is not sorted: {}",
435                        tool_name,
436                        file2_name,
437                        i2 + 1,
438                        String::from_utf8_lossy(lines2[i2])
439                    );
440                    if config.order_check == OrderCheck::Strict {
441                        out.write_all(&buf)?;
442                        return Ok(true);
443                    }
444                }
445            }
446        }
447
448        match compare_keys(key1, key2, ci) {
449            Ordering::Less => {
450                if show_unpaired1 {
451                    let fields1 = split_fields(lines1[i1], config.separator);
452                    if let Some(specs) = format {
453                        write_unpaired_format(
454                            &fields1,
455                            0,
456                            config.field1,
457                            specs,
458                            empty,
459                            out_sep,
460                            delim,
461                            &mut buf,
462                        );
463                    } else {
464                        write_unpaired_default(&fields1, config.field1, out_sep, delim, &mut buf);
465                    }
466                }
467                i1 += 1;
468            }
469            Ordering::Greater => {
470                if show_unpaired2 {
471                    let fields2 = split_fields(lines2[i2], config.separator);
472                    if let Some(specs) = format {
473                        write_unpaired_format(
474                            &fields2,
475                            1,
476                            config.field2,
477                            specs,
478                            empty,
479                            out_sep,
480                            delim,
481                            &mut buf,
482                        );
483                    } else {
484                        write_unpaired_default(&fields2, config.field2, out_sep, delim, &mut buf);
485                    }
486                }
487                i2 += 1;
488            }
489            Ordering::Equal => {
490                // Find all consecutive file2 lines with the same key
491                let group_start = i2;
492                let current_key = key2;
493                i2 += 1;
494                while i2 < lines2.len() {
495                    let next_key = extract_field(lines2[i2], config.field2, config.separator);
496                    if compare_keys(next_key, current_key, ci) != Ordering::Equal {
497                        break;
498                    }
499                    i2 += 1;
500                }
501
502                // For each file1 line with the same key, cross-product with file2 group
503                loop {
504                    if print_paired {
505                        let fields1 = split_fields(lines1[i1], config.separator);
506                        let key = fields1.get(config.field1).copied().unwrap_or(b"");
507                        for j in group_start..i2 {
508                            let fields2 = split_fields(lines2[j], config.separator);
509                            if let Some(specs) = format {
510                                write_paired_format(
511                                    &fields1, &fields2, key, specs, empty, out_sep, delim, &mut buf,
512                                );
513                            } else {
514                                write_paired_default(
515                                    &fields1,
516                                    &fields2,
517                                    key,
518                                    config.field1,
519                                    config.field2,
520                                    out_sep,
521                                    delim,
522                                    &mut buf,
523                                );
524                            }
525                        }
526                    }
527                    i1 += 1;
528                    if i1 >= lines1.len() {
529                        break;
530                    }
531                    let next_key = extract_field(lines1[i1], config.field1, config.separator);
532                    let cmp = compare_keys(next_key, current_key, ci);
533                    if cmp != Ordering::Equal {
534                        // Check order: next_key should be > current_key
535                        if config.order_check != OrderCheck::None
536                            && !warned1
537                            && cmp == Ordering::Less
538                        {
539                            had_order_error = true;
540                            warned1 = true;
541                            eprintln!(
542                                "{}: {}:{}: is not sorted: {}",
543                                tool_name,
544                                file1_name,
545                                i1 + 1,
546                                String::from_utf8_lossy(lines1[i1])
547                            );
548                            if config.order_check == OrderCheck::Strict {
549                                out.write_all(&buf)?;
550                                return Ok(true);
551                            }
552                        }
553                        break;
554                    }
555                }
556            }
557        }
558    }
559
560    // Drain remaining from file 1
561    while i1 < lines1.len() {
562        // Check sort order even when draining (GNU join does this)
563        if config.order_check != OrderCheck::None
564            && !warned1
565            && i1 > (if config.header { 1 } else { 0 })
566        {
567            let key1 = extract_field(lines1[i1], config.field1, config.separator);
568            let prev_key = extract_field(lines1[i1 - 1], config.field1, config.separator);
569            if compare_keys(key1, prev_key, ci) == Ordering::Less {
570                had_order_error = true;
571                warned1 = true;
572                eprintln!(
573                    "{}: {}:{}: is not sorted: {}",
574                    tool_name,
575                    file1_name,
576                    i1 + 1,
577                    String::from_utf8_lossy(lines1[i1])
578                );
579                if config.order_check == OrderCheck::Strict {
580                    out.write_all(&buf)?;
581                    return Ok(true);
582                }
583            }
584        }
585        if show_unpaired1 {
586            let fields1 = split_fields(lines1[i1], config.separator);
587            if let Some(specs) = format {
588                write_unpaired_format(
589                    &fields1,
590                    0,
591                    config.field1,
592                    specs,
593                    empty,
594                    out_sep,
595                    delim,
596                    &mut buf,
597                );
598            } else {
599                write_unpaired_default(&fields1, config.field1, out_sep, delim, &mut buf);
600            }
601        }
602        i1 += 1;
603    }
604
605    // Drain remaining from file 2
606    while i2 < lines2.len() {
607        // Check sort order even when draining (GNU join does this)
608        if config.order_check != OrderCheck::None
609            && !warned2
610            && i2 > (if config.header { 1 } else { 0 })
611        {
612            let key2 = extract_field(lines2[i2], config.field2, config.separator);
613            let prev_key = extract_field(lines2[i2 - 1], config.field2, config.separator);
614            if compare_keys(key2, prev_key, ci) == Ordering::Less {
615                had_order_error = true;
616                warned2 = true;
617                eprintln!(
618                    "{}: {}:{}: is not sorted: {}",
619                    tool_name,
620                    file2_name,
621                    i2 + 1,
622                    String::from_utf8_lossy(lines2[i2])
623                );
624                if config.order_check == OrderCheck::Strict {
625                    out.write_all(&buf)?;
626                    return Ok(true);
627                }
628            }
629        }
630        if show_unpaired2 {
631            let fields2 = split_fields(lines2[i2], config.separator);
632            if let Some(specs) = format {
633                write_unpaired_format(
634                    &fields2,
635                    1,
636                    config.field2,
637                    specs,
638                    empty,
639                    out_sep,
640                    delim,
641                    &mut buf,
642                );
643            } else {
644                write_unpaired_default(&fields2, config.field2, out_sep, delim, &mut buf);
645            }
646        }
647        i2 += 1;
648    }
649
650    out.write_all(&buf)?;
651    Ok(had_order_error)
652}