Skip to main content

coreutils_rs/comm/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    /// Default: check, warn once per file, continue, exit 1
8    Default,
9    /// --check-order: check, error, stop immediately
10    Strict,
11    /// --nocheck-order: no checking
12    None,
13}
14
15/// Configuration for the comm command.
16pub struct CommConfig {
17    pub suppress_col1: bool,
18    pub suppress_col2: bool,
19    pub suppress_col3: bool,
20    pub case_insensitive: bool,
21    pub order_check: OrderCheck,
22    pub output_delimiter: Option<Vec<u8>>,
23    pub total: bool,
24    pub zero_terminated: bool,
25}
26
27impl Default for CommConfig {
28    fn default() -> Self {
29        Self {
30            suppress_col1: false,
31            suppress_col2: false,
32            suppress_col3: false,
33            case_insensitive: false,
34            order_check: OrderCheck::Default,
35            output_delimiter: None,
36            total: false,
37            zero_terminated: false,
38        }
39    }
40}
41
42/// Result of the comm operation.
43pub struct CommResult {
44    pub count1: usize,
45    pub count2: usize,
46    pub count3: usize,
47    pub had_order_error: bool,
48}
49
50/// Compare two byte slices, optionally case-insensitive (ASCII).
51#[inline(always)]
52fn compare_lines(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
53    if case_insensitive {
54        for (&ca, &cb) in a.iter().zip(b.iter()) {
55            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
56                Ordering::Equal => continue,
57                other => return other,
58            }
59        }
60        a.len().cmp(&b.len())
61    } else {
62        a.cmp(b)
63    }
64}
65
66/// Write prefix + line + delimiter to buf using unsafe raw pointer writes.
67/// Caller must ensure buf has sufficient capacity.
68#[inline(always)]
69unsafe fn write_line(buf: &mut Vec<u8>, prefix: &[u8], line: &[u8], delim: u8) {
70    unsafe {
71        let start = buf.len();
72        let total = prefix.len() + line.len() + 1;
73        let dst = buf.as_mut_ptr().add(start);
74        if !prefix.is_empty() {
75            std::ptr::copy_nonoverlapping(prefix.as_ptr(), dst, prefix.len());
76        }
77        if !line.is_empty() {
78            std::ptr::copy_nonoverlapping(line.as_ptr(), dst.add(prefix.len()), line.len());
79        }
80        *dst.add(prefix.len() + line.len()) = delim;
81        buf.set_len(start + total);
82    }
83}
84
85/// Ensure buf has at least `needed` bytes of spare capacity.
86#[inline(always)]
87fn ensure_capacity(buf: &mut Vec<u8>, needed: usize) {
88    let avail = buf.capacity() - buf.len();
89    if avail < needed {
90        buf.reserve(needed + 64 * 1024);
91    }
92}
93
94/// Fast path for identical inputs: all lines go to column 3.
95/// Avoids the merge loop entirely — single memchr scan with direct output.
96fn comm_identical(
97    data: &[u8],
98    config: &CommConfig,
99    delim: u8,
100    sep: &[u8],
101    out: &mut impl Write,
102) -> io::Result<CommResult> {
103    let show3 = !config.suppress_col3;
104
105    // Count lines for the result
106    let stripped = if !data.is_empty() && data.last() == Some(&delim) {
107        &data[..data.len() - 1]
108    } else {
109        data
110    };
111    let line_count = if stripped.is_empty() {
112        0
113    } else {
114        memchr::memchr_iter(delim, stripped).count() + 1
115    };
116
117    if show3 {
118        // Build column 3 prefix
119        let mut prefix = Vec::new();
120        if !config.suppress_col1 {
121            prefix.extend_from_slice(sep);
122        }
123        if !config.suppress_col2 {
124            prefix.extend_from_slice(sep);
125        }
126
127        // Stream output in 256KB chunks
128        let mut buf: Vec<u8> = Vec::with_capacity(256 * 1024);
129        let mut pos = 0;
130        for nl_pos in memchr::memchr_iter(delim, stripped) {
131            let line = &stripped[pos..nl_pos];
132            let needed = prefix.len() + line.len() + 1;
133            if buf.len() + needed > 192 * 1024 {
134                out.write_all(&buf)?;
135                buf.clear();
136            }
137            if buf.capacity() - buf.len() < needed {
138                buf.reserve(needed + 64 * 1024);
139            }
140            unsafe {
141                write_line(&mut buf, &prefix, line, delim);
142            }
143            pos = nl_pos + 1;
144        }
145        // Handle last line without trailing delimiter
146        if pos < stripped.len() {
147            let line = &stripped[pos..];
148            let needed = prefix.len() + line.len() + 1;
149            if buf.capacity() - buf.len() < needed {
150                buf.reserve(needed + 1024);
151            }
152            unsafe {
153                write_line(&mut buf, &prefix, line, delim);
154            }
155        }
156        if !buf.is_empty() {
157            out.write_all(&buf)?;
158        }
159    }
160
161    Ok(CommResult {
162        count1: 0,
163        count2: 0,
164        count3: line_count,
165        had_order_error: false,
166    })
167}
168
169/// Run the comm merge algorithm on two sorted inputs.
170pub fn comm(
171    data1: &[u8],
172    data2: &[u8],
173    config: &CommConfig,
174    tool_name: &str,
175    out: &mut impl Write,
176) -> io::Result<CommResult> {
177    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
178    let sep = config.output_delimiter.as_deref().unwrap_or(b"\t");
179
180    // Fast path: identical inputs → all lines are common (column 3).
181    // Avoids per-line comparison entirely. Uses single memchr scan.
182    // Only safe when order checking is disabled: for unsorted-but-identical files,
183    // GNU comm still reports a sort-order violation, so we must fall through to
184    // the merge loop which detects out-of-order adjacent lines.
185    if data1 == data2
186        && !config.case_insensitive
187        && !config.total
188        && config.order_check == OrderCheck::None
189    {
190        return comm_identical(data1, config, delim, sep, out);
191    }
192
193    // Build column prefixes.
194    let prefix1: &[u8] = &[];
195    let prefix2_owned: Vec<u8> = if !config.suppress_col1 {
196        sep.to_vec()
197    } else {
198        Vec::new()
199    };
200    let mut prefix3_owned: Vec<u8> = Vec::new();
201    if !config.suppress_col1 {
202        prefix3_owned.extend_from_slice(sep);
203    }
204    if !config.suppress_col2 {
205        prefix3_owned.extend_from_slice(sep);
206    }
207
208    let show1 = !config.suppress_col1;
209    let show2 = !config.suppress_col2;
210    let show3 = !config.suppress_col3;
211    let ci = config.case_insensitive;
212    let check_order = config.order_check != OrderCheck::None;
213    let strict = config.order_check == OrderCheck::Strict;
214
215    // Use a 256KB output buffer to minimize page faults on first fill.
216    // 256KB = 64 pages — faulted once, then stays warm in L2/TLB for reuse.
217    // Flushed ~40x for 10MB output vs 2x for 4MB, but each flush is fast
218    // (~2µs) and we save ~1000 page faults * ~4µs = ~4ms.
219    let buf_cap = 256 * 1024;
220    let mut buf: Vec<u8> = Vec::with_capacity(buf_cap);
221    let flush_threshold = 192 * 1024;
222
223    let mut count1 = 0usize;
224    let mut count2 = 0usize;
225    let mut count3 = 0usize;
226    let mut had_order_error = false;
227    let mut warned1 = false;
228    let mut warned2 = false;
229
230    // Strip trailing delimiter to avoid empty final line
231    let len1 = if !data1.is_empty() && data1.last() == Some(&delim) {
232        data1.len() - 1
233    } else {
234        data1.len()
235    };
236    let len2 = if !data2.is_empty() && data2.last() == Some(&delim) {
237        data2.len() - 1
238    } else {
239        data2.len()
240    };
241
242    // Use memchr_iter for amortized SIMD line scanning instead of per-line memchr.
243    // Each iterator maintains internal SIMD state across the entire file, eliminating
244    // per-line setup overhead (~150K saved function calls for 10MB files).
245    let mut iter1 = memchr::memchr_iter(delim, &data1[..len1]);
246    let mut iter2 = memchr::memchr_iter(delim, &data2[..len2]);
247    let mut pos1 = 0usize;
248    let mut pos2 = 0usize;
249    let mut end1 = iter1.next().unwrap_or(len1);
250    let mut end2 = iter2.next().unwrap_or(len2);
251
252    // Previous line tracking for order checking
253    let mut prev1: &[u8] = &[];
254    let mut has_prev1 = false;
255    let mut prev2: &[u8] = &[];
256    let mut has_prev2 = false;
257
258    // Main merge loop: both files have remaining lines
259    while pos1 < len1 && pos2 < len2 {
260        let line1 = &data1[pos1..end1];
261        let line2 = &data2[pos2..end2];
262
263        match compare_lines(line1, line2, ci) {
264            Ordering::Less => {
265                if check_order
266                    && !warned1
267                    && has_prev1
268                    && compare_lines(line1, prev1, ci) == Ordering::Less
269                {
270                    had_order_error = true;
271                    warned1 = true;
272                    eprintln!("{}: file {} is not in sorted order", tool_name, 1);
273                    if strict {
274                        out.write_all(&buf)?;
275                        return Ok(CommResult {
276                            count1,
277                            count2,
278                            count3,
279                            had_order_error,
280                        });
281                    }
282                }
283                if show1 {
284                    ensure_capacity(&mut buf, prefix1.len() + line1.len() + 1);
285                    unsafe {
286                        write_line(&mut buf, prefix1, line1, delim);
287                    }
288                }
289                count1 += 1;
290                prev1 = line1;
291                has_prev1 = true;
292                pos1 = end1 + 1;
293                end1 = iter1.next().unwrap_or(len1);
294            }
295            Ordering::Greater => {
296                if check_order
297                    && !warned2
298                    && has_prev2
299                    && compare_lines(line2, prev2, ci) == Ordering::Less
300                {
301                    had_order_error = true;
302                    warned2 = true;
303                    eprintln!("{}: file {} is not in sorted order", tool_name, 2);
304                    if strict {
305                        out.write_all(&buf)?;
306                        return Ok(CommResult {
307                            count1,
308                            count2,
309                            count3,
310                            had_order_error,
311                        });
312                    }
313                }
314                if show2 {
315                    ensure_capacity(&mut buf, prefix2_owned.len() + line2.len() + 1);
316                    unsafe {
317                        write_line(&mut buf, &prefix2_owned, line2, delim);
318                    }
319                }
320                count2 += 1;
321                prev2 = line2;
322                has_prev2 = true;
323                pos2 = end2 + 1;
324                end2 = iter2.next().unwrap_or(len2);
325            }
326            Ordering::Equal => {
327                if show3 {
328                    ensure_capacity(&mut buf, prefix3_owned.len() + line1.len() + 1);
329                    unsafe {
330                        write_line(&mut buf, &prefix3_owned, line1, delim);
331                    }
332                }
333                count3 += 1;
334                prev1 = line1;
335                has_prev1 = true;
336                prev2 = line2;
337                has_prev2 = true;
338                pos1 = end1 + 1;
339                end1 = iter1.next().unwrap_or(len1);
340                pos2 = end2 + 1;
341                end2 = iter2.next().unwrap_or(len2);
342            }
343        }
344
345        if buf.len() >= flush_threshold {
346            out.write_all(&buf)?;
347            buf.clear();
348        }
349    }
350
351    // Drain remaining from file 1
352    // Fast path: if showing col1 and order check is done (or disabled), bulk copy
353    if pos1 < len1 && show1 && (!check_order || warned1) && prefix1.is_empty() {
354        let remaining = &data1[pos1..len1];
355        let line_count = memchr::memchr_iter(delim, remaining).count();
356        let has_trailing = !remaining.is_empty() && remaining.last() != Some(&delim);
357        count1 += line_count + if has_trailing { 1 } else { 0 };
358
359        if !buf.is_empty() {
360            out.write_all(&buf)?;
361            buf.clear();
362        }
363        out.write_all(remaining)?;
364        if has_trailing {
365            out.write_all(&[delim])?;
366        }
367        pos1 = len1;
368    }
369    while pos1 < len1 {
370        let line1 = &data1[pos1..end1];
371        if check_order && !warned1 && has_prev1 && compare_lines(line1, prev1, ci) == Ordering::Less
372        {
373            had_order_error = true;
374            warned1 = true;
375            eprintln!("{}: file 1 is not in sorted order", tool_name);
376            if strict {
377                out.write_all(&buf)?;
378                return Ok(CommResult {
379                    count1,
380                    count2,
381                    count3,
382                    had_order_error,
383                });
384            }
385        }
386        if show1 {
387            ensure_capacity(&mut buf, line1.len() + 1);
388            unsafe {
389                write_line(&mut buf, prefix1, line1, delim);
390            }
391        }
392        count1 += 1;
393        prev1 = line1;
394        has_prev1 = true;
395        pos1 = end1 + 1;
396        end1 = iter1.next().unwrap_or(len1);
397        if buf.len() >= flush_threshold {
398            out.write_all(&buf)?;
399            buf.clear();
400        }
401    }
402
403    // Drain remaining from file 2
404    // Fast path: bulk copy when order check is done and we have lines to drain
405    if pos2 < len2
406        && show2
407        && (!check_order || warned2)
408        && (config.suppress_col1 || prefix2_owned.is_empty())
409    {
410        let remaining = &data2[pos2..len2];
411        if prefix2_owned.is_empty() {
412            let line_count = memchr::memchr_iter(delim, remaining).count();
413            let has_trailing = !remaining.is_empty() && remaining.last() != Some(&delim);
414            count2 += line_count + if has_trailing { 1 } else { 0 };
415            if !buf.is_empty() {
416                out.write_all(&buf)?;
417                buf.clear();
418            }
419            out.write_all(remaining)?;
420            if has_trailing {
421                out.write_all(&[delim])?;
422            }
423            pos2 = len2;
424        }
425    }
426    while pos2 < len2 {
427        let line2 = &data2[pos2..end2];
428        if check_order && !warned2 && has_prev2 && compare_lines(line2, prev2, ci) == Ordering::Less
429        {
430            had_order_error = true;
431            warned2 = true;
432            eprintln!("{}: file 2 is not in sorted order", tool_name);
433            if strict {
434                out.write_all(&buf)?;
435                return Ok(CommResult {
436                    count1,
437                    count2,
438                    count3,
439                    had_order_error,
440                });
441            }
442        }
443        if show2 {
444            ensure_capacity(&mut buf, prefix2_owned.len() + line2.len() + 1);
445            unsafe {
446                write_line(&mut buf, &prefix2_owned, line2, delim);
447            }
448        }
449        count2 += 1;
450        prev2 = line2;
451        has_prev2 = true;
452        pos2 = end2 + 1;
453        end2 = iter2.next().unwrap_or(len2);
454        if buf.len() >= flush_threshold {
455            out.write_all(&buf)?;
456            buf.clear();
457        }
458    }
459
460    // Total summary line
461    if config.total {
462        let mut itoa_buf = itoa::Buffer::new();
463        buf.extend_from_slice(itoa_buf.format(count1).as_bytes());
464        buf.extend_from_slice(sep);
465        buf.extend_from_slice(itoa_buf.format(count2).as_bytes());
466        buf.extend_from_slice(sep);
467        buf.extend_from_slice(itoa_buf.format(count3).as_bytes());
468        buf.extend_from_slice(sep);
469        buf.extend_from_slice(b"total");
470        buf.push(delim);
471    }
472
473    if had_order_error && config.order_check == OrderCheck::Default {
474        eprintln!("{}: input is not in sorted order", tool_name);
475    }
476
477    out.write_all(&buf)?;
478    Ok(CommResult {
479        count1,
480        count2,
481        count3,
482        had_order_error,
483    })
484}