Skip to main content

coreutils_rs/comm/
core.rs

1use std::cmp::Ordering;
2use std::io::{self, Write};
3
4/// How to handle sort-order checking.
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum OrderCheck {
7    /// Default: check, warn once per file, continue, exit 1
8    Default,
9    /// --check-order: check, error, stop immediately
10    Strict,
11    /// --nocheck-order: no checking
12    None,
13}
14
15/// Configuration for the comm command.
16pub struct CommConfig {
17    pub suppress_col1: bool,
18    pub suppress_col2: bool,
19    pub suppress_col3: bool,
20    pub case_insensitive: bool,
21    pub order_check: OrderCheck,
22    pub output_delimiter: Option<Vec<u8>>,
23    pub total: bool,
24    pub zero_terminated: bool,
25}
26
27impl Default for CommConfig {
28    fn default() -> Self {
29        Self {
30            suppress_col1: false,
31            suppress_col2: false,
32            suppress_col3: false,
33            case_insensitive: false,
34            order_check: OrderCheck::Default,
35            output_delimiter: None,
36            total: false,
37            zero_terminated: false,
38        }
39    }
40}
41
42/// Result of the comm operation.
43pub struct CommResult {
44    pub count1: usize,
45    pub count2: usize,
46    pub count3: usize,
47    pub had_order_error: bool,
48}
49
50/// Compare two byte slices, optionally case-insensitive (ASCII).
51#[inline]
52fn compare_lines(a: &[u8], b: &[u8], case_insensitive: bool) -> Ordering {
53    if case_insensitive {
54        for (&ca, &cb) in a.iter().zip(b.iter()) {
55            match ca.to_ascii_lowercase().cmp(&cb.to_ascii_lowercase()) {
56                Ordering::Equal => continue,
57                other => return other,
58            }
59        }
60        a.len().cmp(&b.len())
61    } else {
62        a.cmp(b)
63    }
64}
65
66/// Split data into lines by delimiter, using SIMD-accelerated scanning.
67/// Does NOT include a trailing empty line if data ends with the delimiter.
68fn split_lines<'a>(data: &'a [u8], delim: u8) -> Vec<&'a [u8]> {
69    if data.is_empty() {
70        return Vec::new();
71    }
72    let count = memchr::memchr_iter(delim, data).count();
73    let has_trailing = data.last() == Some(&delim);
74    let cap = if has_trailing { count } else { count + 1 };
75    let mut lines = Vec::with_capacity(cap);
76    let mut start = 0;
77    for pos in memchr::memchr_iter(delim, data) {
78        lines.push(&data[start..pos]);
79        start = pos + 1;
80    }
81    if start < data.len() {
82        lines.push(&data[start..]);
83    }
84    lines
85}
86
87/// Run the comm merge algorithm on two sorted inputs.
88pub fn comm(
89    data1: &[u8],
90    data2: &[u8],
91    config: &CommConfig,
92    tool_name: &str,
93    out: &mut impl Write,
94) -> io::Result<CommResult> {
95    let delim = if config.zero_terminated { b'\0' } else { b'\n' };
96    let sep = config.output_delimiter.as_deref().unwrap_or(b"\t");
97
98    // Build column prefixes. Each shown column before the current one
99    // contributes one copy of the separator.
100    // Column 1: always empty prefix.
101    let prefix2: Vec<u8> = if !config.suppress_col1 {
102        sep.to_vec()
103    } else {
104        Vec::new()
105    };
106    let mut prefix3: Vec<u8> = Vec::new();
107    if !config.suppress_col1 {
108        prefix3.extend_from_slice(sep);
109    }
110    if !config.suppress_col2 {
111        prefix3.extend_from_slice(sep);
112    }
113
114    let lines1 = split_lines(data1, delim);
115    let lines2 = split_lines(data2, delim);
116
117    let mut i1 = 0usize;
118    let mut i2 = 0usize;
119    let mut count1 = 0usize;
120    let mut count2 = 0usize;
121    let mut count3 = 0usize;
122    let mut had_order_error = false;
123    let mut warned1 = false;
124    let mut warned2 = false;
125    let ci = config.case_insensitive;
126
127    let mut buf = Vec::with_capacity((data1.len() + data2.len()).min(4 * 1024 * 1024));
128    let flush_threshold = 4 * 1024 * 1024; // Flush output buffer at 4MB to limit memory
129
130    // Macro to check sort order of a file and handle warnings/errors.
131    macro_rules! check_order {
132        ($warned:ident, $lines:ident, $idx:ident, $file_num:expr) => {
133            if config.order_check != OrderCheck::None
134                && !$warned
135                && $idx > 0
136                && compare_lines($lines[$idx], $lines[$idx - 1], ci) == Ordering::Less
137            {
138                had_order_error = true;
139                $warned = true;
140                eprintln!("{}: file {} is not in sorted order", tool_name, $file_num);
141                if config.order_check == OrderCheck::Strict {
142                    out.write_all(&buf)?;
143                    return Ok(CommResult {
144                        count1,
145                        count2,
146                        count3,
147                        had_order_error,
148                    });
149                }
150            }
151        };
152    }
153
154    while i1 < lines1.len() && i2 < lines2.len() {
155        match compare_lines(lines1[i1], lines2[i2], ci) {
156            Ordering::Less => {
157                // File1 line is unique — check file1 sort order before consuming
158                check_order!(warned1, lines1, i1, 1);
159                if !config.suppress_col1 {
160                    buf.extend_from_slice(lines1[i1]);
161                    buf.push(delim);
162                }
163                count1 += 1;
164                i1 += 1;
165            }
166            Ordering::Greater => {
167                // File2 line is unique — check file2 sort order before consuming
168                check_order!(warned2, lines2, i2, 2);
169                if !config.suppress_col2 {
170                    buf.extend_from_slice(&prefix2);
171                    buf.extend_from_slice(lines2[i2]);
172                    buf.push(delim);
173                }
174                count2 += 1;
175                i2 += 1;
176            }
177            Ordering::Equal => {
178                // Lines match — no sort check needed (GNU comm behavior)
179                if !config.suppress_col3 {
180                    buf.extend_from_slice(&prefix3);
181                    buf.extend_from_slice(lines1[i1]);
182                    buf.push(delim);
183                }
184                count3 += 1;
185                i1 += 1;
186                i2 += 1;
187            }
188        }
189
190        // Periodic flush to limit memory usage for large files
191        if buf.len() >= flush_threshold {
192            out.write_all(&buf)?;
193            buf.clear();
194        }
195    }
196
197    // Drain remaining from file 1
198    while i1 < lines1.len() {
199        if config.order_check != OrderCheck::None
200            && !warned1
201            && i1 > 0
202            && compare_lines(lines1[i1], lines1[i1 - 1], ci) == Ordering::Less
203        {
204            had_order_error = true;
205            warned1 = true;
206            eprintln!("{}: file 1 is not in sorted order", tool_name);
207            if config.order_check == OrderCheck::Strict {
208                out.write_all(&buf)?;
209                return Ok(CommResult {
210                    count1,
211                    count2,
212                    count3,
213                    had_order_error,
214                });
215            }
216        }
217        if !config.suppress_col1 {
218            buf.extend_from_slice(lines1[i1]);
219            buf.push(delim);
220        }
221        count1 += 1;
222        i1 += 1;
223    }
224
225    // Drain remaining from file 2
226    while i2 < lines2.len() {
227        if config.order_check != OrderCheck::None
228            && !warned2
229            && i2 > 0
230            && compare_lines(lines2[i2], lines2[i2 - 1], ci) == Ordering::Less
231        {
232            had_order_error = true;
233            warned2 = true;
234            eprintln!("{}: file 2 is not in sorted order", tool_name);
235            if config.order_check == OrderCheck::Strict {
236                out.write_all(&buf)?;
237                return Ok(CommResult {
238                    count1,
239                    count2,
240                    count3,
241                    had_order_error,
242                });
243            }
244        }
245        if !config.suppress_col2 {
246            buf.extend_from_slice(&prefix2);
247            buf.extend_from_slice(lines2[i2]);
248            buf.push(delim);
249        }
250        count2 += 1;
251        i2 += 1;
252    }
253
254    // Total summary line — use itoa for fast integer formatting
255    if config.total {
256        let mut itoa_buf = itoa::Buffer::new();
257        buf.extend_from_slice(itoa_buf.format(count1).as_bytes());
258        buf.extend_from_slice(sep);
259        buf.extend_from_slice(itoa_buf.format(count2).as_bytes());
260        buf.extend_from_slice(sep);
261        buf.extend_from_slice(itoa_buf.format(count3).as_bytes());
262        buf.extend_from_slice(sep);
263        buf.extend_from_slice(b"total");
264        buf.push(delim);
265    }
266
267    // In Default mode, print a final summary message (matches GNU comm behavior)
268    if had_order_error && config.order_check == OrderCheck::Default {
269        eprintln!("{}: input is not in sorted order", tool_name);
270    }
271
272    out.write_all(&buf)?;
273    Ok(CommResult {
274        count1,
275        count2,
276        count3,
277        had_order_error,
278    })
279}