numscan 0.1.0

Extract numbers from text
Documentation
//! # numscan
//! A library to scan for numbers in text.
//!
//! - A number must have a digit in it.
//! - A number may have a minus (-) prefix.
//! - A number may contain digits, dots, and comma.
//! - A number ends on the first non-numeric character it finds or second dot, and never on comma or dot.
//! - The library only looks for ASCII digits (in contrast to Arabic numerals etc.)
//! - Comma separators can be used liberally, they're mostly ignored.
//!
//! # Examples
//! ```rust
//! use numscan::NumberScanner;
//!
//! let input = "1.9 - (-1.7) + 1,000 is 1,000.2.";
//! let output = NumberScanner::from(input).collect::<Vec<_>>();
//!
//! assert_eq!(
//!   output,
//!   [
//!     0..3,   // "1.9"
//!     7..11,  // "-1.7"
//!     15..20, // "1,000"
//!     24..31, // "1,000.2"
//!   ]
//! );
//! ```
//!
//! # Tip: Parsing
//! The library returns ranges which you can use for indexing the numbers.
//! In the likely scenario you want to parse them as numbers, keep in mind:
//! 
//! - You might want to use a [decimal library](https://crates.io/keywords/decimal),
//!   if precision matters at all.
//! - The strings might contain commas which your number type's parsing function is
//!   likely to not accept, so you might want to remove them before attempting to
//!   parse.

#![feature(portable_simd)]
#![warn(clippy::missing_inline_in_public_items)]
#![deny(missing_docs, rustdoc::broken_intra_doc_links)]
#![no_std]

use core::{simd::{SimdPartialOrd, SimdPartialEq, ToBitMask}, ops::Range};

/// Scans for numbers.
#[derive(Debug, Hash, Default, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
pub struct NumberScanner<'a> {
    /// The input data.
    input: &'a [u8],
    /// How far we are in the [`input`](NumberScanner::input).
    offset: usize,
}

// Construction {{{

impl<'a> NumberScanner<'a> {
    /// Creates a new number scanner for the given string.
    #[inline(always)]
    pub const fn new(input: &'a str) -> Self { Self::from_bytes(input.as_bytes()) }
    /// Creates a new number scanner for the given bytes.
    #[inline(always)]
    pub const fn from_bytes(input: &'a [u8]) -> Self { Self { input, offset: 0 } }
}

impl<'a> From<&'a str> for NumberScanner<'a> {
    #[inline(always)]
    fn from(input: &'a str) -> Self { Self::new(input) }
}

impl<'a> From<&'a [u8]> for NumberScanner<'a> {
    #[inline(always)]
    fn from(input: &'a [u8]) -> Self { Self::from_bytes(input) }
}

impl<'a, const N: usize> From<&'a [u8; N]> for NumberScanner<'a> {
    #[inline(always)]
    fn from(input: &'a [u8; N]) -> Self { Self::from_bytes(&input[..]) }
}

// }}}

/// SIMD [`scan`](Self::scan) results.
///
/// The library uses SIMD to locate numbers.
#[derive(Debug, Hash, Default, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
struct SimdScanResult {
    /// The index where the number starts.
    start: usize,
    /// The index where the number ends _in the chunk_ (see [`proper_end`]).
    end: usize,
    /// Whether [`end`] points to the actual end of the number.
    /// If [`false`], then [`end`] to points either the end of the number or
    /// an index within the number (and the end of the number can be looked for
    /// past it, i.e. it's unnecessary to look for the end from [`start`]).
    ///
    /// The reason we don't just get the proper end is because this type
    /// represents a result for a single SIMD chunk. If we find a number that
    /// goes on to the end of the SIMD chunk, we can't tell from that alone
    /// that there isn't more of that number past the chunk.
    proper_end: bool,
}

impl SimdScanResult {
    /// Scans for [SIMD result](Self).
    ///
    /// Returns a tuple with a result if found, and the length of tested bytes.
    /// That is, if this function checked for numbers in the first 64 bytes,
    /// it will return 64, which is also the next index to check.
    fn scan(input: &[u8]) -> (Option<Self>, usize) {
        type Simd = core::simd::u8x32;
        let lanes = round_mult::NonZeroPow2::of::<Simd>();

        let mut i = 0;
        while i < round_mult::down(input.len(), lanes) {
            let chunk = Simd::from_slice(&input[i..]);
            let digits_mask = (chunk.simd_ge(Simd::splat(b'0')) & chunk.simd_le(Simd::splat(b'9'))).to_bitmask();

            let skip = digits_mask.trailing_zeros() as usize;
            if skip < Simd::LANES {
                let accept_mask = digits_mask | (chunk.simd_eq(Simd::splat(b',')) | chunk.simd_eq(Simd::splat(b'.'))).to_bitmask();
                let length = (accept_mask >> skip).trailing_ones() as usize;

                let start = i + skip;
                let end = start + length;
                return (Some(Self {
                    start,
                    end,
                    proper_end: skip + length < Simd::LANES,
                }), i + Simd::LANES)
            }

            i += Simd::LANES;
        }

        (None, i)
    }
}

#[cfg(test)]
mod test_simd_scan {
    use super::*;
    
    #[test] fn test_empty() { assert_eq!(SimdScanResult::scan(b""), (None, 0)) }
    #[test] fn test_short_no_num() { assert_eq!(SimdScanResult::scan(b"hello"), (None, 0)) }
    #[test] fn test_32_no_num() { assert_eq!(SimdScanResult::scan(b"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"), (None, 32)) }
    #[test] fn test_32_num() {
		assert_eq!(
			SimdScanResult::scan(b"~~~12345~~~~~~~~~~~~~~~~~~~~~~~~"),
			(Some(SimdScanResult { start: 3, end: 8, proper_end: true }), 32));
	}
	#[test] fn test_32_num_end() {
		assert_eq!(
			SimdScanResult::scan(b"~~~~~~~~~~~~~~~~~~~~~~~~~~~12345"),
			(Some(SimdScanResult { start: 27, end: 32, proper_end: false }), 32));
	}
    #[test] fn test_64_num() {
		assert_eq!(
			SimdScanResult::scan(b"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|~~~12345~~~~~~~~~~~~~~~~~~~~~~~~"),
			(Some(SimdScanResult { start: 35, end: 40, proper_end: true }), 64));
	}
	#[test] fn test_64_num_end() {
		assert_eq!(
			SimdScanResult::scan(b"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~~~~~~~~~~~12345"),
			(Some(SimdScanResult { start: 59, end: 64, proper_end: false }), 64));
	}
    #[test] fn test_seps() {
		assert_eq!(
			SimdScanResult::scan(b"~~~12,345.05~~~~~~~~~~~~~~~~~~~~"),
			(Some(SimdScanResult { start: 3, end: 12, proper_end: true }), 32));
    }
}

impl<'a> Iterator for NumberScanner<'a> {
    type Item = Range<usize>;

    #[inline]
    fn next(&mut self) -> Option<Self::Item> {
        // This is where the library is actually doing its thing.
        //
        // This function:
        // (1) Scans the input for a digit.
        //       It uses SIMD to scan chunks of substring at a time.
        // (2) It then looks for the end of the number
        //       Necessary because it might be past the SIMD chunk.
        // (3) Checks for a minus sign behind the start.
        // (4) Refine the end of the number
        //       Necessary because the scan also accepts dot and commas, but:
        //       a. we don't want to have two dots (e.g. on "1.5.8" we'd just want
        //          to get "1.5", as "1.5.8" as a whole isn't a number).
        //       b. we don't want to end on punctuation (e.g. on "12.65.", the last
        //          "." is accepted by the scan but should be excluded in the
        //          result).
        
        // Step 1: scan for start
        let (simd, i) = SimdScanResult::scan(self.input);

        let start = match simd {
            Some(SimdScanResult { start, .. }) => start,
            None => self.input[i..].iter().position(u8::is_ascii_digit).map(|n| n + i)?,
        };

        // Step 2: find end
        let end = match simd {
            Some(SimdScanResult { end, proper_end: true, .. }) => end,
            _ => {
                let look_from = if let Some(simd) = simd { simd.end } else { start };
                self.input[look_from..].iter()
                    .position(|&c| !(c.is_ascii_digit() || c == b',' || c == b'.'))
                    .map(|n| n + look_from)
                    .unwrap_or(self.input.len())
            }
        };

        // Step 3: check for minus sign
        let start = if start > 0 && self.input[start - 1] == b'-' { start - 1 } else { start };

        // Step 4a: trim to leftmost period (e.g. "1.2.3" should yield "1.2" and treat "3" as its own
        // number).
        let end = self.input[start..end].iter()
            .enumerate()
            .filter(|&(_,&c)| c == b'.')
            .nth(1)
            .map(|(n, _)| n + start)
            .unwrap_or(end);
        // Step 4b: don't end on punctuation
        let mut end = end;
        while [b',', b'.'].contains(&self.input[end - 1]) {
            end -= 1;
        }
        let end = end;

        let start = self.offset + start;
        self.input = &self.input[end..];
        self.offset += end;

        Some(start..self.offset)
    }
}

#[cfg(test)]
mod test {
    extern crate alloc;

    use alloc::vec::Vec;

    use super::*;

    fn scan(input: &str) -> Vec<Range<usize>> {
        NumberScanner::from(input).collect()
    }

    macro_rules! scan_tests {
        ($($name:ident: $input:expr => $expected:expr),* $(,)?) => {
            $(#[test] fn $name() { assert_eq!(scan($input), $expected) })*
        };
    }

    scan_tests!(
        empty:               ""                                                                 => [],
        short_no_num:        "hello"                                                            => [],
        l32_no_num:          "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"                                 => [],
        l32_num:             "~~~12345~~~~~~~~~~~~~~~~~~~~~~~~"                                 => [3..8],
        l32_num_end:         "~~~~~~~~~~~~~~~~~~~~~~~~~~~12345"                                 => [27..32],
        l64_num:             "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|~~~12345~~~~~~~~~~~~~~~~~~~~~~~~" => [35..40],
        l64_num_end:         "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~|~~~~~~~~~~~~~~~~~~~~~~~~~~~12345" => [59..64],
        l32_seps:            "~~~12,345.05~~~~~~~~~~~~~~~~~~~~"                                 => [3..12],
        short_single_digit:  "~~8~~"                                                            => [2..3],
        single_digit:        "5"                                                                => [0..1],
        short_multiple:      "~~~12~34~5~"                                                      => [3..5, 6..8, 9..10],
        one_dot:             "~~1.2~~"                                                          => [2..5],
        two_dots:            "~~1.2.3~~"                                                        => [2..5, 6..7],
        commas_and_two_dots: "~~-1,000.5,000.8~~"                                               => [2..14, 15..16],
        end_comma:           "~~105,~~"                                                         => [2..5],
        short_commas:        ",,,,,,,"                                                          => [],
        short_periods:       "......."                                                          => [],
    );
}