Skip to main content

ferray_strings/
extras.rs

1// ferray-strings: numpy.strings extras
2//
3// Functions added to round out NumPy 2.x parity that don't fit neatly
4// into the existing search/case/classify/strip/etc. modules:
5//   - encode/decode (UTF-8 codec ufuncs)
6//   - expandtabs
7//   - mod_ (printf-style % formatting)
8//   - partition / rpartition
9//   - slice (string-slicing ufunc)
10//   - translate (char-translate ufunc)
11
12use std::collections::HashMap;
13
14use ferray_core::dimension::Dimension;
15use ferray_core::error::{FerrayError, FerrayResult};
16
17use crate::string_array::StringArray;
18
19// ===========================================================================
20// encode / decode (UTF-8 codec ufuncs)
21// ===========================================================================
22
23/// Encode each string as UTF-8 bytes, returning a `Vec<Vec<u8>>` per element.
24///
25/// Equivalent to `numpy.strings.encode(arr, encoding="utf-8")` for the
26/// common UTF-8 case. Other encodings are intentionally not supported —
27/// the workspace standardizes on UTF-8 strings, and supporting Latin-1 /
28/// CP1252 / Shift-JIS would pull in `encoding_rs`.
29///
30/// Errors if `encoding` is not `"utf-8"` (case-insensitive) and not empty.
31///
32/// # Errors
33/// - `FerrayError::InvalidValue` if an unsupported encoding is requested.
34pub fn encode<D: Dimension>(a: &StringArray<D>, encoding: &str) -> FerrayResult<Vec<Vec<u8>>> {
35    let enc = encoding.trim().to_ascii_lowercase();
36    if !enc.is_empty() && enc != "utf-8" && enc != "utf8" {
37        return Err(FerrayError::invalid_value(format!(
38            "encode: only UTF-8 is supported, got {encoding:?}"
39        )));
40    }
41    Ok(a.iter().map(|s| s.as_bytes().to_vec()).collect())
42}
43
44/// Decode UTF-8 byte buffers back into a [`StringArray`] preserving the
45/// input shape.
46///
47/// Equivalent to `numpy.strings.decode(arr, encoding="utf-8")` for the
48/// UTF-8 case. The number of byte-buffers must equal the array's element
49/// count. Invalid UTF-8 produces `FerrayError::InvalidValue`.
50///
51/// # Errors
52/// - `FerrayError::ShapeMismatch` if `byte_arrays.len()` != `shape_len()`.
53/// - `FerrayError::InvalidValue` for unsupported encoding or invalid UTF-8.
54pub fn decode<D: Dimension>(
55    byte_arrays: &[Vec<u8>],
56    shape: D,
57    encoding: &str,
58) -> FerrayResult<StringArray<D>> {
59    let enc = encoding.trim().to_ascii_lowercase();
60    if !enc.is_empty() && enc != "utf-8" && enc != "utf8" {
61        return Err(FerrayError::invalid_value(format!(
62            "decode: only UTF-8 is supported, got {encoding:?}"
63        )));
64    }
65    let expected: usize = shape.as_slice().iter().product();
66    if byte_arrays.len() != expected {
67        return Err(FerrayError::shape_mismatch(format!(
68            "decode: got {} byte buffers, but shape {:?} requires {}",
69            byte_arrays.len(),
70            shape.as_slice(),
71            expected,
72        )));
73    }
74    let mut data = Vec::with_capacity(byte_arrays.len());
75    for (i, bytes) in byte_arrays.iter().enumerate() {
76        match std::str::from_utf8(bytes) {
77            Ok(s) => data.push(s.to_owned()),
78            Err(e) => {
79                return Err(FerrayError::invalid_value(format!(
80                    "decode: element {i} is not valid UTF-8: {e}"
81                )));
82            }
83        }
84    }
85    StringArray::from_vec(shape, data)
86}
87
88// ===========================================================================
89// expandtabs
90// ===========================================================================
91
92/// Replace tab characters in each element with spaces, expanding to the
93/// given tab size on column boundaries.
94///
95/// Equivalent to `numpy.strings.expandtabs(arr, tabsize)`.
96///
97/// # Errors
98/// Returns an error if the internal array construction fails.
99pub fn expandtabs<D: Dimension>(
100    a: &StringArray<D>,
101    tabsize: usize,
102) -> FerrayResult<StringArray<D>> {
103    a.map(|s| {
104        let mut out = String::with_capacity(s.len());
105        let mut col = 0usize;
106        for c in s.chars() {
107            match c {
108                '\t' => {
109                    let pad = if tabsize == 0 {
110                        0
111                    } else {
112                        tabsize - (col % tabsize)
113                    };
114                    for _ in 0..pad {
115                        out.push(' ');
116                    }
117                    col += pad;
118                }
119                '\n' | '\r' => {
120                    out.push(c);
121                    col = 0;
122                }
123                other => {
124                    out.push(other);
125                    col += 1;
126                }
127            }
128        }
129        out
130    })
131}
132
133// ===========================================================================
134// mod_ (printf-style % formatting)
135// ===========================================================================
136
137/// Format each element as a printf-style template, substituting `args` in.
138///
139/// Equivalent to `numpy.strings.mod(arr, args)` (the `%`-style binding).
140/// Supports the most common conversions:
141///   - `%s` — verbatim insertion of the next argument
142///   - `%%` — literal percent sign
143///   - `%d` — integer (parses argument as i64)
144///   - `%f` — float (parses argument as f64, default 6-decimal precision)
145///   - `%.Nf` — float with explicit precision N
146///   - `%e` / `%g` — scientific / shortest-of-fe representation
147///
148/// `args` is a slice indexed in conversion-order; each `%X` consumes the
149/// next argument in `args`.
150///
151/// Width/flag specifiers beyond `%.Nf` are not supported — parsing the
152/// full printf grammar inside a ufunc isn't worth the complexity given
153/// most NumPy users reach for f-strings instead.
154///
155/// # Errors
156/// - `FerrayError::InvalidValue` if `args.len()` is fewer than the number
157///   of conversions, an unknown specifier appears, or a `%d`/`%f`/etc.
158///   argument fails to parse.
159pub fn mod_<D: Dimension>(a: &StringArray<D>, args: &[&str]) -> FerrayResult<StringArray<D>> {
160    let strings: Vec<String> = a
161        .iter()
162        .map(|s| format_template(s, args))
163        .collect::<FerrayResult<Vec<_>>>()?;
164    StringArray::from_vec(a.dim().clone(), strings)
165}
166
167fn format_template(template: &str, args: &[&str]) -> FerrayResult<String> {
168    let mut out = String::with_capacity(template.len());
169    let mut chars = template.chars().peekable();
170    let mut argi = 0usize;
171    while let Some(c) = chars.next() {
172        if c != '%' {
173            out.push(c);
174            continue;
175        }
176        // Parse conversion spec.
177        let next = chars
178            .next()
179            .ok_or_else(|| FerrayError::invalid_value("mod: trailing '%' with no spec"))?;
180        if next == '%' {
181            out.push('%');
182            continue;
183        }
184        // Optional `.N` precision.
185        let mut precision: Option<usize> = None;
186        let mut spec = next;
187        if next == '.' {
188            let mut digits = String::new();
189            while let Some(&p) = chars.peek() {
190                if p.is_ascii_digit() {
191                    digits.push(p);
192                    chars.next();
193                } else {
194                    break;
195                }
196            }
197            precision = Some(
198                digits
199                    .parse::<usize>()
200                    .map_err(|_| FerrayError::invalid_value("mod: bad precision after '%.'"))?,
201            );
202            spec = chars
203                .next()
204                .ok_or_else(|| FerrayError::invalid_value("mod: missing spec char after '%.N'"))?;
205        }
206        let arg = args.get(argi).ok_or_else(|| {
207            FerrayError::invalid_value(format!("mod: not enough args at conversion {argi}"))
208        })?;
209        argi += 1;
210        match spec {
211            's' => out.push_str(arg),
212            'd' | 'i' => {
213                let v: i64 = arg.parse().map_err(|_| {
214                    FerrayError::invalid_value(format!("mod: %d arg {arg:?} is not an integer"))
215                })?;
216                out.push_str(&v.to_string());
217            }
218            'f' => {
219                let v: f64 = arg.parse().map_err(|_| {
220                    FerrayError::invalid_value(format!("mod: %f arg {arg:?} is not a float"))
221                })?;
222                let p = precision.unwrap_or(6);
223                out.push_str(&format!("{v:.p$}"));
224            }
225            'e' => {
226                let v: f64 = arg.parse().map_err(|_| {
227                    FerrayError::invalid_value(format!("mod: %e arg {arg:?} is not a float"))
228                })?;
229                let p = precision.unwrap_or(6);
230                out.push_str(&format!("{v:.p$e}"));
231            }
232            'g' => {
233                let v: f64 = arg.parse().map_err(|_| {
234                    FerrayError::invalid_value(format!("mod: %g arg {arg:?} is not a float"))
235                })?;
236                // Pick the shorter of `{:.6}` and `{:.6e}`, mirroring printf %g's intent.
237                let fixed = format!("{v}");
238                let sci = format!("{v:e}");
239                out.push_str(if fixed.len() <= sci.len() {
240                    &fixed
241                } else {
242                    &sci
243                });
244            }
245            other => {
246                return Err(FerrayError::invalid_value(format!(
247                    "mod: unsupported conversion '%{other}'"
248                )));
249            }
250        }
251    }
252    Ok(out)
253}
254
255// ===========================================================================
256// partition / rpartition
257// ===========================================================================
258
259/// Split each element on the first occurrence of `sep`, returning a
260/// `(before, sep, after)` triple per element.
261///
262/// If `sep` is not found in an element, the result is `(element, "", "")`.
263/// Mirrors `numpy.strings.partition`.
264///
265/// # Errors
266/// - `FerrayError::InvalidValue` if `sep` is empty.
267pub fn partition<D: Dimension>(
268    a: &StringArray<D>,
269    sep: &str,
270) -> FerrayResult<Vec<(String, String, String)>> {
271    if sep.is_empty() {
272        return Err(FerrayError::invalid_value(
273            "partition: separator must not be empty",
274        ));
275    }
276    Ok(a.iter()
277        .map(|s| match s.find(sep) {
278            Some(i) => (
279                s[..i].to_owned(),
280                sep.to_owned(),
281                s[i + sep.len()..].to_owned(),
282            ),
283            None => (s.clone(), String::new(), String::new()),
284        })
285        .collect())
286}
287
288/// Split each element on the last occurrence of `sep`, returning a
289/// `(before, sep, after)` triple per element.
290///
291/// If `sep` is not found, the result is `("", "", element)` (matching
292/// Python / NumPy `rpartition`).
293///
294/// # Errors
295/// - `FerrayError::InvalidValue` if `sep` is empty.
296pub fn rpartition<D: Dimension>(
297    a: &StringArray<D>,
298    sep: &str,
299) -> FerrayResult<Vec<(String, String, String)>> {
300    if sep.is_empty() {
301        return Err(FerrayError::invalid_value(
302            "rpartition: separator must not be empty",
303        ));
304    }
305    Ok(a.iter()
306        .map(|s| match s.rfind(sep) {
307            Some(i) => (
308                s[..i].to_owned(),
309                sep.to_owned(),
310                s[i + sep.len()..].to_owned(),
311            ),
312            None => (String::new(), String::new(), s.clone()),
313        })
314        .collect())
315}
316
317// ===========================================================================
318// slice (string-slicing ufunc)
319// ===========================================================================
320
321/// Slice each element by character index — `s[start..stop]` with negative
322/// indices counting from the end (Python-style). `None` for either bound
323/// keeps the corresponding edge.
324///
325/// `step` is intentionally not supported here (numpy.strings.slice's
326/// `step` parameter requires character-level reverse iteration that is
327/// non-trivial under multi-byte UTF-8). Pass two slice calls if you need
328/// `[start:stop:step]` behavior.
329///
330/// # Errors
331/// Returns an error if the internal array construction fails.
332pub fn slice<D: Dimension>(
333    a: &StringArray<D>,
334    start: Option<isize>,
335    stop: Option<isize>,
336) -> FerrayResult<StringArray<D>> {
337    a.map(|s| slice_str(s, start, stop))
338}
339
340fn slice_str(s: &str, start: Option<isize>, stop: Option<isize>) -> String {
341    let n = s.chars().count() as isize;
342    let resolve = |i: isize| -> usize {
343        let r = if i < 0 { n + i } else { i };
344        r.clamp(0, n) as usize
345    };
346    let lo = resolve(start.unwrap_or(0));
347    let hi = resolve(stop.unwrap_or(n));
348    if hi <= lo {
349        return String::new();
350    }
351    s.chars().skip(lo).take(hi - lo).collect()
352}
353
354// ===========================================================================
355// translate
356// ===========================================================================
357
358/// Apply a per-character translation table to each element.
359///
360/// `table` maps each input char to:
361///   - `Some(c)` — replacement character (single char)
362///   - `None` — drop the character (delete from output)
363///
364/// Characters not present in `table` are passed through unchanged.
365/// Mirrors `numpy.strings.translate(arr, table)`.
366///
367/// # Errors
368/// Returns an error if the internal array construction fails.
369pub fn translate<D: Dimension>(
370    a: &StringArray<D>,
371    table: &HashMap<char, Option<char>>,
372) -> FerrayResult<StringArray<D>> {
373    a.map(|s| {
374        let mut out = String::with_capacity(s.len());
375        for c in s.chars() {
376            match table.get(&c) {
377                Some(Some(replacement)) => out.push(*replacement),
378                Some(None) => {} // drop
379                None => out.push(c),
380            }
381        }
382        out
383    })
384}
385
386// ===========================================================================
387// Tests
388// ===========================================================================
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393    use crate::string_array::array;
394
395    #[test]
396    fn encode_decode_roundtrip() {
397        let a = array(&["hello", "world", "café"]).unwrap();
398        let bytes = encode(&a, "utf-8").unwrap();
399        assert_eq!(bytes.len(), 3);
400        let back = decode(&bytes, a.dim().clone(), "utf-8").unwrap();
401        assert_eq!(back.as_slice(), a.as_slice());
402    }
403
404    #[test]
405    fn encode_default_encoding_empty_string_is_utf8() {
406        let a = array(&["x"]).unwrap();
407        assert!(encode(&a, "").is_ok());
408    }
409
410    #[test]
411    fn encode_unsupported_errs() {
412        let a = array(&["x"]).unwrap();
413        assert!(encode(&a, "latin-1").is_err());
414    }
415
416    #[test]
417    fn decode_invalid_utf8_errs() {
418        use ferray_core::dimension::Ix1;
419        let bad = vec![vec![0xff_u8, 0xfe, 0xfd]];
420        let r = decode(&bad, Ix1::new([1]), "utf-8");
421        assert!(r.is_err());
422    }
423
424    #[test]
425    fn expandtabs_basic() {
426        let a = array(&["a\tb\tc"]).unwrap();
427        let r = expandtabs(&a, 4).unwrap();
428        // 'a' at col 0, tab pads to col 4 (3 spaces); 'b' at col 4, tab pads
429        // to col 8 (3 spaces); 'c' at col 8.
430        assert_eq!(r.as_slice(), &["a   b   c"]);
431    }
432
433    #[test]
434    fn expandtabs_resets_on_newline() {
435        let a = array(&["x\ty\n\tz"]).unwrap();
436        let r = expandtabs(&a, 4).unwrap();
437        // First line: x___y. Newline resets col=0; second tab pads to col 4
438        // → 4 spaces, then z.
439        assert_eq!(r.as_slice(), &["x   y\n    z"]);
440    }
441
442    #[test]
443    fn mod_simple_s() {
444        let a = array(&["hello %s"]).unwrap();
445        let r = mod_(&a, &["world"]).unwrap();
446        assert_eq!(r.as_slice(), &["hello world"]);
447    }
448
449    #[test]
450    fn mod_d_and_f() {
451        let a = array(&["%d items @ %.2f each"]).unwrap();
452        let r = mod_(&a, &["7", "2.5"]).unwrap();
453        assert_eq!(r.as_slice(), &["7 items @ 2.50 each"]);
454    }
455
456    #[test]
457    fn mod_double_percent_literal() {
458        let a = array(&["100%% pure %s"]).unwrap();
459        let r = mod_(&a, &["rust"]).unwrap();
460        assert_eq!(r.as_slice(), &["100% pure rust"]);
461    }
462
463    #[test]
464    fn mod_too_few_args_errs() {
465        let a = array(&["%s and %s"]).unwrap();
466        assert!(mod_(&a, &["only one"]).is_err());
467    }
468
469    #[test]
470    fn partition_found() {
471        let a = array(&["a-b-c", "no-sep-here"]).unwrap();
472        let r = partition(&a, "-").unwrap();
473        assert_eq!(
474            r,
475            vec![
476                ("a".to_owned(), "-".to_owned(), "b-c".to_owned()),
477                ("no".to_owned(), "-".to_owned(), "sep-here".to_owned()),
478            ]
479        );
480    }
481
482    #[test]
483    fn partition_not_found() {
484        let a = array(&["abc"]).unwrap();
485        let r = partition(&a, "-").unwrap();
486        assert_eq!(r, vec![("abc".to_owned(), String::new(), String::new())]);
487    }
488
489    #[test]
490    fn rpartition_found_picks_last() {
491        let a = array(&["a-b-c"]).unwrap();
492        let r = rpartition(&a, "-").unwrap();
493        assert_eq!(r, vec![("a-b".to_owned(), "-".to_owned(), "c".to_owned())]);
494    }
495
496    #[test]
497    fn rpartition_not_found() {
498        let a = array(&["abc"]).unwrap();
499        let r = rpartition(&a, "-").unwrap();
500        assert_eq!(r, vec![(String::new(), String::new(), "abc".to_owned())]);
501    }
502
503    #[test]
504    fn partition_empty_sep_errs() {
505        let a = array(&["x"]).unwrap();
506        assert!(partition(&a, "").is_err());
507        assert!(rpartition(&a, "").is_err());
508    }
509
510    #[test]
511    fn slice_basic() {
512        let a = array(&["hello", "world"]).unwrap();
513        let r = slice(&a, Some(1), Some(4)).unwrap();
514        assert_eq!(r.as_slice(), &["ell", "orl"]);
515    }
516
517    #[test]
518    fn slice_negative_indices() {
519        let a = array(&["abcdef"]).unwrap();
520        let r = slice(&a, Some(-3), None).unwrap();
521        assert_eq!(r.as_slice(), &["def"]);
522        let r2 = slice(&a, None, Some(-2)).unwrap();
523        assert_eq!(r2.as_slice(), &["abcd"]);
524    }
525
526    #[test]
527    fn slice_unicode_char_aware() {
528        let a = array(&["café"]).unwrap();
529        let r = slice(&a, Some(2), Some(4)).unwrap();
530        // 4 chars: c, a, f, é → [2..4] = "fé"
531        assert_eq!(r.as_slice(), &["fé"]);
532    }
533
534    #[test]
535    fn slice_empty_when_stop_le_start() {
536        let a = array(&["hello"]).unwrap();
537        let r = slice(&a, Some(3), Some(2)).unwrap();
538        assert_eq!(r.as_slice(), &[""]);
539    }
540
541    #[test]
542    fn translate_replace_and_drop() {
543        let a = array(&["hello"]).unwrap();
544        let mut table: HashMap<char, Option<char>> = HashMap::new();
545        table.insert('l', Some('L'));
546        table.insert('o', None); // drop
547        let r = translate(&a, &table).unwrap();
548        assert_eq!(r.as_slice(), &["heLL"]);
549    }
550
551    #[test]
552    fn translate_passthrough_unmapped() {
553        let a = array(&["abc"]).unwrap();
554        let table: HashMap<char, Option<char>> = HashMap::new();
555        let r = translate(&a, &table).unwrap();
556        assert_eq!(r.as_slice(), &["abc"]);
557    }
558}