Skip to main content

coreutils_rs/comm/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    /// Default: check, warn once per file, continue, exit 1
8    Default,
9    /// --check-order: check, error, stop immediately
10    Strict,
11    /// --nocheck-order: no checking
12    None,
13}
14
15/// Configuration for the comm command.
16pub struct CommConfig {
17    pub suppress_col1: bool,
18    pub suppress_col2: bool,
19    pub suppress_col3: bool,
20    pub case_insensitive: bool,
21    pub order_check: OrderCheck,
22    pub output_delimiter: Option<Vec<u8>>,
23    pub total: bool,
24    pub zero_terminated: bool,
25}
26
27impl Default for CommConfig {
28    fn default() -> Self {
29        Self {
30            suppress_col1: false,
31            suppress_col2: false,
32            suppress_col3: false,
33            case_insensitive: false,
34            order_check: OrderCheck::Default,
35            output_delimiter: None,
36            total: false,
37            zero_terminated: false,
38        }
39    }
40}
41
42/// Result of the comm operation.
43pub struct CommResult {
44    pub count1: usize,
45    pub count2: usize,
46    pub count3: usize,
47    pub had_order_error: bool,
48}
49
50/// Compare two byte slices, optionally case-insensitive (ASCII).
51#[inline(always)]
52fn compare_lines(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
53    if case_insensitive {
54        for (&ca, &cb) in a.iter().zip(b.iter()) {
55            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
56                Ordering::Equal => continue,
57                other => return other,
58            }
59        }
60        a.len().cmp(&b.len())
61    } else {
62        a.cmp(b)
63    }
64}
65
66/// Find the next line from data starting at `pos`, delimited by `delim`.
67/// Returns (line_slice, next_pos). If no delimiter found, returns remaining data.
68#[inline(always)]
69fn next_line(data: &[u8], pos: usize, delim: u8) -> (&[u8], usize) {
70    let remaining = &data[pos..];
71    match memchr::memchr(delim, remaining) {
72        Some(offset) => (&data[pos..pos + offset], pos + offset + 1),
73        std::option::Option::None => (remaining, data.len()),
74    }
75}
76
77/// Write prefix + line + delimiter to buf using unsafe raw pointer writes.
78/// Caller must ensure buf has sufficient capacity.
79#[inline(always)]
80unsafe fn write_line(buf: &mut Vec<u8>, prefix: &[u8], line: &[u8], delim: u8) {
81    unsafe {
82        let start = buf.len();
83        let total = prefix.len() + line.len() + 1;
84        let dst = buf.as_mut_ptr().add(start);
85        if !prefix.is_empty() {
86            std::ptr::copy_nonoverlapping(prefix.as_ptr(), dst, prefix.len());
87        }
88        if !line.is_empty() {
89            std::ptr::copy_nonoverlapping(line.as_ptr(), dst.add(prefix.len()), line.len());
90        }
91        *dst.add(prefix.len() + line.len()) = delim;
92        buf.set_len(start + total);
93    }
94}
95
96/// Ensure buf has at least `needed` bytes of spare capacity.
97#[inline(always)]
98fn ensure_capacity(buf: &mut Vec<u8>, needed: usize) {
99    let avail = buf.capacity() - buf.len();
100    if avail < needed {
101        buf.reserve(needed + 4 * 1024 * 1024);
102    }
103}
104
105/// Run the comm merge algorithm on two sorted inputs.
106pub fn comm(
107    data1: &[u8],
108    data2: &[u8],
109    config: &CommConfig,
110    tool_name: &str,
111    out: &mut impl Write,
112) -> io::Result<CommResult> {
113    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
114    let sep = config.output_delimiter.as_deref().unwrap_or(b"\t");
115
116    // Build column prefixes.
117    let prefix1: &[u8] = &[];
118    let prefix2_owned: Vec<u8> = if !config.suppress_col1 {
119        sep.to_vec()
120    } else {
121        Vec::new()
122    };
123    let mut prefix3_owned: Vec<u8> = Vec::new();
124    if !config.suppress_col1 {
125        prefix3_owned.extend_from_slice(sep);
126    }
127    if !config.suppress_col2 {
128        prefix3_owned.extend_from_slice(sep);
129    }
130
131    let show1 = !config.suppress_col1;
132    let show2 = !config.suppress_col2;
133    let show3 = !config.suppress_col3;
134    let ci = config.case_insensitive;
135    let check_order = config.order_check != OrderCheck::None;
136    let strict = config.order_check == OrderCheck::Strict;
137
138    // Pre-allocate output buffer generously
139    let total_input = data1.len() + data2.len();
140    let buf_cap = total_input.min(8 * 1024 * 1024);
141    let mut buf: Vec<u8> = Vec::with_capacity(buf_cap);
142    let flush_threshold = 4 * 1024 * 1024;
143
144    let mut count1 = 0usize;
145    let mut count2 = 0usize;
146    let mut count3 = 0usize;
147    let mut had_order_error = false;
148    let mut warned1 = false;
149    let mut warned2 = false;
150
151    // Streaming merge: track position and previous line for each file
152    let mut pos1 = 0usize;
153    let mut pos2 = 0usize;
154
155    // Strip trailing delimiter to avoid empty final line
156    let len1 = if !data1.is_empty() && data1.last() == Some(&delim) {
157        data1.len() - 1
158    } else {
159        data1.len()
160    };
161    let len2 = if !data2.is_empty() && data2.last() == Some(&delim) {
162        data2.len() - 1
163    } else {
164        data2.len()
165    };
166
167    // Previous line tracking for order checking
168    let mut prev1: &[u8] = &[];
169    let mut has_prev1 = false;
170    let mut prev2: &[u8] = &[];
171    let mut has_prev2 = false;
172
173    // Main merge loop: both files have remaining lines
174    while pos1 < len1 && pos2 < len2 {
175        let (line1, next1) = next_line(&data1[..len1], pos1, delim);
176        let (line2, next2) = next_line(&data2[..len2], pos2, delim);
177
178        match compare_lines(line1, line2, ci) {
179            Ordering::Less => {
180                // Check file1 order
181                if check_order
182                    && !warned1
183                    && has_prev1
184                    && compare_lines(line1, prev1, ci) == Ordering::Less
185                {
186                    had_order_error = true;
187                    warned1 = true;
188                    eprintln!("{}: file {} is not in sorted order", tool_name, 1);
189                    if strict {
190                        out.write_all(&buf)?;
191                        return Ok(CommResult {
192                            count1,
193                            count2,
194                            count3,
195                            had_order_error,
196                        });
197                    }
198                }
199                if show1 {
200                    ensure_capacity(&mut buf, prefix1.len() + line1.len() + 1);
201                    unsafe {
202                        write_line(&mut buf, prefix1, line1, delim);
203                    }
204                }
205                count1 += 1;
206                prev1 = line1;
207                has_prev1 = true;
208                pos1 = next1;
209            }
210            Ordering::Greater => {
211                // Check file2 order
212                if check_order
213                    && !warned2
214                    && has_prev2
215                    && compare_lines(line2, prev2, ci) == Ordering::Less
216                {
217                    had_order_error = true;
218                    warned2 = true;
219                    eprintln!("{}: file {} is not in sorted order", tool_name, 2);
220                    if strict {
221                        out.write_all(&buf)?;
222                        return Ok(CommResult {
223                            count1,
224                            count2,
225                            count3,
226                            had_order_error,
227                        });
228                    }
229                }
230                if show2 {
231                    ensure_capacity(&mut buf, prefix2_owned.len() + line2.len() + 1);
232                    unsafe {
233                        write_line(&mut buf, &prefix2_owned, line2, delim);
234                    }
235                }
236                count2 += 1;
237                prev2 = line2;
238                has_prev2 = true;
239                pos2 = next2;
240            }
241            Ordering::Equal => {
242                if show3 {
243                    ensure_capacity(&mut buf, prefix3_owned.len() + line1.len() + 1);
244                    unsafe {
245                        write_line(&mut buf, &prefix3_owned, line1, delim);
246                    }
247                }
248                count3 += 1;
249                prev1 = line1;
250                has_prev1 = true;
251                prev2 = line2;
252                has_prev2 = true;
253                pos1 = next1;
254                pos2 = next2;
255            }
256        }
257
258        if buf.len() >= flush_threshold {
259            out.write_all(&buf)?;
260            buf.clear();
261        }
262    }
263
264    // Drain remaining from file 1
265    // Fast path: if showing col1 and order check is done (or disabled), bulk copy
266    if pos1 < len1 && show1 && (!check_order || warned1) && prefix1.is_empty() {
267        // Bulk copy remainder — no per-line processing needed
268        let remaining = &data1[pos1..len1];
269        let line_count = memchr::memchr_iter(delim, remaining).count();
270        let has_trailing = !remaining.is_empty() && remaining.last() != Some(&delim);
271        count1 += line_count + if has_trailing { 1 } else { 0 };
272
273        // Flush current buffer, then write remainder directly
274        if !buf.is_empty() {
275            out.write_all(&buf)?;
276            buf.clear();
277        }
278        out.write_all(remaining)?;
279        if has_trailing {
280            out.write_all(&[delim])?;
281        }
282        pos1 = len1;
283    }
284    while pos1 < len1 {
285        let (line1, next1) = next_line(&data1[..len1], pos1, delim);
286        if check_order && !warned1 && has_prev1 && compare_lines(line1, prev1, ci) == Ordering::Less
287        {
288            had_order_error = true;
289            warned1 = true;
290            eprintln!("{}: file 1 is not in sorted order", tool_name);
291            if strict {
292                out.write_all(&buf)?;
293                return Ok(CommResult {
294                    count1,
295                    count2,
296                    count3,
297                    had_order_error,
298                });
299            }
300        }
301        if show1 {
302            ensure_capacity(&mut buf, line1.len() + 1);
303            unsafe {
304                write_line(&mut buf, prefix1, line1, delim);
305            }
306        }
307        count1 += 1;
308        prev1 = line1;
309        has_prev1 = true;
310        pos1 = next1;
311        if buf.len() >= flush_threshold {
312            out.write_all(&buf)?;
313            buf.clear();
314        }
315    }
316
317    // Drain remaining from file 2
318    // Fast path: bulk copy when order check is done and we have lines to drain
319    if pos2 < len2
320        && show2
321        && (!check_order || warned2)
322        && (config.suppress_col1 || prefix2_owned.is_empty())
323    {
324        let remaining = &data2[pos2..len2];
325        // Only bulk if no prefix needed (single column output) — otherwise per-line
326        if prefix2_owned.is_empty() {
327            let line_count = memchr::memchr_iter(delim, remaining).count();
328            let has_trailing = !remaining.is_empty() && remaining.last() != Some(&delim);
329            count2 += line_count + if has_trailing { 1 } else { 0 };
330            if !buf.is_empty() {
331                out.write_all(&buf)?;
332                buf.clear();
333            }
334            out.write_all(remaining)?;
335            if has_trailing {
336                out.write_all(&[delim])?;
337            }
338            pos2 = len2;
339        }
340    }
341    while pos2 < len2 {
342        let (line2, next2) = next_line(&data2[..len2], pos2, delim);
343        if check_order && !warned2 && has_prev2 && compare_lines(line2, prev2, ci) == Ordering::Less
344        {
345            had_order_error = true;
346            warned2 = true;
347            eprintln!("{}: file 2 is not in sorted order", tool_name);
348            if strict {
349                out.write_all(&buf)?;
350                return Ok(CommResult {
351                    count1,
352                    count2,
353                    count3,
354                    had_order_error,
355                });
356            }
357        }
358        if show2 {
359            ensure_capacity(&mut buf, prefix2_owned.len() + line2.len() + 1);
360            unsafe {
361                write_line(&mut buf, &prefix2_owned, line2, delim);
362            }
363        }
364        count2 += 1;
365        prev2 = line2;
366        has_prev2 = true;
367        pos2 = next2;
368        if buf.len() >= flush_threshold {
369            out.write_all(&buf)?;
370            buf.clear();
371        }
372    }
373
374    // Total summary line
375    if config.total {
376        let mut itoa_buf = itoa::Buffer::new();
377        buf.extend_from_slice(itoa_buf.format(count1).as_bytes());
378        buf.extend_from_slice(sep);
379        buf.extend_from_slice(itoa_buf.format(count2).as_bytes());
380        buf.extend_from_slice(sep);
381        buf.extend_from_slice(itoa_buf.format(count3).as_bytes());
382        buf.extend_from_slice(sep);
383        buf.extend_from_slice(b"total");
384        buf.push(delim);
385    }
386
387    if had_order_error && config.order_check == OrderCheck::Default {
388        eprintln!("{}: input is not in sorted order", tool_name);
389    }
390
391    out.write_all(&buf)?;
392    Ok(CommResult {
393        count1,
394        count2,
395        count3,
396        had_order_error,
397    })
398}