Skip to main content

coreutils_rs/join/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    Default,
8    Strict,
9    None,
10}
11
12/// An output field specification from -o format.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum OutputSpec {
15    /// Field 0: the join field
16    JoinField,
17    /// (file_index 0-based, field_index 0-based)
18    FileField(usize, usize),
19}
20
21/// Configuration for the join command.
22pub struct JoinConfig {
23    /// Join field for file 1 (0-indexed)
24    pub field1: usize,
25    /// Join field for file 2 (0-indexed)
26    pub field2: usize,
27    /// Also print unpairable lines from file 1 (-a 1)
28    pub print_unpaired1: bool,
29    /// Also print unpairable lines from file 2 (-a 2)
30    pub print_unpaired2: bool,
31    /// Print ONLY unpairable lines from file 1 (-v 1)
32    pub only_unpaired1: bool,
33    /// Print ONLY unpairable lines from file 2 (-v 2)
34    pub only_unpaired2: bool,
35    /// Replace missing fields with this string (-e)
36    pub empty_filler: Option<Vec<u8>>,
37    /// Ignore case in key comparison (-i)
38    pub case_insensitive: bool,
39    /// Output format (-o)
40    pub output_format: Option<Vec<OutputSpec>>,
41    /// Auto output format (-o auto)
42    pub auto_format: bool,
43    /// Field separator (-t). None = whitespace mode.
44    pub separator: Option<u8>,
45    /// Order checking
46    pub order_check: OrderCheck,
47    /// Treat first line as header (--header)
48    pub header: bool,
49    /// Use NUL as line delimiter (-z)
50    pub zero_terminated: bool,
51}
52
53impl Default for JoinConfig {
54    fn default() -> Self {
55        Self {
56            field1: 0,
57            field2: 0,
58            print_unpaired1: false,
59            print_unpaired2: false,
60            only_unpaired1: false,
61            only_unpaired2: false,
62            empty_filler: None,
63            case_insensitive: false,
64            output_format: None,
65            auto_format: false,
66            separator: None,
67            order_check: OrderCheck::Default,
68            header: false,
69            zero_terminated: false,
70        }
71    }
72}
73
74/// Split data into lines by delimiter using SIMD scanning.
75fn split_lines<'a>(data: &'a [u8], delim: u8) -> Vec<&'a [u8]> {
76    if data.is_empty() {
77        return Vec::new();
78    }
79    let count = memchr::memchr_iter(delim, data).count();
80    let has_trailing = data.last() == Some(&delim);
81    let cap = if has_trailing { count } else { count + 1 };
82    let mut lines = Vec::with_capacity(cap);
83    let mut start = 0;
84    for pos in memchr::memchr_iter(delim, data) {
85        lines.push(&data[start..pos]);
86        start = pos + 1;
87    }
88    if start < data.len() {
89        lines.push(&data[start..]);
90    }
91    lines
92}
93
94/// Split a line into fields by whitespace (runs of space/tab).
95fn split_fields_whitespace<'a>(line: &'a [u8]) -> Vec<&'a [u8]> {
96    let mut fields = Vec::new();
97    let mut i = 0;
98    let len = line.len();
99    while i < len {
100        // Skip whitespace
101        while i < len && (line[i] == b' ' || line[i] == b'\t') {
102            i += 1;
103        }
104        if i >= len {
105            break;
106        }
107        let start = i;
108        while i < len && line[i] != b' ' && line[i] != b'\t' {
109            i += 1;
110        }
111        fields.push(&line[start..i]);
112    }
113    fields
114}
115
116/// Split a line into fields by exact single character.
117fn split_fields_char<'a>(line: &'a [u8], sep: u8) -> Vec<&'a [u8]> {
118    let count = memchr::memchr_iter(sep, line).count();
119    let mut fields = Vec::with_capacity(count + 1);
120    let mut start = 0;
121    for pos in memchr::memchr_iter(sep, line) {
122        fields.push(&line[start..pos]);
123        start = pos + 1;
124    }
125    fields.push(&line[start..]);
126    fields
127}
128
129/// Split a line into fields based on the separator setting.
130#[inline]
131fn split_fields<'a>(line: &'a [u8], separator: Option<u8>) -> Vec<&'a [u8]> {
132    if let Some(sep) = separator {
133        split_fields_char(line, sep)
134    } else {
135        split_fields_whitespace(line)
136    }
137}
138
139/// Extract a single field from a line without allocating a Vec.
140#[inline]
141fn extract_field<'a>(line: &'a [u8], field_index: usize, separator: Option<u8>) -> &'a [u8] {
142    if let Some(sep) = separator {
143        let mut count = 0;
144        let mut start = 0;
145        for pos in memchr::memchr_iter(sep, line) {
146            if count == field_index {
147                return &line[start..pos];
148            }
149            count += 1;
150            start = pos + 1;
151        }
152        if count == field_index {
153            return &line[start..];
154        }
155        b""
156    } else {
157        let mut count = 0;
158        let mut i = 0;
159        let len = line.len();
160        while i < len {
161            while i < len && (line[i] == b' ' || line[i] == b'\t') {
162                i += 1;
163            }
164            if i >= len {
165                break;
166            }
167            let start = i;
168            while i < len && line[i] != b' ' && line[i] != b'\t' {
169                i += 1;
170            }
171            if count == field_index {
172                return &line[start..i];
173            }
174            count += 1;
175        }
176        b""
177    }
178}
179
180/// Compare two keys, optionally case-insensitive.
181#[inline]
182fn compare_keys(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
183    if case_insensitive {
184        for (&ca, &cb) in a.iter().zip(b.iter()) {
185            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
186                Ordering::Equal => continue,
187                other => return other,
188            }
189        }
190        a.len().cmp(&b.len())
191    } else {
192        a.cmp(b)
193    }
194}
195
196/// Write a paired output line (default format: join_key + other fields).
197fn write_paired_default(
198    fields1: &[&[u8]],
199    fields2: &[&[u8]],
200    join_key: &[u8],
201    field1: usize,
202    field2: usize,
203    out_sep: u8,
204    delim: u8,
205    buf: &mut Vec<u8>,
206) {
207    buf.extend_from_slice(join_key);
208    for (i, f) in fields1.iter().enumerate() {
209        if i == field1 {
210            continue;
211        }
212        buf.push(out_sep);
213        buf.extend_from_slice(f);
214    }
215    for (i, f) in fields2.iter().enumerate() {
216        if i == field2 {
217            continue;
218        }
219        buf.push(out_sep);
220        buf.extend_from_slice(f);
221    }
222    buf.push(delim);
223}
224
225/// Write a paired output line with -o format.
226fn write_paired_format(
227    fields1: &[&[u8]],
228    fields2: &[&[u8]],
229    join_key: &[u8],
230    specs: &[OutputSpec],
231    empty: &[u8],
232    out_sep: u8,
233    delim: u8,
234    buf: &mut Vec<u8>,
235) {
236    for (i, spec) in specs.iter().enumerate() {
237        if i > 0 {
238            buf.push(out_sep);
239        }
240        match spec {
241            OutputSpec::JoinField => buf.extend_from_slice(join_key),
242            OutputSpec::FileField(file_num, field_idx) => {
243                let fields = if *file_num == 0 { fields1 } else { fields2 };
244                if let Some(f) = fields.get(*field_idx) {
245                    buf.extend_from_slice(f);
246                } else {
247                    buf.extend_from_slice(empty);
248                }
249            }
250        }
251    }
252    buf.push(delim);
253}
254
255/// Write an unpaired output line (default format).
256fn write_unpaired_default(
257    fields: &[&[u8]],
258    join_field: usize,
259    out_sep: u8,
260    delim: u8,
261    buf: &mut Vec<u8>,
262) {
263    let key = fields.get(join_field).copied().unwrap_or(b"");
264    buf.extend_from_slice(key);
265    for (i, f) in fields.iter().enumerate() {
266        if i == join_field {
267            continue;
268        }
269        buf.push(out_sep);
270        buf.extend_from_slice(f);
271    }
272    buf.push(delim);
273}
274
275/// Write an unpaired output line with -o format.
276fn write_unpaired_format(
277    fields: &[&[u8]],
278    file_num: usize,
279    join_field: usize,
280    specs: &[OutputSpec],
281    empty: &[u8],
282    out_sep: u8,
283    delim: u8,
284    buf: &mut Vec<u8>,
285) {
286    let key = fields.get(join_field).copied().unwrap_or(b"");
287    for (i, spec) in specs.iter().enumerate() {
288        if i > 0 {
289            buf.push(out_sep);
290        }
291        match spec {
292            OutputSpec::JoinField => buf.extend_from_slice(key),
293            OutputSpec::FileField(fnum, fidx) => {
294                if *fnum == file_num {
295                    if let Some(f) = fields.get(*fidx) {
296                        buf.extend_from_slice(f);
297                    } else {
298                        buf.extend_from_slice(empty);
299                    }
300                } else {
301                    buf.extend_from_slice(empty);
302                }
303            }
304        }
305    }
306    buf.push(delim);
307}
308
309/// Run the join merge algorithm on two sorted inputs.
310pub fn join(
311    data1: &[u8],
312    data2: &[u8],
313    config: &JoinConfig,
314    tool_name: &str,
315    file1_name: &str,
316    file2_name: &str,
317    out: &mut impl Write,
318) -> io::Result<bool> {
319    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
320    let out_sep = config.separator.unwrap_or(b' ');
321    let empty = config.empty_filler.as_deref().unwrap_or(b"");
322    let ci = config.case_insensitive;
323
324    let print_paired = !config.only_unpaired1 && !config.only_unpaired2;
325    let show_unpaired1 = config.print_unpaired1 || config.only_unpaired1;
326    let show_unpaired2 = config.print_unpaired2 || config.only_unpaired2;
327
328    let lines1 = split_lines(data1, delim);
329    let lines2 = split_lines(data2, delim);
330
331    let mut i1 = 0usize;
332    let mut i2 = 0usize;
333    let mut had_order_error = false;
334    let mut warned1 = false;
335    let mut warned2 = false;
336
337    const FLUSH_THRESHOLD: usize = 256 * 1024;
338    let mut buf = Vec::with_capacity((data1.len() + data2.len()).min(FLUSH_THRESHOLD * 2));
339
340    // Handle -o auto: build format from first lines
341    let auto_specs: Option<Vec<OutputSpec>> = if config.auto_format {
342        let fc1 = if !lines1.is_empty() {
343            split_fields(lines1[0], config.separator).len()
344        } else {
345            1
346        };
347        let fc2 = if !lines2.is_empty() {
348            split_fields(lines2[0], config.separator).len()
349        } else {
350            1
351        };
352        let mut specs = Vec::new();
353        specs.push(OutputSpec::JoinField);
354        for i in 0..fc1 {
355            if i != config.field1 {
356                specs.push(OutputSpec::FileField(0, i));
357            }
358        }
359        for i in 0..fc2 {
360            if i != config.field2 {
361                specs.push(OutputSpec::FileField(1, i));
362            }
363        }
364        Some(specs)
365    } else {
366        None
367    };
368
369    let format = config.output_format.as_deref().or(auto_specs.as_deref());
370
371    // Handle --header: join first lines without sort check
372    if config.header && !lines1.is_empty() && !lines2.is_empty() {
373        let fields1 = split_fields(lines1[0], config.separator);
374        let fields2 = split_fields(lines2[0], config.separator);
375        let key = fields1.get(config.field1).copied().unwrap_or(b"");
376
377        if let Some(specs) = format {
378            write_paired_format(
379                &fields1, &fields2, key, specs, empty, out_sep, delim, &mut buf,
380            );
381        } else {
382            write_paired_default(
383                &fields1,
384                &fields2,
385                key,
386                config.field1,
387                config.field2,
388                out_sep,
389                delim,
390                &mut buf,
391            );
392        }
393        i1 = 1;
394        i2 = 1;
395    } else if config.header {
396        // One or both files empty — skip header
397        if !lines1.is_empty() {
398            i1 = 1;
399        }
400        if !lines2.is_empty() {
401            i2 = 1;
402        }
403    }
404
405    while i1 < lines1.len() && i2 < lines2.len() {
406        let key1 = extract_field(lines1[i1], config.field1, config.separator);
407        let key2 = extract_field(lines2[i2], config.field2, config.separator);
408
409        // Order checks
410        if config.order_check != OrderCheck::None {
411            if !warned1 && i1 > (if config.header { 1 } else { 0 }) {
412                let prev_key = extract_field(lines1[i1 - 1], config.field1, config.separator);
413                if compare_keys(key1, prev_key, ci) == Ordering::Less {
414                    had_order_error = true;
415                    warned1 = true;
416                    eprintln!(
417                        "{}: {}:{}: is not sorted: {}",
418                        tool_name,
419                        file1_name,
420                        i1 + 1,
421                        String::from_utf8_lossy(lines1[i1])
422                    );
423                    if config.order_check == OrderCheck::Strict {
424                        out.write_all(&buf)?;
425                        return Ok(true);
426                    }
427                }
428            }
429            if !warned2 && i2 > (if config.header { 1 } else { 0 }) {
430                let prev_key = extract_field(lines2[i2 - 1], config.field2, config.separator);
431                if compare_keys(key2, prev_key, ci) == Ordering::Less {
432                    had_order_error = true;
433                    warned2 = true;
434                    eprintln!(
435                        "{}: {}:{}: is not sorted: {}",
436                        tool_name,
437                        file2_name,
438                        i2 + 1,
439                        String::from_utf8_lossy(lines2[i2])
440                    );
441                    if config.order_check == OrderCheck::Strict {
442                        out.write_all(&buf)?;
443                        return Ok(true);
444                    }
445                }
446            }
447        }
448
449        match compare_keys(key1, key2, ci) {
450            Ordering::Less => {
451                if show_unpaired1 {
452                    let fields1 = split_fields(lines1[i1], config.separator);
453                    if let Some(specs) = format {
454                        write_unpaired_format(
455                            &fields1,
456                            0,
457                            config.field1,
458                            specs,
459                            empty,
460                            out_sep,
461                            delim,
462                            &mut buf,
463                        );
464                    } else {
465                        write_unpaired_default(&fields1, config.field1, out_sep, delim, &mut buf);
466                    }
467                }
468                i1 += 1;
469            }
470            Ordering::Greater => {
471                if show_unpaired2 {
472                    let fields2 = split_fields(lines2[i2], config.separator);
473                    if let Some(specs) = format {
474                        write_unpaired_format(
475                            &fields2,
476                            1,
477                            config.field2,
478                            specs,
479                            empty,
480                            out_sep,
481                            delim,
482                            &mut buf,
483                        );
484                    } else {
485                        write_unpaired_default(&fields2, config.field2, out_sep, delim, &mut buf);
486                    }
487                }
488                i2 += 1;
489
490                // Periodic flush to limit memory usage for large inputs
491                if buf.len() >= FLUSH_THRESHOLD {
492                    out.write_all(&buf)?;
493                    buf.clear();
494                }
495            }
496            Ordering::Equal => {
497                // Find all consecutive file2 lines with the same key
498                let group_start = i2;
499                let current_key = key2;
500                i2 += 1;
501                while i2 < lines2.len() {
502                    let next_key = extract_field(lines2[i2], config.field2, config.separator);
503                    if compare_keys(next_key, current_key, ci) != Ordering::Equal {
504                        break;
505                    }
506                    i2 += 1;
507                }
508
509                // For each file1 line with the same key, cross-product with file2 group
510                loop {
511                    if print_paired {
512                        let fields1 = split_fields(lines1[i1], config.separator);
513                        let key = fields1.get(config.field1).copied().unwrap_or(b"");
514                        for j in group_start..i2 {
515                            let fields2 = split_fields(lines2[j], config.separator);
516                            if let Some(specs) = format {
517                                write_paired_format(
518                                    &fields1, &fields2, key, specs, empty, out_sep, delim, &mut buf,
519                                );
520                            } else {
521                                write_paired_default(
522                                    &fields1,
523                                    &fields2,
524                                    key,
525                                    config.field1,
526                                    config.field2,
527                                    out_sep,
528                                    delim,
529                                    &mut buf,
530                                );
531                            }
532                        }
533                    }
534                    i1 += 1;
535                    if i1 >= lines1.len() {
536                        break;
537                    }
538                    let next_key = extract_field(lines1[i1], config.field1, config.separator);
539                    let cmp = compare_keys(next_key, current_key, ci);
540                    if cmp != Ordering::Equal {
541                        // Check order: next_key should be > current_key
542                        if config.order_check != OrderCheck::None
543                            && !warned1
544                            && cmp == Ordering::Less
545                        {
546                            had_order_error = true;
547                            warned1 = true;
548                            eprintln!(
549                                "{}: {}:{}: is not sorted: {}",
550                                tool_name,
551                                file1_name,
552                                i1 + 1,
553                                String::from_utf8_lossy(lines1[i1])
554                            );
555                            if config.order_check == OrderCheck::Strict {
556                                out.write_all(&buf)?;
557                                return Ok(true);
558                            }
559                        }
560                        break;
561                    }
562                }
563            }
564        }
565    }
566
567    // Drain remaining from file 1
568    while i1 < lines1.len() {
569        // Check sort order even when draining (GNU join does this)
570        if config.order_check != OrderCheck::None
571            && !warned1
572            && i1 > (if config.header { 1 } else { 0 })
573        {
574            let key1 = extract_field(lines1[i1], config.field1, config.separator);
575            let prev_key = extract_field(lines1[i1 - 1], config.field1, config.separator);
576            if compare_keys(key1, prev_key, ci) == Ordering::Less {
577                had_order_error = true;
578                warned1 = true;
579                eprintln!(
580                    "{}: {}:{}: is not sorted: {}",
581                    tool_name,
582                    file1_name,
583                    i1 + 1,
584                    String::from_utf8_lossy(lines1[i1])
585                );
586                if config.order_check == OrderCheck::Strict {
587                    out.write_all(&buf)?;
588                    return Ok(true);
589                }
590            }
591        }
592        if show_unpaired1 {
593            let fields1 = split_fields(lines1[i1], config.separator);
594            if let Some(specs) = format {
595                write_unpaired_format(
596                    &fields1,
597                    0,
598                    config.field1,
599                    specs,
600                    empty,
601                    out_sep,
602                    delim,
603                    &mut buf,
604                );
605            } else {
606                write_unpaired_default(&fields1, config.field1, out_sep, delim, &mut buf);
607            }
608        }
609        i1 += 1;
610    }
611
612    // Drain remaining from file 2
613    while i2 < lines2.len() {
614        // Check sort order even when draining (GNU join does this)
615        if config.order_check != OrderCheck::None
616            && !warned2
617            && i2 > (if config.header { 1 } else { 0 })
618        {
619            let key2 = extract_field(lines2[i2], config.field2, config.separator);
620            let prev_key = extract_field(lines2[i2 - 1], config.field2, config.separator);
621            if compare_keys(key2, prev_key, ci) == Ordering::Less {
622                had_order_error = true;
623                warned2 = true;
624                eprintln!(
625                    "{}: {}:{}: is not sorted: {}",
626                    tool_name,
627                    file2_name,
628                    i2 + 1,
629                    String::from_utf8_lossy(lines2[i2])
630                );
631                if config.order_check == OrderCheck::Strict {
632                    out.write_all(&buf)?;
633                    return Ok(true);
634                }
635            }
636        }
637        if show_unpaired2 {
638            let fields2 = split_fields(lines2[i2], config.separator);
639            if let Some(specs) = format {
640                write_unpaired_format(
641                    &fields2,
642                    1,
643                    config.field2,
644                    specs,
645                    empty,
646                    out_sep,
647                    delim,
648                    &mut buf,
649                );
650            } else {
651                write_unpaired_default(&fields2, config.field2, out_sep, delim, &mut buf);
652            }
653        }
654        i2 += 1;
655    }
656
657    out.write_all(&buf)?;
658    Ok(had_order_error)
659}