Skip to main content

coreutils_rs/fold/
core.rs

1use std::io::Write;
2
3/// Fold (wrap) lines to a given width.
4///
5/// Modes:
6/// - `bytes` mode (-b): count bytes, break at byte boundaries
7/// - default mode: count columns (tab = advance to next tab stop, backspace = decrement)
8///
9/// If `spaces` (-s): break at the last space within the width instead of mid-word.
10pub fn fold_bytes(
11    data: &[u8],
12    width: usize,
13    count_bytes: bool,
14    break_at_spaces: bool,
15    out: &mut impl Write,
16) -> std::io::Result<()> {
17    if data.is_empty() {
18        return Ok(());
19    }
20
21    if width == 0 {
22        return fold_width_zero(data, out);
23    }
24
25    // Fast path: byte mode, use SIMD-accelerated scanning
26    if count_bytes {
27        if break_at_spaces {
28            return fold_byte_fast_spaces(data, width, out);
29        } else {
30            return fold_byte_fast(data, width, out);
31        }
32    }
33
34    // Column mode without tabs: byte mode is equivalent (on glibc)
35    if memchr::memchr(b'\t', data).is_none() {
36        if break_at_spaces {
37            return fold_byte_fast_spaces(data, width, out);
38        } else {
39            return fold_byte_fast(data, width, out);
40        }
41    }
42
43    fold_column_mode_streaming(data, width, break_at_spaces, out)
44}
45
46/// Width 0: GNU fold behavior — each byte becomes a newline.
47fn fold_width_zero(data: &[u8], out: &mut impl Write) -> std::io::Result<()> {
48    let output = vec![b'\n'; data.len()];
49    out.write_all(&output)
50}
51
52/// Fast fold by byte count without -s flag.
53/// Uses unsafe pointer copies and a pre-allocated 1MB output buffer.
54/// For short lines (≤width), copies line+newline with a single memcpy.
55fn fold_byte_fast(data: &[u8], width: usize, out: &mut impl Write) -> std::io::Result<()> {
56    const BUF_CAP: usize = 1024 * 1024 + 4096;
57    let mut buf: Vec<u8> = Vec::with_capacity(BUF_CAP);
58    let base = buf.as_mut_ptr();
59    // SAFETY: `base` stays valid across `buf.clear()` calls because clear()
60    // retains the allocation. We never push/extend through the Vec API, so no
61    // reallocation occurs; `wp < BUF_CAP` is maintained before every write.
62    let src = data.as_ptr();
63    let mut wp: usize = 0;
64    let mut seg_start = 0usize;
65
66    for nl_pos in memchr::memchr_iter(b'\n', data) {
67        let seg_len = nl_pos - seg_start;
68
69        if seg_len <= width {
70            // Short line: fits without folding. Copy line + newline in one go.
71            let total = seg_len + 1;
72            if wp + total > BUF_CAP {
73                unsafe { buf.set_len(wp) };
74                out.write_all(&buf)?;
75                buf.clear();
76                wp = 0;
77            }
78            unsafe {
79                std::ptr::copy_nonoverlapping(src.add(seg_start), base.add(wp), total);
80            }
81            wp += total;
82        } else {
83            // Long line: fold at width boundaries.
84            let mut off = seg_start;
85            let end = nl_pos;
86            while off + width < end {
87                let chunk = width + 1; // width bytes + newline
88                if wp + chunk > BUF_CAP {
89                    unsafe { buf.set_len(wp) };
90                    out.write_all(&buf)?;
91                    buf.clear();
92                    wp = 0;
93                }
94                unsafe {
95                    std::ptr::copy_nonoverlapping(src.add(off), base.add(wp), width);
96                    *base.add(wp + width) = b'\n';
97                }
98                wp += chunk;
99                off += width;
100            }
101            // Remaining bytes + newline
102            let rem = end - off + 1; // includes the newline at nl_pos
103            if wp + rem > BUF_CAP {
104                unsafe { buf.set_len(wp) };
105                out.write_all(&buf)?;
106                buf.clear();
107                wp = 0;
108            }
109            unsafe {
110                std::ptr::copy_nonoverlapping(src.add(off), base.add(wp), rem);
111            }
112            wp += rem;
113        }
114        seg_start = nl_pos + 1;
115    }
116
117    // Handle final segment without trailing newline
118    if seg_start < data.len() {
119        let seg_len = data.len() - seg_start;
120        let mut off = seg_start;
121        let end = data.len();
122        while off + width < end {
123            let chunk = width + 1;
124            if wp + chunk > BUF_CAP {
125                unsafe { buf.set_len(wp) };
126                out.write_all(&buf)?;
127                buf.clear();
128                wp = 0;
129            }
130            unsafe {
131                std::ptr::copy_nonoverlapping(src.add(off), base.add(wp), width);
132                *base.add(wp + width) = b'\n';
133            }
134            wp += chunk;
135            off += width;
136        }
137        if off < end {
138            let rem = end - off;
139            if wp + rem > BUF_CAP {
140                unsafe { buf.set_len(wp) };
141                out.write_all(&buf)?;
142                buf.clear();
143                wp = 0;
144            }
145            unsafe {
146                std::ptr::copy_nonoverlapping(src.add(off), base.add(wp), rem);
147            }
148            wp += rem;
149        }
150        let _ = seg_len;
151    }
152
153    if wp > 0 {
154        unsafe { buf.set_len(wp) };
155        out.write_all(&buf)?;
156    }
157
158    Ok(())
159}
160
161/// Fast fold by byte count with -s (break at spaces).
162/// Buffers output into ~1MB chunks to minimize write syscalls.
163fn fold_byte_fast_spaces(data: &[u8], width: usize, out: &mut impl Write) -> std::io::Result<()> {
164    let mut outbuf: Vec<u8> = Vec::with_capacity(1024 * 1024 + 4096);
165    let mut pos: usize = 0;
166
167    for nl_pos in memchr::memchr_iter(b'\n', data) {
168        let segment = &data[pos..nl_pos];
169        fold_segment_bytes_spaces_buffered(segment, width, &mut outbuf);
170        outbuf.push(b'\n');
171        pos = nl_pos + 1;
172
173        if outbuf.len() >= 1024 * 1024 {
174            out.write_all(&outbuf)?;
175            outbuf.clear();
176        }
177    }
178
179    // Handle final segment without trailing newline
180    if pos < data.len() {
181        fold_segment_bytes_spaces_buffered(&data[pos..], width, &mut outbuf);
182    }
183
184    if !outbuf.is_empty() {
185        out.write_all(&outbuf)?;
186    }
187    Ok(())
188}
189
190/// Streaming fold by column count — single-pass stream using memchr2.
191/// Processes the entire file in one scan, finding both tabs and newlines
192/// simultaneously. Avoids the overhead of per-line decomposition + per-line
193/// tab checking (two separate SIMD passes over the data).
194fn fold_column_mode_streaming(
195    data: &[u8],
196    width: usize,
197    break_at_spaces: bool,
198    out: &mut impl Write,
199) -> std::io::Result<()> {
200    if break_at_spaces {
201        return fold_column_mode_spaces_streaming(data, width, out);
202    }
203
204    let mut outbuf: Vec<u8> = Vec::with_capacity(1024 * 1024 + 4096);
205    let mut col: usize = 0;
206    let mut seg_start: usize = 0;
207    let mut i: usize = 0;
208
209    while i < data.len() {
210        // SIMD scan: skip regular bytes, find next tab or newline
211        match memchr::memchr2(b'\t', b'\n', &data[i..]) {
212            Some(off) => {
213                let special_pos = i + off;
214                let run_len = special_pos - i;
215
216                // Check if regular bytes before the special char cause overflow
217                if col + run_len > width {
218                    // Need line breaks within this regular-byte run
219                    loop {
220                        let remaining = special_pos - i;
221                        let fit = width - col;
222                        if fit >= remaining {
223                            col += remaining;
224                            i = special_pos;
225                            break;
226                        }
227                        outbuf.extend_from_slice(&data[seg_start..i + fit]);
228                        outbuf.push(b'\n');
229                        i += fit;
230                        seg_start = i;
231                        col = 0;
232                    }
233                } else {
234                    col += run_len;
235                    i = special_pos;
236                }
237
238                // Handle the special character
239                if data[i] == b'\n' {
240                    outbuf.extend_from_slice(&data[seg_start..=i]);
241                    col = 0;
242                    i += 1;
243                    seg_start = i;
244                    if outbuf.len() >= 1024 * 1024 {
245                        out.write_all(&outbuf)?;
246                        outbuf.clear();
247                    }
248                } else {
249                    // Tab
250                    let new_col = ((col >> 3) + 1) << 3;
251                    if new_col > width && col > 0 {
252                        outbuf.extend_from_slice(&data[seg_start..i]);
253                        outbuf.push(b'\n');
254                        seg_start = i;
255                        col = 0;
256                        continue; // re-evaluate tab at col 0
257                    }
258                    col = new_col;
259                    i += 1;
260                }
261            }
262            None => {
263                // Remaining data is all regular bytes (no tabs or newlines)
264                let remaining = data.len() - i;
265                if col + remaining > width {
266                    loop {
267                        let rem_now = data.len() - i;
268                        let fit = width - col;
269                        if fit >= rem_now {
270                            break;
271                        }
272                        outbuf.extend_from_slice(&data[seg_start..i + fit]);
273                        outbuf.push(b'\n');
274                        i += fit;
275                        seg_start = i;
276                        col = 0;
277                    }
278                }
279                break;
280            }
281        }
282    }
283
284    if seg_start < data.len() {
285        outbuf.extend_from_slice(&data[seg_start..]);
286    }
287    if !outbuf.is_empty() {
288        out.write_all(&outbuf)?;
289    }
290
291    Ok(())
292}
293
294/// Fold a byte segment (no newlines) with -s (break at spaces), buffered output.
295#[inline]
296fn fold_segment_bytes_spaces_buffered(segment: &[u8], width: usize, outbuf: &mut Vec<u8>) {
297    let mut start = 0;
298    while start + width < segment.len() {
299        let chunk = &segment[start..start + width];
300        match memchr::memrchr2(b' ', b'\t', chunk) {
301            Some(sp_offset) => {
302                let break_at = start + sp_offset + 1;
303                outbuf.extend_from_slice(&segment[start..break_at]);
304                outbuf.push(b'\n');
305                start = break_at;
306            }
307            None => {
308                outbuf.extend_from_slice(&segment[start..start + width]);
309                outbuf.push(b'\n');
310                start += width;
311            }
312        }
313    }
314    if start < segment.len() {
315        outbuf.extend_from_slice(&segment[start..]);
316    }
317}
318
319/// Streaming fold column mode with -s (break at spaces).
320/// Uses buffered output to minimize write syscalls.
321/// Fast path: if no tabs in data, column width == byte width, so we can
322/// use the simpler byte-mode space-breaking algorithm.
323fn fold_column_mode_spaces_streaming(
324    data: &[u8],
325    width: usize,
326    out: &mut impl Write,
327) -> std::io::Result<()> {
328    // If no tabs, column mode == byte mode (every byte has width 1)
329    // BS/CR/control chars could theoretically differ but are vanishingly rare
330    // in practice and the difference is negligible.
331    if memchr::memchr(b'\t', data).is_none() {
332        return fold_byte_fast_spaces(data, width, out);
333    }
334
335    let mut pos = 0;
336    let mut outbuf: Vec<u8> = Vec::with_capacity(1024 * 1024 + 4096);
337
338    for nl_pos in memchr::memchr_iter(b'\n', data) {
339        let line = &data[pos..nl_pos];
340        // Short-circuit: line fits in width AND has no tabs → no folding needed
341        if line.len() <= width && memchr::memchr(b'\t', line).is_none() {
342            outbuf.extend_from_slice(line);
343        } else {
344            fold_column_spaces_fast(line, width, &mut outbuf);
345        }
346        outbuf.push(b'\n');
347
348        if outbuf.len() >= 1024 * 1024 {
349            out.write_all(&outbuf)?;
350            outbuf.clear();
351        }
352
353        pos = nl_pos + 1;
354    }
355
356    // Handle final line without trailing newline
357    if pos < data.len() {
358        let line = &data[pos..];
359        if line.len() <= width && memchr::memchr(b'\t', line).is_none() {
360            outbuf.extend_from_slice(line);
361        } else {
362            fold_column_spaces_fast(line, width, &mut outbuf);
363        }
364    }
365
366    if !outbuf.is_empty() {
367        out.write_all(&outbuf)?;
368    }
369
370    Ok(())
371}
372
373/// Fast column-mode fold for a single line with -s (break at spaces).
374/// Uses memchr2 to find tabs and spaces in bulk, processing runs of regular
375/// bytes without per-byte branching. Matches GNU fold's exact algorithm:
376/// - `column > width` triggers break (strictly greater)
377/// - Break at last blank: output INCLUDING the blank, remainder starts after it
378/// - After break: recalculate column from remaining data, re-process current char
379/// - All bytes width 1 except tab (next tab stop), BS (col-1), CR (col=0)
380#[inline]
381fn fold_column_spaces_fast(line: &[u8], width: usize, outbuf: &mut Vec<u8>) {
382    let mut col: usize = 0;
383    let mut seg_start: usize = 0;
384    let mut last_space_after: usize = 0;
385    let mut has_space = false;
386    let mut i: usize = 0;
387
388    while i < line.len() {
389        let b = line[i];
390        if b == b'\t' {
391            let new_col = ((col >> 3) + 1) << 3;
392            if new_col > width && col > 0 {
393                // Tab exceeds width — break
394                if has_space {
395                    outbuf.extend_from_slice(&line[seg_start..last_space_after]);
396                    outbuf.push(b'\n');
397                    seg_start = last_space_after;
398                    col = recalc_column(&line[seg_start..i]);
399                    has_space = false;
400                    continue; // re-evaluate tab
401                }
402                outbuf.extend_from_slice(&line[seg_start..i]);
403                outbuf.push(b'\n');
404                seg_start = i;
405                col = 0;
406                continue; // re-evaluate tab with col=0
407            }
408            // Tab also counts as a breakable whitespace for -s (GNU compat)
409            has_space = true;
410            last_space_after = i + 1;
411            col = new_col;
412            i += 1;
413        } else if b == b' ' {
414            col += 1;
415            if col > width {
416                if has_space {
417                    outbuf.extend_from_slice(&line[seg_start..last_space_after]);
418                    outbuf.push(b'\n');
419                    seg_start = last_space_after;
420                    col = recalc_column(&line[seg_start..i]);
421                    has_space = false;
422                    continue; // re-evaluate this space
423                }
424                // No prior blank — break before this space (GNU: output buffer, rescan)
425                outbuf.extend_from_slice(&line[seg_start..i]);
426                outbuf.push(b'\n');
427                seg_start = i;
428                col = 1; // space starts the new line with width 1
429                has_space = true;
430                last_space_after = i + 1;
431                i += 1;
432                continue;
433            }
434            has_space = true;
435            last_space_after = i + 1;
436            i += 1;
437        } else {
438            // Find next tab or space using SIMD memchr2
439            let run_end = match memchr::memchr2(b'\t', b' ', &line[i + 1..]) {
440                Some(off) => i + 1 + off,
441                None => line.len(),
442            };
443
444            // Process run of regular bytes: each has column width 1
445            let run_remaining = run_end - i;
446            if col + run_remaining <= width {
447                // Entire run fits
448                col += run_remaining;
449                i = run_end;
450            } else {
451                // Run exceeds width — need to break
452                let mut j = i;
453                loop {
454                    let rem = run_end - j;
455                    if col + rem <= width {
456                        col += rem;
457                        i = run_end;
458                        break;
459                    }
460                    if has_space {
461                        // Break at last blank (includes the blank)
462                        outbuf.extend_from_slice(&line[seg_start..last_space_after]);
463                        outbuf.push(b'\n');
464                        seg_start = last_space_after;
465                        col = j - seg_start; // regular bytes only, each width 1
466                        has_space = false;
467                        continue; // re-check with new col
468                    }
469                    // No blank — hard break at width boundary
470                    let fit = width - col;
471                    outbuf.extend_from_slice(&line[seg_start..j + fit]);
472                    outbuf.push(b'\n');
473                    j += fit;
474                    seg_start = j;
475                    col = 0;
476                }
477            }
478        }
479    }
480
481    if seg_start < line.len() {
482        outbuf.extend_from_slice(&line[seg_start..]);
483    }
484}
485
486/// Get the column width and byte length of a byte at `data[pos]`.
487/// Returns (column_width, byte_length) — always (1, 1) for non-special bytes.
488///
489/// GNU fold's multibyte path is guarded by:
490///   `#if HAVE_MBRTOC32 && (! defined __GLIBC__ || defined __UCLIBC__)`
491/// On glibc (every mainstream Linux distro), that condition is false, so
492/// fold counts bytes — one column per byte, same as -b mode.
493/// Tab, backspace, and CR are handled by the caller.
494#[inline]
495fn char_info(data: &[u8], pos: usize) -> (usize, usize) {
496    let b = data[pos];
497    if b < 0x80 {
498        // ASCII: tab/backspace handled by caller; control chars have 0 width
499        if b < 0x20 || b == 0x7f {
500            (0, 1)
501        } else {
502            (1, 1)
503        }
504    } else {
505        // High byte: count as 1 column, 1 byte (GNU glibc compat)
506        (1, 1)
507    }
508}
509
510/// Check if folding would produce identical output (all lines fit within width).
511/// Used by the binary for direct write-through optimization.
512pub fn fold_is_passthrough(data: &[u8], width: usize, count_bytes: bool) -> bool {
513    if width == 0 || data.is_empty() {
514        return data.is_empty();
515    }
516    // Column mode with tabs: can't easily determine passthrough
517    if !count_bytes && memchr::memchr(b'\t', data).is_some() {
518        return false;
519    }
520    let mut prev = 0;
521    for nl_pos in memchr::memchr_iter(b'\n', data) {
522        if nl_pos - prev > width {
523            return false;
524        }
525        prev = nl_pos + 1;
526    }
527    data.len() - prev <= width
528}
529
530/// Recalculate column position by replaying a segment (handles tabs, CR, backspace).
531/// Used when non-linear column operations (CR, backspace) invalidate the fast
532/// `col - col_at_space` delta formula.
533fn recalc_column(data: &[u8]) -> usize {
534    let mut col = 0;
535    let mut i = 0;
536    while i < data.len() {
537        let b = data[i];
538        if b == b'\r' {
539            col = 0;
540            i += 1;
541        } else if b == b'\t' {
542            col = ((col / 8) + 1) * 8;
543            i += 1;
544        } else if b == b'\x08' {
545            if col > 0 {
546                col -= 1;
547            }
548            i += 1;
549        } else if b < 0x80 {
550            if b >= 0x20 && b != 0x7f {
551                col += 1;
552            }
553            i += 1;
554        } else {
555            let (cw, byte_len) = char_info(data, i);
556            col += cw;
557            i += byte_len;
558        }
559    }
560    col
561}