Skip to main content

coreutils_rs/tac/
core.rs

1use rayon::prelude::*;
2use std::io::{self, IoSlice, Write};
3
4/// Maximum number of iovecs per writev() call (Linux IOV_MAX is 1024).
5const IOV_BATCH: usize = 1024;
6
7/// Minimum data size for parallel processing.
8const PAR_THRESHOLD: usize = 2 * 1024 * 1024;
9
10/// Write all IoSlices to the writer, handling partial writes.
11/// For large numbers of slices, batches into IOV_BATCH-sized groups.
12fn write_all_slices(out: &mut impl Write, slices: &[IoSlice<'_>]) -> io::Result<()> {
13    // Small number of slices: use simple write_all for each
14    if slices.len() <= 4 {
15        for s in slices {
16            out.write_all(s)?;
17        }
18        return Ok(());
19    }
20
21    let mut offset = 0;
22    while offset < slices.len() {
23        let end = (offset + IOV_BATCH).min(slices.len());
24        let n = out.write_vectored(&slices[offset..end])?;
25        if n == 0 {
26            return Err(io::Error::new(
27                io::ErrorKind::WriteZero,
28                "failed to write any data",
29            ));
30        }
31        let mut remaining = n;
32        while offset < end && remaining >= slices[offset].len() {
33            remaining -= slices[offset].len();
34            offset += 1;
35        }
36        if remaining > 0 && offset < end {
37            out.write_all(&slices[offset][remaining..])?;
38            offset += 1;
39        }
40    }
41    Ok(())
42}
43
44/// Reverse records separated by a single byte.
45/// Uses forward SIMD scan (memchr_iter) to collect all separator positions,
46/// then fills output buffer in reverse order with parallel copy for large data.
47/// Single write_all at the end for minimum syscall overhead.
48pub fn tac_bytes(data: &[u8], separator: u8, before: bool, out: &mut impl Write) -> io::Result<()> {
49    if data.is_empty() {
50        return Ok(());
51    }
52
53    // Forward SIMD scan: collect all separator positions in one pass.
54    // This is faster than memrchr_iter for building the complete positions list
55    // because forward scanning has better hardware prefetch behavior.
56    let positions: Vec<usize> = memchr::memchr_iter(separator, data).collect();
57
58    if positions.is_empty() {
59        return out.write_all(data);
60    }
61
62    // Build list of (src_start, src_end) record ranges in reversed output order.
63    // This allows us to compute exact output positions for parallel copy.
64    let mut records: Vec<(usize, usize)> = Vec::with_capacity(positions.len() + 2);
65
66    if !before {
67        // separator-after mode: records end with separator
68        let last_sep = *positions.last().unwrap();
69
70        // Trailing content without separator — output first
71        if last_sep + 1 < data.len() {
72            records.push((last_sep + 1, data.len()));
73        }
74
75        // Records in reverse: each record is from (prev_sep+1) to (cur_sep+1)
76        for i in (0..positions.len()).rev() {
77            let start = if i == 0 { 0 } else { positions[i - 1] + 1 };
78            let end = positions[i] + 1;
79            records.push((start, end));
80        }
81    } else {
82        // separator-before mode: records start with separator
83        for i in (0..positions.len()).rev() {
84            let start = positions[i];
85            let end = if i + 1 < positions.len() {
86                positions[i + 1]
87            } else {
88                data.len()
89            };
90            records.push((start, end));
91        }
92
93        // Leading content before first separator
94        if positions[0] > 0 {
95            records.push((0, positions[0]));
96        }
97    }
98
99    // Compute output offsets (prefix sum of record lengths)
100    let mut out_offsets: Vec<usize> = Vec::with_capacity(records.len());
101    let mut total = 0usize;
102    for &(start, end) in &records {
103        out_offsets.push(total);
104        total += end - start;
105    }
106
107    // Allocate output buffer
108    #[allow(clippy::uninit_vec)]
109    let mut outbuf: Vec<u8> = unsafe {
110        let mut v = Vec::with_capacity(total);
111        v.set_len(total); // SAFETY: fully overwritten by copy_nonoverlapping below
112        v
113    };
114
115    // For large data: parallel copy using rayon
116    if data.len() >= PAR_THRESHOLD && records.len() > 64 {
117        // Use usize to pass addresses across threads (raw ptrs aren't Send)
118        let out_base = outbuf.as_mut_ptr() as usize;
119        let data_base = data.as_ptr() as usize;
120        // SAFETY: Each record writes to a non-overlapping region of outbuf.
121        // out_offsets are monotonically increasing and non-overlapping.
122        records.par_iter().zip(out_offsets.par_iter()).for_each(
123            |(&(src_start, src_end), &dst_offset)| {
124                let len = src_end - src_start;
125                unsafe {
126                    std::ptr::copy_nonoverlapping(
127                        (data_base as *const u8).add(src_start),
128                        (out_base as *mut u8).add(dst_offset),
129                        len,
130                    );
131                }
132            },
133        );
134    } else {
135        // Small data: sequential copy using ptr::copy_nonoverlapping
136        let out_ptr: *mut u8 = outbuf.as_mut_ptr();
137        let data_ptr: *const u8 = data.as_ptr();
138        for (i, &(src_start, src_end)) in records.iter().enumerate() {
139            let len = src_end - src_start;
140            unsafe {
141                std::ptr::copy_nonoverlapping(
142                    data_ptr.add(src_start),
143                    out_ptr.add(out_offsets[i]),
144                    len,
145                );
146            }
147        }
148    }
149
150    out.write_all(&outbuf)
151}
152
153/// Reverse records using a multi-byte string separator.
154/// Uses SIMD-accelerated memmem for substring search + parallel reverse copy.
155pub fn tac_string_separator(
156    data: &[u8],
157    separator: &[u8],
158    before: bool,
159    out: &mut impl Write,
160) -> io::Result<()> {
161    if data.is_empty() {
162        return Ok(());
163    }
164
165    if separator.len() == 1 {
166        return tac_bytes(data, separator[0], before, out);
167    }
168
169    // Find all occurrences of the separator using SIMD-accelerated memmem
170    let positions: Vec<usize> = memchr::memmem::find_iter(data, separator).collect();
171
172    if positions.is_empty() {
173        return out.write_all(data);
174    }
175
176    let sep_len = separator.len();
177
178    // Build record ranges in reversed output order
179    let mut records: Vec<(usize, usize)> = Vec::with_capacity(positions.len() + 2);
180
181    if !before {
182        let last_end = positions.last().unwrap() + sep_len;
183        if last_end < data.len() {
184            records.push((last_end, data.len()));
185        }
186        for i in (0..positions.len()).rev() {
187            let rec_start = if i == 0 {
188                0
189            } else {
190                positions[i - 1] + sep_len
191            };
192            records.push((rec_start, positions[i] + sep_len));
193        }
194    } else {
195        for i in (0..positions.len()).rev() {
196            let start = positions[i];
197            let end = if i + 1 < positions.len() {
198                positions[i + 1]
199            } else {
200                data.len()
201            };
202            records.push((start, end));
203        }
204        if positions[0] > 0 {
205            records.push((0, positions[0]));
206        }
207    }
208
209    // Compute output offsets
210    let mut out_offsets: Vec<usize> = Vec::with_capacity(records.len());
211    let mut total = 0usize;
212    for &(start, end) in &records {
213        out_offsets.push(total);
214        total += end - start;
215    }
216
217    // Allocate and fill output buffer
218    #[allow(clippy::uninit_vec)]
219    let mut outbuf: Vec<u8> = unsafe {
220        let mut v = Vec::with_capacity(total);
221        v.set_len(total); // SAFETY: fully overwritten by copy_nonoverlapping below
222        v
223    };
224
225    if data.len() >= PAR_THRESHOLD && records.len() > 64 {
226        let out_base = outbuf.as_mut_ptr() as usize;
227        let data_base = data.as_ptr() as usize;
228        records.par_iter().zip(out_offsets.par_iter()).for_each(
229            |(&(src_start, src_end), &dst_offset)| {
230                let len = src_end - src_start;
231                unsafe {
232                    std::ptr::copy_nonoverlapping(
233                        (data_base as *const u8).add(src_start),
234                        (out_base as *mut u8).add(dst_offset),
235                        len,
236                    );
237                }
238            },
239        );
240    } else {
241        let out_ptr: *mut u8 = outbuf.as_mut_ptr();
242        let data_ptr: *const u8 = data.as_ptr();
243        for (i, &(src_start, src_end)) in records.iter().enumerate() {
244            let len = src_end - src_start;
245            unsafe {
246                std::ptr::copy_nonoverlapping(
247                    data_ptr.add(src_start),
248                    out_ptr.add(out_offsets[i]),
249                    len,
250                );
251            }
252        }
253    }
254
255    out.write_all(&outbuf)
256}
257
258/// Find regex matches using backward scanning, matching GNU tac's re_search behavior.
259/// GNU tac scans backward from the end, finding the rightmost starting position first.
260/// This produces different matches than forward scanning for patterns like [0-9]+.
261/// The matches are returned in left-to-right order.
262fn find_regex_matches_backward(data: &[u8], re: &regex::bytes::Regex) -> Vec<(usize, usize)> {
263    let mut matches = Vec::new();
264    let mut past_end = data.len();
265
266    while past_end > 0 {
267        let buf = &data[..past_end];
268        let mut found = false;
269
270        // Scan backward: try positions from past_end-1 down to 0
271        // We need the LAST match starting position in buf, so we try from the end
272        let mut pos = past_end;
273        while pos > 0 {
274            pos -= 1;
275            if let Some(m) = re.find_at(buf, pos) {
276                if m.start() == pos {
277                    // Match starts at exactly this position — this is the rightmost match start
278                    matches.push((m.start(), m.end()));
279                    past_end = m.start();
280                    found = true;
281                    break;
282                }
283                // Match starts later than pos — skip to before that match
284                // No point checking positions between pos and m.start() since
285                // find_at already told us the leftmost match from pos starts at m.start()
286                // But we need matches that START before m.start(), so continue decrementing
287            }
288            // If None, there's no match at pos or later, but there might be one earlier
289            // (find_at only searches forward from pos)
290        }
291
292        if !found {
293            break;
294        }
295    }
296
297    matches.reverse(); // Convert from backward order to left-to-right order
298    matches
299}
300
301/// Reverse records using a regex separator.
302/// Uses regex::bytes for direct byte-level matching (no UTF-8 conversion needed).
303/// NOTE: GNU tac uses POSIX Basic Regular Expressions (BRE), so we convert to ERE first.
304/// Uses backward scanning to match GNU tac's re_search behavior.
305pub fn tac_regex_separator(
306    data: &[u8],
307    pattern: &str,
308    before: bool,
309    out: &mut impl Write,
310) -> io::Result<()> {
311    if data.is_empty() {
312        return Ok(());
313    }
314
315    let re = match regex::bytes::Regex::new(pattern) {
316        Ok(r) => r,
317        Err(e) => {
318            return Err(io::Error::new(
319                io::ErrorKind::InvalidInput,
320                format!("invalid regex '{}': {}", pattern, e),
321            ));
322        }
323    };
324
325    // Use backward scanning to match GNU tac's re_search behavior
326    let matches = find_regex_matches_backward(data, &re);
327
328    if matches.is_empty() {
329        out.write_all(data)?;
330        return Ok(());
331    }
332
333    // Small data: contiguous buffer + single write (avoids IoSlice/writev overhead)
334    if data.len() < 16 * 1024 * 1024 {
335        let mut outbuf = Vec::with_capacity(data.len());
336
337        if !before {
338            let last_end = matches.last().unwrap().1;
339
340            if last_end < data.len() {
341                outbuf.extend_from_slice(&data[last_end..]);
342            }
343
344            let mut i = matches.len();
345            while i > 0 {
346                i -= 1;
347                let rec_start = if i == 0 { 0 } else { matches[i - 1].1 };
348                outbuf.extend_from_slice(&data[rec_start..matches[i].1]);
349            }
350        } else {
351            let mut i = matches.len();
352            while i > 0 {
353                i -= 1;
354                let start = matches[i].0;
355                let end = if i + 1 < matches.len() {
356                    matches[i + 1].0
357                } else {
358                    data.len()
359                };
360                outbuf.extend_from_slice(&data[start..end]);
361            }
362            if matches[0].0 > 0 {
363                outbuf.extend_from_slice(&data[..matches[0].0]);
364            }
365        }
366        return out.write_all(&outbuf);
367    }
368
369    // Large data: batched IoSlice/writev for zero-copy output
370    let mut batch: Vec<IoSlice<'_>> = Vec::with_capacity(IOV_BATCH);
371
372    if !before {
373        let last_end = matches.last().unwrap().1;
374        let has_trailing_sep = last_end == data.len();
375
376        if !has_trailing_sep {
377            batch.push(IoSlice::new(&data[last_end..]));
378        }
379
380        let mut i = matches.len();
381        while i > 0 {
382            i -= 1;
383            let rec_start = if i == 0 { 0 } else { matches[i - 1].1 };
384            let rec_end = matches[i].1;
385            batch.push(IoSlice::new(&data[rec_start..rec_end]));
386            if batch.len() == IOV_BATCH {
387                write_all_slices(out, &batch)?;
388                batch.clear();
389            }
390        }
391    } else {
392        let mut i = matches.len();
393        while i > 0 {
394            i -= 1;
395            let start = matches[i].0;
396            let end = if i + 1 < matches.len() {
397                matches[i + 1].0
398            } else {
399                data.len()
400            };
401            batch.push(IoSlice::new(&data[start..end]));
402            if batch.len() == IOV_BATCH {
403                write_all_slices(out, &batch)?;
404                batch.clear();
405            }
406        }
407
408        if matches[0].0 > 0 {
409            batch.push(IoSlice::new(&data[..matches[0].0]));
410        }
411    }
412
413    if !batch.is_empty() {
414        write_all_slices(out, &batch)?;
415    }
416
417    Ok(())
418}